我想问一个关于配置中数据类型的问题。 我的笔记本里有代码:
train_config = {
'opt_cls': torch.optim.Adam,
'loss_fn': nn.CrossEntropyLoss(),
'metric_fn': f1_score,
'max_epochs': 10,
'opt_params': {
'lr': 5e-4
},
'scheduler_cls': torch.optim.lr_scheduler.StepLR,
'scheduler_params': {
'step_size': 4,
'gamma': 0.55
},
'device': device
}
所以现在我想将我的代码从笔记本分发到 .py 文件。按照我的想法,写是合乎逻辑的
配置文件中的类似上面的内容,如 .yaml、.json。那么,我怎样才能轻松读取此配置而不是读取字符串,而是读取类型或函数?
我在 pyyaml 网站上找到了解决方案。但它需要制作继承自 yaml.YAMLObject
的类但这似乎很糟糕,因为如果我决定使用,例如,我自己的损失函数,我会写很多代码来启动我的实验。