Showing
1 changed file
with
201 additions
and
28 deletions
| 1 | -import os | 1 | +import torch |
| 2 | +import pandas as pd | ||
| 2 | import numpy as np | 3 | import numpy as np |
| 3 | -from BERT_CTM import BERT_CTM_Model # 假设BERT_CTM模型在这个文件中 | 4 | +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score |
| 5 | +from torch.utils.data import DataLoader, TensorDataset | ||
| 6 | +from CNN import extract_CNN_features | ||
| 7 | +from MHA import MultiHeadAttentionLayer | ||
| 8 | +from classifier import FinalClassifier | ||
| 9 | +from BERT_CTM import BERT_CTM_Model | ||
| 10 | +import os | ||
| 11 | +from tqdm import tqdm | ||
| 12 | +from sklearn.metrics import confusion_matrix | ||
| 4 | 13 | ||
| 5 | -# BERT_CTM 嵌入生成和加载函数 | 14 | + |
| 15 | +# BERT_CTM embeddings generation and loading function | ||
| 6 | def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20, save_path=None): | 16 | def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20, save_path=None): |
| 7 | - """ | ||
| 8 | - 获取或生成 BERT+CTM 嵌入,并保存到文件。 | ||
| 9 | - | ||
| 10 | - :param texts: 需要嵌入的文本 | ||
| 11 | - :param bert_model_path: BERT 模型的路径 | ||
| 12 | - :param ctm_tokenizer_path: CTM tokenizer 的路径 | ||
| 13 | - :param n_components: 生成的主题数量 | ||
| 14 | - :param num_epochs: 训练的epoch数 | ||
| 15 | - :param save_path: 嵌入保存路径 | ||
| 16 | - :return: 生成或加载的嵌入 | ||
| 17 | - """ | ||
| 18 | - # 检查是否已经存在保存的嵌入文件 | 17 | + # Check if saved embeddings already exist |
| 19 | if save_path and os.path.exists(save_path): | 18 | if save_path and os.path.exists(save_path): |
| 20 | - print(f"从文件 {save_path} 加载嵌入...") | 19 | + print(f"Loading embeddings from {save_path}...") |
| 21 | embeddings = np.load(save_path) | 20 | embeddings = np.load(save_path) |
| 22 | else: | 21 | else: |
| 23 | - print("生成 BERT+CTM 嵌入...") | 22 | + print("Generating BERT+CTM embeddings...") |
| 24 | bert_ctm_model = BERT_CTM_Model( | 23 | bert_ctm_model = BERT_CTM_Model( |
| 25 | bert_model_path=bert_model_path, | 24 | bert_model_path=bert_model_path, |
| 26 | ctm_tokenizer_path=ctm_tokenizer_path, | 25 | ctm_tokenizer_path=ctm_tokenizer_path, |
| 27 | n_components=n_components, | 26 | n_components=n_components, |
| 28 | num_epochs=num_epochs | 27 | num_epochs=num_epochs |
| 29 | ) | 28 | ) |
| 30 | - embeddings = bert_ctm_model.train(texts) # 生成嵌入 | 29 | + embeddings = bert_ctm_model.train(texts) # Generate embeddings |
| 31 | 30 | ||
| 32 | - # 保存嵌入到文件 | 31 | + # Save embeddings to file |
| 33 | if save_path: | 32 | if save_path: |
| 34 | - print(f"保存嵌入到文件 {save_path}...") | 33 | + print(f"Saving embeddings to file {save_path}...") |
| 35 | np.save(save_path, embeddings) | 34 | np.save(save_path, embeddings) |
| 36 | 35 | ||
| 37 | return embeddings | 36 | return embeddings |
| 38 | 37 | ||
| 39 | 38 | ||
| 39 | +# Data loading and preparation function | ||
| 40 | +def prepare_dataloader(features, labels, batch_size): | ||
| 41 | + """Create DataLoader for training, validation, and testing""" | ||
| 42 | + tensor_x = torch.tensor(features, dtype=torch.float32) | ||
| 43 | + tensor_y = torch.tensor(labels, dtype=torch.long) | ||
| 44 | + dataset = TensorDataset(tensor_x, tensor_y) | ||
| 45 | + return DataLoader(dataset, batch_size=batch_size, shuffle=True) | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +# Model training function | ||
| 49 | +def train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels, | ||
| 50 | + bert_model_path, ctm_tokenizer_path, num_heads=8, num_classes=2, epochs=10, batch_size=128, | ||
| 51 | + learning_rate=5e-3, model_save_path='./final_model.pt'): | ||
| 52 | + # Step 1: Get BERT+CTM embeddings | ||
| 53 | + print("Step 1: Getting BERT+CTM embeddings...") | ||
| 54 | + valid_features = get_bert_ctm_embeddings(valid_data_path, bert_model_path, ctm_tokenizer_path, | ||
| 55 | + save_path='valid_embeddings.npy') | ||
| 56 | + test_features = get_bert_ctm_embeddings(test_data_path, bert_model_path, ctm_tokenizer_path, | ||
| 57 | + save_path='test_embeddings.npy') | ||
| 58 | + train_features = get_bert_ctm_embeddings(train_data_path, bert_model_path, ctm_tokenizer_path, | ||
| 59 | + save_path='train_embeddings.npy') | ||
| 60 | + | ||
| 61 | + # Save labels to .npy file | ||
| 62 | + print("Saving labels to labels.npy file...") | ||
| 63 | + np.save('train_labels.npy', train_labels) | ||
| 64 | + np.save('valid_labels.npy', valid_labels) | ||
| 65 | + np.save('test_labels.npy', test_labels) | ||
| 66 | + | ||
| 67 | + # Step 2: Validate label correctness | ||
| 68 | + print("Step 2: Validating label correctness...") | ||
| 69 | + unique_labels_train = np.unique(train_labels) | ||
| 70 | + unique_labels_valid = np.unique(valid_labels) | ||
| 71 | + unique_labels_test = np.unique(test_labels) | ||
| 72 | + print(f"Unique train labels: {unique_labels_train}") | ||
| 73 | + print(f"Train set class distribution: {np.bincount(train_labels)}") | ||
| 74 | + print(f"Unique validation labels: {unique_labels_valid}") | ||
| 75 | + print(f"Validation set class distribution: {np.bincount(valid_labels)}") | ||
| 76 | + print(f"Unique test labels: {unique_labels_test}") | ||
| 77 | + print(f"Test set class distribution: {np.bincount(test_labels)}") | ||
| 78 | + | ||
| 79 | + if len(unique_labels_train) != num_classes or len(unique_labels_valid) != num_classes or len( | ||
| 80 | + unique_labels_test) != num_classes: | ||
| 81 | + raise ValueError(f"Number of classes in labels does not match expected: expected {num_classes}, " | ||
| 82 | + f"but found different classes in training, validation, or test sets") | ||
| 83 | + | ||
| 84 | + # Step 3: Create DataLoader | ||
| 85 | + print("Step 3: Creating DataLoader...") | ||
| 86 | + train_loader = prepare_dataloader(train_features, train_labels, batch_size) | ||
| 87 | + valid_loader = prepare_dataloader(valid_features, valid_labels, batch_size) | ||
| 88 | + test_loader = prepare_dataloader(test_features, test_labels, batch_size) | ||
| 89 | + | ||
| 90 | + # Step 4: Initialize CNN | ||
| 91 | + print("Step 4: Initializing CNN...") | ||
| 92 | + num_filters = 256 # Use 256 convolutional output channels | ||
| 93 | + kernel_sizes = [2, 3, 4] # Kernel sizes for convolution | ||
| 94 | + k = 3 * len(kernel_sizes) | ||
| 95 | + cnn_output_dim = num_filters * (k + 1) # Calculate the output feature dimension of CNN | ||
| 96 | + | ||
| 97 | + # Step 5: Initialize attention mechanism | ||
| 98 | + print("Step 5: Initializing multi-head attention...") | ||
| 99 | + attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) | ||
| 100 | + | ||
| 101 | + # Step 6: Initialize classifier | ||
| 102 | + print("Step 6: Initializing classifier...") | ||
| 103 | + classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes) | ||
| 104 | + optimizer = torch.optim.Adam(classifier_model.parameters(), lr=learning_rate) | ||
| 105 | + criterion = torch.nn.CrossEntropyLoss() | ||
| 106 | + | ||
| 107 | + # Step 7: Start training | ||
| 108 | + print("Starting training...") | ||
| 109 | + torch.autograd.set_detect_anomaly(True) | ||
| 110 | + for epoch in range(epochs): | ||
| 111 | + classifier_model.train() | ||
| 112 | + epoch_loss = 0 | ||
| 113 | + y_true = [] | ||
| 114 | + y_pred = [] | ||
| 115 | + | ||
| 116 | + # Use tqdm to add progress bar for CNN feature extraction | ||
| 117 | + for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} - Training"): | ||
| 118 | + optimizer.zero_grad() | ||
| 119 | + batch_x = torch.mean(batch_x, dim=1) | ||
| 120 | + # Extract features from CNN | ||
| 121 | + # cnn_output = extract_CNN_features(batch_x) | ||
| 122 | + # batch_x = torch.mean(batch_x, dim=1) | ||
| 123 | + # cnn_output = torch.cat((batch_x, cnn_output), dim=-1) | ||
| 124 | + attention_output = attention_model(batch_x, batch_x, batch_x) | ||
| 125 | + outputs = classifier_model(attention_output) | ||
| 126 | + outputs = torch.mean(outputs, dim=1) | ||
| 127 | + loss = criterion(outputs, batch_y) # Compute loss | ||
| 128 | + loss.backward() # Backpropagation | ||
| 129 | + optimizer.step() # Optimize | ||
| 130 | + | ||
| 131 | + epoch_loss += loss.item() | ||
| 132 | + | ||
| 133 | + _, predicted = torch.max(outputs, 1) # Get predicted class | ||
| 134 | + y_true.extend(batch_y.tolist()) | ||
| 135 | + y_pred.extend(predicted.tolist()) | ||
| 136 | + | ||
| 137 | + # Calculate training accuracy, precision, recall, and F1 score | ||
| 138 | + accuracy = accuracy_score(y_true, y_pred) | ||
| 139 | + precision = precision_score(y_true, y_pred, average='macro') | ||
| 140 | + recall = recall_score(y_true, y_pred, average='macro') | ||
| 141 | + f1 = f1_score(y_true, y_pred, average='macro') | ||
| 142 | + | ||
| 143 | + print( | ||
| 144 | + f"Epoch [{epoch + 1}/{epochs}] Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") | ||
| 145 | + print(confusion_matrix(y_true, y_pred)) | ||
| 146 | + | ||
| 147 | + # Save model | ||
| 148 | + torch.save(classifier_model, model_save_path) | ||
| 149 | + print(f"Trained model has been saved to {model_save_path}") | ||
| 150 | + | ||
| 151 | + # Validation set evaluation | ||
| 152 | + classifier_model.eval() | ||
| 153 | + y_true = [] | ||
| 154 | + y_pred = [] | ||
| 155 | + | ||
| 156 | + with torch.no_grad(): | ||
| 157 | + for batch_x, batch_y in valid_loader: | ||
| 158 | + batch_x = torch.mean(batch_x, dim=1) | ||
| 159 | + # cnn_output = extract_CNN_features(batch_x) | ||
| 160 | + # batch_x = torch.mean(batch_x, dim=1) | ||
| 161 | + # cnn_output = torch.cat((batch_x, cnn_output), dim=-1) | ||
| 162 | + attention_output = attention_model(batch_x, batch_x, batch_x) | ||
| 163 | + outputs = classifier_model(attention_output) | ||
| 164 | + outputs = torch.mean(outputs, dim=1) | ||
| 165 | + _, predicted = torch.max(outputs, 1) | ||
| 166 | + y_true.extend(batch_y.tolist()) | ||
| 167 | + y_pred.extend(predicted.tolist()) | ||
| 168 | + | ||
| 169 | + # Validation accuracy, precision, recall, and F1 score | ||
| 170 | + accuracy = accuracy_score(y_true, y_pred) | ||
| 171 | + precision = precision_score(y_true, y_pred, average='macro') | ||
| 172 | + recall = recall_score(y_true, y_pred, average='macro') | ||
| 173 | + f1 = f1_score(y_true, y_pred, average='macro') | ||
| 174 | + | ||
| 175 | + print(f"\nValidation - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") | ||
| 176 | + print(confusion_matrix(y_true, y_pred)) | ||
| 177 | + | ||
| 178 | + # Test set evaluation | ||
| 179 | + y_true = [] | ||
| 180 | + y_pred = [] | ||
| 181 | + | ||
| 182 | + with torch.no_grad(): | ||
| 183 | + for batch_x, batch_y in test_loader: | ||
| 184 | + batch_x = torch.mean(batch_x, dim=1) | ||
| 185 | + # cnn_output = extract_CNN_features(batch_x) | ||
| 186 | + # batch_x = torch.mean(batch_x, dim=1) | ||
| 187 | + # cnn_output = torch.cat((batch_x, cnn_output), dim=-1) | ||
| 188 | + attention_output = attention_model(batch_x, batch_x, batch_x) | ||
| 189 | + outputs = classifier_model(attention_output) | ||
| 190 | + outputs = torch.mean(outputs, dim=1) | ||
| 191 | + _, predicted = torch.max(outputs, 1) | ||
| 192 | + y_true.extend(batch_y.tolist()) | ||
| 193 | + y_pred.extend(predicted.tolist()) | ||
| 194 | + # Test accuracy, precision, recall, and F1 score | ||
| 195 | + accuracy = accuracy_score(y_true, y_pred) | ||
| 196 | + precision = precision_score(y_true, y_pred, average='macro') | ||
| 197 | + recall = recall_score(y_true, y_pred, average='macro') | ||
| 198 | + f1 = f1_score(y_true, y_pred, average='macro') | ||
| 199 | + | ||
| 200 | + print(f"\nTest - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") | ||
| 201 | + print(confusion_matrix(y_true, y_pred)) | ||
| 202 | + | ||
| 203 | + | ||
| 40 | if __name__ == "__main__": | 204 | if __name__ == "__main__": |
| 41 | - # 示例调用 | ||
| 42 | - sample_texts = ["This is a test text.", "Another example of text data."] | 205 | + # Load and prepare data |
| 206 | + train_data_path = './train.csv' | ||
| 207 | + valid_data_path = './dev.csv' | ||
| 208 | + test_data_path = './test.csv' | ||
| 209 | + | ||
| 210 | + train_data = pd.read_csv(train_data_path) | ||
| 211 | + valid_data = pd.read_csv(valid_data_path) | ||
| 212 | + test_data = pd.read_csv(test_data_path) | ||
| 213 | + | ||
| 214 | + train_labels = train_data['label'].values | ||
| 215 | + valid_labels = valid_data['label'].values | ||
| 216 | + test_labels = test_data['label'].values | ||
| 217 | + | ||
| 218 | + # Train model | ||
| 43 | bert_model_path = './bert_model' | 219 | bert_model_path = './bert_model' |
| 44 | ctm_tokenizer_path = './sentence_bert_model' | 220 | ctm_tokenizer_path = './sentence_bert_model' |
| 45 | - save_path = 'sample_embeddings.npy' | ||
| 46 | - | ||
| 47 | - # 生成或加载 BERT+CTM 嵌入 | ||
| 48 | - embeddings = get_bert_ctm_embeddings(sample_texts, bert_model_path, ctm_tokenizer_path, save_path=save_path) | ||
| 49 | 221 | ||
| 50 | - # 打印嵌入形状 | ||
| 51 | - print(f"嵌入形状: {embeddings.shape}") | 222 | + # Train model |
| 223 | + train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels, | ||
| 224 | + bert_model_path, ctm_tokenizer_path, num_heads=12, num_classes=2, model_save_path='./final_model.pt') |
-
Please register or login to post a comment