如何获得tf.data.Dataset的特征和张量字典?

问题描述 投票:0回答:1

给出以下数据集:

import pandas as pd
import tensorflow as tf

df = pd.DataFrame({
    'feat_binomial': [5, 1, 7, 4, 6],
    'feat_normal': [5.001512, 5.346654, -0.480363,4.821558,-2.080958],
    'feat_ordinal': ['low', 'low', 'low','low','low'],
    'feat_string': ['a', 'b', 'b','b','b'],
})

dataset = tf.data.Dataset.from_tensor_slices(dict(df))

我想获得特征和张量的字典,将map函数应用于数据集时,会得到它。在这里可以看到它的印刷:

def print_input_dict(input_dict):
    features_tensors = dict(input_dict)
    print(features_tensors)
    return features_tensors

dataset.map(print_input_dict)
{'feat_binomial': <tf.Tensor 'args_0:0' shape=() dtype=int32>,
 'feat_normal': <tf.Tensor 'args_1:0' shape=() dtype=float64>,
 'feat_ordinal': <tf.Tensor 'args_2:0' shape=() dtype=string>,
 'feat_string': <tf.Tensor 'args_3:0' shape=() dtype=string>}

为了得到这本字典,我设法做到了以下几点:

def _extract_labels(input_dict):
    global features_tensors
    features_tensors = dict(input_dict)
    return features_tensors

dataset.map(print_input_dict)

因此它存储在一个全局变量中,我以后可以访问它,但这并不像是正确的方法。还有其他获取这本词典的方法吗?

谢谢

python tensorflow tensorflow-datasets
1个回答
0
投票

您可以尝试使用迭代器,

iterator = dataset.make_one_shot_iterator()
feature_dict = iterator.get_next() # Returns Dictionary of features
© www.soinside.com 2019 - 2024. All rights reserved.