Tensorflow CNN形状不匹配

问题描述 投票:0回答:1
def load_data(data_path, batch_size, num_workers=2):
    t_m  = transforms.Compose(
            [transforms.Grayscale(num_output_channels=1),
             transforms.Resize((400,400)),
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

            ])

    dataset = torchvision.datasets.ImageFolder(root = data_path, transform=t_m)
    # print (np.shape(dataset))
    #split
    train, test = torch.utils.data.random_split(dataset, [int( len(dataset) * 0.7 ),  len(dataset) - int( len(dataset) * 0.7 ) ])


    trainloader = torch.utils.data.DataLoader(train, batch_size=batch_size,
                                          shuffle=True, num_workers=num_workers,drop_last = True)
    testloader = torch.utils.data.DataLoader(test, batch_size=batch_size,
                                         shuffle=False, num_workers=num_workers, drop_last = False)


    return dataset,trainloader,testloader





import torch.nn as nn
model = torch.nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=5, padding=2),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=5, padding=2),
    nn.MaxPool2d(2, 2),
    nn.Linear ( 7 * 7 * 64, 1000),
    nn.Linear(1000, 600),
    nn.Linear(600, 200),
    nn.Linear(200, 10)


)



#Training

total_epochs = 5
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters())

for epoch in tqdm(range(total_epochs)):

  #initialize 
  batch_count = 0
  gc.collect()
  loop_loss = 0.0

  for img in (trainloader):

      input_, label_ = img

      # print (input_.shape)

      out = model(input_)
      out= nn.functional.relu(out)



      loss = criterion(out, label_)
      loss.backward()

      optimizer.zero_grad()
      optimizer.step()

      loop_loss = loop_loss + loss.item()
      batch_count = batch_count + 1

      print('batch_loss: ', str(loss.item()))

  print('Epochs completed:', epoch+1,'\n')
  print('epoch_loss = ' + loop_loss/float(batch_count))

尺寸不匹配,m1: [25600 x 100], m2: [3136 x 1000] 在 pytorchatensrcTHgenericTHTensorMath.cpp:41处。

请解释一下形状哪里出了问题?我应该如何解决这个问题?我是新来的,所以这可能不是一个好问题,但任何细节都会有帮助输入图像的大小为400,400,并从rgb转换为灰色

deep-learning data-science torch cnn
1个回答
0
投票

你的问题出在第一个线性层。总是这样编码,这样你就可以自己想办法了。

class MyModel(nn.Module):
    def __init__(self, params):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(...)
        self.fc = nn.Linear(...)
    def forward(self, x):
        x = self.conv1(x)
        import pdb; pdb.set_trace()
        x = self.fc(x)
        return x

这样你就可以把pdb放在你想要的地方,你可以用x.shape命令检查形状。你的问题出在conv层输出的形状和你的第一个Linear层不匹配。

© www.soinside.com 2019 - 2024. All rights reserved.