TensorFlow估算器在每次调用预测时做出不同的预测

问题描述 投票:1回答:1

我使用TF Estimators训练了Iris数据集的分类器,但是每次预测都得到了不同的结果。我想知道我是在训练中出错还是在预测中出了问题。

我正在加载一个已经训练有素的模型,然后进行.predict调用。这是我用于预测的输入函数。

def get_predict_fn(features,batch_size):
    def predict_input_fn():
        dataset = tf.data.Dataset.from_tensor_slices(dict(features))
        dataset = dataset.batch(batch_size)
        return dataset.make_one_shot_iterator().get_next()

    return predict_input_fn

这是一个呼叫的结果

[{'logits':array([-3.5082035,-1.074667,-3.8533034],dtype = float32),“概率”:array([0.07629351,0.8696793,0.05402722],dtype = float32),'class_ids':array([1]),'classes':array([b'Iris-versicolor'],dtype = object)}]

这是另一个电话

[{'logits':array([3.0530725,-1.0889677,2.3922846],dtype = float32),“概率”:array([0.6525989,0.01037006,0.337031],dtype = float32),'class_ids':array([0]),'classes':array([b'Iris-setosa'],dtype = object)}]

两者都调用相同的模型,发送相同的示例DataFrame。

sepal_length sepal_width花瓣_长度petal_width5.7 2.5 5.0 2.0

python tensorflow tensorflow-estimator
1个回答
0
投票

我有同样的问题。我已经尝试过此代码,并且对我有用:

        checkpoint_path = model.latest_checkpoint()
        print("checkpoint_path=", checkpoint_path)
        result = model.predict(input_fn=predict_input_fn,
                               hooks=[tf.train.LoggingTensorHook([ 'user_id', 'ad_info',
                                                                   'predict_id'], every_n_iter=1000)],
                               checkpoint_path=checkpoint_path)
        prediction_res = []
        for prediction in result:
            # print("predictions=", prediction)
            user_id = prediction['user_id']
            predict_label = prediction['predict_label']

我也阅读了TF1.13估计器的预测源代码,实际上有用于读取最新检查点的代码,

   with context.graph_mode():
      hooks = _check_hooks_type(hooks)
      # Check that model has been trained.
      if not checkpoint_path:
        checkpoint_path = checkpoint_management.latest_checkpoint(
            self._model_dir)
      if not checkpoint_path:
        logging.info('Could not find trained model in model_dir: {}, running '
                     'initialization to predict.'.format(self._model_dir))

但是我仍然不知道为什么函数“ checkpoint_management.latest_checkpoint(self._model_dir)”与“ model.latest_checkpoint()”不同)

© www.soinside.com 2019 - 2024. All rights reserved.