我想尝试用任意尺寸的图像来细化模型。 写这样的东西并不难
...
y_pred = torch.cat([model(im) for im in imgs])
loss = loss_fn(y_pred, y)
...
但即使是简单的
model(batch)
也有点受CPU限制,所以这甚至更慢。
我尝试过类似的事情
pool.map(model, imgs)
这给了我错误,告诉我 autograd 不能在进程之间使用。
我简要查看了 pytorch 文档中的 dataparallel 页面并得出结论,dataparallel 不能以这种方式使用。 那么有没有办法呢?
from concurrent.futures import ThreadPoolExecutor
...
model = torch.jit.script(model)
...
pool = ThreadPoolExecutor()
...
y_pred = torch.cat(list(pool.map(model, imgs)))
...
但这并没有我想象的那么快。