BCAT.py 9.66 KB
import torch
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.utils.data import DataLoader, TensorDataset
from CNN import extract_CNN_features
from MHA import MultiHeadAttentionLayer
from classifier import FinalClassifier
from BERT_CTM import BERT_CTM_Model
import os
from tqdm import tqdm
from sklearn.metrics import confusion_matrix


# BERT_CTM embeddings generation and loading function
def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20, save_path=None):
    # Check if saved embeddings already exist
    if save_path and os.path.exists(save_path):
        print(f"Loading embeddings from {save_path}...")
        embeddings = np.load(save_path)
    else:
        print("Generating BERT+CTM embeddings...")
        bert_ctm_model = BERT_CTM_Model(
            bert_model_path=bert_model_path,
            ctm_tokenizer_path=ctm_tokenizer_path,
            n_components=n_components,
            num_epochs=num_epochs
        )
        embeddings = bert_ctm_model.train(texts)  # Generate embeddings

        # Save embeddings to file
        if save_path:
            print(f"Saving embeddings to file {save_path}...")
            np.save(save_path, embeddings)

    return embeddings


# Data loading and preparation function
def prepare_dataloader(features, labels, batch_size):
    """Create DataLoader for training, validation, and testing"""
    tensor_x = torch.tensor(features, dtype=torch.float32)
    tensor_y = torch.tensor(labels, dtype=torch.long)
    dataset = TensorDataset(tensor_x, tensor_y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


# Model training function
def train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels,
                bert_model_path, ctm_tokenizer_path, num_heads=8, num_classes=2, epochs=10, batch_size=128,
                learning_rate=5e-3, model_save_path='./final_model.pt'):
    # Step 1: Get BERT+CTM embeddings
    print("Step 1: Getting BERT+CTM embeddings...")
    valid_features = get_bert_ctm_embeddings(valid_data_path, bert_model_path, ctm_tokenizer_path,
                                             save_path='valid_embeddings.npy')
    test_features = get_bert_ctm_embeddings(test_data_path, bert_model_path, ctm_tokenizer_path,
                                            save_path='test_embeddings.npy')
    train_features = get_bert_ctm_embeddings(train_data_path, bert_model_path, ctm_tokenizer_path,
                                             save_path='train_embeddings.npy')

    # Save labels to .npy file
    print("Saving labels to labels.npy file...")
    np.save('train_labels.npy', train_labels)
    np.save('valid_labels.npy', valid_labels)
    np.save('test_labels.npy', test_labels)

    # Step 2: Validate label correctness
    print("Step 2: Validating label correctness...")
    unique_labels_train = np.unique(train_labels)
    unique_labels_valid = np.unique(valid_labels)
    unique_labels_test = np.unique(test_labels)
    print(f"Unique train labels: {unique_labels_train}")
    print(f"Train set class distribution: {np.bincount(train_labels)}")
    print(f"Unique validation labels: {unique_labels_valid}")
    print(f"Validation set class distribution: {np.bincount(valid_labels)}")
    print(f"Unique test labels: {unique_labels_test}")
    print(f"Test set class distribution: {np.bincount(test_labels)}")

    if len(unique_labels_train) != num_classes or len(unique_labels_valid) != num_classes or len(
            unique_labels_test) != num_classes:
        raise ValueError(f"Number of classes in labels does not match expected: expected {num_classes}, "
                         f"but found different classes in training, validation, or test sets")

    # Step 3: Create DataLoader
    print("Step 3: Creating DataLoader...")
    train_loader = prepare_dataloader(train_features, train_labels, batch_size)
    valid_loader = prepare_dataloader(valid_features, valid_labels, batch_size)
    test_loader = prepare_dataloader(test_features, test_labels, batch_size)

    # Step 4: Initialize CNN
    print("Step 4: Initializing CNN...")
    num_filters = 256  # Use 256 convolutional output channels
    kernel_sizes = [2, 3, 4]  # Kernel sizes for convolution
    k = 3 * len(kernel_sizes)
    cnn_output_dim = num_filters * (k + 1)  # Calculate the output feature dimension of CNN

    # Step 5: Initialize attention mechanism
    print("Step 5: Initializing multi-head attention...")
    attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)

    # Step 6: Initialize classifier
    print("Step 6: Initializing classifier...")
    classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes)
    optimizer = torch.optim.Adam(classifier_model.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()

    # Step 7: Start training
    print("Starting training...")
    torch.autograd.set_detect_anomaly(True)
    for epoch in range(epochs):
        classifier_model.train()
        epoch_loss = 0
        y_true = []
        y_pred = []

        # Use tqdm to add progress bar for CNN feature extraction
        for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} - Training"):
            optimizer.zero_grad()
            batch_x = torch.mean(batch_x, dim=1)
            # Extract features from CNN
            # cnn_output = extract_CNN_features(batch_x)
            # batch_x = torch.mean(batch_x, dim=1)
            # cnn_output = torch.cat((batch_x, cnn_output), dim=-1)
            attention_output = attention_model(batch_x, batch_x, batch_x)
            outputs = classifier_model(attention_output)
            outputs = torch.mean(outputs, dim=1)
            loss = criterion(outputs, batch_y)  # Compute loss
            loss.backward()  # Backpropagation
            optimizer.step()  # Optimize

            epoch_loss += loss.item()

            _, predicted = torch.max(outputs, 1)  # Get predicted class
            y_true.extend(batch_y.tolist())
            y_pred.extend(predicted.tolist())

        # Calculate training accuracy, precision, recall, and F1 score
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='macro')
        recall = recall_score(y_true, y_pred, average='macro')
        f1 = f1_score(y_true, y_pred, average='macro')

        print(
            f"Epoch [{epoch + 1}/{epochs}] Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
        print(confusion_matrix(y_true, y_pred))

    # Save model
    torch.save(classifier_model, model_save_path)
    print(f"Trained model has been saved to {model_save_path}")

    # Validation set evaluation
    classifier_model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for batch_x, batch_y in valid_loader:
            batch_x = torch.mean(batch_x, dim=1)
            # cnn_output = extract_CNN_features(batch_x)
            # batch_x = torch.mean(batch_x, dim=1)
            # cnn_output = torch.cat((batch_x, cnn_output), dim=-1)
            attention_output = attention_model(batch_x, batch_x, batch_x)
            outputs = classifier_model(attention_output)
            outputs = torch.mean(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(batch_y.tolist())
            y_pred.extend(predicted.tolist())

    # Validation accuracy, precision, recall, and F1 score
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='macro')
    recall = recall_score(y_true, y_pred, average='macro')
    f1 = f1_score(y_true, y_pred, average='macro')

    print(f"\nValidation - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    print(confusion_matrix(y_true, y_pred))

    # Test set evaluation
    y_true = []
    y_pred = []

    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            batch_x = torch.mean(batch_x, dim=1)
            # cnn_output = extract_CNN_features(batch_x)
            # batch_x = torch.mean(batch_x, dim=1)
            # cnn_output = torch.cat((batch_x, cnn_output), dim=-1)
            attention_output = attention_model(batch_x, batch_x, batch_x)
            outputs = classifier_model(attention_output)
            outputs = torch.mean(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(batch_y.tolist())
            y_pred.extend(predicted.tolist())
    # Test accuracy, precision, recall, and F1 score
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='macro')
    recall = recall_score(y_true, y_pred, average='macro')
    f1 = f1_score(y_true, y_pred, average='macro')

    print(f"\nTest - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    print(confusion_matrix(y_true, y_pred))


if __name__ == "__main__":
    # Load and prepare data
    train_data_path = './train.csv'
    valid_data_path = './dev.csv'
    test_data_path = './test.csv'

    train_data = pd.read_csv(train_data_path)
    valid_data = pd.read_csv(valid_data_path)
    test_data = pd.read_csv(test_data_path)

    train_labels = train_data['label'].values
    valid_labels = valid_data['label'].values
    test_labels = test_data['label'].values

    # Train model
    bert_model_path = './bert_model'
    ctm_tokenizer_path = './sentence_bert_model'

    # Train model
    train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels,
                bert_model_path, ctm_tokenizer_path, num_heads=12, num_classes=2, model_save_path='./final_model.pt')