tf.data.dataset to to jax.numpyiterator

问题描述 投票:0回答:2
Dataset.as_numpy_generator()

将TF张量转换为Numpy阵列。但是,我想知道这是否是一个很好的做法,因为Numpy阵列存储在CPU内存中,而这不是我想要的训练(我使用GPU)。因此,我发现的最后一个想法是通过调用

jnp.array
手动重新铸造数组,但这并不是真正的优雅(我担心GPU内存中的副本)。有人对此有更好的主意吗?
Quick代码要说明:
import os
import jax.numpy as jnp
import tensorflow as tf

def generator():
    for _ in range(2):
        yield tf.random.uniform((1, ))

ds = tf.data.Dataset.from_generator(generator, output_types=tf.float32,
                                    output_shapes=tf.TensorShape([1]))

ds1 = ds.take(1).as_numpy_iterator()
ds2 = ds.skip(1)

for i, batch in enumerate(ds1):
    print(type(batch))

for i, batch in enumerate(ds2):
    print(type(jnp.array(batch)))

# returns:

<class 'numpy.ndarray'> # not good
<class 'jaxlib.xla_extension.DeviceArray'> # good but not elegant

tostorflow和Jax都能将数组转换为
dlPack
张量,而无需复制内存,因此您可以从TensorFlow数组中创建JAX数组而无需复制基础数据缓冲区,即通过DLPACK进行操作。
python tensorflow numpy-ndarray jax
2个回答
6
投票

通过对JAX进行往返,您可以比较unsafe_buffer_pointer()以确保阵列指向相同的缓冲区,而不是沿途复制缓冲区:

def tf_to_jax(arr):
  return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(arr))

def jax_to_tf(arr):
  return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))

jax_arr = jnp.arange(20.)
tf_arr = jax_to_tf(jax_arr)
jax_arr2 = tf_to_jax(tf_arr)

print(jnp.all(jax_arr == jax_arr2))
# True
print(jax_arr.unsafe_buffer_pointer() == jax_arr2.unsafe_buffer_pointer())
# True

来自亚麻示例:

Https://github.com/google/flax/blob/6ae22681EF6F6C00414140C3759E717553BDA55BD/EXAMPLES/IMAGENET/IMAGENET/IMAGENET/MAGENET/TRAIN.PYPY#L183

1
投票

def prepare_tf_data(xs): local_device_count = jax.local_device_count() def _prepare(x): x = x._numpy() return x.reshape((local_device_count, -1) + x.shape[1:]) return jax.tree_util.tree_map(_prepare, xs) it = map(prepare_tf_data, ds) it = jax_utils.prefetch_to_device(it, 2)

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.