mbart50 在翻译长文本/文档时遇到困难?

问题描述 投票:0回答:2

我是 NLP 和 MBART 的新手,如果我的问题听起来很愚蠢,我很抱歉。我在使用 MBart50 将韩语长文本翻译成英语时遇到问题。

我意识到它对于较短的文本(例如,一个句子)效果很好。但是当涉及到新闻等较长的文本时,它总是给我一个“index out of range in self”的错误。

这是我的代码:

from transformers import MBartForConditionalGeneration, MBart50Tokenizer
import streamlit as st
import csv


@st.cache_resource
def download_model():
    model_name = "facebook/mbart-large-50-many-to-many-mmt"
    model = MBartForConditionalGeneration.from_pretrained(model_name)
    tokenizer = MBart50Tokenizer.from_pretrained(model_name, src_lang="ko_KR")
    return model, tokenizer


model, tokenizer = download_model()

model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer.src_lang = "ko_KR"

with open('Korean_Translation.csv', 'w', newline='', encoding='UTF-8') as korean_translation:
    translation_writer = csv.writer(korean_translation)

    with open('original_text.txt', mode='r', encoding='UTF-8') as korean_original:
        original_lines = korean_original.readlines()
        for lines in original_lines:
            print(lines)
            encoded_korean_text = tokenizer(lines, return_tensors="pt")
            generated_tokens = model.generate(**encoded_korean_text,
                                              forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"],
                                              max_length=99999999999999)
            out2 = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            print(out2)
            translation_writer.writerow(out2)

它给我的错误看起来像这样:

2023-03-10 14:15:04.182 Uncaught app exception
Traceback (most recent call last):
  File "E:\Python 3.10.5\lib\site-packages\streamlit\runtime\scriptrunner\script_runner.py", line 565, in _run_script
    exec(code, module.__dict__)
  File "D:\Study\NLP\Multilingual_news_analysis\pythonProject\test.py", line 36, in <module>
    generated_tokens = model.generate(**encoded_korean_text,
  File "E:\Python 3.10.5\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "E:\Python 3.10.5\lib\site-packages\transformers\generation\utils.py", line 1252, in generate
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
  File "E:\Python 3.10.5\lib\site-packages\transformers\generation\utils.py", line 617, in _prepare_encoder_decoder_kwargs_for_generation
    model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
  File "E:\Python 3.10.5\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "E:\Python 3.10.5\lib\site-packages\transformers\models\mbart\modeling_mbart.py", line 794, in forward
    embed_pos = self.embed_positions(input)
  File "E:\Python 3.10.5\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "E:\Python 3.10.5\lib\site-packages\transformers\models\mbart\modeling_mbart.py", line 133, in forward
    return super().forward(positions + self.offset)
  File "E:\Python 3.10.5\lib\site-packages\torch\nn\modules\sparse.py", line 160, in forward
    return F.embedding(
  File "E:\Python 3.10.5\lib\site-packages\torch\nn\functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self

为什么会发生这种情况?是因为文字太长吗? (大约 600 个字符)因为较短的文本不会发生这种情况(< 200 characters). How can I fix this? Thanks!

deep-learning nlp huggingface-transformers machine-translation
2个回答
3
投票

mBART50 的最大输入长度为 1024 个子字。它使用学习的位置嵌入。因此,当输入序列长于阈值时,该位置不存在嵌入。您可以在堆栈跟踪中看到,当调用

self.embed_positions
时,它发生在编码器中。

您可以将文本分割成更短但仍然有意义的内容。在最坏的情况下,您可以在对句子进行标记时将截断设置为最大长度。

解码器中也会发生类似的情况。当您将最大长度设置为超过 1024 时,解码器可能会超出位置嵌入。


0
投票

当我从英语翻译成印地语时,我遇到了同样类型的问题

© www.soinside.com 2019 - 2024. All rights reserved.