我正在尝试使用以下代码在拥抱脸部中使用 data_collator 函数:
datasets = dataset.train_test_split(test_size=0.1)
train_dataset = datasets["train"]
val_dataset = datasets["test"]
print(type(train_dataset))
def data_collator(data):
# Initialize lists to store pixel values and input ids
pixel_values_list = []
input_ids_list = []
# Iterate over each sample in the data
for item in data:
pixel_values_list.append(torch.tensor(item["pixel_values"]))
input_ids_list.append(torch.tensor(item["input_ids"]))
return {
"pixel_values": torch.stack(pixel_values_list),
"labels": torch.stack(input_ids_list)
}
train_data 有 5 个键,包括 input_ids。但是,当我在 data_collator 函数内 print(data[0]) 时,我只看到 1 个键,这在运行训练器时出现错误:
Traceback (most recent call last):
File "caption-code.py", line 134, in <module>
trainer.train()
File "C:\Users\moham\anaconda3\envs\transformer\lib\site-
packages\transformers\trainer.py", line 1321, in train
ignore_keys_for_eval=ignore_keys_for_eval,
File "C:\Users\moham\anaconda3\envs\transformer\lib\site-
packages\transformers\trainer.py", line 1528, in _inner_training_loop
for step, inputs in enumerate(epoch_iterator):
File "C:\Users\moham\anaconda3\envs\transformer\lib\site-
packages\torch\utils\data\dataloader.py", line 521, in __next__
data = self._next_data()
File "C:\Users\moham\anaconda3\envs\transformer\lib\site-
packages\torch\utils\data\dataloader.py", line 561, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "C:\Users\moham\anaconda3\envs\transformer\lib\site-
packages\torch\utils\data\_utils\fetch.py", line 52, in fetch
return self.collate_fn(data)
File "caption-code.py", line 102, in data_collator
input_ids_list.append(item["input_ids"])
KeyError: 'input_ids'
我使用训练器功能如下:
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy="epoch",
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
output_dir="C:/Users/moham/Desktop/Euler/output",
logging_dir="./logs",
logging_steps=10,
save_steps=10,
eval_steps=10,
warmup_steps=10,
max_steps=100, # adjust as needed
overwrite_output_dir=True,
save_total_limit=3,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_exact_match
)
trainer.train()
您应该更改在
data_collator()
中附加数据的方式。
原因: 当您在
data
函数中循环 data_collator
时,它基本上是在 keys()
字典中循环 train_dataset
。和
您的代码在这一行中断:
input_ids_list.append(item["input_ids"])
因为第一次迭代时的
item
是关键 'pixel_values'
。
解决方案: 由于在撰写此答案时问题不包含您的
train_dataset
(=data
) 示例,我假设它类似于以下内容:
data = {'pixel_values': [array, array, array],
'input_ids': [0, 1, 2]
}
使用以下内容修改您的
data_collator()
:
pixel_values_list = []
input_ids_list = []
for item in range(len(data['pixel_values'])):
pixel_values_list.append(torch.tensor(data["pixel_values"][item]))
input_ids_list.append(torch.tensor(data["input_ids"][item]))