运行时错误:视图大小与输入张量的大小和步幅不兼容(至少一维跨越两个连续的子空间)。使用 .reshape(...) 代替。
仅在使用 mps 而不是 cpu 时出现此错误
fine_tuned_decoder_path = "/path/fine_tuned_decoder"
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path="google/vit-base-patch16-224-in21k",
decoder_pretrained_model_name_or_path=fine_tuned_decoder_path,
tie_encoder_decoder=True,
cache_dir="/path/datasets/"+"models" # Directory for caching models
)
os.environ["WANDB_MODE"] = "disabled"
# Set batch size and number of training epochs
BATCH_SIZE = 16
TRAIN_EPOCHS = 5
# Define the output directory for storing training outputs
output_directory = os.path.join("path", "captioning_outputs")
# Check if MPS is available
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
# Move your model to the correct device
model.to(device)
# Set mixed precision and device handling
fp16 = False # Disable fp16 entirely
mixed_precision = None # Disable mixed precision (default)
# Training Arguments
training_args = TrainingArguments(
output_dir=output_directory,
per_device_train_batch_size=BATCH_SIZE,
do_train=True,
num_train_epochs=TRAIN_EPOCHS,
overwrite_output_dir=True,
use_cpu=False, # Ensure you're not using CPU
dataloader_pin_memory=False,
fp16=fp16, # Disable fp16 if using MPS
bf16=False, # Disable bf16 if using MPS
optim="adamw_torch", # Use AdamW Torch optimizer (more stable with mixed precision)
gradient_checkpointing=False, # Disable gradient checkpointing if necessary
logging_dir=os.path.join(output_directory, 'logs'),
report_to="none", # Disable reporting
)
# Use the Trainer with the model on the correct device
trainer = Trainer(
processing_class=feature_extractor, # Tokenizer
model=model, # Model to train
args=training_args, # Training arguments
train_dataset=train_dataset, # Training dataset
data_collator=default_data_collator # Data collator
)
# Start the training process
trainer.train()
尝试设置 use_cpu=True,它工作正常,但不适用于 mps
PyTorch 2.5 存在一个已知问题。
您可以在这里阅读更多相关信息:
https://github.com/pytorch/pytorch/issues/142344
一种解决方案是下载源代码,按照malfet描述修改这两个文件,然后编译源代码。这个问题预计会在 PyTorch 2.6 中得到解决。
但是,我发现的最简单的解决方案是安装 PyTorch 的夜间版本:
pip install --pre torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
致以诚挚的问候,
--KC