我构建了一个 TransformerEncoder 模型,如果我在评估模式期间使用“with torch.no_grad()”,它会更改输出的序列长度。
我的型号详细信息:
class TransEnc(nn.Module):
def __init__(self,ntoken: int,encoder_embedding_dim: int,max_item_count: int,encoder_num_heads: int,encoder_hidden_dim: int,encoder_num_layers: int,padding_idx: int,dropout: float = 0.2):
super().__init__()
self.encoder_embedding = nn.Embedding(ntoken, encoder_embedding_dim, padding_idx=padding_idx)
self.pos_encoder = PositionalEncoding(encoder_embedding_dim, max_item_count, dropout)
encoder_layers = nn.TransformerEncoderLayer(encoder_embedding_dim, encoder_num_heads, encoder_hidden_dim, dropout, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, encoder_num_layers)
self.encoder_embedding_dim = encoder_embedding_dim
def forward(self,src: torch.Tensor,src_key_padding_mask: torch.Tensor = None) -> torch.Tensor:
src = self.encoder_embedding(src.long()) * math.sqrt(self.encoder_embedding_dim)
src = self.pos_encoder(src)
src = self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask)
与
batch_size = 32
ntoken = 4096
encoder_embedding_dim = 256
max_item_count = 64 # max sequence length with padding
encoder_num_heads = 8
encoder_hidden_dim = 256
encoder_num_layers = 4
padding_idx = 0
我有一个张量 (src),包含 32 个单词级标记化句子(具有不同的填充),形状为 (32,64)(batch_size,max_item_count)。
当我使用“model.train()”激活训练模式时,设置“src_key_padding_mask = src == tokenizer.pad_token_id”并运行“logits = model(src = src, src_key_padding_mask = src_key_padding_mask)”,我得到具有预期形状的 logits (32,64,256)(batch_size,max_item_count,encoder_embedding_dim)。
但是,当我使用“model.eval()”激活评估模式时,设置“src_key_padding_mask = src == tokenizer.pad_token_id”并使用“torch.no_grad(): logits = model(src = src, src_key_padding_mask = src_key_padding_mask) 运行”,我每次都会得到不同的 logits 形状,例如 (32,31,256)、(32,25,256) 等。我想得到形状为 (32,64,256) 的 logits。我该如何解决这个问题?
操作系统:Windows 10 x64
Python:2012 年 10 月 3 日
Torch:1.13.1+cu117和2.0.1+cu117都试过了,问题还是一样。
降级成功了!如果你也遇到这个问题,不妨试试
pip uninstall torch torchvision torchaudio
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116