我正在将代码从tensorflow 1.x更新到2.1.0。
我更改了tensorflow 1.x代码
labels = tf.cast(labels, tf.int64)
predict = tf.argmax(input=logits, axis=1)
tf.metrics.accuracy(labels=labels, predictions=predict)
到tensorflow 2.1.0代码。
labels = tf.cast(labels, tf.int64)
predict = tf.argmax(input=logits, axis=1)
tf.keras.metrics.Accuracy.update_state(labels, predict) #updated code
但是,当我运行更新的代码时,出现以下错误。
TypeError: update_state() missing 1 required positional argument: 'y_pred'
所以,我检查了tensorflow 2.1.0文档,tf.keras.metrics.Accuracy.update_state()
的参数似乎是一个列表(以[,,]的形式)。然后,我寻找一种将张量转换为列表的方法,这是
labels = tf.make_tensor_proto(labels)
labels = tf.make_ndarray(labels)
运行此代码后,出现以下错误。
TypeError: List of Tensors when single Tensor expected
因此,我尝试通过以下方式将张量列表转换为张量:>
labels = tf.stack(labels) #or labels = torch.stack(labels)
[
tf.stack()
不起作用,因为它给出了相同的初始TypeError,表明更新的代码中缺少'y_pred'。
但是,torch.stack()
出现以下错误。
。但是,TypeError: stack() : argument 'tensors' (position 1) must be tuple of Tensors, not Tensor
所以,我猜测
torch.stack()
仅接受一个元组,没有列表
tf.stack()
似乎接受了一个列表,但它没有将其转换为张量?我的标签是否首先预测张量列表?如果是这样,为什么tf.stack()不能将它们变成张量?如何正确转换标签并进行预测,以便可以将其传递到tf.keras.metrics.Accuracy.update_state()
?
除非绝对必要,否则如果不使用compat.v1.
,我将不胜感激。
我正在将代码从tensorflow 1.x更新为2.1.0。我更改了tensorflow 1.x代码标签= tf.cast(标签,tf.int64)预测= tf.argmax(输入= logits,轴= 1)tf.metrics.accuracy(标签=标签,...
以这种方式尝试: