我正在尝试修改sklearn源代码。特别是,我正在修改GridSearch源代码,以使评估不同模型配置的单独进程/线程在它们之间共享一个变量。我需要每个线程/进程在运行时读取/更新该变量,以便根据获得的其他线程来修改其执行。更具体地说,我要共享的参数是best,在下面的代码段中:
out = parallel(delayed(_fit_and_score)(clone(base_estimator), X, y, best, self.early,train=train, test=test,parameters=parameters,**fit_and_score_kwargs) for parameters, (train, test) in product(candidate_params, cv.split(X, y, groups)))
注意,_ fit_and_score函数位于单独的模块中。Sklearn利用joblib进行并行化,但是我无法理解如何使用外部模块有效地做到这一点。在joblib文档中提供了以下代码:
>>> shared_set = set()
>>> def collect(x):
... shared_set.add(x)
...
>>> Parallel(n_jobs=2, require='sharedmem')(
... delayed(collect)(i) for i in range(5))
[None, None, None, None, None]
>>> sorted(shared_set)
[0, 1, 2, 3, 4]
但是我无法理解如何使其在我的上下文中运行。您可以在这里找到源代码:
您可以使用python的Manager(https://docs.python.org/3/library/multiprocessing.html#multiprocessing.sharedctypes.multiprocessing.Manager),例如,简单的代码:
from joblib import Parallel, delayed
from multiprocessing import Manager
manager = Manager()
q = manager.Namespace()
q.flag = False
def test(i, q):
#update shared var in 0 process
if i == 0:
q.flag = True
# do nothing for few seconds
for n in range(100000000):
if q.flag == True:
return f'process {i} was updated'
return 'process {i} was not updated'
out = Parallel(n_jobs=4)(delayed(test)(i, q) for i in range(4))
输出:
['process 0 was updated',
'process 1 was updated',
'process 2 was updated',
'process 3 was updated']