我想做类似 NER 任务代码的事情,它将单词的 WordPieces 与该单词的标签对齐:
import tensorflow as tf
tokens = tf.ragged.constant([[4], [2, 5, 9]], dtype=tf.int32)
tags = tf.ragged.constant([3, 5], dtype=tf.int32)
flat_tokens = tf.reshape(tokens, [-1])
duplicated_tags = tf.repeat(tags, [tf.shape(tok)[0] for tok in tokens])
print(flat_tokens.numpy()) # -> [4 2 5 9]
print(duplicated_tags.numpy()) # -> [3 5 5 5]
但是输入
tokens
和tags
到tf.repeat
作为数据集,应该是TextLineDataset
的输出。有什么简约的方法可以做到吗?
也许是这样的:
import tensorflow as tf
tokens = tf.data.Dataset.from_tensor_slices(tf.ragged.constant([[4], [2, 5, 9]], dtype=tf.int32))
tags = tf.data.Dataset.from_tensor_slices(tf.ragged.constant([3, 5], dtype=tf.int32))
ds = tf.data.Dataset.zip((tokens, tags)).map(lambda x, y: (x, tf.repeat(y, repeats=tf.shape(x)[0])))
tokens = ds.map(lambda a, b: a).flat_map(tf.data.Dataset.from_tensor_slices)
tags = ds.map(lambda a, b: b).flat_map(tf.data.Dataset.from_tensor_slices)
print(list(tokens.as_numpy_iterator()))
print(list(tags.as_numpy_iterator()))
[4, 2, 5, 9]
[3, 5, 5, 5]