我是 Equinox 和 JAX 的新手,但想用它们来模拟动态系统。
但是当我将系统模型作为 Equinox 模块传递到 jax.lax.scan 时,我在标题中收到了不可散列的类型错误。 我知道 jax 期望函数参数是一个纯函数,但我认为 Equinox 模块会模拟这一点。
这是一个重现错误的测试脚本
import equinox as eqx
import jax
import jax.numpy as jnp
class EqxModel(eqx.Module):
A: jax.Array
B: jax.Array
C: jax.Array
D: jax.Array
def __call__(self, states, inputs):
x = states.reshape(-1, 1)
u = inputs.reshape(-1, 1)
x_next = self.A @ x + self.B @ u
y = self.C @ x + self.D @ u
return x_next.reshape(-1), y.reshape(-1)
def simulate(model, inputs, x0):
xk = x0
outputs = []
for uk in inputs:
xk, yk = model(xk, uk)
outputs.append(yk)
outputs = jnp.stack(outputs)
return xk, outputs
A = jnp.array([[0.7, 1.0], [0.0, 1.0]])
B = jnp.array([[0.0], [1.0]])
C = jnp.array([[0.3, 0.0]])
D = jnp.array([[0.0]])
model = EqxModel(A, B, C, D)
# Test simulation
inputs = jnp.array([[0.0], [1.0], [1.0], [1.0]])
x0 = jnp.zeros(2)
xk, outputs = simulate(model, inputs, x0)
assert jnp.allclose(xk, jnp.array([2.7, 3.0]))
assert jnp.allclose(outputs, jnp.array([[0.0], [0.0], [0.0], [0.3]]))
# This raises TypeError
xk, outputs = jax.lax.scan(model, x0, inputs)
unhashable type: 'ArrayImpl'
指的是什么? 是数组A、B、C、D吗? 在此模型中,这些矩阵是参数,因此在模拟期间应该是静态的。
我刚刚发现这个可能相关的问题线程:
Owen Lockwood (lockwo) 在本期主题中提供了解释和答案,我将在下面重申。
我相信您的问题正在发生,因为 jax 尝试对您正在扫描的函数进行哈希处理,但它无法对模块中的数组进行哈希处理。您可能可以做很多事情来解决这个问题,最简单的就是柯里化模型,例如或者,根据我正在使用的变量名称重写:
xk, outputs = jax.lax.scan(lambda carry, y: model(carry, y), x0, inputs)
工作正常
xk, outputs = jax.lax.scan(lambda xk, uk: model(xk, uk), x0, inputs)