我用参数创建了自定义丢失函数。
def w_categorical_crossentropy(weights):
def loss(y_true, y_pred):
print(weights)
print("----")
print(weights.shape)
final_mask = K.zeros_like(y_pred[:, 0])
y_pred_max = K.max(y_pred, axis=1)
y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
y_pred_max_mat = K.cast(K.equal(y_pred, y_pred_max), K.floatx())
return K.categorical_crossentropy(y_pred, y_true)
return loss
现在,我需要控制权重参数值,但打印功能无法正常工作。有没有办法打印重量值?
您可以使用tf.print
来调试自定义丢失函数。 tf.print
必须包含在图表中才能实际打印。根据文档,你用tf.control_dependencies
做到这一点。
def w_categorical_crossentropy(weights):
def loss(y_true, y_pred):
print_op = tf.print("weights: ", weights, weights.shape)
final_mask = K.zeros_like(y_pred[:, 0])
y_pred_max = K.max(y_pred, axis=1)
y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
y_pred_max_mat = K.cast(K.equal(y_pred, y_pred_max), K.floatx())
with tf.control_dependencies([print_op]):
return K.categorical_crossentropy(y_pred, y_true)
return loss