我正在尝试运行这个基于分数的生成建模的简单介绍。代码使用的是
flax.optim
,同时似乎已移至 optax
(https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/optax_update_guide.html)。
我已经制作了 Colab 代码的副本,其中包含我认为需要进行的更改(我只是不确定需要如何替换
optimizer = flax.jax_utils.replicate(optimizer)
)。
现在,在训练部分,我收到错误
pmap 被要求沿轴 0 映射其参数,这意味着它的等级至少应为 1,但实际上仅为 0(其形状为 ())
在
loss, params, opt_state = train_step_fn(step_rng, x, params, opt_state)
线上。这显然来自“定义损失函数”部分中的return jax.pmap(step_fn, axis_name='device')
。
如何修复此错误?我用谷歌搜索过,但不知道这里出了什么问题。
发生这种情况是因为您将标量参数传递给 pmap 函数。例如:
import jax
func = lambda x: x ** 2
pfunc = jax.pmap(func)
pfunc(1.0)
# ValueError: pmap was requested to map its argument along axis 0, which implies
# that its rank should be at least 1, but is only 0 (its shape is ())
如果你想对标量进行操作,你应该使用该函数而不将其包装在
pmap
:
func(1.0)
# 1.0
或者,如果您想使用
pmap
,您应该对前导维度与设备数量匹配的数组进行操作:
num_devices = len(jax.devices())
x = jax.numpy.arange(num_devices)
pfunc(x)
# Array([ 0, 1, 4, 9, 16, 25, 36, 49], dtype=int32)