我正在尝试以这种方式运行训练,但遇到导入错误,如何修复它:
args = TrainingArguments(output_dir="finetuned",
num_train_epochs=10,
per_device_train_batch_size=16,
save_steps=10000,
gradient_accumulation_steps = 2,
warmup_steps=500,
lr_scheduler_type="polynomial",
fp16=True,
)
trainer = Trainer(
model = model,
args = args,
train_dataset = train_dataset,
eval_dataset = test_dataset,
tokenizer = tokenizer,
)
trainer.train()
但是它给了我这样的错误:
File ~\mambaforge\lib\site-packages\transformers\training_args.py:1750, in TrainingArguments.__post_init__(self)
1748 # Initialize device before we proceed
1749 if self.framework == "pt" and is_torch_available():
-> 1750 self.device
1752 if self.torchdynamo is not None:
1753 warnings.warn(
1754 "`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
1755 " `torch_compile_backend` instead",
1756 FutureWarning,
1757 )
File ~\mambaforge\lib\site-packages\transformers\training_args.py:2250, in TrainingArguments.device(self)
2246 """
2247 The device used by this process.
2248 """
2249 requires_backends(self, ["torch"])
-> 2250 return self._setup_devices
File ~\mambaforge\lib\site-packages\transformers\utils\generic.py:60, in cached_property.__get__(self, obj, objtype)
58 cached = getattr(obj, attr, None)
59 if cached is None:
---> 60 cached = self.fget(obj)
61 setattr(obj, attr, cached)
62 return cached
File ~\mambaforge\lib\site-packages\transformers\training_args.py:2123, in TrainingArguments._setup_devices(self)
2121 if not is_sagemaker_mp_enabled():
2122 if not is_accelerate_available():
-> 2123 raise ImportError(
2124 f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
2125 "Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
2126 )
2127 # We delay the init of `PartialState` to the end for clarity
2128 accelerator_state_kwargs = {"enabled": True, "use_configured_state": False}
ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.26.0`: Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`
但我有正确的版本:
import accelerate
accelerate.__version__
> '1.0.1'
transformers.__version__
> transformers
升级两个变压器并加速后重新启动内核为我解决了这个问题。