Showing
1 changed file
with
6 additions
and
1 deletions
| @@ -9,10 +9,15 @@ class MultiHeadAttentionLayer(nn.Module): | @@ -9,10 +9,15 @@ class MultiHeadAttentionLayer(nn.Module): | ||
| 9 | self.head_dim = embed_size // num_heads | 9 | self.head_dim = embed_size // num_heads |
| 10 | 10 | ||
| 11 | assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads" | 11 | assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads" |
| 12 | + | ||
| 13 | + # Define linear layers for Q, K, V | ||
| 14 | + self.q_linear = nn.Linear(embed_size, embed_size) | ||
| 15 | + self.k_linear = nn.Linear(embed_size, embed_size) | ||
| 16 | + self.v_linear = nn.Linear(embed_size, embed_size) | ||
| 12 | 17 | ||
| 13 | 18 | ||
| 14 | if __name__ == "__main__": | 19 | if __name__ == "__main__": |
| 15 | embed_size = 512 | 20 | embed_size = 512 |
| 16 | num_heads = 8 | 21 | num_heads = 8 |
| 17 | mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) | 22 | mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) |
| 18 | - print("Model initialized successfully.") | 23 | + print("Linear layers for Q, K, V initialized.") |
-
Please register or login to post a comment