Showing
1 changed file
with
23 additions
and
1 deletions
| @@ -15,9 +15,31 @@ class MultiHeadAttentionLayer(nn.Module): | @@ -15,9 +15,31 @@ class MultiHeadAttentionLayer(nn.Module): | ||
| 15 | self.k_linear = nn.Linear(embed_size, embed_size) | 15 | self.k_linear = nn.Linear(embed_size, embed_size) |
| 16 | self.v_linear = nn.Linear(embed_size, embed_size) | 16 | self.v_linear = nn.Linear(embed_size, embed_size) |
| 17 | 17 | ||
| 18 | + def forward(self, values, keys, query): | ||
| 19 | + N = query.shape[0] # batch_size | ||
| 20 | + | ||
| 21 | + # Linear transformations for Q, K, V | ||
| 22 | + Q = self.q_linear(query) # shape: (N, seq_len, embed_size) | ||
| 23 | + K = self.k_linear(keys) # shape: (N, seq_len, embed_size) | ||
| 24 | + V = self.v_linear(values) # shape: (N, seq_len, embed_size) | ||
| 25 | + | ||
| 26 | + # Reshape Q, K, V into multiple heads | ||
| 27 | + Q = Q.reshape(N, -1, self.num_heads, self.head_dim) | ||
| 28 | + K = K.reshape(N, -1, self.num_heads, self.head_dim) | ||
| 29 | + V = V.reshape(N, -1, self.num_heads, self.head_dim) | ||
| 30 | + | ||
| 31 | + return Q, K, V | ||
| 32 | + | ||
| 18 | 33 | ||
| 19 | if __name__ == "__main__": | 34 | if __name__ == "__main__": |
| 20 | embed_size = 512 | 35 | embed_size = 512 |
| 21 | num_heads = 8 | 36 | num_heads = 8 |
| 22 | mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) | 37 | mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) |
| 23 | - print("Linear layers for Q, K, V initialized.") | 38 | + |
| 39 | + # Dummy data | ||
| 40 | + values = torch.randn(2, 10, embed_size) | ||
| 41 | + keys = torch.randn(2, 10, embed_size) | ||
| 42 | + query = torch.randn(2, 10, embed_size) | ||
| 43 | + | ||
| 44 | + Q, K, V = mha_layer(values, keys, query) | ||
| 45 | + print(f"Q shape: {Q.shape}, K shape: {K.shape}, V shape: {V.shape}") |
-
Please register or login to post a comment