您好,我使用了GraphSAGE来进行节点嵌入。我选择用于聚合的函数是LSTM以及用于图神经网络的PyG库,它需要的参数如下:
输入 1:节点特征 (|V|, F_in) - 这里我使用 2D 平面中的节点坐标 x-y (V x 2),并且已经标准化为 [0, 1] 范围,例如
x y
0 0.374540 0.598658
1 0.950714 0.156019
2 0.731994 0.155995
输入 2:边索引 (2, |E|) - 邻接矩阵 (V x V),但仅从我拥有的原始邻接矩阵中检索边到 (2, |E|)
idx 0 1 2
0 [[0, 1, 1],
1 [1, 0, 1],
2 [1, 1, 0]]
上图中我们有一个具有 6 条边的形状 (V x V)。我们必须对其进行一些改造以适应 PyG 对形状 (2, |E|) 的使用,我想将其称为
edge_index
,其中边为 (0, 1), (0, 2), (1, 0 ), (1, 2), (2, 0), (2, 1)。
idx 0 1 2 3 4 5
0 [[0, 0, 1, 1, 2, 2],
1 [1, 2, 0, 2, 0, 1]]
输出:节点特征(|V|,F_out) - 与节点坐标类似,但它们不再是二维的,它们位于具有 F_out 维度的新嵌入维度中。
我的问题是,当使用 LSTM 聚合器时,它被迫排序
edge_index
(input2 中的边缘索引),否则会显示错误 ValueError: Can not perform aggregation since the 'index' tensor. is not sorted.
所以我必须做排序用以下命令给出它:
# inside def __init__()
self.graph_sage=SAGEConv(in_channels=2, out_channels=hidden_dim, aggr='lstm')
# inside def forward()
sorted_edge_index, _ = torch.sort(edge_index, dim=1) # for LSTM
x = self.graph_sage(coord.view(-1, 2), sorted_edge_index) # using GraphSAGE
排序后
sorted_edge_index
张量将如下所示。
idx 0 1 2 3 4 5
0 [[0, 0, 1, 1, 2, 2],
1 [0, 0, 1, 1, 2, 2]]
我注意到,在连接 3 个节点的全网格图中,当我对它进行排序时,边可以被重新解释为 (0, 0)、(0, 0)、(1, 1)、(1, 1)、 (2, 2), (2, 2) 这让我很好奇。我的问题是以下两件事。
edge_index
进行排序?edge_index
进行排序后,我的模型将如何知道哪些节点已连接?因为所有原来的边关系对都没有了。这就像发送图中不存在的边对作为输入。这会是一个缺点吗?我已经尝试过执行上述操作,并且运行良好。但我有一些疑问,希望有知识的人能够帮助像我这样的初学者澄清问题。我真诚地希望这个问题对其他学习 GNN 的学生也有用。
按行排序 = False
from torch_geometric.utils import sort_edge_index
sort_edge_index=sort_edge_index(edge_index, num_nodes=self.num_nodes, sort_by_row=False)
x=self.graphsage(coord.view(-1, 2), sort_edge_index)
https://github.com/pyg-team/pytorch_geometric/discussions/8908