戒酒的李白

BCAT is basically completed.

@@ -8,49 +8,49 @@ from MHA import MultiHeadAttentionLayer @@ -8,49 +8,49 @@ from MHA import MultiHeadAttentionLayer
8 from classifier import FinalClassifier 8 from classifier import FinalClassifier
9 from BERT_CTM import BERT_CTM_Model 9 from BERT_CTM import BERT_CTM_Model
10 import os 10 import os
11 -from tqdm import tqdm 11 +from tqdm import tqdm # 导入 tqdm 库用于进度条
12 from sklearn.metrics import confusion_matrix 12 from sklearn.metrics import confusion_matrix
13 13
14 14
15 -# BERT_CTM embeddings generation and loading function 15 +# BERT_CTM 嵌入生成和加载函数
16 def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20, save_path=None): 16 def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20, save_path=None):
17 - # Check if saved embeddings already exist 17 + # 检查是否已经存在保存的嵌入文件
18 if save_path and os.path.exists(save_path): 18 if save_path and os.path.exists(save_path):
19 - print(f"Loading embeddings from {save_path}...") 19 + print(f"从文件 {save_path} 加载嵌入...")
20 embeddings = np.load(save_path) 20 embeddings = np.load(save_path)
21 else: 21 else:
22 - print("Generating BERT+CTM embeddings...") 22 + print("生成 BERT+CTM 嵌入...")
23 bert_ctm_model = BERT_CTM_Model( 23 bert_ctm_model = BERT_CTM_Model(
24 bert_model_path=bert_model_path, 24 bert_model_path=bert_model_path,
25 ctm_tokenizer_path=ctm_tokenizer_path, 25 ctm_tokenizer_path=ctm_tokenizer_path,
26 n_components=n_components, 26 n_components=n_components,
27 num_epochs=num_epochs 27 num_epochs=num_epochs
28 ) 28 )
29 - embeddings = bert_ctm_model.train(texts) # Generate embeddings 29 + embeddings = bert_ctm_model.train(texts) # 生成嵌入
30 30
31 - # Save embeddings to file 31 + # 保存嵌入到文件
32 if save_path: 32 if save_path:
33 - print(f"Saving embeddings to file {save_path}...") 33 + print(f"保存嵌入到文件 {save_path}...")
34 np.save(save_path, embeddings) 34 np.save(save_path, embeddings)
35 35
36 return embeddings 36 return embeddings
37 37
38 38
39 -# Data loading and preparation function 39 +# 数据加载和准备函数
40 def prepare_dataloader(features, labels, batch_size): 40 def prepare_dataloader(features, labels, batch_size):
41 - """Create DataLoader for training, validation, and testing""" 41 + """创建 DataLoader 用于训练、验证和测试"""
42 tensor_x = torch.tensor(features, dtype=torch.float32) 42 tensor_x = torch.tensor(features, dtype=torch.float32)
43 tensor_y = torch.tensor(labels, dtype=torch.long) 43 tensor_y = torch.tensor(labels, dtype=torch.long)
44 dataset = TensorDataset(tensor_x, tensor_y) 44 dataset = TensorDataset(tensor_x, tensor_y)
45 return DataLoader(dataset, batch_size=batch_size, shuffle=True) 45 return DataLoader(dataset, batch_size=batch_size, shuffle=True)
46 46
47 47
48 -# Model training function 48 +# 训练模型函数
49 def train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels, 49 def train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels,
50 bert_model_path, ctm_tokenizer_path, num_heads=8, num_classes=2, epochs=10, batch_size=128, 50 bert_model_path, ctm_tokenizer_path, num_heads=8, num_classes=2, epochs=10, batch_size=128,
51 learning_rate=5e-3, model_save_path='./final_model.pt'): 51 learning_rate=5e-3, model_save_path='./final_model.pt'):
52 - # Step 1: Get BERT+CTM embeddings  
53 - print("Step 1: Getting BERT+CTM embeddings...") 52 + # Step 1: 获取 BERT+CTM 嵌入
  53 + print("Step 1: 获取 BERT+CTM 嵌入...")
54 valid_features = get_bert_ctm_embeddings(valid_data_path, bert_model_path, ctm_tokenizer_path, 54 valid_features = get_bert_ctm_embeddings(valid_data_path, bert_model_path, ctm_tokenizer_path,
55 save_path='valid_embeddings.npy') 55 save_path='valid_embeddings.npy')
56 test_features = get_bert_ctm_embeddings(test_data_path, bert_model_path, ctm_tokenizer_path, 56 test_features = get_bert_ctm_embeddings(test_data_path, bert_model_path, ctm_tokenizer_path,
@@ -58,54 +58,53 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, @@ -58,54 +58,53 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels,
58 train_features = get_bert_ctm_embeddings(train_data_path, bert_model_path, ctm_tokenizer_path, 58 train_features = get_bert_ctm_embeddings(train_data_path, bert_model_path, ctm_tokenizer_path,
59 save_path='train_embeddings.npy') 59 save_path='train_embeddings.npy')
60 60
61 - # Save labels to .npy file  
62 - print("Saving labels to labels.npy file...") 61 + # 保存标签到 .npy 文件
  62 + print("保存标签到 labels.npy 文件...")
63 np.save('train_labels.npy', train_labels) 63 np.save('train_labels.npy', train_labels)
64 np.save('valid_labels.npy', valid_labels) 64 np.save('valid_labels.npy', valid_labels)
65 np.save('test_labels.npy', test_labels) 65 np.save('test_labels.npy', test_labels)
66 66
67 - # Step 2: Validate label correctness  
68 - print("Step 2: Validating label correctness...") 67 + # Step 2: 检查标签的合理性
  68 + print("Step 2: 检查标签的合理性...")
69 unique_labels_train = np.unique(train_labels) 69 unique_labels_train = np.unique(train_labels)
70 unique_labels_valid = np.unique(valid_labels) 70 unique_labels_valid = np.unique(valid_labels)
71 unique_labels_test = np.unique(test_labels) 71 unique_labels_test = np.unique(test_labels)
72 - print(f"Unique train labels: {unique_labels_train}")  
73 - print(f"Train set class distribution: {np.bincount(train_labels)}")  
74 - print(f"Unique validation labels: {unique_labels_valid}")  
75 - print(f"Validation set class distribution: {np.bincount(valid_labels)}")  
76 - print(f"Unique test labels: {unique_labels_test}")  
77 - print(f"Test set class distribution: {np.bincount(test_labels)}") 72 + print(f"训练标签的唯一值: {unique_labels_train}")
  73 + print(f"训练集类别分布: {np.bincount(train_labels)}")
  74 + print(f"验证标签的唯一值: {unique_labels_valid}")
  75 + print(f"验证集类别分布: {np.bincount(valid_labels)}")
  76 + print(f"测试标签的唯一值: {unique_labels_test}")
  77 + print(f"测试集类别分布: {np.bincount(test_labels)}")
78 78
79 if len(unique_labels_train) != num_classes or len(unique_labels_valid) != num_classes or len( 79 if len(unique_labels_train) != num_classes or len(unique_labels_valid) != num_classes or len(
80 unique_labels_test) != num_classes: 80 unique_labels_test) != num_classes:
81 - raise ValueError(f"Number of classes in labels does not match expected: expected {num_classes}, "  
82 - f"but found different classes in training, validation, or test sets") 81 + raise ValueError(f"标签中的类别数量与期望的不符: 期望 {num_classes}, 但训练集、验证集或测试集中发现了其他类别")
83 82
84 - # Step 3: Create DataLoader  
85 - print("Step 3: Creating DataLoader...") 83 + # Step 3: 创建 DataLoader
  84 + print("Step 3: 创建 DataLoader...")
86 train_loader = prepare_dataloader(train_features, train_labels, batch_size) 85 train_loader = prepare_dataloader(train_features, train_labels, batch_size)
87 valid_loader = prepare_dataloader(valid_features, valid_labels, batch_size) 86 valid_loader = prepare_dataloader(valid_features, valid_labels, batch_size)
88 test_loader = prepare_dataloader(test_features, test_labels, batch_size) 87 test_loader = prepare_dataloader(test_features, test_labels, batch_size)
89 88
90 - # Step 4: Initialize CNN  
91 - print("Step 4: Initializing CNN...")  
92 - num_filters = 256 # Use 256 convolutional output channels  
93 - kernel_sizes = [2, 3, 4] # Kernel sizes for convolution 89 + # Step 4: 初始化CNN
  90 + print("Step 4: 初始化CNN...")
  91 + num_filters = 256 # 使用256个卷积输出通道
  92 + kernel_sizes = [2, 3, 4] # 卷积核大小
94 k = 3 * len(kernel_sizes) 93 k = 3 * len(kernel_sizes)
95 - cnn_output_dim = num_filters * (k + 1) # Calculate the output feature dimension of CNN 94 + cnn_output_dim = num_filters * (k + 1) # 计算CNN输出的特征维度
96 95
97 - # Step 5: Initialize attention mechanism  
98 - print("Step 5: Initializing multi-head attention...") 96 + # Step 5: 初始化注意力机制
  97 + print("Step 5: 初始化多头注意力机制...")
99 attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) 98 attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
100 99
101 - # Step 6: Initialize classifier  
102 - print("Step 6: Initializing classifier...") 100 + # Step 6: 初始化分类器
  101 + print("Step 6: 初始化分类器...")
103 classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes) 102 classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes)
104 optimizer = torch.optim.Adam(classifier_model.parameters(), lr=learning_rate) 103 optimizer = torch.optim.Adam(classifier_model.parameters(), lr=learning_rate)
105 criterion = torch.nn.CrossEntropyLoss() 104 criterion = torch.nn.CrossEntropyLoss()
106 105
107 - # Step 7: Start training  
108 - print("Starting training...") 106 + # Step 7: 开始训练
  107 + print("开始训练...")
109 torch.autograd.set_detect_anomaly(True) 108 torch.autograd.set_detect_anomaly(True)
110 for epoch in range(epochs): 109 for epoch in range(epochs):
111 classifier_model.train() 110 classifier_model.train()
@@ -113,28 +112,28 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, @@ -113,28 +112,28 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels,
113 y_true = [] 112 y_true = []
114 y_pred = [] 113 y_pred = []
115 114
116 - # Use tqdm to add progress bar for CNN feature extraction 115 + # 使用 tqdm 为 CNN 特征提取添加进度条
117 for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} - Training"): 116 for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} - Training"):
118 optimizer.zero_grad() 117 optimizer.zero_grad()
119 batch_x = torch.mean(batch_x, dim=1) 118 batch_x = torch.mean(batch_x, dim=1)
120 - # Extract features from CNN 119 + # 从CNN提取特征
121 # cnn_output = extract_CNN_features(batch_x) 120 # cnn_output = extract_CNN_features(batch_x)
122 # batch_x = torch.mean(batch_x, dim=1) 121 # batch_x = torch.mean(batch_x, dim=1)
123 - # cnn_output = torch.cat((batch_x, cnn_output), dim=-1) 122 + # cnn_output = torch.cat((batch_x,cnn_output), dim=-1)
124 attention_output = attention_model(batch_x, batch_x, batch_x) 123 attention_output = attention_model(batch_x, batch_x, batch_x)
125 outputs = classifier_model(attention_output) 124 outputs = classifier_model(attention_output)
126 outputs = torch.mean(outputs, dim=1) 125 outputs = torch.mean(outputs, dim=1)
127 - loss = criterion(outputs, batch_y) # Compute loss  
128 - loss.backward() # Backpropagation  
129 - optimizer.step() # Optimize 126 + loss = criterion(outputs, batch_y) # 计算损失
  127 + loss.backward() # 反向传播
  128 + optimizer.step() # 优化
130 129
131 epoch_loss += loss.item() 130 epoch_loss += loss.item()
132 131
133 - _, predicted = torch.max(outputs, 1) # Get predicted class 132 + _, predicted = torch.max(outputs, 1) # 获取预测类别
134 y_true.extend(batch_y.tolist()) 133 y_true.extend(batch_y.tolist())
135 y_pred.extend(predicted.tolist()) 134 y_pred.extend(predicted.tolist())
136 135
137 - # Calculate training accuracy, precision, recall, and F1 score 136 + # 计算训练准确率、精确率、召回率和F1分数
138 accuracy = accuracy_score(y_true, y_pred) 137 accuracy = accuracy_score(y_true, y_pred)
139 precision = precision_score(y_true, y_pred, average='macro') 138 precision = precision_score(y_true, y_pred, average='macro')
140 recall = recall_score(y_true, y_pred, average='macro') 139 recall = recall_score(y_true, y_pred, average='macro')
@@ -144,11 +143,11 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, @@ -144,11 +143,11 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels,
144 f"Epoch [{epoch + 1}/{epochs}] Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") 143 f"Epoch [{epoch + 1}/{epochs}] Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
145 print(confusion_matrix(y_true, y_pred)) 144 print(confusion_matrix(y_true, y_pred))
146 145
147 - # Save model 146 + # 保存模型
148 torch.save(classifier_model, model_save_path) 147 torch.save(classifier_model, model_save_path)
149 - print(f"Trained model has been saved to {model_save_path}") 148 + print(f"训练好的模型已经保存到 {model_save_path}")
150 149
151 - # Validation set evaluation 150 + # 验证集评估
152 classifier_model.eval() 151 classifier_model.eval()
153 y_true = [] 152 y_true = []
154 y_pred = [] 153 y_pred = []
@@ -158,7 +157,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, @@ -158,7 +157,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels,
158 batch_x = torch.mean(batch_x, dim=1) 157 batch_x = torch.mean(batch_x, dim=1)
159 # cnn_output = extract_CNN_features(batch_x) 158 # cnn_output = extract_CNN_features(batch_x)
160 # batch_x = torch.mean(batch_x, dim=1) 159 # batch_x = torch.mean(batch_x, dim=1)
161 - # cnn_output = torch.cat((batch_x, cnn_output), dim=-1) 160 + # cnn_output = torch.cat((batch_x,cnn_output), dim=-1)
162 attention_output = attention_model(batch_x, batch_x, batch_x) 161 attention_output = attention_model(batch_x, batch_x, batch_x)
163 outputs = classifier_model(attention_output) 162 outputs = classifier_model(attention_output)
164 outputs = torch.mean(outputs, dim=1) 163 outputs = torch.mean(outputs, dim=1)
@@ -166,7 +165,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, @@ -166,7 +165,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels,
166 y_true.extend(batch_y.tolist()) 165 y_true.extend(batch_y.tolist())
167 y_pred.extend(predicted.tolist()) 166 y_pred.extend(predicted.tolist())
168 167
169 - # Validation accuracy, precision, recall, and F1 score 168 + # 验证集准确率、精确率、召回率和F1分数
170 accuracy = accuracy_score(y_true, y_pred) 169 accuracy = accuracy_score(y_true, y_pred)
171 precision = precision_score(y_true, y_pred, average='macro') 170 precision = precision_score(y_true, y_pred, average='macro')
172 recall = recall_score(y_true, y_pred, average='macro') 171 recall = recall_score(y_true, y_pred, average='macro')
@@ -175,7 +174,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, @@ -175,7 +174,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels,
175 print(f"\nValidation - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") 174 print(f"\nValidation - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
176 print(confusion_matrix(y_true, y_pred)) 175 print(confusion_matrix(y_true, y_pred))
177 176
178 - # Test set evaluation 177 + # 测试集评估
179 y_true = [] 178 y_true = []
180 y_pred = [] 179 y_pred = []
181 180
@@ -184,14 +183,14 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, @@ -184,14 +183,14 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels,
184 batch_x = torch.mean(batch_x, dim=1) 183 batch_x = torch.mean(batch_x, dim=1)
185 # cnn_output = extract_CNN_features(batch_x) 184 # cnn_output = extract_CNN_features(batch_x)
186 # batch_x = torch.mean(batch_x, dim=1) 185 # batch_x = torch.mean(batch_x, dim=1)
187 - # cnn_output = torch.cat((batch_x, cnn_output), dim=-1) 186 + # cnn_output = torch.cat((batch_x,cnn_output), dim=-1)
188 attention_output = attention_model(batch_x, batch_x, batch_x) 187 attention_output = attention_model(batch_x, batch_x, batch_x)
189 outputs = classifier_model(attention_output) 188 outputs = classifier_model(attention_output)
190 outputs = torch.mean(outputs, dim=1) 189 outputs = torch.mean(outputs, dim=1)
191 _, predicted = torch.max(outputs, 1) 190 _, predicted = torch.max(outputs, 1)
192 y_true.extend(batch_y.tolist()) 191 y_true.extend(batch_y.tolist())
193 y_pred.extend(predicted.tolist()) 192 y_pred.extend(predicted.tolist())
194 - # Test accuracy, precision, recall, and F1 score 193 + # 测试集准确率、精确率、召回率和F1分数
195 accuracy = accuracy_score(y_true, y_pred) 194 accuracy = accuracy_score(y_true, y_pred)
196 precision = precision_score(y_true, y_pred, average='macro') 195 precision = precision_score(y_true, y_pred, average='macro')
197 recall = recall_score(y_true, y_pred, average='macro') 196 recall = recall_score(y_true, y_pred, average='macro')
@@ -202,7 +201,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, @@ -202,7 +201,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels,
202 201
203 202
204 if __name__ == "__main__": 203 if __name__ == "__main__":
205 - # Load and prepare data 204 + # 加载和准备数据
206 train_data_path = './train.csv' 205 train_data_path = './train.csv'
207 valid_data_path = './dev.csv' 206 valid_data_path = './dev.csv'
208 test_data_path = './test.csv' 207 test_data_path = './test.csv'
@@ -215,10 +214,10 @@ if __name__ == "__main__": @@ -215,10 +214,10 @@ if __name__ == "__main__":
215 valid_labels = valid_data['label'].values 214 valid_labels = valid_data['label'].values
216 test_labels = test_data['label'].values 215 test_labels = test_data['label'].values
217 216
218 - # Train model 217 + # 训练模型
219 bert_model_path = './bert_model' 218 bert_model_path = './bert_model'
220 ctm_tokenizer_path = './sentence_bert_model' 219 ctm_tokenizer_path = './sentence_bert_model'
221 220
222 - # Train model 221 + # 训练模型
223 train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels, 222 train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels,
224 bert_model_path, ctm_tokenizer_path, num_heads=12, num_classes=2, model_save_path='./final_model.pt') 223 bert_model_path, ctm_tokenizer_path, num_heads=12, num_classes=2, model_save_path='./final_model.pt')