如何使用tf.data.TextlineDataset读取多个输入模型的数据?

问题描述 投票:0回答:1

型号

我用多个输入创建了一个模型,该模型可以嵌入索引或连续数字。例如,存在三个输入,其名称分别为input1input2input3,分别为fixed length embedding indexvariable length embedding indexcontinuous numbers

数据

数据文件的格式组织如下:

input1 input2   input3   label
1       1,2    0.51,0.62   2

所有输入均由制表符(\ t)分隔。

[Variable length embedding indexcontinuous numbers输入值用逗号(,)分隔。

加载数据

现在,我想从数据文件中加载火车数据。为此,我使用了tf.data.TextLineDataset。但是如何将input2和input3的值转换为数组张量进行训练和评估?我已经尝试过Dataset的地图功能。

剪切的代码

dataset = tf.data.TextLineDataset('file.tsv')
dataset = dataset.map(labeler)

def labeler(record):
    fields = tf.decode_csv(record, record_defaults=['0', '0', '0', 0], field_delim='\t')
    label = fields[-1]
    del fields[-1]

    data = dict()
    data['input1'] = tf.cast(fields[0], dtype=int64)
    # How to do with input2 and input3??
    data['input2'] = ??
    data['input3'] = ??

    return data, label
tensorflow tensorflow-datasets
1个回答
0
投票

我会亲自回答这个问题,这里是函数labeler的代码:

def labeler(record):
    fields = tf.io.decode_csv(record,
                              record_defaults=['0'] * 4,
                              field_delim='\t',
                              select_cols=list(range(0, 4)))
    data = dict()
    data['input1'] = tf.strings.to_number(fields[0], out_type='int64')
    data['input2'] = tf.strings.to_number(tf.strings.split([fields[1]],
                                                           sep=',').values,
                                          out_type='int64')
    data['input3'] = tf.strings.to_number(tf.strings.split([fields[2]],
                                                           sep=',').values,
                                          out_type='float64')
    label = tf.strings.to_number(fields[-1], out_type='int64')

    return data, label

通知:

如果要使用batch功能批处理上述数据集,它将失败。因为数据集具有可变长度输入字段。

解决此问题的方法是使用数据集的padded_batch函数。并且由于有多个输入,您应该使用元组为每个输入设置shape,该元组将传递给padded_batch。这是代码:

shapes = ({'input1': [], 'input2': [None], 'input3': []}, [])
dataset = dataset.map(lambda ex: labeler(ex))
dataset = dataset.shuffle(1000).repeat(2).padded_batch(batch_size,
                                                       padded_shapes=shapes)

[]表示无填充,[None]表示使用0填充该批次中最长的记录。

尽管这有效,但是填充全部0是否会影响training effect仍是未知的。如果您有任何想法,很高兴听到您的声音。

© www.soinside.com 2019 - 2024. All rights reserved.