predict_universal.py 14.7 KB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Qwen3微博情感分析统一预测接口
支持0.6B、4B、8B三种规格的Embedding和LoRA模型
"""

import os
import sys
import argparse
import torch
from typing import List, Dict, Tuple, Any

# 添加当前目录到路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from models_config import MODEL_CONFIGS, MODEL_PATHS
from qwen3_embedding_universal import Qwen3EmbeddingUniversal
from qwen3_lora_universal import Qwen3LoRAUniversal


class Qwen3UniversalPredictor:
    """Qwen3统一预测器"""
    
    def __init__(self):
        self.models = {}  # 存储已加载的模型 {model_key: {model: obj, display_name: str}}
        
    def _get_model_key(self, model_type: str, model_size: str) -> str:
        """生成模型键值"""
        return f"{model_type}_{model_size}"
    
    def load_model(self, model_type: str, model_size: str) -> None:
        """加载指定的模型"""
        if model_type not in ['embedding', 'lora']:
            raise ValueError(f"不支持的模型类型: {model_type}")
        if model_size not in ['0.6B', '4B', '8B']:
            raise ValueError(f"不支持的模型大小: {model_size}")
            
        model_path = MODEL_PATHS[model_type][model_size]
        model_key = self._get_model_key(model_type, model_size)
        
        # 检查训练好的模型文件是否存在
        if not os.path.exists(model_path):
            print(f"训练好的模型文件不存在: {model_path}")
            print(f"请先训练 {model_type.upper()}-{model_size} 模型,或检查模型路径配置")
            return
        
        print(f"加载 {model_type.upper()}-{model_size} 模型...")
        
        try:
            if model_type == 'embedding':
                model = Qwen3EmbeddingUniversal(model_size)
                model.load_model(model_path)
            else:  # lora
                model = Qwen3LoRAUniversal(model_size)
                model.load_model(model_path)
            
            self.models[model_key] = {
                'model': model,
                'display_name': f"Qwen3-{model_type.title()}-{model_size}"
            }
            print(f"{model_type.upper()}-{model_size} 模型加载成功")
            
        except Exception as e:
            print(f"加载 {model_type.upper()}-{model_size} 模型失败: {e}")
            print(f"这可能是因为基础模型下载失败或训练好的模型文件损坏")
    
    def load_all_models(self, model_dir: str = './models') -> None:
        """加载所有可用的模型"""
        print("开始加载所有可用的Qwen3模型...")
        
        loaded_count = 0
        for model_type in ['embedding', 'lora']:
            for model_size in ['0.6B', '4B', '8B']:
                try:
                    self.load_model(model_type, model_size)
                    loaded_count += 1
                except Exception as e:
                    print(f"跳过 {model_type}-{model_size}: {e}")
        
        print(f"\n已加载 {loaded_count} 个模型")
        self._print_loaded_models()
    
    def load_specific_models(self, model_configs: List[Tuple[str, str]]) -> None:
        """加载指定的模型配置
        Args:
            model_configs: [(model_type, model_size), ...] 的列表
        """
        print("加载指定的Qwen3模型...")
        
        for model_type, model_size in model_configs:
            try:
                self.load_model(model_type, model_size)
            except Exception as e:
                print(f"跳过 {model_type}-{model_size}: {e}")
        
        print(f"\n已加载 {len(self.models)} 个模型")
        self._print_loaded_models()
    
    def _print_loaded_models(self):
        """打印已加载的模型列表"""
        if self.models:
            print("已加载模型:")
            for model_info in self.models.values():
                print(f"  - {model_info['display_name']}")
        else:
            print("没有成功加载任何模型")
    
    def predict_single(self, text: str, model_key: str = None) -> Dict[str, Tuple[int, float]]:
        """单文本预测
        Args:
            text: 要预测的文本
            model_key: 指定模型键值,None表示使用所有模型
        Returns:
            {model_name: (prediction, confidence), ...}
        """
        results = {}
        
        if model_key and model_key in self.models:
            # 使用指定模型
            model_info = self.models[model_key]
            try:
                prediction, confidence = model_info['model'].predict_single(text)
                results[model_info['display_name']] = (prediction, confidence)
            except Exception as e:
                print(f"模型 {model_info['display_name']} 预测失败: {e}")
                results[model_info['display_name']] = (0, 0.0)
        else:
            # 使用所有模型
            for model_info in self.models.values():
                try:
                    prediction, confidence = model_info['model'].predict_single(text)
                    results[model_info['display_name']] = (prediction, confidence)
                except Exception as e:
                    print(f"模型 {model_info['display_name']} 预测失败: {e}")
                    results[model_info['display_name']] = (0, 0.0)
        
        return results
    
    def predict_batch(self, texts: List[str]) -> Dict[str, List[int]]:
        """批量预测"""
        results = {}
        
        for model_info in self.models.values():
            try:
                predictions = model_info['model'].predict(texts)
                results[model_info['display_name']] = predictions
            except Exception as e:
                print(f"模型 {model_info['display_name']} 预测失败: {e}")
                results[model_info['display_name']] = [0] * len(texts)
        
        return results
    
    def ensemble_predict(self, text: str) -> Tuple[int, float]:
        """集成预测"""
        if len(self.models) < 2:
            raise ValueError("集成预测需要至少2个模型")
        
        results = self.predict_single(text)
        
        # 加权平均(这里使用简单平均,可以根据模型性能调整权重)
        total_weight = 0
        weighted_prob = 0
        
        for model_name, (pred, conf) in results.items():
            if conf > 0:  # 只考虑有效预测
                prob = conf if pred == 1 else 1 - conf
                weighted_prob += prob
                total_weight += 1
        
        if total_weight == 0:
            return 0, 0.5
        
        final_prob = weighted_prob / total_weight
        final_pred = int(final_prob > 0.5)
        final_conf = final_prob if final_pred == 1 else 1 - final_prob
        
        return final_pred, final_conf
    
    def _select_and_load_model(self):
        """让用户选择并加载模型"""
        print("Qwen3微博情感分析预测系统")
        print("="*40)
        print("请选择要使用的模型:")
        print("\n方法选择:")
        print("  1. Embedding + 分类头 (推理快速,显存占用少)")
        print("  2. LoRA微调 (效果更好,显存占用较多)")
        
        method_choice = None
        while method_choice not in ['1', '2']:
            method_choice = input("\n请选择方法 (1/2): ").strip()
            if method_choice not in ['1', '2']:
                print("无效选择,请输入 1 或 2")
        
        method_type = "embedding" if method_choice == '1' else "lora"
        method_name = "Embedding + 分类头" if method_choice == '1' else "LoRA微调"
        
        print(f"\n已选择: {method_name}")
        print("\n模型大小选择:")
        print("  1. 0.6B - 轻量级,推理快速")
        print("  2. 4B  - 中等规模,性能均衡") 
        print("  3. 8B  - 大规模,性能最佳")
        
        size_choice = None
        while size_choice not in ['1', '2', '3']:
            size_choice = input("\n请选择模型大小 (1/2/3): ").strip()
            if size_choice not in ['1', '2', '3']:
                print("无效选择,请输入 1、2 或 3")
        
        size_map = {'1': '0.6B', '2': '4B', '3': '8B'}
        model_size = size_map[size_choice]
        
        print(f"已选择: Qwen3-{method_name}-{model_size}")
        print("正在加载模型...")
        
        try:
            self.load_model(method_type, model_size)
            print(f"模型加载成功!")
        except Exception as e:
            print(f"模型加载失败: {e}")
            print("请检查模型文件是否存在,或先进行训练")
    
    def interactive_predict(self):
        """交互式预测模式"""
        if len(self.models) == 0:
            # 让用户选择要加载的模型
            self._select_and_load_model()
            if len(self.models) == 0:
                print("没有加载任何模型,退出预测")
                return
        
        print("\n" + "="*60)
        print("Qwen3微博情感分析预测系统")
        print("="*60)
        print("已加载模型:")
        for model_info in self.models.values():
            print(f"   - {model_info['display_name']}")
        print("\n命令提示:")
        print("   输入 'q' 退出程序")
        print("   输入 'switch' 切换模型")  
        print("   输入 'models' 查看已加载模型")
        print("   输入 'compare' 比较所有模型性能")
        print("-"*60)
        
        while True:
            try:
                text = input("\n请输入要分析的微博内容: ").strip()
                
                if text.lower() == 'q':
                    print("感谢使用,再见!")
                    break
                
                if text.lower() == 'models':
                    print("已加载模型:")
                    for model_info in self.models.values():
                        print(f"   - {model_info['display_name']}")
                    continue
                
                if text.lower() == 'switch':
                    print("切换模型...")
                    self.models.clear()  # 清空当前模型
                    self._select_and_load_model()
                    if len(self.models) > 0:
                        print("模型切换成功!")
                        for model_info in self.models.values():
                            print(f"   当前模型: {model_info['display_name']}")
                    continue
                
                if text.lower() == 'compare':
                    test_text = input("请输入要比较的文本: ")
                    self._compare_models(test_text)
                    continue
                
                if not text:
                    print("请输入有效内容")
                    continue
                
                # 预测
                results = self.predict_single(text)
                
                print(f"\n原文: {text}")
                print("预测结果:")
                
                # 按模型类型和大小排序显示
                sorted_results = sorted(results.items())
                for model_name, (pred, conf) in sorted_results:
                    sentiment = "正面" if pred == 1 else "负面"
                    print(f"   {model_name:20}: {sentiment} (置信度: {conf:.4f})")
                
                # 只显示单个模型的预测结果(不进行集成)
                
            except KeyboardInterrupt:
                print("\n\n程序被中断,再见!")
                break
            except Exception as e:
                print(f"预测过程中出现错误: {e}")
    
    def _compare_models(self, text: str):
        """比较不同模型的性能"""
        print(f"\n模型性能比较 - 文本: {text}")
        print("-" * 60)
        
        results = self.predict_single(text)
        
        embedding_models = []
        lora_models = []
        
        for model_name, (pred, conf) in results.items():
            sentiment = "正面" if pred == 1 else "负面"
            if "Embedding" in model_name:
                embedding_models.append((model_name, sentiment, conf))
            elif "Lora" in model_name:
                lora_models.append((model_name, sentiment, conf))
        
        if embedding_models:
            print("Embedding + 分类头方法:")
            for name, sentiment, conf in embedding_models:
                print(f"   {name}: {sentiment} ({conf:.4f})")
        
        if lora_models:
            print("LoRA微调方法:")
            for name, sentiment, conf in lora_models:
                print(f"   {name}: {sentiment} ({conf:.4f})")


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description='Qwen3微博情感分析统一预测接口')
    parser.add_argument('--model_dir', type=str, default='./models',
                        help='模型文件目录')
    parser.add_argument('--model_type', type=str, choices=['embedding', 'lora'],
                        help='指定模型类型')
    parser.add_argument('--model_size', type=str, choices=['0.6B', '4B', '8B'],
                        help='指定模型大小')
    parser.add_argument('--text', type=str,
                        help='直接预测指定文本')
    parser.add_argument('--interactive', action='store_true', default=True,
                        help='交互式预测模式(默认)')
    parser.add_argument('--ensemble', action='store_true',
                        help='使用集成预测')
    parser.add_argument('--load_all', action='store_true',
                        help='加载所有可用模型')
    
    args = parser.parse_args()
    
    # 创建预测器
    predictor = Qwen3UniversalPredictor()
    
    # 加载模型
    if args.load_all:
        # 加载所有模型
        predictor.load_all_models(args.model_dir)
    elif args.model_type and args.model_size:
        # 加载指定模型
        predictor.load_model(args.model_type, args.model_size)
    # 如果没有指定模型,交互式模式会让用户选择
    
    # 如果指定了文本,直接预测
    if args.text:
        if args.ensemble and len(predictor.models) > 1:
            pred, conf = predictor.ensemble_predict(args.text)
            sentiment = "正面" if pred == 1 else "负面"
            print(f"文本: {args.text}")
            print(f"集成预测: {sentiment} (置信度: {conf:.4f})")
        else:
            results = predictor.predict_single(args.text)
            print(f"文本: {args.text}")
            for model_name, (pred, conf) in results.items():
                sentiment = "正面" if pred == 1 else "负面"
                print(f"{model_name}: {sentiment} (置信度: {conf:.4f})")
    else:
        # 进入交互式模式
        predictor.interactive_predict()


if __name__ == "__main__":
    main()