base_model.py 6.26 KB
# -*- coding: utf-8 -*-
"""
Qwen3模型基础类,统一接口
"""
import os
import pickle
from abc import ABC, abstractmethod
from typing import List, Tuple, Dict, Any
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import train_test_split


class BaseQwenModel(ABC):
    """Qwen3情感分析模型基类"""
    
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.model = None
        self.is_trained = False
        
    @abstractmethod
    def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
        """训练模型"""
        pass
    
    @abstractmethod
    def predict(self, texts: List[str]) -> List[int]:
        """预测文本情感"""
        pass
    
    def predict_single(self, text: str) -> Tuple[int, float]:
        """预测单条文本的情感
        
        Args:
            text: 待预测文本
            
        Returns:
            (predicted_label, confidence)
        """
        predictions = self.predict([text])
        return predictions[0], 0.0  # 默认置信度为0
    
    def evaluate(self, test_data: List[Tuple[str, int]]) -> Dict[str, float]:
        """评估模型性能"""
        if not self.is_trained:
            raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
            
        texts = [item[0] for item in test_data]
        labels = [item[1] for item in test_data]
        
        predictions = self.predict(texts)
        
        accuracy = accuracy_score(labels, predictions)
        f1 = f1_score(labels, predictions, average='weighted')
        
        print(f"\n{self.model_name} 模型评估结果:")
        print(f"准确率: {accuracy:.4f}")
        print(f"F1分数: {f1:.4f}")
        print("\n详细报告:")
        print(classification_report(labels, predictions))
        
        return {
            'accuracy': accuracy,
            'f1_score': f1,
            'classification_report': classification_report(labels, predictions)
        }
    
    @abstractmethod
    def save_model(self, model_path: str = None) -> None:
        """保存模型到文件"""
        pass
    
    @abstractmethod
    def load_model(self, model_path: str) -> None:
        """从文件加载模型"""
        pass
    
    @staticmethod
    def load_data(train_path: str = None, test_path: str = None, csv_path: str = 'dataset/weibo_senti_100k.csv') -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
        """加载训练和测试数据
        
        Args:
            train_path: 训练数据txt文件路径(可选)
            test_path: 测试数据txt文件路径(可选)
            csv_path: CSV数据文件路径(默认使用)
        """
        
        # 优先尝试使用CSV文件
        if os.path.exists(csv_path):
            print(f"从CSV文件加载数据: {csv_path}")
            df = pd.read_csv(csv_path)
            
            # 检查数据格式
            if 'review' in df.columns and 'label' in df.columns:
                # 将DataFrame转换为元组列表
                data = [(row['review'], row['label']) for _, row in df.iterrows()]
                
                # 分割训练和测试数据,固定测试集为5000条
                total_samples = len(data)
                if total_samples > 5000:
                    test_size = 5000
                    train_data, test_data = train_test_split(
                        data, 
                        test_size=test_size, 
                        random_state=42, 
                        stratify=[label for _, label in data]
                    )
                else:
                    # 如果总数据不足5000条,使用20%作为测试集
                    train_data, test_data = train_test_split(
                        data, 
                        test_size=0.2, 
                        random_state=42, 
                        stratify=[label for _, label in data]
                    )
                
                print(f"训练数据量: {len(train_data)}")
                print(f"测试数据量: {len(test_data)}")
                
                return train_data, test_data
            else:
                print(f"CSV文件格式不正确,缺少'review'或'label'列")
        
        # 如果CSV不存在,尝试使用txt文件
        elif train_path and test_path and os.path.exists(train_path) and os.path.exists(test_path):
            def load_corpus(path):
                data = []
                with open(path, "r", encoding="utf8") as f:
                    for line in f:
                        parts = line.strip().split("\t")
                        if len(parts) >= 2:
                            content = parts[0]
                            sentiment = int(parts[1])
                            data.append((content, sentiment))
                return data
            
            print("从txt文件加载训练数据...")
            train_data = load_corpus(train_path)
            print(f"训练数据量: {len(train_data)}")
            
            print("从txt文件加载测试数据...")
            test_data = load_corpus(test_path)
            print(f"测试数据量: {len(test_data)}")
            
            return train_data, test_data
        
        else:
            # 如果都没有,提供样例数据创建指导
            print("未找到数据文件!")
            print("请确保以下文件之一存在:")
            print(f"1. CSV文件: {csv_path}")
            print(f"2. txt文件: {train_path} 和 {test_path}")
            print("\n数据格式要求:")
            print("CSV文件: 包含'review'和'label'列")
            print("txt文件: 每行格式为'文本内容\\t标签'")
            
            # 创建样例数据
            sample_data = [
                ("今天天气真好,心情很棒!", 1),
                ("这部电影太无聊了", 0),
                ("非常喜欢这个产品", 1),
                ("服务态度很差", 0),
                ("质量不错,值得推荐", 1)
            ]
            
            print("使用样例数据进行演示...")
            train_data = sample_data * 20  # 扩充样例数据
            test_data = sample_data * 5
            
            return train_data, test_data