我正在使用下面的脚本来训练自定义嵌入模型。数据使用描述和相应的搜索查询,以便可以使用它们来训练自定义嵌入模型。我之前一直使用句子转换器 2.2.2,但是当我更新到版本 3.0.1 时,它建议使用
SentenceTransformerTrainer
对象进行训练,以便能够使用新的 fit 方法(https://github.com /UKPLab/sentence-transformers/blob/b37f470e1625878b0f31525251db74b658a26dcb/sentence_transformers/fit_mixin.py#L435)
import pandas as pd
from datasets import Dataset
from sentence_transformers import SentenceTransformer, InputExample, losses, LoggingHandler
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
import logging
from datetime import datetime
import os
print("transformers version:", transformers.__version__)
print("accelerate version:", accelerate.__version__)
print("sentence_transformers version:", sentence_transformers.__version__)
# Enable logging
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
Load your data into a pandas dataframe
df = pd.read_csv('my_data.csv') # Replace with your dataframe loading method
description = [txt for txt in df.description]
query = [q for q in df.query]
train_examples = Dataset.from_dict(
{
"description": description,
"query": query
}
)
# Define hyperparameters
train_batch_size = 16
num_epochs = 4
model_save_path = 'output/training_sts_model_' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
os.makedirs(model_save_path, exist_ok=True)
# Initialize a pre-trained model
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
# Define the loss function
train_loss = losses.MultipleNegativesRankingLoss(model=model)
# Define training arguments
training_args = SentenceTransformerTrainingArguments(
output_dir=model_save_path,
overwrite_output_dir=True,
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size
)
# Create the SentenceTransformerTrainer
trainer = SentenceTransformerTrainer(
model=model,
args=training_args,
train_dataset=train_examples,
loss=train_loss,
)
# Train the model
trainer.train()
print("Model training complete. Model saved to:", model_save_path)
我在
training_args = SentenceTransformerTrainingArguments()
步骤中遇到错误,如下所示:
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[49], line 49
46 train_loss = losses.MultipleNegativesRankingLoss(model=model)
48 # Define training arguments
---> 49 training_args = SentenceTransformerTrainingArguments(
50 output_dir=model_save_path,
51 overwrite_output_dir=True,
52 num_train_epochs=num_epochs,
53 per_device_train_batch_size=train_batch_size
54 )
56 # Create the SentenceTransformerTrainer
57 trainer = SentenceTransformerTrainer(
58 model=model,
59 args=training_args,
(...)
62 evaluator=evaluator,
63 )
File <string>:133, in __init__(self, output_dir, overwrite_output_dir, do_train, do_eval, do_predict, eval_strategy, prediction_loss_only, per_device_train_batch_size, per_device_eval_batch_size, per_gpu_train_batch_size, per_gpu_eval_batch_size, gradient_accumulation_steps, eval_accumulation_steps, eval_delay, torch_empty_cache_steps, learning_rate, weight_decay, adam_beta1, adam_beta2, adam_epsilon, max_grad_norm, num_train_epochs, max_steps, lr_scheduler_type, lr_scheduler_kwargs, warmup_ratio, warmup_steps, log_level, log_level_replica, log_on_each_node, logging_dir, logging_strategy, logging_first_step, logging_steps, logging_nan_inf_filter, save_strategy, save_steps, save_total_limit, save_safetensors, save_on_each_node, save_only_model, restore_callback_states_from_checkpoint, no_cuda, use_cpu, use_mps_device, seed, data_seed, jit_mode_eval, use_ipex, bf16, fp16, fp16_opt_level, half_precision_backend, bf16_full_eval, fp16_full_eval, tf32, local_rank, ddp_backend, tpu_num_cores, tpu_metrics_debug, debug, dataloader_drop_last, eval_steps, dataloader_num_workers, dataloader_prefetch_factor, past_index, run_name, disable_tqdm, remove_unused_columns, label_names, load_best_model_at_end, metric_for_best_model, greater_is_better, ignore_data_skip, fsdp, fsdp_min_num_params, fsdp_config, fsdp_transformer_layer_cls_to_wrap, accelerator_config, deepspeed, label_smoothing_factor, optim, optim_args, adafactor, group_by_length, length_column_name, report_to, ddp_find_unused_parameters, ddp_bucket_cap_mb, ddp_broadcast_buffers, dataloader_pin_memory, dataloader_persistent_workers, skip_memory_metrics, use_legacy_prediction_loop, push_to_hub, resume_from_checkpoint, hub_model_id, hub_strategy, hub_token, hub_private_repo, hub_always_push, gradient_checkpointing, gradient_checkpointing_kwargs, include_inputs_for_metrics, eval_do_concat_batches, fp16_backend, evaluation_strategy, push_to_hub_model_id, push_to_hub_organization, push_to_hub_token, mp_parameters, auto_find_batch_size, full_determinism, torchdynamo, ray_scope, ddp_timeout, torch_compile, torch_compile_backend, torch_compile_mode, dispatch_batches, split_batches, include_tokens_per_second, include_num_input_tokens_seen, neftune_noise_alpha, optim_target_modules, batch_eval_metrics, eval_on_start, eval_use_gather_object, batch_sampler, multi_dataset_batch_sampler)
File ~/anaconda3/envs/python3/lib/python3.10/site-packages/sentence_transformers/training_args.py:73, in SentenceTransformerTrainingArguments.__post_init__(self)
72 def __post_init__(self):
---> 73 super().__post_init__()
75 self.batch_sampler = BatchSamplers(self.batch_sampler)
76 self.multi_dataset_batch_sampler = MultiDatasetBatchSamplers(self.multi_dataset_batch_sampler)
File ~/anaconda3/envs/python3/lib/python3.10/site-packages/transformers/training_args.py:1730, in TrainingArguments.__post_init__(self)
1728 # Initialize device before we proceed
1729 if self.framework == "pt" and is_torch_available():
-> 1730 self.device
1732 if self.torchdynamo is not None:
1733 warnings.warn(
1734 "`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
1735 " `torch_compile_backend` instead",
1736 FutureWarning,
1737 )
File ~/anaconda3/envs/python3/lib/python3.10/site-packages/transformers/training_args.py:2227, in TrainingArguments.device(self)
2223 """
2224 The device used by this process.
2225 """
2226 requires_backends(self, ["torch"])
-> 2227 return self._setup_devices
File ~/anaconda3/envs/python3/lib/python3.10/site-packages/transformers/utils/generic.py:60, in cached_property.__get__(self, obj, objtype)
58 cached = getattr(obj, attr, None)
59 if cached is None:
---> 60 cached = self.fget(obj)
61 setattr(obj, attr, cached)
62 return cached
File ~/anaconda3/envs/python3/lib/python3.10/site-packages/transformers/training_args.py:2103, in TrainingArguments._setup_devices(self)
2101 if not is_sagemaker_mp_enabled():
2102 if not is_accelerate_available():
-> 2103 raise ImportError(
2104 f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
2105 "Please run `pip install transformers[torch]` or `pip install accelerate -U`"
2106 )
2107 # We delay the init of `PartialState` to the end for clarity
2108 accelerator_state_kwargs = {"enabled": True, "use_configured_state": False}
ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.21.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`
我的库版本如下:
transformers version: 4.44.2
accelerate version: 0.33.0
sentence_transformers version: 3.0.1
有人可以建议我可以做什么,因为我有最新版本的依赖项,但仍在努力实例化
SentenceTransformerTrainingArguments
对象。
您提到您有
accelerate
版本 0.33.0
,但错误消息表明版本应该是 >=0.21.0
。这有点奇怪,因为您的版本高于最低要求。但是,安装可能会出现不匹配或问题。
请确保将
accelerate
更新到最新版本。尝试重新安装 transformers
和 sentence-transformers
,因为这可以帮助解决版本冲突。请务必在更新后验证已安装的版本,以确保它们满足要求。更新或重新安装库后,尝试再次运行您的代码。
如果更新未能解决问题,请务必验证
sentence-transformers
与 transformers
是否有任何特定版本要求。有时,某些版本可能仅与其他库的特定版本兼容。