在 pytorch 中使用多处理进行训练

问题描述 投票:0回答:1

我想尝试用任意尺寸的图像来细化模型。 写这样的东西并不难

...
y_pred = torch.cat([model(im) for im in imgs])
loss = loss_fn(y_pred, y)
...

但即使是简单的

model(batch)
也有点受CPU限制,所以这甚至更慢。

我尝试过类似的事情

pool.map(model, imgs)

这给了我错误,告诉我 autograd 不能在进程之间使用。

我简要查看了 pytorch 文档中的 dataparallel 页面并得出结论,dataparallel 不能以这种方式使用。 那么有没有办法呢?

python deep-learning pytorch python-multiprocessing dataparallel
1个回答
0
投票
from concurrent.futures import ThreadPoolExecutor
...
model = torch.jit.script(model)
...
pool = ThreadPoolExecutor()
...
y_pred = torch.cat(list(pool.map(model, imgs)))
...

但这并没有我想象的那么快。

© www.soinside.com 2019 - 2024. All rights reserved.