我正在做一个ViT图片分类任务。
这是我的代码。当我使用别处的Transformer模型时,模型可以正常训练。但是当我使用 nn.Transformer 时,模型的损失不会正常减少。
self.transformer_encoder:有问题
self.another_transformer_encoder:效果很好
class ViT(nn.Module):
def __init__(self, image_size, patch_size, channels, num_classes, dim, depth, heads, mlp_dim):
super(ViT, self).__init__()
# Compute number of patches
self.num_patches = (image_size // patch_size) ** 2
# Patch embedding layer
self.patch_embedding = nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size, bias=False)
# nn.TransformerEncoderLayer and nn.TransformerEncoder
encoder_layers = TransformerEncoderLayer(
d_model=dim,
nhead=heads,
dim_feedforward=mlp_dim,
dropout=0.5,
batch_first=True
)
self.transformer_encoder = TransformerEncoder(
encoder_layer = encoder_layers,
num_layers=depth
)
# this is taken from elsewhere
self.another_transformer_encoder = Transformer(dim, depth, heads, mlp_dim)
# Position embedding layer
self.pos_encoder = PositionalEncoding(dim)
# Classification head
self.cls = nn.Parameter(torch.randn(1, 1, dim))
self.classification_head = nn.Linear(dim, num_classes)
def forward(self, x):
# Split input image into patches
patches = self.patch_embedding(x) # (batch_size, dim, num_patches_h, num_patches_w)
patches = patches.flatten(2) # (batch_size, num_patches, dim)
patches = patches.transpose(1, 2) # (batch_size, num_patches, dim)
# Add position embedding to patches
patches = self.pos_encoder(patches)
# Add CLS token
cls_token = self.cls.expand(x.shape[0], -1, -1) # (batch_size, 1, dim)
patches = torch.cat([cls_token, patches], dim=1) # (batch_size, num_patches+1, dim)
# Pass patches through transformer encoder layers
patches = self.transformer_encoder(patches) # bad
# patches = self.another_transformer_encoder(patches) # good
# Extract CLS token and pass through classification head
cls_token = patches[:, 0]
output = self.classification_head(cls_token)
return output
这是模型调用部分。
model = model = ViT(image_size=28,
patch_size=7,
channels=1,
num_classes=10,
dim=64,
depth=6,
heads=8,
mlp_dim=128
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
我在
batch_first=True
里加了TransformerEncoderLayer
,但是没用