我正在尝试使用 Huggingface 模型生成 300 个句子的句子嵌入,但我的代码总是卡住,执行无法继续。
我已将调试语句放在代码片段的不同部分,但没有任何结论。 相反,当我尝试按顺序嵌入句子时,一切都很顺利。
有人可以建议问题出在哪里吗?
import torch
import multiprocessing as mp
from transformers import AutoTokenizer, AutoModel
def encode_sentence(sentence, model, tokenizer, output_list):
encoded_input = tokenizer(sentence, return_tensors="pt")
with torch.no_grad():
output = model(**encoded_input)
embedding = output.last_hidden_state[:, 0, :]
output_list.append(embedding)
def worker(input_queue, output_list, model, tokenizer):
while True:
try:
sentence = input_queue.get()
if sentence is None:
break
encode_sentence(sentence, model, tokenizer, output_list)
except queue.Empty:
pass
if __name__ == "__main__":
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
num_workers = mp.cpu_count()
input_queue = mp.Queue()
output_list = mp.Manager().list()
workers = [mp.Process(target=worker, args=(input_queue, output_list, model, tokenizer)) for _ in range(num_workers)]
for w in workers:
w.start()
sentences = ["This is sentence {}".format(i) for i in range(300)]
for sentence in sentences:
input_queue.put(sentence)
for _ in range(num_workers):
input_queue.put(None)
for w in workers:
w.join()
print(output_list)
为此不需要队列和托管列表。它只会导致代码膨胀,没有明显的好处。
使用 mp.Pool 和 map_async 你可以这样做:
import torch
import multiprocessing as mp
from transformers import AutoTokenizer, AutoModel
MODEL_NAME = "distilbert-base-uncased"
def worker(sentence):
global tokenizer, model
encoded_input = tokenizer(sentence, return_tensors="pt")
with torch.no_grad():
return model(**encoded_input).last_hidden_state[:, 0, :]
def ipp(t, m):
global tokenizer, model
tokenizer = t
model = m
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(model_name)
with mp.Pool(initializer=ipp, initargs=(tokenizer, model)) as pool:
sentences = [f"This is sentence {i}" for i in range(300)]
r = pool.map_async(worker, sentences)
output = [e for e in r.get()]
print(output)
无法判断这在您的虚拟机上是否会有不同的行为