jax没有任何用于数组或计算的内置缓存机制。我不确定您提出的修改将如何减少获取输出所需的计算量。对于实际的缓存,您需要从“外部”中通过以下路线传递“外部”计算。
def __call__(self, x, theta, x_cache=None, train: bool = False):
if x_cache is None:
x = self.compute_x(x)
else:
x = x_cache
x_concat = jnp.concatenate([x, theta_projected], axis=-1)
...
return self.output_layers(x_concat), x
result, x_cache = model(x, theta)
# and later
result, x_cache = model(x, theta, x_cache)
flax