我尝试在pytorch中使用view()但我不能输入3个参数。我不知道为什么它一直给出这个错误?谁能帮我这个?
def forward(self, input):
lstm_out, self.hidden = self.lstm(input.view(len(input), self.batch_size, -1))
看起来你的input
是一个numpy数组,而不是火炬张量。你需要先转换它,比如input = torch.Tensor(input)
。