我正在尝试从tensorflow数据集创建填充批处理,这使我出错。
样本数据:
A = [
([[101,9385,13302,102], [1,1,1,1], [0,0,0,0]]),
([[101,9385,13302], [1,1,1], [0,0,0]]),
([[101,9385,13302,102], [1,1,1,1], [0,0,0,0]]),
([[101,9385,13302], [1,1,1], [0,0,0]])
]
预期输出:
Output = [
([[101,9385,13302,102], [1,1,1,1], [0,0,0,0]]),
([[101,9385,13302,-100], [1,1,1,-100], [0,0,0,-100]]),
([[101,9385,13302,102], [1,1,1,1], [0,0,0,0]]),
([[101,9385,13302,-100], [1,1,1,-100], [0,0,0,-100]])
]
下面是代码:
chk = tf.data.Dataset.from_generator(lambda: A,
output_types = (tf.int32))
BATCH_SIZE = 2
chk_batched = chk.padded_batch(BATCH_SIZE ,
padded_shapes=(4),
padding_values=(0))
for elem in chk_batched.as_numpy_iterator():
print(elem)
错误在下面给出:
InvalidArgumentError: All elements in a batch must have the same rank as the padded shape for component0: expected rank 1 but got element with rank 2
因此它是二维的,papped_shapes =(4,3),无需编写padding_values =(0)。默认情况下,填充的值为0