我正在用TensorFlow Datasets开发一个输入流水线,数据集只有两列,我想基于一个值列表进行过滤,但我只是可以使用操作符等于"=="来过滤数据集,当我尝试使用成员操作符 "in "时,我收到了以下错误。
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.
下面是我的代码。
import numpy as np
import tensorflow as tf
# Load file
file_path = 'drive/My Drive/Datasets/category_catalog.csv.gz'
def get_dataset(file_path, batch_size=5, num_epochs=1, **kwargs):
return tf.data.experimental.make_csv_dataset(
file_path,
batch_size=batch_size,
na_value="?",
num_epochs=num_epochs,
ignore_errors=True,
**kwargs
)
raw_data = get_dataset(
file_path,
select_columns=['description', 'department'],
compression_type='GZIP'
)
这个过滤器工作。
@tf.function
def filter_fn(features):
return features['department'] == 'MOVEIS'
ds = raw_data.unbatch()
ds = ds.filter(filter_fn)
ds = ds.batch(2)
输出:
next(iter(ds))
OrderedDict([('description', <tf.Tensor: shape=(2,), dtype=string, numpy=
array([b'KIT DE COZINHA KITS PARANA 8 PORTAS GOLDEN EM MDP LINHO BRANCO E LINHO PRETO',
b'ARMARIO AEREO PARA COZINHA 1 PORTA HORIZONTAL EXCLUSIVE ITATIAIA PRETO MATTE'],
dtype=object)>),
('department',
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'MOVEIS', b'MOVEIS'], dtype=object)>)])
这个过滤器不工作。
@tf.function
def filter_fn(features):
return features['department'] in ['FERRAMENTAS', 'MERCEARIA', 'MOVEIS']
ds = raw_data.unbatch()
ds = ds.filter(filter_fn)
ds = ds.batch(2)
错误:
---------------------------------------------------------------------------
OperatorNotAllowedInGraphError Traceback (most recent call last)
<ipython-input-52-52131b5369b6> in <module>()
6
7 ds = raw_data.unbatch()
----> 8 ds = ds.filter(filter_fn)
9 ds = ds.batch(2)
18 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise
OperatorNotAllowedInGraphError: in user code:
<ipython-input-52-52131b5369b6>:5 filter_fn *
return features['department'] in ['FERRAMENTAS', 'MERCEARIA', 'MOVEIS']
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:778 __bool__
self._disallow_bool_casting()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:545 _disallow_bool_casting
"using a `tf.Tensor` as a Python `bool`")
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:532 _disallow_when_autograph_enabled
" decorating it directly with @tf.function.".format(task))
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.
这里是 联系 到Colab,在那里可以运行和检查错误。
我尝试的是下面的过滤函数。
tf.reduce_any(tf.math.equal(features['department'], ['FERRAMENTAS', 'MERCEARIA', 'MOVEIS']))
在你提供的colab中似乎是可行的,但我不确定这是否是你想要的。
其逻辑是这样的:在你提供的colab中,它似乎可以工作,但我不确定你想要的是什么。math.equal
运算符将给出一个大小为3的张量,其中每个条目都是 True
或 False
. 首先输入的是部门是否为 "FERRAMENTAS "等......然后是。reduce_any
基本上会对这3个条目张量进行逻辑OR。因此,如果该部门是3个白名单中的一个,它将有正好一个 True
3个条目张量中的条目,因此在 reduce_any
输出将是 True
. 这将是 False
在所有其他情况下。