我正在尝试创建一个基于使用 PyG 构建的异构图的数据加载器,以便在 Pytorch Lightning 中实现的模型中使用。我尝试使用 HGTLoader 和 NeighborLoader 生成数据加载器,但在这两种情况下,输出都是空对象。
为了检查数据加载器创建过程是否正常工作,我创建了一个玩具图,如下所示:
data= HeteroData(
user={x=[5, 1],},
poi={x=[10, 1] },
(user, visits, poi)={
edge_index=[2, 40],
edge_attr=[40, 2],
},
(poi, connects, poi)={
edge_index=[2, 100],
edge_attr=[100, 2],
}
)
与
poi_ids = (0,1,2,3,4,5,6,7,8,9)
user_ids = (10,11,12,13,14)
我已经检查过,所有的edge_index元素都被正确识别。
我尝试使用 HGTLoader 和 NeighborLoader 生成数据加载器,但在这两种情况下,输出都是空对象。考虑到我希望用户节点作为采样的输入节点。
下面是两个采样器的代码。
data_loader = HGTLoader(
data,
batch_size=2,
input_nodes=('user', torch.tensor([10, 11, 12, 13,14], dtype=torch.long)),
num_samples={key: [2] for key in data.node_types},
shuffle=True,
)
data_loader= NeighborLoader(
data,
batch_size=2,
input_nodes=('user', torch.tensor([10, 11, 12, 13,14], dtype=torch.long)),
num_neighbors={key: [2] for key in data.edge_types},
shuffle=True)
有人可以帮我找出问题所在吗?非常感谢!
尝试 input_nodes=('user')。 也许,当您将索引传递为 torch.tensor([10, 11, 12, 13,14], dtype=torch.long) 时,加载程序无法找到这样的用户节点,因为它们的索引为 0, 1, 2, 3 , 4.