我创建了两个函数来准备我的数据集来学习消息传递图神经网络。当我调用 MPNNDataset() 函数时,它会产生与 prepare_batch() 函数相关的错误:
TypeError: 在用户代码中: 类型错误:tf__prepare_batch() 接受 1 个位置参数,但给出了 3 个
我的代码如下:
def graphs_from_smiles(smiles_list):
# Initialize graphs
atom_features_list = []
bond_features_list = []
pair_indices_list = []
for smiles in smiles_list:
if smiles!=np.NaN:
molecule = molecule_from_smiles(smiles)
atom_features, bond_features, pair_indices = graph_from_molecule(molecule)
atom_features_list.append(atom_features)
bond_features_list.append(bond_features)
pair_indices_list.append(pair_indices)
else:
atom_features_list = np.array([0 for i in range(129)], ndmin=2)
bond_features_list = np.array([0 for i in range(7)], ndmin=2)
pair_indices_list = np.array([0, 0], ndmin=2)
# Convert lists to ragged tensors for tf.data.Dataset later on
return (
tf.ragged.constant(atom_features_list, dtype=tf.float32),
tf.ragged.constant(bond_features_list, dtype=tf.float32),
tf.ragged.constant(pair_indices_list, dtype=tf.int64),
)
train_index = permuted_indices[: int(len(smiles_list) * 0.8)]
x_train = graphs_from_smiles(df.iloc[train_index].smiles)
def prepare_batch(x_batch):
"""Merges (sub)graphs of batch into a single global (disconnected) graph
"""
atom_features, bond_features, pair_indices = x_batch
# Obtain number of atoms and bonds for each graph (molecule)
num_atoms = atom_features.row_lengths()
num_bonds = bond_features.row_lengths()
# Obtain partition indices (molecule_indicator), which will be used to
# gather (sub)graphs from global graph in model later on
molecule_indices = tf.range(len(num_atoms))
molecule_indicator = tf.repeat(molecule_indices, num_atoms)
# Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to
# 'pair_indices' (and merging ragged tensors) actualizes the global graph
gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:])
increment = tf.cumsum(num_atoms[:-1]) # tf.cumsum([a, b, c]) # [a, a + b, a + b + c]
increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)])
pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
pair_indices = pair_indices + increment[:, tf.newaxis]
atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
return (atom_features, bond_features, pair_indices, molecule_indicator)
def MPNNDataset(X, batch_size=32, shuffle=False):
dataset = tf.data.Dataset.from_tensor_slices((X))
if shuffle:
dataset = dataset.shuffle(1024)
return dataset.batch(batch_size).map(prepare_batch, -1).prefetch(-1)
train_dataset = MPNNDataset(x_train)
TypeError Traceback (most recent call last)
<ipython-input-19-f3777974a600> in <module>
----> 1 train_dataset = MPNNDataset(x_train)
2 valid_dataset = MPNNDataset(x_valid)
3 test_dataset = MPNNDataset(x_test)
4
5 history = mpnn.fit(
14 frames
/usr/local/lib/python3.9/dist-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
437 try:
438 if kwargs is not None:
--> 439 result = converted_f(*effective_args, **kwargs)
440 else:
441 result = converted_f(*effective_args)
TypeError: in user code:
TypeError: tf__prepare_batch() takes 1 positional argument but 3 were given
我不知道为什么会这样,在我看来我传递了 1 个参数给 prepare_batch,但是错误说我有 3 个参数。