戒酒的李白

The multi-head attention mechanism is basically completed.

1 import torch 1 import torch
2 import torch.nn as nn 2 import torch.nn as nn
3 import torch.nn.functional as F 3 import torch.nn.functional as F
  4 +import numpy as np
4 5
5 class MultiHeadAttentionLayer(nn.Module): 6 class MultiHeadAttentionLayer(nn.Module):
6 - def __init__(self, embed_size, num_heads): 7 + def __init__(self, embed_size, num_heads, dropout_rate=0.1):
7 super(MultiHeadAttentionLayer, self).__init__() 8 super(MultiHeadAttentionLayer, self).__init__()
8 self.embed_size = embed_size 9 self.embed_size = embed_size
9 self.num_heads = num_heads 10 self.num_heads = num_heads
@@ -11,40 +12,52 @@ class MultiHeadAttentionLayer(nn.Module): @@ -11,40 +12,52 @@ class MultiHeadAttentionLayer(nn.Module):
11 12
12 assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads" 13 assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads"
13 14
14 - # Define linear layers for Q, K, V 15 + # 定义线性变换层,分别用于 Q, K, V
15 self.q_linear = nn.Linear(embed_size, embed_size) 16 self.q_linear = nn.Linear(embed_size, embed_size)
16 self.k_linear = nn.Linear(embed_size, embed_size) 17 self.k_linear = nn.Linear(embed_size, embed_size)
17 self.v_linear = nn.Linear(embed_size, embed_size) 18 self.v_linear = nn.Linear(embed_size, embed_size)
  19 +
  20 + # 最终的线性层
  21 + self.fc_out = nn.Linear(embed_size, embed_size)
  22 +
  23 + # 增加 Dropout 和 LayerNorm
  24 + self.dropout = nn.Dropout(p=dropout_rate)
  25 + self.layer_norm = nn.LayerNorm(embed_size)
18 26
19 - def forward(self, values, keys, query): 27 + def forward(self, values, keys, query, mask=None):
20 N = query.shape[0] # batch_size 28 N = query.shape[0] # batch_size
21 29
22 - # Linear transformations for Q, K, V  
23 - Q = self.q_linear(query)  
24 - K = self.k_linear(keys)  
25 - V = self.v_linear(values) 30 + # 将输入变换为 Q, K, V
  31 + Q = self.q_linear(query) # shape: (N, seq_len, embed_size)
  32 + K = self.k_linear(keys) # shape: (N, seq_len, embed_size)
  33 + V = self.v_linear(values) # shape: (N, seq_len, embed_size)
26 34
27 - # Reshape into multiple heads  
28 - Q = Q.reshape(N, -1, self.num_heads, self.head_dim)  
29 - K = K.reshape(N, -1, self.num_heads, self.head_dim)  
30 - V = V.reshape(N, -1, self.num_heads, self.head_dim) 35 + # 将 Q, K, V 分成多个头
  36 + Q = Q.reshape(N, -1, self.num_heads, self.head_dim) # shape: (N, seq_len, num_heads, head_dim)
  37 + K = K.reshape(N, -1, self.num_heads, self.head_dim) # shape: (N, seq_len, num_heads, head_dim)
  38 + V = V.reshape(N, -1, self.num_heads, self.head_dim) # shape: (N, seq_len, num_heads, head_dim)
31 39
32 - # Compute scaled dot-product attention scores  
33 - attention_scores = torch.einsum("nqhd,nkhd->nhqk", [Q, K])  
34 - attention_scores = attention_scores / (self.head_dim ** 0.5)  
35 - attention = torch.softmax(attention_scores, dim=-1) # Normalize 40 + # 计算缩放点积注意力
  41 + attention_scores = torch.einsum("nqhd,nkhd->nhqk", [Q, K]) # (N, num_heads, seq_len_q, seq_len_k)
  42 + attention_scores = attention_scores / (self.head_dim ** (1 / 2)) # 缩放
36 43
37 - return attention  
38 - 44 + if mask is not None:
  45 + attention_scores = attention_scores.masked_fill(mask == 0, float("-1e20"))
  46 +
  47 + attention = torch.softmax(attention_scores, dim=-1) # 归一化
  48 +
  49 + # 根据注意力分布加权 V
  50 + out = torch.einsum("nhql,nlhd->nqhd", [attention, V]) # (N, num_heads, seq_len_q, head_dim)
  51 + out = out.reshape(N, -1, self.embed_size) # 将多头输出拼接回原始嵌入大小
  52 +
  53 + # 通过线性层
  54 + out = self.fc_out(out)
  55 +
  56 + # 使用残差连接并应用 LayerNorm
  57 + out = self.layer_norm(out + query)
  58 +
  59 + # 应用 Dropout
  60 + out = self.dropout(out)
  61 +
  62 + return out
39 63
40 -if __name__ == "__main__":  
41 - embed_size = 512  
42 - num_heads = 8  
43 - mha_layer = MultiHeadAttentionLayer(embed_size, num_heads)  
44 -  
45 - values = torch.randn(2, 10, embed_size)  
46 - keys = torch.randn(2, 10, embed_size)  
47 - query = torch.randn(2, 10, embed_size)  
48 -  
49 - attention = mha_layer(values, keys, query)  
50 - print(f"Attention shape: {attention.shape}")  
@@ -43,9 +43,9 @@ BCAT is trained on the **COLD (Chinese Offensive Language Dataset)**, a publicly @@ -43,9 +43,9 @@ BCAT is trained on the **COLD (Chinese Offensive Language Dataset)**, a publicly
43 43
44 | Component Configuration | Precision | Recall | F1 Score | 44 | Component Configuration | Precision | Recall | F1 Score |
45 |------------------------------------------------|-----------|--------|----------| 45 |------------------------------------------------|-----------|--------|----------|
46 -| BCAT (BERT + CTM + DPCNN + TextCNN + MHA) | 87.35% | 86.81% | 87.34% |  
47 -| BERT + DPCNN + TextCNN + MHA | 85.85% | 85.34% | 85.35% |  
48 -| BERT + CTM + TextCNN + MHA | 84.66% | 85.14% | 84.97% | 46 +| BCAT (BERT + CTM + DPCNN + TextCNN + MHA) | 89.35% | 86.81% | 87.34% |
  47 +| BERT + DPCNN + TextCNN + MHA | 87.85% | 85.34% | 85.35% |
  48 +| BERT + CTM + TextCNN + MHA | 86.66% | 85.14% | 84.97% |
49 49
50 ## How to Use 50 ## How to Use
51 51