predict.py 11.4 KB
# -*- coding: utf-8 -*-
"""
统一的情感分析预测程序
支持加载所有模型进行情感预测
"""
import argparse
import os
import re
from typing import Dict, Tuple, List
import warnings
warnings.filterwarnings("ignore")

# 导入所有模型类
from bayes_train import BayesModel
from svm_train import SVMModel
from xgboost_train import XGBoostModel
from lstm_train import LSTMModel
from bert_train import BertModel_Custom
from utils import processing


class SentimentPredictor:
    """情感分析预测器"""
    
    def __init__(self):
        self.models = {}
        self.available_models = {
            'bayes': BayesModel,
            'svm': SVMModel,
            'xgboost': XGBoostModel,
            'lstm': LSTMModel,
            'bert': BertModel_Custom
        }
        
    def load_model(self, model_type: str, model_path: str, **kwargs) -> None:
        """加载指定类型的模型
        
        Args:
            model_type: 模型类型 ('bayes', 'svm', 'xgboost', 'lstm', 'bert')
            model_path: 模型文件路径
            **kwargs: 其他参数(如BERT的预训练模型路径)
        """
        if model_type not in self.available_models:
            raise ValueError(f"不支持的模型类型: {model_type}")
        
        if not os.path.exists(model_path):
            print(f"警告: 模型文件不存在: {model_path}")
            return
        
        print(f"加载 {model_type.upper()} 模型...")
        
        try:
            if model_type == 'bert':
                # BERT需要额外的预训练模型路径
                bert_path = kwargs.get('bert_path', './model/chinese_wwm_pytorch')
                model = BertModel_Custom(bert_path)
            else:
                model = self.available_models[model_type]()
            
            model.load_model(model_path)
            self.models[model_type] = model
            print(f"{model_type.upper()} 模型加载成功")
            
        except Exception as e:
            print(f"加载 {model_type.upper()} 模型失败: {e}")
    
    def load_all_models(self, model_dir: str = './model', bert_path: str = './model/chinese_wwm_pytorch') -> None:
        """加载所有可用的模型
        
        Args:
            model_dir: 模型文件目录
            bert_path: BERT预训练模型路径
        """
        model_files = {
            'bayes': os.path.join(model_dir, 'bayes_model.pkl'),
            'svm': os.path.join(model_dir, 'svm_model.pkl'),
            'xgboost': os.path.join(model_dir, 'xgboost_model.pkl'),
            'lstm': os.path.join(model_dir, 'lstm_model.pth'),
            'bert': os.path.join(model_dir, 'bert_model.pth')
        }
        
        print("开始加载所有可用模型...")
        for model_type, model_path in model_files.items():
            self.load_model(model_type, model_path, bert_path=bert_path)
        
        print(f"\n已加载 {len(self.models)} 个模型: {list(self.models.keys())}")
    
    def predict_single(self, text: str, model_type: str = None) -> Dict[str, Tuple[int, float]]:
        """预测单条文本的情感
        
        Args:
            text: 待预测文本
            model_type: 指定模型类型,如果为None则使用所有已加载的模型
            
        Returns:
            Dict[model_type, (prediction, confidence)]
        """
        # 文本预处理
        processed_text = processing(text)
        
        if model_type:
            if model_type not in self.models:
                raise ValueError(f"模型 {model_type} 未加载")
            
            prediction, confidence = self.models[model_type].predict_single(processed_text)
            return {model_type: (prediction, confidence)}
        
        # 使用所有模型预测
        results = {}
        for name, model in self.models.items():
            try:
                prediction, confidence = model.predict_single(processed_text)
                results[name] = (prediction, confidence)
            except Exception as e:
                print(f"模型 {name} 预测失败: {e}")
                results[name] = (0, 0.0)
        
        return results
    
    def predict_batch(self, texts: List[str], model_type: str = None) -> Dict[str, List[int]]:
        """批量预测文本情感
        
        Args:
            texts: 待预测文本列表
            model_type: 指定模型类型,如果为None则使用所有已加载的模型
            
        Returns:
            Dict[model_type, predictions]
        """
        # 文本预处理
        processed_texts = [processing(text) for text in texts]
        
        if model_type:
            if model_type not in self.models:
                raise ValueError(f"模型 {model_type} 未加载")
            
            predictions = self.models[model_type].predict(processed_texts)
            return {model_type: predictions}
        
        # 使用所有模型预测
        results = {}
        for name, model in self.models.items():
            try:
                predictions = model.predict(processed_texts)
                results[name] = predictions
            except Exception as e:
                print(f"模型 {name} 预测失败: {e}")
                results[name] = [0] * len(texts)
        
        return results
    
    def ensemble_predict(self, text: str, weights: Dict[str, float] = None) -> Tuple[int, float]:
        """集成预测(多个模型投票)
        
        Args:
            text: 待预测文本
            weights: 模型权重,如果为None则平均权重
            
        Returns:
            (prediction, confidence)
        """
        if len(self.models) == 0:
            raise ValueError("没有加载任何模型")
        
        results = self.predict_single(text)
        
        if weights is None:
            weights = {name: 1.0 for name in results.keys()}
        
        # 加权平均
        total_weight = 0
        weighted_prob = 0
        
        for model_name, (pred, conf) in results.items():
            if model_name in weights:
                weight = weights[model_name]
                prob = conf if pred == 1 else 1 - conf
                weighted_prob += prob * weight
                total_weight += weight
        
        if total_weight == 0:
            return 0, 0.5
        
        final_prob = weighted_prob / total_weight
        final_pred = int(final_prob > 0.5)
        final_conf = final_prob if final_pred == 1 else 1 - final_prob
        
        return final_pred, final_conf
    
    def interactive_predict(self):
        """交互式预测模式"""
        if len(self.models) == 0:
            print("错误: 没有加载任何模型,请先加载模型")
            return
        
        print("\n" + "="*50)
        print("="*50)
        print(f"已加载模型: {', '.join(self.models.keys())}")
        print("输入 'q' 退出程序")
        print("输入 'models' 查看模型列表")
        print("输入 'ensemble' 使用集成预测")
        print("-"*50)
        
        while True:
            try:
                text = input("\n请输入要分析的微博内容: ").strip()
                
                if text.lower() == 'q':
                    print("👋 再见!")
                    break
                
                if text.lower() == 'models':
                    print(f"已加载模型: {list(self.models.keys())}")
                    continue
                
                if text.lower() == 'ensemble':
                    if len(self.models) > 1:
                        pred, conf = self.ensemble_predict(text)
                        sentiment = "😊 正面" if pred == 1 else "😞 负面"
                        print(f"\n🤖 集成预测结果:")
                        print(f"   情感倾向: {sentiment}")
                        print(f"   置信度: {conf:.4f}")
                    else:
                        print("❌ 集成预测需要至少2个模型")
                    continue
                
                if not text:
                    print("❌ 请输入有效内容")
                    continue
                
                # 预测
                results = self.predict_single(text)
                
                print(f"\n📝 原文: {text}")
                print("🔍 预测结果:")
                
                for model_name, (pred, conf) in results.items():
                    sentiment = "😊 正面" if pred == 1 else "😞 负面"
                    print(f"   {model_name.upper():8}: {sentiment} (置信度: {conf:.4f})")
                
                # 如果有多个模型,显示集成结果
                if len(results) > 1:
                    ensemble_pred, ensemble_conf = self.ensemble_predict(text)
                    ensemble_sentiment = "😊 正面" if ensemble_pred == 1 else "😞 负面"
                    print(f"   {'集成':8}: {ensemble_sentiment} (置信度: {ensemble_conf:.4f})")
                
            except KeyboardInterrupt:
                print("\n\n👋 程序被中断,再见!")
                break
            except Exception as e:
                print(f"❌ 预测过程中出现错误: {e}")


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description='微博情感分析统一预测程序')
    parser.add_argument('--model_dir', type=str, default='./model',
                        help='模型文件目录')
    parser.add_argument('--bert_path', type=str, default='./model/chinese_wwm_pytorch',
                        help='BERT预训练模型路径')
    parser.add_argument('--model_type', type=str, choices=['bayes', 'svm', 'xgboost', 'lstm', 'bert'],
                        help='指定单个模型类型进行预测')
    parser.add_argument('--text', type=str,
                        help='直接预测指定文本')
    parser.add_argument('--interactive', action='store_true', default=True,
                        help='交互式预测模式(默认)')
    parser.add_argument('--ensemble', action='store_true',
                        help='使用集成预测')
    
    args = parser.parse_args()
    
    # 创建预测器
    predictor = SentimentPredictor()
    
    # 加载模型
    if args.model_type:
        # 加载指定模型
        model_files = {
            'bayes': 'bayes_model.pkl',
            'svm': 'svm_model.pkl',
            'xgboost': 'xgboost_model.pkl',
            'lstm': 'lstm_model.pth',
            'bert': 'bert_model.pth'
        }
        model_path = os.path.join(args.model_dir, model_files[args.model_type])
        predictor.load_model(args.model_type, model_path, bert_path=args.bert_path)
    else:
        # 加载所有模型
        predictor.load_all_models(args.model_dir, args.bert_path)
    
    # 如果指定了文本,直接预测
    if args.text:
        if args.ensemble and len(predictor.models) > 1:
            pred, conf = predictor.ensemble_predict(args.text)
            sentiment = "正面" if pred == 1 else "负面"
            print(f"文本: {args.text}")
            print(f"集成预测: {sentiment} (置信度: {conf:.4f})")
        else:
            results = predictor.predict_single(args.text, args.model_type)
            print(f"文本: {args.text}")
            for model_name, (pred, conf) in results.items():
                sentiment = "正面" if pred == 1 else "负面"
                print(f"{model_name.upper()}: {sentiment} (置信度: {conf:.4f})")
    elif args.interactive:
        # 交互式模式
        predictor.interactive_predict()


if __name__ == "__main__":
    main()