You need to sign in or sign up before continuing.
base.py 2.25 KB
"""
Report Engine LLM基类
定义所有LLM实现的基础接口
"""

from abc import ABC, abstractmethod
from typing import Dict, Any, Optional


class BaseLLM(ABC):
    """LLM基类"""
    
    def __init__(self, api_key: str, model_name: Optional[str] = None):
        """
        初始化LLM客户端
        
        Args:
            api_key: API密钥
            model_name: 模型名称
        """
        self.api_key = api_key
        self.model_name = model_name
    
    @abstractmethod
    def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
        """
        调用LLM生成回复
        
        Args:
            system_prompt: 系统提示词
            user_prompt: 用户输入
            **kwargs: 其他参数
            
        Returns:
            生成的回复文本
        """
        pass
    
    @abstractmethod
    def get_model_info(self) -> Dict[str, Any]:
        """
        获取当前模型信息
        
        Returns:
            模型信息字典
        """
        pass
    
    @abstractmethod
    def get_default_model(self) -> str:
        """
        获取默认模型名称
        
        Returns:
            默认模型名称
        """
        pass
    
    def validate_response(self, response: str) -> str:
        """
        验证和清理响应内容
        
        Args:
            response: 原始响应
            
        Returns:
            清理后的响应
        """
        if not response:
            return ""
        
        # 移除多余的空白字符
        response = response.strip()
        
        # 确保响应不为空
        if not response:
            return "抱歉,生成的内容为空。"
        
        return response
    
    def estimate_tokens(self, text: str) -> int:
        """
        估算文本的token数量(简单实现)
        
        Args:
            text: 输入文本
            
        Returns:
            估算的token数量
        """
        # 简单估算:中文字符按1.5个token计算,英文单词按1个token计算
        chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff'])
        english_words = len(text.split()) - chinese_chars
        
        return int(chinese_chars * 1.5 + english_words)