Showing
1 changed file
with
56 additions
and
57 deletions
| @@ -8,49 +8,49 @@ from MHA import MultiHeadAttentionLayer | @@ -8,49 +8,49 @@ from MHA import MultiHeadAttentionLayer | ||
| 8 | from classifier import FinalClassifier | 8 | from classifier import FinalClassifier |
| 9 | from BERT_CTM import BERT_CTM_Model | 9 | from BERT_CTM import BERT_CTM_Model |
| 10 | import os | 10 | import os |
| 11 | -from tqdm import tqdm | 11 | +from tqdm import tqdm # 导入 tqdm 库用于进度条 |
| 12 | from sklearn.metrics import confusion_matrix | 12 | from sklearn.metrics import confusion_matrix |
| 13 | 13 | ||
| 14 | 14 | ||
| 15 | -# BERT_CTM embeddings generation and loading function | 15 | +# BERT_CTM 嵌入生成和加载函数 |
| 16 | def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20, save_path=None): | 16 | def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20, save_path=None): |
| 17 | - # Check if saved embeddings already exist | 17 | + # 检查是否已经存在保存的嵌入文件 |
| 18 | if save_path and os.path.exists(save_path): | 18 | if save_path and os.path.exists(save_path): |
| 19 | - print(f"Loading embeddings from {save_path}...") | 19 | + print(f"从文件 {save_path} 加载嵌入...") |
| 20 | embeddings = np.load(save_path) | 20 | embeddings = np.load(save_path) |
| 21 | else: | 21 | else: |
| 22 | - print("Generating BERT+CTM embeddings...") | 22 | + print("生成 BERT+CTM 嵌入...") |
| 23 | bert_ctm_model = BERT_CTM_Model( | 23 | bert_ctm_model = BERT_CTM_Model( |
| 24 | bert_model_path=bert_model_path, | 24 | bert_model_path=bert_model_path, |
| 25 | ctm_tokenizer_path=ctm_tokenizer_path, | 25 | ctm_tokenizer_path=ctm_tokenizer_path, |
| 26 | n_components=n_components, | 26 | n_components=n_components, |
| 27 | num_epochs=num_epochs | 27 | num_epochs=num_epochs |
| 28 | ) | 28 | ) |
| 29 | - embeddings = bert_ctm_model.train(texts) # Generate embeddings | 29 | + embeddings = bert_ctm_model.train(texts) # 生成嵌入 |
| 30 | 30 | ||
| 31 | - # Save embeddings to file | 31 | + # 保存嵌入到文件 |
| 32 | if save_path: | 32 | if save_path: |
| 33 | - print(f"Saving embeddings to file {save_path}...") | 33 | + print(f"保存嵌入到文件 {save_path}...") |
| 34 | np.save(save_path, embeddings) | 34 | np.save(save_path, embeddings) |
| 35 | 35 | ||
| 36 | return embeddings | 36 | return embeddings |
| 37 | 37 | ||
| 38 | 38 | ||
| 39 | -# Data loading and preparation function | 39 | +# 数据加载和准备函数 |
| 40 | def prepare_dataloader(features, labels, batch_size): | 40 | def prepare_dataloader(features, labels, batch_size): |
| 41 | - """Create DataLoader for training, validation, and testing""" | 41 | + """创建 DataLoader 用于训练、验证和测试""" |
| 42 | tensor_x = torch.tensor(features, dtype=torch.float32) | 42 | tensor_x = torch.tensor(features, dtype=torch.float32) |
| 43 | tensor_y = torch.tensor(labels, dtype=torch.long) | 43 | tensor_y = torch.tensor(labels, dtype=torch.long) |
| 44 | dataset = TensorDataset(tensor_x, tensor_y) | 44 | dataset = TensorDataset(tensor_x, tensor_y) |
| 45 | return DataLoader(dataset, batch_size=batch_size, shuffle=True) | 45 | return DataLoader(dataset, batch_size=batch_size, shuffle=True) |
| 46 | 46 | ||
| 47 | 47 | ||
| 48 | -# Model training function | 48 | +# 训练模型函数 |
| 49 | def train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels, | 49 | def train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels, |
| 50 | bert_model_path, ctm_tokenizer_path, num_heads=8, num_classes=2, epochs=10, batch_size=128, | 50 | bert_model_path, ctm_tokenizer_path, num_heads=8, num_classes=2, epochs=10, batch_size=128, |
| 51 | learning_rate=5e-3, model_save_path='./final_model.pt'): | 51 | learning_rate=5e-3, model_save_path='./final_model.pt'): |
| 52 | - # Step 1: Get BERT+CTM embeddings | ||
| 53 | - print("Step 1: Getting BERT+CTM embeddings...") | 52 | + # Step 1: 获取 BERT+CTM 嵌入 |
| 53 | + print("Step 1: 获取 BERT+CTM 嵌入...") | ||
| 54 | valid_features = get_bert_ctm_embeddings(valid_data_path, bert_model_path, ctm_tokenizer_path, | 54 | valid_features = get_bert_ctm_embeddings(valid_data_path, bert_model_path, ctm_tokenizer_path, |
| 55 | save_path='valid_embeddings.npy') | 55 | save_path='valid_embeddings.npy') |
| 56 | test_features = get_bert_ctm_embeddings(test_data_path, bert_model_path, ctm_tokenizer_path, | 56 | test_features = get_bert_ctm_embeddings(test_data_path, bert_model_path, ctm_tokenizer_path, |
| @@ -58,54 +58,53 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | @@ -58,54 +58,53 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | ||
| 58 | train_features = get_bert_ctm_embeddings(train_data_path, bert_model_path, ctm_tokenizer_path, | 58 | train_features = get_bert_ctm_embeddings(train_data_path, bert_model_path, ctm_tokenizer_path, |
| 59 | save_path='train_embeddings.npy') | 59 | save_path='train_embeddings.npy') |
| 60 | 60 | ||
| 61 | - # Save labels to .npy file | ||
| 62 | - print("Saving labels to labels.npy file...") | 61 | + # 保存标签到 .npy 文件 |
| 62 | + print("保存标签到 labels.npy 文件...") | ||
| 63 | np.save('train_labels.npy', train_labels) | 63 | np.save('train_labels.npy', train_labels) |
| 64 | np.save('valid_labels.npy', valid_labels) | 64 | np.save('valid_labels.npy', valid_labels) |
| 65 | np.save('test_labels.npy', test_labels) | 65 | np.save('test_labels.npy', test_labels) |
| 66 | 66 | ||
| 67 | - # Step 2: Validate label correctness | ||
| 68 | - print("Step 2: Validating label correctness...") | 67 | + # Step 2: 检查标签的合理性 |
| 68 | + print("Step 2: 检查标签的合理性...") | ||
| 69 | unique_labels_train = np.unique(train_labels) | 69 | unique_labels_train = np.unique(train_labels) |
| 70 | unique_labels_valid = np.unique(valid_labels) | 70 | unique_labels_valid = np.unique(valid_labels) |
| 71 | unique_labels_test = np.unique(test_labels) | 71 | unique_labels_test = np.unique(test_labels) |
| 72 | - print(f"Unique train labels: {unique_labels_train}") | ||
| 73 | - print(f"Train set class distribution: {np.bincount(train_labels)}") | ||
| 74 | - print(f"Unique validation labels: {unique_labels_valid}") | ||
| 75 | - print(f"Validation set class distribution: {np.bincount(valid_labels)}") | ||
| 76 | - print(f"Unique test labels: {unique_labels_test}") | ||
| 77 | - print(f"Test set class distribution: {np.bincount(test_labels)}") | 72 | + print(f"训练标签的唯一值: {unique_labels_train}") |
| 73 | + print(f"训练集类别分布: {np.bincount(train_labels)}") | ||
| 74 | + print(f"验证标签的唯一值: {unique_labels_valid}") | ||
| 75 | + print(f"验证集类别分布: {np.bincount(valid_labels)}") | ||
| 76 | + print(f"测试标签的唯一值: {unique_labels_test}") | ||
| 77 | + print(f"测试集类别分布: {np.bincount(test_labels)}") | ||
| 78 | 78 | ||
| 79 | if len(unique_labels_train) != num_classes or len(unique_labels_valid) != num_classes or len( | 79 | if len(unique_labels_train) != num_classes or len(unique_labels_valid) != num_classes or len( |
| 80 | unique_labels_test) != num_classes: | 80 | unique_labels_test) != num_classes: |
| 81 | - raise ValueError(f"Number of classes in labels does not match expected: expected {num_classes}, " | ||
| 82 | - f"but found different classes in training, validation, or test sets") | 81 | + raise ValueError(f"标签中的类别数量与期望的不符: 期望 {num_classes}, 但训练集、验证集或测试集中发现了其他类别") |
| 83 | 82 | ||
| 84 | - # Step 3: Create DataLoader | ||
| 85 | - print("Step 3: Creating DataLoader...") | 83 | + # Step 3: 创建 DataLoader |
| 84 | + print("Step 3: 创建 DataLoader...") | ||
| 86 | train_loader = prepare_dataloader(train_features, train_labels, batch_size) | 85 | train_loader = prepare_dataloader(train_features, train_labels, batch_size) |
| 87 | valid_loader = prepare_dataloader(valid_features, valid_labels, batch_size) | 86 | valid_loader = prepare_dataloader(valid_features, valid_labels, batch_size) |
| 88 | test_loader = prepare_dataloader(test_features, test_labels, batch_size) | 87 | test_loader = prepare_dataloader(test_features, test_labels, batch_size) |
| 89 | 88 | ||
| 90 | - # Step 4: Initialize CNN | ||
| 91 | - print("Step 4: Initializing CNN...") | ||
| 92 | - num_filters = 256 # Use 256 convolutional output channels | ||
| 93 | - kernel_sizes = [2, 3, 4] # Kernel sizes for convolution | 89 | + # Step 4: 初始化CNN |
| 90 | + print("Step 4: 初始化CNN...") | ||
| 91 | + num_filters = 256 # 使用256个卷积输出通道 | ||
| 92 | + kernel_sizes = [2, 3, 4] # 卷积核大小 | ||
| 94 | k = 3 * len(kernel_sizes) | 93 | k = 3 * len(kernel_sizes) |
| 95 | - cnn_output_dim = num_filters * (k + 1) # Calculate the output feature dimension of CNN | 94 | + cnn_output_dim = num_filters * (k + 1) # 计算CNN输出的特征维度 |
| 96 | 95 | ||
| 97 | - # Step 5: Initialize attention mechanism | ||
| 98 | - print("Step 5: Initializing multi-head attention...") | 96 | + # Step 5: 初始化注意力机制 |
| 97 | + print("Step 5: 初始化多头注意力机制...") | ||
| 99 | attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) | 98 | attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) |
| 100 | 99 | ||
| 101 | - # Step 6: Initialize classifier | ||
| 102 | - print("Step 6: Initializing classifier...") | 100 | + # Step 6: 初始化分类器 |
| 101 | + print("Step 6: 初始化分类器...") | ||
| 103 | classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes) | 102 | classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes) |
| 104 | optimizer = torch.optim.Adam(classifier_model.parameters(), lr=learning_rate) | 103 | optimizer = torch.optim.Adam(classifier_model.parameters(), lr=learning_rate) |
| 105 | criterion = torch.nn.CrossEntropyLoss() | 104 | criterion = torch.nn.CrossEntropyLoss() |
| 106 | 105 | ||
| 107 | - # Step 7: Start training | ||
| 108 | - print("Starting training...") | 106 | + # Step 7: 开始训练 |
| 107 | + print("开始训练...") | ||
| 109 | torch.autograd.set_detect_anomaly(True) | 108 | torch.autograd.set_detect_anomaly(True) |
| 110 | for epoch in range(epochs): | 109 | for epoch in range(epochs): |
| 111 | classifier_model.train() | 110 | classifier_model.train() |
| @@ -113,28 +112,28 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | @@ -113,28 +112,28 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | ||
| 113 | y_true = [] | 112 | y_true = [] |
| 114 | y_pred = [] | 113 | y_pred = [] |
| 115 | 114 | ||
| 116 | - # Use tqdm to add progress bar for CNN feature extraction | 115 | + # 使用 tqdm 为 CNN 特征提取添加进度条 |
| 117 | for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} - Training"): | 116 | for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} - Training"): |
| 118 | optimizer.zero_grad() | 117 | optimizer.zero_grad() |
| 119 | batch_x = torch.mean(batch_x, dim=1) | 118 | batch_x = torch.mean(batch_x, dim=1) |
| 120 | - # Extract features from CNN | 119 | + # 从CNN提取特征 |
| 121 | # cnn_output = extract_CNN_features(batch_x) | 120 | # cnn_output = extract_CNN_features(batch_x) |
| 122 | # batch_x = torch.mean(batch_x, dim=1) | 121 | # batch_x = torch.mean(batch_x, dim=1) |
| 123 | - # cnn_output = torch.cat((batch_x, cnn_output), dim=-1) | 122 | + # cnn_output = torch.cat((batch_x,cnn_output), dim=-1) |
| 124 | attention_output = attention_model(batch_x, batch_x, batch_x) | 123 | attention_output = attention_model(batch_x, batch_x, batch_x) |
| 125 | outputs = classifier_model(attention_output) | 124 | outputs = classifier_model(attention_output) |
| 126 | outputs = torch.mean(outputs, dim=1) | 125 | outputs = torch.mean(outputs, dim=1) |
| 127 | - loss = criterion(outputs, batch_y) # Compute loss | ||
| 128 | - loss.backward() # Backpropagation | ||
| 129 | - optimizer.step() # Optimize | 126 | + loss = criterion(outputs, batch_y) # 计算损失 |
| 127 | + loss.backward() # 反向传播 | ||
| 128 | + optimizer.step() # 优化 | ||
| 130 | 129 | ||
| 131 | epoch_loss += loss.item() | 130 | epoch_loss += loss.item() |
| 132 | 131 | ||
| 133 | - _, predicted = torch.max(outputs, 1) # Get predicted class | 132 | + _, predicted = torch.max(outputs, 1) # 获取预测类别 |
| 134 | y_true.extend(batch_y.tolist()) | 133 | y_true.extend(batch_y.tolist()) |
| 135 | y_pred.extend(predicted.tolist()) | 134 | y_pred.extend(predicted.tolist()) |
| 136 | 135 | ||
| 137 | - # Calculate training accuracy, precision, recall, and F1 score | 136 | + # 计算训练准确率、精确率、召回率和F1分数 |
| 138 | accuracy = accuracy_score(y_true, y_pred) | 137 | accuracy = accuracy_score(y_true, y_pred) |
| 139 | precision = precision_score(y_true, y_pred, average='macro') | 138 | precision = precision_score(y_true, y_pred, average='macro') |
| 140 | recall = recall_score(y_true, y_pred, average='macro') | 139 | recall = recall_score(y_true, y_pred, average='macro') |
| @@ -144,11 +143,11 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | @@ -144,11 +143,11 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | ||
| 144 | f"Epoch [{epoch + 1}/{epochs}] Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") | 143 | f"Epoch [{epoch + 1}/{epochs}] Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") |
| 145 | print(confusion_matrix(y_true, y_pred)) | 144 | print(confusion_matrix(y_true, y_pred)) |
| 146 | 145 | ||
| 147 | - # Save model | 146 | + # 保存模型 |
| 148 | torch.save(classifier_model, model_save_path) | 147 | torch.save(classifier_model, model_save_path) |
| 149 | - print(f"Trained model has been saved to {model_save_path}") | 148 | + print(f"训练好的模型已经保存到 {model_save_path}") |
| 150 | 149 | ||
| 151 | - # Validation set evaluation | 150 | + # 验证集评估 |
| 152 | classifier_model.eval() | 151 | classifier_model.eval() |
| 153 | y_true = [] | 152 | y_true = [] |
| 154 | y_pred = [] | 153 | y_pred = [] |
| @@ -158,7 +157,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | @@ -158,7 +157,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | ||
| 158 | batch_x = torch.mean(batch_x, dim=1) | 157 | batch_x = torch.mean(batch_x, dim=1) |
| 159 | # cnn_output = extract_CNN_features(batch_x) | 158 | # cnn_output = extract_CNN_features(batch_x) |
| 160 | # batch_x = torch.mean(batch_x, dim=1) | 159 | # batch_x = torch.mean(batch_x, dim=1) |
| 161 | - # cnn_output = torch.cat((batch_x, cnn_output), dim=-1) | 160 | + # cnn_output = torch.cat((batch_x,cnn_output), dim=-1) |
| 162 | attention_output = attention_model(batch_x, batch_x, batch_x) | 161 | attention_output = attention_model(batch_x, batch_x, batch_x) |
| 163 | outputs = classifier_model(attention_output) | 162 | outputs = classifier_model(attention_output) |
| 164 | outputs = torch.mean(outputs, dim=1) | 163 | outputs = torch.mean(outputs, dim=1) |
| @@ -166,7 +165,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | @@ -166,7 +165,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | ||
| 166 | y_true.extend(batch_y.tolist()) | 165 | y_true.extend(batch_y.tolist()) |
| 167 | y_pred.extend(predicted.tolist()) | 166 | y_pred.extend(predicted.tolist()) |
| 168 | 167 | ||
| 169 | - # Validation accuracy, precision, recall, and F1 score | 168 | + # 验证集准确率、精确率、召回率和F1分数 |
| 170 | accuracy = accuracy_score(y_true, y_pred) | 169 | accuracy = accuracy_score(y_true, y_pred) |
| 171 | precision = precision_score(y_true, y_pred, average='macro') | 170 | precision = precision_score(y_true, y_pred, average='macro') |
| 172 | recall = recall_score(y_true, y_pred, average='macro') | 171 | recall = recall_score(y_true, y_pred, average='macro') |
| @@ -175,7 +174,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | @@ -175,7 +174,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | ||
| 175 | print(f"\nValidation - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") | 174 | print(f"\nValidation - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") |
| 176 | print(confusion_matrix(y_true, y_pred)) | 175 | print(confusion_matrix(y_true, y_pred)) |
| 177 | 176 | ||
| 178 | - # Test set evaluation | 177 | + # 测试集评估 |
| 179 | y_true = [] | 178 | y_true = [] |
| 180 | y_pred = [] | 179 | y_pred = [] |
| 181 | 180 | ||
| @@ -184,14 +183,14 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | @@ -184,14 +183,14 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | ||
| 184 | batch_x = torch.mean(batch_x, dim=1) | 183 | batch_x = torch.mean(batch_x, dim=1) |
| 185 | # cnn_output = extract_CNN_features(batch_x) | 184 | # cnn_output = extract_CNN_features(batch_x) |
| 186 | # batch_x = torch.mean(batch_x, dim=1) | 185 | # batch_x = torch.mean(batch_x, dim=1) |
| 187 | - # cnn_output = torch.cat((batch_x, cnn_output), dim=-1) | 186 | + # cnn_output = torch.cat((batch_x,cnn_output), dim=-1) |
| 188 | attention_output = attention_model(batch_x, batch_x, batch_x) | 187 | attention_output = attention_model(batch_x, batch_x, batch_x) |
| 189 | outputs = classifier_model(attention_output) | 188 | outputs = classifier_model(attention_output) |
| 190 | outputs = torch.mean(outputs, dim=1) | 189 | outputs = torch.mean(outputs, dim=1) |
| 191 | _, predicted = torch.max(outputs, 1) | 190 | _, predicted = torch.max(outputs, 1) |
| 192 | y_true.extend(batch_y.tolist()) | 191 | y_true.extend(batch_y.tolist()) |
| 193 | y_pred.extend(predicted.tolist()) | 192 | y_pred.extend(predicted.tolist()) |
| 194 | - # Test accuracy, precision, recall, and F1 score | 193 | + # 测试集准确率、精确率、召回率和F1分数 |
| 195 | accuracy = accuracy_score(y_true, y_pred) | 194 | accuracy = accuracy_score(y_true, y_pred) |
| 196 | precision = precision_score(y_true, y_pred, average='macro') | 195 | precision = precision_score(y_true, y_pred, average='macro') |
| 197 | recall = recall_score(y_true, y_pred, average='macro') | 196 | recall = recall_score(y_true, y_pred, average='macro') |
| @@ -202,7 +201,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | @@ -202,7 +201,7 @@ def train_model(train_data_path, valid_data_path, test_data_path, train_labels, | ||
| 202 | 201 | ||
| 203 | 202 | ||
| 204 | if __name__ == "__main__": | 203 | if __name__ == "__main__": |
| 205 | - # Load and prepare data | 204 | + # 加载和准备数据 |
| 206 | train_data_path = './train.csv' | 205 | train_data_path = './train.csv' |
| 207 | valid_data_path = './dev.csv' | 206 | valid_data_path = './dev.csv' |
| 208 | test_data_path = './test.csv' | 207 | test_data_path = './test.csv' |
| @@ -215,10 +214,10 @@ if __name__ == "__main__": | @@ -215,10 +214,10 @@ if __name__ == "__main__": | ||
| 215 | valid_labels = valid_data['label'].values | 214 | valid_labels = valid_data['label'].values |
| 216 | test_labels = test_data['label'].values | 215 | test_labels = test_data['label'].values |
| 217 | 216 | ||
| 218 | - # Train model | 217 | + # 训练模型 |
| 219 | bert_model_path = './bert_model' | 218 | bert_model_path = './bert_model' |
| 220 | ctm_tokenizer_path = './sentence_bert_model' | 219 | ctm_tokenizer_path = './sentence_bert_model' |
| 221 | 220 | ||
| 222 | - # Train model | 221 | + # 训练模型 |
| 223 | train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels, | 222 | train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels, |
| 224 | bert_model_path, ctm_tokenizer_path, num_heads=12, num_classes=2, model_save_path='./final_model.pt') | 223 | bert_model_path, ctm_tokenizer_path, num_heads=12, num_classes=2, model_save_path='./final_model.pt') |
-
Please register or login to post a comment