戒酒的李白

BCAT Preliminary

1 -import os 1 +import torch
  2 +import pandas as pd
2 import numpy as np 3 import numpy as np
3 -from BERT_CTM import BERT_CTM_Model # 假设BERT_CTM模型在这个文件中 4 +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
  5 +from torch.utils.data import DataLoader, TensorDataset
  6 +from CNN import extract_CNN_features
  7 +from MHA import MultiHeadAttentionLayer
  8 +from classifier import FinalClassifier
  9 +from BERT_CTM import BERT_CTM_Model
  10 +import os
  11 +from tqdm import tqdm
  12 +from sklearn.metrics import confusion_matrix
4 13
5 -# BERT_CTM 嵌入生成和加载函数 14 +
  15 +# BERT_CTM embeddings generation and loading function
6 def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20, save_path=None): 16 def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20, save_path=None):
7 - """  
8 - 获取或生成 BERT+CTM 嵌入,并保存到文件。  
9 -  
10 - :param texts: 需要嵌入的文本  
11 - :param bert_model_path: BERT 模型的路径  
12 - :param ctm_tokenizer_path: CTM tokenizer 的路径  
13 - :param n_components: 生成的主题数量  
14 - :param num_epochs: 训练的epoch数  
15 - :param save_path: 嵌入保存路径  
16 - :return: 生成或加载的嵌入  
17 - """  
18 - # 检查是否已经存在保存的嵌入文件 17 + # Check if saved embeddings already exist
19 if save_path and os.path.exists(save_path): 18 if save_path and os.path.exists(save_path):
20 - print(f"从文件 {save_path} 加载嵌入...") 19 + print(f"Loading embeddings from {save_path}...")
21 embeddings = np.load(save_path) 20 embeddings = np.load(save_path)
22 else: 21 else:
23 - print("生成 BERT+CTM 嵌入...") 22 + print("Generating BERT+CTM embeddings...")
24 bert_ctm_model = BERT_CTM_Model( 23 bert_ctm_model = BERT_CTM_Model(
25 bert_model_path=bert_model_path, 24 bert_model_path=bert_model_path,
26 ctm_tokenizer_path=ctm_tokenizer_path, 25 ctm_tokenizer_path=ctm_tokenizer_path,
27 n_components=n_components, 26 n_components=n_components,
28 num_epochs=num_epochs 27 num_epochs=num_epochs
29 ) 28 )
30 - embeddings = bert_ctm_model.train(texts) # 生成嵌入 29 + embeddings = bert_ctm_model.train(texts) # Generate embeddings
31 30
32 - # 保存嵌入到文件 31 + # Save embeddings to file
33 if save_path: 32 if save_path:
34 - print(f"保存嵌入到文件 {save_path}...") 33 + print(f"Saving embeddings to file {save_path}...")
35 np.save(save_path, embeddings) 34 np.save(save_path, embeddings)
36 35
37 return embeddings 36 return embeddings
38 37
39 38
  39 +# Data loading and preparation function
  40 +def prepare_dataloader(features, labels, batch_size):
  41 + """Create DataLoader for training, validation, and testing"""
  42 + tensor_x = torch.tensor(features, dtype=torch.float32)
  43 + tensor_y = torch.tensor(labels, dtype=torch.long)
  44 + dataset = TensorDataset(tensor_x, tensor_y)
  45 + return DataLoader(dataset, batch_size=batch_size, shuffle=True)
  46 +
  47 +
  48 +# Model training function
  49 +def train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels,
  50 + bert_model_path, ctm_tokenizer_path, num_heads=8, num_classes=2, epochs=10, batch_size=128,
  51 + learning_rate=5e-3, model_save_path='./final_model.pt'):
  52 + # Step 1: Get BERT+CTM embeddings
  53 + print("Step 1: Getting BERT+CTM embeddings...")
  54 + valid_features = get_bert_ctm_embeddings(valid_data_path, bert_model_path, ctm_tokenizer_path,
  55 + save_path='valid_embeddings.npy')
  56 + test_features = get_bert_ctm_embeddings(test_data_path, bert_model_path, ctm_tokenizer_path,
  57 + save_path='test_embeddings.npy')
  58 + train_features = get_bert_ctm_embeddings(train_data_path, bert_model_path, ctm_tokenizer_path,
  59 + save_path='train_embeddings.npy')
  60 +
  61 + # Save labels to .npy file
  62 + print("Saving labels to labels.npy file...")
  63 + np.save('train_labels.npy', train_labels)
  64 + np.save('valid_labels.npy', valid_labels)
  65 + np.save('test_labels.npy', test_labels)
  66 +
  67 + # Step 2: Validate label correctness
  68 + print("Step 2: Validating label correctness...")
  69 + unique_labels_train = np.unique(train_labels)
  70 + unique_labels_valid = np.unique(valid_labels)
  71 + unique_labels_test = np.unique(test_labels)
  72 + print(f"Unique train labels: {unique_labels_train}")
  73 + print(f"Train set class distribution: {np.bincount(train_labels)}")
  74 + print(f"Unique validation labels: {unique_labels_valid}")
  75 + print(f"Validation set class distribution: {np.bincount(valid_labels)}")
  76 + print(f"Unique test labels: {unique_labels_test}")
  77 + print(f"Test set class distribution: {np.bincount(test_labels)}")
  78 +
  79 + if len(unique_labels_train) != num_classes or len(unique_labels_valid) != num_classes or len(
  80 + unique_labels_test) != num_classes:
  81 + raise ValueError(f"Number of classes in labels does not match expected: expected {num_classes}, "
  82 + f"but found different classes in training, validation, or test sets")
  83 +
  84 + # Step 3: Create DataLoader
  85 + print("Step 3: Creating DataLoader...")
  86 + train_loader = prepare_dataloader(train_features, train_labels, batch_size)
  87 + valid_loader = prepare_dataloader(valid_features, valid_labels, batch_size)
  88 + test_loader = prepare_dataloader(test_features, test_labels, batch_size)
  89 +
  90 + # Step 4: Initialize CNN
  91 + print("Step 4: Initializing CNN...")
  92 + num_filters = 256 # Use 256 convolutional output channels
  93 + kernel_sizes = [2, 3, 4] # Kernel sizes for convolution
  94 + k = 3 * len(kernel_sizes)
  95 + cnn_output_dim = num_filters * (k + 1) # Calculate the output feature dimension of CNN
  96 +
  97 + # Step 5: Initialize attention mechanism
  98 + print("Step 5: Initializing multi-head attention...")
  99 + attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
  100 +
  101 + # Step 6: Initialize classifier
  102 + print("Step 6: Initializing classifier...")
  103 + classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes)
  104 + optimizer = torch.optim.Adam(classifier_model.parameters(), lr=learning_rate)
  105 + criterion = torch.nn.CrossEntropyLoss()
  106 +
  107 + # Step 7: Start training
  108 + print("Starting training...")
  109 + torch.autograd.set_detect_anomaly(True)
  110 + for epoch in range(epochs):
  111 + classifier_model.train()
  112 + epoch_loss = 0
  113 + y_true = []
  114 + y_pred = []
  115 +
  116 + # Use tqdm to add progress bar for CNN feature extraction
  117 + for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} - Training"):
  118 + optimizer.zero_grad()
  119 + batch_x = torch.mean(batch_x, dim=1)
  120 + # Extract features from CNN
  121 + # cnn_output = extract_CNN_features(batch_x)
  122 + # batch_x = torch.mean(batch_x, dim=1)
  123 + # cnn_output = torch.cat((batch_x, cnn_output), dim=-1)
  124 + attention_output = attention_model(batch_x, batch_x, batch_x)
  125 + outputs = classifier_model(attention_output)
  126 + outputs = torch.mean(outputs, dim=1)
  127 + loss = criterion(outputs, batch_y) # Compute loss
  128 + loss.backward() # Backpropagation
  129 + optimizer.step() # Optimize
  130 +
  131 + epoch_loss += loss.item()
  132 +
  133 + _, predicted = torch.max(outputs, 1) # Get predicted class
  134 + y_true.extend(batch_y.tolist())
  135 + y_pred.extend(predicted.tolist())
  136 +
  137 + # Calculate training accuracy, precision, recall, and F1 score
  138 + accuracy = accuracy_score(y_true, y_pred)
  139 + precision = precision_score(y_true, y_pred, average='macro')
  140 + recall = recall_score(y_true, y_pred, average='macro')
  141 + f1 = f1_score(y_true, y_pred, average='macro')
  142 +
  143 + print(
  144 + f"Epoch [{epoch + 1}/{epochs}] Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
  145 + print(confusion_matrix(y_true, y_pred))
  146 +
  147 + # Save model
  148 + torch.save(classifier_model, model_save_path)
  149 + print(f"Trained model has been saved to {model_save_path}")
  150 +
  151 + # Validation set evaluation
  152 + classifier_model.eval()
  153 + y_true = []
  154 + y_pred = []
  155 +
  156 + with torch.no_grad():
  157 + for batch_x, batch_y in valid_loader:
  158 + batch_x = torch.mean(batch_x, dim=1)
  159 + # cnn_output = extract_CNN_features(batch_x)
  160 + # batch_x = torch.mean(batch_x, dim=1)
  161 + # cnn_output = torch.cat((batch_x, cnn_output), dim=-1)
  162 + attention_output = attention_model(batch_x, batch_x, batch_x)
  163 + outputs = classifier_model(attention_output)
  164 + outputs = torch.mean(outputs, dim=1)
  165 + _, predicted = torch.max(outputs, 1)
  166 + y_true.extend(batch_y.tolist())
  167 + y_pred.extend(predicted.tolist())
  168 +
  169 + # Validation accuracy, precision, recall, and F1 score
  170 + accuracy = accuracy_score(y_true, y_pred)
  171 + precision = precision_score(y_true, y_pred, average='macro')
  172 + recall = recall_score(y_true, y_pred, average='macro')
  173 + f1 = f1_score(y_true, y_pred, average='macro')
  174 +
  175 + print(f"\nValidation - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
  176 + print(confusion_matrix(y_true, y_pred))
  177 +
  178 + # Test set evaluation
  179 + y_true = []
  180 + y_pred = []
  181 +
  182 + with torch.no_grad():
  183 + for batch_x, batch_y in test_loader:
  184 + batch_x = torch.mean(batch_x, dim=1)
  185 + # cnn_output = extract_CNN_features(batch_x)
  186 + # batch_x = torch.mean(batch_x, dim=1)
  187 + # cnn_output = torch.cat((batch_x, cnn_output), dim=-1)
  188 + attention_output = attention_model(batch_x, batch_x, batch_x)
  189 + outputs = classifier_model(attention_output)
  190 + outputs = torch.mean(outputs, dim=1)
  191 + _, predicted = torch.max(outputs, 1)
  192 + y_true.extend(batch_y.tolist())
  193 + y_pred.extend(predicted.tolist())
  194 + # Test accuracy, precision, recall, and F1 score
  195 + accuracy = accuracy_score(y_true, y_pred)
  196 + precision = precision_score(y_true, y_pred, average='macro')
  197 + recall = recall_score(y_true, y_pred, average='macro')
  198 + f1 = f1_score(y_true, y_pred, average='macro')
  199 +
  200 + print(f"\nTest - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
  201 + print(confusion_matrix(y_true, y_pred))
  202 +
  203 +
40 if __name__ == "__main__": 204 if __name__ == "__main__":
41 - # 示例调用  
42 - sample_texts = ["This is a test text.", "Another example of text data."] 205 + # Load and prepare data
  206 + train_data_path = './train.csv'
  207 + valid_data_path = './dev.csv'
  208 + test_data_path = './test.csv'
  209 +
  210 + train_data = pd.read_csv(train_data_path)
  211 + valid_data = pd.read_csv(valid_data_path)
  212 + test_data = pd.read_csv(test_data_path)
  213 +
  214 + train_labels = train_data['label'].values
  215 + valid_labels = valid_data['label'].values
  216 + test_labels = test_data['label'].values
  217 +
  218 + # Train model
43 bert_model_path = './bert_model' 219 bert_model_path = './bert_model'
44 ctm_tokenizer_path = './sentence_bert_model' 220 ctm_tokenizer_path = './sentence_bert_model'
45 - save_path = 'sample_embeddings.npy'  
46 -  
47 - # 生成或加载 BERT+CTM 嵌入  
48 - embeddings = get_bert_ctm_embeddings(sample_texts, bert_model_path, ctm_tokenizer_path, save_path=save_path)  
49 221
50 - # 打印嵌入形状  
51 - print(f"嵌入形状: {embeddings.shape}") 222 + # Train model
  223 + train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels,
  224 + bert_model_path, ctm_tokenizer_path, num_heads=12, num_classes=2, model_save_path='./final_model.pt')