如何在自定义指标中获取会话?

问题描述 投票:0回答:1

我正在尝试在Keras中运行自定义指标。我成功了,但我不相信它的结果所以我想检查一些值。麻烦的是,一切都在张量中,我想将它们转换为ndarrays以便检查它们。要转换它们,我必须有一个会话来评估它们。当我尝试使用Keras后端进行会话时出现错误:

InvalidArgumentError(请参阅上面的回溯):您必须为占位符张量'Dense_1_target_1'提供一个值,其中dtype为float和shape [?,?] [[Node:Dense_1_target_1 = Placeholderdtype = DT_FLOAT,shape = [?,?],_ device =“ /工作:本地主机/副本:0 /任务:0 /设备:GPU:0" ]]

我唯一想要的是能够打印一些有关张量的信息:值,形状等。

from keras import backend as K

def t_zeros(Y_true, y_pred):
""" Just count # zero's in Y_true and try to print some info """
    threshold = 0.5
    true_zeros = K.less(Y_true, threshold) # element-wise True where Y_true < theshold
    true_zeros = K.cast(true_zeros, K.floatx())  # cast to 0.0 / 1.0
    n_zeros = K.sum(true_zeros)

    sess = K.get_session()
    y_t = Y_true.eval(session=sess) # <== error happens here
    print(y_t.shape)

    return n_zeros
python tensorflow keras metrics
1个回答
1
投票

请记住,tensorflow使用延迟评估。

所以你不能在你的函数中使用print的值。您需要创建一个打印节点并将其挂钩到整个图形中。

像这样的东西

def t_zeros(Y_true, y_pred):
""" Just count # zero's in Y_true and try to print some info """
    threshold = 0.5
    true_zeros = K.less(Y_true, threshold) # element-wise True where Y_true < theshold
    true_zeros = K.cast(true_zeros, K.floatx())  # cast to 0.0 / 1.0
    n_zeros = K.sum(true_zeros)

    return tf.Print(n_zeros, [n_zeros]) 

... 
my_metric = t_zeros(Y_true, y_pred)  # Returns the tensor, but we need to make sure it's evaluated
...
train_op = tf.group(train_op, my_metric) 

如果您愿意,可以将其连接到其他操作,只需确保它得到评估。

© www.soinside.com 2019 - 2024. All rights reserved.