我有一个函数,它基本上可以实例化一个巨大的数组并执行其他操作。我在 TPU 上运行我的代码,所以基本上我的内存是有限的。
如何在CPU上专门执行我的函数?
如果我这样做:
y = jax.device_put(my_function(), device=jax.devices("cpu")[0])
我猜
my_function()
首先在TPU上执行,结果放在CPU上,这给了我内存错误。
并在我的代码开头使用
jax.config.update('jax_platform_name', 'cpu')
似乎没有效果。
另请注意,我无法修改
my_function()
谢谢!
要直接指定应执行函数的设备,请使用
device
的 jax.jit
参数。例如(使用 GPU 运行时,因为它是我目前可以访问的加速器):
import jax
gpu_device = jax.devices('gpu')[0]
cpu_device = jax.devices('cpu')[0]
def my_function(x):
return x.sum()
x = jax.numpy.arange(10)
x_gpu = jax.jit(my_function, device=gpu_device)(x)
print(x_gpu.device())
# gpu:0
x_cpu = jax.jit(my_function, device=cpu_device)(x)
print(x_cpu.device())
# TFRT_CPU_0
这也可以通过调用站点周围的
jax.default_device
装饰器来控制:
with jax.default_device(cpu_device):
print(jax.jit(my_function)(x).device())
# TFRT_CPU_0
with jax.default_device(gpu_device):
print(jax.jit(my_function)(x).device())
# gpu:0