我想用 PyTorch 构建一个联邦学习系统。我目前已经制定并编写了以下代码:
我想用 MapReduce 并行化平均步骤。在我看来,map步骤可以将所有权重除以模型总数,reduce步骤可以将所有得到的权重相加得到一个平均模型。
我想使用 PySpark 执行此操作,但是我无法弄清楚如何在 Spark 脚本中使用 PyTorch 加载模型文件(通常保存为 Pickle 文件)。我正在考虑将模型位置作为具有相同键的值传递。我在网上找到的所有 Spark 示例都只适用于文本文件,所以我很感激对此提供一些帮助。