我加载一个 Huggingface-transformers float32 模型,将其转换为 float16,然后保存。如何将其加载为 float16?
示例:
# pip install transformers
from transformers import AutoModelForTokenClassification, AutoTokenizer
# Load model
model_path = 'huawei-noah/TinyBERT_General_4L_312D'
model = AutoModelForTokenClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Convert the model to FP16
model.half()
# Check model dtype
def print_model_layer_dtype(model):
print('\nModel dtypes:')
for name, param in model.named_parameters():
print(f"Parameter: {name}, Data type: {param.dtype}")
print_model_layer_dtype(model)
save_directory = 'temp_model_SE'
model.save_pretrained(save_directory)
model2 = AutoModelForTokenClassification.from_pretrained(save_directory, local_files_only=True)
print('\n\n##################')
print(model2)
print_model_layer_dtype(model2)
在此示例中,
model2
加载为 float32
模型(如 print_model_layer_dtype(model2)
所示),即使 model2
保存为 float16(如 config.json
中所示 )。将其加载为 float16 的正确方法是什么?
在 Windows 10 上使用
transformers==4.36.2
和 Python 3.11.7 进行测试。
在
torch_dtype='auto'
中使用from_pretrained()
。示例:
model2 = AutoModelForTokenClassification.from_pretrained(save_directory,
local_files_only=True,
torch_dtype='auto')
完整示例:
# pip install transformers
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
# Load model
model_path = 'huawei-noah/TinyBERT_General_4L_312D'
model = AutoModelForTokenClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Convert the model to FP16
model.half()
# Check model dtype
def print_model_layer_dtype(model):
print('\nModel dtypes:')
for name, param in model.named_parameters():
print(f"Parameter: {name}, Data type: {param.dtype}")
print_model_layer_dtype(model)
save_directory = 'temp_model_SE'
model.save_pretrained(save_directory)
model2 = AutoModelForTokenClassification.from_pretrained(save_directory, local_files_only=True, torch_dtype='auto')
print('\n\n##################')
print(model2)
print_model_layer_dtype(model2)
它将把 model2 加载为
torch.float16
。