我需要使用这个关系 GAT 层。我正在将它用于另一项任务,但它一直在抛出错误,即使我认为我正确地遵循了文档。我写了另一段代码,为 x、edge_index、edge_attr 生成随机张量,但它似乎抛出了同样的错误。类的原始源代码有问题吗?
class RGAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels,
num_relations):
super().__init__()
self.conv1 = RGATConv(in_channels, hidden_channels, num_relations)
self.conv2 = RGATConv(hidden_channels, hidden_channels, num_relations)
self.lin = torch.nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index, edge_attr):
x = self.conv1(x, edge_index, edge_attr = edge_attr).relu()
x = self.conv2(x, edge_index, edge_attr = edge_attr).relu()
x = self.lin(x)
return F.log_softmax(x, dim=-1)
import torch
in_channels = 16
hidden_channels = 32
out_channels = 2
num_relations = 4
num_nodes = 5
x = torch.randn((num_nodes, in_channels))
num_edges = 7
edge_index = torch.randint(low=0, high=num_nodes, size=(2, num_edges))
edge_attr = torch.randn((num_edges, num_relations))
# Instantiate the model
model = RGAT(in_channels, hidden_channels, out_channels, num_relations)
# Pass the sample input through the model
output = model(x, edge_index, edge_attr)
print(output.shape) # should output torch.Size([num_nodes, out_channels])
TypeError: index_select() received an invalid combination of arguments - got (Parameter, int, NoneType), but expected one of:
* (Tensor input, int dim, Tensor index, *, Tensor out)
* (Tensor input, name dim, Tensor index, *, Tensor out)