戒酒的李白

Divide the input into long heads

@@ -15,9 +15,31 @@ class MultiHeadAttentionLayer(nn.Module): @@ -15,9 +15,31 @@ class MultiHeadAttentionLayer(nn.Module):
15 self.k_linear = nn.Linear(embed_size, embed_size) 15 self.k_linear = nn.Linear(embed_size, embed_size)
16 self.v_linear = nn.Linear(embed_size, embed_size) 16 self.v_linear = nn.Linear(embed_size, embed_size)
17 17
  18 + def forward(self, values, keys, query):
  19 + N = query.shape[0] # batch_size
  20 +
  21 + # Linear transformations for Q, K, V
  22 + Q = self.q_linear(query) # shape: (N, seq_len, embed_size)
  23 + K = self.k_linear(keys) # shape: (N, seq_len, embed_size)
  24 + V = self.v_linear(values) # shape: (N, seq_len, embed_size)
  25 +
  26 + # Reshape Q, K, V into multiple heads
  27 + Q = Q.reshape(N, -1, self.num_heads, self.head_dim)
  28 + K = K.reshape(N, -1, self.num_heads, self.head_dim)
  29 + V = V.reshape(N, -1, self.num_heads, self.head_dim)
  30 +
  31 + return Q, K, V
  32 +
18 33
19 if __name__ == "__main__": 34 if __name__ == "__main__":
20 embed_size = 512 35 embed_size = 512
21 num_heads = 8 36 num_heads = 8
22 mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) 37 mha_layer = MultiHeadAttentionLayer(embed_size, num_heads)
23 - print("Linear layers for Q, K, V initialized.") 38 +
  39 + # Dummy data
  40 + values = torch.randn(2, 10, embed_size)
  41 + keys = torch.randn(2, 10, embed_size)
  42 + query = torch.randn(2, 10, embed_size)
  43 +
  44 + Q, K, V = mha_layer(values, keys, query)
  45 + print(f"Q shape: {Q.shape}, K shape: {K.shape}, V shape: {V.shape}")