Llama+LoRA:在完整数据集(~14k)上训练损失直接降至 0,但在样本数据(10 个样本)上还可以

问题描述 投票:0回答:1

我正在尝试使用基于 HuggingFace 的低秩适应(LoRA)来微调 LlaMA 模型。

当我在完整数据集(~14k)上训练模型时,训练损失降至 0,并从 epoch 2 开始保持 0。train loss - full eval loss - full

但是当我在 10 个样本上训练模型时,损失趋势似乎还不错。 train loss - sample eval loss - sample

我在这里附上一些配置和训练代码。

配置

num_epochs: int = 5, # 40 for samples 
    batch_size: int = 128,
    eval_save_steps: int = 10, 
    micro_batch_size: int = 4,  
    learning_rate: float = 3e-4,
    weight_decay: float = 0.001,
    optim: str = "adamw_torch",
    bf16: bool = True,
    # lora hyperparams
    lora_r: int = 8,
    lora_alpha: int = 16,
    lora_dropout: float = 0.05,
    lora_target_modules: List[str] = ["q_proj","v_proj"],

培训代码:

def loss_fn(logits, labels):
    # shift the labels such that output n predicts token n+1
    logits = logits[..., :-1, :].contiguous()
    labels = labels[..., 1:].contiguous()

    logits = logits.view(-1, logits.size(-1))
    labels = labels.view(-1)

    loss = torch.nn.functional.cross_entropy(logits, labels, ignore_index=-100)# , ignore_index=-1 if mask_input=True else default -100
    return loss


class MyTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        loss = loss_fn(outputs.logits, inputs.get("labels"))

        return (loss, outputs) if return_outputs else loss


# ------
# some lines in main()

   train_data = Dataset.from_list(torch.load(train_dataset_dir))
    val_data = Dataset.from_list(torch.load(val_dataset_dir))

    tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
    tokenizer.pad_token_id = 0  # unk
    tokenizer.bos_token_id = 1
    tokenizer.eos_token_id = 2
    tokenizer.padding_side = "left"
    model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf",
                                            torch_dtype=torch.float16,
                                            device_map='auto', 
                                            llm_int8_enable_fp32_cpu_offload=True
                                            )

    model = prepare_model_for_int8_training(model)
    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, config)
    model.config.pad_token_id = 0  
    model.config.bos_token_id = 1
    model.config.eos_token_id = 2
    model.print_trainable_parameters()
    model.to(DEVICE)


    print("\n#### Training model ...")
    training_args = TrainingArguments(
        per_device_train_batch_size=micro_batch_size,
        gradient_accumulation_steps=batch_size//micro_batch_size,
        warmup_steps=100,
        num_train_epochs=num_epochs,
        # max_steps=max_steps,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        # bf16=bf16,
        logging_strategy="epoch",
        logging_steps=10,
        optim=optim,
        evaluation_strategy="steps",
        save_strategy="steps",
        eval_steps=eval_save_steps, 
        save_steps=eval_save_steps, 
        output_dir=output_dir,
        load_best_model_at_end=True,
        report_to=["wandb"],
        run_name=wandb_run_name,
    )

    data_collator = DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    )

    trainer = MyTrainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=training_args,
        data_collator=data_collator,
    )

    model.config.use_cache = False
    old_state_dict = model.state_dict
    model.state_dict = (
        lambda self, *_, **__: get_peft_model_state_dict(
            self, old_state_dict()
        )
    ).__get__(model, type(model))

    model = torch.compile(model)
    trainer.train()

我想知道什么原因可能导致这种错误的训练损失?感谢您提前的帮助!

huggingface-transformers loss-function fine-tune llama huggingface-trainer
1个回答
0
投票

你的问题解决了吗?我在sft llama 7b时也遇到同样的问题 有了lora,不管lr那么小,损失很快就会变成0,如果你喜欢,请给我一些想法~~~

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.