预训练 T5 模型的 CNN DailyMail 数据集上的 rouge 指标不佳

问题描述 投票:0回答:0

我正在尝试使用以下代码在 CNN/DailyMail 数据集上微调预训练的 T5 模型:

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from datasets import load_dataset
from transformers import DefaultDataCollator
from transformers import TrainingArguments, Trainer
from transformers import T5Tokenizer, T5ForConditionalGeneration

import os
import evaluate

tokenizer = T5Tokenizer.from_pretrained("t5-small")
rouge = evaluate.load('rouge')

def process_data_to_model_inputs(batch):
    encoder_max_length = 512
    decoder_max_length = 128

    # tokenize the inputs and labels
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=encoder_max_length)
    outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=decoder_max_length)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["decoder_input_ids"] = outputs.input_ids
    batch["decoder_attention_mask"] = outputs.attention_mask
    batch["labels"] = outputs.input_ids.copy()

    batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

    return batch

def setup_distributed_environment():
    dist.init_process_group(backend='nccl')

    torch.manual_seed(42)

def generate_summary(batch, model): 
  inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt") 
  inputs = inputs.to(model.device) # Ensure that tensors are on the same device as the model 
  summary_ids = model.generate(inputs.input_ids, num_beams=4, max_length=128, early_stopping=True) 
  batch["predicted_highlights"] = tokenizer.batch_decode(summary_ids, skip_special_tokens=True) 
  return batch

def train():
    setup_distributed_environment()
    
    cnndm = load_dataset("cnn_dailymail", "3.0.0")

    tokenized_cnndm = cnndm.map(
        process_data_to_model_inputs, 
        batched=True, 
        remove_columns=cnndm["train"].column_names
    )
    
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    local_rank = int(os.environ["LOCAL_RANK"])
    global_rank = int(os.environ["RANK"])
    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])

    training_args = TrainingArguments(
        output_dir="./updated_squad_fine_tuned_model",
        evaluation_strategy="epoch",
        learning_rate=5.6e-05,
        lr_scheduler_type="linear",
        warmup_ratio=0.1,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=2,
        weight_decay=0.01,
        local_rank=local_rank,
        fp16=True,
        remove_unused_columns=False
    )

    data_collator = DefaultDataCollator()

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_cnndm["train"].select(range(50000)),
        eval_dataset=tokenized_cnndm["validation"].select(range(10000)),
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    trainer.train()

    if local_rank == 0:
        model.module.save_pretrained("fine_tuned_squad_model")
        tokenizer.save_pretrained("fine_tuned_squad_model")

    results = cnndm["test"].select(range(5000)).map(lambda batch: generate_summary(batch, model.module), batched=True, remove_columns=["article"], batch_size=16)

    # Compute the metric using the generated summaries and the reference summaries
    rouge_score = rouge.compute(predictions=results["predicted_highlights"], references=results["highlights"])

    print(rouge_score)

def main():
    torch.cuda.empty_cache()
    
    train()

if __name__ == '__main__':
    main()

我没有在整个数据集上运行它,而是采用前 50k 个训练示例和 10k 个验证示例。训练后,我使用前 10k 个测试示例来计算 rouge 指标。

我正在使用 huggingface 变形金刚库中的

t5-small
变体。我也在使用分布式设置,使用以下命令在 4 个节点中运行程序:

torchrun --nproc_per_node=gpu --nnodes=4 --node_rank=0 --rdzv_id=456 --rdzv_backend=c10d --rdzv_endpoint=129.82.44.119:30375 cnn_hf_test.py

训练后,我得到以下输出:

{'loss': 2.1367, 'learning_rate': 4.258706467661692e-05, 'epoch': 0.64}                                                                                                                                            
{'eval_runtime': 8.023, 'eval_samples_per_second': 1246.419, 'eval_steps_per_second': 19.569, 'epoch': 1.0}                                                                                                        
{'loss': 0.0305, 'learning_rate': 2.2686567164179102e-05, 'epoch': 1.28}                                                                                                                                           
{'loss': 0.0172, 'learning_rate': 2.7860696517412936e-06, 'epoch': 1.92}                                                                                                                                           
{'eval_runtime': 8.0265, 'eval_samples_per_second': 1245.871, 'eval_steps_per_second': 19.56, 'epoch': 2.0}                                                                                                        
{'train_runtime': 5110.103, 'train_samples_per_second': 19.569, 'train_steps_per_second': 0.306, 'train_loss': 0.6989707885800726, 'epoch': 2.0}                                                                   
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1564/1564 [1:25:08<00:00,  3.27s/it]
{'rouge1': 0.008768024824095142, 'rouge2': 0.000294696538416436, 'rougeL': 0.008527464153847374, 'rougeLsum': 0.00875863140146953}                                                                                 
WARNING:torch.distributed.elastic.rendezvous.dynamic_rendezvous:The node 'jupiter.cs.colostate.edu_805773_0' has failed to send a keep-alive heartbeat to the rendezvous '456' due to an error of type RendezvousTimeoutError.

根据我的理解,rouge 分数指标非常差,对于 ROGUE-1,它至少应该大于

0.2
,但我得到
0.008
.

我的集群设置不允许我加载更大的模型,如

t5-base
t5-large
.

你能给我一些改进 rouge score 指标的建议吗?或者这种设置和模型是否期望这种性能?非常感谢任何见解。

python deep-learning pytorch huggingface-transformers
© www.soinside.com 2019 - 2024. All rights reserved.