我决定从keras切换到tf.keras(建议here。因此,我安装了tf.__version__=2.0.0
和tf.keras.__version__=2.2.4-tf
。在我的代码的较旧版本中(使用一些较旧的Tensorflow版本tf.__version__=1.x.x
),我使用了一个回调在每个时期结束时对整个验证数据计算自定义指标。这样做的想法来自here。但是,似乎似乎不赞成使用“ validation_data”属性,因此以下代码不再起作用。
class ValMetrics(Callback): def on_train_begin(self, logs={}): self.val_all_mse = [] def on_epoch_end(self, epoch, logs): val_predict = np.asarray(self.model.predict(self.validation_data[0])) val_targ = self.validation_data[1] val_epoch_mse = mse_score(val_targ, val_predict) self.val_epoch_mse.append(val_epoch_mse) # Add custom metrics to the logs, so that we can use them with # EarlyStop and csvLogger callbacks logs["val_epoch_mse"] = val_epoch_mse print(f"\nEpoch: {epoch + 1}") print("-----------------") print("val_mse: {:+.6f}".format(val_epoch_mse)) return
我当前的解决方法如下。我只是给了validate_data作为
ValMetrics
类的参数:
class ValMetrics(Callback): def __init__(self, validation_data): super(Callback, self).__init__() self.X_val, self.y_val = validation_data
仍然有一些问题:是否确实不赞成使用“ validation_data”属性,或者可以在其他位置找到它?与上述解决方法相比,在每个时期结束时是否有更好的方法来访问验证数据?
非常感谢!
我决定从keras切换到tf.keras(按照此处的建议)。因此,我安装了tf .__ version __ = 2.0.0和tf.keras .__ version __ = 2.2.4-tf。在我的代码的旧版本中(使用一些较旧的...
您很正确,validation_data
参数已根据Tensorflow Callbacks Documentation弃用。