我是深度学习的新手,这几天我一直在阅读几个模型的源代码或像 pytorch-image-models 这样的开源仓库,但发现它们在实现方式上有很大不同(例如 Self-变压器的注意操作)。这里有一些例子:
来自开源
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, 3*dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
"""
x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim)
"""
B, T, N, C = x.shape
# result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# (B, H, T, N, C'), permute -> (B, T, N, C', H)
x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x # (B, T, N, C)
来自发布代码的某种论文
class MultiAttention(nn.Module):
def __init__(self, input_size=1024, output_size=1024, freq=10000, pos_enc=None,
num_segments=None, heads=1, fusion=None):
"""
:param int input_size: The expected input feature size.
:param int output_size: The hidden feature size of the attention mechanisms.
:param int freq: The frequency of the sinusoidal positional encoding.
:param None | str pos_enc: The selected positional encoding [absolute, relative].
:param None | int num_segments: The selected number of segments to split the videos.
:param int heads: The selected number of global heads.
:param None | str fusion: The selected type of feature fusion.
"""
super(MultiAttention, self).__init__()
# Global Attention, considering differences among all frames
self.attention = SelfAttention(input_size=input_size, output_size=output_size,
freq=freq, pos_enc=pos_enc, heads=heads)
self.num_segments = num_segments
if self.num_segments is not None:
assert self.num_segments >= 2, "num_segments must be None or 2+"
self.local_attention = nn.ModuleList()
for _ in range(self.num_segments):
# Local Attention, considering differences among the same segment with reduce hidden size
self.local_attention.append(
SelfAttention(input_size=input_size, output_size=output_size // num_segments,
freq=freq, pos_enc=pos_enc, heads=4))
self.permitted_fusions = ["add", "mult", "avg", "max"]
self.fusion = fusion
if self.fusion is not None:
self.fusion = self.fusion.lower()
assert self.fusion in self.permitted_fusions, f"Fusion method must be: {*self.permitted_fusions,}"
def forward(self, x):
""" Compute the weighted frame features, based on the global and locals (multi-head) attention mechanisms.
:param torch.Tensor x: Tensor with shape [T, input_size] containing the frame features.
:return: A tuple of:
weighted_value: Tensor with shape [T, input_size] containing the weighted frame features.
attn_weights: Tensor with shape [T, T] containing the attention weights.
"""
weighted_value, attn_weights = self.attention(x) # global attention
if self.num_segments is not None and self.fusion is not None:
segment_size = math.ceil(x.shape[0] / self.num_segments)
for segment in range(self.num_segments):
left_pos = segment * segment_size
right_pos = (segment + 1) * segment_size
local_x = x[left_pos:right_pos]
weighted_local_value, attn_local_weights = self.local_attention[segment](local_x) # local attentions
# Normalize the features vectors
weighted_value[left_pos:right_pos] = F.normalize(weighted_value[left_pos:right_pos].clone(), p=2, dim=1)
weighted_local_value = F.normalize(weighted_local_value, p=2, dim=1)
if self.fusion == "add":
weighted_value[left_pos:right_pos] += weighted_local_value
elif self.fusion == "mult":
weighted_value[left_pos:right_pos] *= weighted_local_value
elif self.fusion == "avg":
weighted_value[left_pos:right_pos] += weighted_local_value
weighted_value[left_pos:right_pos] /= 2
elif self.fusion == "max":
weighted_value[left_pos:right_pos] = torch.max(weighted_value[left_pos:right_pos].clone(),
weighted_local_value)
return weighted_value, attn_weights
我的想法是两者的效果应该差不多,第一个代码更简洁优雅,可读性更好。 但事实并非如此,所以我想请问是什么导致了这个问题,请帮助我!