Committed by
BaiFu
1. LLM接口改为字节级流式接口,防止超时错误,也避免utf-8长字节字符拼接错误
Showing
19 changed files
with
315 additions
and
40 deletions
| @@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Insight Engine, with retry support. | @@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Insight Engine, with retry support. | ||
| 5 | import os | 5 | import os |
| 6 | import sys | 6 | import sys |
| 7 | from datetime import datetime | 7 | from datetime import datetime |
| 8 | -from typing import Any, Dict, Optional | 8 | +from typing import Any, Dict, Optional, Iterator, Generator |
| 9 | +from loguru import logger | ||
| 9 | 10 | ||
| 10 | from openai import OpenAI | 11 | from openai import OpenAI |
| 11 | 12 | ||
| @@ -82,6 +83,76 @@ class LLMClient: | @@ -82,6 +83,76 @@ class LLMClient: | ||
| 82 | return self.validate_response(response.choices[0].message.content) | 83 | return self.validate_response(response.choices[0].message.content) |
| 83 | return "" | 84 | return "" |
| 84 | 85 | ||
| 86 | + @with_retry(LLM_RETRY_CONFIG) | ||
| 87 | + def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]: | ||
| 88 | + """ | ||
| 89 | + 流式调用LLM,逐步返回响应内容 | ||
| 90 | + | ||
| 91 | + Args: | ||
| 92 | + system_prompt: 系统提示词 | ||
| 93 | + user_prompt: 用户提示词 | ||
| 94 | + **kwargs: 额外参数(temperature, top_p等) | ||
| 95 | + | ||
| 96 | + Yields: | ||
| 97 | + 响应文本块(str) | ||
| 98 | + """ | ||
| 99 | + current_time = datetime.now().strftime("%Y年%m月%d日%H时%M分") | ||
| 100 | + time_prefix = f"今天的实际时间是{current_time}" | ||
| 101 | + if user_prompt: | ||
| 102 | + user_prompt = f"{time_prefix}\n{user_prompt}" | ||
| 103 | + else: | ||
| 104 | + user_prompt = time_prefix | ||
| 105 | + messages = [ | ||
| 106 | + {"role": "system", "content": system_prompt}, | ||
| 107 | + {"role": "user", "content": user_prompt}, | ||
| 108 | + ] | ||
| 109 | + | ||
| 110 | + allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"} | ||
| 111 | + extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None} | ||
| 112 | + # 强制使用流式 | ||
| 113 | + extra_params["stream"] = True | ||
| 114 | + | ||
| 115 | + timeout = kwargs.pop("timeout", self.timeout) | ||
| 116 | + | ||
| 117 | + try: | ||
| 118 | + stream = self.client.chat.completions.create( | ||
| 119 | + model=self.model_name, | ||
| 120 | + messages=messages, | ||
| 121 | + timeout=timeout, | ||
| 122 | + **extra_params, | ||
| 123 | + ) | ||
| 124 | + | ||
| 125 | + for chunk in stream: | ||
| 126 | + if chunk.choices and len(chunk.choices) > 0: | ||
| 127 | + delta = chunk.choices[0].delta | ||
| 128 | + if delta and delta.content: | ||
| 129 | + yield delta.content | ||
| 130 | + except Exception as e: | ||
| 131 | + logger.error(f"流式请求失败: {str(e)}") | ||
| 132 | + raise e | ||
| 133 | + | ||
| 134 | + def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str: | ||
| 135 | + """ | ||
| 136 | + 流式调用LLM并安全地拼接为完整字符串(避免UTF-8多字节字符截断) | ||
| 137 | + | ||
| 138 | + Args: | ||
| 139 | + system_prompt: 系统提示词 | ||
| 140 | + user_prompt: 用户提示词 | ||
| 141 | + **kwargs: 额外参数(temperature, top_p等) | ||
| 142 | + | ||
| 143 | + Returns: | ||
| 144 | + 完整的响应字符串 | ||
| 145 | + """ | ||
| 146 | + # 以字节形式收集所有块 | ||
| 147 | + byte_chunks = [] | ||
| 148 | + for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs): | ||
| 149 | + byte_chunks.append(chunk.encode('utf-8')) | ||
| 150 | + | ||
| 151 | + # 拼接所有字节,然后一次性解码 | ||
| 152 | + if byte_chunks: | ||
| 153 | + return b''.join(byte_chunks).decode('utf-8', errors='replace') | ||
| 154 | + return "" | ||
| 155 | + | ||
| 85 | @staticmethod | 156 | @staticmethod |
| 86 | def validate_response(response: Optional[str]) -> str: | 157 | def validate_response(response: Optional[str]) -> str: |
| 87 | if response is None: | 158 | if response is None: |
| @@ -70,8 +70,8 @@ class ReportFormattingNode(BaseNode): | @@ -70,8 +70,8 @@ class ReportFormattingNode(BaseNode): | ||
| 70 | 70 | ||
| 71 | logger.info("正在格式化最终报告") | 71 | logger.info("正在格式化最终报告") |
| 72 | 72 | ||
| 73 | - # 调用LLM | ||
| 74 | - response = self.llm_client.invoke( | 73 | + # 调用LLM(流式,安全拼接UTF-8) |
| 74 | + response = self.llm_client.stream_invoke_to_string( | ||
| 75 | SYSTEM_PROMPT_REPORT_FORMATTING, | 75 | SYSTEM_PROMPT_REPORT_FORMATTING, |
| 76 | message, | 76 | message, |
| 77 | ) | 77 | ) |
| @@ -51,8 +51,8 @@ class ReportStructureNode(StateMutationNode): | @@ -51,8 +51,8 @@ class ReportStructureNode(StateMutationNode): | ||
| 51 | try: | 51 | try: |
| 52 | logger.info(f"正在为查询生成报告结构: {self.query}") | 52 | logger.info(f"正在为查询生成报告结构: {self.query}") |
| 53 | 53 | ||
| 54 | - # 调用LLM | ||
| 55 | - response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) | 54 | + # 调用LLM(流式,安全拼接UTF-8) |
| 55 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) | ||
| 56 | 56 | ||
| 57 | # 处理响应 | 57 | # 处理响应 |
| 58 | processed_response = self.process_output(response) | 58 | processed_response = self.process_output(response) |
| @@ -65,8 +65,8 @@ class FirstSearchNode(BaseNode): | @@ -65,8 +65,8 @@ class FirstSearchNode(BaseNode): | ||
| 65 | 65 | ||
| 66 | logger.info("正在生成首次搜索查询") | 66 | logger.info("正在生成首次搜索查询") |
| 67 | 67 | ||
| 68 | - # 调用LLM | ||
| 69 | - response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message) | 68 | + # 调用LLM(流式,安全拼接UTF-8) |
| 69 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SEARCH, message) | ||
| 70 | 70 | ||
| 71 | # 处理响应 | 71 | # 处理响应 |
| 72 | processed_response = self.process_output(response) | 72 | processed_response = self.process_output(response) |
| @@ -200,8 +200,8 @@ class ReflectionNode(BaseNode): | @@ -200,8 +200,8 @@ class ReflectionNode(BaseNode): | ||
| 200 | 200 | ||
| 201 | logger.info("正在进行反思并生成新搜索查询") | 201 | logger.info("正在进行反思并生成新搜索查询") |
| 202 | 202 | ||
| 203 | - # 调用LLM | ||
| 204 | - response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message) | 203 | + # 调用LLM(流式,安全拼接UTF-8) |
| 204 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION, message) | ||
| 205 | 205 | ||
| 206 | # 处理响应 | 206 | # 处理响应 |
| 207 | processed_response = self.process_output(response) | 207 | processed_response = self.process_output(response) |
| @@ -99,8 +99,8 @@ class FirstSummaryNode(StateMutationNode): | @@ -99,8 +99,8 @@ class FirstSummaryNode(StateMutationNode): | ||
| 99 | 99 | ||
| 100 | logger.info("正在生成首次段落总结") | 100 | logger.info("正在生成首次段落总结") |
| 101 | 101 | ||
| 102 | - # 调用LLM | ||
| 103 | - response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message) | 102 | + # 调用LLM(流式,安全拼接UTF-8) |
| 103 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SUMMARY, message) | ||
| 104 | 104 | ||
| 105 | # 处理响应 | 105 | # 处理响应 |
| 106 | processed_response = self.process_output(response) | 106 | processed_response = self.process_output(response) |
| @@ -264,8 +264,8 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -264,8 +264,8 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 264 | 264 | ||
| 265 | logger.info("正在生成反思总结") | 265 | logger.info("正在生成反思总结") |
| 266 | 266 | ||
| 267 | - # 调用LLM | ||
| 268 | - response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message) | 267 | + # 调用LLM(流式,安全拼接UTF-8) |
| 268 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION_SUMMARY, message) | ||
| 269 | 269 | ||
| 270 | # 处理响应 | 270 | # 处理响应 |
| 271 | processed_response = self.process_output(response) | 271 | processed_response = self.process_output(response) |
| @@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Media Engine, with retry support. | @@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Media Engine, with retry support. | ||
| 5 | import os | 5 | import os |
| 6 | import sys | 6 | import sys |
| 7 | from datetime import datetime | 7 | from datetime import datetime |
| 8 | -from typing import Any, Dict, Optional | 8 | +from typing import Any, Dict, Optional, Generator |
| 9 | +from loguru import logger | ||
| 9 | 10 | ||
| 10 | from openai import OpenAI | 11 | from openai import OpenAI |
| 11 | 12 | ||
| @@ -85,6 +86,76 @@ class LLMClient: | @@ -85,6 +86,76 @@ class LLMClient: | ||
| 85 | return self.validate_response(response.choices[0].message.content) | 86 | return self.validate_response(response.choices[0].message.content) |
| 86 | return "" | 87 | return "" |
| 87 | 88 | ||
| 89 | + @with_retry(LLM_RETRY_CONFIG) | ||
| 90 | + def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]: | ||
| 91 | + """ | ||
| 92 | + 流式调用LLM,逐步返回响应内容 | ||
| 93 | + | ||
| 94 | + Args: | ||
| 95 | + system_prompt: 系统提示词 | ||
| 96 | + user_prompt: 用户提示词 | ||
| 97 | + **kwargs: 额外参数(temperature, top_p等) | ||
| 98 | + | ||
| 99 | + Yields: | ||
| 100 | + 响应文本块(str) | ||
| 101 | + """ | ||
| 102 | + current_time = datetime.now().strftime("%Y年%m月%d日%H时%M分") | ||
| 103 | + time_prefix = f"今天的实际时间是{current_time}" | ||
| 104 | + if user_prompt: | ||
| 105 | + user_prompt = f"{time_prefix}\n{user_prompt}" | ||
| 106 | + else: | ||
| 107 | + user_prompt = time_prefix | ||
| 108 | + messages = [ | ||
| 109 | + {"role": "system", "content": system_prompt}, | ||
| 110 | + {"role": "user", "content": user_prompt}, | ||
| 111 | + ] | ||
| 112 | + | ||
| 113 | + allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"} | ||
| 114 | + extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None} | ||
| 115 | + # 强制使用流式 | ||
| 116 | + extra_params["stream"] = True | ||
| 117 | + | ||
| 118 | + timeout = kwargs.pop("timeout", self.timeout) | ||
| 119 | + | ||
| 120 | + try: | ||
| 121 | + stream = self.client.chat.completions.create( | ||
| 122 | + model=self.model_name, | ||
| 123 | + messages=messages, | ||
| 124 | + timeout=timeout, | ||
| 125 | + **extra_params, | ||
| 126 | + ) | ||
| 127 | + | ||
| 128 | + for chunk in stream: | ||
| 129 | + if chunk.choices and len(chunk.choices) > 0: | ||
| 130 | + delta = chunk.choices[0].delta | ||
| 131 | + if delta and delta.content: | ||
| 132 | + yield delta.content | ||
| 133 | + except Exception as e: | ||
| 134 | + logger.error(f"流式请求失败: {str(e)}") | ||
| 135 | + raise e | ||
| 136 | + | ||
| 137 | + def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str: | ||
| 138 | + """ | ||
| 139 | + 流式调用LLM并安全地拼接为完整字符串(避免UTF-8多字节字符截断) | ||
| 140 | + | ||
| 141 | + Args: | ||
| 142 | + system_prompt: 系统提示词 | ||
| 143 | + user_prompt: 用户提示词 | ||
| 144 | + **kwargs: 额外参数(temperature, top_p等) | ||
| 145 | + | ||
| 146 | + Returns: | ||
| 147 | + 完整的响应字符串 | ||
| 148 | + """ | ||
| 149 | + # 以字节形式收集所有块 | ||
| 150 | + byte_chunks = [] | ||
| 151 | + for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs): | ||
| 152 | + byte_chunks.append(chunk.encode('utf-8')) | ||
| 153 | + | ||
| 154 | + # 拼接所有字节,然后一次性解码 | ||
| 155 | + if byte_chunks: | ||
| 156 | + return b''.join(byte_chunks).decode('utf-8', errors='replace') | ||
| 157 | + return "" | ||
| 158 | + | ||
| 88 | @staticmethod | 159 | @staticmethod |
| 89 | def validate_response(response: Optional[str]) -> str: | 160 | def validate_response(response: Optional[str]) -> str: |
| 90 | if response is None: | 161 | if response is None: |
| @@ -68,8 +68,8 @@ class ReportFormattingNode(BaseNode): | @@ -68,8 +68,8 @@ class ReportFormattingNode(BaseNode): | ||
| 68 | 68 | ||
| 69 | logger.info("正在格式化最终报告") | 69 | logger.info("正在格式化最终报告") |
| 70 | 70 | ||
| 71 | - # 调用LLM生成Markdown格式 | ||
| 72 | - response = self.llm_client.invoke( | 71 | + # 调用LLM生成Markdown格式(流式,安全拼接UTF-8) |
| 72 | + response = self.llm_client.stream_invoke_to_string( | ||
| 73 | SYSTEM_PROMPT_REPORT_FORMATTING, | 73 | SYSTEM_PROMPT_REPORT_FORMATTING, |
| 74 | message, | 74 | message, |
| 75 | ) | 75 | ) |
| @@ -52,7 +52,7 @@ class ReportStructureNode(StateMutationNode): | @@ -52,7 +52,7 @@ class ReportStructureNode(StateMutationNode): | ||
| 52 | logger.info(f"正在为查询生成报告结构: {self.query}") | 52 | logger.info(f"正在为查询生成报告结构: {self.query}") |
| 53 | 53 | ||
| 54 | # 调用LLM | 54 | # 调用LLM |
| 55 | - response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) | 55 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) |
| 56 | 56 | ||
| 57 | # 处理响应 | 57 | # 处理响应 |
| 58 | processed_response = self.process_output(response) | 58 | processed_response = self.process_output(response) |
| @@ -66,7 +66,7 @@ class FirstSearchNode(BaseNode): | @@ -66,7 +66,7 @@ class FirstSearchNode(BaseNode): | ||
| 66 | logger.info("正在生成首次搜索查询") | 66 | logger.info("正在生成首次搜索查询") |
| 67 | 67 | ||
| 68 | # 调用LLM | 68 | # 调用LLM |
| 69 | - response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message) | 69 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SEARCH, message) |
| 70 | 70 | ||
| 71 | # 处理响应 | 71 | # 处理响应 |
| 72 | processed_response = self.process_output(response) | 72 | processed_response = self.process_output(response) |
| @@ -201,7 +201,7 @@ class ReflectionNode(BaseNode): | @@ -201,7 +201,7 @@ class ReflectionNode(BaseNode): | ||
| 201 | logger.info("正在进行反思并生成新搜索查询") | 201 | logger.info("正在进行反思并生成新搜索查询") |
| 202 | 202 | ||
| 203 | # 调用LLM | 203 | # 调用LLM |
| 204 | - response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message) | 204 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION, message) |
| 205 | 205 | ||
| 206 | # 处理响应 | 206 | # 处理响应 |
| 207 | processed_response = self.process_output(response) | 207 | processed_response = self.process_output(response) |
| @@ -99,8 +99,8 @@ class FirstSummaryNode(StateMutationNode): | @@ -99,8 +99,8 @@ class FirstSummaryNode(StateMutationNode): | ||
| 99 | 99 | ||
| 100 | logger.info("正在生成首次段落总结") | 100 | logger.info("正在生成首次段落总结") |
| 101 | 101 | ||
| 102 | - # 调用LLM生成总结 | ||
| 103 | - response = self.llm_client.invoke( | 102 | + # 调用LLM生成总结(流式,安全拼接UTF-8) |
| 103 | + response = self.llm_client.stream_invoke_to_string( | ||
| 104 | SYSTEM_PROMPT_FIRST_SUMMARY, | 104 | SYSTEM_PROMPT_FIRST_SUMMARY, |
| 105 | message, | 105 | message, |
| 106 | ) | 106 | ) |
| @@ -267,8 +267,8 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -267,8 +267,8 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 267 | 267 | ||
| 268 | logger.info("正在生成反思总结") | 268 | logger.info("正在生成反思总结") |
| 269 | 269 | ||
| 270 | - # 调用LLM生成总结 | ||
| 271 | - response = self.llm_client.invoke( | 270 | + # 调用LLM生成总结(流式,安全拼接UTF-8) |
| 271 | + response = self.llm_client.stream_invoke_to_string( | ||
| 272 | SYSTEM_PROMPT_REFLECTION_SUMMARY, | 272 | SYSTEM_PROMPT_REFLECTION_SUMMARY, |
| 273 | message, | 273 | message, |
| 274 | ) | 274 | ) |
| @@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Query Engine, with retry support. | @@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Query Engine, with retry support. | ||
| 5 | import os | 5 | import os |
| 6 | import sys | 6 | import sys |
| 7 | from datetime import datetime | 7 | from datetime import datetime |
| 8 | -from typing import Any, Dict, Optional | 8 | +from typing import Any, Dict, Optional, Generator |
| 9 | +from loguru import logger | ||
| 9 | 10 | ||
| 10 | from openai import OpenAI | 11 | from openai import OpenAI |
| 11 | 12 | ||
| @@ -82,6 +83,76 @@ class LLMClient: | @@ -82,6 +83,76 @@ class LLMClient: | ||
| 82 | return self.validate_response(response.choices[0].message.content) | 83 | return self.validate_response(response.choices[0].message.content) |
| 83 | return "" | 84 | return "" |
| 84 | 85 | ||
| 86 | + @with_retry(LLM_RETRY_CONFIG) | ||
| 87 | + def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]: | ||
| 88 | + """ | ||
| 89 | + 流式调用LLM,逐步返回响应内容 | ||
| 90 | + | ||
| 91 | + Args: | ||
| 92 | + system_prompt: 系统提示词 | ||
| 93 | + user_prompt: 用户提示词 | ||
| 94 | + **kwargs: 额外参数(temperature, top_p等) | ||
| 95 | + | ||
| 96 | + Yields: | ||
| 97 | + 响应文本块(str) | ||
| 98 | + """ | ||
| 99 | + current_time = datetime.now().strftime("%Y年%m月%d日%H时%M分") | ||
| 100 | + time_prefix = f"今天的实际时间是{current_time}" | ||
| 101 | + if user_prompt: | ||
| 102 | + user_prompt = f"{time_prefix}\n{user_prompt}" | ||
| 103 | + else: | ||
| 104 | + user_prompt = time_prefix | ||
| 105 | + messages = [ | ||
| 106 | + {"role": "system", "content": system_prompt}, | ||
| 107 | + {"role": "user", "content": user_prompt}, | ||
| 108 | + ] | ||
| 109 | + | ||
| 110 | + allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"} | ||
| 111 | + extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None} | ||
| 112 | + # 强制使用流式 | ||
| 113 | + extra_params["stream"] = True | ||
| 114 | + | ||
| 115 | + timeout = kwargs.pop("timeout", self.timeout) | ||
| 116 | + | ||
| 117 | + try: | ||
| 118 | + stream = self.client.chat.completions.create( | ||
| 119 | + model=self.model_name, | ||
| 120 | + messages=messages, | ||
| 121 | + timeout=timeout, | ||
| 122 | + **extra_params, | ||
| 123 | + ) | ||
| 124 | + | ||
| 125 | + for chunk in stream: | ||
| 126 | + if chunk.choices and len(chunk.choices) > 0: | ||
| 127 | + delta = chunk.choices[0].delta | ||
| 128 | + if delta and delta.content: | ||
| 129 | + yield delta.content | ||
| 130 | + except Exception as e: | ||
| 131 | + logger.error(f"流式请求失败: {str(e)}") | ||
| 132 | + raise e | ||
| 133 | + | ||
| 134 | + def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str: | ||
| 135 | + """ | ||
| 136 | + 流式调用LLM并安全地拼接为完整字符串(避免UTF-8多字节字符截断) | ||
| 137 | + | ||
| 138 | + Args: | ||
| 139 | + system_prompt: 系统提示词 | ||
| 140 | + user_prompt: 用户提示词 | ||
| 141 | + **kwargs: 额外参数(temperature, top_p等) | ||
| 142 | + | ||
| 143 | + Returns: | ||
| 144 | + 完整的响应字符串 | ||
| 145 | + """ | ||
| 146 | + # 以字节形式收集所有块 | ||
| 147 | + byte_chunks = [] | ||
| 148 | + for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs): | ||
| 149 | + byte_chunks.append(chunk.encode('utf-8')) | ||
| 150 | + | ||
| 151 | + # 拼接所有字节,然后一次性解码 | ||
| 152 | + if byte_chunks: | ||
| 153 | + return b''.join(byte_chunks).decode('utf-8', errors='replace') | ||
| 154 | + return "" | ||
| 155 | + | ||
| 85 | @staticmethod | 156 | @staticmethod |
| 86 | def validate_response(response: Optional[str]) -> str: | 157 | def validate_response(response: Optional[str]) -> str: |
| 87 | if response is None: | 158 | if response is None: |
| @@ -68,8 +68,8 @@ class ReportFormattingNode(BaseNode): | @@ -68,8 +68,8 @@ class ReportFormattingNode(BaseNode): | ||
| 68 | 68 | ||
| 69 | logger.info("正在格式化最终报告") | 69 | logger.info("正在格式化最终报告") |
| 70 | 70 | ||
| 71 | - # 调用LLM生成Markdown格式 | ||
| 72 | - response = self.llm_client.invoke( | 71 | + # 调用LLM生成Markdown格式(流式,安全拼接UTF-8) |
| 72 | + response = self.llm_client.stream_invoke_to_string( | ||
| 73 | SYSTEM_PROMPT_REPORT_FORMATTING, | 73 | SYSTEM_PROMPT_REPORT_FORMATTING, |
| 74 | message, | 74 | message, |
| 75 | ) | 75 | ) |
| @@ -52,7 +52,7 @@ class ReportStructureNode(StateMutationNode): | @@ -52,7 +52,7 @@ class ReportStructureNode(StateMutationNode): | ||
| 52 | logger.info(f"正在为查询生成报告结构: {self.query}") | 52 | logger.info(f"正在为查询生成报告结构: {self.query}") |
| 53 | 53 | ||
| 54 | # 调用LLM | 54 | # 调用LLM |
| 55 | - response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) | 55 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) |
| 56 | 56 | ||
| 57 | # 处理响应 | 57 | # 处理响应 |
| 58 | processed_response = self.process_output(response) | 58 | processed_response = self.process_output(response) |
| @@ -66,7 +66,7 @@ class FirstSearchNode(BaseNode): | @@ -66,7 +66,7 @@ class FirstSearchNode(BaseNode): | ||
| 66 | logger.info("正在生成首次搜索查询") | 66 | logger.info("正在生成首次搜索查询") |
| 67 | 67 | ||
| 68 | # 调用LLM | 68 | # 调用LLM |
| 69 | - response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message) | 69 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SEARCH, message) |
| 70 | 70 | ||
| 71 | # 处理响应 | 71 | # 处理响应 |
| 72 | processed_response = self.process_output(response) | 72 | processed_response = self.process_output(response) |
| @@ -201,7 +201,7 @@ class ReflectionNode(BaseNode): | @@ -201,7 +201,7 @@ class ReflectionNode(BaseNode): | ||
| 201 | logger.info("正在进行反思并生成新搜索查询") | 201 | logger.info("正在进行反思并生成新搜索查询") |
| 202 | 202 | ||
| 203 | # 调用LLM | 203 | # 调用LLM |
| 204 | - response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message) | 204 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION, message) |
| 205 | 205 | ||
| 206 | # 处理响应 | 206 | # 处理响应 |
| 207 | processed_response = self.process_output(response) | 207 | processed_response = self.process_output(response) |
| @@ -99,8 +99,8 @@ class FirstSummaryNode(StateMutationNode): | @@ -99,8 +99,8 @@ class FirstSummaryNode(StateMutationNode): | ||
| 99 | 99 | ||
| 100 | logger.info("正在生成首次段落总结") | 100 | logger.info("正在生成首次段落总结") |
| 101 | 101 | ||
| 102 | - # 调用LLM生成总结 | ||
| 103 | - response = self.llm_client.invoke( | 102 | + # 调用LLM生成总结(流式,安全拼接UTF-8) |
| 103 | + response = self.llm_client.stream_invoke_to_string( | ||
| 104 | SYSTEM_PROMPT_FIRST_SUMMARY, | 104 | SYSTEM_PROMPT_FIRST_SUMMARY, |
| 105 | message, | 105 | message, |
| 106 | ) | 106 | ) |
| @@ -267,8 +267,8 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -267,8 +267,8 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 267 | 267 | ||
| 268 | logger.info("正在生成反思总结") | 268 | logger.info("正在生成反思总结") |
| 269 | 269 | ||
| 270 | - # 调用LLM生成总结 | ||
| 271 | - response = self.llm_client.invoke( | 270 | + # 调用LLM生成总结(流式,安全拼接UTF-8) |
| 271 | + response = self.llm_client.stream_invoke_to_string( | ||
| 272 | SYSTEM_PROMPT_REFLECTION_SUMMARY, | 272 | SYSTEM_PROMPT_REFLECTION_SUMMARY, |
| 273 | message, | 273 | message, |
| 274 | ) | 274 | ) |
| @@ -4,7 +4,8 @@ Unified OpenAI-compatible LLM client for the Report Engine, with retry support. | @@ -4,7 +4,8 @@ Unified OpenAI-compatible LLM client for the Report Engine, with retry support. | ||
| 4 | 4 | ||
| 5 | import os | 5 | import os |
| 6 | import sys | 6 | import sys |
| 7 | -from typing import Any, Dict, Optional | 7 | +from typing import Any, Dict, Optional, Generator |
| 8 | +from loguru import logger | ||
| 8 | 9 | ||
| 9 | from openai import OpenAI | 10 | from openai import OpenAI |
| 10 | 11 | ||
| @@ -75,6 +76,70 @@ class LLMClient: | @@ -75,6 +76,70 @@ class LLMClient: | ||
| 75 | return self.validate_response(response.choices[0].message.content) | 76 | return self.validate_response(response.choices[0].message.content) |
| 76 | return "" | 77 | return "" |
| 77 | 78 | ||
| 79 | + @with_retry(LLM_RETRY_CONFIG) | ||
| 80 | + def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]: | ||
| 81 | + """ | ||
| 82 | + 流式调用LLM,逐步返回响应内容 | ||
| 83 | + | ||
| 84 | + Args: | ||
| 85 | + system_prompt: 系统提示词 | ||
| 86 | + user_prompt: 用户提示词 | ||
| 87 | + **kwargs: 额外参数(temperature, top_p等) | ||
| 88 | + | ||
| 89 | + Yields: | ||
| 90 | + 响应文本块(str) | ||
| 91 | + """ | ||
| 92 | + messages = [ | ||
| 93 | + {"role": "system", "content": system_prompt}, | ||
| 94 | + {"role": "user", "content": user_prompt}, | ||
| 95 | + ] | ||
| 96 | + | ||
| 97 | + allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"} | ||
| 98 | + extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None} | ||
| 99 | + # 强制使用流式 | ||
| 100 | + extra_params["stream"] = True | ||
| 101 | + | ||
| 102 | + timeout = kwargs.pop("timeout", self.timeout) | ||
| 103 | + | ||
| 104 | + try: | ||
| 105 | + stream = self.client.chat.completions.create( | ||
| 106 | + model=self.model_name, | ||
| 107 | + messages=messages, | ||
| 108 | + timeout=timeout, | ||
| 109 | + **extra_params, | ||
| 110 | + ) | ||
| 111 | + | ||
| 112 | + for chunk in stream: | ||
| 113 | + if chunk.choices and len(chunk.choices) > 0: | ||
| 114 | + delta = chunk.choices[0].delta | ||
| 115 | + if delta and delta.content: | ||
| 116 | + yield delta.content | ||
| 117 | + except Exception as e: | ||
| 118 | + logger.error(f"流式请求失败: {str(e)}") | ||
| 119 | + raise e | ||
| 120 | + | ||
| 121 | + def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str: | ||
| 122 | + """ | ||
| 123 | + 流式调用LLM并安全地拼接为完整字符串(避免UTF-8多字节字符截断) | ||
| 124 | + | ||
| 125 | + Args: | ||
| 126 | + system_prompt: 系统提示词 | ||
| 127 | + user_prompt: 用户提示词 | ||
| 128 | + **kwargs: 额外参数(temperature, top_p等) | ||
| 129 | + | ||
| 130 | + Returns: | ||
| 131 | + 完整的响应字符串 | ||
| 132 | + """ | ||
| 133 | + # 以字节形式收集所有块 | ||
| 134 | + byte_chunks = [] | ||
| 135 | + for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs): | ||
| 136 | + byte_chunks.append(chunk.encode('utf-8')) | ||
| 137 | + | ||
| 138 | + # 拼接所有字节,然后一次性解码 | ||
| 139 | + if byte_chunks: | ||
| 140 | + return b''.join(byte_chunks).decode('utf-8', errors='replace') | ||
| 141 | + return "" | ||
| 142 | + | ||
| 78 | @staticmethod | 143 | @staticmethod |
| 79 | def validate_response(response: Optional[str]) -> str: | 144 | def validate_response(response: Optional[str]) -> str: |
| 80 | if response is None: | 145 | if response is None: |
| @@ -60,7 +60,7 @@ class HTMLGenerationNode(StateMutationNode): | @@ -60,7 +60,7 @@ class HTMLGenerationNode(StateMutationNode): | ||
| 60 | message = json.dumps(llm_input, ensure_ascii=False, indent=2) | 60 | message = json.dumps(llm_input, ensure_ascii=False, indent=2) |
| 61 | 61 | ||
| 62 | # 调用LLM生成HTML | 62 | # 调用LLM生成HTML |
| 63 | - response = self.llm_client.invoke(SYSTEM_PROMPT_HTML_GENERATION, message) | 63 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_HTML_GENERATION, message) |
| 64 | 64 | ||
| 65 | # 处理响应(简化版) | 65 | # 处理响应(简化版) |
| 66 | processed_response = self.process_output(response) | 66 | processed_response = self.process_output(response) |
| @@ -115,7 +115,7 @@ class TemplateSelectionNode(BaseNode): | @@ -115,7 +115,7 @@ class TemplateSelectionNode(BaseNode): | ||
| 115 | 请根据查询内容、报告内容和论坛日志的具体情况,选择最合适的模板。""" | 115 | 请根据查询内容、报告内容和论坛日志的具体情况,选择最合适的模板。""" |
| 116 | 116 | ||
| 117 | # 调用LLM | 117 | # 调用LLM |
| 118 | - response = self.llm_client.invoke(SYSTEM_PROMPT_TEMPLATE_SELECTION, user_message) | 118 | + response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_TEMPLATE_SELECTION, user_message) |
| 119 | 119 | ||
| 120 | # 检查响应是否为空 | 120 | # 检查响应是否为空 |
| 121 | if not response or not response.strip(): | 121 | if not response or not response.strip(): |
| @@ -6,10 +6,7 @@ Forum日志读取工具 | @@ -6,10 +6,7 @@ Forum日志读取工具 | ||
| 6 | import re | 6 | import re |
| 7 | from pathlib import Path | 7 | from pathlib import Path |
| 8 | from typing import Optional, List, Dict | 8 | from typing import Optional, List, Dict |
| 9 | -import logging | ||
| 10 | - | ||
| 11 | -logger = logging.getLogger(__name__) | ||
| 12 | - | 9 | +from loguru import logger |
| 13 | 10 | ||
| 14 | def get_latest_host_speech(log_dir: str = "logs") -> Optional[str]: | 11 | def get_latest_host_speech(log_dir: str = "logs") -> Optional[str]: |
| 15 | """ | 12 | """ |
-
Please register or login to post a comment