使用 torch.no_grad() 在评估模式期间更改序列长度

问题描述 投票:0回答:1

我构建了一个 TransformerEncoder 模型,如果我在评估模式期间使用“with torch.no_grad()”,它会更改输出的序列长度。

我的型号详细信息:

class TransEnc(nn.Module):
    def __init__(self,ntoken: int,encoder_embedding_dim: int,max_item_count: int,encoder_num_heads: int,encoder_hidden_dim: int,encoder_num_layers: int,padding_idx: int,dropout: float = 0.2):
        super().__init__()
        self.encoder_embedding = nn.Embedding(ntoken, encoder_embedding_dim, padding_idx=padding_idx)
        self.pos_encoder = PositionalEncoding(encoder_embedding_dim, max_item_count, dropout)
        encoder_layers = nn.TransformerEncoderLayer(encoder_embedding_dim, encoder_num_heads, encoder_hidden_dim, dropout, batch_first=True) 
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, encoder_num_layers)
        self.encoder_embedding_dim = encoder_embedding_dim

def forward(self,src: torch.Tensor,src_key_padding_mask: torch.Tensor = None) -> torch.Tensor:
        src = self.encoder_embedding(src.long()) * math.sqrt(self.encoder_embedding_dim)
        src = self.pos_encoder(src)
        src = self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask)

batch_size = 32
ntoken = 4096
encoder_embedding_dim = 256
max_item_count = 64 # max sequence length with padding
encoder_num_heads = 8
encoder_hidden_dim = 256
encoder_num_layers = 4
padding_idx = 0

我有一个张量 (src),包含 32 个单词级标记化句子(具有不同的填充),形状为 (32,64)(batch_size,max_item_count)。

当我使用“model.train()”激活训练模式时,设置“src_key_padding_mask = src == tokenizer.pad_token_id”并运行“logits = model(src = src, src_key_padding_mask = src_key_padding_mask)”,我得到具有预期形状的 logits (32,64,256)(batch_size,max_item_count,encoder_embedding_dim)。

但是,当我使用“model.eval()”激活评估模式时,设置“src_key_padding_mask = src == tokenizer.pad_token_id”并使用“torch.no_grad(): logits = model(src = src, src_key_padding_mask = src_key_padding_mask) 运行”,我每次都会得到不同的 logits 形状,例如 (32,31,256)、(32,25,256) 等。我想得到形状为 (32,64,256) 的 logits。我该如何解决这个问题?

操作系统:Windows 10 x64

Python:2012 年 10 月 3 日

Torch:1.13.1+cu117和2.0.1+cu117都试过了,问题还是一样。

python deep-learning pytorch transformer-model encoder-decoder
1个回答
0
投票

降级成功了!如果你也遇到这个问题,不妨试试

pip uninstall torch torchvision torchaudio
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
© www.soinside.com 2019 - 2024. All rights reserved.