戒酒的李白

Calculates the scaling dot product attention

1 import torch 1 import torch
2 import torch.nn as nn 2 import torch.nn as nn
  3 +import torch.nn.functional as F
3 4
4 class MultiHeadAttentionLayer(nn.Module): 5 class MultiHeadAttentionLayer(nn.Module):
5 def __init__(self, embed_size, num_heads): 6 def __init__(self, embed_size, num_heads):
@@ -19,16 +20,21 @@ class MultiHeadAttentionLayer(nn.Module): @@ -19,16 +20,21 @@ class MultiHeadAttentionLayer(nn.Module):
19 N = query.shape[0] # batch_size 20 N = query.shape[0] # batch_size
20 21
21 # Linear transformations for Q, K, V 22 # Linear transformations for Q, K, V
22 - Q = self.q_linear(query) # shape: (N, seq_len, embed_size)  
23 - K = self.k_linear(keys) # shape: (N, seq_len, embed_size)  
24 - V = self.v_linear(values) # shape: (N, seq_len, embed_size) 23 + Q = self.q_linear(query)
  24 + K = self.k_linear(keys)
  25 + V = self.v_linear(values)
25 26
26 - # Reshape Q, K, V into multiple heads 27 + # Reshape into multiple heads
27 Q = Q.reshape(N, -1, self.num_heads, self.head_dim) 28 Q = Q.reshape(N, -1, self.num_heads, self.head_dim)
28 K = K.reshape(N, -1, self.num_heads, self.head_dim) 29 K = K.reshape(N, -1, self.num_heads, self.head_dim)
29 V = V.reshape(N, -1, self.num_heads, self.head_dim) 30 V = V.reshape(N, -1, self.num_heads, self.head_dim)
30 31
31 - return Q, K, V 32 + # Compute scaled dot-product attention scores
  33 + attention_scores = torch.einsum("nqhd,nkhd->nhqk", [Q, K])
  34 + attention_scores = attention_scores / (self.head_dim ** 0.5)
  35 + attention = torch.softmax(attention_scores, dim=-1) # Normalize
  36 +
  37 + return attention
32 38
33 39
34 if __name__ == "__main__": 40 if __name__ == "__main__":
@@ -36,10 +42,9 @@ if __name__ == "__main__": @@ -36,10 +42,9 @@ if __name__ == "__main__":
36 num_heads = 8 42 num_heads = 8
37 mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) 43 mha_layer = MultiHeadAttentionLayer(embed_size, num_heads)
38 44
39 - # Dummy data  
40 values = torch.randn(2, 10, embed_size) 45 values = torch.randn(2, 10, embed_size)
41 keys = torch.randn(2, 10, embed_size) 46 keys = torch.randn(2, 10, embed_size)
42 query = torch.randn(2, 10, embed_size) 47 query = torch.randn(2, 10, embed_size)
43 48
44 - Q, K, V = mha_layer(values, keys, query)  
45 - print(f"Q shape: {Q.shape}, K shape: {K.shape}, V shape: {V.shape}") 49 + attention = mha_layer(values, keys, query)
  50 + print(f"Attention shape: {attention.shape}")