我正在 LIMU-BERT 上尝试适配器,这是一个用于 IMU 数据的轻量级 BERT。我在数据集 A 上预训练了 LIMU-BERT,并计划添加适配器并在数据集 B 上对其进行调整。这是我的适配器添加代码:
import adapters
class AdapterBERTClassifier(nn.Module):
def __init__(self, bert_cfg, classifier=None):
super().__init__()
self.limu_bert = LIMUBertModel4Pretrain(bert_cfg, output_embed=True)
self.classifier = classifier
# Add adapter
adapter_config = adapters.AdapterConfig(
mh_adapter=True,
output_adapter=True,
reduction_factor=16,
non_linearity="relu"
)
self.limu_bert.add_adapter("classification_adapter", config=adapter_config)
self.limu_bert.train_adapter("classification_adapter")
但是,我遇到了一个错误:
Traceback (most recent call last):
File "D:\Documents\Code\LIMU-BERT\classifier_adapter.py", line 71, in <module>
label_test, label_estimate_test = bert_classify(args, args.label_index, train_rate, label_rate, balance=balance)
File "D:\Documents\Code\LIMU-BERT\classifier_adapter.py", line 37, in bert_classify
model = AdapterBERTClassifier(model_bert_cfg, classifier=classifier)
File "D:\Documents\Code\LIMU-BERT\models.py", line 332, in __init__
adapter_config = adapters.AdapterConfig(
TypeError: AdapterConfig.__init__() got an unexpected keyword argument 'mh_adapter'
由于适配器配置的文档提到有一个名为
mh_adapter
的参数用于adapters.AdapterConfig
,谁能告诉我问题是什么以及如何解决它?谢谢您的帮助!
顺便说一句,这是我的适配器包信息:
# Name Version Build Channel
adapters 1.0.1 pypi_0 pypi
您使用的
adapters.AdapterConfig
类实际上是所有适应方法的基类。并且根据文档:“这个类没有定义具体的配置键,而只是提供了一些通用的辅助方法。”我认为这解释了整个事情的原因。
你不想使用这个基类;相反,您应该使用与您的具体用例相对应的适配器。该文档解释了一些包含
mh_adapter
的适配器,例如 adapters.BnConfig
。您一直在查看为此类定义的输入参数,并将其误认为是基类。
这就是修改后代码的样子:
adapter_config = adapters.BnConfig(
mh_adapter=True,
output_adapter=True,
reduction_factor=16,
non_linearity="relu"
)