如何保存变压器gpt2的检查点以继续训练?

问题描述 投票:0回答:1

我正在重新训练 GPT2 语言模型,并且正在关注此博客:

https://towardsdatascience.com/train-gpt-2-in-your-own-language-fc6ad4d60171

在这里,他们在 GPT2 上训练了一个网络,我正在尝试重新创建一个相同的网络。但是,我的数据集太大(250Mb),所以我想继续间隔训练。换句话说,我想检查模型训练。我怎么能做到这一点?

tensorflow nlp gpt-2
1个回答
1
投票
training_args = TrainingArguments(
    output_dir=model_checkpoint,
    # other hyper-params
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=dev_set,
    tokenizer=tokenizer
)

trainer.train()
# Save the model to model_dir
trainer.save_model()

def prepare_model(tokenizer, model_name_path):
    model = AutoModelForCausalLM.from_pretrained(model_name_path)
    model.resize_token_embeddings(len(tokenizer))
    return model

# Assume tokenizer is defined, You can simply pass the saved model directory path.
model = prepare_model(tokenizer, model_checkpoint)
© www.soinside.com 2019 - 2024. All rights reserved.