我正在尝试设置图像管道:将图像从磁盘加载到数据集中,重新缩放,生成补丁,训练。从磁盘加载很慢,所以我想创建随机数据来模拟真实图像。 我创建了一个生成器类,它生成一个随机 numpy 数组并将其转换为张量,然后返回它。然后我使用 Dataset.from_generator 但在尝试迭代数据时出现错误。这个问题最初是在将数据集传递到 model.fit 时出现的,但我发现我可以通过以下方式更快地触发它
import numpy as np
import tensorflow as tf
# batch size for dataset
full_image_batch_size = 1
batch_size = 1
image_height = 3000
image_width = 5328
image_channels = 3
tile_size = 256
class FakeImageGenerator:
def __init__(self, number_of_images):
self.generated = 0
self.number_of_images = number_of_images
def __len__(self):
return self.number_of_images
def __getitem__(self, index):
arr = np.random.rand(image_height, image_width, image_channels) * 255.0
arr = tf.convert_to_tensor(arr, dtype=tf.float32)
return (arr, )
def __call__(self):
return self.__getitem__(0)
x_image_generator = FakeImageGenerator(full_image_batch_size)
x_train = tf.data.Dataset.from_generator(
x_image_generator,
output_signature=(
tf.TensorSpec(shape=(image_height, image_width, image_channels), dtype=tf.float32),
))
for elem in iter(x_train):
print(elem)
W tensorflow/core/framework/op_kernel.cc:1829] INVALID_ARGUMENT: TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32,), but the yielded element was [[[ 59.90327 5.024378 118.99237 ]...
[178.26884 200.46623 15.418176]]].
Traceback (most recent call last):
File "/root/.virtualenvs/ga-python/lib/python3.12/site-packages/tensorflow/python/data/ops/from_generator_op.py", line 204, in generator_py_func
flattened_values = nest.flatten_up_to(output_types, values)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.virtualenvs/ga-python/lib/python3.12/site-packages/tensorflow/python/data/util/nest.py", line 237, in flatten_up_to
return nest_util.flatten_up_to(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.virtualenvs/ga-python/lib/python3.12/site-packages/tensorflow/python/util/nest_util.py", line 1541, in flatten_up_to
return _tf_data_flatten_up_to(shallow_tree, input_tree)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.virtualenvs/ga-python/lib/python3.12/site-packages/tensorflow/python/util/nest_util.py", line 1570, in _tf_data_flatten_up_to
_tf_data_assert_shallow_structure(shallow_tree, input_tree)
File "/root/.virtualenvs/ga-python/lib/python3.12/site-packages/tensorflow/python/util/nest_util.py", line 1414, in _tf_data_assert_shallow_structure
raise TypeError(
TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: 'EagerTensor'.
注意:在堆栈跟踪中,“产生的元素是...”显示整个 numpy 数组,但我只包含几个切片来代表。
您的
output_signature
指定生成器应该生成元组,如错误所示:
预期的结构是 (tf.float32,)
删除
output_signature
的元组,它应该可以工作:
x_train = tf.data.Dataset.from_generator(
x_image_generator,
output_signature=tf.TensorSpec(shape=(image_height, image_width, image_channels), dtype=tf.float32)
)