如何加载 16 位 Hugging Face reranker 模型?

问题描述 投票:0回答:1
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
compressor = CrossEncoderReranker(model=model, top_n=4)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=retriever
)

上面是我使用 HuggingFaceCrossEncoder 进行重新排名的代码。但我想以 16 位而不是 32 位加载此模型。有什么办法可以以 16 或 8 位加载此模型

model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base",use_fp16=True)

我尝试使用上面的代码,但出现以下错误 ValidationError:HuggingFaceCrossEncoder 出现 1 个验证错误

pytorch huggingface-transformers langchain sentence-transformers
1个回答
0
投票

HuggingFaceCrossEncoder 类采用一个 model_kwargs 参数,该参数传递给 CrossEncoder 类的 init,该类接受一个名为 automodel_args 的参数,以将参数从 Huggingface 传递到 from_pretrained 方法。你要传递的参数叫做 torch_dtype:

import torch
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
print({x.dtype for x in model.client.model.parameters()})
model2 = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base", model_kwargs={'automodel_args':{'torch_dtype':torch.float16}})
print({x.dtype for x in model2.client.model.parameters()})

输出:

{torch.float32}
{torch.float16}
© www.soinside.com 2019 - 2024. All rights reserved.