如何获得 GCN 中的最终节点嵌入?

我正在 UPFD(假新闻检测)图数据集上构建 GCN。我的代码执行图分类。我需要获得最终的节点嵌入以便在我的项目中进一步使用。


current_file = '.'

train_dataset = UPFD(current_file, 'politifact', 'spacy', 'train', ToUndirected())
val_dataset = UPFD(current_file, 'politifact', 'spacy', 'val', ToUndirected())
test_dataset = UPFD(current_file, 'politifact', 'spacy', 'test', ToUndirected())

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# before_training = train_dataset[0].x
# print('Feature vector(node embedding) of datapoint #0 (before gtn):\n\t', train_dataset[0].x)

class GraphTransformer(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
        self.concat = concat

        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = TransformerConv(hidden_channels, hidden_channels)
        self.conv3 = TransformerConv(hidden_channels, hidden_channels)

        if self.concat:
            self.lin0 = Linear(in_channels, hidden_channels)
            self.lin1 = Linear(2 * hidden_channels, hidden_channels)

        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        h = self.conv1(x, edge_index).relu()
        h = self.conv2(h, edge_index).relu()
        h = self.conv3(h, edge_index).relu()
        h = global_max_pool(h, batch)

        if self.concat:
            # Get the root node (tweet) features of each graph:
            root = (batch[1:] - batch[:-1]).nonzero(as_tuple=False).view(-1)
            root = torch.cat([root.new_zeros(1), root + 1], dim=0)
            news = x[root]

            news = self.lin0(news).relu()
            h = self.lin1(torch.cat([news, h], dim=-1)).relu()

        h = self.lin2(h)
        return h.log_softmax(dim=-1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphTransformer(train_dataset.num_features, 128, train_dataset.num_classes, concat=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)


在您的代码中,模型 GraphTransformer 似乎执行图分类。在向前传递模型各层的过程中,节点嵌入会被更新和修改。但是,在前向传递期间所做的修改不会影响 train_dataset 中存储的原始节点嵌入,因为 PyTorch 对数据副本进行操作,而不是原始数据本身。



class GraphTransformer(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, concat=False):
        self.concat = concat

        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = TransformerConv(hidden_channels, hidden_channels)
        self.conv3 = TransformerConv(hidden_channels, hidden_channels)

        if self.concat:
            self.lin0 = Linear(in_channels, hidden_channels)
            self.lin1 = Linear(2 * hidden_channels, hidden_channels)

        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        h = self.conv1(x, edge_index).relu()
        h = self.conv2(h, edge_index).relu()
        h = self.conv3(h, edge_index).relu()
        h = global_max_pool(h, batch)

        if self.concat:
            # Get the root node (tweet) features of each graph:
            root = (batch[1:] - batch[:-1]).nonzero(as_tuple=False).view(-1)
            root = torch.cat([root.new_zeros(1), root + 1], dim=0)
            news = x[root]

            news = self.lin0(news).relu()
            h = self.lin1(torch.cat([news, h], dim=-1)).relu()

        h = self.lin2(h)
        return h.log_softmax(dim=-1), h  # Return both the logits and the final embeddings

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphTransformer(train_dataset.num_features, 128, train_dataset.num_classes, concat=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)

现在,当您前向传递模型时,您可以捕获 logits(用于分类)和最终嵌入:

logits, final_embeddings = model(x, edge_index, batch)


