我正在尝试实现一个DQN,该DQN在同一模型上对Estimator.train()
和Estimator.predict()
进行多次调用,每个示例都有少量示例。但是每个调用至少要花费几百毫秒到一秒以上的时间,这与小数字(例如1-20)的示例数无关。
我认为这些延迟是由于重建图表并在每次调用时保存检查点而引起的。有没有办法将相同的图形和参数保留在内存中,以进行快速的火车预测迭代或以其他方式加快速度?]
转换为tf.keras.Model
而不是Estimator
,并使用tf.keras.Model.fit()
代替Estimator.train()
。 fit()
没有train()的固定延迟。 Keras predict()
也没有。