类型错误:不可哈希类型:尝试将 Equinox 模块与 jax.lax.scan 一起使用时出现“ArrayImpl”

问题描述 投票:0回答:1

我是 Equinox 和 JAX 的新手,但想用它们来模拟动态系统。

但是当我将系统模型作为 Equinox 模块传递到 jax.lax.scan 时,我在标题中收到了不可散列的类型错误。 我知道 jax 期望函数参数是一个纯函数,但我认为 Equinox 模块会模拟这一点。

这是一个重现错误的测试脚本

import equinox as eqx
import jax
import jax.numpy as jnp


class EqxModel(eqx.Module):
    A: jax.Array
    B: jax.Array
    C: jax.Array
    D: jax.Array

    def __call__(self, states, inputs):
        x = states.reshape(-1, 1)
        u = inputs.reshape(-1, 1)
        x_next = self.A @ x + self.B @ u
        y = self.C @ x + self.D @ u
        return x_next.reshape(-1), y.reshape(-1)


def simulate(model, inputs, x0):
    xk = x0
    outputs = []
    for uk in inputs:
        xk, yk = model(xk, uk)
        outputs.append(yk)
    outputs = jnp.stack(outputs)
    return xk, outputs


A = jnp.array([[0.7, 1.0], [0.0, 1.0]])
B = jnp.array([[0.0], [1.0]])
C = jnp.array([[0.3, 0.0]])
D = jnp.array([[0.0]])
model = EqxModel(A, B, C, D)

# Test simulation
inputs = jnp.array([[0.0], [1.0], [1.0], [1.0]])
x0 = jnp.zeros(2)
xk, outputs = simulate(model, inputs, x0)
assert jnp.allclose(xk, jnp.array([2.7, 3.0]))
assert jnp.allclose(outputs, jnp.array([[0.0], [0.0], [0.0], [0.3]]))

# This raises TypeError
xk, outputs = jax.lax.scan(model, x0, inputs)

unhashable type: 'ArrayImpl'
指的是什么? 是数组A、B、C、D吗? 在此模型中,这些矩阵是参数,因此在模拟期间应该是静态的。

我刚刚发现这个可能相关的问题线程:

python jax equinox computation-graph
1个回答
0
投票

Owen Lockwood (lockwo) 在本期主题中提供了解释和答案,我将在下面重申。

我相信您的问题正在发生,因为 jax 尝试对您正在扫描的函数进行哈希处理,但它无法对模块中的数组进行哈希处理。您可能可以做很多事情来解决这个问题,最简单的就是柯里化模型,例如

xk, outputs = jax.lax.scan(lambda carry, y: model(carry, y), x0, inputs)

工作正常

或者,根据我正在使用的变量名称重写:

xk, outputs = jax.lax.scan(lambda xk, uk: model(xk, uk), x0, inputs)
    
© www.soinside.com 2019 - 2024. All rights reserved.