戒酒的李白

JSON parsing fix.

... ... @@ -13,7 +13,8 @@ from ..prompts import SYSTEM_PROMPT_REPORT_STRUCTURE
from ..utils.text_processing import (
remove_reasoning_from_output,
clean_json_tags,
extract_clean_response
extract_clean_response,
fix_incomplete_json
)
... ... @@ -77,48 +78,91 @@ class ReportStructureNode(StateMutationNode):
cleaned_output = remove_reasoning_from_output(output)
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
# 解析JSON
try:
report_structure = json.loads(cleaned_output)
except JSONDecodeError:
self.log_info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
report_structure = extract_clean_response(cleaned_output)
if "error" in report_structure:
raise ValueError("JSON解析失败")
self.log_error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
report_structure = json.loads(fixed_json)
self.log_info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
# 返回默认结构
return self._generate_default_structure()
else:
self.log_error("无法修复JSON,使用默认结构")
return self._generate_default_structure()
# 验证结构
if not isinstance(report_structure, list):
raise ValueError("报告结构应该是一个列表")
self.log_info("报告结构不是列表,尝试转换...")
if isinstance(report_structure, dict):
# 如果是单个对象,包装成列表
report_structure = [report_structure]
else:
self.log_error("报告结构格式无效,使用默认结构")
return self._generate_default_structure()
# 验证每个段落
validated_structure = []
for i, paragraph in enumerate(report_structure):
if not isinstance(paragraph, dict):
self.log_warning(f"段落 {i+1} 不是字典格式,跳过")
continue
title = paragraph.get("title", f"段落 {i+1}")
content = paragraph.get("content", "")
if not title or not content:
self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过")
continue
validated_structure.append({
"title": title,
"content": content
})
if not validated_structure:
self.log_warning("没有有效的段落结构,使用默认结构")
return self._generate_default_structure()
self.log_info(f"成功验证 {len(validated_structure)} 个段落结构")
return validated_structure
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
# 返回默认结构
return [
{
"title": "概述",
"content": f"对'{self.query}'的总体概述和背景介绍"
},
{
"title": "详细分析",
"content": f"深入分析'{self.query}'的相关内容"
}
]
return self._generate_default_structure()
def _generate_default_structure(self) -> List[Dict[str, str]]:
"""
生成默认的报告结构
Returns:
默认的报告结构列表
"""
self.log_info("生成默认报告结构")
return [
{
"title": "研究概述",
"content": "对查询主题进行总体概述和分析"
},
{
"title": "深度分析",
"content": "深入分析查询主题的各个方面"
}
]
def mutate_state(self, input_data: Any = None, state: State = None, **kwargs) -> State:
"""
... ...
... ... @@ -12,7 +12,8 @@ from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
from ..utils.text_processing import (
remove_reasoning_from_output,
clean_json_tags,
extract_clean_response
extract_clean_response,
fix_incomplete_json
)
... ... @@ -91,21 +92,40 @@ class FirstSearchNode(BaseNode):
cleaned_output = remove_reasoning_from_output(output)
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
# 解析JSON
try:
result = json.loads(cleaned_output)
except JSONDecodeError:
self.log_info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
raise ValueError("JSON解析失败")
self.log_error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
# 返回默认查询
return self._get_default_search_query()
else:
self.log_error("无法修复JSON,使用默认查询")
return self._get_default_search_query()
# 验证和清理结果
search_query = result.get("search_query", "")
reasoning = result.get("reasoning", "")
if not search_query:
raise ValueError("未找到搜索查询")
self.log_warning("未找到搜索查询,使用默认查询")
return self._get_default_search_query()
return {
"search_query": search_query,
... ... @@ -115,10 +135,19 @@ class FirstSearchNode(BaseNode):
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
# 返回默认查询
return {
"search_query": "相关主题研究",
"reasoning": "由于解析失败,使用默认搜索查询"
}
return self._get_default_search_query()
def _get_default_search_query(self) -> Dict[str, str]:
"""
获取默认搜索查询
Returns:
默认的搜索查询字典
"""
return {
"search_query": "相关主题研究",
"reasoning": "由于解析失败,使用默认搜索查询"
}
class ReflectionNode(BaseNode):
... ... @@ -198,21 +227,40 @@ class ReflectionNode(BaseNode):
cleaned_output = remove_reasoning_from_output(output)
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
# 解析JSON
try:
result = json.loads(cleaned_output)
except JSONDecodeError:
self.log_info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
raise ValueError("JSON解析失败")
self.log_error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
# 返回默认查询
return self._get_default_reflection_query()
else:
self.log_error("无法修复JSON,使用默认查询")
return self._get_default_reflection_query()
# 验证和清理结果
search_query = result.get("search_query", "")
reasoning = result.get("reasoning", "")
if not search_query:
raise ValueError("未找到搜索查询")
self.log_warning("未找到搜索查询,使用默认查询")
return self._get_default_reflection_query()
return {
"search_query": search_query,
... ... @@ -222,7 +270,16 @@ class ReflectionNode(BaseNode):
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
# 返回默认查询
return {
"search_query": "深度研究补充信息",
"reasoning": "由于解析失败,使用默认反思搜索查询"
}
return self._get_default_reflection_query()
def _get_default_reflection_query(self) -> Dict[str, str]:
"""
获取默认反思搜索查询
Returns:
默认的反思搜索查询字典
"""
return {
"search_query": "深度研究补充信息",
"reasoning": "由于解析失败,使用默认反思搜索查询"
}
... ...
... ... @@ -14,6 +14,7 @@ from ..utils.text_processing import (
remove_reasoning_from_output,
clean_json_tags,
extract_clean_response,
fix_incomplete_json,
format_search_results_for_prompt
)
... ... @@ -82,25 +83,42 @@ class FirstSummaryNode(StateMutationNode):
def process_output(self, output: str) -> str:
"""
处理LLM输出,提取段落总结
处理LLM输出,提取段落内容
Args:
output: LLM原始输出
Returns:
段落总结内容
段落内容
"""
try:
# 清理响应文本
cleaned_output = remove_reasoning_from_output(output)
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
# 解析JSON
try:
result = json.loads(cleaned_output)
except JSONDecodeError:
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
self.log_info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
# 提取段落内容
if isinstance(result, dict):
... ... @@ -224,12 +242,29 @@ class ReflectionSummaryNode(StateMutationNode):
cleaned_output = remove_reasoning_from_output(output)
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
# 解析JSON
try:
result = json.loads(cleaned_output)
except JSONDecodeError:
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
self.log_info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
# 提取更新后的段落内容
if isinstance(result, dict):
... ...
... ... @@ -55,6 +55,20 @@ def remove_reasoning_from_output(text: str) -> str:
Returns:
清理后的文本
"""
# 查找JSON开始位置
json_start = -1
# 尝试找到第一个 { 或 [
for i, char in enumerate(text):
if char in '{[':
json_start = i
break
if json_start != -1:
# 从JSON开始位置截取
return text[json_start:].strip()
# 如果没有找到JSON标记,尝试其他方法
# 移除常见的推理标识
patterns = [
r'(?:reasoning|推理|思考|分析)[::]\s*.*?(?=\{|\[)', # 移除推理部分
... ... @@ -88,6 +102,14 @@ def extract_clean_response(text: str) -> Dict[str, Any]:
except JSONDecodeError:
pass
# 尝试修复不完整的JSON
fixed_text = fix_incomplete_json(cleaned_text)
if fixed_text:
try:
return json.loads(fixed_text)
except JSONDecodeError:
pass
# 尝试查找JSON对象
json_pattern = r'\{.*\}'
match = re.search(json_pattern, cleaned_text, re.DOTALL)
... ... @@ -111,6 +133,92 @@ def extract_clean_response(text: str) -> Dict[str, Any]:
return {"error": "JSON解析失败", "raw_text": cleaned_text}
def fix_incomplete_json(text: str) -> str:
"""
修复不完整的JSON响应
Args:
text: 原始文本
Returns:
修复后的JSON文本,如果无法修复则返回空字符串
"""
# 移除多余的逗号和空白
text = re.sub(r',\s*}', '}', text)
text = re.sub(r',\s*]', ']', text)
# 检查是否已经是有效的JSON
try:
json.loads(text)
return text
except JSONDecodeError:
pass
# 检查是否缺少开头的数组符号
if text.strip().startswith('{') and not text.strip().startswith('['):
# 如果以对象开始,尝试包装成数组
if text.count('{') > 1:
# 多个对象,包装成数组
text = '[' + text + ']'
else:
# 单个对象,包装成数组
text = '[' + text + ']'
# 检查是否缺少结尾的数组符号
if text.strip().endswith('}') and not text.strip().endswith(']'):
# 如果以对象结束,尝试包装成数组
if text.count('}') > 1:
# 多个对象,包装成数组
text = '[' + text + ']'
else:
# 单个对象,包装成数组
text = '[' + text + ']'
# 检查括号是否匹配
open_braces = text.count('{')
close_braces = text.count('}')
open_brackets = text.count('[')
close_brackets = text.count(']')
# 修复不匹配的括号
if open_braces > close_braces:
text += '}' * (open_braces - close_braces)
if open_brackets > close_brackets:
text += ']' * (open_brackets - close_brackets)
# 验证修复后的JSON是否有效
try:
json.loads(text)
return text
except JSONDecodeError:
# 如果仍然无效,尝试更激进的修复
return fix_aggressive_json(text)
def fix_aggressive_json(text: str) -> str:
"""
更激进的JSON修复方法
Args:
text: 原始文本
Returns:
修复后的JSON文本
"""
# 查找所有可能的JSON对象
objects = re.findall(r'\{[^{}]*\}', text)
if len(objects) >= 2:
# 如果有多个对象,包装成数组
return '[' + ','.join(objects) + ']'
elif len(objects) == 1:
# 如果只有一个对象,包装成数组
return '[' + objects[0] + ']'
else:
# 如果没有找到对象,返回空数组
return '[]'
def update_state_with_search_results(search_results: List[Dict[str, Any]],
paragraph_index: int, state: Any) -> Any:
"""
... ...
... ... @@ -13,7 +13,7 @@ import json
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '.'))
from src import DeepSearchAgent, Config
from config import DEEPSEEK_API_KEY, DEEPSEEK_API_KEY_2, TAVILY_API_KEY
from config import DEEPSEEK_API_KEY, TAVILY_API_KEY
def main():
... ...