问题是,当我尝试在 BatchNorm 和 ReLU 之后重塑线性层的输出时(如图所示,由于他们使用了 Tensorflow,所以是密集的),它会抛出错误:TypeError: reshape(): argument 'input' (position 1)必须是 Tensor,而不是 int
我理解该错误,但找不到解决方案。 除了显式调用 torch 之外,还有其他方法可以在 nn.Sequential 中重塑形状吗?
class Generator(nn.Module):
def __init__(self, z_dim=100, im_chan=1, hidden_dim=64, rdim=9216):
super(Generator, self).__init__()
self.z_dim = z_dim
self.gen = nn.Sequential(
nn.Linear(z_dim, rdim),
nn.BatchNorm2d(rdim,momentum=0.9),
nn.ReLU(inplace=True),
----> torch.reshape(rdim, (6,6,256)),
self.make_gen_block(rdim, hidden_dim*2),
self.make_gen_block(hidden_dim*2,hidden_dim),
self.make_gen_block(hidden_dim,im_chan,final_layer=True),
)
def make_gen_block(self, input_channels, output_channels, kernel_size=1, stride=2, final_layer=False):
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.Tanh()
)
def unsqueeze_noise(self, noise):
return noise.view(len(noise), self.zdim, 1, 1)
def forward(self, noise):
x = self.unsqueeze_noise(noise)
return self.gen(x)
def get_noise(n_samples, z_dim, device='cpu'):
return torch.randn(n_samples, z_dim, device=device)
#Testing the Gen arch
gen = Generator()
num_test = 100
#test the hidden block
test_hidden_noise = get_noise(num_test, gen.z_dim)
test_hidden_block = gen.make_gen_block(6, 6, kernel_size=1,stride=2)
test_uns_noise = gen.unsqueeze_noise(test_hidden_noise)
hidden_output = test_hidden_block(test_uns_noise)
在
nn.Sequential
中,torch.nn.Unflatten()
可以帮助您实现重塑操作。
对于
nn.Linear
,其输入形状为(N, *, H_{in})
,输出形状为(H, *, H_{out})
。请注意,特征维度位于最后。所以unsqueeze_noise()
在这里没有用。
根据网络结构,传递给
make_gen_block
的参数是错误的。
我检查了以下代码:
import torch
from torch import nn
class Generator(nn.Module):
def __init__(self, z_dim=100, im_chan=1, hidden_dim=64, rdim=9216):
super(Generator, self).__init__()
self.z_dim = z_dim
self.gen = nn.Sequential(
nn.Linear(z_dim, rdim),
nn.BatchNorm1d(rdim,momentum=0.9), # use BN1d
nn.ReLU(inplace=True),
nn.Unflatten(1, (256,6,6)),
self.make_gen_block(256, hidden_dim*2,kernel_size=2), # note arguments
self.make_gen_block(hidden_dim*2,hidden_dim,kernel_size=2), # note kernel_size
self.make_gen_block(hidden_dim,im_chan,kernel_size=2,final_layer=True), # note kernel_size
)
def make_gen_block(self, input_channels, output_channels, kernel_size=1, stride=2, final_layer=False):
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.Tanh()
)
def forward(self, x):
return self.gen(x)
def get_noise(n_samples, z_dim, device='cpu'):
return torch.randn(n_samples, z_dim, device=device)
gen = Generator()
num_test = 100
input_noise = get_noise(num_test, gen.z_dim)
output = gen(input_noise)
assert output.shape == (num_test, 1, 48, 48)
你可以这样做
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.linear = nn.Linear(Config.LATENT_DIM, 4*4*Config.LATENT_DIM)
self.active_0 = nn.ReLU()
self.conv_trans_1 = nn.ConvTranspose2d(Config.LATENT_DIM, 512, kernel_size=4, stride=2, padding=1)
self.norm_1 = nn.BatchNorm2d(512)
self.active_1 = nn.LeakyReLU(0.2)
self.conv_trans_2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
self.norm_2 = nn.BatchNorm2d(256)
self.active_2 = nn.LeakyReLU(0.2)
self.conv_trans_3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
self.norm_3 = nn.BatchNorm2d(128)
self.active_3 = nn.LeakyReLU(0.2)
self.conv_trans_4 = nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1)
self.norm_4 = nn.BatchNorm2d(3)
self.active_4 = nn.Tanh()
def forward(self, inputs):
x = self.linear(inputs)
x = self.active_0(x)
x = x.view(-1, Config.LATENT_DIM, 4, 4) ################
x = self.conv_trans_1(x)
x = self.norm_1(x)
x = self.active_1(x)
x = self.conv_trans_2(x)
x = self.norm_2(x)
x = self.active_2(x)
x = self.conv_trans_3(x)
x = self.norm_3(x)
x = self.active_3(x)
x = self.conv_trans_4(x)
x = self.norm_4(x)
x = self.active_4(x)
return x