在方法上使用 Jax Jit 作为装饰器与直接应用 jit 函数

问题描述 投票:0回答:1

我想大多数熟悉 jax 的人都在文档中看过这个示例并且知道它不起作用:

import jax.numpy as jnp from jax import jit class CustomClass: def __init__(self, x: jnp.ndarray, mul: bool): self.x = x self.mul = mul @jit # <---- How to do this correctly? def calc(self, y): if self.mul: return self.x * y return y c = CustomClass(2, True) c.calc(3)
提到了 3 个解决方法,但似乎直接将 jit 作为函数应用,而不是装饰器也可以很好地工作。也就是说,JAX 不会抱怨不知道如何处理 

CustomClass

 类型的 
self
:

import jax.numpy as jnp from jax import jit class CustomClass: def __init__(self, x: jnp.ndarray, mul: bool): self.x = x self.mul = mul # No decorator here ! def calc(self, y): if self.mul: return self.x * y return y c = CustomClass(2, True) jitted_calc = jit(c.calc) print(jitted_calc(3))
6 # works fine!
虽然没有记录(也许应该是?),但这似乎与通过 

@partial(jax.jit, static_argnums=0)

 将 self 标记为静态相同,因为更改 
self
 对后续调用没有任何作用,即:

c = CustomClass(2, True) jitted_calc = jit(c.calc) print(jitted_calc(3)) c.mul = False print(jitted_calc(3))
6
6 # no update
所以我最初假设装饰器在直接应用它们时通常可能只是将 self 作为静态参数处理。因为该方法可能会保存到具有特定 self 实例的另一个变量中。情况似乎并非如此,因为下面的“装饰”函数很高兴地处理对 self 的更改:

def decorator(func): def wrapper(*args, **kwargs): x = func(*args, **kwargs) return x return wrapper custom = CustomClass(2, True) decorated_calc = decorator(custom.calc) print(decorated_calc(3)) custom.mul = False print(decorated_calc(3))
6
3
我看到了一些关于直接将装饰器应用为函数与装饰器样式的其他问题(例如

herehere),并且提到两个版本之间存在细微差别,但这几乎不重要。 我想知道 jit 装饰器是什么让这些版本的行为如此不同,因为 JAX.jit 可以处理 self

 类型,如果不是装饰风格的话。如果有人有答案,将不胜感激。

python python-decorators jax
1个回答
0
投票
装饰器与静态参数无关:静态参数是

jax.jit

特有的概念。

在此代码片段中,您使用 JIT 包装了绑定方法:

c = CustomClass(2, True) jitted_calc = jit(c.calc)
当您更新 

jitted_calc

 的属性时 
c
 不更新的原因是因为 JIT 缓存其已编译的代码工件,并且该缓存基于多个量,包括 (1) 正在编译的 ID 函数,(2)任何数组参数的静态属性(例如 
shape
dtype
),以及 (3) 标有 
static_argnums
 的任何静态参数的哈希值。

当您稍后就地修改

c

 时,它不会更改 (1) 或 (2),并且 (3) 不适用。因此,先前缓存的编译工件(具有先前的 
mul
 值)将再次执行。这是我没有在您链接到的文档中提到此策略的主要原因:这很少是用户想要的行为。

在 JIT 中包装绑定方法与用

@partial(jit, static_argnums=0)

 包装方法定义非常相似,但机制不同:在 
static_argnums
 版本中,类实例的哈希值用作 JIT 缓存键的一部分,但由于您尚未定义对 
__hash__
 的值敏感的 
__eq__
self.mul
 方法,因此更改该属性不会更改哈希值,因此将使用之前缓存的计算。您可以在链接到的文档中的
策略 2 下查看如何纠正此问题的示例。

在最后一个示例中,(使用

decorator

wrapper
)您根本没有使用 
jax.jit
,因此 JIT 缓存永远不会发挥作用。这表明您所看到的行为与装饰器语法无关,而是与 JIT 缓存的机制有关。

我希望一切都清楚了!

© www.soinside.com 2019 - 2024. All rights reserved.