PyTorch:预期输入 batch_size (96) 以匹配目标 batch_size (24)

问题描述 投票:0回答:0

我试用了 PyTorch,想为 Gender_Classification 编写一个程序。但是,我收到错误消息: 预期输入 batch_size (96) 以匹配目标 batch_size (24) 我搜索了一个解决方案,但我不明白我的代码有什么问题。

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()

        self.conv1=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.bn1=nn.BatchNorm2d(num_features=64)
        self.relu1=nn.ReLU()
        self.pool=nn.MaxPool2d(kernel_size=2)
        
        self.conv2=nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,stride=1,padding=1)
        self.relu2=nn.ReLU()
        
        self.conv3=nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,stride=1,padding=1)
        self.bn3=nn.BatchNorm2d(num_features=256)
        self.relu3=nn.ReLU() 
        self.fc=nn.Linear(in_features=256*56*56,out_features=2)
        
    def forward(self,input):
        print(input.shape)
        output=self.conv1(input)
        print("1.1: {}".format(output.shape))
        output=self.bn1(output)
        output=self.relu1(output)  
        output=self.pool(output)
          
        output=self.conv2(output) 
        output=self.relu2(output)

        output=self.conv3(output)
        output=self.bn3(output)
        output=self.relu3(output)
        print("3.3: {}".format(output.shape))
        output=output.view(-1, 256*56*56)
        print("Flatten: {}".format(output.shape))
        output=self.fc(output)
        print("Output: {}".format(output.shape))

        return output


def train_model(model, criterior, optimizer, scheduler, num_epochs):
    # since = time.time()

    # best_model_wts = copy.deepcopy(model.state_dict())
    # best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloader_dict[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    print(outputs.shape)
                    loss = criterior(outputs, labels)
                    _, preds = torch.max(outputs, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / len(dataloader_dict[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloader_dict[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

输出应该显示纪元和一些其他信息。实际上,我打印出了张量的形状,但我不知道哪里出了问题。这是错误信息:

input: torch.Size([24, 3, 224, 224])
1.1: torch.Size([24, 64, 224, 224])
3.3: torch.Size([24, 256, 112, 112])
Flatten: torch.Size([96, 802816])
Output: torch.Size([96, 2])
torch.Size([96, 2])
Traceback (most recent call last):
  File "C:\Users\long1\OneDrive\Máy tính\Gender_Classification\main.py", line 209, in <module>
    model = train_model(model, criterior, optimizer, exp_lr_scheduler, num_epochs=10)
  File "C:\Users\long1\OneDrive\Máy tính\Gender_Classification\main.py", line 185, in train_model
    loss = criterior(outputs, labels)
  File "C:\Users\long1\anaconda3\envs\py10\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\long1\anaconda3\envs\py10\lib\site-packages\torch\nn\modules\loss.py", line 1174, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "C:\Users\long1\anaconda3\envs\py10\lib\site-packages\torch\nn\functional.py", line 3029, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (96) to match target batch_size (24).
python python-3.x pytorch pycharm conv-neural-network
© www.soinside.com 2019 - 2024. All rights reserved.