我正在尝试实施 DiffusionRecoveryLikelihood EBM,出现张量大小错误,但无法发现我的错误 这是代码:
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(nn.Module):
def __init__(self, module, name='weight', power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()
def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False
def _make_params(self):
w = getattr(self.module, self.name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = Parameter(w.data)
del self.module._parameters[self.name]
self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)
def forward(self, *args):
self._update_u_v()
return self.module.forward(*args)
def get_timestep_embedding(timesteps, embedding_dim: int):
assert len(timesteps.shape) == 1 # and timesteps.dtype == torch.int32
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(0, half_dim) * -emb).to(timesteps.device)
emb = torch.matmul(1.0 * timesteps.reshape(-1, 1), emb.reshape(1, -1))
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, [0, 1, 0, 0])
assert list(emb.shape) == [timesteps.shape[0], embedding_dim]
return emb
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
class wide_basic(nn.Module):
def __init__(self, in_planes, planes, dropout_rate, stride=1, norm=None, leak=.2):
super(wide_basic, self).__init__()
self.norm = norm
self.lrelu = nn.LeakyReLU(leak)
self.bn1 = Identity()
self.conv1 = SpectralNorm(nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True))
self.dropout = Identity() if dropout_rate == 0.0 else nn.Dropout(p=dropout_rate)
self.bn2 = Identity()
self.conv2 = SpectralNorm(nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True))
self.temb_dense = SpectralNorm(nn.Linear(512, planes))
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
self.shortcut = nn.Sequential(
SpectralNorm(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True)),
)
def forward(self, x):
x, temb = x
out = self.bn1(x)
out = self.conv1(self.lrelu(out))
if temb is not None:
# add in timestep embedding
temp_o = self.lrelu(self.temb_dense(temb))
b, l = temp_o.shape
out += temp_o.view(b, l, 1, 1)
out = self.dropout(out)
out = self.bn2(out)
out = self.conv2(self.lrelu(out))
out += self.shortcut(x)
return out, temb
class Wide_ResNet(nn.Module):
def __init__(self, depth, widen_factor, num_classes=10, input_channels=3,
sum_pool=False, norm=None, leak=.2, dropout_rate=0.0):
super(Wide_ResNet, self).__init__()
self.leak = leak
self.in_planes = 16
self.sum_pool = sum_pool
self.norm = norm
self.lrelu = nn.LeakyReLU(leak)
self.n_classes = num_classes
assert ((depth - 4) % 6 == 0), 'Wide-reSpectralNormet depth should be 6n+4'
n = (depth - 4) // 6
k = widen_factor
print('| Wide-ReSpectralNormet %dx%d, SpectralNorm time embedding' % (depth, k))
nStages = [16, 16 * k, 32 * k, 64 * k]
self.layer_one_out = None
self.conv1 = SpectralNorm(conv3x3(input_channels, nStages[0]))
self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1, leak=leak)
self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2, leak=leak)
self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2, leak=leak)
self.bn1 = Identity()
self.last_dim = nStages[3]
# self.linear = SpectralNorm(nn.Linear(nStages[3], num_classes))
self.linear = SpectralNorm(nn.Conv2d(nStages[3], num_classes, kernel_size=(10,1), stride=1))
self.temb_dense_0 = SpectralNorm(nn.Linear(128, 512))
self.temb_dense_1 = SpectralNorm(nn.Linear(512, 512))
self.temb_dense_2 = SpectralNorm(nn.Linear(512, nStages[3]))
def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride, leak=0.2):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, dropout_rate, stride, leak=leak, norm=self.norm))
self.in_planes = planes
return nn.Sequential(*layers)
def forward(self, x, t, logits=False, feature=True):
out = self.conv1(x)
assert x.dtype == torch.float32
if isinstance(t, int) or len(t.shape) == 0:
t = torch.ones(x.shape[0], dtype=torch.int64, device=x.device) * t
temb = get_timestep_embedding(t, 128)
temb = self.temb_dense_0(temb)
temb = self.temb_dense_1(self.lrelu(temb))
out, _ = self.layer1([out, temb])
out, _ = self.layer2([out, temb])
out, _ = self.layer3([out, temb])
out = self.lrelu(self.bn1(out))
if self.sum_pool:
out = out.view(out.size(0), out.size(1), -1).sum(2)
else:
if self.n_classes > 100:
out = F.adaptive_avg_pool2d(out, 1)
else:
out = F.avg_pool2d(out, 8)
temb = self.lrelu(temb)
temb = self.temb_dense_2(temb)
out = out.view(out.size(0),-1)
out *= temb
if logits:
out = self.linear(out)
return out
错误是:
189 # temb = temb.reshape(-1, self.feature_maps)
190 out = out.view(out.size(0), -1)
--> 191 out *= temb
192 if logits:
193 out = self.linear(out)
RuntimeError:张量 a (768) 的大小必须与非单维 1 处的张量 b (192) 的大小匹配
我试过重塑,不同的视图程序并得到相同的错误