如何获得 GCN 中的最终节点嵌入?

问题描述 投票:0回答:1

我正在 UPFD(假新闻检测)图数据集上构建 GCN。我的代码执行图分类。我需要获得最终的节点嵌入以便在我的项目中进一步使用。

这是到目前为止的代码:

current_file = '.'

train_dataset = UPFD(current_file, 'politifact', 'spacy', 'train', ToUndirected())
val_dataset = UPFD(current_file, 'politifact', 'spacy', 'val', ToUndirected())
test_dataset = UPFD(current_file, 'politifact', 'spacy', 'test', ToUndirected())

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# before_training = train_dataset[0].x
# print('Feature vector(node embedding) of datapoint #0 (before gtn):\n\t', train_dataset[0].x)


class GraphTransformer(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 concat=False):
        super().__init__()
        self.concat = concat

        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = TransformerConv(hidden_channels, hidden_channels)
        self.conv3 = TransformerConv(hidden_channels, hidden_channels)

        if self.concat:
            self.lin0 = Linear(in_channels, hidden_channels)
            self.lin1 = Linear(2 * hidden_channels, hidden_channels)

        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        h = self.conv1(x, edge_index).relu()
        h = self.conv2(h, edge_index).relu()
        h = self.conv3(h, edge_index).relu()
        h = global_max_pool(h, batch)

        if self.concat:
            # Get the root node (tweet) features of each graph:
            root = (batch[1:] - batch[:-1]).nonzero(as_tuple=False).view(-1)
            root = torch.cat([root.new_zeros(1), root + 1], dim=0)
            news = x[root]

            news = self.lin0(news).relu()
            h = self.lin1(torch.cat([news, h], dim=-1)).relu()

        h = self.lin2(h)
        return h.log_softmax(dim=-1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphTransformer(train_dataset.num_features, 128, train_dataset.num_classes, concat=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)

我尝试在模型之后打印节点嵌入,但它给出了与通过模型之前相同的嵌入。是由库函数制作并修改了副本还是修改了原始副本?我该怎么办?有没有办法得到最终的节点嵌入?

deep-learning pytorch transformer-model graph-neural-network
1个回答
0
投票

在您的代码中,模型 GraphTransformer 似乎执行图分类。在向前传递模型各层的过程中,节点嵌入会被更新和修改。但是,在前向传递期间所做的修改不会影响 train_dataset 中存储的原始节点嵌入,因为 PyTorch 对数据副本进行操作,而不是原始数据本身。

您遇到的嵌入保持不变的问题可能是因为您没有在代码中的正确位置捕获更新的嵌入。要在模型处理数据后获得最终的节点嵌入,您应该在模型的前向传递中捕获嵌入,特别是在应用所有层之后。

以下是如何捕获模型中最终节点嵌入的示例:

class GraphTransformer(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, concat=False):
        super().__init__()
        self.concat = concat

        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = TransformerConv(hidden_channels, hidden_channels)
        self.conv3 = TransformerConv(hidden_channels, hidden_channels)

        if self.concat:
            self.lin0 = Linear(in_channels, hidden_channels)
            self.lin1 = Linear(2 * hidden_channels, hidden_channels)

        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        h = self.conv1(x, edge_index).relu()
        h = self.conv2(h, edge_index).relu()
        h = self.conv3(h, edge_index).relu()
        h = global_max_pool(h, batch)

        if self.concat:
            # Get the root node (tweet) features of each graph:
            root = (batch[1:] - batch[:-1]).nonzero(as_tuple=False).view(-1)
            root = torch.cat([root.new_zeros(1), root + 1], dim=0)
            news = x[root]

            news = self.lin0(news).relu()
            h = self.lin1(torch.cat([news, h], dim=-1)).relu()

        h = self.lin2(h)
        return h.log_softmax(dim=-1), h  # Return both the logits and the final embeddings

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphTransformer(train_dataset.num_features, 128, train_dataset.num_classes, concat=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)

现在,当您前向传递模型时,您可以捕获 logits(用于分类)和最终嵌入:

logits, final_embeddings = model(x, edge_index, batch)

这将使您可以在模型处理数据后访问修改后的节点嵌入,您可以将其用于项目中的进一步分析。

© www.soinside.com 2019 - 2024. All rights reserved.