使用 Mutiprocessing 在 PyTorch 中进行并行预测

问题描述 投票:0回答:1

我正在尝试对需要处理 GB 数据的大型回归模型进行推理。我尝试并行化预测循环,但它没有按预期工作,这是我尝试运行的代码

def predict_batch(batch):
    
    batch_idx, (inputs) = batch
    
    losses = []
    true = []
    predicted = []

    with torch.no_grad():
        y_output = model(current,calibrated)


        # making predictions here

        loss = loss_fn()
        
        overall_loss = loss.item()
            
        losses.append(overall_loss)

    # taking the outputs of each batch and making it into two list predicted and true


    return losses,true,predicted

有人有我可以使用的样板代码吗?

我厌倦了上面的内容,但它不起作用,任何进程运行都需要很多时间。

这是并行化代码:

from multiprocessing import Pool
import tqdm
import time
model.share_memory()
if __name__ == '__main__':
   with Pool(4) as p:
      r = list(tqdm.tqdm(p.imap(predict_batch,predict_batch_input), total=len(predict_batch_input)))
python deep-learning pytorch multiprocessing
1个回答
0
投票

如果我理解你的问题,你想运行四个子进程,每个子进程都对传递的批次数据进行预测。我所尝试的方法的主要问题是每个子进程将使用完全相同的预测数据集,因此每个子进程将执行完全相同的工作并且(出于所有意图和目的)具有完全相同的输出,使多处理变得毫无意义。

子进程应该使用自己的数据子集。例如,如果您的输入数据是

[0, 1, 2, 3, 4, 5, 6, 7]
,那么您可能希望子进程 0-3 分别具有输入数据
[0, 1]
[2, 3]
[4, 5]
[6, 7]
,并传递以下每一项这些到各自的流程。否则,按照您当前的方式,每个子进程都会收到
[0, 1, 2, 3, 4, 5, 6, 7]
并对其执行相同的工作。

from torch.multiprocessing import Pool, Manager


# pass shared queue to function used by child processes to store data
def predict_batch(batch, queue):
    ...
    with torch.no_grad():
        ...

    queue.put((losses, true, predicted))

    return losses,true,predicted

if __name__ == "__main__":
    pool = Pool(4)
    q = Manager().Queue()

    # ideally predict_batch_input would be something that is split between the processes and not the same passed variable for all n processes
    batch_inputs = [predict_batch_input0, predict_batch_input1, predict_batch_input2, predict_batch_input3]

    r = [pool.apply_async(predict_batch, args=(i, q)), len(i) for i in batch_inputs]
    
    # make parent process wait for child processes to finish
    pool.close()
    pool.join()

    # get data from queue
    results = [(i, j, k) for i, j, k in [q.get() for _ in range(q.qsize())]]

使用 Python 中的多处理,除非您调试或手动编写异常处理代码,否则子进程中发生的错误将终止该进程,而不显示错误消息。因此,如果您看到子进程中没有发生某些事情,但父进程已完成,则子进程中的某个地方很可能存在错误。

我不确定这是否是您正在寻找的样板,但这与我使用 Pytorch 进行多重处理的方式非常相似。

© www.soinside.com 2019 - 2024. All rights reserved.