使用 torch.no_grad() 在评估模式期间更改序列长度

问题描述 投票:0回答:1

我构建了一个 TransformerEncoder 模型,它在评估模式期间更改了输出的序列长度。

我的型号详细信息:

class TransEnc(nn.Module):
    def __init__(self,ntoken: int,encoder_embedding_dim: int,max_item_count: int,encoder_num_heads: int,encoder_hidden_dim: int,encoder_num_layers: int,padding_idx: int,dropout: float = 0.2):
        super().__init__()
        self.encoder_embedding = nn.Embedding(ntoken, encoder_embedding_dim, padding_idx=padding_idx)
        self.pos_encoder = PositionalEncoding(encoder_embedding_dim, max_item_count, dropout)
        encoder_layers = nn.TransformerEncoderLayer(encoder_embedding_dim, encoder_num_heads, encoder_hidden_dim, dropout, batch_first=True) 
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, encoder_num_layers)
        self.encoder_embedding_dim = encoder_embedding_dim

def forward(self,src: torch.Tensor,src_key_padding_mask: torch.Tensor = None) -> torch.Tensor:
        src = self.encoder_embedding(src.long()) * math.sqrt(self.encoder_embedding_dim)
        src = self.pos_encoder(src)
        src = self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask)

batch_size = 32
ntoken = 4096
encoder_embedding_dim = 256
max_item_count = 64 # max sequence length with padding
encoder_num_heads = 8
encoder_hidden_dim = 256
encoder_num_layers = 4
padding_idx = tokenizer.pad_token_id

我使用 Transformers 库中的 BertTokenizer 作为 tokenizer,并且 tokenizer.pad_token_id = 0。我有一个张量 (src),其中包含 32 个单词级标记化句子(具有不同的填充),形状为 (32,64)(batch_size,max_item_count) .

当我使用 model.train() 激活训练模式,设置 src_key_padding_mask = src == tokenizer.pad_token_id 并运行 logits = model(src = src, src_key_padding_mask = src_key_padding_mask) 时,我得到预期形状为 (32,64,256) 的 logits (batch_size、max_item_count、encoder_embedding_dim)。

但是,当我使用 model.eval() 激活评估模式时,设置 src_key_padding_mask = src == tokenizer.pad_token_id 并使用 torch.no_grad() 运行: logits = model(src = src, src_key_padding_mask = src_key_padding_mask) ,我得到不同的 logits ' 每次形状都像 (32,31,256)、(32,25,256) 等。我想获得形状为 (32,64,256) 的 logits。我该如何解决这个问题?

操作系统:Windows 10 x64 蟒蛇:3.10.12 Torch:1.13.1+cu117和2.0.1+cu117都试过了,问题还是一样

python deep-learning pytorch transformer-model encoder-decoder
1个回答
0
投票

降级成功了!如果你也遇到这个问题,不妨试试

pip uninstall torch torchvision torchaudio
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
© www.soinside.com 2019 - 2024. All rights reserved.