戒酒的李白

Define the linear transformation layer

@@ -9,10 +9,15 @@ class MultiHeadAttentionLayer(nn.Module): @@ -9,10 +9,15 @@ class MultiHeadAttentionLayer(nn.Module):
9 self.head_dim = embed_size // num_heads 9 self.head_dim = embed_size // num_heads
10 10
11 assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads" 11 assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads"
  12 +
  13 + # Define linear layers for Q, K, V
  14 + self.q_linear = nn.Linear(embed_size, embed_size)
  15 + self.k_linear = nn.Linear(embed_size, embed_size)
  16 + self.v_linear = nn.Linear(embed_size, embed_size)
12 17
13 18
14 if __name__ == "__main__": 19 if __name__ == "__main__":
15 embed_size = 512 20 embed_size = 512
16 num_heads = 8 21 num_heads = 8
17 mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) 22 mha_layer = MultiHeadAttentionLayer(embed_size, num_heads)
18 - print("Model initialized successfully.") 23 + print("Linear layers for Q, K, V initialized.")