Starmap 与 tqdm 结合?

问题描述 投票:0回答:4

我正在做一些并行处理,如下:

with mp.Pool(8) as tmpPool:
        results = tmpPool.starmap(my_function, inputs)

输入如下: [(1,0.2312),(5,0.52) ...] 即 int 和 float 的元组。

代码运行良好,但我似乎无法将其包装在加载栏(tqdm)上,例如可以使用 imap 方法来完成,如下所示:

tqdm.tqdm(mp.imap(some_function,some_inputs))

星图也可以这样做吗?

谢谢!

python multiprocessing python-multiprocessing tqdm process-pool
4个回答
57
投票

最简单的方法可能是在输入周围应用 tqdm(),而不是映射函数。例如:

inputs = zip(param1, param2, param3)
with mp.Pool(8) as pool:
    results = pool.starmap(my_function, tqdm.tqdm(inputs, total=len(param1)))

请注意,当调用

my_function
时,而不是返回时,该栏会更新。如果这种区别很重要,您可以考虑按照其他一些答案的建议重写星图。否则,这是一个简单而有效的替代方案。


50
投票

使用

starmap()
是不可能的,但通过添加
Pool.istarmap()
的补丁是可能的。它基于
imap()
的代码。您所要做的就是创建
istarmap.py
文件并导入模块以应用补丁,然后再进行常规多处理导入。

Python <3.8

# istarmap.py for Python <3.8
import multiprocessing.pool as mpp


def istarmap(self, func, iterable, chunksize=1):
    """starmap-version of imap
    """
    if self._state != mpp.RUN:
        raise ValueError("Pool not running")

    if chunksize < 1:
        raise ValueError(
            "Chunksize must be 1+, not {0:n}".format(
                chunksize))

    task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
    result = mpp.IMapIterator(self._cache)
    self._taskqueue.put(
        (
            self._guarded_task_generation(result._job,
                                          mpp.starmapstar,
                                          task_batches),
            result._set_length
        ))
    return (item for chunk in result for item in chunk)


mpp.Pool.istarmap = istarmap

Python 3.8+

# istarmap.py for Python 3.8+
import multiprocessing.pool as mpp


def istarmap(self, func, iterable, chunksize=1):
    """starmap-version of imap
    """
    self._check_running()
    if chunksize < 1:
        raise ValueError(
            "Chunksize must be 1+, not {0:n}".format(
                chunksize))

    task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
    result = mpp.IMapIterator(self)
    self._taskqueue.put(
        (
            self._guarded_task_generation(result._job,
                                          mpp.starmapstar,
                                          task_batches),
            result._set_length
        ))
    return (item for chunk in result for item in chunk)


mpp.Pool.istarmap = istarmap

然后在你的脚本中:

import istarmap  # import to apply patch
from multiprocessing import Pool
import tqdm    


def foo(a, b):
    for _ in range(int(50e6)):
        pass
    return a, b    


if __name__ == '__main__':

    with Pool(4) as pool:
        iterable = [(i, 'x') for i in range(10)]
        for _ in tqdm.tqdm(pool.istarmap(foo, iterable),
                           total=len(iterable)):
            pass

21
投票

正如 Darkonaut 提到的,在撰写本文时,本身还没有

istarmap
可用。如果您想避免修补,可以添加一个简单的 *
_star
函数作为解决方法。 (此解决方案的灵感来自于本教程。

import tqdm
import multiprocessing

def my_function(arg1, arg2, arg3):
  return arg1 + arg2 + arg3

def my_function_star(args):
    return my_function(*args)

jobs = 4
with multiprocessing.Pool(jobs) as pool:
    args = [(i, i, i) for i in range(10000)]
    results = list(tqdm.tqdm(pool.imap(my_function_star, args), total=len(args))

一些注意事项:

我也很喜欢科里的回答。它更干净,尽管进度条的更新似乎不像我的答案那么顺利。请注意,使用我上面使用

chunksize=1
(默认)发布的代码,corey 的答案要快几个数量级。我猜测这是由于多处理序列化造成的,因为增加
chunksize
(或具有更昂贵的
my_function
)使它们的运行时间具有可比性。

我对我的应用程序给出了答案,因为我的序列化/功能成本比非常低。


0
投票
from tqdm import tqdm

...other code...


    pool = Pool(num_workers)
    tasks = []
    for task_idx, (prompt, ...) in enumerate(zip(prompts, ...)):
        task_args = (prompt, ..., api_key, model, max_tokens, task_idx) 
        tasks.append(task_args)
    print(f'{len(tasks)=}')
    # results = pool.starmap(call_to_client_api_with_retry, tasks)
    with tqdm(total=len(tasks)) as progress_bar:
        results = pool.starmap(call_to_client_api_with_retry, tqdm(tasks, total=len(tasks)))
    pool.close()
    pool.join()
    return results 

为我工作!

© www.soinside.com 2019 - 2024. All rights reserved.