我需要你的帮助。我有两组图形结构化数据,一组来自开放图基准 (OGB),另一组使用
torch_geometric.data.Dataset
根据我自己的数据创建。数据如下:
Data(edge_index=[2, 88], edge_attr=[88, 3], x=[39, 9], y=[1, 1]) #OGB
Data(x=[23, 9], edge_index=[2, 48], edge_attr=[48, 2], y=[1]) #PyG
我正在尝试使用使用 OGB 函数开发的框架,这不适用于使用 PyG 创建的数据。例如:框架的第一部分加载并将数据集拆分为train、val和test:
# Set the random seed
random.seed(random_seed)
np.random.seed(random_seed)
# Create data loaders
split_idx = dataset.get_idx_split() # train/val/test split
loader_dict = {}
for phase in split_idx:
batch_size = 32
loader_dict[phase] = DataLoader(dataset[split_idx[phase]], batch_size=batch_size, shuffle=False)
当我使用本机 ogb 数据集运行此代码时,没有任何问题,当我使用 PyG 数据时返回错误:
AttributeError
这很奇怪,因为它们都是 Pytorch 对象,唯一的区别是 OGB 数据集是 InMemoryDataset,而 PyG 是一个“更大”的数据集(https://pytorch-geometric.readthedocs.io/en/latest/笔记/create_dataset.html)。有什么办法可以解决这个问题而无需更改源代码吗?
如果您想使用相同的代码,您需要为您自己的数据集实现
get_idx_split
。
您可以在 OGB GitHub 中找到所需的返回结构,例如这里:
def get_idx_split(self):
< ... do something to retrieve train/test/validation set>
return {'train': train_idx, 'valid': valid_idx, 'test': test_idx}