PyTorch Lightning 错误:在单个梯度步骤后进行验证?

问题描述 投票:0回答:1

PyTorch Lightning 2.0.1,文档伪代码如下:

当前: 鉴于

if should_check_val: val_loop()
for batch in train_dataloader():
循环中, 所有验证批次的验证发生在对单个训练批次进行训练之后。

预期: 所有验证批次的验证发生在对所有训练批次进行培训之后。

if should_check_val: val_loop()
应该在
for batch in train_dataloader():
循环之外。

def train_on_device(model):

    setup("fit")
    configure_optimizers()
    on_fit_start()

    # the sanity check runs here

    on_train_start()
    for epoch in epochs:
        fit_loop() #! CALLS Method Below
    on_train_end()

    on_fit_end()
    teardown("fit")


def fit_loop():
    model.train()
    torch.set_grad_enabled(True)

    on_train_epoch_start()

    for batch in train_dataloader(): # TRAINING Loop
        on_train_batch_start()

        on_before_batch_transfer()
        transfer_batch_to_device()
        on_after_batch_transfer()

        out = training_step()

        on_before_zero_grad()
        optimizer_zero_grad()

        on_before_backward()
        backward()
        on_after_backward()

        on_before_optimizer_step()
        configure_gradient_clipping()
        optimizer_step()

        on_train_batch_end(out, batch, batch_idx)

        if should_check_val: #! 
            val_loop() #! VALIDATION Loop

    on_train_epoch_end()
deep-learning pytorch-lightning
1个回答
0
投票

默认情况下,验证循环在训练循环之后开始。您可以使用 val_check_interval flag

对其进行配置
© www.soinside.com 2019 - 2024. All rights reserved.