我无法在 pytorch 中初始化我的模型并获取:
TypeError Traceback (most recent call last)
<ipython-input-82-9bfee30a439d> in <module>()
288 dataset = News_Dataset(true_path=args.true_news_file,
fake_path=args.fake_news_file,
289 embeddings_path=args.embeddings_file)
--> 290 classifier = News_classifier_resnet_based().cuda()
291 try:
292 classifier.load_state_dict(torch.load(args.model_state_file))
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in
__call__(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
TypeError: forward() missing 1 required positional argument: 'input'
这是我的代码:
class News_classifier_resnet_based(torch.nn.Module):
def __init__(self):
super().__init__()
self.activation = torch.nn.ReLU6()
self.sigmoid = torch.nn.Sigmoid()
self.positional_encodings = PositionalEncoder()
self.resnet = list(torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).children())
self.to_appropriate_shape = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=77)
self.conv1 = torch.nn.Conv2d(in_channels=1,out_channels=64,kernel_size=7,stride=2,padding=3)
self.conv1.weight = torch.nn.Parameter(self.resnet[0].weight[:,0,:,:].data)
self.center = torch.nn.Sequential(*self.resnet[1:-2])
self.conv2 = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1)
self.conv3 = torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=7)
self.title_conv = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=3,stride=3),
self.activation(),
torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=2,stride=2),
self.activation(),
torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=2,stride=2)
)
self.title_lin = torch.nn.Linear(25,1)
self.year_lin = torch.nn.Linear(10,1)
self.month_lin = torch.nn.Linear(12,1)
self.day_lin = torch.nn.Linear(31,1)
self.date_lin = torch.nn.Linear(3,1)
self.final_lin = torch.nn.Linear(3,1)
def forward(self,x_in):
#input shape - (batch_size, 3+title_len+seq_len, embedding_dim)
#output shape - (batch_size, 1)
year = x_in[:,0,:10]
month = x_in[:,1,:12]
day = x_in[:,2,:31]
title = x_in[:,3:3+args.title_len,:]
text = x_in[:,3+args.title_len:,:]
title = self.positional_encodings(title)
text = self.positional_encodings(text)
text = text.unsqueeze(1)
text = self.activation(self.to_appropriate_shape(text))
text = self.activation(self.conv1(text))
text = self.activation(self.center(text))
text = self.activation(self.conv2(text))
text = self.activation(self.conv3(text))
text = text.reshape(args.batch_size,-1)
title = title.unsqueeze(1)
title = self.activation(self.title_conv(title))
title = title.reshape(args.batch_size,-1)
title = self.activation(self.title_lin(title))
year = self.activation(self.year_lin(year))
month = self.activation(self.month_lin(month))
day = self.activation(self.day_lin(day))
date = torch.cat([year,month,day], dim=-1)
date = self.activation(self.date_lin(date))
final = torch.cat([date,title,text], dim=-1)
final = self.sigmoid(self.final_lin(final))
return final
classifier = News_classifier_resnet_based().cuda()
我该怎么办?我正在尝试使用词嵌入对文本进行分类,但问题出在最后一行。我在谷歌合作实验室工作。另外,当我在其他代码块中创建一些模型时,我没有遇到任何问题。
问题出在你的
init
函数上。当您创建 title_conv
时,而不是传递先前创建的激活对象,您将在不带参数的情况下调用激活。您可以通过更改这部分代码来修复它:
self.title_conv = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=3,stride=3),
self.activation, # Notice I have removed ()
torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=2,stride=2),
self.activation, # Notice I have removed ()
torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=2,stride=2)
)
如果通过以下任一方式更改
self.title_conv
,则不会出现该错误。
self.title_conv = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=3,stride=3),
torch.nn.ReLU6(),
torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=2,stride=2),
torch.nn.ReLU6(),
torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=2,stride=2)
)
或如:
self.title_conv = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=3,stride=3),
self.activation,
torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=2,stride=2),
self.activation,
torch.nn.Conv2d(in_channels=1,out_channels=1,kernel_size=2,stride=2)
)
我不确定原因。我认为这可能是因为
self.activation
是一个函数,并且通过将其写为 self.activation()
来再次进行调用。