我想大多数熟悉 jax 的人都在文档中看过这个示例并且知道它不起作用:
import jax.numpy as jnp
from jax import jit
class CustomClass:
def __init__(self, x: jnp.ndarray, mul: bool):
self.x = x
self.mul = mul
@jit # <---- How to do this correctly?
def calc(self, y):
if self.mul:
return self.x * y
return y
c = CustomClass(2, True)
c.calc(3)
提到了 3 个解决方法,但似乎直接将 jit 作为函数应用,而不是装饰器也可以很好地工作。也就是说,JAX 不会抱怨不知道如何处理 CustomClass
类型的
self
:
import jax.numpy as jnp
from jax import jit
class CustomClass:
def __init__(self, x: jnp.ndarray, mul: bool):
self.x = x
self.mul = mul
# No decorator here !
def calc(self, y):
if self.mul:
return self.x * y
return y
c = CustomClass(2, True)
jitted_calc = jit(c.calc)
print(jitted_calc(3))
6 # works fine!
虽然没有记录(也许应该是?),但这似乎与通过 @partial(jax.jit, static_argnums=0)
将 self 标记为静态相同,因为更改
self
对后续调用没有任何作用,即:
c = CustomClass(2, True)
jitted_calc = jit(c.calc)
print(jitted_calc(3))
c.mul = False
print(jitted_calc(3))
6
6 # no update
所以我最初假设装饰器在直接应用它们时通常可能只是将 self 作为静态参数处理。因为该方法可能会保存到具有特定 self 实例的另一个变量中。情况似乎并非如此,因为下面的“装饰”函数很高兴地处理对 self 的更改:
def decorator(func):
def wrapper(*args, **kwargs):
x = func(*args, **kwargs)
return x
return wrapper
custom = CustomClass(2, True)
decorated_calc = decorator(custom.calc)
print(decorated_calc(3))
custom.mul = False
print(decorated_calc(3))
6
3
我看到了一些关于直接将装饰器应用为函数与装饰器样式的其他问题(例如here和here),并且提到两个版本之间存在细微差别,但这几乎不重要。
我想知道 jit 装饰器是什么让这些版本的行为如此不同,因为 JAX.jit 可以处理 self
类型,如果不是装饰风格的话。如果有人有答案,将不胜感激。
jax.jit
特有的概念。在此代码片段中,您使用 JIT 包装了绑定方法:
c = CustomClass(2, True)
jitted_calc = jit(c.calc)
当您更新 jitted_calc
的属性时
c
不更新的原因是因为 JIT 缓存其已编译的代码工件,并且该缓存基于多个量,包括 (1) 正在编译的 ID 函数,(2)任何数组参数的静态属性(例如
shape
和
dtype
),以及 (3) 标有
static_argnums
的任何静态参数的哈希值。当您稍后就地修改
c
时,它不会更改 (1) 或 (2),并且 (3) 不适用。因此,先前缓存的编译工件(具有先前的
mul
值)将再次执行。这是我没有在您链接到的文档中提到此策略的主要原因:这很少是用户想要的行为。在 JIT 中包装绑定方法与用
@partial(jit, static_argnums=0)
包装方法定义非常相似,但机制不同:在
static_argnums
版本中,类实例的哈希值用作 JIT 缓存键的一部分,但由于您尚未定义对
__hash__
的值敏感的
__eq__
和
self.mul
方法,因此更改该属性不会更改哈希值,因此将使用之前缓存的计算。您可以在链接到的文档中的策略 2 下查看如何纠正此问题的示例。 在最后一个示例中,(使用
decorator
和
wrapper
)您根本没有使用
jax.jit
,因此 JIT 缓存永远不会发挥作用。这表明您所看到的行为与装饰器语法无关,而是与 JIT 缓存的机制有关。我希望一切都清楚了!