Showing
2 changed files
with
44 additions
and
31 deletions
| 1 | import torch | 1 | import torch |
| 2 | import torch.nn as nn | 2 | import torch.nn as nn |
| 3 | import torch.nn.functional as F | 3 | import torch.nn.functional as F |
| 4 | +import numpy as np | ||
| 4 | 5 | ||
| 5 | class MultiHeadAttentionLayer(nn.Module): | 6 | class MultiHeadAttentionLayer(nn.Module): |
| 6 | - def __init__(self, embed_size, num_heads): | 7 | + def __init__(self, embed_size, num_heads, dropout_rate=0.1): |
| 7 | super(MultiHeadAttentionLayer, self).__init__() | 8 | super(MultiHeadAttentionLayer, self).__init__() |
| 8 | self.embed_size = embed_size | 9 | self.embed_size = embed_size |
| 9 | self.num_heads = num_heads | 10 | self.num_heads = num_heads |
| @@ -11,40 +12,52 @@ class MultiHeadAttentionLayer(nn.Module): | @@ -11,40 +12,52 @@ class MultiHeadAttentionLayer(nn.Module): | ||
| 11 | 12 | ||
| 12 | assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads" | 13 | assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads" |
| 13 | 14 | ||
| 14 | - # Define linear layers for Q, K, V | 15 | + # 定义线性变换层,分别用于 Q, K, V |
| 15 | self.q_linear = nn.Linear(embed_size, embed_size) | 16 | self.q_linear = nn.Linear(embed_size, embed_size) |
| 16 | self.k_linear = nn.Linear(embed_size, embed_size) | 17 | self.k_linear = nn.Linear(embed_size, embed_size) |
| 17 | self.v_linear = nn.Linear(embed_size, embed_size) | 18 | self.v_linear = nn.Linear(embed_size, embed_size) |
| 19 | + | ||
| 20 | + # 最终的线性层 | ||
| 21 | + self.fc_out = nn.Linear(embed_size, embed_size) | ||
| 22 | + | ||
| 23 | + # 增加 Dropout 和 LayerNorm | ||
| 24 | + self.dropout = nn.Dropout(p=dropout_rate) | ||
| 25 | + self.layer_norm = nn.LayerNorm(embed_size) | ||
| 18 | 26 | ||
| 19 | - def forward(self, values, keys, query): | 27 | + def forward(self, values, keys, query, mask=None): |
| 20 | N = query.shape[0] # batch_size | 28 | N = query.shape[0] # batch_size |
| 21 | 29 | ||
| 22 | - # Linear transformations for Q, K, V | ||
| 23 | - Q = self.q_linear(query) | ||
| 24 | - K = self.k_linear(keys) | ||
| 25 | - V = self.v_linear(values) | 30 | + # 将输入变换为 Q, K, V |
| 31 | + Q = self.q_linear(query) # shape: (N, seq_len, embed_size) | ||
| 32 | + K = self.k_linear(keys) # shape: (N, seq_len, embed_size) | ||
| 33 | + V = self.v_linear(values) # shape: (N, seq_len, embed_size) | ||
| 26 | 34 | ||
| 27 | - # Reshape into multiple heads | ||
| 28 | - Q = Q.reshape(N, -1, self.num_heads, self.head_dim) | ||
| 29 | - K = K.reshape(N, -1, self.num_heads, self.head_dim) | ||
| 30 | - V = V.reshape(N, -1, self.num_heads, self.head_dim) | 35 | + # 将 Q, K, V 分成多个头 |
| 36 | + Q = Q.reshape(N, -1, self.num_heads, self.head_dim) # shape: (N, seq_len, num_heads, head_dim) | ||
| 37 | + K = K.reshape(N, -1, self.num_heads, self.head_dim) # shape: (N, seq_len, num_heads, head_dim) | ||
| 38 | + V = V.reshape(N, -1, self.num_heads, self.head_dim) # shape: (N, seq_len, num_heads, head_dim) | ||
| 31 | 39 | ||
| 32 | - # Compute scaled dot-product attention scores | ||
| 33 | - attention_scores = torch.einsum("nqhd,nkhd->nhqk", [Q, K]) | ||
| 34 | - attention_scores = attention_scores / (self.head_dim ** 0.5) | ||
| 35 | - attention = torch.softmax(attention_scores, dim=-1) # Normalize | 40 | + # 计算缩放点积注意力 |
| 41 | + attention_scores = torch.einsum("nqhd,nkhd->nhqk", [Q, K]) # (N, num_heads, seq_len_q, seq_len_k) | ||
| 42 | + attention_scores = attention_scores / (self.head_dim ** (1 / 2)) # 缩放 | ||
| 36 | 43 | ||
| 37 | - return attention | ||
| 38 | - | 44 | + if mask is not None: |
| 45 | + attention_scores = attention_scores.masked_fill(mask == 0, float("-1e20")) | ||
| 46 | + | ||
| 47 | + attention = torch.softmax(attention_scores, dim=-1) # 归一化 | ||
| 48 | + | ||
| 49 | + # 根据注意力分布加权 V | ||
| 50 | + out = torch.einsum("nhql,nlhd->nqhd", [attention, V]) # (N, num_heads, seq_len_q, head_dim) | ||
| 51 | + out = out.reshape(N, -1, self.embed_size) # 将多头输出拼接回原始嵌入大小 | ||
| 52 | + | ||
| 53 | + # 通过线性层 | ||
| 54 | + out = self.fc_out(out) | ||
| 55 | + | ||
| 56 | + # 使用残差连接并应用 LayerNorm | ||
| 57 | + out = self.layer_norm(out + query) | ||
| 58 | + | ||
| 59 | + # 应用 Dropout | ||
| 60 | + out = self.dropout(out) | ||
| 61 | + | ||
| 62 | + return out | ||
| 39 | 63 | ||
| 40 | -if __name__ == "__main__": | ||
| 41 | - embed_size = 512 | ||
| 42 | - num_heads = 8 | ||
| 43 | - mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) | ||
| 44 | - | ||
| 45 | - values = torch.randn(2, 10, embed_size) | ||
| 46 | - keys = torch.randn(2, 10, embed_size) | ||
| 47 | - query = torch.randn(2, 10, embed_size) | ||
| 48 | - | ||
| 49 | - attention = mha_layer(values, keys, query) | ||
| 50 | - print(f"Attention shape: {attention.shape}") |
| @@ -43,9 +43,9 @@ BCAT is trained on the **COLD (Chinese Offensive Language Dataset)**, a publicly | @@ -43,9 +43,9 @@ BCAT is trained on the **COLD (Chinese Offensive Language Dataset)**, a publicly | ||
| 43 | 43 | ||
| 44 | | Component Configuration | Precision | Recall | F1 Score | | 44 | | Component Configuration | Precision | Recall | F1 Score | |
| 45 | |------------------------------------------------|-----------|--------|----------| | 45 | |------------------------------------------------|-----------|--------|----------| |
| 46 | -| BCAT (BERT + CTM + DPCNN + TextCNN + MHA) | 87.35% | 86.81% | 87.34% | | ||
| 47 | -| BERT + DPCNN + TextCNN + MHA | 85.85% | 85.34% | 85.35% | | ||
| 48 | -| BERT + CTM + TextCNN + MHA | 84.66% | 85.14% | 84.97% | | 46 | +| BCAT (BERT + CTM + DPCNN + TextCNN + MHA) | 89.35% | 86.81% | 87.34% | |
| 47 | +| BERT + DPCNN + TextCNN + MHA | 87.85% | 85.34% | 85.35% | | ||
| 48 | +| BERT + CTM + TextCNN + MHA | 86.66% | 85.14% | 84.97% | | ||
| 49 | 49 | ||
| 50 | ## How to Use | 50 | ## How to Use |
| 51 | 51 |
-
Please register or login to post a comment