Python成员操作符 "在 "TensorFlow数据集中。

问题描述 投票:1回答:1

我正在用TensorFlow Datasets开发一个输入流水线,数据集只有两列,我想基于一个值列表进行过滤,但我只是可以使用操作符等于"=="来过滤数据集,当我尝试使用成员操作符 "in "时,我收到了以下错误。

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.

下面是我的代码。

import numpy as np
import tensorflow as tf


# Load file
file_path = 'drive/My Drive/Datasets/category_catalog.csv.gz'

def get_dataset(file_path, batch_size=5, num_epochs=1, **kwargs):
    return tf.data.experimental.make_csv_dataset(
        file_path,
        batch_size=batch_size,
        na_value="?",
        num_epochs=num_epochs,
        ignore_errors=True,
        **kwargs
    )

raw_data = get_dataset(
    file_path,
    select_columns=['description', 'department'],
    compression_type='GZIP'
)

这个过滤器工作。

@tf.function
def filter_fn(features):
    return features['department'] == 'MOVEIS'

ds = raw_data.unbatch()
ds = ds.filter(filter_fn)
ds = ds.batch(2)

输出:

next(iter(ds))

OrderedDict([('description', <tf.Tensor: shape=(2,), dtype=string, numpy=
              array([b'KIT DE COZINHA KITS PARANA 8 PORTAS GOLDEN EM MDP LINHO BRANCO E LINHO PRETO',
                     b'ARMARIO AEREO PARA COZINHA 1 PORTA HORIZONTAL EXCLUSIVE ITATIAIA PRETO MATTE'],
                    dtype=object)>),
             ('department',
              <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'MOVEIS', b'MOVEIS'], dtype=object)>)])

这个过滤器不工作。

@tf.function
def filter_fn(features):
    return features['department'] in ['FERRAMENTAS', 'MERCEARIA', 'MOVEIS']

ds = raw_data.unbatch()
ds = ds.filter(filter_fn)
ds = ds.batch(2)

错误:

---------------------------------------------------------------------------
OperatorNotAllowedInGraphError            Traceback (most recent call last)
<ipython-input-52-52131b5369b6> in <module>()
      6 
      7 ds = raw_data.unbatch()
----> 8 ds = ds.filter(filter_fn)
      9 ds = ds.batch(2)

18 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

OperatorNotAllowedInGraphError: in user code:

    <ipython-input-52-52131b5369b6>:5 filter_fn  *
        return features['department'] in ['FERRAMENTAS', 'MERCEARIA', 'MOVEIS']
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:778 __bool__
        self._disallow_bool_casting()
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:545 _disallow_bool_casting
        "using a `tf.Tensor` as a Python `bool`")
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:532 _disallow_when_autograph_enabled
        " decorating it directly with @tf.function.".format(task))

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.

这里是 联系 到Colab,在那里可以运行和检查错误。

python tensorflow tensorflow-datasets
1个回答
1
投票

我尝试的是下面的过滤函数。

tf.reduce_any(tf.math.equal(features['department'], ['FERRAMENTAS', 'MERCEARIA', 'MOVEIS']))

在你提供的colab中似乎是可行的,但我不确定这是否是你想要的。

其逻辑是这样的:在你提供的colab中,它似乎可以工作,但我不确定你想要的是什么。math.equal 运算符将给出一个大小为3的张量,其中每个条目都是 TrueFalse. 首先输入的是部门是否为 "FERRAMENTAS "等......然后是。reduce_any 基本上会对这3个条目张量进行逻辑OR。因此,如果该部门是3个白名单中的一个,它将有正好一个 True 3个条目张量中的条目,因此在 reduce_any 输出将是 True. 这将是 False 在所有其他情况下。

© www.soinside.com 2019 - 2024. All rights reserved.