如何处理“类型错误:tf__prepare_batch() 需要 1 个位置参数,但给出了 3 个”?

问题描述 投票:0回答:0

我创建了两个函数来准备我的数据集来学习消息传递图神经网络。当我调用 MPNNDataset() 函数时,它会产生与 prepare_batch() 函数相关的错误:

TypeError: 在用户代码中: 类型错误:tf__prepare_batch() 接受 1 个位置参数,但给出了 3 个

我的代码如下:

def graphs_from_smiles(smiles_list):
    # Initialize graphs
    atom_features_list = []
    bond_features_list = []
    pair_indices_list = []

    for smiles in smiles_list:
      if smiles!=np.NaN:
        molecule = molecule_from_smiles(smiles)
        atom_features, bond_features, pair_indices = graph_from_molecule(molecule)

        atom_features_list.append(atom_features)
        bond_features_list.append(bond_features)
        pair_indices_list.append(pair_indices)
      else:
        atom_features_list = np.array([0 for i in range(129)], ndmin=2)
        bond_features_list = np.array([0 for i in range(7)], ndmin=2)
        pair_indices_list = np.array([0, 0], ndmin=2)
    # Convert lists to ragged tensors for tf.data.Dataset later on
    return (
      tf.ragged.constant(atom_features_list, dtype=tf.float32),
      tf.ragged.constant(bond_features_list, dtype=tf.float32),
      tf.ragged.constant(pair_indices_list, dtype=tf.int64),
      )

train_index = permuted_indices[: int(len(smiles_list) * 0.8)]
x_train = graphs_from_smiles(df.iloc[train_index].smiles)

def prepare_batch(x_batch):
    """Merges (sub)graphs of batch into a single global (disconnected) graph
    """

    atom_features, bond_features, pair_indices = x_batch

    # Obtain number of atoms and bonds for each graph (molecule)
    num_atoms = atom_features.row_lengths()
    num_bonds = bond_features.row_lengths()

    # Obtain partition indices (molecule_indicator), which will be used to
    # gather (sub)graphs from global graph in model later on
    molecule_indices = tf.range(len(num_atoms))
    molecule_indicator = tf.repeat(molecule_indices, num_atoms)

    # Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to
    # 'pair_indices' (and merging ragged tensors) actualizes the global graph
    gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:])
    increment = tf.cumsum(num_atoms[:-1]) # tf.cumsum([a, b, c])   # [a, a + b, a + b + c]
    increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)])
    pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
    pair_indices = pair_indices + increment[:, tf.newaxis]
    atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
    bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()

    return (atom_features, bond_features, pair_indices, molecule_indicator)


def MPNNDataset(X, batch_size=32, shuffle=False):
    dataset = tf.data.Dataset.from_tensor_slices((X))
    if shuffle:
        dataset = dataset.shuffle(1024)
    return dataset.batch(batch_size).map(prepare_batch, -1).prefetch(-1)

train_dataset = MPNNDataset(x_train)

TypeError                                 Traceback (most recent call last)
<ipython-input-19-f3777974a600> in <module>
----> 1 train_dataset = MPNNDataset(x_train)
      2 valid_dataset = MPNNDataset(x_valid)
      3 test_dataset = MPNNDataset(x_test)
      4 
      5 history = mpnn.fit(

14 frames
/usr/local/lib/python3.9/dist-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
    437     try:
    438       if kwargs is not None:
--> 439         result = converted_f(*effective_args, **kwargs)
    440       else:
    441         result = converted_f(*effective_args)

TypeError: in user code:


    TypeError: tf__prepare_batch() takes 1 positional argument but 3 were given

我不知道为什么会这样,在我看来我传递了 1 个参数给 prepare_batch,但是错误说我有 3 个参数。

tensorflow batch-file dataset arguments
© www.soinside.com 2019 - 2024. All rights reserved.