使用 NeighborLoader/ HGTLoader 和异构图创建数据加载器时出现问题

问题描述 投票:0回答:1

我正在尝试创建一个基于使用 PyG 构建的异构图的数据加载器,以便在 Pytorch Lightning 中实现的模型中使用。我尝试使用 HGTLoader 和 NeighborLoader 生成数据加载器,但在这两种情况下,输出都是空对象。

为了检查数据加载器创建过程是否正常工作,我创建了一个玩具图,如下所示:

data= HeteroData(
  user={x=[5, 1],},
  poi={x=[10, 1] },
  (user, visits, poi)={
    edge_index=[2, 40],
    edge_attr=[40, 2],
  },
  (poi, connects, poi)={
    edge_index=[2, 100],
    edge_attr=[100, 2],
  }
)

poi_ids = (0,1,2,3,4,5,6,7,8,9)
user_ids = (10,11,12,13,14)

我已经检查过,所有的edge_index元素都被正确识别。 我尝试使用 HGTLoader 和 NeighborLoader 生成数据加载器,但在这两种情况下,输出都是空对象。考虑到我希望用户节点作为采样的输入节点。
下面是两个采样器的代码。

data_loader = HGTLoader(
    data, 
    batch_size=2, 
    input_nodes=('user', torch.tensor([10, 11, 12, 13,14], dtype=torch.long)),
    num_samples={key: [2]  for key in data.node_types}, 
    shuffle=True, 
)

data_loader= NeighborLoader(
          data, 
          batch_size=2, 
          input_nodes=('user', torch.tensor([10, 11, 12, 13,14], dtype=torch.long)), 
          num_neighbors={key: [2]  for key in data.edge_types}, 
          shuffle=True)

有人可以帮我找出问题所在吗?非常感谢!

graph pytorch-lightning pytorch-dataloader pytorch-geometric gnn
1个回答
0
投票

尝试 input_nodes=('user')。 也许,当您将索引传递为 torch.tensor([10, 11, 12, 13,14], dtype=torch.long) 时,加载程序无法找到这样的用户节点,因为它们的索引为 0, 1, 2, 3 , 4.

© www.soinside.com 2019 - 2024. All rights reserved.