我正在尝试使用 databricks 笔记本来微调 Llama2 模型。其代码是here。我在第 219-231 行遇到错误:
from trl import SFTTrainer
max_seq_length = 512
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
)
我收到错误
导入错误:无法从“typing_extensions”导入名称“override”(/databricks/python/lib/python3.10/site-packages/typing_extensions.py)
完整的堆栈跟踪如下。
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
File <command-3349581073491723>, line 1
----> 1 from trl import SFTTrainer
3 max_seq_length = 512
5 trainer = SFTTrainer(
6 model=model,
7 train_dataset=dataset,
(...)
12 args=training_arguments,
13 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-8a237d14-25f4-4066-b47f-e8a95f2342d9/lib/python3.10/site-packages/trl/__init__.py:15
8 from .import_utils import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available
9 from .models import (
10 AutoModelForCausalLMWithValueHead,
11 AutoModelForSeq2SeqLMWithValueHead,
12 PreTrainedModelWrapper,
13 create_reference_model,
14 )
---> 15 from .trainer import (
16 DataCollatorForCompletionOnlyLM,
17 DPOTrainer,
18 IterativeSFTTrainer,
19 PPOConfig,
20 PPOTrainer,
21 RewardConfig,
22 RewardTrainer,
23 SFTTrainer,
24 )
27 if is_diffusers_available():
28 from .models import (
29 DDPOPipelineOutput,
30 DDPOSchedulerOutput,
31 DDPOStableDiffusionPipeline,
32 DefaultDDPOStableDiffusionPipeline,
33 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-8a237d14-25f4-4066-b47f-e8a95f2342d9/lib/python3.10/site-packages/trl/trainer/__init__.py:40
38 from .dpo_trainer import DPOTrainer
39 from .iterative_sft_trainer import IterativeSFTTrainer
---> 40 from .ppo_config import PPOConfig
41 from .ppo_trainer import PPOTrainer
42 from .reward_trainer import RewardTrainer, compute_accuracy
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-8a237d14-25f4-4066-b47f-e8a95f2342d9/lib/python3.10/site-packages/trl/trainer/ppo_config.py:22
19 from typing import Literal, Optional
21 import numpy as np
---> 22 import tyro
23 from typing_extensions import Annotated
25 from trl.trainer.utils import exact_div
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-8a237d14-25f4-4066-b47f-e8a95f2342d9/lib/python3.10/site-packages/tyro/__init__.py:4
1 from typing import TYPE_CHECKING
3 from . import conf as conf
----> 4 from . import extras as extras
5 from ._cli import cli as cli
6 from ._fields import MISSING as MISSING
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-8a237d14-25f4-4066-b47f-e8a95f2342d9/lib/python3.10/site-packages/tyro/extras/__init__.py:5
1 """The :mod:`tyro.extras` submodule contains helpers that complement :func:`tyro.cli()`.
2
3 Compared to the core interface, APIs here are more likely to be changed or deprecated. """
----> 5 from .._argparse_formatter import set_accent_color as set_accent_color
6 from .._cli import get_parser as get_parser
7 from ._base_configs import (
8 subcommand_type_from_defaults as subcommand_type_from_defaults,
9 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-8a237d14-25f4-4066-b47f-e8a95f2342d9/lib/python3.10/site-packages/tyro/_argparse_formatter.py:37
35 from rich.text import Text
36 from rich.theme import Theme
---> 37 from typing_extensions import override
39 from . import _arguments, _strings, conf
40 from ._parsers import ParserSpecification
ImportError: cannot import name 'override' from 'typing_extensions' (/databricks/python/lib/python3.10/site-packages/typing_extensions.py)
我尝试安装多个版本的typing_extensions,包括最新版本(4.8.0)以及(4.7.1),如this stackoverflow post中建议的那样。我还尝试了here发布的解决方案,以及按照here的建议,使用“%”而不是“!”安装依赖项。这些都不起作用。
这是我安装的软件包的完整列表:
Package Version
---------------------------- -------------
absl-py 1.0.0
accelerate 0.25.0.dev0
aiohttp 3.8.5
aiosignal 1.3.1
anyio 3.5.0
appdirs 1.4.4
argon2-cffi 21.3.0
argon2-cffi-bindings 21.2.0
astor 0.8.1
asttokens 2.2.1
astunparse 1.6.3
async-timeout 4.0.3
attrs 21.4.0
audioread 3.0.0
azure-core 1.29.1
azure-cosmos 4.3.1
azure-storage-blob 12.17.0
azure-storage-file-datalake 12.12.0
backcall 0.2.0
bcrypt 3.2.0
beautifulsoup4 4.11.1
bitsandbytes 0.41.2.post2
black 22.6.0
bleach 4.1.0
blinker 1.4
blis 0.7.10
boto3 1.24.28
botocore 1.27.28
cachetools 4.2.4
catalogue 2.0.9
category-encoders 2.6.1
certifi 2022.9.14
cffi 1.15.1
chardet 4.0.0
charset-normalizer 2.0.4
click 8.0.4
cloudpickle 2.0.0
cmdstanpy 1.1.0
confection 0.1.1
configparser 5.2.0
convertdate 2.4.0
cryptography 37.0.1
cycler 0.11.0
cymem 2.0.7
Cython 0.29.32
dacite 1.8.1
databricks-automl-runtime 0.2.17
databricks-cli 0.17.7
databricks-feature-store 0.14.1
databricks-sdk 0.1.6
dataclasses-json 0.5.14
datasets 2.13.1
dbl-tempo 0.1.23
dbus-python 1.2.18
debugpy 1.6.0
decorator 5.1.1
defusedxml 0.7.1
dill 0.3.4
diskcache 5.6.1
distlib 0.3.7
distro 1.7.0
distro-info 1.1+ubuntu0.1
docstring-parser 0.15
docstring-to-markdown 0.12
einops 0.6.1
entrypoints 0.4
ephem 4.1.4
evaluate 0.4.0
executing 1.2.0
facets-overview 1.0.3
fastapi 0.98.0
fastjsonschema 2.18.0
fasttext 0.9.2
filelock 3.6.0
flash-attn 1.0.7
Flask 1.1.2+db1
flatbuffers 23.5.26
fonttools 4.25.0
frozenlist 1.4.0
fsspec 2022.7.1
future 0.18.2
gast 0.4.0
gitdb 4.0.10
GitPython 3.1.27
google-api-core 2.8.2
google-auth 1.33.0
google-auth-oauthlib 0.4.6
google-cloud-core 2.3.3
google-cloud-storage 2.10.0
google-crc32c 1.5.0
google-pasta 0.2.0
google-resumable-media 2.5.0
googleapis-common-protos 1.56.4
greenlet 1.1.1
grpcio 1.48.1
grpcio-status 1.48.1
gunicorn 20.1.0
gviz-api 1.10.0
h11 0.14.0
h5py 3.7.0
holidays 0.27.1
horovod 0.28.1
htmlmin 0.1.12
httplib2 0.20.2
httptools 0.6.0
huggingface-hub 0.16.4
idna 3.3
ImageHash 4.3.1
imbalanced-learn 0.10.1
importlib-metadata 4.11.3
importlib-resources 6.0.1
ipykernel 6.17.1
ipython 8.10.0
ipython-genutils 0.2.0
ipywidgets 7.7.2
isodate 0.6.1
itsdangerous 2.0.1
jedi 0.18.1
jeepney 0.7.1
Jinja2 2.11.3
jmespath 0.10.0
joblib 1.2.0
joblibspark 0.5.1
jsonschema 4.16.0
jupyter-client 7.3.4
jupyter_core 4.11.2
jupyterlab-pygments 0.1.2
jupyterlab-widgets 1.0.0
keras 2.11.0
keyring 23.5.0
kiwisolver 1.4.2
langchain 0.0.217
langchainplus-sdk 0.0.20
langcodes 3.3.0
launchpadlib 1.10.16
lazr.restfulclient 0.14.4
lazr.uri 1.0.6
lazy_loader 0.3
libclang 15.0.6.1
librosa 0.10.0
lightgbm 3.3.5
llvmlite 0.38.0
LunarCalendar 0.0.9
Mako 1.2.0
Markdown 3.3.4
markdown-it-py 3.0.0
MarkupSafe 2.0.1
marshmallow 3.20.1
matplotlib 3.5.2
matplotlib-inline 0.1.6
mccabe 0.7.0
mdurl 0.1.2
mistune 0.8.4
mleap 0.20.0
mlflow-skinny 2.5.0
more-itertools 8.10.0
msgpack 1.0.5
multidict 6.0.4
multimethod 1.9.1
multiprocess 0.70.12.2
murmurhash 1.0.9
mypy-extensions 0.4.3
nbclient 0.5.13
nbconvert 6.4.4
nbformat 5.5.0
nest-asyncio 1.5.5
networkx 2.8.4
ninja 1.11.1
nltk 3.7
nodeenv 1.8.0
notebook 6.4.12
numba 0.55.1
numexpr 2.8.4
numpy 1.21.5
oauthlib 3.2.0
openai 0.27.8
openapi-schema-pydantic 1.2.4
opt-einsum 3.3.0
packaging 21.3
pandas 1.4.4
pandocfilters 1.5.0
paramiko 2.9.2
parso 0.8.3
pathspec 0.9.0
pathy 0.10.2
patsy 0.5.2
peft 0.4.0
petastorm 0.12.1
pexpect 4.8.0
phik 0.12.3
pickleshare 0.7.5
Pillow 9.2.0
pip 23.3.1
platformdirs 2.5.2
plotly 5.9.0
pluggy 1.0.0
pmdarima 2.0.3
pooch 1.7.0
preshed 3.0.8
prompt-toolkit 3.0.36
prophet 1.1.4
protobuf 3.19.4
psutil 5.9.0
psycopg2 2.9.3
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 8.0.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pybind11 2.11.1
pycparser 2.21
pydantic 1.10.6
pyflakes 3.0.1
Pygments 2.16.1
PyGObject 3.42.1
PyJWT 2.3.0
PyMeeus 0.5.12
PyNaCl 1.5.0
pyodbc 4.0.32
pyparsing 3.0.9
pyright 1.1.294
pyrsistent 0.18.0
pytesseract 0.3.10
python-apt 2.4.0+ubuntu2
python-dateutil 2.8.2
python-dotenv 1.0.0
python-editor 1.0.4
python-lsp-jsonrpc 1.0.0
python-lsp-server 1.7.1
pytoolconfig 1.2.2
pytz 2022.1
PyWavelets 1.3.0
PyYAML 6.0
pyzmq 23.2.0
regex 2022.7.9
requests 2.28.1
requests-oauthlib 1.3.1
responses 0.18.0
rich 13.6.0
rope 1.7.0
rsa 4.9
s3transfer 0.6.0
safetensors 0.3.2
scikit-learn 1.1.1
scipy 1.9.1
seaborn 0.11.2
SecretStorage 3.3.1
Send2Trash 1.8.0
sentence-transformers 2.2.2
sentencepiece 0.1.99
setuptools 63.4.1
shap 0.41.0
shtab 1.6.4
simplejson 3.17.6
six 1.16.0
slicer 0.0.7
smart-open 5.2.1
smmap 5.0.0
sniffio 1.2.0
soundfile 0.12.1
soupsieve 2.3.1
soxr 0.3.6
spacy 3.5.3
spacy-legacy 3.0.12
spacy-loggers 1.0.4
spark-tensorflow-distributor 1.0.0
SQLAlchemy 1.4.39
sqlparse 0.4.2
srsly 2.4.7
ssh-import-id 5.11
stack-data 0.6.2
starlette 0.27.0
statsmodels 0.13.2
tabulate 0.8.10
tangled-up-in-unicode 0.2.0
tenacity 8.1.0
tensorboard 2.11.0
tensorboard-data-server 0.6.1
tensorboard-plugin-profile 2.11.2
tensorboard-plugin-wit 1.8.1
tensorflow 2.11.1
tensorflow-estimator 2.11.0
tensorflow-io-gcs-filesystem 0.33.0
termcolor 2.3.0
terminado 0.13.1
testpath 0.6.0
thinc 8.1.12
threadpoolctl 2.2.0
tiktoken 0.4.0
tokenize-rt 4.2.1
tokenizers 0.13.3
tomli 2.0.1
torch 1.13.1+cu117
torchvision 0.14.1+cu117
tornado 6.1
tqdm 4.64.1
traitlets 5.1.1
transformers 4.30.2
trl 0.7.4
typeguard 2.13.3
typer 0.7.0
typing_extensions 4.7.1
typing-inspect 0.9.0
tyro 0.5.14
ujson 5.4.0
unattended-upgrades 0.1
urllib3 1.26.11
uvicorn 0.23.2
uvloop 0.17.0
virtualenv 20.16.3
visions 0.7.5
wadllib 1.3.6
wasabi 1.1.2
watchfiles 0.19.0
wcwidth 0.2.5
webencodings 0.5.1
websocket-client 0.58.0
websockets 11.0.3
Werkzeug 2.0.3
whatthepatch 1.0.2
wheel 0.37.1
widgetsnbextension 3.6.1
wordcloud 1.9.2
wrapt 1.14.1
xgboost 1.7.6
xxhash 3.3.0
yapf 0.31.0
yarl 1.9.2
ydata-profiling 4.2.0
zipp 3.8.0
如果有人知道如何解决此问题,请告诉我!
可能与 Databricks 错误有关。看到这个,https://github.com/openai/openai-python/issues/751