Tensorflow + google ml-engine本地预测:如何处理CSV引号?

问题描述 投票:1回答:2

我已经训练了一个用于预测0/1值的CNN模型,并使用谷歌ml-engine本地预测进行测试。我的测试文件包含2行:

some text
"some text"

我知道这应该给我1和1作为预测结果。但是输出是1和0.所以双引号因某种原因很重要。虽然训练pandas.read_csv用于词汇创建。

pd.read_csv(filename, header=None, sep=',', names=['source', 'title'],encoding='utf-8', na_filter=False,engine='python')

对于预测,使用以下命令:

gcloud ml-engine local predict --model-dir=.... --text-instances=... --format=json

我在阅读csv进行培训时是否缺少一些参数,还是谷歌ml-engine的问题?

python pandas tensorflow google-cloud-ml
2个回答
1
投票

据推测,您导出的图形包含decode_csv op来读取输入。如何处理报价将取决于参数use_quote_delim的设置。为了说明,请考虑以下事项:

import tensorflow as tf

data = ['some text', '"some text"']
with tf.Session() as sess:
  use_delim = tf.decode_csv(data, [['']], use_quote_delim=True)
  dont = tf.decode_csv(data, [['']], use_quote_delim=False)
  out = sess.run([use_delim, dont])
  print("use", out[0])
  print("dont", out[1])

>>> ('use', [array(['some text', 'some text'], dtype=object)])
>>> ('dont', [array(['some text', '"some text"'], dtype=object)])

要获得您期望的行为,您希望导出的模型设置use_quote_delim=True(这是默认值)。


0
投票

你在输入函数中使用pandas还是使用tf.decode_csv?

在您的gcloud预测调用中,您将输入格式设置为json。您的意思是将其设置为CSV吗?

© www.soinside.com 2019 - 2024. All rights reserved.