JIT:部分或带有静态参数?不可哈希输入,但可哈希部分

问题描述 投票:0回答:1

我有点迷失到底发生了什么以及该选择什么选项。让我们看一个例子:

import jax
from functools import partial
from typing import List

def dummy(a: int, b: List[str]):
    return a + 1

由于

b
参数是可变的,使用静态参数名称进行抖动将会失败:

j_dummy = jax.jit(dummy, static_argnames=['b'])
j_dummy(2, ['kek'])
ValueError: Non-hashable static arguments are not supported

但是,如果我们这样做

partial
jp_dummy = jax.jit(partial(dummy, b=['kek']))
,我们就瞄准了目标。不知何故,
partial
对象确实有
__hash__
方法,所以我们可以用
hash(partial(dummy, b=['kek']))
来检查它。

所以,我在这里有点迷失:我应该如何从更大的角度进行?我应该使用任何参数生成部分函数,然后将它们合并,还是应该尝试保持我的参数可散列?在什么情况下一种方法优于另一种方法?有什么缺点吗?

python jit jax
1个回答
0
投票

当您使用

static_argnames
时,传递给函数的静态值将成为缓存键的一部分,因此如果值发生更改,函数将重新编译:

import jax
import jax.numpy as jnp

def f(x, s):
  return x * len(s)

f_jit = jax.jit(f, static_argnames=['s'])

print(f_jit(2, "abc"))  # 6
print(f_jit(2, "abcd"))  # 8

这就是为什么静态参数必须是可散列的:它们的散列用作 JIT 缓存键。

另一方面,当您通过闭包包装静态参数时,它的值不会影响缓存键,因此它不需要是可散列的。另一方面,由于它不是缓存键的一部分,如果全局值发生变化,它不会触发重新编译,因此您可能会得到意想不到的结果:

f_closure = jax.jit(lambda x: f(x, s))

s = "abc"
print(f_closure(2))  # 6
s = "abcd"
print(f_closure(2))  # 6

因此,显式静态参数可能更安全。在您的情况下,最好将列表更改为元组,因为元组是可散列的并且可以用作显式静态参数。

© www.soinside.com 2019 - 2024. All rights reserved.