我想使用DenseHashTable查找字符串张量,就像这个答案answer一样,键的类型是tf.string,值嵌入tf.float32 dtype。但是当键是多维时,就会出现错误。
keys = ["Fritz", "Franz", "Fred"]
values = [[1, 2, 3, -1], [4, 5, -1, -1], [6, 7, 8, 9]]
table = tf.lookup.experimental.DenseHashTable(key_dtype=tf.string, value_dtype=tf.float32, empty_key="0", deleted_key="-1", default_value=[-1,-1,-1,-1])
table.insert(keys, values)
table.lookup(['Franz', 'Emil']) # shape=(2,) its ok
table.lookup([['Franz', 'Emil'], ['Emil', 'Fred']]) # when lookup with 2-D tensor(shape like (batch_size, 2)), throws error.
我怎样才能让它像 tf.nn.embedding_lookup 一样工作?键不是数组索引而是 tf.string。
问题是 TensorFlow 需要一个键列表,而不是嵌套的键列表。当然,
docs中按键描述中的
Can be a tensor of any shape.
有点令人困惑。keys = [['Franz', 'Emil'], ['Emil', 'Fred']]
keys = tf.convert_to_tensor(keys) # to get the shape
key_shape = keys.shape # shape: (2, 2)
x = table.lookup(tf.reshape(keys, -1)) # shape: (4, 4) after hashing
x = tf.reshape(x, key_shape+(x.shape[-1:])) # shape: (2, 2 ,4)