base.py
6.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
Report Engine 默认的OpenAI兼容LLM客户端封装。
提供统一的非流式/流式调用、可选重试、字节安全拼接与模型元信息查询。
"""
import os
import sys
from typing import Any, Dict, Optional, Generator
from loguru import logger
from openai import OpenAI
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(current_dir))
utils_dir = os.path.join(project_root, "utils")
if utils_dir not in sys.path:
sys.path.append(utils_dir)
try:
from retry_helper import with_retry, LLM_RETRY_CONFIG
except ImportError:
def with_retry(config=None):
"""简化版with_retry占位,实现与真实装饰器一致的调用签名"""
def decorator(func):
"""直接返回原函数,确保无retry依赖时代码仍可运行"""
return func
return decorator
LLM_RETRY_CONFIG = None
class LLMClient:
"""针对OpenAI Chat Completion API的轻量封装,统一Report Engine调用入口。"""
def __init__(self, api_key: str, model_name: str, base_url: Optional[str] = None):
"""
初始化LLM客户端并保存基础连接信息。
Args:
api_key: 用于鉴权的API Token
model_name: 具体模型ID,用于定位供应商能力
base_url: 自定义兼容接口地址,默认为OpenAI官方
"""
if not api_key:
raise ValueError("Report Engine LLM API key is required.")
if not model_name:
raise ValueError("Report Engine model name is required.")
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
self.provider = model_name
timeout_fallback = os.getenv("LLM_REQUEST_TIMEOUT") or os.getenv("REPORT_ENGINE_REQUEST_TIMEOUT") or "3000"
try:
self.timeout = float(timeout_fallback)
except ValueError:
self.timeout = 3000.0
client_kwargs: Dict[str, Any] = {
"api_key": api_key,
"max_retries": 0,
}
if base_url:
client_kwargs["base_url"] = base_url
self.client = OpenAI(**client_kwargs)
@with_retry(LLM_RETRY_CONFIG)
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
以非流式方式调用LLM,并返回一次性完成的完整响应。
Args:
system_prompt: 系统角色提示
user_prompt: 用户高优先级指令
**kwargs: 允许透传temperature/top_p等采样参数
Returns:
去除首尾空白后的LLM响应文本
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty", "stream"}
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
timeout = kwargs.pop("timeout", self.timeout)
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
timeout=timeout,
**extra_params,
)
if response.choices and response.choices[0].message:
return self.validate_response(response.choices[0].message.content)
return ""
def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]:
"""
流式调用LLM,逐步返回响应内容。
参数:
system_prompt: 系统提示词。
user_prompt: 用户提示词。
**kwargs: 采样参数(temperature、top_p等)。
产出:
str: 每次yield一段delta文本,方便上层实时渲染。
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"}
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
# 强制使用流式
extra_params["stream"] = True
timeout = kwargs.pop("timeout", self.timeout)
try:
stream = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
timeout=timeout,
**extra_params,
)
for chunk in stream:
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if delta and delta.content:
yield delta.content
except Exception as e:
logger.error(f"流式请求失败: {str(e)}")
raise e
@with_retry(LLM_RETRY_CONFIG)
def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
流式调用LLM并安全地拼接为完整字符串(避免UTF-8多字节字符截断)。
参数:
system_prompt: 系统提示词。
user_prompt: 用户提示词。
**kwargs: 采样或超时配置。
返回:
str: 将所有delta拼接后的完整响应。
"""
# 以字节形式收集所有块
byte_chunks = []
for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs):
byte_chunks.append(chunk.encode('utf-8'))
# 拼接所有字节,然后一次性解码
if byte_chunks:
return b''.join(byte_chunks).decode('utf-8', errors='replace')
return ""
@staticmethod
def validate_response(response: Optional[str]) -> str:
"""兜底处理None/空白字符串,防止上层逻辑崩溃"""
if response is None:
return ""
return response.strip()
def get_model_info(self) -> Dict[str, Any]:
"""以字典形式返回当前客户端的模型/提供方/基础URL信息"""
return {
"provider": self.provider,
"model": self.model_name,
"api_base": self.base_url or "default",
}