Jax 中专门在 CPU 上执行函数

问题描述 投票:0回答:2

我有一个函数,它基本上可以实例化一个巨大的数组并执行其他操作。我在 TPU 上运行我的代码,所以基本上我的内存是有限的。

如何在CPU上专门执行我的函数?

如果我这样做:

y = jax.device_put(my_function(), device=jax.devices("cpu")[0])

我猜

my_function()
首先在TPU上执行,结果放在CPU上,这给了我内存错误。

并在我的代码开头使用

jax.config.update('jax_platform_name', 'cpu')
似乎没有效果。

另请注意,我无法修改

my_function()

谢谢!

python memory cpu tpu jax
2个回答
1
投票

要直接指定应执行函数的设备,请使用

device
jax.jit
参数。例如(使用 GPU 运行时,因为它是我目前可以访问的加速器):

import jax

gpu_device = jax.devices('gpu')[0]
cpu_device = jax.devices('cpu')[0]

def my_function(x):
  return x.sum()

x = jax.numpy.arange(10)

x_gpu = jax.jit(my_function, device=gpu_device)(x)
print(x_gpu.device())
# gpu:0

x_cpu = jax.jit(my_function, device=cpu_device)(x)
print(x_cpu.device())
# TFRT_CPU_0

这也可以通过调用站点周围的

jax.default_device
装饰器来控制:

with jax.default_device(cpu_device):
  print(jax.jit(my_function)(x).device())
  # TFRT_CPU_0

with jax.default_device(gpu_device):
  print(jax.jit(my_function)(x).device())
  # gpu:0

0
投票

我在这里猜测一下。我也无法运行它,所以你可能不得不摆弄它

with jax.default_device(jax.devices("cpu")[0]):
    y = my_function()

请参阅文档此处此处

© www.soinside.com 2019 - 2024. All rights reserved.