Showing
1 changed file
with
13 additions
and
8 deletions
| 1 | import torch | 1 | import torch |
| 2 | import torch.nn as nn | 2 | import torch.nn as nn |
| 3 | +import torch.nn.functional as F | ||
| 3 | 4 | ||
| 4 | class MultiHeadAttentionLayer(nn.Module): | 5 | class MultiHeadAttentionLayer(nn.Module): |
| 5 | def __init__(self, embed_size, num_heads): | 6 | def __init__(self, embed_size, num_heads): |
| @@ -19,16 +20,21 @@ class MultiHeadAttentionLayer(nn.Module): | @@ -19,16 +20,21 @@ class MultiHeadAttentionLayer(nn.Module): | ||
| 19 | N = query.shape[0] # batch_size | 20 | N = query.shape[0] # batch_size |
| 20 | 21 | ||
| 21 | # Linear transformations for Q, K, V | 22 | # Linear transformations for Q, K, V |
| 22 | - Q = self.q_linear(query) # shape: (N, seq_len, embed_size) | ||
| 23 | - K = self.k_linear(keys) # shape: (N, seq_len, embed_size) | ||
| 24 | - V = self.v_linear(values) # shape: (N, seq_len, embed_size) | 23 | + Q = self.q_linear(query) |
| 24 | + K = self.k_linear(keys) | ||
| 25 | + V = self.v_linear(values) | ||
| 25 | 26 | ||
| 26 | - # Reshape Q, K, V into multiple heads | 27 | + # Reshape into multiple heads |
| 27 | Q = Q.reshape(N, -1, self.num_heads, self.head_dim) | 28 | Q = Q.reshape(N, -1, self.num_heads, self.head_dim) |
| 28 | K = K.reshape(N, -1, self.num_heads, self.head_dim) | 29 | K = K.reshape(N, -1, self.num_heads, self.head_dim) |
| 29 | V = V.reshape(N, -1, self.num_heads, self.head_dim) | 30 | V = V.reshape(N, -1, self.num_heads, self.head_dim) |
| 30 | 31 | ||
| 31 | - return Q, K, V | 32 | + # Compute scaled dot-product attention scores |
| 33 | + attention_scores = torch.einsum("nqhd,nkhd->nhqk", [Q, K]) | ||
| 34 | + attention_scores = attention_scores / (self.head_dim ** 0.5) | ||
| 35 | + attention = torch.softmax(attention_scores, dim=-1) # Normalize | ||
| 36 | + | ||
| 37 | + return attention | ||
| 32 | 38 | ||
| 33 | 39 | ||
| 34 | if __name__ == "__main__": | 40 | if __name__ == "__main__": |
| @@ -36,10 +42,9 @@ if __name__ == "__main__": | @@ -36,10 +42,9 @@ if __name__ == "__main__": | ||
| 36 | num_heads = 8 | 42 | num_heads = 8 |
| 37 | mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) | 43 | mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) |
| 38 | 44 | ||
| 39 | - # Dummy data | ||
| 40 | values = torch.randn(2, 10, embed_size) | 45 | values = torch.randn(2, 10, embed_size) |
| 41 | keys = torch.randn(2, 10, embed_size) | 46 | keys = torch.randn(2, 10, embed_size) |
| 42 | query = torch.randn(2, 10, embed_size) | 47 | query = torch.randn(2, 10, embed_size) |
| 43 | 48 | ||
| 44 | - Q, K, V = mha_layer(values, keys, query) | ||
| 45 | - print(f"Q shape: {Q.shape}, K shape: {K.shape}, V shape: {V.shape}") | 49 | + attention = mha_layer(values, keys, query) |
| 50 | + print(f"Attention shape: {attention.shape}") |
-
Please register or login to post a comment