为什么我用了nn.TransformerEncoders模型没有提升?

问题描述 投票:0回答:0

我正在做一个ViT图片分类任务。

这是我的代码。当我使用别处的Transformer模型时,模型可以正常训练。但是当我使用 nn.Transformer 时,模型的损失不会正常减少。

self.transformer_encoder:有问题

self.another_transformer_encoder:效果很好

class ViT(nn.Module):
    def __init__(self, image_size, patch_size, channels, num_classes, dim, depth, heads, mlp_dim):
        super(ViT, self).__init__()

        # Compute number of patches
        self.num_patches = (image_size // patch_size) ** 2

        # Patch embedding layer
        self.patch_embedding = nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size, bias=False)

        # nn.TransformerEncoderLayer and nn.TransformerEncoder
        encoder_layers = TransformerEncoderLayer(
            d_model=dim,
            nhead=heads,
            dim_feedforward=mlp_dim,
            dropout=0.5,
            batch_first=True
        )
        self.transformer_encoder = TransformerEncoder(
            encoder_layer = encoder_layers,
            num_layers=depth
        )
        # this is taken from elsewhere
        self.another_transformer_encoder = Transformer(dim, depth, heads, mlp_dim)

        # Position embedding layer
        self.pos_encoder = PositionalEncoding(dim)

        # Classification head
        self.cls = nn.Parameter(torch.randn(1, 1, dim))
        self.classification_head = nn.Linear(dim, num_classes)

    def forward(self, x):
        # Split input image into patches
        patches = self.patch_embedding(x)    # (batch_size, dim, num_patches_h, num_patches_w)
        patches = patches.flatten(2)   # (batch_size, num_patches, dim)
        patches = patches.transpose(1, 2)    # (batch_size, num_patches, dim)

        # Add position embedding to patches
        patches = self.pos_encoder(patches)

        # Add CLS token
        cls_token = self.cls.expand(x.shape[0], -1, -1)    # (batch_size, 1, dim)
        patches = torch.cat([cls_token, patches], dim=1)   # (batch_size, num_patches+1, dim)

        # Pass patches through transformer encoder layers
        patches = self.transformer_encoder(patches) # bad
        # patches = self.another_transformer_encoder(patches) # good

        # Extract CLS token and pass through classification head
        cls_token = patches[:, 0]
        output = self.classification_head(cls_token)
        return output

这是模型调用部分。

model = model = ViT(image_size=28,
                    patch_size=7,
                    channels=1,
                    num_classes=10,
                    dim=64,
                    depth=6,
                    heads=8,
                    mlp_dim=128
                    ).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

我在

batch_first=True
里加了
TransformerEncoderLayer
,但是没用

python deep-learning pytorch transformer-model
© www.soinside.com 2019 - 2024. All rights reserved.