PyTorch Lightning 2.0.1,文档伪代码如下:
当前: 鉴于
if should_check_val: val_loop()
在 for batch in train_dataloader():
循环中,
所有验证批次的验证发生在对单个训练批次进行训练之后。
预期: 所有验证批次的验证发生在对所有训练批次进行培训之后。
if should_check_val: val_loop()
应该在 for batch in train_dataloader():
循环之外。
def train_on_device(model):
setup("fit")
configure_optimizers()
on_fit_start()
# the sanity check runs here
on_train_start()
for epoch in epochs:
fit_loop() #! CALLS Method Below
on_train_end()
on_fit_end()
teardown("fit")
def fit_loop():
model.train()
torch.set_grad_enabled(True)
on_train_epoch_start()
for batch in train_dataloader(): # TRAINING Loop
on_train_batch_start()
on_before_batch_transfer()
transfer_batch_to_device()
on_after_batch_transfer()
out = training_step()
on_before_zero_grad()
optimizer_zero_grad()
on_before_backward()
backward()
on_after_backward()
on_before_optimizer_step()
configure_gradient_clipping()
optimizer_step()
on_train_batch_end(out, batch, batch_idx)
if should_check_val: #!
val_loop() #! VALIDATION Loop
on_train_epoch_end()
默认情况下,验证循环在训练循环之后开始。您可以使用 val_check_interval flag
对其进行配置