我试图了解使用
tqdm
的进度条是如何准确工作的。我有一些代码如下所示:
import torch
import torchvision
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
load_data()
manual_transforms = transforms.Compose([])
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders()
# them within the main function I have placed the train function that exists in the `engine.py` file
def main():
results = engine.train(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
optimizer=optimizer,
loss_fn=loss_fn,
epochs=5,
device=device)
并且
engine.train()
函数包含以下代码for epoch in tqdm(range(epochs)):
然后,对每个批次进行训练以可视化训练进度。每次 tqdm 运行每个步骤时,它还会打印以下语句:
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
最后,我的问题是为什么会发生这种情况。主函数如何访问这些全局语句以及如何避免在每个循环中打印所有内容?
您注意到的实际上与
tqdm
无关,而是与 PyTorch 的内部工作原理(特别是 DataLoader
的 num_workers
属性)和 Python 的底层 multiprocessing
框架有关。这是一个应该重现您的问题的最小工作示例:
from contextlib import suppress
from multiprocessing import set_start_method
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
print("torch version:", torch.__version__)
class DummyData(Dataset):
def __len__(self): return 256
def __getitem__(self, i): return i
def main():
for batch in tqdm(DataLoader(DummyData(), batch_size=16, num_workers=4)):
pass # Do something
if __name__ == "__main__":
# Enforce "spawn" method (e.g. on Linux) for subprocess creation to
# reproduce problem (suppress error for reruns in same interpreter)
with suppress(RuntimeError): set_start_method("spawn")
main()
如果运行这段代码,您应该会看到 PyTorch 版本号被打印了 4 次,弄乱了您的
tqdm
进度条。这个数字与 num_workers
相同并非巧合(您可以通过更改此数字轻松检查)。
发生的情况如下:
num_workers
> 0,则为工作人员启动子流程。set_start_method()
完成了这一操作)。if __name__ == "__main__":
块保护的行。这包括您在脚本顶部的 print()
调用。该行为以及潜在的缓解措施已记录在此处。我想,对你有用的一个是:
将大部分主脚本代码包装在
块中,以确保它不会再次运行if __name__ == '__main__':
所以,要么
print()
调用移至 if __name__ == '__main__':
块的开头,print()
调用移至 main()
函数的开头,或者 print()
呼叫。或者,但这可能不是您想要的,您可以设置
num_workers=0
,这将完全禁用multiprocessing
的底层使用(但这样您也会失去并行化的好处)。请注意,您可能还应该将其他函数调用(例如 load_data()
)移至 if __name__ == '__main__':
块或 main()
函数中,以避免多次意外执行。