Doubao.py 8.62 KB
# AIfeng/2024-12-19
# 豆包大模型API实现
# 基于火山引擎豆包API: https://www.volcengine.com/docs/82379/1494384

import os
import json
import requests
from typing import Dict, List, Any, Optional
from logger import logger

class Doubao:
    """豆包大模型API客户端"""
    
    def __init__(self, config_path: str = "config/doubao_config.json"):
        """初始化豆包模型
        
        Args:
            config_path: 配置文件路径
        """
        self.config_file = config_path
        self.config = self._load_config(config_path)
        self.api_key = os.getenv("DOUBAO_API_KEY") or self.config.get("api_key")
        self.base_url = self.config.get("base_url", "https://ark.cn-beijing.volces.com/api/v3")
        self.model = self.config.get("model", "ep-20241219000000-xxxxx")
        self.character_config = self.config.get("character", {})
        
        if not self.api_key:
            raise ValueError("豆包API密钥未配置,请设置环境变量DOUBAO_API_KEY或在配置文件中设置api_key")
    
    def _load_config(self, config_path: str) -> Dict[str, Any]:
        """加载配置文件"""
        try:
            with open(config_path, 'r', encoding='utf-8') as f:
                return json.load(f)
        except FileNotFoundError:
            logger.warning(f"配置文件 {config_path} 不存在,使用默认配置")
            return {}
        except json.JSONDecodeError as e:
            logger.error(f"配置文件格式错误: {e}")
            return {}
    
    def _build_system_message(self) -> str:
        """构建系统消息"""
        character = self.character_config
        
        system_prompt = character.get("base_prompt", "你是一个AI助手")
        
        # 添加角色设定
        if character.get("name"):
            system_prompt += f",你的名字是{character['name']}"
        
        if character.get("personality"):
            system_prompt += f",性格特点:{character['personality']}"
        
        if character.get("background"):
            system_prompt += f",背景设定:{character['background']}"
        
        if character.get("speaking_style"):
            system_prompt += f",说话风格:{character['speaking_style']}"
        
        if character.get("constraints"):
            system_prompt += f",行为约束:{character['constraints']}"
        
        return system_prompt
    
    def chat(self, message: str, history: Optional[List[Dict[str, str]]] = None) -> str:
        """发送聊天请求
        
        Args:
            message: 用户消息
            history: 对话历史
            
        Returns:
            AI回复内容
        """
        url = f"{self.base_url}/chat/completions"
        
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        
        # 构建消息列表
        messages = []
        
        # 添加系统消息
        system_message = self._build_system_message()
        messages.append({
            "role": "system",
            "content": system_message
        })
        
        # 添加历史对话
        if history:
            messages.extend(history)
        
        # 添加当前用户消息
        messages.append({
            "role": "user",
            "content": message
        })
        
        # 构建请求数据
        data = {
            "model": self.model,
            "messages": messages,
            "stream": self.config.get("stream", True),
            "max_tokens": self.config.get("max_tokens", 1024),
            "temperature": self.config.get("temperature", 0.7),
            "top_p": self.config.get("top_p", 0.9)
        }
        
        try:
            response = requests.post(url, headers=headers, json=data, timeout=30)
            response.raise_for_status()
            
            if self.config.get("stream", True):
                return self._handle_stream_response(response)
            else:
                result = response.json()
                return result["choices"][0]["message"]["content"]
                
        except requests.exceptions.RequestException as e:
            logger.error(f"豆包API请求失败: {e}")
            return "抱歉,我现在无法回答您的问题,请稍后再试。"
        except Exception as e:
            logger.error(f"豆包API处理异常: {e}")
            return "抱歉,处理您的请求时出现了问题。"
    
    def _handle_stream_response(self, response) -> str:
        """处理流式响应"""
        result = ""
        
        try:
            for line in response.iter_lines():
                if line:
                    line = line.decode('utf-8')
                    if line.startswith('data: '):
                        data_str = line[6:]
                        if data_str.strip() == '[DONE]':
                            break
                        
                        try:
                            data = json.loads(data_str)
                            if 'choices' in data and len(data['choices']) > 0:
                                delta = data['choices'][0].get('delta', {})
                                content = delta.get('content', '')
                                if content:
                                    result += content
                        except json.JSONDecodeError:
                            continue
            
            return result
            
        except Exception as e:
            logger.error(f"处理流式响应失败: {e}")
            return "抱歉,处理响应时出现问题。"
    
    def chat_stream(self, message: str, history: Optional[List[Dict[str, str]]] = None, callback=None):
        """流式聊天,支持回调函数处理每个token
        
        Args:
            message: 用户消息
            history: 对话历史
            callback: 回调函数,接收每个生成的token
        """
        url = f"{self.base_url}/chat/completions"
        
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        
        # 构建消息列表
        messages = []
        
        # 添加系统消息
        system_message = self._build_system_message()
        messages.append({
            "role": "system",
            "content": system_message
        })
        
        # 添加历史对话
        if history:
            messages.extend(history)
        
        # 添加当前用户消息
        messages.append({
            "role": "user",
            "content": message
        })
        
        # 构建请求数据
        data = {
            "model": self.model,
            "messages": messages,
            "stream": True,
            "max_tokens": self.config.get("max_tokens", 1024),
            "temperature": self.config.get("temperature", 0.7),
            "top_p": self.config.get("top_p", 0.9)
        }
        
        try:
            response = requests.post(url, headers=headers, json=data, stream=True, timeout=30)
            response.raise_for_status()
            
            result = ""
            for line in response.iter_lines():
                if line:
                    line = line.decode('utf-8')
                    if line.startswith('data: '):
                        data_str = line[6:]
                        if data_str.strip() == '[DONE]':
                            break
                        
                        try:
                            data = json.loads(data_str)
                            if 'choices' in data and len(data['choices']) > 0:
                                delta = data['choices'][0].get('delta', {})
                                content = delta.get('content', '')
                                if content:
                                    result += content
                                    if callback:
                                        callback(content)
                        except json.JSONDecodeError:
                            continue
            
            return result
            
        except Exception as e:
            logger.error(f"豆包流式API请求失败: {e}")
            if callback:
                callback("抱歉,我现在无法回答您的问题,请稍后再试。")
            return "抱歉,我现在无法回答您的问题,请稍后再试。"


def test_doubao():
    """测试豆包API"""
    try:
        doubao = Doubao()
        response = doubao.chat("你好,请介绍一下自己")
        print(f"豆包回复: {response}")
    except Exception as e:
        print(f"测试失败: {e}")


if __name__ == "__main__":
    test_doubao()