为什么 torch_geometric rgat.conv 在 Index_select() 函数上抛出这个错误?

问题描述 投票:0回答:0

我需要使用这个关系 GAT 层。我正在将它用于另一项任务,但它一直在抛出错误,即使我认为我正确地遵循了文档。我写了另一段代码,为 x、edge_index、edge_attr 生成随机张量,但它似乎抛出了同样的错误。类的原始源代码有问题吗?

class RGAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_relations):
        super().__init__()
        self.conv1 = RGATConv(in_channels, hidden_channels, num_relations)
        self.conv2 = RGATConv(hidden_channels, hidden_channels, num_relations)
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, edge_attr):
        x = self.conv1(x, edge_index, edge_attr = edge_attr).relu()
        x = self.conv2(x, edge_index, edge_attr = edge_attr).relu()
        x = self.lin(x)
        return F.log_softmax(x, dim=-1)

import torch

in_channels = 16
hidden_channels = 32
out_channels = 2
num_relations = 4
num_nodes = 5
x = torch.randn((num_nodes, in_channels))
num_edges = 7
edge_index = torch.randint(low=0, high=num_nodes, size=(2, num_edges))
edge_attr = torch.randn((num_edges, num_relations))

# Instantiate the model
model = RGAT(in_channels, hidden_channels, out_channels, num_relations)

# Pass the sample input through the model
output = model(x, edge_index, edge_attr)
print(output.shape) # should output torch.Size([num_nodes, out_channels])
TypeError: index_select() received an invalid combination of arguments - got (Parameter, int, NoneType), but expected one of:
 * (Tensor input, int dim, Tensor index, *, Tensor out)
 * (Tensor input, name dim, Tensor index, *, Tensor out)
machine-learning pytorch torch
© www.soinside.com 2019 - 2024. All rights reserved.