Showing
33 changed files
with
919 additions
and
1200 deletions
Too many changes to show.
To preserve performance only 33 of 33+ files are displayed.
| @@ -63,5 +63,7 @@ KEYWORD_OPTIMIZER_MODEL_NAME= | @@ -63,5 +63,7 @@ KEYWORD_OPTIMIZER_MODEL_NAME= | ||
| 63 | # ================== 网络工具配置 ==================== | 63 | # ================== 网络工具配置 ==================== |
| 64 | # Tavily API密钥,用于Tavily网络搜索。注册地址:https://www.tavily.com/ | 64 | # Tavily API密钥,用于Tavily网络搜索。注册地址:https://www.tavily.com/ |
| 65 | TAVILY_API_KEY= | 65 | TAVILY_API_KEY= |
| 66 | +# Bocha Web/AI Search BASEURL,用于Bocha搜索。注册地址:https://open.bochaai.com/ | ||
| 67 | +BOCHA_BASE_URL= | ||
| 66 | # Bocha Web Search API密钥,用于Bocha搜索。注册地址:https://open.bochaai.com/ | 68 | # Bocha Web Search API密钥,用于Bocha搜索。注册地址:https://open.bochaai.com/ |
| 67 | BOCHA_WEB_SEARCH_API_KEY= | 69 | BOCHA_WEB_SEARCH_API_KEY= |
| @@ -11,13 +11,14 @@ import re | @@ -11,13 +11,14 @@ import re | ||
| 11 | import json | 11 | import json |
| 12 | from typing import Dict, Optional, List | 12 | from typing import Dict, Optional, List |
| 13 | from threading import Lock | 13 | from threading import Lock |
| 14 | +from loguru import logger | ||
| 14 | 15 | ||
| 15 | # 导入论坛主持人模块 | 16 | # 导入论坛主持人模块 |
| 16 | try: | 17 | try: |
| 17 | from .llm_host import generate_host_speech | 18 | from .llm_host import generate_host_speech |
| 18 | HOST_AVAILABLE = True | 19 | HOST_AVAILABLE = True |
| 19 | except ImportError: | 20 | except ImportError: |
| 20 | - print("ForumEngine: 论坛主持人模块未找到,将以纯监控模式运行") | 21 | + logger.warning("ForumEngine: 论坛主持人模块未找到,将以纯监控模式运行") |
| 21 | HOST_AVAILABLE = False | 22 | HOST_AVAILABLE = False |
| 22 | 23 | ||
| 23 | class LogMonitor: | 24 | class LogMonitor: |
| @@ -76,7 +77,7 @@ class LogMonitor: | @@ -76,7 +77,7 @@ class LogMonitor: | ||
| 76 | pass # 先创建空文件 | 77 | pass # 先创建空文件 |
| 77 | self.write_to_forum_log(f"=== ForumEngine 监控开始 - {start_time} ===", "SYSTEM") | 78 | self.write_to_forum_log(f"=== ForumEngine 监控开始 - {start_time} ===", "SYSTEM") |
| 78 | 79 | ||
| 79 | - print(f"ForumEngine: forum.log 已清空并初始化") | 80 | + logger.info(f"ForumEngine: forum.log 已清空并初始化") |
| 80 | 81 | ||
| 81 | # 重置JSON捕获状态 | 82 | # 重置JSON捕获状态 |
| 82 | self.capturing_json = {} | 83 | self.capturing_json = {} |
| @@ -88,7 +89,7 @@ class LogMonitor: | @@ -88,7 +89,7 @@ class LogMonitor: | ||
| 88 | self.is_host_generating = False | 89 | self.is_host_generating = False |
| 89 | 90 | ||
| 90 | except Exception as e: | 91 | except Exception as e: |
| 91 | - print(f"ForumEngine: 清空forum.log失败: {e}") | 92 | + logger.exception(f"ForumEngine: 清空forum.log失败: {e}") |
| 92 | 93 | ||
| 93 | def write_to_forum_log(self, content: str, source: str = None): | 94 | def write_to_forum_log(self, content: str, source: str = None): |
| 94 | """写入内容到forum.log(线程安全)""" | 95 | """写入内容到forum.log(线程安全)""" |
| @@ -105,7 +106,7 @@ class LogMonitor: | @@ -105,7 +106,7 @@ class LogMonitor: | ||
| 105 | f.write(f"[{timestamp}] {content_one_line}\n") | 106 | f.write(f"[{timestamp}] {content_one_line}\n") |
| 106 | f.flush() | 107 | f.flush() |
| 107 | except Exception as e: | 108 | except Exception as e: |
| 108 | - print(f"ForumEngine: 写入forum.log失败: {e}") | 109 | + logger.exception(f"ForumEngine: 写入forum.log失败: {e}") |
| 109 | 110 | ||
| 110 | def is_target_log_line(self, line: str) -> bool: | 111 | def is_target_log_line(self, line: str) -> bool: |
| 111 | """检查是否是目标日志行(SummaryNode)""" | 112 | """检查是否是目标日志行(SummaryNode)""" |
| @@ -241,7 +242,7 @@ class LogMonitor: | @@ -241,7 +242,7 @@ class LogMonitor: | ||
| 241 | return f"清理后的输出: {json.dumps(json_obj, ensure_ascii=False, indent=2)}" | 242 | return f"清理后的输出: {json.dumps(json_obj, ensure_ascii=False, indent=2)}" |
| 242 | 243 | ||
| 243 | except Exception as e: | 244 | except Exception as e: |
| 244 | - print(f"ForumEngine: 格式化JSON时出错: {e}") | 245 | + logger.exception(f"ForumEngine: 格式化JSON时出错: {e}") |
| 245 | return f"清理后的输出: {json.dumps(json_obj, ensure_ascii=False, indent=2)}" | 246 | return f"清理后的输出: {json.dumps(json_obj, ensure_ascii=False, indent=2)}" |
| 246 | 247 | ||
| 247 | def extract_node_content(self, line: str) -> Optional[str]: | 248 | def extract_node_content(self, line: str) -> Optional[str]: |
| @@ -331,7 +332,7 @@ class LogMonitor: | @@ -331,7 +332,7 @@ class LogMonitor: | ||
| 331 | new_lines = [line.strip() for line in new_lines if line.strip()] | 332 | new_lines = [line.strip() for line in new_lines if line.strip()] |
| 332 | 333 | ||
| 333 | except Exception as e: | 334 | except Exception as e: |
| 334 | - print(f"ForumEngine: 读取{app_name}日志失败: {e}") | 335 | + logger.exception(f"ForumEngine: 读取{app_name}日志失败: {e}") |
| 335 | 336 | ||
| 336 | return new_lines | 337 | return new_lines |
| 337 | 338 | ||
| @@ -406,7 +407,7 @@ class LogMonitor: | @@ -406,7 +407,7 @@ class LogMonitor: | ||
| 406 | self.is_host_generating = False | 407 | self.is_host_generating = False |
| 407 | return | 408 | return |
| 408 | 409 | ||
| 409 | - print("ForumEngine: 正在生成主持人发言...") | 410 | + logger.info("ForumEngine: 正在生成主持人发言...") |
| 410 | 411 | ||
| 411 | # 调用主持人生成发言(传入最近5条) | 412 | # 调用主持人生成发言(传入最近5条) |
| 412 | host_speech = generate_host_speech(recent_speeches) | 413 | host_speech = generate_host_speech(recent_speeches) |
| @@ -414,18 +415,18 @@ class LogMonitor: | @@ -414,18 +415,18 @@ class LogMonitor: | ||
| 414 | if host_speech: | 415 | if host_speech: |
| 415 | # 写入主持人发言到forum.log | 416 | # 写入主持人发言到forum.log |
| 416 | self.write_to_forum_log(host_speech, "HOST") | 417 | self.write_to_forum_log(host_speech, "HOST") |
| 417 | - print(f"ForumEngine: 主持人发言已记录") | 418 | + logger.info(f"ForumEngine: 主持人发言已记录") |
| 418 | 419 | ||
| 419 | # 清空已处理的5条发言 | 420 | # 清空已处理的5条发言 |
| 420 | self.agent_speeches_buffer = self.agent_speeches_buffer[5:] | 421 | self.agent_speeches_buffer = self.agent_speeches_buffer[5:] |
| 421 | else: | 422 | else: |
| 422 | - print("ForumEngine: 主持人发言生成失败") | 423 | + logger.error("ForumEngine: 主持人发言生成失败") |
| 423 | 424 | ||
| 424 | # 重置生成标志 | 425 | # 重置生成标志 |
| 425 | self.is_host_generating = False | 426 | self.is_host_generating = False |
| 426 | 427 | ||
| 427 | except Exception as e: | 428 | except Exception as e: |
| 428 | - print(f"ForumEngine: 触发主持人发言时出错: {e}") | 429 | + logger.exception(f"ForumEngine: 触发主持人发言时出错: {e}") |
| 429 | self.is_host_generating = False | 430 | self.is_host_generating = False |
| 430 | 431 | ||
| 431 | def _clean_content_tags(self, content: str, app_name: str) -> str: | 432 | def _clean_content_tags(self, content: str, app_name: str) -> str: |
| @@ -453,7 +454,7 @@ class LogMonitor: | @@ -453,7 +454,7 @@ class LogMonitor: | ||
| 453 | 454 | ||
| 454 | def monitor_logs(self): | 455 | def monitor_logs(self): |
| 455 | """智能监控日志文件""" | 456 | """智能监控日志文件""" |
| 456 | - print("ForumEngine: 论坛创建中...") | 457 | + logger.info("ForumEngine: 论坛创建中...") |
| 457 | 458 | ||
| 458 | # 初始化文件行数和位置 - 记录当前状态作为基线 | 459 | # 初始化文件行数和位置 - 记录当前状态作为基线 |
| 459 | for app_name, log_file in self.monitored_logs.items(): | 460 | for app_name, log_file in self.monitored_logs.items(): |
| @@ -461,7 +462,7 @@ class LogMonitor: | @@ -461,7 +462,7 @@ class LogMonitor: | ||
| 461 | self.file_positions[app_name] = self.get_file_size(log_file) | 462 | self.file_positions[app_name] = self.get_file_size(log_file) |
| 462 | self.capturing_json[app_name] = False | 463 | self.capturing_json[app_name] = False |
| 463 | self.json_buffer[app_name] = [] | 464 | self.json_buffer[app_name] = [] |
| 464 | - # print(f"ForumEngine: {app_name} 基线行数: {self.file_line_counts[app_name]}") | 465 | + # logger.info(f"ForumEngine: {app_name} 基线行数: {self.file_line_counts[app_name]}") |
| 465 | 466 | ||
| 466 | while self.is_monitoring: | 467 | while self.is_monitoring: |
| 467 | try: | 468 | try: |
| @@ -484,7 +485,7 @@ class LogMonitor: | @@ -484,7 +485,7 @@ class LogMonitor: | ||
| 484 | if not self.is_searching: | 485 | if not self.is_searching: |
| 485 | for line in new_lines: | 486 | for line in new_lines: |
| 486 | if line.strip() and 'FirstSummaryNode' in line: | 487 | if line.strip() and 'FirstSummaryNode' in line: |
| 487 | - print(f"ForumEngine: 在{app_name}中检测到第一次论坛发表内容") | 488 | + logger.info(f"ForumEngine: 在{app_name}中检测到第一次论坛发表内容") |
| 488 | self.is_searching = True | 489 | self.is_searching = True |
| 489 | self.search_inactive_count = 0 | 490 | self.search_inactive_count = 0 |
| 490 | # 清空forum.log开始新会话 | 491 | # 清空forum.log开始新会话 |
| @@ -500,7 +501,7 @@ class LogMonitor: | @@ -500,7 +501,7 @@ class LogMonitor: | ||
| 500 | # 将app_name转换为大写作为标签(如 insight -> INSIGHT) | 501 | # 将app_name转换为大写作为标签(如 insight -> INSIGHT) |
| 501 | source_tag = app_name.upper() | 502 | source_tag = app_name.upper() |
| 502 | self.write_to_forum_log(content, source_tag) | 503 | self.write_to_forum_log(content, source_tag) |
| 503 | - # print(f"ForumEngine: 捕获 - {content}") | 504 | + # logger.info(f"ForumEngine: 捕获 - {content}") |
| 504 | captured_any = True | 505 | captured_any = True |
| 505 | 506 | ||
| 506 | # 将发言添加到缓冲区(格式化为完整的日志行) | 507 | # 将发言添加到缓冲区(格式化为完整的日志行) |
| @@ -515,7 +516,7 @@ class LogMonitor: | @@ -515,7 +516,7 @@ class LogMonitor: | ||
| 515 | 516 | ||
| 516 | elif current_lines < previous_lines: | 517 | elif current_lines < previous_lines: |
| 517 | any_shrink = True | 518 | any_shrink = True |
| 518 | - # print(f"ForumEngine: 检测到 {app_name} 日志缩短,将重置基线") | 519 | + # logger.info(f"ForumEngine: 检测到 {app_name} 日志缩短,将重置基线") |
| 519 | # 重置文件位置到新的文件末尾 | 520 | # 重置文件位置到新的文件末尾 |
| 520 | self.file_positions[app_name] = self.get_file_size(log_file) | 521 | self.file_positions[app_name] = self.get_file_size(log_file) |
| 521 | # 重置JSON捕获状态 | 522 | # 重置JSON捕获状态 |
| @@ -529,7 +530,7 @@ class LogMonitor: | @@ -529,7 +530,7 @@ class LogMonitor: | ||
| 529 | if self.is_searching: | 530 | if self.is_searching: |
| 530 | if any_shrink: | 531 | if any_shrink: |
| 531 | # log变短,结束当前搜索会话,重置为等待状态 | 532 | # log变短,结束当前搜索会话,重置为等待状态 |
| 532 | - # print("ForumEngine: 日志缩短,结束当前搜索会话,回到等待状态") | 533 | + # logger.info("ForumEngine: 日志缩短,结束当前搜索会话,回到等待状态") |
| 533 | self.is_searching = False | 534 | self.is_searching = False |
| 534 | self.search_inactive_count = 0 | 535 | self.search_inactive_count = 0 |
| 535 | # 重置主持人相关状态 | 536 | # 重置主持人相关状态 |
| @@ -538,12 +539,12 @@ class LogMonitor: | @@ -538,12 +539,12 @@ class LogMonitor: | ||
| 538 | # 写入结束标记 | 539 | # 写入结束标记 |
| 539 | end_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') | 540 | end_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| 540 | self.write_to_forum_log(f"=== ForumEngine 论坛结束 - {end_time} ===", "SYSTEM") | 541 | self.write_to_forum_log(f"=== ForumEngine 论坛结束 - {end_time} ===", "SYSTEM") |
| 541 | - # print("ForumEngine: 已重置基线,等待下次FirstSummaryNode触发") | 542 | + # logger.info("ForumEngine: 已重置基线,等待下次FirstSummaryNode触发") |
| 542 | elif not any_growth and not captured_any: | 543 | elif not any_growth and not captured_any: |
| 543 | # 没有增长也没有捕获内容,增加非活跃计数 | 544 | # 没有增长也没有捕获内容,增加非活跃计数 |
| 544 | self.search_inactive_count += 1 | 545 | self.search_inactive_count += 1 |
| 545 | if self.search_inactive_count >= 900: # 15分钟无活动才结束 | 546 | if self.search_inactive_count >= 900: # 15分钟无活动才结束 |
| 546 | - print("ForumEngine: 长时间无活动,结束论坛") | 547 | + logger.info("ForumEngine: 长时间无活动,结束论坛") |
| 547 | self.is_searching = False | 548 | self.is_searching = False |
| 548 | self.search_inactive_count = 0 | 549 | self.search_inactive_count = 0 |
| 549 | # 重置主持人相关状态 | 550 | # 重置主持人相关状态 |
| @@ -559,17 +560,17 @@ class LogMonitor: | @@ -559,17 +560,17 @@ class LogMonitor: | ||
| 559 | time.sleep(1) | 560 | time.sleep(1) |
| 560 | 561 | ||
| 561 | except Exception as e: | 562 | except Exception as e: |
| 562 | - print(f"ForumEngine: 论坛记录中出错: {e}") | 563 | + logger.exception(f"ForumEngine: 论坛记录中出错: {e}") |
| 563 | import traceback | 564 | import traceback |
| 564 | traceback.print_exc() | 565 | traceback.print_exc() |
| 565 | time.sleep(2) | 566 | time.sleep(2) |
| 566 | 567 | ||
| 567 | - print("ForumEngine: 停止论坛日志文件") | 568 | + logger.info("ForumEngine: 停止论坛日志文件") |
| 568 | 569 | ||
| 569 | def start_monitoring(self): | 570 | def start_monitoring(self): |
| 570 | """开始智能监控""" | 571 | """开始智能监控""" |
| 571 | if self.is_monitoring: | 572 | if self.is_monitoring: |
| 572 | - print("ForumEngine: 论坛已经在运行中") | 573 | + logger.info("ForumEngine: 论坛已经在运行中") |
| 573 | return False | 574 | return False |
| 574 | 575 | ||
| 575 | try: | 576 | try: |
| @@ -578,18 +579,18 @@ class LogMonitor: | @@ -578,18 +579,18 @@ class LogMonitor: | ||
| 578 | self.monitor_thread = threading.Thread(target=self.monitor_logs, daemon=True) | 579 | self.monitor_thread = threading.Thread(target=self.monitor_logs, daemon=True) |
| 579 | self.monitor_thread.start() | 580 | self.monitor_thread.start() |
| 580 | 581 | ||
| 581 | - print("ForumEngine: 论坛已启动") | 582 | + logger.info("ForumEngine: 论坛已启动") |
| 582 | return True | 583 | return True |
| 583 | 584 | ||
| 584 | except Exception as e: | 585 | except Exception as e: |
| 585 | - print(f"ForumEngine: 启动论坛失败: {e}") | 586 | + logger.exception(f"ForumEngine: 启动论坛失败: {e}") |
| 586 | self.is_monitoring = False | 587 | self.is_monitoring = False |
| 587 | return False | 588 | return False |
| 588 | 589 | ||
| 589 | def stop_monitoring(self): | 590 | def stop_monitoring(self): |
| 590 | """停止监控""" | 591 | """停止监控""" |
| 591 | if not self.is_monitoring: | 592 | if not self.is_monitoring: |
| 592 | - print("ForumEngine: 论坛未运行") | 593 | + logger.info("ForumEngine: 论坛未运行") |
| 593 | return | 594 | return |
| 594 | 595 | ||
| 595 | try: | 596 | try: |
| @@ -602,10 +603,10 @@ class LogMonitor: | @@ -602,10 +603,10 @@ class LogMonitor: | ||
| 602 | end_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') | 603 | end_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| 603 | self.write_to_forum_log(f"=== ForumEngine 论坛结束 - {end_time} ===", "SYSTEM") | 604 | self.write_to_forum_log(f"=== ForumEngine 论坛结束 - {end_time} ===", "SYSTEM") |
| 604 | 605 | ||
| 605 | - print("ForumEngine: 论坛已停止") | 606 | + logger.info("ForumEngine: 论坛已停止") |
| 606 | 607 | ||
| 607 | except Exception as e: | 608 | except Exception as e: |
| 608 | - print(f"ForumEngine: 停止论坛失败: {e}") | 609 | + logger.exception(f"ForumEngine: 停止论坛失败: {e}") |
| 609 | 610 | ||
| 610 | def get_forum_log_content(self) -> List[str]: | 611 | def get_forum_log_content(self) -> List[str]: |
| 611 | """获取forum.log的内容""" | 612 | """获取forum.log的内容""" |
| @@ -617,7 +618,7 @@ class LogMonitor: | @@ -617,7 +618,7 @@ class LogMonitor: | ||
| 617 | return [line.rstrip('\n\r') for line in f.readlines()] | 618 | return [line.rstrip('\n\r') for line in f.readlines()] |
| 618 | 619 | ||
| 619 | except Exception as e: | 620 | except Exception as e: |
| 620 | - print(f"ForumEngine: 读取forum.log失败: {e}") | 621 | + logger.exception(f"ForumEngine: 读取forum.log失败: {e}") |
| 621 | return [] | 622 | return [] |
| 622 | 623 | ||
| 623 | def fix_json_string(self, json_text: str) -> str: | 624 | def fix_json_string(self, json_text: str) -> str: |
| @@ -4,9 +4,9 @@ Deep Search Agent | @@ -4,9 +4,9 @@ Deep Search Agent | ||
| 4 | """ | 4 | """ |
| 5 | 5 | ||
| 6 | from .agent import DeepSearchAgent, create_agent | 6 | from .agent import DeepSearchAgent, create_agent |
| 7 | -from .utils.config import Config, load_config | 7 | +from .utils.config import settings, Settings |
| 8 | 8 | ||
| 9 | __version__ = "1.0.0" | 9 | __version__ = "1.0.0" |
| 10 | __author__ = "Deep Search Agent Team" | 10 | __author__ = "Deep Search Agent Team" |
| 11 | 11 | ||
| 12 | -__all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"] | 12 | +__all__ = ["DeepSearchAgent", "create_agent", "settings", "Settings"] |
| @@ -8,6 +8,7 @@ import os | @@ -8,6 +8,7 @@ import os | ||
| 8 | import re | 8 | import re |
| 9 | from datetime import datetime | 9 | from datetime import datetime |
| 10 | from typing import Optional, Dict, Any, List, Union | 10 | from typing import Optional, Dict, Any, List, Union |
| 11 | +from loguru import logger | ||
| 11 | 12 | ||
| 12 | from .llms import LLMClient | 13 | from .llms import LLMClient |
| 13 | from .nodes import ( | 14 | from .nodes import ( |
| @@ -20,32 +21,25 @@ from .nodes import ( | @@ -20,32 +21,25 @@ from .nodes import ( | ||
| 20 | ) | 21 | ) |
| 21 | from .state import State | 22 | from .state import State |
| 22 | from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer | 23 | from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer |
| 23 | -from .utils import Config, load_config, format_search_results_for_prompt | 24 | +from .utils.config import settings, Settings |
| 25 | +from .utils import format_search_results_for_prompt | ||
| 24 | 26 | ||
| 25 | 27 | ||
| 26 | class DeepSearchAgent: | 28 | class DeepSearchAgent: |
| 27 | """Deep Search Agent主类""" | 29 | """Deep Search Agent主类""" |
| 28 | 30 | ||
| 29 | - def __init__(self, config: Optional[Config] = None): | 31 | + def __init__(self, config: Optional[Settings] = None): |
| 30 | """ | 32 | """ |
| 31 | 初始化Deep Search Agent | 33 | 初始化Deep Search Agent |
| 32 | 34 | ||
| 33 | Args: | 35 | Args: |
| 34 | - config: 配置对象,如果不提供则自动加载 | 36 | + config: 可选配置对象(不填则用全局settings) |
| 35 | """ | 37 | """ |
| 36 | - # 加载配置 | ||
| 37 | - self.config = config or load_config() | 38 | + self.config = config or settings |
| 38 | 39 | ||
| 39 | # 初始化LLM客户端 | 40 | # 初始化LLM客户端 |
| 40 | self.llm_client = self._initialize_llm() | 41 | self.llm_client = self._initialize_llm() |
| 41 | 42 | ||
| 42 | - # 设置数据库环境变量 | ||
| 43 | - os.environ["DB_HOST"] = self.config.db_host or "" | ||
| 44 | - os.environ["DB_USER"] = self.config.db_user or "" | ||
| 45 | - os.environ["DB_PASSWORD"] = self.config.db_password or "" | ||
| 46 | - os.environ["DB_NAME"] = self.config.db_name or "" | ||
| 47 | - os.environ["DB_PORT"] = str(self.config.db_port) | ||
| 48 | - os.environ["DB_CHARSET"] = self.config.db_charset | ||
| 49 | 43 | ||
| 50 | # 初始化搜索工具集 | 44 | # 初始化搜索工具集 |
| 51 | self.search_agency = MediaCrawlerDB() | 45 | self.search_agency = MediaCrawlerDB() |
| @@ -60,19 +54,19 @@ class DeepSearchAgent: | @@ -60,19 +54,19 @@ class DeepSearchAgent: | ||
| 60 | self.state = State() | 54 | self.state = State() |
| 61 | 55 | ||
| 62 | # 确保输出目录存在 | 56 | # 确保输出目录存在 |
| 63 | - os.makedirs(self.config.output_dir, exist_ok=True) | 57 | + os.makedirs(self.config.OUTPUT_DIR, exist_ok=True) |
| 64 | 58 | ||
| 65 | - print(f"Insight Agent已初始化") | ||
| 66 | - print(f"使用LLM: {self.llm_client.get_model_info()}") | ||
| 67 | - print(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)") | ||
| 68 | - print(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)") | 59 | + logger.info(f"Insight Agent已初始化") |
| 60 | + logger.info(f"使用LLM: {self.llm_client.get_model_info()}") | ||
| 61 | + logger.info(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)") | ||
| 62 | + logger.info(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)") | ||
| 69 | 63 | ||
| 70 | def _initialize_llm(self) -> LLMClient: | 64 | def _initialize_llm(self) -> LLMClient: |
| 71 | """初始化LLM客户端""" | 65 | """初始化LLM客户端""" |
| 72 | return LLMClient( | 66 | return LLMClient( |
| 73 | - api_key=self.config.llm_api_key, | ||
| 74 | - model_name=self.config.llm_model_name, | ||
| 75 | - base_url=self.config.llm_base_url, | 67 | + api_key=self.config.INSIGHT_ENGINE_API_KEY, |
| 68 | + model_name=self.config.INSIGHT_ENGINE_MODEL_NAME, | ||
| 69 | + base_url=self.config.INSIGHT_ENGINE_BASE_URL, | ||
| 76 | ) | 70 | ) |
| 77 | 71 | ||
| 78 | def _initialize_nodes(self): | 72 | def _initialize_nodes(self): |
| @@ -127,7 +121,7 @@ class DeepSearchAgent: | @@ -127,7 +121,7 @@ class DeepSearchAgent: | ||
| 127 | Returns: | 121 | Returns: |
| 128 | DBResponse对象(可能包含情感分析结果) | 122 | DBResponse对象(可能包含情感分析结果) |
| 129 | """ | 123 | """ |
| 130 | - print(f" → 执行数据库查询工具: {tool_name}") | 124 | + logger.info(f" → 执行数据库查询工具: {tool_name}") |
| 131 | 125 | ||
| 132 | # 对于热点内容搜索,不需要关键词优化(因为不需要query参数) | 126 | # 对于热点内容搜索,不需要关键词优化(因为不需要query参数) |
| 133 | if tool_name == "search_hot_content": | 127 | if tool_name == "search_hot_content": |
| @@ -138,12 +132,12 @@ class DeepSearchAgent: | @@ -138,12 +132,12 @@ class DeepSearchAgent: | ||
| 138 | # 检查是否需要进行情感分析 | 132 | # 检查是否需要进行情感分析 |
| 139 | enable_sentiment = kwargs.get("enable_sentiment", True) | 133 | enable_sentiment = kwargs.get("enable_sentiment", True) |
| 140 | if enable_sentiment and response.results and len(response.results) > 0: | 134 | if enable_sentiment and response.results and len(response.results) > 0: |
| 141 | - print(f" 🎭 开始对热点内容进行情感分析...") | 135 | + logger.info(f" 🎭 开始对热点内容进行情感分析...") |
| 142 | sentiment_analysis = self._perform_sentiment_analysis(response.results) | 136 | sentiment_analysis = self._perform_sentiment_analysis(response.results) |
| 143 | if sentiment_analysis: | 137 | if sentiment_analysis: |
| 144 | # 将情感分析结果添加到响应的parameters中 | 138 | # 将情感分析结果添加到响应的parameters中 |
| 145 | response.parameters["sentiment_analysis"] = sentiment_analysis | 139 | response.parameters["sentiment_analysis"] = sentiment_analysis |
| 146 | - print(f" ✅ 情感分析完成") | 140 | + logger.info(f" ✅ 情感分析完成") |
| 147 | 141 | ||
| 148 | return response | 142 | return response |
| 149 | 143 | ||
| @@ -170,32 +164,32 @@ class DeepSearchAgent: | @@ -170,32 +164,32 @@ class DeepSearchAgent: | ||
| 170 | context=f"使用{tool_name}工具进行查询" | 164 | context=f"使用{tool_name}工具进行查询" |
| 171 | ) | 165 | ) |
| 172 | 166 | ||
| 173 | - print(f" 🔍 原始查询: '{query}'") | ||
| 174 | - print(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}") | 167 | + logger.info(f" 🔍 原始查询: '{query}'") |
| 168 | + logger.info(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}") | ||
| 175 | 169 | ||
| 176 | # 使用优化后的关键词进行多次查询并整合结果 | 170 | # 使用优化后的关键词进行多次查询并整合结果 |
| 177 | all_results = [] | 171 | all_results = [] |
| 178 | total_count = 0 | 172 | total_count = 0 |
| 179 | 173 | ||
| 180 | for keyword in optimized_response.optimized_keywords: | 174 | for keyword in optimized_response.optimized_keywords: |
| 181 | - print(f" 查询关键词: '{keyword}'") | 175 | + logger.info(f" 查询关键词: '{keyword}'") |
| 182 | 176 | ||
| 183 | try: | 177 | try: |
| 184 | if tool_name == "search_topic_globally": | 178 | if tool_name == "search_topic_globally": |
| 185 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 | 179 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 |
| 186 | - limit_per_table = self.config.default_search_topic_globally_limit_per_table | 180 | + limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE |
| 187 | response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table) | 181 | response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table) |
| 188 | elif tool_name == "search_topic_by_date": | 182 | elif tool_name == "search_topic_by_date": |
| 189 | start_date = kwargs.get("start_date") | 183 | start_date = kwargs.get("start_date") |
| 190 | end_date = kwargs.get("end_date") | 184 | end_date = kwargs.get("end_date") |
| 191 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 | 185 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 |
| 192 | - limit_per_table = self.config.default_search_topic_by_date_limit_per_table | 186 | + limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE |
| 193 | if not start_date or not end_date: | 187 | if not start_date or not end_date: |
| 194 | raise ValueError("search_topic_by_date工具需要start_date和end_date参数") | 188 | raise ValueError("search_topic_by_date工具需要start_date和end_date参数") |
| 195 | response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table) | 189 | response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table) |
| 196 | elif tool_name == "get_comments_for_topic": | 190 | elif tool_name == "get_comments_for_topic": |
| 197 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 | 191 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 |
| 198 | - limit = self.config.default_get_comments_for_topic_limit // len(optimized_response.optimized_keywords) | 192 | + limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(optimized_response.optimized_keywords) |
| 199 | limit = max(limit, 50) | 193 | limit = max(limit, 50) |
| 200 | response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit) | 194 | response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit) |
| 201 | elif tool_name == "search_topic_on_platform": | 195 | elif tool_name == "search_topic_on_platform": |
| @@ -203,30 +197,30 @@ class DeepSearchAgent: | @@ -203,30 +197,30 @@ class DeepSearchAgent: | ||
| 203 | start_date = kwargs.get("start_date") | 197 | start_date = kwargs.get("start_date") |
| 204 | end_date = kwargs.get("end_date") | 198 | end_date = kwargs.get("end_date") |
| 205 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 | 199 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 |
| 206 | - limit = self.config.default_search_topic_on_platform_limit // len(optimized_response.optimized_keywords) | 200 | + limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(optimized_response.optimized_keywords) |
| 207 | limit = max(limit, 30) | 201 | limit = max(limit, 30) |
| 208 | if not platform: | 202 | if not platform: |
| 209 | raise ValueError("search_topic_on_platform工具需要platform参数") | 203 | raise ValueError("search_topic_on_platform工具需要platform参数") |
| 210 | response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit) | 204 | response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit) |
| 211 | else: | 205 | else: |
| 212 | - print(f" 未知的搜索工具: {tool_name},使用默认全局搜索") | ||
| 213 | - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.default_search_topic_globally_limit_per_table) | 206 | + logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索") |
| 207 | + response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE) | ||
| 214 | 208 | ||
| 215 | # 收集结果 | 209 | # 收集结果 |
| 216 | if response.results: | 210 | if response.results: |
| 217 | - print(f" 找到 {len(response.results)} 条结果") | 211 | + logger.info(f" 找到 {len(response.results)} 条结果") |
| 218 | all_results.extend(response.results) | 212 | all_results.extend(response.results) |
| 219 | total_count += len(response.results) | 213 | total_count += len(response.results) |
| 220 | else: | 214 | else: |
| 221 | - print(f" 未找到结果") | 215 | + logger.info(f" 未找到结果") |
| 222 | 216 | ||
| 223 | except Exception as e: | 217 | except Exception as e: |
| 224 | - print(f" 查询'{keyword}'时出错: {str(e)}") | 218 | + logger.error(f" 查询'{keyword}'时出错: {str(e)}") |
| 225 | continue | 219 | continue |
| 226 | 220 | ||
| 227 | # 去重和整合结果 | 221 | # 去重和整合结果 |
| 228 | unique_results = self._deduplicate_results(all_results) | 222 | unique_results = self._deduplicate_results(all_results) |
| 229 | - print(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条") | 223 | + logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条") |
| 230 | 224 | ||
| 231 | # 构建整合后的响应 | 225 | # 构建整合后的响应 |
| 232 | integrated_response = DBResponse( | 226 | integrated_response = DBResponse( |
| @@ -244,12 +238,12 @@ class DeepSearchAgent: | @@ -244,12 +238,12 @@ class DeepSearchAgent: | ||
| 244 | # 检查是否需要进行情感分析 | 238 | # 检查是否需要进行情感分析 |
| 245 | enable_sentiment = kwargs.get("enable_sentiment", True) | 239 | enable_sentiment = kwargs.get("enable_sentiment", True) |
| 246 | if enable_sentiment and unique_results and len(unique_results) > 0: | 240 | if enable_sentiment and unique_results and len(unique_results) > 0: |
| 247 | - print(f" 🎭 开始对搜索结果进行情感分析...") | 241 | + logger.info(f" 🎭 开始对搜索结果进行情感分析...") |
| 248 | sentiment_analysis = self._perform_sentiment_analysis(unique_results) | 242 | sentiment_analysis = self._perform_sentiment_analysis(unique_results) |
| 249 | if sentiment_analysis: | 243 | if sentiment_analysis: |
| 250 | # 将情感分析结果添加到响应的parameters中 | 244 | # 将情感分析结果添加到响应的parameters中 |
| 251 | integrated_response.parameters["sentiment_analysis"] = sentiment_analysis | 245 | integrated_response.parameters["sentiment_analysis"] = sentiment_analysis |
| 252 | - print(f" ✅ 情感分析完成") | 246 | + logger.info(f" ✅ 情感分析完成") |
| 253 | 247 | ||
| 254 | return integrated_response | 248 | return integrated_response |
| 255 | 249 | ||
| @@ -282,11 +276,11 @@ class DeepSearchAgent: | @@ -282,11 +276,11 @@ class DeepSearchAgent: | ||
| 282 | try: | 276 | try: |
| 283 | # 初始化情感分析器(如果尚未初始化且未被禁用) | 277 | # 初始化情感分析器(如果尚未初始化且未被禁用) |
| 284 | if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: | 278 | if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: |
| 285 | - print(" 初始化情感分析模型...") | 279 | + logger.info(" 初始化情感分析模型...") |
| 286 | if not self.sentiment_analyzer.initialize(): | 280 | if not self.sentiment_analyzer.initialize(): |
| 287 | - print(" 情感分析模型初始化失败,将直接透传原始文本") | 281 | + logger.info(" 情感分析模型初始化失败,将直接透传原始文本") |
| 288 | elif self.sentiment_analyzer.is_disabled: | 282 | elif self.sentiment_analyzer.is_disabled: |
| 289 | - print(" 情感分析功能已禁用,直接透传原始文本") | 283 | + logger.info(" 情感分析功能已禁用,直接透传原始文本") |
| 290 | 284 | ||
| 291 | # 将查询结果转换为字典格式 | 285 | # 将查询结果转换为字典格式 |
| 292 | results_dict = [] | 286 | results_dict = [] |
| @@ -310,7 +304,7 @@ class DeepSearchAgent: | @@ -310,7 +304,7 @@ class DeepSearchAgent: | ||
| 310 | return sentiment_analysis.get("sentiment_analysis") | 304 | return sentiment_analysis.get("sentiment_analysis") |
| 311 | 305 | ||
| 312 | except Exception as e: | 306 | except Exception as e: |
| 313 | - print(f" ❌ 情感分析过程中发生错误: {str(e)}") | 307 | + logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") |
| 314 | return None | 308 | return None |
| 315 | 309 | ||
| 316 | def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]: | 310 | def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]: |
| @@ -323,16 +317,16 @@ class DeepSearchAgent: | @@ -323,16 +317,16 @@ class DeepSearchAgent: | ||
| 323 | Returns: | 317 | Returns: |
| 324 | 情感分析结果 | 318 | 情感分析结果 |
| 325 | """ | 319 | """ |
| 326 | - print(f" → 执行独立情感分析") | 320 | + logger.info(f" → 执行独立情感分析") |
| 327 | 321 | ||
| 328 | try: | 322 | try: |
| 329 | # 初始化情感分析器(如果尚未初始化且未被禁用) | 323 | # 初始化情感分析器(如果尚未初始化且未被禁用) |
| 330 | if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: | 324 | if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: |
| 331 | - print(" 初始化情感分析模型...") | 325 | + logger.info(" 初始化情感分析模型...") |
| 332 | if not self.sentiment_analyzer.initialize(): | 326 | if not self.sentiment_analyzer.initialize(): |
| 333 | - print(" 情感分析模型初始化失败,将直接透传原始文本") | 327 | + logger.info(" 情感分析模型初始化失败,将直接透传原始文本") |
| 334 | elif self.sentiment_analyzer.is_disabled: | 328 | elif self.sentiment_analyzer.is_disabled: |
| 335 | - print(" 情感分析功能已禁用,直接透传原始文本") | 329 | + logger.warning(" 情感分析功能已禁用,直接透传原始文本") |
| 336 | 330 | ||
| 337 | # 执行分析 | 331 | # 执行分析 |
| 338 | if isinstance(texts, str): | 332 | if isinstance(texts, str): |
| @@ -368,7 +362,7 @@ class DeepSearchAgent: | @@ -368,7 +362,7 @@ class DeepSearchAgent: | ||
| 368 | return response | 362 | return response |
| 369 | 363 | ||
| 370 | except Exception as e: | 364 | except Exception as e: |
| 371 | - print(f" ❌ 情感分析过程中发生错误: {str(e)}") | 365 | + logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") |
| 372 | return { | 366 | return { |
| 373 | "success": False, | 367 | "success": False, |
| 374 | "error": str(e), | 368 | "error": str(e), |
| @@ -386,9 +380,9 @@ class DeepSearchAgent: | @@ -386,9 +380,9 @@ class DeepSearchAgent: | ||
| 386 | Returns: | 380 | Returns: |
| 387 | 最终报告内容 | 381 | 最终报告内容 |
| 388 | """ | 382 | """ |
| 389 | - print(f"\n{'='*60}") | ||
| 390 | - print(f"开始深度研究: {query}") | ||
| 391 | - print(f"{'='*60}") | 383 | + logger.info(f"\n{'='*60}") |
| 384 | + logger.info(f"开始深度研究: {query}") | ||
| 385 | + logger.info(f"{'='*60}") | ||
| 392 | 386 | ||
| 393 | try: | 387 | try: |
| 394 | # Step 1: 生成报告结构 | 388 | # Step 1: 生成报告结构 |
| @@ -403,20 +397,18 @@ class DeepSearchAgent: | @@ -403,20 +397,18 @@ class DeepSearchAgent: | ||
| 403 | # Step 4: 保存报告 | 397 | # Step 4: 保存报告 |
| 404 | if save_report: | 398 | if save_report: |
| 405 | self._save_report(final_report) | 399 | self._save_report(final_report) |
| 406 | - | ||
| 407 | - print(f"\n{'='*60}") | ||
| 408 | - print("深度研究完成!") | ||
| 409 | - print(f"{'='*60}") | 400 | + |
| 401 | + logger.info("深度研究完成!") | ||
| 410 | 402 | ||
| 411 | return final_report | 403 | return final_report |
| 412 | 404 | ||
| 413 | except Exception as e: | 405 | except Exception as e: |
| 414 | - print(f"研究过程中发生错误: {str(e)}") | 406 | + logger.exception(f"研究过程中发生错误: {str(e)}") |
| 415 | raise e | 407 | raise e |
| 416 | 408 | ||
| 417 | def _generate_report_structure(self, query: str): | 409 | def _generate_report_structure(self, query: str): |
| 418 | """生成报告结构""" | 410 | """生成报告结构""" |
| 419 | - print(f"\n[步骤 1] 生成报告结构...") | 411 | + logger.info(f"\n[步骤 1] 生成报告结构...") |
| 420 | 412 | ||
| 421 | # 创建报告结构节点 | 413 | # 创建报告结构节点 |
| 422 | report_structure_node = ReportStructureNode(self.llm_client, query) | 414 | report_structure_node = ReportStructureNode(self.llm_client, query) |
| @@ -424,17 +416,18 @@ class DeepSearchAgent: | @@ -424,17 +416,18 @@ class DeepSearchAgent: | ||
| 424 | # 生成结构并更新状态 | 416 | # 生成结构并更新状态 |
| 425 | self.state = report_structure_node.mutate_state(state=self.state) | 417 | self.state = report_structure_node.mutate_state(state=self.state) |
| 426 | 418 | ||
| 427 | - print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:") | 419 | + _message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:" |
| 428 | for i, paragraph in enumerate(self.state.paragraphs, 1): | 420 | for i, paragraph in enumerate(self.state.paragraphs, 1): |
| 429 | - print(f" {i}. {paragraph.title}") | 421 | + _message += f"\n {i}. {paragraph.title}" |
| 422 | + logger.info(_message) | ||
| 430 | 423 | ||
| 431 | def _process_paragraphs(self): | 424 | def _process_paragraphs(self): |
| 432 | """处理所有段落""" | 425 | """处理所有段落""" |
| 433 | total_paragraphs = len(self.state.paragraphs) | 426 | total_paragraphs = len(self.state.paragraphs) |
| 434 | 427 | ||
| 435 | for i in range(total_paragraphs): | 428 | for i in range(total_paragraphs): |
| 436 | - print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") | ||
| 437 | - print("-" * 50) | 429 | + logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") |
| 430 | + logger.info("-" * 50) | ||
| 438 | 431 | ||
| 439 | # 初始搜索和总结 | 432 | # 初始搜索和总结 |
| 440 | self._initial_search_and_summary(i) | 433 | self._initial_search_and_summary(i) |
| @@ -446,7 +439,7 @@ class DeepSearchAgent: | @@ -446,7 +439,7 @@ class DeepSearchAgent: | ||
| 446 | self.state.paragraphs[i].research.mark_completed() | 439 | self.state.paragraphs[i].research.mark_completed() |
| 447 | 440 | ||
| 448 | progress = (i + 1) / total_paragraphs * 100 | 441 | progress = (i + 1) / total_paragraphs * 100 |
| 449 | - print(f"段落处理完成 ({progress:.1f}%)") | 442 | + logger.info(f"段落处理完成 ({progress:.1f}%)") |
| 450 | 443 | ||
| 451 | def _initial_search_and_summary(self, paragraph_index: int): | 444 | def _initial_search_and_summary(self, paragraph_index: int): |
| 452 | """执行初始搜索和总结""" | 445 | """执行初始搜索和总结""" |
| @@ -459,18 +452,18 @@ class DeepSearchAgent: | @@ -459,18 +452,18 @@ class DeepSearchAgent: | ||
| 459 | } | 452 | } |
| 460 | 453 | ||
| 461 | # 生成搜索查询和工具选择 | 454 | # 生成搜索查询和工具选择 |
| 462 | - print(" - 生成搜索查询...") | 455 | + logger.info(" - 生成搜索查询...") |
| 463 | search_output = self.first_search_node.run(search_input) | 456 | search_output = self.first_search_node.run(search_input) |
| 464 | search_query = search_output["search_query"] | 457 | search_query = search_output["search_query"] |
| 465 | search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具 | 458 | search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具 |
| 466 | reasoning = search_output["reasoning"] | 459 | reasoning = search_output["reasoning"] |
| 467 | 460 | ||
| 468 | - print(f" - 搜索查询: {search_query}") | ||
| 469 | - print(f" - 选择的工具: {search_tool}") | ||
| 470 | - print(f" - 推理: {reasoning}") | 461 | + logger.info(f" - 搜索查询: {search_query}") |
| 462 | + logger.info(f" - 选择的工具: {search_tool}") | ||
| 463 | + logger.info(f" - 推理: {reasoning}") | ||
| 471 | 464 | ||
| 472 | # 执行搜索 | 465 | # 执行搜索 |
| 473 | - print(" - 执行数据库查询...") | 466 | + logger.info(" - 执行数据库查询...") |
| 474 | 467 | ||
| 475 | # 处理特殊参数 | 468 | # 处理特殊参数 |
| 476 | search_kwargs = {} | 469 | search_kwargs = {} |
| @@ -485,13 +478,13 @@ class DeepSearchAgent: | @@ -485,13 +478,13 @@ class DeepSearchAgent: | ||
| 485 | if self._validate_date_format(start_date) and self._validate_date_format(end_date): | 478 | if self._validate_date_format(start_date) and self._validate_date_format(end_date): |
| 486 | search_kwargs["start_date"] = start_date | 479 | search_kwargs["start_date"] = start_date |
| 487 | search_kwargs["end_date"] = end_date | 480 | search_kwargs["end_date"] = end_date |
| 488 | - print(f" - 时间范围: {start_date} 到 {end_date}") | 481 | + logger.info(f" - 时间范围: {start_date} 到 {end_date}") |
| 489 | else: | 482 | else: |
| 490 | - print(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") | ||
| 491 | - print(f" 提供的日期: start_date={start_date}, end_date={end_date}") | 483 | + logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") |
| 484 | + logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") | ||
| 492 | search_tool = "search_topic_globally" | 485 | search_tool = "search_topic_globally" |
| 493 | elif search_tool == "search_topic_by_date": | 486 | elif search_tool == "search_topic_by_date": |
| 494 | - print(f" search_topic_by_date工具缺少时间参数,改用全局搜索") | 487 | + logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索") |
| 495 | search_tool = "search_topic_globally" | 488 | search_tool = "search_topic_globally" |
| 496 | 489 | ||
| 497 | # 处理需要平台参数的工具 | 490 | # 处理需要平台参数的工具 |
| @@ -499,28 +492,28 @@ class DeepSearchAgent: | @@ -499,28 +492,28 @@ class DeepSearchAgent: | ||
| 499 | platform = search_output.get("platform") | 492 | platform = search_output.get("platform") |
| 500 | if platform: | 493 | if platform: |
| 501 | search_kwargs["platform"] = platform | 494 | search_kwargs["platform"] = platform |
| 502 | - print(f" - 指定平台: {platform}") | 495 | + logger.info(f" - 指定平台: {platform}") |
| 503 | else: | 496 | else: |
| 504 | - print(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") | 497 | + logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") |
| 505 | search_tool = "search_topic_globally" | 498 | search_tool = "search_topic_globally" |
| 506 | 499 | ||
| 507 | # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数 | 500 | # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数 |
| 508 | if search_tool == "search_hot_content": | 501 | if search_tool == "search_hot_content": |
| 509 | time_period = search_output.get("time_period", "week") | 502 | time_period = search_output.get("time_period", "week") |
| 510 | - limit = self.config.default_search_hot_content_limit | 503 | + limit = self.config.DEFAULT_SEARCH_HOT_CONTENT_LIMIT |
| 511 | search_kwargs["time_period"] = time_period | 504 | search_kwargs["time_period"] = time_period |
| 512 | search_kwargs["limit"] = limit | 505 | search_kwargs["limit"] = limit |
| 513 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: | 506 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: |
| 514 | if search_tool == "search_topic_globally": | 507 | if search_tool == "search_topic_globally": |
| 515 | - limit_per_table = self.config.default_search_topic_globally_limit_per_table | 508 | + limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE |
| 516 | else: # search_topic_by_date | 509 | else: # search_topic_by_date |
| 517 | - limit_per_table = self.config.default_search_topic_by_date_limit_per_table | 510 | + limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE |
| 518 | search_kwargs["limit_per_table"] = limit_per_table | 511 | search_kwargs["limit_per_table"] = limit_per_table |
| 519 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: | 512 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: |
| 520 | if search_tool == "get_comments_for_topic": | 513 | if search_tool == "get_comments_for_topic": |
| 521 | - limit = self.config.default_get_comments_for_topic_limit | 514 | + limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT |
| 522 | else: # search_topic_on_platform | 515 | else: # search_topic_on_platform |
| 523 | - limit = self.config.default_search_topic_on_platform_limit | 516 | + limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT |
| 524 | search_kwargs["limit"] = limit | 517 | search_kwargs["limit"] = limit |
| 525 | 518 | ||
| 526 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | 519 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) |
| @@ -529,8 +522,8 @@ class DeepSearchAgent: | @@ -529,8 +522,8 @@ class DeepSearchAgent: | ||
| 529 | search_results = [] | 522 | search_results = [] |
| 530 | if search_response and search_response.results: | 523 | if search_response and search_response.results: |
| 531 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 | 524 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 |
| 532 | - if self.config.max_search_results_for_llm > 0: | ||
| 533 | - max_results = min(len(search_response.results), self.config.max_search_results_for_llm) | 525 | + if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: |
| 526 | + max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) | ||
| 534 | else: | 527 | else: |
| 535 | max_results = len(search_response.results) # 不限制,传递所有结果 | 528 | max_results = len(search_response.results) # 不限制,传递所有结果 |
| 536 | for result in search_response.results[:max_results]: | 529 | for result in search_response.results[:max_results]: |
| @@ -548,24 +541,25 @@ class DeepSearchAgent: | @@ -548,24 +541,25 @@ class DeepSearchAgent: | ||
| 548 | }) | 541 | }) |
| 549 | 542 | ||
| 550 | if search_results: | 543 | if search_results: |
| 551 | - print(f" - 找到 {len(search_results)} 个搜索结果") | 544 | + _message = f" - 找到 {len(search_results)} 个搜索结果" |
| 552 | for j, result in enumerate(search_results, 1): | 545 | for j, result in enumerate(search_results, 1): |
| 553 | date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" | 546 | date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" |
| 554 | - print(f" {j}. {result['title'][:50]}...{date_info}") | 547 | + _message += f"\n {j}. {result['title'][:50]}...{date_info}" |
| 548 | + logger.info(_message) | ||
| 555 | else: | 549 | else: |
| 556 | - print(" - 未找到搜索结果") | 550 | + logger.info(" - 未找到搜索结果") |
| 557 | 551 | ||
| 558 | # 更新状态中的搜索历史 | 552 | # 更新状态中的搜索历史 |
| 559 | paragraph.research.add_search_results(search_query, search_results) | 553 | paragraph.research.add_search_results(search_query, search_results) |
| 560 | 554 | ||
| 561 | # 生成初始总结 | 555 | # 生成初始总结 |
| 562 | - print(" - 生成初始总结...") | 556 | + logger.info(" - 生成初始总结...") |
| 563 | summary_input = { | 557 | summary_input = { |
| 564 | "title": paragraph.title, | 558 | "title": paragraph.title, |
| 565 | "content": paragraph.content, | 559 | "content": paragraph.content, |
| 566 | "search_query": search_query, | 560 | "search_query": search_query, |
| 567 | "search_results": format_search_results_for_prompt( | 561 | "search_results": format_search_results_for_prompt( |
| 568 | - search_results, self.config.max_content_length | 562 | + search_results, self.config.MAX_CONTENT_LENGTH |
| 569 | ) | 563 | ) |
| 570 | } | 564 | } |
| 571 | 565 | ||
| @@ -574,14 +568,14 @@ class DeepSearchAgent: | @@ -574,14 +568,14 @@ class DeepSearchAgent: | ||
| 574 | summary_input, self.state, paragraph_index | 568 | summary_input, self.state, paragraph_index |
| 575 | ) | 569 | ) |
| 576 | 570 | ||
| 577 | - print(" - 初始总结完成") | 571 | + logger.info(" - 初始总结完成") |
| 578 | 572 | ||
| 579 | def _reflection_loop(self, paragraph_index: int): | 573 | def _reflection_loop(self, paragraph_index: int): |
| 580 | """执行反思循环""" | 574 | """执行反思循环""" |
| 581 | paragraph = self.state.paragraphs[paragraph_index] | 575 | paragraph = self.state.paragraphs[paragraph_index] |
| 582 | 576 | ||
| 583 | - for reflection_i in range(self.config.max_reflections): | ||
| 584 | - print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...") | 577 | + for reflection_i in range(self.config.MAX_REFLECTIONS): |
| 578 | + logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...") | ||
| 585 | 579 | ||
| 586 | # 准备反思输入 | 580 | # 准备反思输入 |
| 587 | reflection_input = { | 581 | reflection_input = { |
| @@ -596,9 +590,9 @@ class DeepSearchAgent: | @@ -596,9 +590,9 @@ class DeepSearchAgent: | ||
| 596 | search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具 | 590 | search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具 |
| 597 | reasoning = reflection_output["reasoning"] | 591 | reasoning = reflection_output["reasoning"] |
| 598 | 592 | ||
| 599 | - print(f" 反思查询: {search_query}") | ||
| 600 | - print(f" 选择的工具: {search_tool}") | ||
| 601 | - print(f" 反思推理: {reasoning}") | 593 | + logger.info(f" 反思查询: {search_query}") |
| 594 | + logger.info(f" 选择的工具: {search_tool}") | ||
| 595 | + logger.info(f" 反思推理: {reasoning}") | ||
| 602 | 596 | ||
| 603 | # 执行反思搜索 | 597 | # 执行反思搜索 |
| 604 | # 处理特殊参数 | 598 | # 处理特殊参数 |
| @@ -614,13 +608,13 @@ class DeepSearchAgent: | @@ -614,13 +608,13 @@ class DeepSearchAgent: | ||
| 614 | if self._validate_date_format(start_date) and self._validate_date_format(end_date): | 608 | if self._validate_date_format(start_date) and self._validate_date_format(end_date): |
| 615 | search_kwargs["start_date"] = start_date | 609 | search_kwargs["start_date"] = start_date |
| 616 | search_kwargs["end_date"] = end_date | 610 | search_kwargs["end_date"] = end_date |
| 617 | - print(f" 时间范围: {start_date} 到 {end_date}") | 611 | + logger.info(f" 时间范围: {start_date} 到 {end_date}") |
| 618 | else: | 612 | else: |
| 619 | - print(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") | ||
| 620 | - print(f" 提供的日期: start_date={start_date}, end_date={end_date}") | 613 | + logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") |
| 614 | + logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") | ||
| 621 | search_tool = "search_topic_globally" | 615 | search_tool = "search_topic_globally" |
| 622 | elif search_tool == "search_topic_by_date": | 616 | elif search_tool == "search_topic_by_date": |
| 623 | - print(f" search_topic_by_date工具缺少时间参数,改用全局搜索") | 617 | + logger.warning(f" search_topic_by_date工具缺少时间参数,改用全局搜索") |
| 624 | search_tool = "search_topic_globally" | 618 | search_tool = "search_topic_globally" |
| 625 | 619 | ||
| 626 | # 处理需要平台参数的工具 | 620 | # 处理需要平台参数的工具 |
| @@ -628,31 +622,31 @@ class DeepSearchAgent: | @@ -628,31 +622,31 @@ class DeepSearchAgent: | ||
| 628 | platform = reflection_output.get("platform") | 622 | platform = reflection_output.get("platform") |
| 629 | if platform: | 623 | if platform: |
| 630 | search_kwargs["platform"] = platform | 624 | search_kwargs["platform"] = platform |
| 631 | - print(f" 指定平台: {platform}") | 625 | + logger.info(f" 指定平台: {platform}") |
| 632 | else: | 626 | else: |
| 633 | - print(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") | 627 | + logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") |
| 634 | search_tool = "search_topic_globally" | 628 | search_tool = "search_topic_globally" |
| 635 | 629 | ||
| 636 | # 处理限制参数 | 630 | # 处理限制参数 |
| 637 | if search_tool == "search_hot_content": | 631 | if search_tool == "search_hot_content": |
| 638 | time_period = reflection_output.get("time_period", "week") | 632 | time_period = reflection_output.get("time_period", "week") |
| 639 | # 使用配置文件中的默认值,不允许agent控制limit参数 | 633 | # 使用配置文件中的默认值,不允许agent控制limit参数 |
| 640 | - limit = self.config.default_search_hot_content_limit | 634 | + limit = self.config.DEFAULT_SEARCH_HOT_CONTENT_LIMIT |
| 641 | search_kwargs["time_period"] = time_period | 635 | search_kwargs["time_period"] = time_period |
| 642 | search_kwargs["limit"] = limit | 636 | search_kwargs["limit"] = limit |
| 643 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: | 637 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: |
| 644 | # 使用配置文件中的默认值,不允许agent控制limit_per_table参数 | 638 | # 使用配置文件中的默认值,不允许agent控制limit_per_table参数 |
| 645 | if search_tool == "search_topic_globally": | 639 | if search_tool == "search_topic_globally": |
| 646 | - limit_per_table = self.config.default_search_topic_globally_limit_per_table | 640 | + limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE |
| 647 | else: # search_topic_by_date | 641 | else: # search_topic_by_date |
| 648 | - limit_per_table = self.config.default_search_topic_by_date_limit_per_table | 642 | + limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE |
| 649 | search_kwargs["limit_per_table"] = limit_per_table | 643 | search_kwargs["limit_per_table"] = limit_per_table |
| 650 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: | 644 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: |
| 651 | # 使用配置文件中的默认值,不允许agent控制limit参数 | 645 | # 使用配置文件中的默认值,不允许agent控制limit参数 |
| 652 | if search_tool == "get_comments_for_topic": | 646 | if search_tool == "get_comments_for_topic": |
| 653 | - limit = self.config.default_get_comments_for_topic_limit | 647 | + limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT |
| 654 | else: # search_topic_on_platform | 648 | else: # search_topic_on_platform |
| 655 | - limit = self.config.default_search_topic_on_platform_limit | 649 | + limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT |
| 656 | search_kwargs["limit"] = limit | 650 | search_kwargs["limit"] = limit |
| 657 | 651 | ||
| 658 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | 652 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) |
| @@ -661,8 +655,8 @@ class DeepSearchAgent: | @@ -661,8 +655,8 @@ class DeepSearchAgent: | ||
| 661 | search_results = [] | 655 | search_results = [] |
| 662 | if search_response and search_response.results: | 656 | if search_response and search_response.results: |
| 663 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 | 657 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 |
| 664 | - if self.config.max_search_results_for_llm > 0: | ||
| 665 | - max_results = min(len(search_response.results), self.config.max_search_results_for_llm) | 658 | + if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: |
| 659 | + max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) | ||
| 666 | else: | 660 | else: |
| 667 | max_results = len(search_response.results) # 不限制,传递所有结果 | 661 | max_results = len(search_response.results) # 不限制,传递所有结果 |
| 668 | for result in search_response.results[:max_results]: | 662 | for result in search_response.results[:max_results]: |
| @@ -680,12 +674,13 @@ class DeepSearchAgent: | @@ -680,12 +674,13 @@ class DeepSearchAgent: | ||
| 680 | }) | 674 | }) |
| 681 | 675 | ||
| 682 | if search_results: | 676 | if search_results: |
| 683 | - print(f" 找到 {len(search_results)} 个反思搜索结果") | 677 | + _message = f" 找到 {len(search_results)} 个反思搜索结果" |
| 684 | for j, result in enumerate(search_results, 1): | 678 | for j, result in enumerate(search_results, 1): |
| 685 | date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" | 679 | date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" |
| 686 | - print(f" {j}. {result['title'][:50]}...{date_info}") | 680 | + _message += f"\n {j}. {result['title'][:50]}...{date_info}" |
| 681 | + logger.info(_message) | ||
| 687 | else: | 682 | else: |
| 688 | - print(" 未找到反思搜索结果") | 683 | + logger.info(" 未找到反思搜索结果") |
| 689 | 684 | ||
| 690 | # 更新搜索历史 | 685 | # 更新搜索历史 |
| 691 | paragraph.research.add_search_results(search_query, search_results) | 686 | paragraph.research.add_search_results(search_query, search_results) |
| @@ -696,7 +691,7 @@ class DeepSearchAgent: | @@ -696,7 +691,7 @@ class DeepSearchAgent: | ||
| 696 | "content": paragraph.content, | 691 | "content": paragraph.content, |
| 697 | "search_query": search_query, | 692 | "search_query": search_query, |
| 698 | "search_results": format_search_results_for_prompt( | 693 | "search_results": format_search_results_for_prompt( |
| 699 | - search_results, self.config.max_content_length | 694 | + search_results, self.config.MAX_CONTENT_LENGTH |
| 700 | ), | 695 | ), |
| 701 | "paragraph_latest_state": paragraph.research.latest_summary | 696 | "paragraph_latest_state": paragraph.research.latest_summary |
| 702 | } | 697 | } |
| @@ -706,11 +701,11 @@ class DeepSearchAgent: | @@ -706,11 +701,11 @@ class DeepSearchAgent: | ||
| 706 | reflection_summary_input, self.state, paragraph_index | 701 | reflection_summary_input, self.state, paragraph_index |
| 707 | ) | 702 | ) |
| 708 | 703 | ||
| 709 | - print(f" 反思 {reflection_i + 1} 完成") | 704 | + logger.info(f" 反思 {reflection_i + 1} 完成") |
| 710 | 705 | ||
| 711 | def _generate_final_report(self) -> str: | 706 | def _generate_final_report(self) -> str: |
| 712 | """生成最终报告""" | 707 | """生成最终报告""" |
| 713 | - print(f"\n[步骤 3] 生成最终报告...") | 708 | + logger.info(f"\n[步骤 3] 生成最终报告...") |
| 714 | 709 | ||
| 715 | # 准备报告数据 | 710 | # 准备报告数据 |
| 716 | report_data = [] | 711 | report_data = [] |
| @@ -724,7 +719,7 @@ class DeepSearchAgent: | @@ -724,7 +719,7 @@ class DeepSearchAgent: | ||
| 724 | try: | 719 | try: |
| 725 | final_report = self.report_formatting_node.run(report_data) | 720 | final_report = self.report_formatting_node.run(report_data) |
| 726 | except Exception as e: | 721 | except Exception as e: |
| 727 | - print(f"LLM格式化失败,使用备用方法: {str(e)}") | 722 | + logger.exception(f"LLM格式化失败,使用备用方法: {str(e)}") |
| 728 | final_report = self.report_formatting_node.format_report_manually( | 723 | final_report = self.report_formatting_node.format_report_manually( |
| 729 | report_data, self.state.report_title | 724 | report_data, self.state.report_title |
| 730 | ) | 725 | ) |
| @@ -733,7 +728,7 @@ class DeepSearchAgent: | @@ -733,7 +728,7 @@ class DeepSearchAgent: | ||
| 733 | self.state.final_report = final_report | 728 | self.state.final_report = final_report |
| 734 | self.state.mark_completed() | 729 | self.state.mark_completed() |
| 735 | 730 | ||
| 736 | - print("最终报告生成完成") | 731 | + logger.info("最终报告生成完成") |
| 737 | return final_report | 732 | return final_report |
| 738 | 733 | ||
| 739 | def _save_report(self, report_content: str): | 734 | def _save_report(self, report_content: str): |
| @@ -744,20 +739,20 @@ class DeepSearchAgent: | @@ -744,20 +739,20 @@ class DeepSearchAgent: | ||
| 744 | query_safe = query_safe.replace(' ', '_')[:30] | 739 | query_safe = query_safe.replace(' ', '_')[:30] |
| 745 | 740 | ||
| 746 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" | 741 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" |
| 747 | - filepath = os.path.join(self.config.output_dir, filename) | 742 | + filepath = os.path.join(self.config.OUTPUT_DIR, filename) |
| 748 | 743 | ||
| 749 | # 保存报告 | 744 | # 保存报告 |
| 750 | with open(filepath, 'w', encoding='utf-8') as f: | 745 | with open(filepath, 'w', encoding='utf-8') as f: |
| 751 | f.write(report_content) | 746 | f.write(report_content) |
| 752 | 747 | ||
| 753 | - print(f"报告已保存到: {filepath}") | 748 | + logger.info(f"报告已保存到: {filepath}") |
| 754 | 749 | ||
| 755 | # 保存状态(如果配置允许) | 750 | # 保存状态(如果配置允许) |
| 756 | - if self.config.save_intermediate_states: | 751 | + if self.config.SAVE_INTERMEDIATE_STATES: |
| 757 | state_filename = f"state_{query_safe}_{timestamp}.json" | 752 | state_filename = f"state_{query_safe}_{timestamp}.json" |
| 758 | - state_filepath = os.path.join(self.config.output_dir, state_filename) | 753 | + state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename) |
| 759 | self.state.save_to_file(state_filepath) | 754 | self.state.save_to_file(state_filepath) |
| 760 | - print(f"状态已保存到: {state_filepath}") | 755 | + logger.info(f"状态已保存到: {state_filepath}") |
| 761 | 756 | ||
| 762 | def get_progress_summary(self) -> Dict[str, Any]: | 757 | def get_progress_summary(self) -> Dict[str, Any]: |
| 763 | """获取进度摘要""" | 758 | """获取进度摘要""" |
| @@ -766,12 +761,12 @@ class DeepSearchAgent: | @@ -766,12 +761,12 @@ class DeepSearchAgent: | ||
| 766 | def load_state(self, filepath: str): | 761 | def load_state(self, filepath: str): |
| 767 | """从文件加载状态""" | 762 | """从文件加载状态""" |
| 768 | self.state = State.load_from_file(filepath) | 763 | self.state = State.load_from_file(filepath) |
| 769 | - print(f"状态已从 {filepath} 加载") | 764 | + logger.info(f"状态已从 {filepath} 加载") |
| 770 | 765 | ||
| 771 | def save_state(self, filepath: str): | 766 | def save_state(self, filepath: str): |
| 772 | """保存状态到文件""" | 767 | """保存状态到文件""" |
| 773 | self.state.save_to_file(filepath) | 768 | self.state.save_to_file(filepath) |
| 774 | - print(f"状态已保存到 {filepath}") | 769 | + logger.info(f"状态已保存到 {filepath}") |
| 775 | 770 | ||
| 776 | 771 | ||
| 777 | def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: | 772 | def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: |
| @@ -784,5 +779,5 @@ def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: | @@ -784,5 +779,5 @@ def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: | ||
| 784 | Returns: | 779 | Returns: |
| 785 | DeepSearchAgent实例 | 780 | DeepSearchAgent实例 |
| 786 | """ | 781 | """ |
| 787 | - config = load_config(config_file) | 782 | + config = settings |
| 788 | return DeepSearchAgent(config) | 783 | return DeepSearchAgent(config) |
| @@ -31,9 +31,9 @@ class LLMClient: | @@ -31,9 +31,9 @@ class LLMClient: | ||
| 31 | 31 | ||
| 32 | def __init__(self, api_key: str, model_name: str, base_url: Optional[str] = None): | 32 | def __init__(self, api_key: str, model_name: str, base_url: Optional[str] = None): |
| 33 | if not api_key: | 33 | if not api_key: |
| 34 | - raise ValueError("Insight Engine LLM API key is required.") | 34 | + raise ValueError("Insight Engine INSIGHT_ENGINE_API_KEY is required.") |
| 35 | if not model_name: | 35 | if not model_name: |
| 36 | - raise ValueError("Insight Engine model name is required.") | 36 | + raise ValueError("Insight Engine INSIGHT_ENGINE_MODEL_NAME is required.") |
| 37 | 37 | ||
| 38 | self.api_key = api_key | 38 | self.api_key = api_key |
| 39 | self.base_url = base_url | 39 | self.base_url = base_url |
| @@ -5,6 +5,7 @@ | @@ -5,6 +5,7 @@ | ||
| 5 | 5 | ||
| 6 | from abc import ABC, abstractmethod | 6 | from abc import ABC, abstractmethod |
| 7 | from typing import Any, Dict, Optional | 7 | from typing import Any, Dict, Optional |
| 8 | +from loguru import logger | ||
| 8 | from ..llms.base import LLMClient | 9 | from ..llms.base import LLMClient |
| 9 | from ..state.state import State | 10 | from ..state.state import State |
| 10 | 11 | ||
| @@ -63,11 +64,15 @@ class BaseNode(ABC): | @@ -63,11 +64,15 @@ class BaseNode(ABC): | ||
| 63 | 64 | ||
| 64 | def log_info(self, message: str): | 65 | def log_info(self, message: str): |
| 65 | """记录信息日志""" | 66 | """记录信息日志""" |
| 66 | - print(f"[{self.node_name}] {message}") | 67 | + logger.info(f"[{self.node_name}] {message}") |
| 68 | + | ||
| 69 | + def log_warning(self, message: str): | ||
| 70 | + """记录警告日志""" | ||
| 71 | + logger.warning(f"[{self.node_name}] 警告: {message}") | ||
| 67 | 72 | ||
| 68 | def log_error(self, message: str): | 73 | def log_error(self, message: str): |
| 69 | """记录错误日志""" | 74 | """记录错误日志""" |
| 70 | - print(f"[{self.node_name}] 错误: {message}") | 75 | + logger.error(f"[{self.node_name}] 错误: {message}") |
| 71 | 76 | ||
| 72 | 77 | ||
| 73 | class StateMutationNode(BaseNode): | 78 | class StateMutationNode(BaseNode): |
| @@ -5,6 +5,7 @@ | @@ -5,6 +5,7 @@ | ||
| 5 | 5 | ||
| 6 | import json | 6 | import json |
| 7 | from typing import List, Dict, Any | 7 | from typing import List, Dict, Any |
| 8 | +from loguru import logger | ||
| 8 | 9 | ||
| 9 | from .base_node import BaseNode | 10 | from .base_node import BaseNode |
| 10 | from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING | 11 | from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING |
| @@ -14,6 +15,8 @@ from ..utils.text_processing import ( | @@ -14,6 +15,8 @@ from ..utils.text_processing import ( | ||
| 14 | ) | 15 | ) |
| 15 | 16 | ||
| 16 | 17 | ||
| 18 | + | ||
| 19 | + | ||
| 17 | class ReportFormattingNode(BaseNode): | 20 | class ReportFormattingNode(BaseNode): |
| 18 | """格式化最终报告的节点""" | 21 | """格式化最终报告的节点""" |
| 19 | 22 | ||
| @@ -65,19 +68,22 @@ class ReportFormattingNode(BaseNode): | @@ -65,19 +68,22 @@ class ReportFormattingNode(BaseNode): | ||
| 65 | else: | 68 | else: |
| 66 | message = json.dumps(input_data, ensure_ascii=False) | 69 | message = json.dumps(input_data, ensure_ascii=False) |
| 67 | 70 | ||
| 68 | - self.log_info("正在格式化最终报告") | 71 | + logger.info("正在格式化最终报告") |
| 69 | 72 | ||
| 70 | # 调用LLM | 73 | # 调用LLM |
| 71 | - response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_FORMATTING, message) | 74 | + response = self.llm_client.invoke( |
| 75 | + SYSTEM_PROMPT_REPORT_FORMATTING, | ||
| 76 | + message, | ||
| 77 | + ) | ||
| 72 | 78 | ||
| 73 | # 处理响应 | 79 | # 处理响应 |
| 74 | processed_response = self.process_output(response) | 80 | processed_response = self.process_output(response) |
| 75 | 81 | ||
| 76 | - self.log_info("成功生成格式化报告") | 82 | + logger.info("成功生成格式化报告") |
| 77 | return processed_response | 83 | return processed_response |
| 78 | 84 | ||
| 79 | except Exception as e: | 85 | except Exception as e: |
| 80 | - self.log_error(f"报告格式化失败: {str(e)}") | 86 | + logger.exception(f"报告格式化失败: {str(e)}") |
| 81 | raise e | 87 | raise e |
| 82 | 88 | ||
| 83 | def process_output(self, output: str) -> str: | 89 | def process_output(self, output: str) -> str: |
| @@ -106,7 +112,7 @@ class ReportFormattingNode(BaseNode): | @@ -106,7 +112,7 @@ class ReportFormattingNode(BaseNode): | ||
| 106 | return cleaned_output.strip() | 112 | return cleaned_output.strip() |
| 107 | 113 | ||
| 108 | except Exception as e: | 114 | except Exception as e: |
| 109 | - self.log_error(f"处理输出失败: {str(e)}") | 115 | + logger.exception(f"处理输出失败: {str(e)}") |
| 110 | return "# 报告处理失败\n\n报告格式化过程中发生错误。" | 116 | return "# 报告处理失败\n\n报告格式化过程中发生错误。" |
| 111 | 117 | ||
| 112 | def format_report_manually(self, paragraphs_data: List[Dict[str, str]], | 118 | def format_report_manually(self, paragraphs_data: List[Dict[str, str]], |
| @@ -122,7 +128,7 @@ class ReportFormattingNode(BaseNode): | @@ -122,7 +128,7 @@ class ReportFormattingNode(BaseNode): | ||
| 122 | 格式化的Markdown报告 | 128 | 格式化的Markdown报告 |
| 123 | """ | 129 | """ |
| 124 | try: | 130 | try: |
| 125 | - self.log_info("使用手动格式化方法") | 131 | + logger.info("使用手动格式化方法") |
| 126 | 132 | ||
| 127 | # 构建报告 | 133 | # 构建报告 |
| 128 | report_lines = [ | 134 | report_lines = [ |
| @@ -160,5 +166,5 @@ class ReportFormattingNode(BaseNode): | @@ -160,5 +166,5 @@ class ReportFormattingNode(BaseNode): | ||
| 160 | return "\n".join(report_lines) | 166 | return "\n".join(report_lines) |
| 161 | 167 | ||
| 162 | except Exception as e: | 168 | except Exception as e: |
| 163 | - self.log_error(f"手动格式化失败: {str(e)}") | 169 | + logger.exception(f"手动格式化失败: {str(e)}") |
| 164 | return "# 报告生成失败\n\n无法完成报告格式化。" | 170 | return "# 报告生成失败\n\n无法完成报告格式化。" |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | import json | 6 | import json |
| 7 | from typing import Dict, Any, List | 7 | from typing import Dict, Any, List |
| 8 | from json.decoder import JSONDecodeError | 8 | from json.decoder import JSONDecodeError |
| 9 | +from loguru import logger | ||
| 9 | 10 | ||
| 10 | from .base_node import StateMutationNode | 11 | from .base_node import StateMutationNode |
| 11 | from ..state.state import State | 12 | from ..state.state import State |
| @@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode): | @@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode): | ||
| 48 | 报告结构列表 | 49 | 报告结构列表 |
| 49 | """ | 50 | """ |
| 50 | try: | 51 | try: |
| 51 | - self.log_info(f"正在为查询生成报告结构: {self.query}") | 52 | + logger.info(f"正在为查询生成报告结构: {self.query}") |
| 52 | 53 | ||
| 53 | # 调用LLM | 54 | # 调用LLM |
| 54 | response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) | 55 | response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) |
| @@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode): | @@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode): | ||
| 56 | # 处理响应 | 57 | # 处理响应 |
| 57 | processed_response = self.process_output(response) | 58 | processed_response = self.process_output(response) |
| 58 | 59 | ||
| 59 | - self.log_info(f"成功生成 {len(processed_response)} 个段落结构") | 60 | + logger.info(f"成功生成 {len(processed_response)} 个段落结构") |
| 60 | return processed_response | 61 | return processed_response |
| 61 | 62 | ||
| 62 | except Exception as e: | 63 | except Exception as e: |
| 63 | - self.log_error(f"生成报告结构失败: {str(e)}") | 64 | + logger.exception(f"生成报告结构失败: {str(e)}") |
| 64 | raise e | 65 | raise e |
| 65 | 66 | ||
| 66 | def process_output(self, output: str) -> List[Dict[str, str]]: | 67 | def process_output(self, output: str) -> List[Dict[str, str]]: |
| @@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode): | @@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode): | ||
| 79 | cleaned_output = clean_json_tags(cleaned_output) | 80 | cleaned_output = clean_json_tags(cleaned_output) |
| 80 | 81 | ||
| 81 | # 记录清理后的输出用于调试 | 82 | # 记录清理后的输出用于调试 |
| 82 | - self.log_info(f"清理后的输出: {cleaned_output}") | 83 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 83 | 84 | ||
| 84 | # 解析JSON | 85 | # 解析JSON |
| 85 | try: | 86 | try: |
| 86 | report_structure = json.loads(cleaned_output) | 87 | report_structure = json.loads(cleaned_output) |
| 87 | - self.log_info("JSON解析成功") | 88 | + logger.info("JSON解析成功") |
| 88 | except JSONDecodeError as e: | 89 | except JSONDecodeError as e: |
| 89 | - self.log_info(f"JSON解析失败: {str(e)}") | 90 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 90 | # 使用更强大的提取方法 | 91 | # 使用更强大的提取方法 |
| 91 | report_structure = extract_clean_response(cleaned_output) | 92 | report_structure = extract_clean_response(cleaned_output) |
| 92 | if "error" in report_structure: | 93 | if "error" in report_structure: |
| 93 | - self.log_error("JSON解析失败,尝试修复...") | 94 | + logger.exception("JSON解析失败,尝试修复...") |
| 94 | # 尝试修复JSON | 95 | # 尝试修复JSON |
| 95 | fixed_json = fix_incomplete_json(cleaned_output) | 96 | fixed_json = fix_incomplete_json(cleaned_output) |
| 96 | if fixed_json: | 97 | if fixed_json: |
| 97 | try: | 98 | try: |
| 98 | report_structure = json.loads(fixed_json) | 99 | report_structure = json.loads(fixed_json) |
| 99 | - self.log_info("JSON修复成功") | 100 | + logger.info("JSON修复成功") |
| 100 | except JSONDecodeError: | 101 | except JSONDecodeError: |
| 101 | - self.log_error("JSON修复失败") | 102 | + logger.exception("JSON修复失败") |
| 102 | # 返回默认结构 | 103 | # 返回默认结构 |
| 103 | return self._generate_default_structure() | 104 | return self._generate_default_structure() |
| 104 | else: | 105 | else: |
| 105 | - self.log_error("无法修复JSON,使用默认结构") | 106 | + logger.exception("无法修复JSON,使用默认结构") |
| 106 | return self._generate_default_structure() | 107 | return self._generate_default_structure() |
| 107 | 108 | ||
| 108 | # 验证结构 | 109 | # 验证结构 |
| 109 | if not isinstance(report_structure, list): | 110 | if not isinstance(report_structure, list): |
| 110 | - self.log_info("报告结构不是列表,尝试转换...") | 111 | + logger.info("报告结构不是列表,尝试转换...") |
| 111 | if isinstance(report_structure, dict): | 112 | if isinstance(report_structure, dict): |
| 112 | # 如果是单个对象,包装成列表 | 113 | # 如果是单个对象,包装成列表 |
| 113 | report_structure = [report_structure] | 114 | report_structure = [report_structure] |
| 114 | else: | 115 | else: |
| 115 | - self.log_error("报告结构格式无效,使用默认结构") | 116 | + logger.exception("报告结构格式无效,使用默认结构") |
| 116 | return self._generate_default_structure() | 117 | return self._generate_default_structure() |
| 117 | 118 | ||
| 118 | # 验证每个段落 | 119 | # 验证每个段落 |
| 119 | validated_structure = [] | 120 | validated_structure = [] |
| 120 | for i, paragraph in enumerate(report_structure): | 121 | for i, paragraph in enumerate(report_structure): |
| 121 | if not isinstance(paragraph, dict): | 122 | if not isinstance(paragraph, dict): |
| 122 | - self.log_warning(f"段落 {i+1} 不是字典格式,跳过") | 123 | + logger.warning(f"段落 {i+1} 不是字典格式,跳过") |
| 123 | continue | 124 | continue |
| 124 | 125 | ||
| 125 | title = paragraph.get("title", f"段落 {i+1}") | 126 | title = paragraph.get("title", f"段落 {i+1}") |
| 126 | content = paragraph.get("content", "") | 127 | content = paragraph.get("content", "") |
| 127 | 128 | ||
| 128 | if not title or not content: | 129 | if not title or not content: |
| 129 | - self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过") | 130 | + logger.warning(f"段落 {i+1} 缺少标题或内容,跳过") |
| 130 | continue | 131 | continue |
| 131 | 132 | ||
| 132 | validated_structure.append({ | 133 | validated_structure.append({ |
| @@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode): | @@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode): | ||
| 135 | }) | 136 | }) |
| 136 | 137 | ||
| 137 | if not validated_structure: | 138 | if not validated_structure: |
| 138 | - self.log_warning("没有有效的段落结构,使用默认结构") | 139 | + logger.warning("没有有效的段落结构,使用默认结构") |
| 139 | return self._generate_default_structure() | 140 | return self._generate_default_structure() |
| 140 | 141 | ||
| 141 | - self.log_info(f"成功验证 {len(validated_structure)} 个段落结构") | 142 | + logger.info(f"成功验证 {len(validated_structure)} 个段落结构") |
| 142 | return validated_structure | 143 | return validated_structure |
| 143 | 144 | ||
| 144 | except Exception as e: | 145 | except Exception as e: |
| 145 | - self.log_error(f"处理输出失败: {str(e)}") | 146 | + logger.exception(f"处理输出失败: {str(e)}") |
| 146 | return self._generate_default_structure() | 147 | return self._generate_default_structure() |
| 147 | 148 | ||
| 148 | def _generate_default_structure(self) -> List[Dict[str, str]]: | 149 | def _generate_default_structure(self) -> List[Dict[str, str]]: |
| @@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode): | @@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode): | ||
| 152 | Returns: | 153 | Returns: |
| 153 | 默认的报告结构列表 | 154 | 默认的报告结构列表 |
| 154 | """ | 155 | """ |
| 155 | - self.log_info("生成默认报告结构") | 156 | + logger.info("生成默认报告结构") |
| 156 | return [ | 157 | return [ |
| 157 | { | 158 | { |
| 158 | "title": "研究概述", | 159 | "title": "研究概述", |
| @@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode): | @@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode): | ||
| 195 | content=paragraph_data["content"] | 196 | content=paragraph_data["content"] |
| 196 | ) | 197 | ) |
| 197 | 198 | ||
| 198 | - self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中") | 199 | + logger.info(f"已将 {len(report_structure)} 个段落添加到状态中") |
| 199 | return state | 200 | return state |
| 200 | 201 | ||
| 201 | except Exception as e: | 202 | except Exception as e: |
| 202 | - self.log_error(f"状态更新失败: {str(e)}") | 203 | + logger.exception(f"状态更新失败: {str(e)}") |
| 203 | raise e | 204 | raise e |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | import json | 6 | import json |
| 7 | from typing import Dict, Any | 7 | from typing import Dict, Any |
| 8 | from json.decoder import JSONDecodeError | 8 | from json.decoder import JSONDecodeError |
| 9 | +from loguru import logger | ||
| 9 | 10 | ||
| 10 | from .base_node import BaseNode | 11 | from .base_node import BaseNode |
| 11 | from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION | 12 | from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION |
| @@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode): | @@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode): | ||
| 62 | else: | 63 | else: |
| 63 | message = json.dumps(input_data, ensure_ascii=False) | 64 | message = json.dumps(input_data, ensure_ascii=False) |
| 64 | 65 | ||
| 65 | - self.log_info("正在生成首次搜索查询") | 66 | + logger.info("正在生成首次搜索查询") |
| 66 | 67 | ||
| 67 | # 调用LLM | 68 | # 调用LLM |
| 68 | response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message) | 69 | response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message) |
| @@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode): | @@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode): | ||
| 70 | # 处理响应 | 71 | # 处理响应 |
| 71 | processed_response = self.process_output(response) | 72 | processed_response = self.process_output(response) |
| 72 | 73 | ||
| 73 | - self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}") | 74 | + logger.info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}") |
| 74 | return processed_response | 75 | return processed_response |
| 75 | 76 | ||
| 76 | except Exception as e: | 77 | except Exception as e: |
| 77 | - self.log_error(f"生成首次搜索查询失败: {str(e)}") | 78 | + logger.exception(f"生成首次搜索查询失败: {str(e)}") |
| 78 | raise e | 79 | raise e |
| 79 | 80 | ||
| 80 | def process_output(self, output: str) -> Dict[str, str]: | 81 | def process_output(self, output: str) -> Dict[str, str]: |
| @@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode): | @@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode): | ||
| 93 | cleaned_output = clean_json_tags(cleaned_output) | 94 | cleaned_output = clean_json_tags(cleaned_output) |
| 94 | 95 | ||
| 95 | # 记录清理后的输出用于调试 | 96 | # 记录清理后的输出用于调试 |
| 96 | - self.log_info(f"清理后的输出: {cleaned_output}") | 97 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 97 | 98 | ||
| 98 | # 解析JSON | 99 | # 解析JSON |
| 99 | try: | 100 | try: |
| 100 | result = json.loads(cleaned_output) | 101 | result = json.loads(cleaned_output) |
| 101 | - self.log_info("JSON解析成功") | 102 | + logger.info("JSON解析成功") |
| 102 | except JSONDecodeError as e: | 103 | except JSONDecodeError as e: |
| 103 | - self.log_info(f"JSON解析失败: {str(e)}") | 104 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 104 | # 使用更强大的提取方法 | 105 | # 使用更强大的提取方法 |
| 105 | result = extract_clean_response(cleaned_output) | 106 | result = extract_clean_response(cleaned_output) |
| 106 | if "error" in result: | 107 | if "error" in result: |
| 107 | - self.log_error("JSON解析失败,尝试修复...") | 108 | + logger.error("JSON解析失败,尝试修复...") |
| 108 | # 尝试修复JSON | 109 | # 尝试修复JSON |
| 109 | fixed_json = fix_incomplete_json(cleaned_output) | 110 | fixed_json = fix_incomplete_json(cleaned_output) |
| 110 | if fixed_json: | 111 | if fixed_json: |
| 111 | try: | 112 | try: |
| 112 | result = json.loads(fixed_json) | 113 | result = json.loads(fixed_json) |
| 113 | - self.log_info("JSON修复成功") | 114 | + logger.info("JSON修复成功") |
| 114 | except JSONDecodeError: | 115 | except JSONDecodeError: |
| 115 | - self.log_error("JSON修复失败") | 116 | + logger.error("JSON修复失败") |
| 116 | # 返回默认查询 | 117 | # 返回默认查询 |
| 117 | return self._get_default_search_query() | 118 | return self._get_default_search_query() |
| 118 | else: | 119 | else: |
| 119 | - self.log_error("无法修复JSON,使用默认查询") | 120 | + logger.error("无法修复JSON,使用默认查询") |
| 120 | return self._get_default_search_query() | 121 | return self._get_default_search_query() |
| 121 | 122 | ||
| 122 | # 验证和清理结果 | 123 | # 验证和清理结果 |
| @@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode): | @@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode): | ||
| 124 | reasoning = result.get("reasoning", "") | 125 | reasoning = result.get("reasoning", "") |
| 125 | 126 | ||
| 126 | if not search_query: | 127 | if not search_query: |
| 127 | - self.log_warning("未找到搜索查询,使用默认查询") | 128 | + logger.warning("未找到搜索查询,使用默认查询") |
| 128 | return self._get_default_search_query() | 129 | return self._get_default_search_query() |
| 129 | 130 | ||
| 130 | return { | 131 | return { |
| @@ -197,7 +198,7 @@ class ReflectionNode(BaseNode): | @@ -197,7 +198,7 @@ class ReflectionNode(BaseNode): | ||
| 197 | else: | 198 | else: |
| 198 | message = json.dumps(input_data, ensure_ascii=False) | 199 | message = json.dumps(input_data, ensure_ascii=False) |
| 199 | 200 | ||
| 200 | - self.log_info("正在进行反思并生成新搜索查询") | 201 | + logger.info("正在进行反思并生成新搜索查询") |
| 201 | 202 | ||
| 202 | # 调用LLM | 203 | # 调用LLM |
| 203 | response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message) | 204 | response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message) |
| @@ -205,11 +206,11 @@ class ReflectionNode(BaseNode): | @@ -205,11 +206,11 @@ class ReflectionNode(BaseNode): | ||
| 205 | # 处理响应 | 206 | # 处理响应 |
| 206 | processed_response = self.process_output(response) | 207 | processed_response = self.process_output(response) |
| 207 | 208 | ||
| 208 | - self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}") | 209 | + logger.info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}") |
| 209 | return processed_response | 210 | return processed_response |
| 210 | 211 | ||
| 211 | except Exception as e: | 212 | except Exception as e: |
| 212 | - self.log_error(f"反思生成搜索查询失败: {str(e)}") | 213 | + logger.exception(f"反思生成搜索查询失败: {str(e)}") |
| 213 | raise e | 214 | raise e |
| 214 | 215 | ||
| 215 | def process_output(self, output: str) -> Dict[str, str]: | 216 | def process_output(self, output: str) -> Dict[str, str]: |
| @@ -228,30 +229,30 @@ class ReflectionNode(BaseNode): | @@ -228,30 +229,30 @@ class ReflectionNode(BaseNode): | ||
| 228 | cleaned_output = clean_json_tags(cleaned_output) | 229 | cleaned_output = clean_json_tags(cleaned_output) |
| 229 | 230 | ||
| 230 | # 记录清理后的输出用于调试 | 231 | # 记录清理后的输出用于调试 |
| 231 | - self.log_info(f"清理后的输出: {cleaned_output}") | 232 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 232 | 233 | ||
| 233 | # 解析JSON | 234 | # 解析JSON |
| 234 | try: | 235 | try: |
| 235 | result = json.loads(cleaned_output) | 236 | result = json.loads(cleaned_output) |
| 236 | - self.log_info("JSON解析成功") | 237 | + logger.info("JSON解析成功") |
| 237 | except JSONDecodeError as e: | 238 | except JSONDecodeError as e: |
| 238 | - self.log_info(f"JSON解析失败: {str(e)}") | 239 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 239 | # 使用更强大的提取方法 | 240 | # 使用更强大的提取方法 |
| 240 | result = extract_clean_response(cleaned_output) | 241 | result = extract_clean_response(cleaned_output) |
| 241 | if "error" in result: | 242 | if "error" in result: |
| 242 | - self.log_error("JSON解析失败,尝试修复...") | 243 | + logger.error("JSON解析失败,尝试修复...") |
| 243 | # 尝试修复JSON | 244 | # 尝试修复JSON |
| 244 | fixed_json = fix_incomplete_json(cleaned_output) | 245 | fixed_json = fix_incomplete_json(cleaned_output) |
| 245 | if fixed_json: | 246 | if fixed_json: |
| 246 | try: | 247 | try: |
| 247 | result = json.loads(fixed_json) | 248 | result = json.loads(fixed_json) |
| 248 | - self.log_info("JSON修复成功") | 249 | + logger.info("JSON修复成功") |
| 249 | except JSONDecodeError: | 250 | except JSONDecodeError: |
| 250 | - self.log_error("JSON修复失败") | 251 | + logger.error("JSON修复失败") |
| 251 | # 返回默认查询 | 252 | # 返回默认查询 |
| 252 | return self._get_default_reflection_query() | 253 | return self._get_default_reflection_query() |
| 253 | else: | 254 | else: |
| 254 | - self.log_error("无法修复JSON,使用默认查询") | 255 | + logger.error("无法修复JSON,使用默认查询") |
| 255 | return self._get_default_reflection_query() | 256 | return self._get_default_reflection_query() |
| 256 | 257 | ||
| 257 | # 验证和清理结果 | 258 | # 验证和清理结果 |
| @@ -259,7 +260,7 @@ class ReflectionNode(BaseNode): | @@ -259,7 +260,7 @@ class ReflectionNode(BaseNode): | ||
| 259 | reasoning = result.get("reasoning", "") | 260 | reasoning = result.get("reasoning", "") |
| 260 | 261 | ||
| 261 | if not search_query: | 262 | if not search_query: |
| 262 | - self.log_warning("未找到搜索查询,使用默认查询") | 263 | + logger.warning("未找到搜索查询,使用默认查询") |
| 263 | return self._get_default_reflection_query() | 264 | return self._get_default_reflection_query() |
| 264 | 265 | ||
| 265 | return { | 266 | return { |
| @@ -268,7 +269,7 @@ class ReflectionNode(BaseNode): | @@ -268,7 +269,7 @@ class ReflectionNode(BaseNode): | ||
| 268 | } | 269 | } |
| 269 | 270 | ||
| 270 | except Exception as e: | 271 | except Exception as e: |
| 271 | - self.log_error(f"处理输出失败: {str(e)}") | 272 | + logger.exception(f"处理输出失败: {str(e)}") |
| 272 | # 返回默认查询 | 273 | # 返回默认查询 |
| 273 | return self._get_default_reflection_query() | 274 | return self._get_default_reflection_query() |
| 274 | 275 |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | import json | 6 | import json |
| 7 | from typing import Dict, Any, List | 7 | from typing import Dict, Any, List |
| 8 | from json.decoder import JSONDecodeError | 8 | from json.decoder import JSONDecodeError |
| 9 | +from loguru import logger | ||
| 9 | 10 | ||
| 10 | from .base_node import StateMutationNode | 11 | from .base_node import StateMutationNode |
| 11 | from ..state.state import State | 12 | from ..state.state import State |
| @@ -27,7 +28,7 @@ try: | @@ -27,7 +28,7 @@ try: | ||
| 27 | FORUM_READER_AVAILABLE = True | 28 | FORUM_READER_AVAILABLE = True |
| 28 | except ImportError: | 29 | except ImportError: |
| 29 | FORUM_READER_AVAILABLE = False | 30 | FORUM_READER_AVAILABLE = False |
| 30 | - print("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能") | 31 | + logger.warning("无法导入forum_reader模块,将跳过HOST发言读取功能") |
| 31 | 32 | ||
| 32 | 33 | ||
| 33 | class FirstSummaryNode(StateMutationNode): | 34 | class FirstSummaryNode(StateMutationNode): |
| @@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode): | @@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode): | ||
| 84 | if host_speech: | 85 | if host_speech: |
| 85 | # 将HOST发言添加到输入数据中 | 86 | # 将HOST发言添加到输入数据中 |
| 86 | data['host_speech'] = host_speech | 87 | data['host_speech'] = host_speech |
| 87 | - self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符") | 88 | + logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符") |
| 88 | except Exception as e: | 89 | except Exception as e: |
| 89 | - self.log_info(f"读取HOST发言失败: {str(e)}") | 90 | + logger.exception(f"读取HOST发言失败: {str(e)}") |
| 90 | 91 | ||
| 91 | # 转换为JSON字符串 | 92 | # 转换为JSON字符串 |
| 92 | message = json.dumps(data, ensure_ascii=False) | 93 | message = json.dumps(data, ensure_ascii=False) |
| @@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode): | @@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode): | ||
| 96 | formatted_host = format_host_speech_for_prompt(data['host_speech']) | 97 | formatted_host = format_host_speech_for_prompt(data['host_speech']) |
| 97 | message = formatted_host + "\n" + message | 98 | message = formatted_host + "\n" + message |
| 98 | 99 | ||
| 99 | - self.log_info("正在生成首次段落总结") | 100 | + logger.info("正在生成首次段落总结") |
| 100 | 101 | ||
| 101 | # 调用LLM | 102 | # 调用LLM |
| 102 | response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message) | 103 | response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message) |
| @@ -104,11 +105,11 @@ class FirstSummaryNode(StateMutationNode): | @@ -104,11 +105,11 @@ class FirstSummaryNode(StateMutationNode): | ||
| 104 | # 处理响应 | 105 | # 处理响应 |
| 105 | processed_response = self.process_output(response) | 106 | processed_response = self.process_output(response) |
| 106 | 107 | ||
| 107 | - self.log_info("成功生成首次段落总结") | 108 | + logger.info("成功生成首次段落总结") |
| 108 | return processed_response | 109 | return processed_response |
| 109 | 110 | ||
| 110 | except Exception as e: | 111 | except Exception as e: |
| 111 | - self.log_error(f"生成首次总结失败: {str(e)}") | 112 | + logger.exception(f"生成首次总结失败: {str(e)}") |
| 112 | raise e | 113 | raise e |
| 113 | 114 | ||
| 114 | def process_output(self, output: str) -> str: | 115 | def process_output(self, output: str) -> str: |
| @@ -127,26 +128,26 @@ class FirstSummaryNode(StateMutationNode): | @@ -127,26 +128,26 @@ class FirstSummaryNode(StateMutationNode): | ||
| 127 | cleaned_output = clean_json_tags(cleaned_output) | 128 | cleaned_output = clean_json_tags(cleaned_output) |
| 128 | 129 | ||
| 129 | # 记录清理后的输出用于调试 | 130 | # 记录清理后的输出用于调试 |
| 130 | - self.log_info(f"清理后的输出: {cleaned_output}") | 131 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 131 | 132 | ||
| 132 | # 解析JSON | 133 | # 解析JSON |
| 133 | try: | 134 | try: |
| 134 | result = json.loads(cleaned_output) | 135 | result = json.loads(cleaned_output) |
| 135 | - self.log_info("JSON解析成功") | 136 | + logger.info("JSON解析成功") |
| 136 | except JSONDecodeError as e: | 137 | except JSONDecodeError as e: |
| 137 | - self.log_info(f"JSON解析失败: {str(e)}") | 138 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 138 | # 尝试修复JSON | 139 | # 尝试修复JSON |
| 139 | fixed_json = fix_incomplete_json(cleaned_output) | 140 | fixed_json = fix_incomplete_json(cleaned_output) |
| 140 | if fixed_json: | 141 | if fixed_json: |
| 141 | try: | 142 | try: |
| 142 | result = json.loads(fixed_json) | 143 | result = json.loads(fixed_json) |
| 143 | - self.log_info("JSON修复成功") | 144 | + logger.info("JSON修复成功") |
| 144 | except JSONDecodeError: | 145 | except JSONDecodeError: |
| 145 | - self.log_info("JSON修复失败,直接使用清理后的文本") | 146 | + logger.exception("JSON修复失败,直接使用清理后的文本") |
| 146 | # 如果不是JSON格式,直接返回清理后的文本 | 147 | # 如果不是JSON格式,直接返回清理后的文本 |
| 147 | return cleaned_output | 148 | return cleaned_output |
| 148 | else: | 149 | else: |
| 149 | - self.log_info("无法修复JSON,直接使用清理后的文本") | 150 | + logger.exception("无法修复JSON,直接使用清理后的文本") |
| 150 | # 如果不是JSON格式,直接返回清理后的文本 | 151 | # 如果不是JSON格式,直接返回清理后的文本 |
| 151 | return cleaned_output | 152 | return cleaned_output |
| 152 | 153 | ||
| @@ -160,7 +161,7 @@ class FirstSummaryNode(StateMutationNode): | @@ -160,7 +161,7 @@ class FirstSummaryNode(StateMutationNode): | ||
| 160 | return cleaned_output | 161 | return cleaned_output |
| 161 | 162 | ||
| 162 | except Exception as e: | 163 | except Exception as e: |
| 163 | - self.log_error(f"处理输出失败: {str(e)}") | 164 | + logger.exception(f"处理输出失败: {str(e)}") |
| 164 | return "段落总结生成失败" | 165 | return "段落总结生成失败" |
| 165 | 166 | ||
| 166 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: | 167 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: |
| @@ -183,7 +184,7 @@ class FirstSummaryNode(StateMutationNode): | @@ -183,7 +184,7 @@ class FirstSummaryNode(StateMutationNode): | ||
| 183 | # 更新状态 | 184 | # 更新状态 |
| 184 | if 0 <= paragraph_index < len(state.paragraphs): | 185 | if 0 <= paragraph_index < len(state.paragraphs): |
| 185 | state.paragraphs[paragraph_index].research.latest_summary = summary | 186 | state.paragraphs[paragraph_index].research.latest_summary = summary |
| 186 | - self.log_info(f"已更新段落 {paragraph_index} 的首次总结") | 187 | + logger.info(f"已更新段落 {paragraph_index} 的首次总结") |
| 187 | else: | 188 | else: |
| 188 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") | 189 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") |
| 189 | 190 | ||
| @@ -191,7 +192,7 @@ class FirstSummaryNode(StateMutationNode): | @@ -191,7 +192,7 @@ class FirstSummaryNode(StateMutationNode): | ||
| 191 | return state | 192 | return state |
| 192 | 193 | ||
| 193 | except Exception as e: | 194 | except Exception as e: |
| 194 | - self.log_error(f"状态更新失败: {str(e)}") | 195 | + logger.exception(f"状态更新失败: {str(e)}") |
| 195 | raise e | 196 | raise e |
| 196 | 197 | ||
| 197 | 198 | ||
| @@ -249,9 +250,9 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -249,9 +250,9 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 249 | if host_speech: | 250 | if host_speech: |
| 250 | # 将HOST发言添加到输入数据中 | 251 | # 将HOST发言添加到输入数据中 |
| 251 | data['host_speech'] = host_speech | 252 | data['host_speech'] = host_speech |
| 252 | - self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符") | 253 | + logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符") |
| 253 | except Exception as e: | 254 | except Exception as e: |
| 254 | - self.log_info(f"读取HOST发言失败: {str(e)}") | 255 | + logger.exception(f"读取HOST发言失败: {str(e)}") |
| 255 | 256 | ||
| 256 | # 转换为JSON字符串 | 257 | # 转换为JSON字符串 |
| 257 | message = json.dumps(data, ensure_ascii=False) | 258 | message = json.dumps(data, ensure_ascii=False) |
| @@ -261,7 +262,7 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -261,7 +262,7 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 261 | formatted_host = format_host_speech_for_prompt(data['host_speech']) | 262 | formatted_host = format_host_speech_for_prompt(data['host_speech']) |
| 262 | message = formatted_host + "\n" + message | 263 | message = formatted_host + "\n" + message |
| 263 | 264 | ||
| 264 | - self.log_info("正在生成反思总结") | 265 | + logger.info("正在生成反思总结") |
| 265 | 266 | ||
| 266 | # 调用LLM | 267 | # 调用LLM |
| 267 | response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message) | 268 | response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message) |
| @@ -269,11 +270,11 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -269,11 +270,11 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 269 | # 处理响应 | 270 | # 处理响应 |
| 270 | processed_response = self.process_output(response) | 271 | processed_response = self.process_output(response) |
| 271 | 272 | ||
| 272 | - self.log_info("成功生成反思总结") | 273 | + logger.info("成功生成反思总结") |
| 273 | return processed_response | 274 | return processed_response |
| 274 | 275 | ||
| 275 | except Exception as e: | 276 | except Exception as e: |
| 276 | - self.log_error(f"生成反思总结失败: {str(e)}") | 277 | + logger.exception(f"生成反思总结失败: {str(e)}") |
| 277 | raise e | 278 | raise e |
| 278 | 279 | ||
| 279 | def process_output(self, output: str) -> str: | 280 | def process_output(self, output: str) -> str: |
| @@ -292,26 +293,26 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -292,26 +293,26 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 292 | cleaned_output = clean_json_tags(cleaned_output) | 293 | cleaned_output = clean_json_tags(cleaned_output) |
| 293 | 294 | ||
| 294 | # 记录清理后的输出用于调试 | 295 | # 记录清理后的输出用于调试 |
| 295 | - self.log_info(f"清理后的输出: {cleaned_output}") | 296 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 296 | 297 | ||
| 297 | # 解析JSON | 298 | # 解析JSON |
| 298 | try: | 299 | try: |
| 299 | result = json.loads(cleaned_output) | 300 | result = json.loads(cleaned_output) |
| 300 | - self.log_info("JSON解析成功") | 301 | + logger.info("JSON解析成功") |
| 301 | except JSONDecodeError as e: | 302 | except JSONDecodeError as e: |
| 302 | - self.log_info(f"JSON解析失败: {str(e)}") | 303 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 303 | # 尝试修复JSON | 304 | # 尝试修复JSON |
| 304 | fixed_json = fix_incomplete_json(cleaned_output) | 305 | fixed_json = fix_incomplete_json(cleaned_output) |
| 305 | if fixed_json: | 306 | if fixed_json: |
| 306 | try: | 307 | try: |
| 307 | result = json.loads(fixed_json) | 308 | result = json.loads(fixed_json) |
| 308 | - self.log_info("JSON修复成功") | 309 | + logger.info("JSON修复成功") |
| 309 | except JSONDecodeError: | 310 | except JSONDecodeError: |
| 310 | - self.log_info("JSON修复失败,直接使用清理后的文本") | 311 | + logger.info("JSON修复失败,直接使用清理后的文本") |
| 311 | # 如果不是JSON格式,直接返回清理后的文本 | 312 | # 如果不是JSON格式,直接返回清理后的文本 |
| 312 | return cleaned_output | 313 | return cleaned_output |
| 313 | else: | 314 | else: |
| 314 | - self.log_info("无法修复JSON,直接使用清理后的文本") | 315 | + logger.info("无法修复JSON,直接使用清理后的文本") |
| 315 | # 如果不是JSON格式,直接返回清理后的文本 | 316 | # 如果不是JSON格式,直接返回清理后的文本 |
| 316 | return cleaned_output | 317 | return cleaned_output |
| 317 | 318 | ||
| @@ -325,7 +326,7 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -325,7 +326,7 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 325 | return cleaned_output | 326 | return cleaned_output |
| 326 | 327 | ||
| 327 | except Exception as e: | 328 | except Exception as e: |
| 328 | - self.log_error(f"处理输出失败: {str(e)}") | 329 | + logger.exception(f"处理输出失败: {str(e)}") |
| 329 | return "反思总结生成失败" | 330 | return "反思总结生成失败" |
| 330 | 331 | ||
| 331 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: | 332 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: |
| @@ -349,7 +350,7 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -349,7 +350,7 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 349 | if 0 <= paragraph_index < len(state.paragraphs): | 350 | if 0 <= paragraph_index < len(state.paragraphs): |
| 350 | state.paragraphs[paragraph_index].research.latest_summary = updated_summary | 351 | state.paragraphs[paragraph_index].research.latest_summary = updated_summary |
| 351 | state.paragraphs[paragraph_index].research.increment_reflection() | 352 | state.paragraphs[paragraph_index].research.increment_reflection() |
| 352 | - self.log_info(f"已更新段落 {paragraph_index} 的反思总结") | 353 | + logger.info(f"已更新段落 {paragraph_index} 的反思总结") |
| 353 | else: | 354 | else: |
| 354 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") | 355 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") |
| 355 | 356 | ||
| @@ -357,5 +358,5 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -357,5 +358,5 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 357 | return state | 358 | return state |
| 358 | 359 | ||
| 359 | except Exception as e: | 360 | except Exception as e: |
| 360 | - self.log_error(f"状态更新失败: {str(e)}") | 361 | + logger.exception(f"状态更新失败: {str(e)}") |
| 361 | raise e | 362 | raise e |
| @@ -12,7 +12,8 @@ from dataclasses import dataclass | @@ -12,7 +12,8 @@ from dataclasses import dataclass | ||
| 12 | 12 | ||
| 13 | # 添加项目根目录到Python路径以导入config | 13 | # 添加项目根目录到Python路径以导入config |
| 14 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) | 14 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) |
| 15 | -from config import KEYWORD_OPTIMIZER_API_KEY, KEYWORD_OPTIMIZER_BASE_URL, KEYWORD_OPTIMIZER_MODEL_NAME | 15 | +from config import settings |
| 16 | +from loguru import logger | ||
| 16 | 17 | ||
| 17 | # 添加utils目录到Python路径 | 18 | # 添加utils目录到Python路径 |
| 18 | current_dir = os.path.dirname(os.path.abspath(__file__)) | 19 | current_dir = os.path.dirname(os.path.abspath(__file__)) |
| @@ -46,18 +47,18 @@ class KeywordOptimizer: | @@ -46,18 +47,18 @@ class KeywordOptimizer: | ||
| 46 | api_key: 硅基流动API密钥,如果不提供则从配置文件读取 | 47 | api_key: 硅基流动API密钥,如果不提供则从配置文件读取 |
| 47 | base_url: 接口基础地址,默认使用配置文件提供的SiliconFlow地址 | 48 | base_url: 接口基础地址,默认使用配置文件提供的SiliconFlow地址 |
| 48 | """ | 49 | """ |
| 49 | - self.api_key = api_key or KEYWORD_OPTIMIZER_API_KEY | 50 | + self.api_key = api_key or settings.KEYWORD_OPTIMIZER_API_KEY |
| 50 | 51 | ||
| 51 | if not self.api_key: | 52 | if not self.api_key: |
| 52 | raise ValueError("未找到硅基流动API密钥,请在config.py中设置KEYWORD_OPTIMIZER_API_KEY") | 53 | raise ValueError("未找到硅基流动API密钥,请在config.py中设置KEYWORD_OPTIMIZER_API_KEY") |
| 53 | 54 | ||
| 54 | - self.base_url = base_url or KEYWORD_OPTIMIZER_BASE_URL | 55 | + self.base_url = base_url or settings.KEYWORD_OPTIMIZER_BASE_URL |
| 55 | 56 | ||
| 56 | self.client = OpenAI( | 57 | self.client = OpenAI( |
| 57 | api_key=self.api_key, | 58 | api_key=self.api_key, |
| 58 | base_url=self.base_url | 59 | base_url=self.base_url |
| 59 | ) | 60 | ) |
| 60 | - self.model = model_name or KEYWORD_OPTIMIZER_MODEL_NAME | 61 | + self.model = model_name or settings.KEYWORD_OPTIMIZER_MODEL_NAME |
| 61 | 62 | ||
| 62 | def optimize_keywords(self, original_query: str, context: str = "") -> KeywordOptimizationResponse: | 63 | def optimize_keywords(self, original_query: str, context: str = "") -> KeywordOptimizationResponse: |
| 63 | """ | 64 | """ |
| @@ -70,7 +71,7 @@ class KeywordOptimizer: | @@ -70,7 +71,7 @@ class KeywordOptimizer: | ||
| 70 | Returns: | 71 | Returns: |
| 71 | KeywordOptimizationResponse: 优化后的关键词列表 | 72 | KeywordOptimizationResponse: 优化后的关键词列表 |
| 72 | """ | 73 | """ |
| 73 | - print(f"🔍 关键词优化中间件: 处理查询 '{original_query}'") | 74 | + logger.info(f"🔍 关键词优化中间件: 处理查询 '{original_query}'") |
| 74 | 75 | ||
| 75 | try: | 76 | try: |
| 76 | # 构建优化prompt | 77 | # 构建优化prompt |
| @@ -97,9 +98,13 @@ class KeywordOptimizer: | @@ -97,9 +98,13 @@ class KeywordOptimizer: | ||
| 97 | # 验证关键词质量 | 98 | # 验证关键词质量 |
| 98 | validated_keywords = self._validate_keywords(keywords) | 99 | validated_keywords = self._validate_keywords(keywords) |
| 99 | 100 | ||
| 100 | - print(f"✅ 优化成功: {len(validated_keywords)}个关键词") | ||
| 101 | - for i, keyword in enumerate(validated_keywords, 1): | ||
| 102 | - print(f" {i}. '{keyword}'") | 101 | + logger.info( |
| 102 | + f"✅ 优化成功: {len(validated_keywords)}个关键词" + | ||
| 103 | + ("" if not validated_keywords else "\n" + | ||
| 104 | + "\n".join([f" {i}. '{k}'" for i, k in enumerate(validated_keywords, 1)])) | ||
| 105 | + ) | ||
| 106 | + | ||
| 107 | + | ||
| 103 | 108 | ||
| 104 | return KeywordOptimizationResponse( | 109 | return KeywordOptimizationResponse( |
| 105 | original_query=original_query, | 110 | original_query=original_query, |
| @@ -109,7 +114,7 @@ class KeywordOptimizer: | @@ -109,7 +114,7 @@ class KeywordOptimizer: | ||
| 109 | ) | 114 | ) |
| 110 | 115 | ||
| 111 | except Exception as e: | 116 | except Exception as e: |
| 112 | - print(f"⚠️ 解析响应失败,使用备用方案: {str(e)}") | 117 | + logger.exception(f"⚠️ 解析响应失败,使用备用方案: {str(e)}") |
| 113 | # 备用方案:从原始查询中提取关键词 | 118 | # 备用方案:从原始查询中提取关键词 |
| 114 | fallback_keywords = self._fallback_keyword_extraction(original_query) | 119 | fallback_keywords = self._fallback_keyword_extraction(original_query) |
| 115 | return KeywordOptimizationResponse( | 120 | return KeywordOptimizationResponse( |
| @@ -119,7 +124,7 @@ class KeywordOptimizer: | @@ -119,7 +124,7 @@ class KeywordOptimizer: | ||
| 119 | success=True | 124 | success=True |
| 120 | ) | 125 | ) |
| 121 | else: | 126 | else: |
| 122 | - print(f"❌ API调用失败: {response['error']}") | 127 | + logger.error(f"❌ API调用失败: {response['error']}") |
| 123 | # 使用备用方案 | 128 | # 使用备用方案 |
| 124 | fallback_keywords = self._fallback_keyword_extraction(original_query) | 129 | fallback_keywords = self._fallback_keyword_extraction(original_query) |
| 125 | return KeywordOptimizationResponse( | 130 | return KeywordOptimizationResponse( |
| @@ -131,7 +136,7 @@ class KeywordOptimizer: | @@ -131,7 +136,7 @@ class KeywordOptimizer: | ||
| 131 | ) | 136 | ) |
| 132 | 137 | ||
| 133 | except Exception as e: | 138 | except Exception as e: |
| 134 | - print(f"❌ 关键词优化失败: {str(e)}") | 139 | + logger.error(f"❌ 关键词优化失败: {str(e)}") |
| 135 | # 最终备用方案 | 140 | # 最终备用方案 |
| 136 | fallback_keywords = self._fallback_keyword_extraction(original_query) | 141 | fallback_keywords = self._fallback_keyword_extraction(original_query) |
| 137 | return KeywordOptimizationResponse( | 142 | return KeywordOptimizationResponse( |
| @@ -25,10 +25,11 @@ V3.0 核心更新: | @@ -25,10 +25,11 @@ V3.0 核心更新: | ||
| 25 | 25 | ||
| 26 | import os | 26 | import os |
| 27 | import json | 27 | import json |
| 28 | -import pymysql | ||
| 29 | -import pymysql.cursors | 28 | +from loguru import logger |
| 29 | +import asyncio | ||
| 30 | from typing import List, Dict, Any, Optional, Literal | 30 | from typing import List, Dict, Any, Optional, Literal |
| 31 | from dataclasses import dataclass, field | 31 | from dataclasses import dataclass, field |
| 32 | +from ..utils.db import fetch_all | ||
| 32 | from datetime import datetime, timedelta, date | 33 | from datetime import datetime, timedelta, date |
| 33 | 34 | ||
| 34 | # --- 1. 数据结构定义 --- | 35 | # --- 1. 数据结构定义 --- |
| @@ -69,36 +70,28 @@ class MediaCrawlerDB: | @@ -69,36 +70,28 @@ class MediaCrawlerDB: | ||
| 69 | 70 | ||
| 70 | def __init__(self): | 71 | def __init__(self): |
| 71 | """ | 72 | """ |
| 72 | - 初始化客户端。连接信息从环境变量自动读取: | ||
| 73 | - - DB_HOST, DB_USER, DB_PASSWORD, DB_NAME | ||
| 74 | - - DB_PORT (可选, 默认 3306) | ||
| 75 | - - DB_CHARSET (可选, 默认 utf8mb4) | 73 | + 初始化客户端。 |
| 76 | """ | 74 | """ |
| 77 | - self.db_config = { | ||
| 78 | - 'host': os.getenv("DB_HOST"), | ||
| 79 | - 'user': os.getenv("DB_USER"), | ||
| 80 | - 'password': os.getenv("DB_PASSWORD"), | ||
| 81 | - 'db': os.getenv("DB_NAME"), | ||
| 82 | - 'port': int(os.getenv("DB_PORT", 3306)), | ||
| 83 | - 'charset': os.getenv("DB_CHARSET", "utf8mb4"), | ||
| 84 | - 'cursorclass': pymysql.cursors.DictCursor | ||
| 85 | - } | ||
| 86 | - required = ['host', 'user', 'password', 'db'] | ||
| 87 | - if missing := [k for k in required if not self.db_config[k]]: | ||
| 88 | - raise ValueError(f"数据库配置缺失! 请设置环境变量或在代码中提供: {', '.join([f'DB_{k.upper()}' for k in missing])}") | ||
| 89 | - | 75 | + pass |
| 76 | + | ||
| 90 | def _execute_query(self, query: str, params: tuple = None) -> List[Dict[str, Any]]: | 77 | def _execute_query(self, query: str, params: tuple = None) -> List[Dict[str, Any]]: |
| 91 | - conn = None | ||
| 92 | try: | 78 | try: |
| 93 | - conn = pymysql.connect(**self.db_config) | ||
| 94 | - with conn.cursor() as cursor: | ||
| 95 | - cursor.execute(query, params or ()) | ||
| 96 | - return cursor.fetchall() | ||
| 97 | - except pymysql.Error as e: | ||
| 98 | - print(f"数据库查询时发生错误: {e}") | 79 | + # 获取或创建event loop |
| 80 | + try: | ||
| 81 | + loop = asyncio.get_event_loop() | ||
| 82 | + if loop.is_closed(): | ||
| 83 | + loop = asyncio.new_event_loop() | ||
| 84 | + asyncio.set_event_loop(loop) | ||
| 85 | + except RuntimeError: | ||
| 86 | + loop = asyncio.new_event_loop() | ||
| 87 | + asyncio.set_event_loop(loop) | ||
| 88 | + | ||
| 89 | + # 直接运行协程 | ||
| 90 | + return loop.run_until_complete(fetch_all(query, params)) | ||
| 91 | + | ||
| 92 | + except Exception as e: | ||
| 93 | + logger.exception(f"数据库查询时发生错误: {e}") | ||
| 99 | return [] | 94 | return [] |
| 100 | - finally: | ||
| 101 | - if conn: conn.close() | ||
| 102 | 95 | ||
| 103 | @staticmethod | 96 | @staticmethod |
| 104 | def _to_datetime(ts: Any) -> Optional[datetime]: | 97 | def _to_datetime(ts: Any) -> Optional[datetime]: |
| @@ -149,7 +142,7 @@ class MediaCrawlerDB: | @@ -149,7 +142,7 @@ class MediaCrawlerDB: | ||
| 149 | DBResponse: 包含按综合热度排序后的内容列表。 | 142 | DBResponse: 包含按综合热度排序后的内容列表。 |
| 150 | """ | 143 | """ |
| 151 | params_for_log = {'time_period': time_period, 'limit': limit} | 144 | params_for_log = {'time_period': time_period, 'limit': limit} |
| 152 | - print(f"--- TOOL: 查找热点内容 (params: {params_for_log}) ---") | 145 | + logger.info(f"--- TOOL: 查找热点内容 (params: {params_for_log}) ---") |
| 153 | 146 | ||
| 154 | now = datetime.now() | 147 | now = datetime.now() |
| 155 | start_time = now - timedelta(days={'24h': 1, 'week': 7}.get(time_period, 365)) | 148 | start_time = now - timedelta(days={'24h': 1, 'week': 7}.get(time_period, 365)) |
| @@ -202,22 +195,28 @@ class MediaCrawlerDB: | @@ -202,22 +195,28 @@ class MediaCrawlerDB: | ||
| 202 | DBResponse: 包含所有匹配结果的聚合列表。 | 195 | DBResponse: 包含所有匹配结果的聚合列表。 |
| 203 | """ | 196 | """ |
| 204 | params_for_log = {'topic': topic, 'limit_per_table': limit_per_table} | 197 | params_for_log = {'topic': topic, 'limit_per_table': limit_per_table} |
| 205 | - print(f"--- TOOL: 全局话题搜索 (params: {params_for_log}) ---") | 198 | + logger.info(f"--- TOOL: 全局话题搜索 (params: {params_for_log}) ---") |
| 206 | 199 | ||
| 207 | search_term, all_results = f"%{topic}%", [] | 200 | search_term, all_results = f"%{topic}%", [] |
| 208 | search_configs = { 'bilibili_video': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'bilibili_video_comment': {'fields': ['content'], 'type': 'comment'}, 'douyin_aweme': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'douyin_aweme_comment': {'fields': ['content'], 'type': 'comment'}, 'kuaishou_video': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'kuaishou_video_comment': {'fields': ['content'], 'type': 'comment'}, 'weibo_note': {'fields': ['content', 'source_keyword'], 'type': 'note'}, 'weibo_note_comment': {'fields': ['content'], 'type': 'comment'}, 'xhs_note': {'fields': ['title', 'desc', 'tag_list', 'source_keyword'], 'type': 'note'}, 'xhs_note_comment': {'fields': ['content'], 'type': 'comment'}, 'zhihu_content': {'fields': ['title', 'desc', 'content_text', 'source_keyword'], 'type': 'content'}, 'zhihu_comment': {'fields': ['content'], 'type': 'comment'}, 'tieba_note': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'note'}, 'tieba_comment': {'fields': ['content'], 'type': 'comment'}, 'daily_news': {'fields': ['title'], 'type': 'news'}, } | 201 | search_configs = { 'bilibili_video': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'bilibili_video_comment': {'fields': ['content'], 'type': 'comment'}, 'douyin_aweme': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'douyin_aweme_comment': {'fields': ['content'], 'type': 'comment'}, 'kuaishou_video': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'kuaishou_video_comment': {'fields': ['content'], 'type': 'comment'}, 'weibo_note': {'fields': ['content', 'source_keyword'], 'type': 'note'}, 'weibo_note_comment': {'fields': ['content'], 'type': 'comment'}, 'xhs_note': {'fields': ['title', 'desc', 'tag_list', 'source_keyword'], 'type': 'note'}, 'xhs_note_comment': {'fields': ['content'], 'type': 'comment'}, 'zhihu_content': {'fields': ['title', 'desc', 'content_text', 'source_keyword'], 'type': 'content'}, 'zhihu_comment': {'fields': ['content'], 'type': 'comment'}, 'tieba_note': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'note'}, 'tieba_comment': {'fields': ['content'], 'type': 'comment'}, 'daily_news': {'fields': ['title'], 'type': 'news'}, } |
| 209 | 202 | ||
| 210 | for table, config in search_configs.items(): | 203 | for table, config in search_configs.items(): |
| 211 | - where_clause = " OR ".join([f"`{field}` LIKE %s" for field in config['fields']]) | ||
| 212 | - query = f"SELECT * FROM `{table}` WHERE {where_clause} ORDER BY id DESC LIMIT %s" | ||
| 213 | - params = (search_term,) * len(config['fields']) + (limit_per_table,) | ||
| 214 | - raw_results = self._execute_query(query, params) | 204 | + param_dict = {} |
| 205 | + where_clauses = [] | ||
| 206 | + for idx, field in enumerate(config['fields']): | ||
| 207 | + pname = f"term_{idx}" | ||
| 208 | + where_clauses.append(f'"{field}" LIKE :{pname}') | ||
| 209 | + param_dict[pname] = search_term | ||
| 210 | + param_dict['limit'] = limit_per_table | ||
| 211 | + where_clause = " OR ".join(where_clauses) | ||
| 212 | + query = f'SELECT * FROM "{table}" WHERE {where_clause} ORDER BY id DESC LIMIT :limit' | ||
| 213 | + raw_results = self._execute_query(query, param_dict) | ||
| 215 | for row in raw_results: | 214 | for row in raw_results: |
| 216 | content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', '')) | 215 | content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', '')) |
| 217 | time_key = row.get('create_time') or row.get('time') or row.get('created_time') or row.get('publish_time') or row.get('crawl_date') | 216 | time_key = row.get('create_time') or row.get('time') or row.get('created_time') or row.get('publish_time') or row.get('crawl_date') |
| 218 | all_results.append(QueryResult( | 217 | all_results.append(QueryResult( |
| 219 | platform=table.split('_')[0], content_type=config['type'], | 218 | platform=table.split('_')[0], content_type=config['type'], |
| 220 | - title_or_content=content[:500] if content else '', | 219 | + title_or_content=content if content else '', |
| 221 | author_nickname=row.get('nickname') or row.get('user_nickname') or row.get('user_name'), | 220 | author_nickname=row.get('nickname') or row.get('user_nickname') or row.get('user_name'), |
| 222 | url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'), | 221 | url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'), |
| 223 | publish_time=self._to_datetime(time_key), | 222 | publish_time=self._to_datetime(time_key), |
| @@ -241,7 +240,7 @@ class MediaCrawlerDB: | @@ -241,7 +240,7 @@ class MediaCrawlerDB: | ||
| 241 | DBResponse: 包含在指定日期范围内找到的结果的聚合列表。 | 240 | DBResponse: 包含在指定日期范围内找到的结果的聚合列表。 |
| 242 | """ | 241 | """ |
| 243 | params_for_log = {'topic': topic, 'start_date': start_date, 'end_date': end_date, 'limit_per_table': limit_per_table} | 242 | params_for_log = {'topic': topic, 'start_date': start_date, 'end_date': end_date, 'limit_per_table': limit_per_table} |
| 244 | - print(f"--- TOOL: 按日期搜索话题 (params: {params_for_log}) ---") | 243 | + logger.info(f"--- TOOL: 按日期搜索话题 (params: {params_for_log}) ---") |
| 245 | 244 | ||
| 246 | try: | 245 | try: |
| 247 | start_dt, end_dt = datetime.strptime(start_date, '%Y-%m-%d'), datetime.strptime(end_date, '%Y-%m-%d') + timedelta(days=1) | 246 | start_dt, end_dt = datetime.strptime(start_date, '%Y-%m-%d'), datetime.strptime(end_date, '%Y-%m-%d') + timedelta(days=1) |
| @@ -257,25 +256,25 @@ class MediaCrawlerDB: | @@ -257,25 +256,25 @@ class MediaCrawlerDB: | ||
| 257 | } | 256 | } |
| 258 | 257 | ||
| 259 | for table, config in search_configs.items(): | 258 | for table, config in search_configs.items(): |
| 260 | - topic_clause = " OR ".join([f"`{field}` LIKE %s" for field in config['fields']]) | ||
| 261 | - time_col, time_type = config['time_col'], config['time_type'] | ||
| 262 | - if time_type == 'sec': time_params = (int(start_dt.timestamp()), int(end_dt.timestamp())) | ||
| 263 | - elif time_type == 'ms': time_params = (int(start_dt.timestamp() * 1000), int(end_dt.timestamp() * 1000)) | ||
| 264 | - elif time_type in ['str', 'date_str']: time_params = (start_dt.strftime('%Y-%m-%d'), end_dt.strftime('%Y-%m-%d')) | ||
| 265 | - else: time_params = (str(int(start_dt.timestamp())), str(int(end_dt.timestamp()))) | ||
| 266 | - time_clause = f"`{time_col}` >= %s AND `{time_col}` < %s" | ||
| 267 | - if table == 'zhihu_content': time_clause = f"CAST(`{time_col}` AS UNSIGNED) >= %s AND CAST(`{time_col}` AS UNSIGNED) < %s" | ||
| 268 | - query = f"SELECT * FROM `{table}` WHERE ({topic_clause}) AND ({time_clause}) ORDER BY id DESC LIMIT %s" | ||
| 269 | - params = (search_term,) * len(config['fields']) + time_params + (limit_per_table,) | ||
| 270 | - raw_results = self._execute_query(query, params) | 259 | + param_dict = {} |
| 260 | + where_clauses = [] | ||
| 261 | + for idx, field in enumerate(config['fields']): | ||
| 262 | + pname = f"term_{idx}" | ||
| 263 | + where_clauses.append(f'"{field}" LIKE :{pname}') | ||
| 264 | + param_dict[pname] = search_term | ||
| 265 | + param_dict['limit'] = limit_per_table | ||
| 266 | + where_clause = ' OR '.join(where_clauses) | ||
| 267 | + query = f'SELECT * FROM "{table}" WHERE {where_clause} ORDER BY id DESC LIMIT :limit' | ||
| 268 | + raw_results = self._execute_query(query, param_dict) | ||
| 271 | for row in raw_results: | 269 | for row in raw_results: |
| 272 | content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', '')) | 270 | content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', '')) |
| 271 | + time_key = row.get('create_time') or row.get('time') or row.get('created_time') or row.get('publish_time') or row.get('crawl_date') | ||
| 273 | all_results.append(QueryResult( | 272 | all_results.append(QueryResult( |
| 274 | platform=table.split('_')[0], content_type=config['type'], | 273 | platform=table.split('_')[0], content_type=config['type'], |
| 275 | - title_or_content=content[:500] if content else '', | ||
| 276 | - author_nickname=row.get('nickname') or row.get('user_nickname'), | 274 | + title_or_content=content if content else '', |
| 275 | + author_nickname=row.get('nickname') or row.get('user_nickname') or row.get('user_name'), | ||
| 277 | url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'), | 276 | url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'), |
| 278 | - publish_time=self._to_datetime(row.get(config['time_col'])), | 277 | + publish_time=self._to_datetime(time_key), |
| 279 | engagement=self._extract_engagement(row), | 278 | engagement=self._extract_engagement(row), |
| 280 | source_keyword=row.get('source_keyword'), | 279 | source_keyword=row.get('source_keyword'), |
| 281 | source_table=table | 280 | source_table=table |
| @@ -294,7 +293,7 @@ class MediaCrawlerDB: | @@ -294,7 +293,7 @@ class MediaCrawlerDB: | ||
| 294 | DBResponse: 包含匹配的评论列表。 | 293 | DBResponse: 包含匹配的评论列表。 |
| 295 | """ | 294 | """ |
| 296 | params_for_log = {'topic': topic, 'limit': limit} | 295 | params_for_log = {'topic': topic, 'limit': limit} |
| 297 | - print(f"--- TOOL: 获取话题评论 (params: {params_for_log}) ---") | 296 | + logger.info(f"--- TOOL: 获取话题评论 (params: {params_for_log}) ---") |
| 298 | 297 | ||
| 299 | search_term = f"%{topic}%" | 298 | search_term = f"%{topic}%" |
| 300 | comment_tables = ['bilibili_video_comment', 'douyin_aweme_comment', 'kuaishou_video_comment', 'weibo_note_comment', 'xhs_note_comment', 'zhihu_comment', 'tieba_comment'] | 299 | comment_tables = ['bilibili_video_comment', 'douyin_aweme_comment', 'kuaishou_video_comment', 'weibo_note_comment', 'xhs_note_comment', 'zhihu_comment', 'tieba_comment'] |
| @@ -341,7 +340,7 @@ class MediaCrawlerDB: | @@ -341,7 +340,7 @@ class MediaCrawlerDB: | ||
| 341 | DBResponse: 包含在该平台找到的结果列表。 | 340 | DBResponse: 包含在该平台找到的结果列表。 |
| 342 | """ | 341 | """ |
| 343 | params_for_log = {'platform': platform, 'topic': topic, 'start_date': start_date, 'end_date': end_date, 'limit': limit} | 342 | params_for_log = {'platform': platform, 'topic': topic, 'start_date': start_date, 'end_date': end_date, 'limit': limit} |
| 344 | - print(f"--- TOOL: 平台定向搜索 (params: {params_for_log}) ---") | 343 | + logger.info(f"--- TOOL: 平台定向搜索 (params: {params_for_log}) ---") |
| 345 | 344 | ||
| 346 | all_configs = { 'bilibili': [{'table': 'bilibili_video', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'sec'}, {'table': 'bilibili_video_comment', 'fields': ['content'], 'type': 'comment'}], 'douyin': [{'table': 'douyin_aweme', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'ms'}, {'table': 'douyin_aweme_comment', 'fields': ['content'], 'type': 'comment'}], 'kuaishou': [{'table': 'kuaishou_video', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'ms'}, {'table': 'kuaishou_video_comment', 'fields': ['content'], 'type': 'comment'}], 'weibo': [{'table': 'weibo_note', 'fields': ['content', 'source_keyword'], 'type': 'note', 'time_col': 'create_date_time', 'time_type': 'str'}, {'table': 'weibo_note_comment', 'fields': ['content'], 'type': 'comment'}], 'xhs': [{'table': 'xhs_note', 'fields': ['title', 'desc', 'tag_list', 'source_keyword'], 'type': 'note', 'time_col': 'time', 'time_type': 'ms'}, {'table': 'xhs_note_comment', 'fields': ['content'], 'type': 'comment'}], 'zhihu': [{'table': 'zhihu_content', 'fields': ['title', 'desc', 'content_text', 'source_keyword'], 'type': 'content', 'time_col': 'created_time', 'time_type': 'sec_str'}, {'table': 'zhihu_comment', 'fields': ['content'], 'type': 'comment'}], 'tieba': [{'table': 'tieba_note', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'note', 'time_col': 'publish_time', 'time_type': 'str'}, {'table': 'tieba_comment', 'fields': ['content'], 'type': 'comment'}] } | 345 | all_configs = { 'bilibili': [{'table': 'bilibili_video', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'sec'}, {'table': 'bilibili_video_comment', 'fields': ['content'], 'type': 'comment'}], 'douyin': [{'table': 'douyin_aweme', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'ms'}, {'table': 'douyin_aweme_comment', 'fields': ['content'], 'type': 'comment'}], 'kuaishou': [{'table': 'kuaishou_video', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'ms'}, {'table': 'kuaishou_video_comment', 'fields': ['content'], 'type': 'comment'}], 'weibo': [{'table': 'weibo_note', 'fields': ['content', 'source_keyword'], 'type': 'note', 'time_col': 'create_date_time', 'time_type': 'str'}, {'table': 'weibo_note_comment', 'fields': ['content'], 'type': 'comment'}], 'xhs': [{'table': 'xhs_note', 'fields': ['title', 'desc', 'tag_list', 'source_keyword'], 'type': 'note', 'time_col': 'time', 'time_type': 'ms'}, {'table': 'xhs_note_comment', 'fields': ['content'], 'type': 'comment'}], 'zhihu': [{'table': 'zhihu_content', 'fields': ['title', 'desc', 'content_text', 'source_keyword'], 'type': 'content', 'time_col': 'created_time', 'time_type': 'sec_str'}, {'table': 'zhihu_comment', 'fields': ['content'], 'type': 'comment'}], 'tieba': [{'table': 'tieba_note', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'note', 'time_col': 'publish_time', 'time_type': 'str'}, {'table': 'tieba_comment', 'fields': ['content'], 'type': 'comment'}] } |
| 347 | 346 | ||
| @@ -386,7 +385,7 @@ class MediaCrawlerDB: | @@ -386,7 +385,7 @@ class MediaCrawlerDB: | ||
| 386 | for row in raw_results: | 385 | for row in raw_results: |
| 387 | content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', '')) | 386 | content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', '')) |
| 388 | time_key = config.get('time_col') and row.get(config.get('time_col')) | 387 | time_key = config.get('time_col') and row.get(config.get('time_col')) |
| 389 | - all_results.append(QueryResult(platform=platform, content_type=config['type'], title_or_content=content[:500] if content else '', author_nickname=row.get('nickname') or row.get('user_nickname'), url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'), publish_time=self._to_datetime(time_key), engagement=self._extract_engagement(row), source_keyword=row.get('source_keyword'), source_table=table)) | 388 | + all_results.append(QueryResult(platform=platform, content_type=config['type'], title_or_content=content if content else '', author_nickname=row.get('nickname') or row.get('user_nickname'), url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'), publish_time=self._to_datetime(time_key), engagement=self._extract_engagement(row), source_keyword=row.get('source_keyword'), source_table=table)) |
| 390 | 389 | ||
| 391 | return DBResponse("search_topic_on_platform", params_for_log, results=all_results, results_count=len(all_results)) | 390 | return DBResponse("search_topic_on_platform", params_for_log, results=all_results, results_count=len(all_results)) |
| 392 | 391 | ||
| @@ -394,33 +393,41 @@ class MediaCrawlerDB: | @@ -394,33 +393,41 @@ class MediaCrawlerDB: | ||
| 394 | def print_response_summary(response: DBResponse): | 393 | def print_response_summary(response: DBResponse): |
| 395 | """简化的打印函数,用于展示测试结果""" | 394 | """简化的打印函数,用于展示测试结果""" |
| 396 | if response.error_message: | 395 | if response.error_message: |
| 397 | - print(f"工具 '{response.tool_name}' 执行出错: {response.error_message}") | ||
| 398 | - print("-" * 80) | 396 | + logger.info(f"工具 '{response.tool_name}' 执行出错: {response.error_message}") |
| 399 | return | 397 | return |
| 400 | 398 | ||
| 401 | params_str = ", ".join(f"{k}='{v}'" for k, v in response.parameters.items()) | 399 | params_str = ", ".join(f"{k}='{v}'" for k, v in response.parameters.items()) |
| 402 | - print(f"查询: 工具='{response.tool_name}', 参数=[{params_str}]") | ||
| 403 | - print(f"找到 {response.results_count} 条相关记录。") | 400 | + logger.info(f"查询: 工具='{response.tool_name}', 参数=[{params_str}]") |
| 401 | + logger.info(f"找到 {response.results_count} 条相关记录。") | ||
| 404 | 402 | ||
| 405 | - if response.results: | ||
| 406 | - print("--- 前5条结果示例 ---") | ||
| 407 | - for i, res in enumerate(response.results[:5]): | ||
| 408 | - engagement_str = ", ".join(f"{k}: {v}" for k, v in res.engagement.items() if v) | ||
| 409 | - content_preview = (res.title_or_content.replace('\n', ' ')[:70] + '...') if res.title_or_content and len(res.title_or_content) > 70 else res.title_or_content | ||
| 410 | - hotness_str = f", hotness: {res.hotness_score:.2f}" if res.hotness_score > 0 else "" | ||
| 411 | - print( | ||
| 412 | - f"{i+1}. [{res.platform.upper()}/{res.content_type}] {content_preview}\n" | ||
| 413 | - f" by: {res.author_nickname}, at: {res.publish_time.strftime('%Y-%m-%d %H:%M') if res.publish_time else 'N/A'}" | ||
| 414 | - f", src_kw: '{res.source_keyword or 'N/A'}'{hotness_str}" | ||
| 415 | - f", engagement: {{{engagement_str}}}" | 403 | + # 统一为一个消息输出 |
| 404 | + output_lines = [] | ||
| 405 | + output_lines.append("==== 查询结果预览(最多前5条) ====") | ||
| 406 | + if response.results and len(response.results) > 0: | ||
| 407 | + for idx, res in enumerate(response.results[:5], 1): | ||
| 408 | + content_preview = (res.title_or_content.replace('\n', ' ')[:70] + '...') if res.title_or_content and len(res.title_or_content) > 70 else (res.title_or_content or '') | ||
| 409 | + author_str = res.author_nickname or "N/A" | ||
| 410 | + publish_time_str = res.publish_time.strftime('%Y-%m-%d %H:%M') if res.publish_time else "N/A" | ||
| 411 | + hotness_str = f", hotness: {res.hotness_score:.2f}" if getattr(res, "hotness_score", 0) > 0 else "" | ||
| 412 | + engagement_dict = getattr(res, "engagement", {}) or {} | ||
| 413 | + engagement_str = ", ".join(f"{k}: {v}" for k, v in engagement_dict.items() if v) | ||
| 414 | + output_lines.append( | ||
| 415 | + f"{idx}. [{res.platform.upper()}/{res.content_type}] {content_preview}\n" | ||
| 416 | + f" 作者: {author_str} | 时间: {publish_time_str}" | ||
| 417 | + f"{hotness_str} | 源关键词: '{res.source_keyword or 'N/A'}'\n" | ||
| 418 | + f" 链接: {res.url or 'N/A'}\n" | ||
| 419 | + f" 互动数据: {{{engagement_str}}}" | ||
| 416 | ) | 420 | ) |
| 417 | - print("-" * 80) | 421 | + else: |
| 422 | + output_lines.append("暂无相关内容。") | ||
| 423 | + output_lines.append("=" * 60) | ||
| 424 | + logger.info('\n'.join(output_lines)) | ||
| 418 | 425 | ||
| 419 | if __name__ == "__main__": | 426 | if __name__ == "__main__": |
| 420 | 427 | ||
| 421 | try: | 428 | try: |
| 422 | db_agent_tools = MediaCrawlerDB() | 429 | db_agent_tools = MediaCrawlerDB() |
| 423 | - print("数据库工具初始化成功,开始执行测试场景...\n") | 430 | + logger.info("数据库工具初始化成功,开始执行测试场景...\n") |
| 424 | 431 | ||
| 425 | # 场景1: (新) 查找过去一周综合热度最高的内容 (不再需要sort_by) | 432 | # 场景1: (新) 查找过去一周综合热度最高的内容 (不再需要sort_by) |
| 426 | response1 = db_agent_tools.search_hot_content(time_period='week', limit=5) | 433 | response1 = db_agent_tools.search_hot_content(time_period='week', limit=5) |
| @@ -443,7 +450,7 @@ if __name__ == "__main__": | @@ -443,7 +450,7 @@ if __name__ == "__main__": | ||
| 443 | print_response_summary(response5) | 450 | print_response_summary(response5) |
| 444 | 451 | ||
| 445 | except ValueError as e: | 452 | except ValueError as e: |
| 446 | - print(f"初始化失败: {e}") | ||
| 447 | - print("请确保相关的数据库环境变量已正确设置, 或在代码中直接提供连接信息。") | 453 | + logger.exception(f"初始化失败: {e}") |
| 454 | + logger.exception("请确保相关的数据库环境变量已正确设置, 或在代码中直接提供连接信息。") | ||
| 448 | except Exception as e: | 455 | except Exception as e: |
| 449 | - print(f"测试过程中发生未知错误: {e}") | ||
| 456 | + logger.exception(f"测试过程中发生未知错误: {e}") |
| @@ -12,8 +12,6 @@ from .text_processing import ( | @@ -12,8 +12,6 @@ from .text_processing import ( | ||
| 12 | format_search_results_for_prompt | 12 | format_search_results_for_prompt |
| 13 | ) | 13 | ) |
| 14 | 14 | ||
| 15 | -from .config import Config, load_config | ||
| 16 | - | ||
| 17 | __all__ = [ | 15 | __all__ = [ |
| 18 | "clean_json_tags", | 16 | "clean_json_tags", |
| 19 | "clean_markdown_tags", | 17 | "clean_markdown_tags", |
| @@ -21,6 +19,4 @@ __all__ = [ | @@ -21,6 +19,4 @@ __all__ = [ | ||
| 21 | "extract_clean_response", | 19 | "extract_clean_response", |
| 22 | "update_state_with_search_results", | 20 | "update_state_with_search_results", |
| 23 | "format_search_results_for_prompt", | 21 | "format_search_results_for_prompt", |
| 24 | - "Config", | ||
| 25 | - "load_config" | ||
| 26 | ] | 22 | ] |
| @@ -6,218 +6,40 @@ Handles environment variables and config file parameters. | @@ -6,218 +6,40 @@ Handles environment variables and config file parameters. | ||
| 6 | import os | 6 | import os |
| 7 | from dataclasses import dataclass | 7 | from dataclasses import dataclass |
| 8 | from typing import Optional | 8 | from typing import Optional |
| 9 | - | ||
| 10 | - | ||
| 11 | -def _get_value(source, key: str, default=None): | ||
| 12 | - """ | ||
| 13 | - Helper to fetch a configuration value with environment fallback. | ||
| 14 | - """ | ||
| 15 | - value = None | ||
| 16 | - if isinstance(source, dict): | ||
| 17 | - value = source.get(key) | ||
| 18 | - else: | ||
| 19 | - value = getattr(source, key, None) | ||
| 20 | - | ||
| 21 | - if value is None: | ||
| 22 | - value = os.getenv(key, default) | ||
| 23 | - return value if value not in ("", None) else default | ||
| 24 | - | ||
| 25 | - | ||
| 26 | -@dataclass | ||
| 27 | -class Config: | ||
| 28 | - """Insight Engine configuration.""" | ||
| 29 | - | ||
| 30 | - # LLM configuration | ||
| 31 | - llm_api_key: Optional[str] = None | ||
| 32 | - llm_base_url: Optional[str] = None | ||
| 33 | - llm_model_name: Optional[str] = None | ||
| 34 | - llm_provider: Optional[str] = None # kept for backward compatibility | ||
| 35 | - | ||
| 36 | - # Database configuration | ||
| 37 | - db_host: Optional[str] = None | ||
| 38 | - db_user: Optional[str] = None | ||
| 39 | - db_password: Optional[str] = None | ||
| 40 | - db_name: Optional[str] = None | ||
| 41 | - db_port: int = 3306 | ||
| 42 | - db_charset: str = "utf8mb4" | ||
| 43 | - | ||
| 44 | - # Model behaviour configuration | ||
| 45 | - max_reflections: int = 3 | ||
| 46 | - max_paragraphs: int = 6 | ||
| 47 | - search_timeout: int = 240 | ||
| 48 | - max_content_length: int = 500000 | ||
| 49 | - | ||
| 50 | - # Search result limits | ||
| 51 | - default_search_hot_content_limit: int = 100 | ||
| 52 | - default_search_topic_globally_limit_per_table: int = 50 | ||
| 53 | - default_search_topic_by_date_limit_per_table: int = 100 | ||
| 54 | - default_get_comments_for_topic_limit: int = 500 | ||
| 55 | - default_search_topic_on_platform_limit: int = 200 | ||
| 56 | - max_search_results_for_llm: int = 0 | ||
| 57 | - max_high_confidence_sentiment_results: int = 0 | ||
| 58 | - | ||
| 59 | - # Output configuration | ||
| 60 | - output_dir: str = "reports" | ||
| 61 | - save_intermediate_states: bool = True | ||
| 62 | - | ||
| 63 | - def __post_init__(self): | ||
| 64 | - if not self.llm_provider and self.llm_model_name: | ||
| 65 | - # Provider is no longer used, but keep the attribute for compatibility. | ||
| 66 | - self.llm_provider = self.llm_model_name | ||
| 67 | - | ||
| 68 | - def validate(self) -> bool: | ||
| 69 | - """Validate configuration.""" | ||
| 70 | - if not self.llm_api_key: | ||
| 71 | - print("错误: Insight Engine LLM API Key 未设置 (INSIGHT_ENGINE_API_KEY)。") | ||
| 72 | - return False | ||
| 73 | - | ||
| 74 | - if not self.llm_model_name: | ||
| 75 | - print("错误: Insight Engine 模型名称未设置 (INSIGHT_ENGINE_MODEL_NAME)。") | ||
| 76 | - return False | ||
| 77 | - | ||
| 78 | - if not all([self.db_host, self.db_user, self.db_password, self.db_name]): | ||
| 79 | - print("错误: 数据库连接信息不完整,请检查 config.py 中的 DB_* 配置。") | ||
| 80 | - return False | ||
| 81 | - | ||
| 82 | - return True | ||
| 83 | - | ||
| 84 | - @classmethod | ||
| 85 | - def from_file(cls, config_file: str) -> "Config": | ||
| 86 | - """Create configuration from file.""" | ||
| 87 | - if config_file.endswith(".py"): | ||
| 88 | - import importlib.util | ||
| 89 | - | ||
| 90 | - spec = importlib.util.spec_from_file_location("config", config_file) | ||
| 91 | - config_module = importlib.util.module_from_spec(spec) | ||
| 92 | - spec.loader.exec_module(config_module) | ||
| 93 | - | ||
| 94 | - return cls( | ||
| 95 | - llm_api_key=_get_value(config_module, "INSIGHT_ENGINE_API_KEY"), | ||
| 96 | - llm_base_url=_get_value(config_module, "INSIGHT_ENGINE_BASE_URL"), | ||
| 97 | - llm_model_name=_get_value(config_module, "INSIGHT_ENGINE_MODEL_NAME"), | ||
| 98 | - db_host=_get_value(config_module, "DB_HOST"), | ||
| 99 | - db_user=_get_value(config_module, "DB_USER"), | ||
| 100 | - db_password=_get_value(config_module, "DB_PASSWORD"), | ||
| 101 | - db_name=_get_value(config_module, "DB_NAME"), | ||
| 102 | - db_port=int(_get_value(config_module, "DB_PORT", 3306)), | ||
| 103 | - db_charset=_get_value(config_module, "DB_CHARSET", "utf8mb4"), | ||
| 104 | - max_reflections=int(_get_value(config_module, "MAX_REFLECTIONS", 3)), | ||
| 105 | - max_paragraphs=int(_get_value(config_module, "MAX_PARAGRAPHS", 6)), | ||
| 106 | - search_timeout=int(_get_value(config_module, "SEARCH_TIMEOUT", 240)), | ||
| 107 | - max_content_length=int(_get_value(config_module, "SEARCH_CONTENT_MAX_LENGTH", 500000)), | ||
| 108 | - default_search_hot_content_limit=int( | ||
| 109 | - _get_value(config_module, "DEFAULT_SEARCH_HOT_CONTENT_LIMIT", 100) | ||
| 110 | - ), | ||
| 111 | - default_search_topic_globally_limit_per_table=int( | ||
| 112 | - _get_value(config_module, "DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", 50) | ||
| 113 | - ), | ||
| 114 | - default_search_topic_by_date_limit_per_table=int( | ||
| 115 | - _get_value(config_module, "DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE", 100) | ||
| 116 | - ), | ||
| 117 | - default_get_comments_for_topic_limit=int( | ||
| 118 | - _get_value(config_module, "DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT", 500) | ||
| 119 | - ), | ||
| 120 | - default_search_topic_on_platform_limit=int( | ||
| 121 | - _get_value(config_module, "DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT", 200) | ||
| 122 | - ), | ||
| 123 | - max_search_results_for_llm=int(_get_value(config_module, "MAX_SEARCH_RESULTS_FOR_LLM", 0)), | ||
| 124 | - max_high_confidence_sentiment_results=int( | ||
| 125 | - _get_value(config_module, "MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", 0) | ||
| 126 | - ), | ||
| 127 | - output_dir=_get_value(config_module, "OUTPUT_DIR", "reports"), | ||
| 128 | - save_intermediate_states=str( | ||
| 129 | - _get_value(config_module, "SAVE_INTERMEDIATE_STATES", "true") | ||
| 130 | - ).lower() | ||
| 131 | - in ("true", "1", "yes"), | ||
| 132 | - ) | ||
| 133 | - | ||
| 134 | - # .env style configuration | ||
| 135 | - config_dict = {} | ||
| 136 | - if os.path.exists(config_file): | ||
| 137 | - with open(config_file, "r", encoding="utf-8") as f: | ||
| 138 | - for line in f: | ||
| 139 | - line = line.strip() | ||
| 140 | - if line and not line.startswith("#") and "=" in line: | ||
| 141 | - key, value = line.split("=", 1) | ||
| 142 | - config_dict[key.strip()] = value.strip() | ||
| 143 | - | ||
| 144 | - return cls( | ||
| 145 | - llm_api_key=_get_value(config_dict, "INSIGHT_ENGINE_API_KEY"), | ||
| 146 | - llm_base_url=_get_value(config_dict, "INSIGHT_ENGINE_BASE_URL"), | ||
| 147 | - llm_model_name=_get_value(config_dict, "INSIGHT_ENGINE_MODEL_NAME"), | ||
| 148 | - db_host=_get_value(config_dict, "DB_HOST"), | ||
| 149 | - db_user=_get_value(config_dict, "DB_USER"), | ||
| 150 | - db_password=_get_value(config_dict, "DB_PASSWORD"), | ||
| 151 | - db_name=_get_value(config_dict, "DB_NAME"), | ||
| 152 | - db_port=int(_get_value(config_dict, "DB_PORT", 3306)), | ||
| 153 | - db_charset=_get_value(config_dict, "DB_CHARSET", "utf8mb4"), | ||
| 154 | - max_reflections=int(_get_value(config_dict, "MAX_REFLECTIONS", 3)), | ||
| 155 | - max_paragraphs=int(_get_value(config_dict, "MAX_PARAGRAPHS", 6)), | ||
| 156 | - search_timeout=int(_get_value(config_dict, "SEARCH_TIMEOUT", 240)), | ||
| 157 | - max_content_length=int(_get_value(config_dict, "SEARCH_CONTENT_MAX_LENGTH", 500000)), | ||
| 158 | - default_search_hot_content_limit=int( | ||
| 159 | - _get_value(config_dict, "DEFAULT_SEARCH_HOT_CONTENT_LIMIT", 100) | ||
| 160 | - ), | ||
| 161 | - default_search_topic_globally_limit_per_table=int( | ||
| 162 | - _get_value(config_dict, "DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", 50) | ||
| 163 | - ), | ||
| 164 | - default_search_topic_by_date_limit_per_table=int( | ||
| 165 | - _get_value(config_dict, "DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE", 100) | ||
| 166 | - ), | ||
| 167 | - default_get_comments_for_topic_limit=int( | ||
| 168 | - _get_value(config_dict, "DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT", 500) | ||
| 169 | - ), | ||
| 170 | - default_search_topic_on_platform_limit=int( | ||
| 171 | - _get_value(config_dict, "DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT", 200) | ||
| 172 | - ), | ||
| 173 | - max_search_results_for_llm=int(_get_value(config_dict, "MAX_SEARCH_RESULTS_FOR_LLM", 0)), | ||
| 174 | - max_high_confidence_sentiment_results=int( | ||
| 175 | - _get_value(config_dict, "MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", 0) | ||
| 176 | - ), | ||
| 177 | - output_dir=_get_value(config_dict, "OUTPUT_DIR", "reports"), | ||
| 178 | - save_intermediate_states=str( | ||
| 179 | - _get_value(config_dict, "SAVE_INTERMEDIATE_STATES", "true") | ||
| 180 | - ).lower() | ||
| 181 | - in ("true", "1", "yes"), | ||
| 182 | - ) | ||
| 183 | - | ||
| 184 | - | ||
| 185 | -def load_config(config_file: Optional[str] = None) -> Config: | ||
| 186 | - """ | ||
| 187 | - Load configuration. | ||
| 188 | - """ | ||
| 189 | - if config_file: | ||
| 190 | - if not os.path.exists(config_file): | ||
| 191 | - raise FileNotFoundError(f"配置文件不存在: {config_file}") | ||
| 192 | - file_to_load = config_file | ||
| 193 | - else: | ||
| 194 | - for candidate in ("config.py", "config.env", ".env"): | ||
| 195 | - if os.path.exists(candidate): | ||
| 196 | - file_to_load = candidate | ||
| 197 | - print(f"已找到配置文件: {candidate}") | ||
| 198 | - break | ||
| 199 | - else: | ||
| 200 | - raise FileNotFoundError("未找到配置文件,请创建 config.py。") | ||
| 201 | - | ||
| 202 | - config = Config.from_file(file_to_load) | ||
| 203 | - | ||
| 204 | - if not config.validate(): | ||
| 205 | - raise ValueError("配置校验失败,请检查 config.py 中的相关配置。") | ||
| 206 | - | ||
| 207 | - return config | ||
| 208 | - | ||
| 209 | - | ||
| 210 | -def print_config(config: Config): | ||
| 211 | - """Print configuration (sensitive values masked).""" | ||
| 212 | - print("\n=== Insight Engine 配置 ===") | ||
| 213 | - print(f"LLM 模型: {config.llm_model_name}") | ||
| 214 | - print(f"LLM Base URL: {config.llm_base_url or '(默认)'}") | ||
| 215 | - print(f"搜索超时: {config.search_timeout} 秒") | ||
| 216 | - print(f"最长内容长度: {config.max_content_length}") | ||
| 217 | - print(f"最大反思次数: {config.max_reflections}") | ||
| 218 | - print(f"最大段落数: {config.max_paragraphs}") | ||
| 219 | - print(f"输出目录: {config.output_dir}") | ||
| 220 | - print(f"保存中间状态: {config.save_intermediate_states}") | ||
| 221 | - print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}") | ||
| 222 | - print(f"数据库连接: {'已配置' if all([config.db_host, config.db_user, config.db_password, config.db_name]) else '未配置'}") | ||
| 223 | - print("========================\n") | 9 | +from pydantic_settings import BaseSettings |
| 10 | +from pydantic import Field | ||
| 11 | +from loguru import logger | ||
| 12 | + | ||
| 13 | +class Settings(BaseSettings): | ||
| 14 | + INSIGHT_ENGINE_API_KEY: Optional[str] = Field(None, description="Insight Engine LLM API密钥") | ||
| 15 | + INSIGHT_ENGINE_BASE_URL: Optional[str] = Field(None, description="Insight Engine LLM base url,可选") | ||
| 16 | + INSIGHT_ENGINE_MODEL_NAME: Optional[str] = Field(None, description="Insight Engine LLM模型名称") | ||
| 17 | + INSIGHT_ENGINE_PROVIDER: Optional[str] = Field(None, description="Insight Engine模型提供者,不再建议使用") | ||
| 18 | + DB_HOST: Optional[str] = Field(None, description="数据库主机") | ||
| 19 | + DB_USER: Optional[str] = Field(None, description="数据库用户名") | ||
| 20 | + DB_PASSWORD: Optional[str] = Field(None, description="数据库密码") | ||
| 21 | + DB_NAME: Optional[str] = Field(None, description="数据库名称") | ||
| 22 | + DB_PORT: int = Field(3306, description="数据库端口") | ||
| 23 | + DB_CHARSET: str = Field("utf8mb4", description="数据库字符集") | ||
| 24 | + DB_DIALECT: Optional[str] = Field("mysql", description="数据库方言,如mysql、postgresql等,SQLAlchemy后端选择") | ||
| 25 | + MAX_REFLECTIONS: int = Field(3, description="最大反思次数") | ||
| 26 | + MAX_PARAGRAPHS: int = Field(6, description="最大段落数") | ||
| 27 | + SEARCH_TIMEOUT: int = Field(240, description="单次搜索请求超时") | ||
| 28 | + MAX_CONTENT_LENGTH: int = Field(500000, description="搜索最大内容长度") | ||
| 29 | + DEFAULT_SEARCH_HOT_CONTENT_LIMIT: int = Field(100, description="热榜内容默认最大数") | ||
| 30 | + DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE: int = Field(50, description="按表全局话题最大数") | ||
| 31 | + DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE: int = Field(100, description="按日期话题最大数") | ||
| 32 | + DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT: int = Field(500, description="单话题评论最大数") | ||
| 33 | + DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT: int = Field(200, description="平台搜索话题最大数") | ||
| 34 | + MAX_SEARCH_RESULTS_FOR_LLM: int = Field(0, description="供LLM用搜索结果最大数") | ||
| 35 | + MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS: int = Field(0, description="高置信度情感分析最大数") | ||
| 36 | + OUTPUT_DIR: str = Field("reports", description="输出路径") | ||
| 37 | + SAVE_INTERMEDIATE_STATES: bool = Field(True, description="是否保存中间状态") | ||
| 38 | + | ||
| 39 | + class Config: | ||
| 40 | + env_file = ".env" | ||
| 41 | + env_prefix = "" | ||
| 42 | + case_sensitive = False | ||
| 43 | + extra = "allow" | ||
| 44 | + | ||
| 45 | +settings = Settings() |
| @@ -4,9 +4,9 @@ Deep Search Agent | @@ -4,9 +4,9 @@ Deep Search Agent | ||
| 4 | """ | 4 | """ |
| 5 | 5 | ||
| 6 | from .agent import DeepSearchAgent, create_agent | 6 | from .agent import DeepSearchAgent, create_agent |
| 7 | -from .utils.config import Config, load_config | 7 | +from .utils.config import Settings |
| 8 | 8 | ||
| 9 | __version__ = "1.0.0" | 9 | __version__ = "1.0.0" |
| 10 | __author__ = "Deep Search Agent Team" | 10 | __author__ = "Deep Search Agent Team" |
| 11 | 11 | ||
| 12 | -__all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"] | 12 | +__all__ = ["DeepSearchAgent", "create_agent", "Settings"] |
| @@ -8,7 +8,7 @@ import os | @@ -8,7 +8,7 @@ import os | ||
| 8 | import re | 8 | import re |
| 9 | from datetime import datetime | 9 | from datetime import datetime |
| 10 | from typing import Optional, Dict, Any, List | 10 | from typing import Optional, Dict, Any, List |
| 11 | - | 11 | +from loguru import logger |
| 12 | from .llms import LLMClient | 12 | from .llms import LLMClient |
| 13 | from .nodes import ( | 13 | from .nodes import ( |
| 14 | ReportStructureNode, | 14 | ReportStructureNode, |
| @@ -20,29 +20,26 @@ from .nodes import ( | @@ -20,29 +20,26 @@ from .nodes import ( | ||
| 20 | ) | 20 | ) |
| 21 | from .state import State | 21 | from .state import State |
| 22 | from .tools import BochaMultimodalSearch, BochaResponse | 22 | from .tools import BochaMultimodalSearch, BochaResponse |
| 23 | -from .utils import Config, load_config, format_search_results_for_prompt | 23 | +from .utils import settings, Settings, format_search_results_for_prompt |
| 24 | 24 | ||
| 25 | 25 | ||
| 26 | class DeepSearchAgent: | 26 | class DeepSearchAgent: |
| 27 | """Deep Search Agent主类""" | 27 | """Deep Search Agent主类""" |
| 28 | 28 | ||
| 29 | - def __init__(self, config: Optional[Config] = None): | 29 | + def __init__(self, config: Optional[Settings] = None): |
| 30 | """ | 30 | """ |
| 31 | 初始化Deep Search Agent | 31 | 初始化Deep Search Agent |
| 32 | 32 | ||
| 33 | Args: | 33 | Args: |
| 34 | config: 配置对象,如果不提供则自动加载 | 34 | config: 配置对象,如果不提供则自动加载 |
| 35 | """ | 35 | """ |
| 36 | - # 加载配置 | ||
| 37 | - self.config = config or load_config() | ||
| 38 | - os.environ["BOCHA_API_KEY"] = self.config.bocha_api_key or "" | ||
| 39 | - os.environ["BOCHA_WEB_SEARCH_API_KEY"] = self.config.bocha_api_key or "" | 36 | + self.config = config or settings |
| 40 | 37 | ||
| 41 | # 初始化LLM客户端 | 38 | # 初始化LLM客户端 |
| 42 | self.llm_client = self._initialize_llm() | 39 | self.llm_client = self._initialize_llm() |
| 43 | 40 | ||
| 44 | # 初始化搜索工具集 | 41 | # 初始化搜索工具集 |
| 45 | - self.search_agency = BochaMultimodalSearch(api_key=self.config.bocha_api_key) | 42 | + self.search_agency = BochaMultimodalSearch(api_key=(self.config.BOCHA_API_KEY or self.config.BOCHA_WEB_SEARCH_API_KEY)) |
| 46 | 43 | ||
| 47 | # 初始化节点 | 44 | # 初始化节点 |
| 48 | self._initialize_nodes() | 45 | self._initialize_nodes() |
| @@ -51,18 +48,18 @@ class DeepSearchAgent: | @@ -51,18 +48,18 @@ class DeepSearchAgent: | ||
| 51 | self.state = State() | 48 | self.state = State() |
| 52 | 49 | ||
| 53 | # 确保输出目录存在 | 50 | # 确保输出目录存在 |
| 54 | - os.makedirs(self.config.output_dir, exist_ok=True) | 51 | + os.makedirs(self.config.OUTPUT_DIR, exist_ok=True) |
| 55 | 52 | ||
| 56 | - print(f"Meida Agent已初始化") | ||
| 57 | - print(f"使用LLM: {self.llm_client.get_model_info()}") | ||
| 58 | - print(f"搜索工具集: BochaMultimodalSearch (支持5种多模态搜索工具)") | 53 | + logger.info(f"Meida Agent已初始化") |
| 54 | + logger.info(f"使用LLM: {self.llm_client.get_model_info()}") | ||
| 55 | + logger.info(f"搜索工具集: BochaMultimodalSearch (支持5种多模态搜索工具)") | ||
| 59 | 56 | ||
| 60 | def _initialize_llm(self) -> LLMClient: | 57 | def _initialize_llm(self) -> LLMClient: |
| 61 | """初始化LLM客户端""" | 58 | """初始化LLM客户端""" |
| 62 | return LLMClient( | 59 | return LLMClient( |
| 63 | - api_key=self.config.llm_api_key, | ||
| 64 | - model_name=self.config.llm_model_name, | ||
| 65 | - base_url=self.config.llm_base_url, | 60 | + api_key=(self.config.MEDIA_ENGINE_API_KEY or self.config.MINDSPIDER_API_KEY), |
| 61 | + model_name=(self.config.MEDIA_ENGINE_MODEL_NAME or self.config.MINDSPIDER_MODEL_NAME), | ||
| 62 | + base_url=(self.config.MEDIA_ENGINE_BASE_URL or self.config.MINDSPIDER_BASE_URL), | ||
| 66 | ) | 63 | ) |
| 67 | 64 | ||
| 68 | def _initialize_nodes(self): | 65 | def _initialize_nodes(self): |
| @@ -115,7 +112,7 @@ class DeepSearchAgent: | @@ -115,7 +112,7 @@ class DeepSearchAgent: | ||
| 115 | Returns: | 112 | Returns: |
| 116 | BochaResponse对象 | 113 | BochaResponse对象 |
| 117 | """ | 114 | """ |
| 118 | - print(f" → 执行搜索工具: {tool_name}") | 115 | + logger.info(f" → 执行搜索工具: {tool_name}") |
| 119 | 116 | ||
| 120 | if tool_name == "comprehensive_search": | 117 | if tool_name == "comprehensive_search": |
| 121 | max_results = kwargs.get("max_results", 10) | 118 | max_results = kwargs.get("max_results", 10) |
| @@ -130,7 +127,7 @@ class DeepSearchAgent: | @@ -130,7 +127,7 @@ class DeepSearchAgent: | ||
| 130 | elif tool_name == "search_last_week": | 127 | elif tool_name == "search_last_week": |
| 131 | return self.search_agency.search_last_week(query) | 128 | return self.search_agency.search_last_week(query) |
| 132 | else: | 129 | else: |
| 133 | - print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认综合搜索") | 130 | + logger.info(f" ⚠️ 未知的搜索工具: {tool_name},使用默认综合搜索") |
| 134 | return self.search_agency.comprehensive_search(query) | 131 | return self.search_agency.comprehensive_search(query) |
| 135 | 132 | ||
| 136 | def research(self, query: str, save_report: bool = True) -> str: | 133 | def research(self, query: str, save_report: bool = True) -> str: |
| @@ -144,9 +141,9 @@ class DeepSearchAgent: | @@ -144,9 +141,9 @@ class DeepSearchAgent: | ||
| 144 | Returns: | 141 | Returns: |
| 145 | 最终报告内容 | 142 | 最终报告内容 |
| 146 | """ | 143 | """ |
| 147 | - print(f"\n{'='*60}") | ||
| 148 | - print(f"开始深度研究: {query}") | ||
| 149 | - print(f"{'='*60}") | 144 | + logger.info(f"\n{'='*60}") |
| 145 | + logger.info(f"开始深度研究: {query}") | ||
| 146 | + logger.info(f"{'='*60}") | ||
| 150 | 147 | ||
| 151 | try: | 148 | try: |
| 152 | # Step 1: 生成报告结构 | 149 | # Step 1: 生成报告结构 |
| @@ -162,19 +159,21 @@ class DeepSearchAgent: | @@ -162,19 +159,21 @@ class DeepSearchAgent: | ||
| 162 | if save_report: | 159 | if save_report: |
| 163 | self._save_report(final_report) | 160 | self._save_report(final_report) |
| 164 | 161 | ||
| 165 | - print(f"\n{'='*60}") | ||
| 166 | - print("深度研究完成!") | ||
| 167 | - print(f"{'='*60}") | 162 | + logger.info(f"\n{'='*60}") |
| 163 | + logger.info("深度研究完成!") | ||
| 164 | + logger.info(f"{'='*60}") | ||
| 168 | 165 | ||
| 169 | return final_report | 166 | return final_report |
| 170 | 167 | ||
| 171 | except Exception as e: | 168 | except Exception as e: |
| 172 | - print(f"研究过程中发生错误: {str(e)}") | 169 | + import traceback |
| 170 | + error_traceback = traceback.format_exc() | ||
| 171 | + logger.error(f"研究过程中发生错误: {str(e)} \n错误堆栈: {error_traceback}") | ||
| 173 | raise e | 172 | raise e |
| 174 | 173 | ||
| 175 | def _generate_report_structure(self, query: str): | 174 | def _generate_report_structure(self, query: str): |
| 176 | """生成报告结构""" | 175 | """生成报告结构""" |
| 177 | - print(f"\n[步骤 1] 生成报告结构...") | 176 | + logger.info(f"\n[步骤 1] 生成报告结构...") |
| 178 | 177 | ||
| 179 | # 创建报告结构节点 | 178 | # 创建报告结构节点 |
| 180 | report_structure_node = ReportStructureNode(self.llm_client, query) | 179 | report_structure_node = ReportStructureNode(self.llm_client, query) |
| @@ -182,17 +181,18 @@ class DeepSearchAgent: | @@ -182,17 +181,18 @@ class DeepSearchAgent: | ||
| 182 | # 生成结构并更新状态 | 181 | # 生成结构并更新状态 |
| 183 | self.state = report_structure_node.mutate_state(state=self.state) | 182 | self.state = report_structure_node.mutate_state(state=self.state) |
| 184 | 183 | ||
| 185 | - print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:") | 184 | + _message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:" |
| 186 | for i, paragraph in enumerate(self.state.paragraphs, 1): | 185 | for i, paragraph in enumerate(self.state.paragraphs, 1): |
| 187 | - print(f" {i}. {paragraph.title}") | 186 | + _message += f"\n {i}. {paragraph.title}" |
| 187 | + logger.info(_message) | ||
| 188 | 188 | ||
| 189 | def _process_paragraphs(self): | 189 | def _process_paragraphs(self): |
| 190 | """处理所有段落""" | 190 | """处理所有段落""" |
| 191 | total_paragraphs = len(self.state.paragraphs) | 191 | total_paragraphs = len(self.state.paragraphs) |
| 192 | 192 | ||
| 193 | for i in range(total_paragraphs): | 193 | for i in range(total_paragraphs): |
| 194 | - print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") | ||
| 195 | - print("-" * 50) | 194 | + logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") |
| 195 | + logger.info("-" * 50) | ||
| 196 | 196 | ||
| 197 | # 初始搜索和总结 | 197 | # 初始搜索和总结 |
| 198 | self._initial_search_and_summary(i) | 198 | self._initial_search_and_summary(i) |
| @@ -204,7 +204,7 @@ class DeepSearchAgent: | @@ -204,7 +204,7 @@ class DeepSearchAgent: | ||
| 204 | self.state.paragraphs[i].research.mark_completed() | 204 | self.state.paragraphs[i].research.mark_completed() |
| 205 | 205 | ||
| 206 | progress = (i + 1) / total_paragraphs * 100 | 206 | progress = (i + 1) / total_paragraphs * 100 |
| 207 | - print(f"段落处理完成 ({progress:.1f}%)") | 207 | + logger.info(f"段落处理完成 ({progress:.1f}%)") |
| 208 | 208 | ||
| 209 | def _initial_search_and_summary(self, paragraph_index: int): | 209 | def _initial_search_and_summary(self, paragraph_index: int): |
| 210 | """执行初始搜索和总结""" | 210 | """执行初始搜索和总结""" |
| @@ -217,18 +217,18 @@ class DeepSearchAgent: | @@ -217,18 +217,18 @@ class DeepSearchAgent: | ||
| 217 | } | 217 | } |
| 218 | 218 | ||
| 219 | # 生成搜索查询和工具选择 | 219 | # 生成搜索查询和工具选择 |
| 220 | - print(" - 生成搜索查询...") | 220 | + logger.info(" - 生成搜索查询...") |
| 221 | search_output = self.first_search_node.run(search_input) | 221 | search_output = self.first_search_node.run(search_input) |
| 222 | search_query = search_output["search_query"] | 222 | search_query = search_output["search_query"] |
| 223 | search_tool = search_output.get("search_tool", "comprehensive_search") # 默认工具 | 223 | search_tool = search_output.get("search_tool", "comprehensive_search") # 默认工具 |
| 224 | reasoning = search_output["reasoning"] | 224 | reasoning = search_output["reasoning"] |
| 225 | 225 | ||
| 226 | - print(f" - 搜索查询: {search_query}") | ||
| 227 | - print(f" - 选择的工具: {search_tool}") | ||
| 228 | - print(f" - 推理: {reasoning}") | 226 | + logger.info(f" - 搜索查询: {search_query}") |
| 227 | + logger.info(f" - 选择的工具: {search_tool}") | ||
| 228 | + logger.info(f" - 推理: {reasoning}") | ||
| 229 | 229 | ||
| 230 | # 执行搜索 | 230 | # 执行搜索 |
| 231 | - print(" - 执行网络搜索...") | 231 | + logger.info(" - 执行网络搜索...") |
| 232 | 232 | ||
| 233 | # 处理特殊参数(新的工具集不需要日期参数处理) | 233 | # 处理特殊参数(新的工具集不需要日期参数处理) |
| 234 | search_kwargs = {} | 234 | search_kwargs = {} |
| @@ -254,24 +254,25 @@ class DeepSearchAgent: | @@ -254,24 +254,25 @@ class DeepSearchAgent: | ||
| 254 | }) | 254 | }) |
| 255 | 255 | ||
| 256 | if search_results: | 256 | if search_results: |
| 257 | - print(f" - 找到 {len(search_results)} 个搜索结果") | 257 | + _message = f" - 找到 {len(search_results)} 个搜索结果" |
| 258 | for j, result in enumerate(search_results, 1): | 258 | for j, result in enumerate(search_results, 1): |
| 259 | date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" | 259 | date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" |
| 260 | - print(f" {j}. {result['title'][:50]}...{date_info}") | 260 | + _message += f"\n {j}. {result['title'][:50]}...{date_info}" |
| 261 | + logger.info(_message) | ||
| 261 | else: | 262 | else: |
| 262 | - print(" - 未找到搜索结果") | 263 | + logger.info(" - 未找到搜索结果") |
| 263 | 264 | ||
| 264 | # 更新状态中的搜索历史 | 265 | # 更新状态中的搜索历史 |
| 265 | paragraph.research.add_search_results(search_query, search_results) | 266 | paragraph.research.add_search_results(search_query, search_results) |
| 266 | 267 | ||
| 267 | # 生成初始总结 | 268 | # 生成初始总结 |
| 268 | - print(" - 生成初始总结...") | 269 | + logger.info(" - 生成初始总结...") |
| 269 | summary_input = { | 270 | summary_input = { |
| 270 | "title": paragraph.title, | 271 | "title": paragraph.title, |
| 271 | "content": paragraph.content, | 272 | "content": paragraph.content, |
| 272 | "search_query": search_query, | 273 | "search_query": search_query, |
| 273 | "search_results": format_search_results_for_prompt( | 274 | "search_results": format_search_results_for_prompt( |
| 274 | - search_results, self.config.max_content_length | 275 | + search_results, self.config.SEARCH_CONTENT_MAX_LENGTH |
| 275 | ) | 276 | ) |
| 276 | } | 277 | } |
| 277 | 278 | ||
| @@ -280,14 +281,14 @@ class DeepSearchAgent: | @@ -280,14 +281,14 @@ class DeepSearchAgent: | ||
| 280 | summary_input, self.state, paragraph_index | 281 | summary_input, self.state, paragraph_index |
| 281 | ) | 282 | ) |
| 282 | 283 | ||
| 283 | - print(" - 初始总结完成") | 284 | + logger.info(" - 初始总结完成") |
| 284 | 285 | ||
| 285 | def _reflection_loop(self, paragraph_index: int): | 286 | def _reflection_loop(self, paragraph_index: int): |
| 286 | """执行反思循环""" | 287 | """执行反思循环""" |
| 287 | paragraph = self.state.paragraphs[paragraph_index] | 288 | paragraph = self.state.paragraphs[paragraph_index] |
| 288 | 289 | ||
| 289 | - for reflection_i in range(self.config.max_reflections): | ||
| 290 | - print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...") | 290 | + for reflection_i in range(self.config.MAX_REFLECTIONS): |
| 291 | + logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...") | ||
| 291 | 292 | ||
| 292 | # 准备反思输入 | 293 | # 准备反思输入 |
| 293 | reflection_input = { | 294 | reflection_input = { |
| @@ -302,9 +303,9 @@ class DeepSearchAgent: | @@ -302,9 +303,9 @@ class DeepSearchAgent: | ||
| 302 | search_tool = reflection_output.get("search_tool", "comprehensive_search") # 默认工具 | 303 | search_tool = reflection_output.get("search_tool", "comprehensive_search") # 默认工具 |
| 303 | reasoning = reflection_output["reasoning"] | 304 | reasoning = reflection_output["reasoning"] |
| 304 | 305 | ||
| 305 | - print(f" 反思查询: {search_query}") | ||
| 306 | - print(f" 选择的工具: {search_tool}") | ||
| 307 | - print(f" 反思推理: {reasoning}") | 306 | + logger.info(f" 反思查询: {search_query}") |
| 307 | + logger.info(f" 选择的工具: {search_tool}") | ||
| 308 | + logger.info(f" 反思推理: {reasoning}") | ||
| 308 | 309 | ||
| 309 | # 执行反思搜索 | 310 | # 执行反思搜索 |
| 310 | # 处理特殊参数 | 311 | # 处理特殊参数 |
| @@ -331,12 +332,13 @@ class DeepSearchAgent: | @@ -331,12 +332,13 @@ class DeepSearchAgent: | ||
| 331 | }) | 332 | }) |
| 332 | 333 | ||
| 333 | if search_results: | 334 | if search_results: |
| 334 | - print(f" 找到 {len(search_results)} 个反思搜索结果") | 335 | + _message = f" 找到 {len(search_results)} 个反思搜索结果" |
| 335 | for j, result in enumerate(search_results, 1): | 336 | for j, result in enumerate(search_results, 1): |
| 336 | date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" | 337 | date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" |
| 337 | - print(f" {j}. {result['title'][:50]}...{date_info}") | 338 | + _message += f"\n {j}. {result['title'][:50]}...{date_info}" |
| 339 | + logger.info(_message) | ||
| 338 | else: | 340 | else: |
| 339 | - print(" 未找到反思搜索结果") | 341 | + logger.info(" 未找到反思搜索结果") |
| 340 | 342 | ||
| 341 | # 更新搜索历史 | 343 | # 更新搜索历史 |
| 342 | paragraph.research.add_search_results(search_query, search_results) | 344 | paragraph.research.add_search_results(search_query, search_results) |
| @@ -347,7 +349,7 @@ class DeepSearchAgent: | @@ -347,7 +349,7 @@ class DeepSearchAgent: | ||
| 347 | "content": paragraph.content, | 349 | "content": paragraph.content, |
| 348 | "search_query": search_query, | 350 | "search_query": search_query, |
| 349 | "search_results": format_search_results_for_prompt( | 351 | "search_results": format_search_results_for_prompt( |
| 350 | - search_results, self.config.max_content_length | 352 | + search_results, self.config.SEARCH_CONTENT_MAX_LENGTH |
| 351 | ), | 353 | ), |
| 352 | "paragraph_latest_state": paragraph.research.latest_summary | 354 | "paragraph_latest_state": paragraph.research.latest_summary |
| 353 | } | 355 | } |
| @@ -357,11 +359,11 @@ class DeepSearchAgent: | @@ -357,11 +359,11 @@ class DeepSearchAgent: | ||
| 357 | reflection_summary_input, self.state, paragraph_index | 359 | reflection_summary_input, self.state, paragraph_index |
| 358 | ) | 360 | ) |
| 359 | 361 | ||
| 360 | - print(f" 反思 {reflection_i + 1} 完成") | 362 | + logger.info(f" 反思 {reflection_i + 1} 完成") |
| 361 | 363 | ||
| 362 | def _generate_final_report(self) -> str: | 364 | def _generate_final_report(self) -> str: |
| 363 | """生成最终报告""" | 365 | """生成最终报告""" |
| 364 | - print(f"\n[步骤 3] 生成最终报告...") | 366 | + logger.info(f"\n[步骤 3] 生成最终报告...") |
| 365 | 367 | ||
| 366 | # 准备报告数据 | 368 | # 准备报告数据 |
| 367 | report_data = [] | 369 | report_data = [] |
| @@ -375,7 +377,7 @@ class DeepSearchAgent: | @@ -375,7 +377,7 @@ class DeepSearchAgent: | ||
| 375 | try: | 377 | try: |
| 376 | final_report = self.report_formatting_node.run(report_data) | 378 | final_report = self.report_formatting_node.run(report_data) |
| 377 | except Exception as e: | 379 | except Exception as e: |
| 378 | - print(f"LLM格式化失败,使用备用方法: {str(e)}") | 380 | + logger.info(f"LLM格式化失败,使用备用方法: {str(e)}") |
| 379 | final_report = self.report_formatting_node.format_report_manually( | 381 | final_report = self.report_formatting_node.format_report_manually( |
| 380 | report_data, self.state.report_title | 382 | report_data, self.state.report_title |
| 381 | ) | 383 | ) |
| @@ -384,7 +386,7 @@ class DeepSearchAgent: | @@ -384,7 +386,7 @@ class DeepSearchAgent: | ||
| 384 | self.state.final_report = final_report | 386 | self.state.final_report = final_report |
| 385 | self.state.mark_completed() | 387 | self.state.mark_completed() |
| 386 | 388 | ||
| 387 | - print("最终报告生成完成") | 389 | + logger.info("最终报告生成完成") |
| 388 | return final_report | 390 | return final_report |
| 389 | 391 | ||
| 390 | def _save_report(self, report_content: str): | 392 | def _save_report(self, report_content: str): |
| @@ -395,20 +397,20 @@ class DeepSearchAgent: | @@ -395,20 +397,20 @@ class DeepSearchAgent: | ||
| 395 | query_safe = query_safe.replace(' ', '_')[:30] | 397 | query_safe = query_safe.replace(' ', '_')[:30] |
| 396 | 398 | ||
| 397 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" | 399 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" |
| 398 | - filepath = os.path.join(self.config.output_dir, filename) | 400 | + filepath = os.path.join(self.config.OUTPUT_DIR, filename) |
| 399 | 401 | ||
| 400 | # 保存报告 | 402 | # 保存报告 |
| 401 | with open(filepath, 'w', encoding='utf-8') as f: | 403 | with open(filepath, 'w', encoding='utf-8') as f: |
| 402 | f.write(report_content) | 404 | f.write(report_content) |
| 403 | 405 | ||
| 404 | - print(f"报告已保存到: {filepath}") | 406 | + logger.info(f"报告已保存到: {filepath}") |
| 405 | 407 | ||
| 406 | # 保存状态(如果配置允许) | 408 | # 保存状态(如果配置允许) |
| 407 | - if self.config.save_intermediate_states: | 409 | + if self.config.SAVE_INTERMEDIATE_STATES: |
| 408 | state_filename = f"state_{query_safe}_{timestamp}.json" | 410 | state_filename = f"state_{query_safe}_{timestamp}.json" |
| 409 | - state_filepath = os.path.join(self.config.output_dir, state_filename) | 411 | + state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename) |
| 410 | self.state.save_to_file(state_filepath) | 412 | self.state.save_to_file(state_filepath) |
| 411 | - print(f"状态已保存到: {state_filepath}") | 413 | + logger.info(f"状态已保存到: {state_filepath}") |
| 412 | 414 | ||
| 413 | def get_progress_summary(self) -> Dict[str, Any]: | 415 | def get_progress_summary(self) -> Dict[str, Any]: |
| 414 | """获取进度摘要""" | 416 | """获取进度摘要""" |
| @@ -417,12 +419,12 @@ class DeepSearchAgent: | @@ -417,12 +419,12 @@ class DeepSearchAgent: | ||
| 417 | def load_state(self, filepath: str): | 419 | def load_state(self, filepath: str): |
| 418 | """从文件加载状态""" | 420 | """从文件加载状态""" |
| 419 | self.state = State.load_from_file(filepath) | 421 | self.state = State.load_from_file(filepath) |
| 420 | - print(f"状态已从 {filepath} 加载") | 422 | + logger.info(f"状态已从 {filepath} 加载") |
| 421 | 423 | ||
| 422 | def save_state(self, filepath: str): | 424 | def save_state(self, filepath: str): |
| 423 | """保存状态到文件""" | 425 | """保存状态到文件""" |
| 424 | self.state.save_to_file(filepath) | 426 | self.state.save_to_file(filepath) |
| 425 | - print(f"状态已保存到 {filepath}") | 427 | + logger.info(f"状态已保存到 {filepath}") |
| 426 | 428 | ||
| 427 | 429 | ||
| 428 | def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: | 430 | def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: |
| @@ -435,5 +437,5 @@ def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: | @@ -435,5 +437,5 @@ def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: | ||
| 435 | Returns: | 437 | Returns: |
| 436 | DeepSearchAgent实例 | 438 | DeepSearchAgent实例 |
| 437 | """ | 439 | """ |
| 438 | - config = load_config(config_file) | ||
| 439 | - return DeepSearchAgent(config) | 440 | + settings = Settings() |
| 441 | + return DeepSearchAgent(settings) |
| @@ -7,67 +7,72 @@ from abc import ABC, abstractmethod | @@ -7,67 +7,72 @@ from abc import ABC, abstractmethod | ||
| 7 | from typing import Any, Dict, Optional | 7 | from typing import Any, Dict, Optional |
| 8 | from ..llms.base import LLMClient | 8 | from ..llms.base import LLMClient |
| 9 | from ..state.state import State | 9 | from ..state.state import State |
| 10 | +from loguru import logger | ||
| 10 | 11 | ||
| 11 | 12 | ||
| 12 | class BaseNode(ABC): | 13 | class BaseNode(ABC): |
| 13 | """节点基类""" | 14 | """节点基类""" |
| 14 | - | 15 | + |
| 15 | def __init__(self, llm_client: LLMClient, node_name: str = ""): | 16 | def __init__(self, llm_client: LLMClient, node_name: str = ""): |
| 16 | """ | 17 | """ |
| 17 | 初始化节点 | 18 | 初始化节点 |
| 18 | - | 19 | + |
| 19 | Args: | 20 | Args: |
| 20 | llm_client: LLM客户端 | 21 | llm_client: LLM客户端 |
| 21 | node_name: 节点名称 | 22 | node_name: 节点名称 |
| 22 | """ | 23 | """ |
| 23 | self.llm_client = llm_client | 24 | self.llm_client = llm_client |
| 24 | self.node_name = node_name or self.__class__.__name__ | 25 | self.node_name = node_name or self.__class__.__name__ |
| 25 | - | 26 | + |
| 26 | @abstractmethod | 27 | @abstractmethod |
| 27 | def run(self, input_data: Any, **kwargs) -> Any: | 28 | def run(self, input_data: Any, **kwargs) -> Any: |
| 28 | """ | 29 | """ |
| 29 | 执行节点处理逻辑 | 30 | 执行节点处理逻辑 |
| 30 | - | 31 | + |
| 31 | Args: | 32 | Args: |
| 32 | input_data: 输入数据 | 33 | input_data: 输入数据 |
| 33 | **kwargs: 额外参数 | 34 | **kwargs: 额外参数 |
| 34 | - | 35 | + |
| 35 | Returns: | 36 | Returns: |
| 36 | 处理结果 | 37 | 处理结果 |
| 37 | """ | 38 | """ |
| 38 | pass | 39 | pass |
| 39 | - | 40 | + |
| 40 | def validate_input(self, input_data: Any) -> bool: | 41 | def validate_input(self, input_data: Any) -> bool: |
| 41 | """ | 42 | """ |
| 42 | 验证输入数据 | 43 | 验证输入数据 |
| 43 | - | 44 | + |
| 44 | Args: | 45 | Args: |
| 45 | input_data: 输入数据 | 46 | input_data: 输入数据 |
| 46 | - | 47 | + |
| 47 | Returns: | 48 | Returns: |
| 48 | 验证是否通过 | 49 | 验证是否通过 |
| 49 | """ | 50 | """ |
| 50 | return True | 51 | return True |
| 51 | - | 52 | + |
| 52 | def process_output(self, output: Any) -> Any: | 53 | def process_output(self, output: Any) -> Any: |
| 53 | """ | 54 | """ |
| 54 | 处理输出数据 | 55 | 处理输出数据 |
| 55 | - | 56 | + |
| 56 | Args: | 57 | Args: |
| 57 | output: 原始输出 | 58 | output: 原始输出 |
| 58 | - | 59 | + |
| 59 | Returns: | 60 | Returns: |
| 60 | 处理后的输出 | 61 | 处理后的输出 |
| 61 | """ | 62 | """ |
| 62 | return output | 63 | return output |
| 63 | - | 64 | + |
| 64 | def log_info(self, message: str): | 65 | def log_info(self, message: str): |
| 65 | """记录信息日志""" | 66 | """记录信息日志""" |
| 66 | - print(f"[{self.node_name}] {message}") | 67 | + logger.info(f"[{self.node_name}] {message}") |
| 67 | 68 | ||
| 69 | + def log_warning(self, message: str): | ||
| 70 | + """记录警告日志""" | ||
| 71 | + logger.warning(f"[{self.node_name}] 警告: {message}") | ||
| 72 | + | ||
| 68 | def log_error(self, message: str): | 73 | def log_error(self, message: str): |
| 69 | """记录错误日志""" | 74 | """记录错误日志""" |
| 70 | - print(f"[{self.node_name}] 错误: {message}") | 75 | + logger.error(f"[{self.node_name}] 错误: {message}") |
| 71 | 76 | ||
| 72 | 77 | ||
| 73 | class StateMutationNode(BaseNode): | 78 | class StateMutationNode(BaseNode): |
| @@ -5,6 +5,7 @@ | @@ -5,6 +5,7 @@ | ||
| 5 | 5 | ||
| 6 | import json | 6 | import json |
| 7 | from typing import List, Dict, Any | 7 | from typing import List, Dict, Any |
| 8 | +from loguru import logger | ||
| 8 | 9 | ||
| 9 | from .base_node import BaseNode | 10 | from .base_node import BaseNode |
| 10 | from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING | 11 | from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING |
| @@ -65,7 +66,7 @@ class ReportFormattingNode(BaseNode): | @@ -65,7 +66,7 @@ class ReportFormattingNode(BaseNode): | ||
| 65 | else: | 66 | else: |
| 66 | message = json.dumps(input_data, ensure_ascii=False) | 67 | message = json.dumps(input_data, ensure_ascii=False) |
| 67 | 68 | ||
| 68 | - self.log_info("正在格式化最终报告") | 69 | + logger.info("正在格式化最终报告") |
| 69 | 70 | ||
| 70 | # 调用LLM生成Markdown格式 | 71 | # 调用LLM生成Markdown格式 |
| 71 | response = self.llm_client.invoke( | 72 | response = self.llm_client.invoke( |
| @@ -76,11 +77,11 @@ class ReportFormattingNode(BaseNode): | @@ -76,11 +77,11 @@ class ReportFormattingNode(BaseNode): | ||
| 76 | # 处理响应 | 77 | # 处理响应 |
| 77 | processed_response = self.process_output(response) | 78 | processed_response = self.process_output(response) |
| 78 | 79 | ||
| 79 | - self.log_info("成功生成格式化报告") | 80 | + logger.info("成功生成格式化报告") |
| 80 | return processed_response | 81 | return processed_response |
| 81 | 82 | ||
| 82 | except Exception as e: | 83 | except Exception as e: |
| 83 | - self.log_error(f"报告格式化失败: {str(e)}") | 84 | + logger.exception(f"报告格式化失败: {str(e)}") |
| 84 | raise e | 85 | raise e |
| 85 | 86 | ||
| 86 | def process_output(self, output: str) -> str: | 87 | def process_output(self, output: str) -> str: |
| @@ -109,7 +110,7 @@ class ReportFormattingNode(BaseNode): | @@ -109,7 +110,7 @@ class ReportFormattingNode(BaseNode): | ||
| 109 | return cleaned_output.strip() | 110 | return cleaned_output.strip() |
| 110 | 111 | ||
| 111 | except Exception as e: | 112 | except Exception as e: |
| 112 | - self.log_error(f"处理输出失败: {str(e)}") | 113 | + logger.exception(f"处理输出失败: {str(e)}") |
| 113 | return "# 报告处理失败\n\n报告格式化过程中发生错误。" | 114 | return "# 报告处理失败\n\n报告格式化过程中发生错误。" |
| 114 | 115 | ||
| 115 | def format_report_manually(self, paragraphs_data: List[Dict[str, str]], | 116 | def format_report_manually(self, paragraphs_data: List[Dict[str, str]], |
| @@ -125,7 +126,7 @@ class ReportFormattingNode(BaseNode): | @@ -125,7 +126,7 @@ class ReportFormattingNode(BaseNode): | ||
| 125 | 格式化的Markdown报告 | 126 | 格式化的Markdown报告 |
| 126 | """ | 127 | """ |
| 127 | try: | 128 | try: |
| 128 | - self.log_info("使用手动格式化方法") | 129 | + logger.info("使用手动格式化方法") |
| 129 | 130 | ||
| 130 | # 构建报告 | 131 | # 构建报告 |
| 131 | report_lines = [ | 132 | report_lines = [ |
| @@ -163,5 +164,5 @@ class ReportFormattingNode(BaseNode): | @@ -163,5 +164,5 @@ class ReportFormattingNode(BaseNode): | ||
| 163 | return "\n".join(report_lines) | 164 | return "\n".join(report_lines) |
| 164 | 165 | ||
| 165 | except Exception as e: | 166 | except Exception as e: |
| 166 | - self.log_error(f"手动格式化失败: {str(e)}") | 167 | + logger.exception(f"手动格式化失败: {str(e)}") |
| 167 | return "# 报告生成失败\n\n无法完成报告格式化。" | 168 | return "# 报告生成失败\n\n无法完成报告格式化。" |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | import json | 6 | import json |
| 7 | from typing import Dict, Any, List | 7 | from typing import Dict, Any, List |
| 8 | from json.decoder import JSONDecodeError | 8 | from json.decoder import JSONDecodeError |
| 9 | +from loguru import logger | ||
| 9 | 10 | ||
| 10 | from .base_node import StateMutationNode | 11 | from .base_node import StateMutationNode |
| 11 | from ..state.state import State | 12 | from ..state.state import State |
| @@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode): | @@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode): | ||
| 48 | 报告结构列表 | 49 | 报告结构列表 |
| 49 | """ | 50 | """ |
| 50 | try: | 51 | try: |
| 51 | - self.log_info(f"正在为查询生成报告结构: {self.query}") | 52 | + logger.info(f"正在为查询生成报告结构: {self.query}") |
| 52 | 53 | ||
| 53 | # 调用LLM | 54 | # 调用LLM |
| 54 | response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) | 55 | response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) |
| @@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode): | @@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode): | ||
| 56 | # 处理响应 | 57 | # 处理响应 |
| 57 | processed_response = self.process_output(response) | 58 | processed_response = self.process_output(response) |
| 58 | 59 | ||
| 59 | - self.log_info(f"成功生成 {len(processed_response)} 个段落结构") | 60 | + logger.info(f"成功生成 {len(processed_response)} 个段落结构") |
| 60 | return processed_response | 61 | return processed_response |
| 61 | 62 | ||
| 62 | except Exception as e: | 63 | except Exception as e: |
| 63 | - self.log_error(f"生成报告结构失败: {str(e)}") | 64 | + logger.exception(f"生成报告结构失败: {str(e)}") |
| 64 | raise e | 65 | raise e |
| 65 | 66 | ||
| 66 | def process_output(self, output: str) -> List[Dict[str, str]]: | 67 | def process_output(self, output: str) -> List[Dict[str, str]]: |
| @@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode): | @@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode): | ||
| 79 | cleaned_output = clean_json_tags(cleaned_output) | 80 | cleaned_output = clean_json_tags(cleaned_output) |
| 80 | 81 | ||
| 81 | # 记录清理后的输出用于调试 | 82 | # 记录清理后的输出用于调试 |
| 82 | - self.log_info(f"清理后的输出: {cleaned_output}") | 83 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 83 | 84 | ||
| 84 | # 解析JSON | 85 | # 解析JSON |
| 85 | try: | 86 | try: |
| 86 | report_structure = json.loads(cleaned_output) | 87 | report_structure = json.loads(cleaned_output) |
| 87 | - self.log_info("JSON解析成功") | 88 | + logger.info("JSON解析成功") |
| 88 | except JSONDecodeError as e: | 89 | except JSONDecodeError as e: |
| 89 | - self.log_info(f"JSON解析失败: {str(e)}") | 90 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 90 | # 使用更强大的提取方法 | 91 | # 使用更强大的提取方法 |
| 91 | report_structure = extract_clean_response(cleaned_output) | 92 | report_structure = extract_clean_response(cleaned_output) |
| 92 | if "error" in report_structure: | 93 | if "error" in report_structure: |
| 93 | - self.log_error("JSON解析失败,尝试修复...") | 94 | + logger.error("JSON解析失败,尝试修复...") |
| 94 | # 尝试修复JSON | 95 | # 尝试修复JSON |
| 95 | fixed_json = fix_incomplete_json(cleaned_output) | 96 | fixed_json = fix_incomplete_json(cleaned_output) |
| 96 | if fixed_json: | 97 | if fixed_json: |
| 97 | try: | 98 | try: |
| 98 | report_structure = json.loads(fixed_json) | 99 | report_structure = json.loads(fixed_json) |
| 99 | - self.log_info("JSON修复成功") | 100 | + logger.info("JSON修复成功") |
| 100 | except JSONDecodeError: | 101 | except JSONDecodeError: |
| 101 | - self.log_error("JSON修复失败") | 102 | + logger.error("JSON修复失败") |
| 102 | # 返回默认结构 | 103 | # 返回默认结构 |
| 103 | return self._generate_default_structure() | 104 | return self._generate_default_structure() |
| 104 | else: | 105 | else: |
| 105 | - self.log_error("无法修复JSON,使用默认结构") | 106 | + logger.error("无法修复JSON,使用默认结构") |
| 106 | return self._generate_default_structure() | 107 | return self._generate_default_structure() |
| 107 | 108 | ||
| 108 | # 验证结构 | 109 | # 验证结构 |
| 109 | if not isinstance(report_structure, list): | 110 | if not isinstance(report_structure, list): |
| 110 | - self.log_info("报告结构不是列表,尝试转换...") | 111 | + logger.info("报告结构不是列表,尝试转换...") |
| 111 | if isinstance(report_structure, dict): | 112 | if isinstance(report_structure, dict): |
| 112 | # 如果是单个对象,包装成列表 | 113 | # 如果是单个对象,包装成列表 |
| 113 | report_structure = [report_structure] | 114 | report_structure = [report_structure] |
| 114 | else: | 115 | else: |
| 115 | - self.log_error("报告结构格式无效,使用默认结构") | 116 | + logger.error("报告结构格式无效,使用默认结构") |
| 116 | return self._generate_default_structure() | 117 | return self._generate_default_structure() |
| 117 | 118 | ||
| 118 | # 验证每个段落 | 119 | # 验证每个段落 |
| 119 | validated_structure = [] | 120 | validated_structure = [] |
| 120 | for i, paragraph in enumerate(report_structure): | 121 | for i, paragraph in enumerate(report_structure): |
| 121 | if not isinstance(paragraph, dict): | 122 | if not isinstance(paragraph, dict): |
| 122 | - self.log_warning(f"段落 {i+1} 不是字典格式,跳过") | 123 | + logger.warning(f"段落 {i+1} 不是字典格式,跳过") |
| 123 | continue | 124 | continue |
| 124 | 125 | ||
| 125 | title = paragraph.get("title", f"段落 {i+1}") | 126 | title = paragraph.get("title", f"段落 {i+1}") |
| 126 | content = paragraph.get("content", "") | 127 | content = paragraph.get("content", "") |
| 127 | 128 | ||
| 128 | if not title or not content: | 129 | if not title or not content: |
| 129 | - self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过") | 130 | + logger.warning(f"段落 {i+1} 缺少标题或内容,跳过") |
| 130 | continue | 131 | continue |
| 131 | 132 | ||
| 132 | validated_structure.append({ | 133 | validated_structure.append({ |
| @@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode): | @@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode): | ||
| 135 | }) | 136 | }) |
| 136 | 137 | ||
| 137 | if not validated_structure: | 138 | if not validated_structure: |
| 138 | - self.log_warning("没有有效的段落结构,使用默认结构") | 139 | + logger.warning("没有有效的段落结构,使用默认结构") |
| 139 | return self._generate_default_structure() | 140 | return self._generate_default_structure() |
| 140 | 141 | ||
| 141 | - self.log_info(f"成功验证 {len(validated_structure)} 个段落结构") | 142 | + logger.info(f"成功验证 {len(validated_structure)} 个段落结构") |
| 142 | return validated_structure | 143 | return validated_structure |
| 143 | 144 | ||
| 144 | except Exception as e: | 145 | except Exception as e: |
| 145 | - self.log_error(f"处理输出失败: {str(e)}") | 146 | + logger.exception(f"处理输出失败: {str(e)}") |
| 146 | return self._generate_default_structure() | 147 | return self._generate_default_structure() |
| 147 | 148 | ||
| 148 | def _generate_default_structure(self) -> List[Dict[str, str]]: | 149 | def _generate_default_structure(self) -> List[Dict[str, str]]: |
| @@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode): | @@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode): | ||
| 152 | Returns: | 153 | Returns: |
| 153 | 默认的报告结构列表 | 154 | 默认的报告结构列表 |
| 154 | """ | 155 | """ |
| 155 | - self.log_info("生成默认报告结构") | 156 | + logger.info("生成默认报告结构") |
| 156 | return [ | 157 | return [ |
| 157 | { | 158 | { |
| 158 | "title": "研究概述", | 159 | "title": "研究概述", |
| @@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode): | @@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode): | ||
| 195 | content=paragraph_data["content"] | 196 | content=paragraph_data["content"] |
| 196 | ) | 197 | ) |
| 197 | 198 | ||
| 198 | - self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中") | 199 | + logger.info(f"已将 {len(report_structure)} 个段落添加到状态中") |
| 199 | return state | 200 | return state |
| 200 | 201 | ||
| 201 | except Exception as e: | 202 | except Exception as e: |
| 202 | - self.log_error(f"状态更新失败: {str(e)}") | 203 | + logger.exception(f"状态更新失败: {str(e)}") |
| 203 | raise e | 204 | raise e |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | import json | 6 | import json |
| 7 | from typing import Dict, Any | 7 | from typing import Dict, Any |
| 8 | from json.decoder import JSONDecodeError | 8 | from json.decoder import JSONDecodeError |
| 9 | +from loguru import logger | ||
| 9 | 10 | ||
| 10 | from .base_node import BaseNode | 11 | from .base_node import BaseNode |
| 11 | from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION | 12 | from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION |
| @@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode): | @@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode): | ||
| 62 | else: | 63 | else: |
| 63 | message = json.dumps(input_data, ensure_ascii=False) | 64 | message = json.dumps(input_data, ensure_ascii=False) |
| 64 | 65 | ||
| 65 | - self.log_info("正在生成首次搜索查询") | 66 | + logger.info("正在生成首次搜索查询") |
| 66 | 67 | ||
| 67 | # 调用LLM | 68 | # 调用LLM |
| 68 | response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message) | 69 | response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message) |
| @@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode): | @@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode): | ||
| 70 | # 处理响应 | 71 | # 处理响应 |
| 71 | processed_response = self.process_output(response) | 72 | processed_response = self.process_output(response) |
| 72 | 73 | ||
| 73 | - self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}") | 74 | + logger.info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}") |
| 74 | return processed_response | 75 | return processed_response |
| 75 | 76 | ||
| 76 | except Exception as e: | 77 | except Exception as e: |
| 77 | - self.log_error(f"生成首次搜索查询失败: {str(e)}") | 78 | + logger.exception(f"生成首次搜索查询失败: {str(e)}") |
| 78 | raise e | 79 | raise e |
| 79 | 80 | ||
| 80 | def process_output(self, output: str) -> Dict[str, str]: | 81 | def process_output(self, output: str) -> Dict[str, str]: |
| @@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode): | @@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode): | ||
| 93 | cleaned_output = clean_json_tags(cleaned_output) | 94 | cleaned_output = clean_json_tags(cleaned_output) |
| 94 | 95 | ||
| 95 | # 记录清理后的输出用于调试 | 96 | # 记录清理后的输出用于调试 |
| 96 | - self.log_info(f"清理后的输出: {cleaned_output}") | 97 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 97 | 98 | ||
| 98 | # 解析JSON | 99 | # 解析JSON |
| 99 | try: | 100 | try: |
| 100 | result = json.loads(cleaned_output) | 101 | result = json.loads(cleaned_output) |
| 101 | - self.log_info("JSON解析成功") | 102 | + logger.info("JSON解析成功") |
| 102 | except JSONDecodeError as e: | 103 | except JSONDecodeError as e: |
| 103 | - self.log_info(f"JSON解析失败: {str(e)}") | 104 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 104 | # 使用更强大的提取方法 | 105 | # 使用更强大的提取方法 |
| 105 | result = extract_clean_response(cleaned_output) | 106 | result = extract_clean_response(cleaned_output) |
| 106 | if "error" in result: | 107 | if "error" in result: |
| 107 | - self.log_error("JSON解析失败,尝试修复...") | 108 | + logger.error("JSON解析失败,尝试修复...") |
| 108 | # 尝试修复JSON | 109 | # 尝试修复JSON |
| 109 | fixed_json = fix_incomplete_json(cleaned_output) | 110 | fixed_json = fix_incomplete_json(cleaned_output) |
| 110 | if fixed_json: | 111 | if fixed_json: |
| 111 | try: | 112 | try: |
| 112 | result = json.loads(fixed_json) | 113 | result = json.loads(fixed_json) |
| 113 | - self.log_info("JSON修复成功") | 114 | + logger.info("JSON修复成功") |
| 114 | except JSONDecodeError: | 115 | except JSONDecodeError: |
| 115 | - self.log_error("JSON修复失败") | 116 | + logger.error("JSON修复失败") |
| 116 | # 返回默认查询 | 117 | # 返回默认查询 |
| 117 | return self._get_default_search_query() | 118 | return self._get_default_search_query() |
| 118 | else: | 119 | else: |
| 119 | - self.log_error("无法修复JSON,使用默认查询") | 120 | + logger.error("无法修复JSON,使用默认查询") |
| 120 | return self._get_default_search_query() | 121 | return self._get_default_search_query() |
| 121 | 122 | ||
| 122 | # 验证和清理结果 | 123 | # 验证和清理结果 |
| @@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode): | @@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode): | ||
| 124 | reasoning = result.get("reasoning", "") | 125 | reasoning = result.get("reasoning", "") |
| 125 | 126 | ||
| 126 | if not search_query: | 127 | if not search_query: |
| 127 | - self.log_warning("未找到搜索查询,使用默认查询") | 128 | + logger.warning("未找到搜索查询,使用默认查询") |
| 128 | return self._get_default_search_query() | 129 | return self._get_default_search_query() |
| 129 | 130 | ||
| 130 | return { | 131 | return { |
| @@ -197,7 +198,7 @@ class ReflectionNode(BaseNode): | @@ -197,7 +198,7 @@ class ReflectionNode(BaseNode): | ||
| 197 | else: | 198 | else: |
| 198 | message = json.dumps(input_data, ensure_ascii=False) | 199 | message = json.dumps(input_data, ensure_ascii=False) |
| 199 | 200 | ||
| 200 | - self.log_info("正在进行反思并生成新搜索查询") | 201 | + logger.info("正在进行反思并生成新搜索查询") |
| 201 | 202 | ||
| 202 | # 调用LLM | 203 | # 调用LLM |
| 203 | response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message) | 204 | response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message) |
| @@ -205,11 +206,11 @@ class ReflectionNode(BaseNode): | @@ -205,11 +206,11 @@ class ReflectionNode(BaseNode): | ||
| 205 | # 处理响应 | 206 | # 处理响应 |
| 206 | processed_response = self.process_output(response) | 207 | processed_response = self.process_output(response) |
| 207 | 208 | ||
| 208 | - self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}") | 209 | + logger.info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}") |
| 209 | return processed_response | 210 | return processed_response |
| 210 | 211 | ||
| 211 | except Exception as e: | 212 | except Exception as e: |
| 212 | - self.log_error(f"反思生成搜索查询失败: {str(e)}") | 213 | + logger.exception(f"反思生成搜索查询失败: {str(e)}") |
| 213 | raise e | 214 | raise e |
| 214 | 215 | ||
| 215 | def process_output(self, output: str) -> Dict[str, str]: | 216 | def process_output(self, output: str) -> Dict[str, str]: |
| @@ -228,30 +229,30 @@ class ReflectionNode(BaseNode): | @@ -228,30 +229,30 @@ class ReflectionNode(BaseNode): | ||
| 228 | cleaned_output = clean_json_tags(cleaned_output) | 229 | cleaned_output = clean_json_tags(cleaned_output) |
| 229 | 230 | ||
| 230 | # 记录清理后的输出用于调试 | 231 | # 记录清理后的输出用于调试 |
| 231 | - self.log_info(f"清理后的输出: {cleaned_output}") | 232 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 232 | 233 | ||
| 233 | # 解析JSON | 234 | # 解析JSON |
| 234 | try: | 235 | try: |
| 235 | result = json.loads(cleaned_output) | 236 | result = json.loads(cleaned_output) |
| 236 | - self.log_info("JSON解析成功") | 237 | + logger.info("JSON解析成功") |
| 237 | except JSONDecodeError as e: | 238 | except JSONDecodeError as e: |
| 238 | - self.log_info(f"JSON解析失败: {str(e)}") | 239 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 239 | # 使用更强大的提取方法 | 240 | # 使用更强大的提取方法 |
| 240 | result = extract_clean_response(cleaned_output) | 241 | result = extract_clean_response(cleaned_output) |
| 241 | if "error" in result: | 242 | if "error" in result: |
| 242 | - self.log_error("JSON解析失败,尝试修复...") | 243 | + logger.error("JSON解析失败,尝试修复...") |
| 243 | # 尝试修复JSON | 244 | # 尝试修复JSON |
| 244 | fixed_json = fix_incomplete_json(cleaned_output) | 245 | fixed_json = fix_incomplete_json(cleaned_output) |
| 245 | if fixed_json: | 246 | if fixed_json: |
| 246 | try: | 247 | try: |
| 247 | result = json.loads(fixed_json) | 248 | result = json.loads(fixed_json) |
| 248 | - self.log_info("JSON修复成功") | 249 | + logger.info("JSON修复成功") |
| 249 | except JSONDecodeError: | 250 | except JSONDecodeError: |
| 250 | - self.log_error("JSON修复失败") | 251 | + logger.error("JSON修复失败") |
| 251 | # 返回默认查询 | 252 | # 返回默认查询 |
| 252 | return self._get_default_reflection_query() | 253 | return self._get_default_reflection_query() |
| 253 | else: | 254 | else: |
| 254 | - self.log_error("无法修复JSON,使用默认查询") | 255 | + logger.error("无法修复JSON,使用默认查询") |
| 255 | return self._get_default_reflection_query() | 256 | return self._get_default_reflection_query() |
| 256 | 257 | ||
| 257 | # 验证和清理结果 | 258 | # 验证和清理结果 |
| @@ -259,7 +260,7 @@ class ReflectionNode(BaseNode): | @@ -259,7 +260,7 @@ class ReflectionNode(BaseNode): | ||
| 259 | reasoning = result.get("reasoning", "") | 260 | reasoning = result.get("reasoning", "") |
| 260 | 261 | ||
| 261 | if not search_query: | 262 | if not search_query: |
| 262 | - self.log_warning("未找到搜索查询,使用默认查询") | 263 | + logger.warning("未找到搜索查询,使用默认查询") |
| 263 | return self._get_default_reflection_query() | 264 | return self._get_default_reflection_query() |
| 264 | 265 | ||
| 265 | return { | 266 | return { |
| @@ -268,7 +269,7 @@ class ReflectionNode(BaseNode): | @@ -268,7 +269,7 @@ class ReflectionNode(BaseNode): | ||
| 268 | } | 269 | } |
| 269 | 270 | ||
| 270 | except Exception as e: | 271 | except Exception as e: |
| 271 | - self.log_error(f"处理输出失败: {str(e)}") | 272 | + logger.exception(f"处理输出失败: {str(e)}") |
| 272 | # 返回默认查询 | 273 | # 返回默认查询 |
| 273 | return self._get_default_reflection_query() | 274 | return self._get_default_reflection_query() |
| 274 | 275 |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | import json | 6 | import json |
| 7 | from typing import Dict, Any, List | 7 | from typing import Dict, Any, List |
| 8 | from json.decoder import JSONDecodeError | 8 | from json.decoder import JSONDecodeError |
| 9 | +from loguru import logger | ||
| 9 | 10 | ||
| 10 | from .base_node import StateMutationNode | 11 | from .base_node import StateMutationNode |
| 11 | from ..state.state import State | 12 | from ..state.state import State |
| @@ -27,7 +28,7 @@ try: | @@ -27,7 +28,7 @@ try: | ||
| 27 | FORUM_READER_AVAILABLE = True | 28 | FORUM_READER_AVAILABLE = True |
| 28 | except ImportError: | 29 | except ImportError: |
| 29 | FORUM_READER_AVAILABLE = False | 30 | FORUM_READER_AVAILABLE = False |
| 30 | - print("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能") | 31 | + logger.warning("无法导入forum_reader模块,将跳过HOST发言读取功能") |
| 31 | 32 | ||
| 32 | 33 | ||
| 33 | class FirstSummaryNode(StateMutationNode): | 34 | class FirstSummaryNode(StateMutationNode): |
| @@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode): | @@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode): | ||
| 84 | if host_speech: | 85 | if host_speech: |
| 85 | # 将HOST发言添加到输入数据中 | 86 | # 将HOST发言添加到输入数据中 |
| 86 | data['host_speech'] = host_speech | 87 | data['host_speech'] = host_speech |
| 87 | - self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符") | 88 | + logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符") |
| 88 | except Exception as e: | 89 | except Exception as e: |
| 89 | - self.log_info(f"读取HOST发言失败: {str(e)}") | 90 | + logger.exception(f"读取HOST发言失败: {str(e)}") |
| 90 | 91 | ||
| 91 | # 转换为JSON字符串 | 92 | # 转换为JSON字符串 |
| 92 | message = json.dumps(data, ensure_ascii=False) | 93 | message = json.dumps(data, ensure_ascii=False) |
| @@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode): | @@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode): | ||
| 96 | formatted_host = format_host_speech_for_prompt(data['host_speech']) | 97 | formatted_host = format_host_speech_for_prompt(data['host_speech']) |
| 97 | message = formatted_host + "\n" + message | 98 | message = formatted_host + "\n" + message |
| 98 | 99 | ||
| 99 | - self.log_info("正在生成首次段落总结") | 100 | + logger.info("正在生成首次段落总结") |
| 100 | 101 | ||
| 101 | # 调用LLM生成总结 | 102 | # 调用LLM生成总结 |
| 102 | response = self.llm_client.invoke( | 103 | response = self.llm_client.invoke( |
| @@ -107,11 +108,11 @@ class FirstSummaryNode(StateMutationNode): | @@ -107,11 +108,11 @@ class FirstSummaryNode(StateMutationNode): | ||
| 107 | # 处理响应 | 108 | # 处理响应 |
| 108 | processed_response = self.process_output(response) | 109 | processed_response = self.process_output(response) |
| 109 | 110 | ||
| 110 | - self.log_info("成功生成首次段落总结") | 111 | + logger.info("成功生成首次段落总结") |
| 111 | return processed_response | 112 | return processed_response |
| 112 | 113 | ||
| 113 | except Exception as e: | 114 | except Exception as e: |
| 114 | - self.log_error(f"生成首次总结失败: {str(e)}") | 115 | + logger.exception(f"生成首次总结失败: {str(e)}") |
| 115 | raise e | 116 | raise e |
| 116 | 117 | ||
| 117 | def process_output(self, output: str) -> str: | 118 | def process_output(self, output: str) -> str: |
| @@ -130,26 +131,26 @@ class FirstSummaryNode(StateMutationNode): | @@ -130,26 +131,26 @@ class FirstSummaryNode(StateMutationNode): | ||
| 130 | cleaned_output = clean_json_tags(cleaned_output) | 131 | cleaned_output = clean_json_tags(cleaned_output) |
| 131 | 132 | ||
| 132 | # 记录清理后的输出用于调试 | 133 | # 记录清理后的输出用于调试 |
| 133 | - self.log_info(f"清理后的输出: {cleaned_output}") | 134 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 134 | 135 | ||
| 135 | # 解析JSON | 136 | # 解析JSON |
| 136 | try: | 137 | try: |
| 137 | result = json.loads(cleaned_output) | 138 | result = json.loads(cleaned_output) |
| 138 | - self.log_info("JSON解析成功") | 139 | + logger.info("JSON解析成功") |
| 139 | except JSONDecodeError as e: | 140 | except JSONDecodeError as e: |
| 140 | - self.log_info(f"JSON解析失败: {str(e)}") | 141 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 141 | # 尝试修复JSON | 142 | # 尝试修复JSON |
| 142 | fixed_json = fix_incomplete_json(cleaned_output) | 143 | fixed_json = fix_incomplete_json(cleaned_output) |
| 143 | if fixed_json: | 144 | if fixed_json: |
| 144 | try: | 145 | try: |
| 145 | result = json.loads(fixed_json) | 146 | result = json.loads(fixed_json) |
| 146 | - self.log_info("JSON修复成功") | 147 | + logger.info("JSON修复成功") |
| 147 | except JSONDecodeError: | 148 | except JSONDecodeError: |
| 148 | - self.log_info("JSON修复失败,直接使用清理后的文本") | 149 | + logger.exception("JSON修复失败,直接使用清理后的文本") |
| 149 | # 如果不是JSON格式,直接返回清理后的文本 | 150 | # 如果不是JSON格式,直接返回清理后的文本 |
| 150 | return cleaned_output | 151 | return cleaned_output |
| 151 | else: | 152 | else: |
| 152 | - self.log_info("无法修复JSON,直接使用清理后的文本") | 153 | + logger.exception("无法修复JSON,直接使用清理后的文本") |
| 153 | # 如果不是JSON格式,直接返回清理后的文本 | 154 | # 如果不是JSON格式,直接返回清理后的文本 |
| 154 | return cleaned_output | 155 | return cleaned_output |
| 155 | 156 | ||
| @@ -163,7 +164,7 @@ class FirstSummaryNode(StateMutationNode): | @@ -163,7 +164,7 @@ class FirstSummaryNode(StateMutationNode): | ||
| 163 | return cleaned_output | 164 | return cleaned_output |
| 164 | 165 | ||
| 165 | except Exception as e: | 166 | except Exception as e: |
| 166 | - self.log_error(f"处理输出失败: {str(e)}") | 167 | + logger.exception(f"处理输出失败: {str(e)}") |
| 167 | return "段落总结生成失败" | 168 | return "段落总结生成失败" |
| 168 | 169 | ||
| 169 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: | 170 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: |
| @@ -186,7 +187,7 @@ class FirstSummaryNode(StateMutationNode): | @@ -186,7 +187,7 @@ class FirstSummaryNode(StateMutationNode): | ||
| 186 | # 更新状态 | 187 | # 更新状态 |
| 187 | if 0 <= paragraph_index < len(state.paragraphs): | 188 | if 0 <= paragraph_index < len(state.paragraphs): |
| 188 | state.paragraphs[paragraph_index].research.latest_summary = summary | 189 | state.paragraphs[paragraph_index].research.latest_summary = summary |
| 189 | - self.log_info(f"已更新段落 {paragraph_index} 的首次总结") | 190 | + logger.info(f"已更新段落 {paragraph_index} 的首次总结") |
| 190 | else: | 191 | else: |
| 191 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") | 192 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") |
| 192 | 193 | ||
| @@ -194,7 +195,7 @@ class FirstSummaryNode(StateMutationNode): | @@ -194,7 +195,7 @@ class FirstSummaryNode(StateMutationNode): | ||
| 194 | return state | 195 | return state |
| 195 | 196 | ||
| 196 | except Exception as e: | 197 | except Exception as e: |
| 197 | - self.log_error(f"状态更新失败: {str(e)}") | 198 | + logger.exception(f"状态更新失败: {str(e)}") |
| 198 | raise e | 199 | raise e |
| 199 | 200 | ||
| 200 | 201 | ||
| @@ -252,9 +253,9 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -252,9 +253,9 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 252 | if host_speech: | 253 | if host_speech: |
| 253 | # 将HOST发言添加到输入数据中 | 254 | # 将HOST发言添加到输入数据中 |
| 254 | data['host_speech'] = host_speech | 255 | data['host_speech'] = host_speech |
| 255 | - self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符") | 256 | + logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符") |
| 256 | except Exception as e: | 257 | except Exception as e: |
| 257 | - self.log_info(f"读取HOST发言失败: {str(e)}") | 258 | + logger.exception(f"读取HOST发言失败: {str(e)}") |
| 258 | 259 | ||
| 259 | # 转换为JSON字符串 | 260 | # 转换为JSON字符串 |
| 260 | message = json.dumps(data, ensure_ascii=False) | 261 | message = json.dumps(data, ensure_ascii=False) |
| @@ -264,7 +265,7 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -264,7 +265,7 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 264 | formatted_host = format_host_speech_for_prompt(data['host_speech']) | 265 | formatted_host = format_host_speech_for_prompt(data['host_speech']) |
| 265 | message = formatted_host + "\n" + message | 266 | message = formatted_host + "\n" + message |
| 266 | 267 | ||
| 267 | - self.log_info("正在生成反思总结") | 268 | + logger.info("正在生成反思总结") |
| 268 | 269 | ||
| 269 | # 调用LLM生成总结 | 270 | # 调用LLM生成总结 |
| 270 | response = self.llm_client.invoke( | 271 | response = self.llm_client.invoke( |
| @@ -275,11 +276,11 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -275,11 +276,11 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 275 | # 处理响应 | 276 | # 处理响应 |
| 276 | processed_response = self.process_output(response) | 277 | processed_response = self.process_output(response) |
| 277 | 278 | ||
| 278 | - self.log_info("成功生成反思总结") | 279 | + logger.info("成功生成反思总结") |
| 279 | return processed_response | 280 | return processed_response |
| 280 | 281 | ||
| 281 | except Exception as e: | 282 | except Exception as e: |
| 282 | - self.log_error(f"生成反思总结失败: {str(e)}") | 283 | + logger.exception(f"生成反思总结失败: {str(e)}") |
| 283 | raise e | 284 | raise e |
| 284 | 285 | ||
| 285 | def process_output(self, output: str) -> str: | 286 | def process_output(self, output: str) -> str: |
| @@ -298,26 +299,26 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -298,26 +299,26 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 298 | cleaned_output = clean_json_tags(cleaned_output) | 299 | cleaned_output = clean_json_tags(cleaned_output) |
| 299 | 300 | ||
| 300 | # 记录清理后的输出用于调试 | 301 | # 记录清理后的输出用于调试 |
| 301 | - self.log_info(f"清理后的输出: {cleaned_output}") | 302 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 302 | 303 | ||
| 303 | # 解析JSON | 304 | # 解析JSON |
| 304 | try: | 305 | try: |
| 305 | result = json.loads(cleaned_output) | 306 | result = json.loads(cleaned_output) |
| 306 | - self.log_info("JSON解析成功") | 307 | + logger.info("JSON解析成功") |
| 307 | except JSONDecodeError as e: | 308 | except JSONDecodeError as e: |
| 308 | - self.log_info(f"JSON解析失败: {str(e)}") | 309 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 309 | # 尝试修复JSON | 310 | # 尝试修复JSON |
| 310 | fixed_json = fix_incomplete_json(cleaned_output) | 311 | fixed_json = fix_incomplete_json(cleaned_output) |
| 311 | if fixed_json: | 312 | if fixed_json: |
| 312 | try: | 313 | try: |
| 313 | result = json.loads(fixed_json) | 314 | result = json.loads(fixed_json) |
| 314 | - self.log_info("JSON修复成功") | 315 | + logger.info("JSON修复成功") |
| 315 | except JSONDecodeError: | 316 | except JSONDecodeError: |
| 316 | - self.log_info("JSON修复失败,直接使用清理后的文本") | 317 | + logger.exception("JSON修复失败,直接使用清理后的文本") |
| 317 | # 如果不是JSON格式,直接返回清理后的文本 | 318 | # 如果不是JSON格式,直接返回清理后的文本 |
| 318 | return cleaned_output | 319 | return cleaned_output |
| 319 | else: | 320 | else: |
| 320 | - self.log_info("无法修复JSON,直接使用清理后的文本") | 321 | + logger.exception("无法修复JSON,直接使用清理后的文本") |
| 321 | # 如果不是JSON格式,直接返回清理后的文本 | 322 | # 如果不是JSON格式,直接返回清理后的文本 |
| 322 | return cleaned_output | 323 | return cleaned_output |
| 323 | 324 | ||
| @@ -331,7 +332,7 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -331,7 +332,7 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 331 | return cleaned_output | 332 | return cleaned_output |
| 332 | 333 | ||
| 333 | except Exception as e: | 334 | except Exception as e: |
| 334 | - self.log_error(f"处理输出失败: {str(e)}") | 335 | + logger.exception(f"处理输出失败: {str(e)}") |
| 335 | return "反思总结生成失败" | 336 | return "反思总结生成失败" |
| 336 | 337 | ||
| 337 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: | 338 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: |
| @@ -355,7 +356,7 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -355,7 +356,7 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 355 | if 0 <= paragraph_index < len(state.paragraphs): | 356 | if 0 <= paragraph_index < len(state.paragraphs): |
| 356 | state.paragraphs[paragraph_index].research.latest_summary = updated_summary | 357 | state.paragraphs[paragraph_index].research.latest_summary = updated_summary |
| 357 | state.paragraphs[paragraph_index].research.increment_reflection() | 358 | state.paragraphs[paragraph_index].research.increment_reflection() |
| 358 | - self.log_info(f"已更新段落 {paragraph_index} 的反思总结") | 359 | + logger.info(f"已更新段落 {paragraph_index} 的反思总结") |
| 359 | else: | 360 | else: |
| 360 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") | 361 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") |
| 361 | 362 | ||
| @@ -363,5 +364,5 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -363,5 +364,5 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 363 | return state | 364 | return state |
| 364 | 365 | ||
| 365 | except Exception as e: | 366 | except Exception as e: |
| 366 | - self.log_error(f"状态更新失败: {str(e)}") | 367 | + logger.exception(f"状态更新失败: {str(e)}") |
| 367 | raise e | 368 | raise e |
| @@ -25,6 +25,9 @@ import json | @@ -25,6 +25,9 @@ import json | ||
| 25 | import sys | 25 | import sys |
| 26 | from typing import List, Dict, Any, Optional, Literal | 26 | from typing import List, Dict, Any, Optional, Literal |
| 27 | 27 | ||
| 28 | +from loguru import logger | ||
| 29 | +from config import settings | ||
| 30 | + | ||
| 28 | # 运行前请确保已安装 requests 库: pip install requests | 31 | # 运行前请确保已安装 requests 库: pip install requests |
| 29 | try: | 32 | try: |
| 30 | import requests | 33 | import requests |
| @@ -90,8 +93,8 @@ class BochaMultimodalSearch: | @@ -90,8 +93,8 @@ class BochaMultimodalSearch: | ||
| 90 | 一个包含多种专用多模态搜索工具的客户端。 | 93 | 一个包含多种专用多模态搜索工具的客户端。 |
| 91 | 每个公共方法都设计为供 AI Agent 独立调用的工具。 | 94 | 每个公共方法都设计为供 AI Agent 独立调用的工具。 |
| 92 | """ | 95 | """ |
| 93 | - | ||
| 94 | - BASE_URL = "https://api.bochaai.com/v1/ai-search" | 96 | + |
| 97 | + BOCHA_BASE_URL = settings.BOCHA_BASE_URL or "https://api.bochaai.com/v1/ai-search" | ||
| 95 | 98 | ||
| 96 | def __init__(self, api_key: Optional[str] = None): | 99 | def __init__(self, api_key: Optional[str] = None): |
| 97 | """ | 100 | """ |
| @@ -100,10 +103,10 @@ class BochaMultimodalSearch: | @@ -100,10 +103,10 @@ class BochaMultimodalSearch: | ||
| 100 | api_key: Bocha API密钥,若不提供则从环境变量 BOCHA_API_KEY 读取。 | 103 | api_key: Bocha API密钥,若不提供则从环境变量 BOCHA_API_KEY 读取。 |
| 101 | """ | 104 | """ |
| 102 | if api_key is None: | 105 | if api_key is None: |
| 103 | - api_key = os.getenv("BOCHA_API_KEY") | 106 | + api_key = settings.BOCHA_WEB_SEARCH_API_KEY |
| 104 | if not api_key: | 107 | if not api_key: |
| 105 | raise ValueError("Bocha API Key未找到!请设置 BOCHA_API_KEY 环境变量或在初始化时提供") | 108 | raise ValueError("Bocha API Key未找到!请设置 BOCHA_API_KEY 环境变量或在初始化时提供") |
| 106 | - | 109 | + |
| 107 | self._headers = { | 110 | self._headers = { |
| 108 | 'Authorization': f'Bearer {api_key}', | 111 | 'Authorization': f'Bearer {api_key}', |
| 109 | 'Content-Type': 'application/json', | 112 | 'Content-Type': 'application/json', |
| @@ -112,7 +115,7 @@ class BochaMultimodalSearch: | @@ -112,7 +115,7 @@ class BochaMultimodalSearch: | ||
| 112 | 115 | ||
| 113 | def _parse_search_response(self, response_dict: Dict[str, Any], query: str) -> BochaResponse: | 116 | def _parse_search_response(self, response_dict: Dict[str, Any], query: str) -> BochaResponse: |
| 114 | """从API的原始字典响应中解析出结构化的BochaResponse对象""" | 117 | """从API的原始字典响应中解析出结构化的BochaResponse对象""" |
| 115 | - | 118 | + |
| 116 | final_response = BochaResponse(query=query) | 119 | final_response = BochaResponse(query=query) |
| 117 | final_response.conversation_id = response_dict.get('conversation_id') | 120 | final_response.conversation_id = response_dict.get('conversation_id') |
| 118 | 121 | ||
| @@ -125,7 +128,7 @@ class BochaMultimodalSearch: | @@ -125,7 +128,7 @@ class BochaMultimodalSearch: | ||
| 125 | msg_type = msg.get('type') | 128 | msg_type = msg.get('type') |
| 126 | content_type = msg.get('content_type') | 129 | content_type = msg.get('content_type') |
| 127 | content_str = msg.get('content', '{}') | 130 | content_str = msg.get('content', '{}') |
| 128 | - | 131 | + |
| 129 | try: | 132 | try: |
| 130 | content_data = json.loads(content_str) | 133 | content_data = json.loads(content_str) |
| 131 | except json.JSONDecodeError: | 134 | except json.JSONDecodeError: |
| @@ -134,7 +137,7 @@ class BochaMultimodalSearch: | @@ -134,7 +137,7 @@ class BochaMultimodalSearch: | ||
| 134 | 137 | ||
| 135 | if msg_type == 'answer' and content_type == 'text': | 138 | if msg_type == 'answer' and content_type == 'text': |
| 136 | final_response.answer = content_data | 139 | final_response.answer = content_data |
| 137 | - | 140 | + |
| 138 | elif msg_type == 'follow_up' and content_type == 'text': | 141 | elif msg_type == 'follow_up' and content_type == 'text': |
| 139 | final_response.follow_ups.append(content_data) | 142 | final_response.follow_ups.append(content_data) |
| 140 | 143 | ||
| @@ -164,7 +167,7 @@ class BochaMultimodalSearch: | @@ -164,7 +167,7 @@ class BochaMultimodalSearch: | ||
| 164 | card_type=content_type, | 167 | card_type=content_type, |
| 165 | content=content_data | 168 | content=content_data |
| 166 | )) | 169 | )) |
| 167 | - | 170 | + |
| 168 | return final_response | 171 | return final_response |
| 169 | 172 | ||
| 170 | 173 | ||
| @@ -176,23 +179,23 @@ class BochaMultimodalSearch: | @@ -176,23 +179,23 @@ class BochaMultimodalSearch: | ||
| 176 | "stream": False, # Agent工具通常使用非流式以获取完整结果 | 179 | "stream": False, # Agent工具通常使用非流式以获取完整结果 |
| 177 | } | 180 | } |
| 178 | payload.update(kwargs) | 181 | payload.update(kwargs) |
| 179 | - | 182 | + |
| 180 | try: | 183 | try: |
| 181 | - response = requests.post(self.BASE_URL, headers=self._headers, json=payload, timeout=30) | 184 | + response = requests.post(self.BOCHA_BASE_URL, headers=self._headers, json=payload, timeout=30) |
| 182 | response.raise_for_status() # 如果HTTP状态码是4xx或5xx,则抛出异常 | 185 | response.raise_for_status() # 如果HTTP状态码是4xx或5xx,则抛出异常 |
| 183 | - | 186 | + |
| 184 | response_dict = response.json() | 187 | response_dict = response.json() |
| 185 | if response_dict.get("code") != 200: | 188 | if response_dict.get("code") != 200: |
| 186 | - print(f"API返回错误: {response_dict.get('msg', '未知错误')}") | 189 | + logger.error(f"API返回错误: {response_dict.get('msg', '未知错误')}") |
| 187 | return BochaResponse(query=query) | 190 | return BochaResponse(query=query) |
| 188 | 191 | ||
| 189 | return self._parse_search_response(response_dict, query) | 192 | return self._parse_search_response(response_dict, query) |
| 190 | 193 | ||
| 191 | except requests.exceptions.RequestException as e: | 194 | except requests.exceptions.RequestException as e: |
| 192 | - print(f"搜索时发生网络错误: {str(e)}") | 195 | + logger.exception(f"搜索时发生网络错误: {str(e)}") |
| 193 | raise e # 让重试机制捕获并处理 | 196 | raise e # 让重试机制捕获并处理 |
| 194 | except Exception as e: | 197 | except Exception as e: |
| 195 | - print(f"处理响应时发生未知错误: {str(e)}") | 198 | + logger.exception(f"处理响应时发生未知错误: {str(e)}") |
| 196 | raise e # 让重试机制捕获并处理 | 199 | raise e # 让重试机制捕获并处理 |
| 197 | 200 | ||
| 198 | # --- Agent 可用的工具方法 --- | 201 | # --- Agent 可用的工具方法 --- |
| @@ -203,19 +206,19 @@ class BochaMultimodalSearch: | @@ -203,19 +206,19 @@ class BochaMultimodalSearch: | ||
| 203 | 返回网页、图片、AI总结、追问建议和可能的模态卡。这是最常用的通用搜索工具。 | 206 | 返回网页、图片、AI总结、追问建议和可能的模态卡。这是最常用的通用搜索工具。 |
| 204 | Agent可提供搜索查询(query)和可选的最大结果数(max_results)。 | 207 | Agent可提供搜索查询(query)和可选的最大结果数(max_results)。 |
| 205 | """ | 208 | """ |
| 206 | - print(f"--- TOOL: 全面综合搜索 (query: {query}) ---") | 209 | + logger.info(f"--- TOOL: 全面综合搜索 (query: {query}) ---") |
| 207 | return self._search_internal( | 210 | return self._search_internal( |
| 208 | query=query, | 211 | query=query, |
| 209 | count=max_results, | 212 | count=max_results, |
| 210 | answer=True # 开启AI总结 | 213 | answer=True # 开启AI总结 |
| 211 | ) | 214 | ) |
| 212 | - | 215 | + |
| 213 | def web_search_only(self, query: str, max_results: int = 15) -> BochaResponse: | 216 | def web_search_only(self, query: str, max_results: int = 15) -> BochaResponse: |
| 214 | """ | 217 | """ |
| 215 | 【工具】纯网页搜索: 只获取网页链接和摘要,不请求AI生成答案。 | 218 | 【工具】纯网页搜索: 只获取网页链接和摘要,不请求AI生成答案。 |
| 216 | 适用于需要快速获取原始网页信息,而不需要AI额外分析的场景。速度更快,成本更低。 | 219 | 适用于需要快速获取原始网页信息,而不需要AI额外分析的场景。速度更快,成本更低。 |
| 217 | """ | 220 | """ |
| 218 | - print(f"--- TOOL: 纯网页搜索 (query: {query}) ---") | 221 | + logger.info(f"--- TOOL: 纯网页搜索 (query: {query}) ---") |
| 219 | return self._search_internal( | 222 | return self._search_internal( |
| 220 | query=query, | 223 | query=query, |
| 221 | count=max_results, | 224 | count=max_results, |
| @@ -228,7 +231,7 @@ class BochaMultimodalSearch: | @@ -228,7 +231,7 @@ class BochaMultimodalSearch: | ||
| 228 | 当Agent意图是查询天气、股票、汇率、百科定义、火车票、汽车参数等结构化信息时,应优先使用此工具。 | 231 | 当Agent意图是查询天气、股票、汇率、百科定义、火车票、汽车参数等结构化信息时,应优先使用此工具。 |
| 229 | 它会返回所有信息,但Agent应重点关注结果中的 `modal_cards` 部分。 | 232 | 它会返回所有信息,但Agent应重点关注结果中的 `modal_cards` 部分。 |
| 230 | """ | 233 | """ |
| 231 | - print(f"--- TOOL: 结构化数据查询 (query: {query}) ---") | 234 | + logger.info(f"--- TOOL: 结构化数据查询 (query: {query}) ---") |
| 232 | # 实现上与 comprehensive_search 相同,但通过命名和文档引导Agent的意图 | 235 | # 实现上与 comprehensive_search 相同,但通过命名和文档引导Agent的意图 |
| 233 | return self._search_internal( | 236 | return self._search_internal( |
| 234 | query=query, | 237 | query=query, |
| @@ -241,7 +244,7 @@ class BochaMultimodalSearch: | @@ -241,7 +244,7 @@ class BochaMultimodalSearch: | ||
| 241 | 【工具】搜索24小时内信息: 获取关于某个主题的最新动态。 | 244 | 【工具】搜索24小时内信息: 获取关于某个主题的最新动态。 |
| 242 | 此工具专门查找过去24小时内发布的内容。适用于追踪突发事件或最新进展。 | 245 | 此工具专门查找过去24小时内发布的内容。适用于追踪突发事件或最新进展。 |
| 243 | """ | 246 | """ |
| 244 | - print(f"--- TOOL: 搜索24小时内信息 (query: {query}) ---") | 247 | + logger.info(f"--- TOOL: 搜索24小时内信息 (query: {query}) ---") |
| 245 | return self._search_internal(query=query, freshness='oneDay', answer=True) | 248 | return self._search_internal(query=query, freshness='oneDay', answer=True) |
| 246 | 249 | ||
| 247 | def search_last_week(self, query: str) -> BochaResponse: | 250 | def search_last_week(self, query: str) -> BochaResponse: |
| @@ -249,7 +252,7 @@ class BochaMultimodalSearch: | @@ -249,7 +252,7 @@ class BochaMultimodalSearch: | ||
| 249 | 【工具】搜索本周信息: 获取关于某个主题过去一周内的主要报道。 | 252 | 【工具】搜索本周信息: 获取关于某个主题过去一周内的主要报道。 |
| 250 | 适用于进行周度舆情总结或回顾。 | 253 | 适用于进行周度舆情总结或回顾。 |
| 251 | """ | 254 | """ |
| 252 | - print(f"--- TOOL: 搜索本周信息 (query: {query}) ---") | 255 | + logger.info(f"--- TOOL: 搜索本周信息 (query: {query}) ---") |
| 253 | return self._search_internal(query=query, freshness='oneWeek', answer=True) | 256 | return self._search_internal(query=query, freshness='oneWeek', answer=True) |
| 254 | 257 | ||
| 255 | 258 | ||
| @@ -258,32 +261,32 @@ class BochaMultimodalSearch: | @@ -258,32 +261,32 @@ class BochaMultimodalSearch: | ||
| 258 | def print_response_summary(response: BochaResponse): | 261 | def print_response_summary(response: BochaResponse): |
| 259 | """简化的打印函数,用于展示测试结果""" | 262 | """简化的打印函数,用于展示测试结果""" |
| 260 | if not response or not response.query: | 263 | if not response or not response.query: |
| 261 | - print("未能获取有效响应。") | 264 | + logger.error("未能获取有效响应。") |
| 262 | return | 265 | return |
| 263 | - | ||
| 264 | - print(f"\n查询: '{response.query}' | 会话ID: {response.conversation_id}") | 266 | + |
| 267 | + logger.info(f"\n查询: '{response.query}' | 会话ID: {response.conversation_id}") | ||
| 265 | if response.answer: | 268 | if response.answer: |
| 266 | - print(f"AI摘要: {response.answer[:150]}...") | ||
| 267 | - | ||
| 268 | - print(f"找到 {len(response.webpages)} 个网页, {len(response.images)} 张图片, {len(response.modal_cards)} 个模态卡。") | 269 | + logger.info(f"AI摘要: {response.answer[:150]}...") |
| 270 | + | ||
| 271 | + logger.info(f"找到 {len(response.webpages)} 个网页, {len(response.images)} 张图片, {len(response.modal_cards)} 个模态卡。") | ||
| 269 | 272 | ||
| 270 | if response.modal_cards: | 273 | if response.modal_cards: |
| 271 | first_card = response.modal_cards[0] | 274 | first_card = response.modal_cards[0] |
| 272 | - print(f"第一个模态卡类型: {first_card.card_type}") | 275 | + logger.info(f"第一个模态卡类型: {first_card.card_type}") |
| 273 | 276 | ||
| 274 | if response.webpages: | 277 | if response.webpages: |
| 275 | first_result = response.webpages[0] | 278 | first_result = response.webpages[0] |
| 276 | - print(f"第一条网页结果: {first_result.name}") | 279 | + logger.info(f"第一条网页结果: {first_result.name}") |
| 277 | 280 | ||
| 278 | if response.follow_ups: | 281 | if response.follow_ups: |
| 279 | - print(f"建议追问: {response.follow_ups}") | 282 | + logger.info(f"建议追问: {response.follow_ups}") |
| 280 | 283 | ||
| 281 | - print("-" * 60) | 284 | + logger.info("-" * 60) |
| 282 | 285 | ||
| 283 | 286 | ||
| 284 | if __name__ == "__main__": | 287 | if __name__ == "__main__": |
| 285 | # 在运行前,请确保您已设置 BOCHA_API_KEY 环境变量 | 288 | # 在运行前,请确保您已设置 BOCHA_API_KEY 环境变量 |
| 286 | - | 289 | + |
| 287 | try: | 290 | try: |
| 288 | # 初始化多模态搜索客户端,它内部包含了所有工具 | 291 | # 初始化多模态搜索客户端,它内部包含了所有工具 |
| 289 | search_client = BochaMultimodalSearch() | 292 | search_client = BochaMultimodalSearch() |
| @@ -297,7 +300,7 @@ if __name__ == "__main__": | @@ -297,7 +300,7 @@ if __name__ == "__main__": | ||
| 297 | print_response_summary(response2) | 300 | print_response_summary(response2) |
| 298 | # 深度解析第一个模态卡 | 301 | # 深度解析第一个模态卡 |
| 299 | if response2.modal_cards and response2.modal_cards[0].card_type == 'weather_china': | 302 | if response2.modal_cards and response2.modal_cards[0].card_type == 'weather_china': |
| 300 | - print("天气模态卡详情:", json.dumps(response2.modal_cards[0].content, indent=2, ensure_ascii=False)) | 303 | + logger.info("天气模态卡详情:", json.dumps(response2.modal_cards[0].content, indent=2, ensure_ascii=False)) |
| 301 | 304 | ||
| 302 | 305 | ||
| 303 | # 场景3: Agent需要查询特定结构化信息 - 股票 | 306 | # 场景3: Agent需要查询特定结构化信息 - 股票 |
| @@ -311,11 +314,11 @@ if __name__ == "__main__": | @@ -311,11 +314,11 @@ if __name__ == "__main__": | ||
| 311 | # 场景5: Agent只需要快速获取网页信息,不需要AI总结 | 314 | # 场景5: Agent只需要快速获取网页信息,不需要AI总结 |
| 312 | response5 = search_client.web_search_only(query="Python dataclasses用法") | 315 | response5 = search_client.web_search_only(query="Python dataclasses用法") |
| 313 | print_response_summary(response5) | 316 | print_response_summary(response5) |
| 314 | - | 317 | + |
| 315 | # 场景6: Agent需要回顾一周内关于某项技术的新闻 | 318 | # 场景6: Agent需要回顾一周内关于某项技术的新闻 |
| 316 | response6 = search_client.search_last_week(query="量子计算商业化") | 319 | response6 = search_client.search_last_week(query="量子计算商业化") |
| 317 | print_response_summary(response6) | 320 | print_response_summary(response6) |
| 318 | - | 321 | + |
| 319 | '''下面是测试程序的输出: | 322 | '''下面是测试程序的输出: |
| 320 | --- TOOL: 全面综合搜索 (query: 人工智能对未来教育的影响) --- | 323 | --- TOOL: 全面综合搜索 (query: 人工智能对未来教育的影响) --- |
| 321 | 324 | ||
| @@ -381,7 +384,7 @@ AI摘要: 量子计算商业化正在逐步推进。 | @@ -381,7 +384,7 @@ AI摘要: 量子计算商业化正在逐步推进。 | ||
| 381 | ------------------------------------------------------------''' | 384 | ------------------------------------------------------------''' |
| 382 | 385 | ||
| 383 | except ValueError as e: | 386 | except ValueError as e: |
| 384 | - print(f"初始化失败: {e}") | ||
| 385 | - print("请确保 BOCHA_API_KEY 环境变量已正确设置。") | 387 | + logger.exception(f"初始化失败: {e}") |
| 388 | + logger.error("请确保 BOCHA_API_KEY 环境变量已正确设置。") | ||
| 386 | except Exception as e: | 389 | except Exception as e: |
| 387 | - print(f"测试过程中发生未知错误: {e}") | ||
| 390 | + logger.exception(f"测试过程中发生未知错误: {e}") |
| @@ -12,15 +12,15 @@ from .text_processing import ( | @@ -12,15 +12,15 @@ from .text_processing import ( | ||
| 12 | format_search_results_for_prompt | 12 | format_search_results_for_prompt |
| 13 | ) | 13 | ) |
| 14 | 14 | ||
| 15 | -from .config import Config, load_config | 15 | +from .config import Settings, settings |
| 16 | 16 | ||
| 17 | __all__ = [ | 17 | __all__ = [ |
| 18 | "clean_json_tags", | 18 | "clean_json_tags", |
| 19 | "clean_markdown_tags", | 19 | "clean_markdown_tags", |
| 20 | - "remove_reasoning_from_output", | 20 | + "remove_reasoning_from_output", |
| 21 | "extract_clean_response", | 21 | "extract_clean_response", |
| 22 | "update_state_with_search_results", | 22 | "update_state_with_search_results", |
| 23 | "format_search_results_for_prompt", | 23 | "format_search_results_for_prompt", |
| 24 | - "Config", | ||
| 25 | - "load_config" | 24 | + "Settings", |
| 25 | + "settings" | ||
| 26 | ] | 26 | ] |
| 1 | """ | 1 | """ |
| 2 | -Configuration management module for the Media Engine. | 2 | +Configuration management module for the Media Engine (pydantic_settings style). |
| 3 | """ | 3 | """ |
| 4 | 4 | ||
| 5 | -import os | ||
| 6 | -from dataclasses import dataclass | 5 | +from pathlib import Path |
| 6 | +from pydantic_settings import BaseSettings | ||
| 7 | +from pydantic import Field | ||
| 7 | from typing import Optional | 8 | from typing import Optional |
| 8 | 9 | ||
| 9 | 10 | ||
| 10 | -def _get_value(source, key: str, default=None, *fallback_keys: str): | ||
| 11 | - candidates = (key,) + fallback_keys | ||
| 12 | - value = None | ||
| 13 | - for candidate in candidates: | ||
| 14 | - if isinstance(source, dict): | ||
| 15 | - value = source.get(candidate) | ||
| 16 | - else: | ||
| 17 | - value = getattr(source, candidate, None) | ||
| 18 | - if value not in (None, ""): | ||
| 19 | - break | ||
| 20 | - if value in (None, ""): | ||
| 21 | - for candidate in candidates: | ||
| 22 | - env_val = os.getenv(candidate) | ||
| 23 | - if env_val not in (None, ""): | ||
| 24 | - value = env_val | ||
| 25 | - break | ||
| 26 | - return value if value not in (None, "") else default | ||
| 27 | - | ||
| 28 | - | ||
| 29 | -@dataclass | ||
| 30 | -class Config: | ||
| 31 | - """Media Engine configuration.""" | ||
| 32 | - | ||
| 33 | - llm_api_key: Optional[str] = None | ||
| 34 | - llm_base_url: Optional[str] = None | ||
| 35 | - llm_model_name: Optional[str] = None | ||
| 36 | - llm_provider: Optional[str] = None # compatibility | ||
| 37 | - | ||
| 38 | - bocha_api_key: Optional[str] = None | ||
| 39 | - | ||
| 40 | - search_timeout: int = 240 | ||
| 41 | - max_content_length: int = 20000 | ||
| 42 | - max_reflections: int = 2 | ||
| 43 | - max_paragraphs: int = 5 | ||
| 44 | - | ||
| 45 | - output_dir: str = "reports" | ||
| 46 | - save_intermediate_states: bool = True | ||
| 47 | - | ||
| 48 | - def __post_init__(self): | ||
| 49 | - if not self.llm_provider and self.llm_model_name: | ||
| 50 | - self.llm_provider = self.llm_model_name | ||
| 51 | - | ||
| 52 | - def validate(self) -> bool: | ||
| 53 | - if not self.llm_api_key: | ||
| 54 | - print("错误: Media Engine LLM API Key 未设置 (MEDIA_ENGINE_API_KEY)。") | ||
| 55 | - return False | ||
| 56 | - if not self.llm_model_name: | ||
| 57 | - print("错误: Media Engine 模型名称未设置 (MEDIA_ENGINE_MODEL_NAME)。") | ||
| 58 | - return False | ||
| 59 | - if not self.bocha_api_key: | ||
| 60 | - print("错误: Bocha API Key 未设置 (BOCHA_WEB_SEARCH_API_KEY)。") | ||
| 61 | - return False | ||
| 62 | - return True | ||
| 63 | - | ||
| 64 | - @classmethod | ||
| 65 | - def from_file(cls, config_file: str) -> "Config": | ||
| 66 | - if config_file.endswith(".py"): | ||
| 67 | - import importlib.util | ||
| 68 | - | ||
| 69 | - spec = importlib.util.spec_from_file_location("config", config_file) | ||
| 70 | - config_module = importlib.util.module_from_spec(spec) | ||
| 71 | - spec.loader.exec_module(config_module) | ||
| 72 | - | ||
| 73 | - return cls( | ||
| 74 | - llm_api_key=_get_value(config_module, "MEDIA_ENGINE_API_KEY"), | ||
| 75 | - llm_base_url=_get_value(config_module, "MEDIA_ENGINE_BASE_URL"), | ||
| 76 | - llm_model_name=_get_value(config_module, "MEDIA_ENGINE_MODEL_NAME"), | ||
| 77 | - bocha_api_key=_get_value( | ||
| 78 | - config_module, | ||
| 79 | - "BOCHA_WEB_SEARCH_API_KEY", | ||
| 80 | - None, | ||
| 81 | - "BOCHA_API_KEY", | ||
| 82 | - ), | ||
| 83 | - search_timeout=int(_get_value(config_module, "SEARCH_TIMEOUT", 240)), | ||
| 84 | - max_content_length=int(_get_value(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000)), | ||
| 85 | - max_reflections=int(_get_value(config_module, "MAX_REFLECTIONS", 2)), | ||
| 86 | - max_paragraphs=int(_get_value(config_module, "MAX_PARAGRAPHS", 5)), | ||
| 87 | - output_dir=_get_value(config_module, "OUTPUT_DIR", "reports"), | ||
| 88 | - save_intermediate_states=str( | ||
| 89 | - _get_value(config_module, "SAVE_INTERMEDIATE_STATES", "true") | ||
| 90 | - ).lower() | ||
| 91 | - in ("true", "1", "yes"), | ||
| 92 | - ) | ||
| 93 | - | ||
| 94 | - config_dict = {} | ||
| 95 | - if os.path.exists(config_file): | ||
| 96 | - with open(config_file, "r", encoding="utf-8") as f: | ||
| 97 | - for line in f: | ||
| 98 | - line = line.strip() | ||
| 99 | - if line and not line.startswith("#") and "=" in line: | ||
| 100 | - key, value = line.split("=", 1) | ||
| 101 | - config_dict[key.strip()] = value.strip() | ||
| 102 | - | ||
| 103 | - return cls( | ||
| 104 | - llm_api_key=_get_value(config_dict, "MEDIA_ENGINE_API_KEY"), | ||
| 105 | - llm_base_url=_get_value(config_dict, "MEDIA_ENGINE_BASE_URL"), | ||
| 106 | - llm_model_name=_get_value(config_dict, "MEDIA_ENGINE_MODEL_NAME"), | ||
| 107 | - bocha_api_key=_get_value( | ||
| 108 | - config_dict, | ||
| 109 | - "BOCHA_WEB_SEARCH_API_KEY", | ||
| 110 | - None, | ||
| 111 | - "BOCHA_API_KEY", | ||
| 112 | - ), | ||
| 113 | - search_timeout=int(_get_value(config_dict, "SEARCH_TIMEOUT", 240)), | ||
| 114 | - max_content_length=int(_get_value(config_dict, "SEARCH_CONTENT_MAX_LENGTH", 20000)), | ||
| 115 | - max_reflections=int(_get_value(config_dict, "MAX_REFLECTIONS", 2)), | ||
| 116 | - max_paragraphs=int(_get_value(config_dict, "MAX_PARAGRAPHS", 5)), | ||
| 117 | - output_dir=_get_value(config_dict, "OUTPUT_DIR", "reports"), | ||
| 118 | - save_intermediate_states=str( | ||
| 119 | - _get_value(config_dict, "SAVE_INTERMEDIATE_STATES", "true") | ||
| 120 | - ).lower() | ||
| 121 | - in ("true", "1", "yes"), | ||
| 122 | - ) | ||
| 123 | - | ||
| 124 | - | ||
| 125 | -def load_config(config_file: Optional[str] = None) -> Config: | ||
| 126 | - if config_file: | ||
| 127 | - if not os.path.exists(config_file): | ||
| 128 | - raise FileNotFoundError(f"配置文件不存在: {config_file}") | ||
| 129 | - file_to_load = config_file | ||
| 130 | - else: | ||
| 131 | - for candidate in ("config.py", "config.env", ".env"): | ||
| 132 | - if os.path.exists(candidate): | ||
| 133 | - file_to_load = candidate | ||
| 134 | - print(f"已找到配置文件: {candidate}") | ||
| 135 | - break | ||
| 136 | - else: | ||
| 137 | - raise FileNotFoundError("未找到配置文件,请创建 config.py。") | ||
| 138 | - | ||
| 139 | - config = Config.from_file(file_to_load) | ||
| 140 | - if not config.validate(): | ||
| 141 | - raise ValueError("配置校验失败,请检查 config.py 中的相关配置。") | ||
| 142 | - return config | ||
| 143 | - | ||
| 144 | - | ||
| 145 | -def print_config(config: Config): | ||
| 146 | - print("\n=== Media Engine 配置 ===") | ||
| 147 | - print(f"LLM 模型: {config.llm_model_name}") | ||
| 148 | - print(f"LLM Base URL: {config.llm_base_url or '(默认)'}") | ||
| 149 | - print(f"Bocha API Key: {'已配置' if config.bocha_api_key else '未配置'}") | ||
| 150 | - print(f"搜索超时: {config.search_timeout} 秒") | ||
| 151 | - print(f"最长内容长度: {config.max_content_length}") | ||
| 152 | - print(f"最大反思次数: {config.max_reflections}") | ||
| 153 | - print(f"最大段落数: {config.max_paragraphs}") | ||
| 154 | - print(f"输出目录: {config.output_dir}") | ||
| 155 | - print(f"保存中间状态: {config.save_intermediate_states}") | ||
| 156 | - print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}") | ||
| 157 | - print("========================\n") | 11 | +# 计算 .env 优先级:优先当前工作目录,其次项目根目录 |
| 12 | +PROJECT_ROOT: Path = Path(__file__).resolve().parents[2] | ||
| 13 | +CWD_ENV: Path = Path.cwd() / ".env" | ||
| 14 | +ENV_FILE: str = str(CWD_ENV if CWD_ENV.exists() else (PROJECT_ROOT / ".env")) | ||
| 15 | + | ||
| 16 | +class Settings(BaseSettings): | ||
| 17 | + """ | ||
| 18 | + 全局配置;支持 .env 和环境变量自动加载。 | ||
| 19 | + 变量名与原 config.py 大写一致,便于平滑过渡。 | ||
| 20 | + """ | ||
| 21 | + # ====================== 数据库配置 ====================== | ||
| 22 | + DB_HOST: str = Field("your_db_host", description="数据库主机,例如localhost 或 127.0.0.1。我们也提供云数据库资源便捷配置,日均10w+数据,可免费申请,联系我们:670939375@qq.com NOTE:为进行数据合规性审查与服务升级,云数据库自2025年10月1日起暂停接收新的使用申请") | ||
| 23 | + DB_PORT: int = Field(3306, description="数据库端口号,默认为3306") | ||
| 24 | + DB_USER: str = Field("your_db_user", description="数据库用户名") | ||
| 25 | + DB_PASSWORD: str = Field("your_db_password", description="数据库密码") | ||
| 26 | + DB_NAME: str = Field("your_db_name", description="数据库名称") | ||
| 27 | + DB_CHARSET: str = Field("utf8mb4", description="数据库字符集,推荐utf8mb4,兼容emoji") | ||
| 28 | + DB_DIALECT: str = Field("mysql", description="数据库类型,例如 'mysql' 或 'postgresql'。用于支持多种数据库后端(如 SQLAlchemy,请与连接信息共同配置)") | ||
| 29 | + | ||
| 30 | + # ======================= LLM 相关 ======================= | ||
| 31 | + INSIGHT_ENGINE_API_KEY: str = Field(None, description="Insight Agent(推荐Kimi,https://platform.moonshot.cn/)API密钥,用于主LLM。您可以更改每个部分LLM使用的API,🚩只要兼容OpenAI请求格式都可以,定义好KEY、BASE_URL与MODEL_NAME即可正常使用。重要提醒:我们强烈推荐您先使用推荐的配置申请API,先跑通再进行您的更改!") | ||
| 32 | + INSIGHT_ENGINE_BASE_URL: Optional[str] = Field("https://api.moonshot.cn/v1", description="Insight Agent LLM接口BaseUrl,可自定义厂商API") | ||
| 33 | + INSIGHT_ENGINE_MODEL_NAME: str = Field("kimi-k2-0711-preview", description="Insight Agent LLM模型名称,如kimi-k2-0711-preview") | ||
| 34 | + | ||
| 35 | + MEDIA_ENGINE_API_KEY: str = Field(None, description="Media Agent(推荐Gemini,这里我用了一个中转厂商,你也可以换成你自己的,申请地址:https://www.chataiapi.com/)API密钥") | ||
| 36 | + MEDIA_ENGINE_BASE_URL: Optional[str] = Field("https://www.chataiapi.com/v1", description="Media Agent LLM接口BaseUrl") | ||
| 37 | + MEDIA_ENGINE_MODEL_NAME: str = Field("gemini-2.5-pro", description="Media Agent LLM模型名称,如gemini-2.5-pro") | ||
| 38 | + | ||
| 39 | + BOCHA_WEB_SEARCH_API_KEY: Optional[str] = Field(None, description="Bocha Web Search API Key") | ||
| 40 | + BOCHA_API_KEY: Optional[str] = Field(None, description="Bocha 兼容键(别名)") | ||
| 41 | + | ||
| 42 | + SEARCH_TIMEOUT: int = Field(240, description="搜索超时(秒)") | ||
| 43 | + SEARCH_CONTENT_MAX_LENGTH: int = Field(20000, description="用于提示的最长内容长度") | ||
| 44 | + MAX_REFLECTIONS: int = Field(2, description="最大反思轮数") | ||
| 45 | + MAX_PARAGRAPHS: int = Field(5, description="最大段落数") | ||
| 46 | + | ||
| 47 | + MINDSPIDER_API_KEY: Optional[str] = Field(None, description="MindSpider API密钥") | ||
| 48 | + MINDSPIDER_BASE_URL: Optional[str] = Field("https://api.deepseek.com", description="MindSpider LLM接口BaseUrl") | ||
| 49 | + MINDSPIDER_MODEL_NAME: str = Field("deepseek-reasoner", description="MindSpider LLM模型名称,如deepseek-reasoner") | ||
| 50 | + | ||
| 51 | + OUTPUT_DIR: str = Field("reports", description="输出目录") | ||
| 52 | + SAVE_INTERMEDIATE_STATES: bool = Field(True, description="是否保存中间状态") | ||
| 53 | + | ||
| 54 | + | ||
| 55 | + QUERY_ENGINE_API_KEY: str = Field(None, description="Query Agent(推荐DeepSeek,https://www.deepseek.com/)API密钥") | ||
| 56 | + QUERY_ENGINE_BASE_URL: Optional[str] = Field("https://api.deepseek.com", description="Query Agent LLM接口BaseUrl") | ||
| 57 | + QUERY_ENGINE_MODEL_NAME: str = Field("deepseek-reasoner", description="Query Agent LLM模型,如deepseek-reasoner") | ||
| 58 | + | ||
| 59 | + REPORT_ENGINE_API_KEY: str = Field(None, description="Report Agent(推荐Gemini,这里我用了一个中转厂商,你也可以换成你自己的,申请地址:https://www.chataiapi.com/)API密钥") | ||
| 60 | + REPORT_ENGINE_BASE_URL: Optional[str] = Field("https://www.chataiapi.com/v1", description="Report Agent LLM接口BaseUrl") | ||
| 61 | + REPORT_ENGINE_MODEL_NAME: str = Field("gemini-2.5-pro", description="Report Agent LLM模型,如gemini-2.5-pro") | ||
| 62 | + | ||
| 63 | + FORUM_HOST_API_KEY: str = Field(None, description="Forum Host(Qwen3最新模型,这里我使用了硅基流动这个平台,申请地址:https://cloud.siliconflow.cn/)API密钥") | ||
| 64 | + FORUM_HOST_BASE_URL: Optional[str] = Field("https://api.siliconflow.cn/v1", description="Forum Host LLM BaseUrl") | ||
| 65 | + FORUM_HOST_MODEL_NAME: str = Field("Qwen/Qwen3-235B-A22B-Instruct-2507", description="Forum Host LLM模型名,如Qwen/Qwen3-235B-A22B-Instruct-2507") | ||
| 66 | + | ||
| 67 | + KEYWORD_OPTIMIZER_API_KEY: str = Field(None, description="SQL keyword Optimizer(小参数Qwen3模型,这里我使用了硅基流动这个平台,申请地址:https://cloud.siliconflow.cn/)API密钥") | ||
| 68 | + KEYWORD_OPTIMIZER_BASE_URL: Optional[str] = Field("https://api.siliconflow.cn/v1", description="Keyword Optimizer BaseUrl") | ||
| 69 | + KEYWORD_OPTIMIZER_MODEL_NAME: str = Field("Qwen/Qwen3-30B-A3B-Instruct-2507", description="Keyword Optimizer LLM模型名称,如Qwen/Qwen3-30B-A3B-Instruct-2507") | ||
| 70 | + | ||
| 71 | + # ================== 网络工具配置 ==================== | ||
| 72 | + TAVILY_API_KEY: str = Field(None, description="Tavily API(申请地址:https://www.tavily.com/)API密钥,用于Tavily网络搜索") | ||
| 73 | + BOCHA_BASE_URL: Optional[str] = Field("https://api.bochaai.com/v1/ai-search", description="Bocha AI 搜索BaseUrl或博查网页搜索BaseUrl") | ||
| 74 | + BOCHA_WEB_SEARCH_API_KEY: str = Field(None, description="Bocha API(申请地址:https://open.bochaai.com/)API密钥,用于Bocha搜索") | ||
| 75 | + | ||
| 76 | + class Config: | ||
| 77 | + env_file = ENV_FILE | ||
| 78 | + env_prefix = "" | ||
| 79 | + case_sensitive = False | ||
| 80 | + extra = "allow" | ||
| 81 | + | ||
| 82 | + | ||
| 83 | +settings = Settings() |
| @@ -4,9 +4,9 @@ Deep Search Agent | @@ -4,9 +4,9 @@ Deep Search Agent | ||
| 4 | """ | 4 | """ |
| 5 | 5 | ||
| 6 | from .agent import DeepSearchAgent, create_agent | 6 | from .agent import DeepSearchAgent, create_agent |
| 7 | -from .utils.config import Config, load_config | 7 | +from .utils.config import Settings |
| 8 | 8 | ||
| 9 | __version__ = "1.0.0" | 9 | __version__ = "1.0.0" |
| 10 | __author__ = "Deep Search Agent Team" | 10 | __author__ = "Deep Search Agent Team" |
| 11 | 11 | ||
| 12 | -__all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"] | 12 | +__all__ = ["DeepSearchAgent", "create_agent", "Settings"] |
| @@ -20,13 +20,13 @@ from .nodes import ( | @@ -20,13 +20,13 @@ from .nodes import ( | ||
| 20 | ) | 20 | ) |
| 21 | from .state import State | 21 | from .state import State |
| 22 | from .tools import TavilyNewsAgency, TavilyResponse | 22 | from .tools import TavilyNewsAgency, TavilyResponse |
| 23 | -from .utils import Config, load_config, format_search_results_for_prompt | ||
| 24 | - | 23 | +from .utils import Settings, format_search_results_for_prompt |
| 24 | +from loguru import logger | ||
| 25 | 25 | ||
| 26 | class DeepSearchAgent: | 26 | class DeepSearchAgent: |
| 27 | """Deep Search Agent主类""" | 27 | """Deep Search Agent主类""" |
| 28 | 28 | ||
| 29 | - def __init__(self, config: Optional[Config] = None): | 29 | + def __init__(self, config: Optional[Settings] = None): |
| 30 | """ | 30 | """ |
| 31 | 初始化Deep Search Agent | 31 | 初始化Deep Search Agent |
| 32 | 32 | ||
| @@ -34,14 +34,14 @@ class DeepSearchAgent: | @@ -34,14 +34,14 @@ class DeepSearchAgent: | ||
| 34 | config: 配置对象,如果不提供则自动加载 | 34 | config: 配置对象,如果不提供则自动加载 |
| 35 | """ | 35 | """ |
| 36 | # 加载配置 | 36 | # 加载配置 |
| 37 | - self.config = config or load_config() | ||
| 38 | - os.environ["TAVILY_API_KEY"] = self.config.tavily_api_key or "" | 37 | + from .utils.config import settings |
| 38 | + self.config = config or settings | ||
| 39 | 39 | ||
| 40 | # 初始化LLM客户端 | 40 | # 初始化LLM客户端 |
| 41 | self.llm_client = self._initialize_llm() | 41 | self.llm_client = self._initialize_llm() |
| 42 | 42 | ||
| 43 | # 初始化搜索工具集 | 43 | # 初始化搜索工具集 |
| 44 | - self.search_agency = TavilyNewsAgency(api_key=self.config.tavily_api_key) | 44 | + self.search_agency = TavilyNewsAgency(api_key=self.config.TAVILY_API_KEY) |
| 45 | 45 | ||
| 46 | # 初始化节点 | 46 | # 初始化节点 |
| 47 | self._initialize_nodes() | 47 | self._initialize_nodes() |
| @@ -50,18 +50,18 @@ class DeepSearchAgent: | @@ -50,18 +50,18 @@ class DeepSearchAgent: | ||
| 50 | self.state = State() | 50 | self.state = State() |
| 51 | 51 | ||
| 52 | # 确保输出目录存在 | 52 | # 确保输出目录存在 |
| 53 | - os.makedirs(self.config.output_dir, exist_ok=True) | 53 | + os.makedirs(self.config.OUTPUT_DIR, exist_ok=True) |
| 54 | 54 | ||
| 55 | - print(f"Query Agent已初始化") | ||
| 56 | - print(f"使用LLM: {self.llm_client.get_model_info()}") | ||
| 57 | - print(f"搜索工具集: TavilyNewsAgency (支持6种搜索工具)") | 55 | + logger.info(f"Query Agent已初始化") |
| 56 | + logger.info(f"使用LLM: {self.llm_client.get_model_info()}") | ||
| 57 | + logger.info(f"搜索工具集: TavilyNewsAgency (支持6种搜索工具)") | ||
| 58 | 58 | ||
| 59 | def _initialize_llm(self) -> LLMClient: | 59 | def _initialize_llm(self) -> LLMClient: |
| 60 | """初始化LLM客户端""" | 60 | """初始化LLM客户端""" |
| 61 | return LLMClient( | 61 | return LLMClient( |
| 62 | - api_key=self.config.llm_api_key, | ||
| 63 | - model_name=self.config.llm_model_name, | ||
| 64 | - base_url=self.config.llm_base_url, | 62 | + api_key=self.config.QUERY_ENGINE_API_KEY, |
| 63 | + model_name=self.config.QUERY_ENGINE_MODEL_NAME, | ||
| 64 | + base_url=self.config.QUERY_ENGINE_BASE_URL, | ||
| 65 | ) | 65 | ) |
| 66 | 66 | ||
| 67 | def _initialize_nodes(self): | 67 | def _initialize_nodes(self): |
| @@ -115,7 +115,7 @@ class DeepSearchAgent: | @@ -115,7 +115,7 @@ class DeepSearchAgent: | ||
| 115 | Returns: | 115 | Returns: |
| 116 | TavilyResponse对象 | 116 | TavilyResponse对象 |
| 117 | """ | 117 | """ |
| 118 | - print(f" → 执行搜索工具: {tool_name}") | 118 | + logger.info(f" → 执行搜索工具: {tool_name}") |
| 119 | 119 | ||
| 120 | if tool_name == "basic_search_news": | 120 | if tool_name == "basic_search_news": |
| 121 | max_results = kwargs.get("max_results", 7) | 121 | max_results = kwargs.get("max_results", 7) |
| @@ -135,7 +135,7 @@ class DeepSearchAgent: | @@ -135,7 +135,7 @@ class DeepSearchAgent: | ||
| 135 | raise ValueError("search_news_by_date工具需要start_date和end_date参数") | 135 | raise ValueError("search_news_by_date工具需要start_date和end_date参数") |
| 136 | return self.search_agency.search_news_by_date(query, start_date, end_date) | 136 | return self.search_agency.search_news_by_date(query, start_date, end_date) |
| 137 | else: | 137 | else: |
| 138 | - print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认基础搜索") | 138 | + logger.warning(f" ⚠️ 未知的搜索工具: {tool_name},使用默认基础搜索") |
| 139 | return self.search_agency.basic_search_news(query) | 139 | return self.search_agency.basic_search_news(query) |
| 140 | 140 | ||
| 141 | def research(self, query: str, save_report: bool = True) -> str: | 141 | def research(self, query: str, save_report: bool = True) -> str: |
| @@ -149,9 +149,9 @@ class DeepSearchAgent: | @@ -149,9 +149,9 @@ class DeepSearchAgent: | ||
| 149 | Returns: | 149 | Returns: |
| 150 | 最终报告内容 | 150 | 最终报告内容 |
| 151 | """ | 151 | """ |
| 152 | - print(f"\n{'='*60}") | ||
| 153 | - print(f"开始深度研究: {query}") | ||
| 154 | - print(f"{'='*60}") | 152 | + logger.info(f"\n{'='*60}") |
| 153 | + logger.info(f"开始深度研究: {query}") | ||
| 154 | + logger.info(f"{'='*60}") | ||
| 155 | 155 | ||
| 156 | try: | 156 | try: |
| 157 | # Step 1: 生成报告结构 | 157 | # Step 1: 生成报告结构 |
| @@ -167,19 +167,21 @@ class DeepSearchAgent: | @@ -167,19 +167,21 @@ class DeepSearchAgent: | ||
| 167 | if save_report: | 167 | if save_report: |
| 168 | self._save_report(final_report) | 168 | self._save_report(final_report) |
| 169 | 169 | ||
| 170 | - print(f"\n{'='*60}") | ||
| 171 | - print("深度研究完成!") | ||
| 172 | - print(f"{'='*60}") | 170 | + logger.info(f"\n{'='*60}") |
| 171 | + logger.info("深度研究完成!") | ||
| 172 | + logger.info(f"{'='*60}") | ||
| 173 | 173 | ||
| 174 | return final_report | 174 | return final_report |
| 175 | 175 | ||
| 176 | except Exception as e: | 176 | except Exception as e: |
| 177 | - print(f"研究过程中发生错误: {str(e)}") | 177 | + import traceback |
| 178 | + error_traceback = traceback.format_exc() | ||
| 179 | + logger.error(f"研究过程中发生错误: {str(e)} \n错误堆栈: {error_traceback}") | ||
| 178 | raise e | 180 | raise e |
| 179 | 181 | ||
| 180 | def _generate_report_structure(self, query: str): | 182 | def _generate_report_structure(self, query: str): |
| 181 | """生成报告结构""" | 183 | """生成报告结构""" |
| 182 | - print(f"\n[步骤 1] 生成报告结构...") | 184 | + logger.info(f"\n[步骤 1] 生成报告结构...") |
| 183 | 185 | ||
| 184 | # 创建报告结构节点 | 186 | # 创建报告结构节点 |
| 185 | report_structure_node = ReportStructureNode(self.llm_client, query) | 187 | report_structure_node = ReportStructureNode(self.llm_client, query) |
| @@ -187,17 +189,18 @@ class DeepSearchAgent: | @@ -187,17 +189,18 @@ class DeepSearchAgent: | ||
| 187 | # 生成结构并更新状态 | 189 | # 生成结构并更新状态 |
| 188 | self.state = report_structure_node.mutate_state(state=self.state) | 190 | self.state = report_structure_node.mutate_state(state=self.state) |
| 189 | 191 | ||
| 190 | - print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:") | 192 | + _message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:" |
| 191 | for i, paragraph in enumerate(self.state.paragraphs, 1): | 193 | for i, paragraph in enumerate(self.state.paragraphs, 1): |
| 192 | - print(f" {i}. {paragraph.title}") | 194 | + _message += f"\n {i}. {paragraph.title}" |
| 195 | + logger.info(_message) | ||
| 193 | 196 | ||
| 194 | def _process_paragraphs(self): | 197 | def _process_paragraphs(self): |
| 195 | """处理所有段落""" | 198 | """处理所有段落""" |
| 196 | total_paragraphs = len(self.state.paragraphs) | 199 | total_paragraphs = len(self.state.paragraphs) |
| 197 | 200 | ||
| 198 | for i in range(total_paragraphs): | 201 | for i in range(total_paragraphs): |
| 199 | - print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") | ||
| 200 | - print("-" * 50) | 202 | + logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") |
| 203 | + logger.info("-" * 50) | ||
| 201 | 204 | ||
| 202 | # 初始搜索和总结 | 205 | # 初始搜索和总结 |
| 203 | self._initial_search_and_summary(i) | 206 | self._initial_search_and_summary(i) |
| @@ -209,7 +212,7 @@ class DeepSearchAgent: | @@ -209,7 +212,7 @@ class DeepSearchAgent: | ||
| 209 | self.state.paragraphs[i].research.mark_completed() | 212 | self.state.paragraphs[i].research.mark_completed() |
| 210 | 213 | ||
| 211 | progress = (i + 1) / total_paragraphs * 100 | 214 | progress = (i + 1) / total_paragraphs * 100 |
| 212 | - print(f"段落处理完成 ({progress:.1f}%)") | 215 | + logger.info(f"段落处理完成 ({progress:.1f}%)") |
| 213 | 216 | ||
| 214 | def _initial_search_and_summary(self, paragraph_index: int): | 217 | def _initial_search_and_summary(self, paragraph_index: int): |
| 215 | """执行初始搜索和总结""" | 218 | """执行初始搜索和总结""" |
| @@ -222,18 +225,18 @@ class DeepSearchAgent: | @@ -222,18 +225,18 @@ class DeepSearchAgent: | ||
| 222 | } | 225 | } |
| 223 | 226 | ||
| 224 | # 生成搜索查询和工具选择 | 227 | # 生成搜索查询和工具选择 |
| 225 | - print(" - 生成搜索查询...") | 228 | + logger.info(" - 生成搜索查询...") |
| 226 | search_output = self.first_search_node.run(search_input) | 229 | search_output = self.first_search_node.run(search_input) |
| 227 | search_query = search_output["search_query"] | 230 | search_query = search_output["search_query"] |
| 228 | search_tool = search_output.get("search_tool", "basic_search_news") # 默认工具 | 231 | search_tool = search_output.get("search_tool", "basic_search_news") # 默认工具 |
| 229 | reasoning = search_output["reasoning"] | 232 | reasoning = search_output["reasoning"] |
| 230 | 233 | ||
| 231 | - print(f" - 搜索查询: {search_query}") | ||
| 232 | - print(f" - 选择的工具: {search_tool}") | ||
| 233 | - print(f" - 推理: {reasoning}") | 234 | + logger.info(f" - 搜索查询: {search_query}") |
| 235 | + logger.info(f" - 选择的工具: {search_tool}") | ||
| 236 | + logger.info(f" - 推理: {reasoning}") | ||
| 234 | 237 | ||
| 235 | # 执行搜索 | 238 | # 执行搜索 |
| 236 | - print(" - 执行网络搜索...") | 239 | + logger.info(" - 执行网络搜索...") |
| 237 | 240 | ||
| 238 | # 处理search_news_by_date的特殊参数 | 241 | # 处理search_news_by_date的特殊参数 |
| 239 | search_kwargs = {} | 242 | search_kwargs = {} |
| @@ -246,13 +249,13 @@ class DeepSearchAgent: | @@ -246,13 +249,13 @@ class DeepSearchAgent: | ||
| 246 | if self._validate_date_format(start_date) and self._validate_date_format(end_date): | 249 | if self._validate_date_format(start_date) and self._validate_date_format(end_date): |
| 247 | search_kwargs["start_date"] = start_date | 250 | search_kwargs["start_date"] = start_date |
| 248 | search_kwargs["end_date"] = end_date | 251 | search_kwargs["end_date"] = end_date |
| 249 | - print(f" - 时间范围: {start_date} 到 {end_date}") | 252 | + logger.info(f" - 时间范围: {start_date} 到 {end_date}") |
| 250 | else: | 253 | else: |
| 251 | - print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索") | ||
| 252 | - print(f" 提供的日期: start_date={start_date}, end_date={end_date}") | 254 | + logger.info(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索") |
| 255 | + logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") | ||
| 253 | search_tool = "basic_search_news" | 256 | search_tool = "basic_search_news" |
| 254 | else: | 257 | else: |
| 255 | - print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索") | 258 | + logger.info(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索") |
| 256 | search_tool = "basic_search_news" | 259 | search_tool = "basic_search_news" |
| 257 | 260 | ||
| 258 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | 261 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) |
| @@ -273,24 +276,24 @@ class DeepSearchAgent: | @@ -273,24 +276,24 @@ class DeepSearchAgent: | ||
| 273 | }) | 276 | }) |
| 274 | 277 | ||
| 275 | if search_results: | 278 | if search_results: |
| 276 | - print(f" - 找到 {len(search_results)} 个搜索结果") | 279 | + _message = f" - 找到 {len(search_results)} 个搜索结果" |
| 277 | for j, result in enumerate(search_results, 1): | 280 | for j, result in enumerate(search_results, 1): |
| 278 | date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" | 281 | date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" |
| 279 | - print(f" {j}. {result['title'][:50]}...{date_info}") | 282 | + _message += f"\n {j}. {result['title'][:50]}...{date_info}" |
| 283 | + logger.info(_message) | ||
| 280 | else: | 284 | else: |
| 281 | - print(" - 未找到搜索结果") | ||
| 282 | - | 285 | + logger.info(" - 未找到搜索结果") |
| 283 | # 更新状态中的搜索历史 | 286 | # 更新状态中的搜索历史 |
| 284 | paragraph.research.add_search_results(search_query, search_results) | 287 | paragraph.research.add_search_results(search_query, search_results) |
| 285 | 288 | ||
| 286 | # 生成初始总结 | 289 | # 生成初始总结 |
| 287 | - print(" - 生成初始总结...") | 290 | + logger.info(" - 生成初始总结...") |
| 288 | summary_input = { | 291 | summary_input = { |
| 289 | "title": paragraph.title, | 292 | "title": paragraph.title, |
| 290 | "content": paragraph.content, | 293 | "content": paragraph.content, |
| 291 | "search_query": search_query, | 294 | "search_query": search_query, |
| 292 | "search_results": format_search_results_for_prompt( | 295 | "search_results": format_search_results_for_prompt( |
| 293 | - search_results, self.config.max_content_length | 296 | + search_results, self.config.SEARCH_CONTENT_MAX_LENGTH |
| 294 | ) | 297 | ) |
| 295 | } | 298 | } |
| 296 | 299 | ||
| @@ -299,14 +302,14 @@ class DeepSearchAgent: | @@ -299,14 +302,14 @@ class DeepSearchAgent: | ||
| 299 | summary_input, self.state, paragraph_index | 302 | summary_input, self.state, paragraph_index |
| 300 | ) | 303 | ) |
| 301 | 304 | ||
| 302 | - print(" - 初始总结完成") | 305 | + logger.info(" - 初始总结完成") |
| 303 | 306 | ||
| 304 | def _reflection_loop(self, paragraph_index: int): | 307 | def _reflection_loop(self, paragraph_index: int): |
| 305 | """执行反思循环""" | 308 | """执行反思循环""" |
| 306 | paragraph = self.state.paragraphs[paragraph_index] | 309 | paragraph = self.state.paragraphs[paragraph_index] |
| 307 | 310 | ||
| 308 | - for reflection_i in range(self.config.max_reflections): | ||
| 309 | - print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...") | 311 | + for reflection_i in range(self.config.MAX_REFLECTIONS): |
| 312 | + logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...") | ||
| 310 | 313 | ||
| 311 | # 准备反思输入 | 314 | # 准备反思输入 |
| 312 | reflection_input = { | 315 | reflection_input = { |
| @@ -321,9 +324,9 @@ class DeepSearchAgent: | @@ -321,9 +324,9 @@ class DeepSearchAgent: | ||
| 321 | search_tool = reflection_output.get("search_tool", "basic_search_news") # 默认工具 | 324 | search_tool = reflection_output.get("search_tool", "basic_search_news") # 默认工具 |
| 322 | reasoning = reflection_output["reasoning"] | 325 | reasoning = reflection_output["reasoning"] |
| 323 | 326 | ||
| 324 | - print(f" 反思查询: {search_query}") | ||
| 325 | - print(f" 选择的工具: {search_tool}") | ||
| 326 | - print(f" 反思推理: {reasoning}") | 327 | + logger.info(f" 反思查询: {search_query}") |
| 328 | + logger.info(f" 选择的工具: {search_tool}") | ||
| 329 | + logger.info(f" 反思推理: {reasoning}") | ||
| 327 | 330 | ||
| 328 | # 执行反思搜索 | 331 | # 执行反思搜索 |
| 329 | # 处理search_news_by_date的特殊参数 | 332 | # 处理search_news_by_date的特殊参数 |
| @@ -337,13 +340,13 @@ class DeepSearchAgent: | @@ -337,13 +340,13 @@ class DeepSearchAgent: | ||
| 337 | if self._validate_date_format(start_date) and self._validate_date_format(end_date): | 340 | if self._validate_date_format(start_date) and self._validate_date_format(end_date): |
| 338 | search_kwargs["start_date"] = start_date | 341 | search_kwargs["start_date"] = start_date |
| 339 | search_kwargs["end_date"] = end_date | 342 | search_kwargs["end_date"] = end_date |
| 340 | - print(f" 时间范围: {start_date} 到 {end_date}") | 343 | + logger.info(f" 时间范围: {start_date} 到 {end_date}") |
| 341 | else: | 344 | else: |
| 342 | - print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索") | ||
| 343 | - print(f" 提供的日期: start_date={start_date}, end_date={end_date}") | 345 | + logger.info(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索") |
| 346 | + logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") | ||
| 344 | search_tool = "basic_search_news" | 347 | search_tool = "basic_search_news" |
| 345 | else: | 348 | else: |
| 346 | - print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索") | 349 | + logger.info(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索") |
| 347 | search_tool = "basic_search_news" | 350 | search_tool = "basic_search_news" |
| 348 | 351 | ||
| 349 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | 352 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) |
| @@ -364,12 +367,12 @@ class DeepSearchAgent: | @@ -364,12 +367,12 @@ class DeepSearchAgent: | ||
| 364 | }) | 367 | }) |
| 365 | 368 | ||
| 366 | if search_results: | 369 | if search_results: |
| 367 | - print(f" 找到 {len(search_results)} 个反思搜索结果") | 370 | + logger.info(f" 找到 {len(search_results)} 个反思搜索结果") |
| 368 | for j, result in enumerate(search_results, 1): | 371 | for j, result in enumerate(search_results, 1): |
| 369 | date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" | 372 | date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" |
| 370 | - print(f" {j}. {result['title'][:50]}...{date_info}") | 373 | + logger.info(f" {j}. {result['title'][:50]}...{date_info}") |
| 371 | else: | 374 | else: |
| 372 | - print(" 未找到反思搜索结果") | 375 | + logger.info(" 未找到反思搜索结果") |
| 373 | 376 | ||
| 374 | # 更新搜索历史 | 377 | # 更新搜索历史 |
| 375 | paragraph.research.add_search_results(search_query, search_results) | 378 | paragraph.research.add_search_results(search_query, search_results) |
| @@ -380,7 +383,7 @@ class DeepSearchAgent: | @@ -380,7 +383,7 @@ class DeepSearchAgent: | ||
| 380 | "content": paragraph.content, | 383 | "content": paragraph.content, |
| 381 | "search_query": search_query, | 384 | "search_query": search_query, |
| 382 | "search_results": format_search_results_for_prompt( | 385 | "search_results": format_search_results_for_prompt( |
| 383 | - search_results, self.config.max_content_length | 386 | + search_results, self.config.SEARCH_CONTENT_MAX_LENGTH |
| 384 | ), | 387 | ), |
| 385 | "paragraph_latest_state": paragraph.research.latest_summary | 388 | "paragraph_latest_state": paragraph.research.latest_summary |
| 386 | } | 389 | } |
| @@ -390,11 +393,11 @@ class DeepSearchAgent: | @@ -390,11 +393,11 @@ class DeepSearchAgent: | ||
| 390 | reflection_summary_input, self.state, paragraph_index | 393 | reflection_summary_input, self.state, paragraph_index |
| 391 | ) | 394 | ) |
| 392 | 395 | ||
| 393 | - print(f" 反思 {reflection_i + 1} 完成") | 396 | + logger.info(f" 反思 {reflection_i + 1} 完成") |
| 394 | 397 | ||
| 395 | def _generate_final_report(self) -> str: | 398 | def _generate_final_report(self) -> str: |
| 396 | """生成最终报告""" | 399 | """生成最终报告""" |
| 397 | - print(f"\n[步骤 3] 生成最终报告...") | 400 | + logger.info(f"\n[步骤 3] 生成最终报告...") |
| 398 | 401 | ||
| 399 | # 准备报告数据 | 402 | # 准备报告数据 |
| 400 | report_data = [] | 403 | report_data = [] |
| @@ -408,7 +411,7 @@ class DeepSearchAgent: | @@ -408,7 +411,7 @@ class DeepSearchAgent: | ||
| 408 | try: | 411 | try: |
| 409 | final_report = self.report_formatting_node.run(report_data) | 412 | final_report = self.report_formatting_node.run(report_data) |
| 410 | except Exception as e: | 413 | except Exception as e: |
| 411 | - print(f"LLM格式化失败,使用备用方法: {str(e)}") | 414 | + logger.error(f"LLM格式化失败,使用备用方法: {str(e)}") |
| 412 | final_report = self.report_formatting_node.format_report_manually( | 415 | final_report = self.report_formatting_node.format_report_manually( |
| 413 | report_data, self.state.report_title | 416 | report_data, self.state.report_title |
| 414 | ) | 417 | ) |
| @@ -417,7 +420,7 @@ class DeepSearchAgent: | @@ -417,7 +420,7 @@ class DeepSearchAgent: | ||
| 417 | self.state.final_report = final_report | 420 | self.state.final_report = final_report |
| 418 | self.state.mark_completed() | 421 | self.state.mark_completed() |
| 419 | 422 | ||
| 420 | - print("最终报告生成完成") | 423 | + logger.info("最终报告生成完成") |
| 421 | return final_report | 424 | return final_report |
| 422 | 425 | ||
| 423 | def _save_report(self, report_content: str): | 426 | def _save_report(self, report_content: str): |
| @@ -428,20 +431,20 @@ class DeepSearchAgent: | @@ -428,20 +431,20 @@ class DeepSearchAgent: | ||
| 428 | query_safe = query_safe.replace(' ', '_')[:30] | 431 | query_safe = query_safe.replace(' ', '_')[:30] |
| 429 | 432 | ||
| 430 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" | 433 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" |
| 431 | - filepath = os.path.join(self.config.output_dir, filename) | 434 | + filepath = os.path.join(self.config.OUTPUT_DIR, filename) |
| 432 | 435 | ||
| 433 | # 保存报告 | 436 | # 保存报告 |
| 434 | with open(filepath, 'w', encoding='utf-8') as f: | 437 | with open(filepath, 'w', encoding='utf-8') as f: |
| 435 | f.write(report_content) | 438 | f.write(report_content) |
| 436 | 439 | ||
| 437 | - print(f"报告已保存到: {filepath}") | 440 | + logger.info(f"报告已保存到: {filepath}") |
| 438 | 441 | ||
| 439 | # 保存状态(如果配置允许) | 442 | # 保存状态(如果配置允许) |
| 440 | - if self.config.save_intermediate_states: | 443 | + if self.config.SAVE_INTERMEDIATE_STATES: |
| 441 | state_filename = f"state_{query_safe}_{timestamp}.json" | 444 | state_filename = f"state_{query_safe}_{timestamp}.json" |
| 442 | - state_filepath = os.path.join(self.config.output_dir, state_filename) | 445 | + state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename) |
| 443 | self.state.save_to_file(state_filepath) | 446 | self.state.save_to_file(state_filepath) |
| 444 | - print(f"状态已保存到: {state_filepath}") | 447 | + logger.info(f"状态已保存到: {state_filepath}") |
| 445 | 448 | ||
| 446 | def get_progress_summary(self) -> Dict[str, Any]: | 449 | def get_progress_summary(self) -> Dict[str, Any]: |
| 447 | """获取进度摘要""" | 450 | """获取进度摘要""" |
| @@ -450,23 +453,21 @@ class DeepSearchAgent: | @@ -450,23 +453,21 @@ class DeepSearchAgent: | ||
| 450 | def load_state(self, filepath: str): | 453 | def load_state(self, filepath: str): |
| 451 | """从文件加载状态""" | 454 | """从文件加载状态""" |
| 452 | self.state = State.load_from_file(filepath) | 455 | self.state = State.load_from_file(filepath) |
| 453 | - print(f"状态已从 {filepath} 加载") | 456 | + logger.info(f"状态已从 {filepath} 加载") |
| 454 | 457 | ||
| 455 | def save_state(self, filepath: str): | 458 | def save_state(self, filepath: str): |
| 456 | """保存状态到文件""" | 459 | """保存状态到文件""" |
| 457 | self.state.save_to_file(filepath) | 460 | self.state.save_to_file(filepath) |
| 458 | - print(f"状态已保存到 {filepath}") | 461 | + logger.info(f"状态已保存到 {filepath}") |
| 459 | 462 | ||
| 460 | 463 | ||
| 461 | -def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: | 464 | +def create_agent() -> DeepSearchAgent: |
| 462 | """ | 465 | """ |
| 463 | 创建Deep Search Agent实例的便捷函数 | 466 | 创建Deep Search Agent实例的便捷函数 |
| 464 | 467 | ||
| 465 | - Args: | ||
| 466 | - config_file: 配置文件路径 | ||
| 467 | - | ||
| 468 | Returns: | 468 | Returns: |
| 469 | DeepSearchAgent实例 | 469 | DeepSearchAgent实例 |
| 470 | """ | 470 | """ |
| 471 | - config = load_config(config_file) | 471 | + from .utils.config import Settings |
| 472 | + config = Settings() | ||
| 472 | return DeepSearchAgent(config) | 473 | return DeepSearchAgent(config) |
| @@ -5,69 +5,74 @@ | @@ -5,69 +5,74 @@ | ||
| 5 | 5 | ||
| 6 | from abc import ABC, abstractmethod | 6 | from abc import ABC, abstractmethod |
| 7 | from typing import Any, Dict, Optional | 7 | from typing import Any, Dict, Optional |
| 8 | +from loguru import logger | ||
| 8 | from ..llms.base import LLMClient | 9 | from ..llms.base import LLMClient |
| 9 | from ..state.state import State | 10 | from ..state.state import State |
| 10 | 11 | ||
| 11 | 12 | ||
| 12 | class BaseNode(ABC): | 13 | class BaseNode(ABC): |
| 13 | """节点基类""" | 14 | """节点基类""" |
| 14 | - | 15 | + |
| 15 | def __init__(self, llm_client: LLMClient, node_name: str = ""): | 16 | def __init__(self, llm_client: LLMClient, node_name: str = ""): |
| 16 | """ | 17 | """ |
| 17 | 初始化节点 | 18 | 初始化节点 |
| 18 | - | 19 | + |
| 19 | Args: | 20 | Args: |
| 20 | llm_client: LLM客户端 | 21 | llm_client: LLM客户端 |
| 21 | node_name: 节点名称 | 22 | node_name: 节点名称 |
| 22 | """ | 23 | """ |
| 23 | self.llm_client = llm_client | 24 | self.llm_client = llm_client |
| 24 | self.node_name = node_name or self.__class__.__name__ | 25 | self.node_name = node_name or self.__class__.__name__ |
| 25 | - | 26 | + |
| 26 | @abstractmethod | 27 | @abstractmethod |
| 27 | def run(self, input_data: Any, **kwargs) -> Any: | 28 | def run(self, input_data: Any, **kwargs) -> Any: |
| 28 | """ | 29 | """ |
| 29 | 执行节点处理逻辑 | 30 | 执行节点处理逻辑 |
| 30 | - | 31 | + |
| 31 | Args: | 32 | Args: |
| 32 | input_data: 输入数据 | 33 | input_data: 输入数据 |
| 33 | **kwargs: 额外参数 | 34 | **kwargs: 额外参数 |
| 34 | - | 35 | + |
| 35 | Returns: | 36 | Returns: |
| 36 | 处理结果 | 37 | 处理结果 |
| 37 | """ | 38 | """ |
| 38 | pass | 39 | pass |
| 39 | - | 40 | + |
| 40 | def validate_input(self, input_data: Any) -> bool: | 41 | def validate_input(self, input_data: Any) -> bool: |
| 41 | """ | 42 | """ |
| 42 | 验证输入数据 | 43 | 验证输入数据 |
| 43 | - | 44 | + |
| 44 | Args: | 45 | Args: |
| 45 | input_data: 输入数据 | 46 | input_data: 输入数据 |
| 46 | - | 47 | + |
| 47 | Returns: | 48 | Returns: |
| 48 | 验证是否通过 | 49 | 验证是否通过 |
| 49 | """ | 50 | """ |
| 50 | return True | 51 | return True |
| 51 | - | 52 | + |
| 52 | def process_output(self, output: Any) -> Any: | 53 | def process_output(self, output: Any) -> Any: |
| 53 | """ | 54 | """ |
| 54 | 处理输出数据 | 55 | 处理输出数据 |
| 55 | - | 56 | + |
| 56 | Args: | 57 | Args: |
| 57 | output: 原始输出 | 58 | output: 原始输出 |
| 58 | - | 59 | + |
| 59 | Returns: | 60 | Returns: |
| 60 | 处理后的输出 | 61 | 处理后的输出 |
| 61 | """ | 62 | """ |
| 62 | return output | 63 | return output |
| 63 | - | 64 | + |
| 64 | def log_info(self, message: str): | 65 | def log_info(self, message: str): |
| 65 | """记录信息日志""" | 66 | """记录信息日志""" |
| 66 | - print(f"[{self.node_name}] {message}") | 67 | + logger.info(f"[{self.node_name}] {message}") |
| 67 | 68 | ||
| 69 | + def log_warning(self, message: str): | ||
| 70 | + """记录警告日志""" | ||
| 71 | + logger.warning(f"[{self.node_name}] 警告: {message}") | ||
| 72 | + | ||
| 68 | def log_error(self, message: str): | 73 | def log_error(self, message: str): |
| 69 | """记录错误日志""" | 74 | """记录错误日志""" |
| 70 | - print(f"[{self.node_name}] 错误: {message}") | 75 | + logger.error(f"[{self.node_name}] 错误: {message}") |
| 71 | 76 | ||
| 72 | 77 | ||
| 73 | class StateMutationNode(BaseNode): | 78 | class StateMutationNode(BaseNode): |
| @@ -7,6 +7,7 @@ import json | @@ -7,6 +7,7 @@ import json | ||
| 7 | from typing import List, Dict, Any | 7 | from typing import List, Dict, Any |
| 8 | 8 | ||
| 9 | from .base_node import BaseNode | 9 | from .base_node import BaseNode |
| 10 | +from loguru import logger | ||
| 10 | from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING | 11 | from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING |
| 11 | from ..utils.text_processing import ( | 12 | from ..utils.text_processing import ( |
| 12 | remove_reasoning_from_output, | 13 | remove_reasoning_from_output, |
| @@ -65,7 +66,7 @@ class ReportFormattingNode(BaseNode): | @@ -65,7 +66,7 @@ class ReportFormattingNode(BaseNode): | ||
| 65 | else: | 66 | else: |
| 66 | message = json.dumps(input_data, ensure_ascii=False) | 67 | message = json.dumps(input_data, ensure_ascii=False) |
| 67 | 68 | ||
| 68 | - self.log_info("正在格式化最终报告") | 69 | + logger.info("正在格式化最终报告") |
| 69 | 70 | ||
| 70 | # 调用LLM生成Markdown格式 | 71 | # 调用LLM生成Markdown格式 |
| 71 | response = self.llm_client.invoke( | 72 | response = self.llm_client.invoke( |
| @@ -76,11 +77,11 @@ class ReportFormattingNode(BaseNode): | @@ -76,11 +77,11 @@ class ReportFormattingNode(BaseNode): | ||
| 76 | # 处理响应 | 77 | # 处理响应 |
| 77 | processed_response = self.process_output(response) | 78 | processed_response = self.process_output(response) |
| 78 | 79 | ||
| 79 | - self.log_info("成功生成格式化报告") | 80 | + logger.info("成功生成格式化报告") |
| 80 | return processed_response | 81 | return processed_response |
| 81 | 82 | ||
| 82 | except Exception as e: | 83 | except Exception as e: |
| 83 | - self.log_error(f"报告格式化失败: {str(e)}") | 84 | + logger.exception(f"报告格式化失败: {str(e)}") |
| 84 | raise e | 85 | raise e |
| 85 | 86 | ||
| 86 | def process_output(self, output: str) -> str: | 87 | def process_output(self, output: str) -> str: |
| @@ -109,7 +110,7 @@ class ReportFormattingNode(BaseNode): | @@ -109,7 +110,7 @@ class ReportFormattingNode(BaseNode): | ||
| 109 | return cleaned_output.strip() | 110 | return cleaned_output.strip() |
| 110 | 111 | ||
| 111 | except Exception as e: | 112 | except Exception as e: |
| 112 | - self.log_error(f"处理输出失败: {str(e)}") | 113 | + logger.exception(f"处理输出失败: {str(e)}") |
| 113 | return "# 报告处理失败\n\n报告格式化过程中发生错误。" | 114 | return "# 报告处理失败\n\n报告格式化过程中发生错误。" |
| 114 | 115 | ||
| 115 | def format_report_manually(self, paragraphs_data: List[Dict[str, str]], | 116 | def format_report_manually(self, paragraphs_data: List[Dict[str, str]], |
| @@ -125,7 +126,7 @@ class ReportFormattingNode(BaseNode): | @@ -125,7 +126,7 @@ class ReportFormattingNode(BaseNode): | ||
| 125 | 格式化的Markdown报告 | 126 | 格式化的Markdown报告 |
| 126 | """ | 127 | """ |
| 127 | try: | 128 | try: |
| 128 | - self.log_info("使用手动格式化方法") | 129 | + logger.info("使用手动格式化方法") |
| 129 | 130 | ||
| 130 | # 构建报告 | 131 | # 构建报告 |
| 131 | report_lines = [ | 132 | report_lines = [ |
| @@ -163,5 +164,5 @@ class ReportFormattingNode(BaseNode): | @@ -163,5 +164,5 @@ class ReportFormattingNode(BaseNode): | ||
| 163 | return "\n".join(report_lines) | 164 | return "\n".join(report_lines) |
| 164 | 165 | ||
| 165 | except Exception as e: | 166 | except Exception as e: |
| 166 | - self.log_error(f"手动格式化失败: {str(e)}") | 167 | + logger.exception(f"手动格式化失败: {str(e)}") |
| 167 | return "# 报告生成失败\n\n无法完成报告格式化。" | 168 | return "# 报告生成失败\n\n无法完成报告格式化。" |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | import json | 6 | import json |
| 7 | from typing import Dict, Any, List | 7 | from typing import Dict, Any, List |
| 8 | from json.decoder import JSONDecodeError | 8 | from json.decoder import JSONDecodeError |
| 9 | +from loguru import logger | ||
| 9 | 10 | ||
| 10 | from .base_node import StateMutationNode | 11 | from .base_node import StateMutationNode |
| 11 | from ..state.state import State | 12 | from ..state.state import State |
| @@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode): | @@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode): | ||
| 48 | 报告结构列表 | 49 | 报告结构列表 |
| 49 | """ | 50 | """ |
| 50 | try: | 51 | try: |
| 51 | - self.log_info(f"正在为查询生成报告结构: {self.query}") | 52 | + logger.info(f"正在为查询生成报告结构: {self.query}") |
| 52 | 53 | ||
| 53 | # 调用LLM | 54 | # 调用LLM |
| 54 | response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) | 55 | response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query) |
| @@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode): | @@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode): | ||
| 56 | # 处理响应 | 57 | # 处理响应 |
| 57 | processed_response = self.process_output(response) | 58 | processed_response = self.process_output(response) |
| 58 | 59 | ||
| 59 | - self.log_info(f"成功生成 {len(processed_response)} 个段落结构") | 60 | + logger.info(f"成功生成 {len(processed_response)} 个段落结构") |
| 60 | return processed_response | 61 | return processed_response |
| 61 | 62 | ||
| 62 | except Exception as e: | 63 | except Exception as e: |
| 63 | - self.log_error(f"生成报告结构失败: {str(e)}") | 64 | + logger.exception(f"生成报告结构失败: {str(e)}") |
| 64 | raise e | 65 | raise e |
| 65 | 66 | ||
| 66 | def process_output(self, output: str) -> List[Dict[str, str]]: | 67 | def process_output(self, output: str) -> List[Dict[str, str]]: |
| @@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode): | @@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode): | ||
| 79 | cleaned_output = clean_json_tags(cleaned_output) | 80 | cleaned_output = clean_json_tags(cleaned_output) |
| 80 | 81 | ||
| 81 | # 记录清理后的输出用于调试 | 82 | # 记录清理后的输出用于调试 |
| 82 | - self.log_info(f"清理后的输出: {cleaned_output}") | 83 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 83 | 84 | ||
| 84 | # 解析JSON | 85 | # 解析JSON |
| 85 | try: | 86 | try: |
| 86 | report_structure = json.loads(cleaned_output) | 87 | report_structure = json.loads(cleaned_output) |
| 87 | - self.log_info("JSON解析成功") | 88 | + logger.info("JSON解析成功") |
| 88 | except JSONDecodeError as e: | 89 | except JSONDecodeError as e: |
| 89 | - self.log_info(f"JSON解析失败: {str(e)}") | 90 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 90 | # 使用更强大的提取方法 | 91 | # 使用更强大的提取方法 |
| 91 | report_structure = extract_clean_response(cleaned_output) | 92 | report_structure = extract_clean_response(cleaned_output) |
| 92 | if "error" in report_structure: | 93 | if "error" in report_structure: |
| 93 | - self.log_error("JSON解析失败,尝试修复...") | 94 | + logger.error("JSON解析失败,尝试修复...") |
| 94 | # 尝试修复JSON | 95 | # 尝试修复JSON |
| 95 | fixed_json = fix_incomplete_json(cleaned_output) | 96 | fixed_json = fix_incomplete_json(cleaned_output) |
| 96 | if fixed_json: | 97 | if fixed_json: |
| 97 | try: | 98 | try: |
| 98 | report_structure = json.loads(fixed_json) | 99 | report_structure = json.loads(fixed_json) |
| 99 | - self.log_info("JSON修复成功") | 100 | + logger.info("JSON修复成功") |
| 100 | except JSONDecodeError: | 101 | except JSONDecodeError: |
| 101 | - self.log_error("JSON修复失败") | 102 | + logger.error("JSON修复失败") |
| 102 | # 返回默认结构 | 103 | # 返回默认结构 |
| 103 | return self._generate_default_structure() | 104 | return self._generate_default_structure() |
| 104 | else: | 105 | else: |
| 105 | - self.log_error("无法修复JSON,使用默认结构") | 106 | + logger.error("无法修复JSON,使用默认结构") |
| 106 | return self._generate_default_structure() | 107 | return self._generate_default_structure() |
| 107 | 108 | ||
| 108 | # 验证结构 | 109 | # 验证结构 |
| 109 | if not isinstance(report_structure, list): | 110 | if not isinstance(report_structure, list): |
| 110 | - self.log_info("报告结构不是列表,尝试转换...") | 111 | + logger.info("报告结构不是列表,尝试转换...") |
| 111 | if isinstance(report_structure, dict): | 112 | if isinstance(report_structure, dict): |
| 112 | # 如果是单个对象,包装成列表 | 113 | # 如果是单个对象,包装成列表 |
| 113 | report_structure = [report_structure] | 114 | report_structure = [report_structure] |
| 114 | else: | 115 | else: |
| 115 | - self.log_error("报告结构格式无效,使用默认结构") | 116 | + logger.error("报告结构格式无效,使用默认结构") |
| 116 | return self._generate_default_structure() | 117 | return self._generate_default_structure() |
| 117 | 118 | ||
| 118 | # 验证每个段落 | 119 | # 验证每个段落 |
| 119 | validated_structure = [] | 120 | validated_structure = [] |
| 120 | for i, paragraph in enumerate(report_structure): | 121 | for i, paragraph in enumerate(report_structure): |
| 121 | if not isinstance(paragraph, dict): | 122 | if not isinstance(paragraph, dict): |
| 122 | - self.log_warning(f"段落 {i+1} 不是字典格式,跳过") | 123 | + logger.warning(f"段落 {i+1} 不是字典格式,跳过") |
| 123 | continue | 124 | continue |
| 124 | 125 | ||
| 125 | title = paragraph.get("title", f"段落 {i+1}") | 126 | title = paragraph.get("title", f"段落 {i+1}") |
| 126 | content = paragraph.get("content", "") | 127 | content = paragraph.get("content", "") |
| 127 | 128 | ||
| 128 | if not title or not content: | 129 | if not title or not content: |
| 129 | - self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过") | 130 | + logger.warning(f"段落 {i+1} 缺少标题或内容,跳过") |
| 130 | continue | 131 | continue |
| 131 | 132 | ||
| 132 | validated_structure.append({ | 133 | validated_structure.append({ |
| @@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode): | @@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode): | ||
| 135 | }) | 136 | }) |
| 136 | 137 | ||
| 137 | if not validated_structure: | 138 | if not validated_structure: |
| 138 | - self.log_warning("没有有效的段落结构,使用默认结构") | 139 | + logger.warning("没有有效的段落结构,使用默认结构") |
| 139 | return self._generate_default_structure() | 140 | return self._generate_default_structure() |
| 140 | 141 | ||
| 141 | - self.log_info(f"成功验证 {len(validated_structure)} 个段落结构") | 142 | + logger.info(f"成功验证 {len(validated_structure)} 个段落结构") |
| 142 | return validated_structure | 143 | return validated_structure |
| 143 | 144 | ||
| 144 | except Exception as e: | 145 | except Exception as e: |
| 145 | - self.log_error(f"处理输出失败: {str(e)}") | 146 | + logger.exception(f"处理输出失败: {str(e)}") |
| 146 | return self._generate_default_structure() | 147 | return self._generate_default_structure() |
| 147 | 148 | ||
| 148 | def _generate_default_structure(self) -> List[Dict[str, str]]: | 149 | def _generate_default_structure(self) -> List[Dict[str, str]]: |
| @@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode): | @@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode): | ||
| 152 | Returns: | 153 | Returns: |
| 153 | 默认的报告结构列表 | 154 | 默认的报告结构列表 |
| 154 | """ | 155 | """ |
| 155 | - self.log_info("生成默认报告结构") | 156 | + logger.info("生成默认报告结构") |
| 156 | return [ | 157 | return [ |
| 157 | { | 158 | { |
| 158 | "title": "研究概述", | 159 | "title": "研究概述", |
| @@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode): | @@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode): | ||
| 195 | content=paragraph_data["content"] | 196 | content=paragraph_data["content"] |
| 196 | ) | 197 | ) |
| 197 | 198 | ||
| 198 | - self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中") | 199 | + logger.info(f"已将 {len(report_structure)} 个段落添加到状态中") |
| 199 | return state | 200 | return state |
| 200 | 201 | ||
| 201 | except Exception as e: | 202 | except Exception as e: |
| 202 | - self.log_error(f"状态更新失败: {str(e)}") | 203 | + logger.exception(f"状态更新失败: {str(e)}") |
| 203 | raise e | 204 | raise e |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | import json | 6 | import json |
| 7 | from typing import Dict, Any | 7 | from typing import Dict, Any |
| 8 | from json.decoder import JSONDecodeError | 8 | from json.decoder import JSONDecodeError |
| 9 | +from loguru import logger | ||
| 9 | 10 | ||
| 10 | from .base_node import BaseNode | 11 | from .base_node import BaseNode |
| 11 | from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION | 12 | from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION |
| @@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode): | @@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode): | ||
| 62 | else: | 63 | else: |
| 63 | message = json.dumps(input_data, ensure_ascii=False) | 64 | message = json.dumps(input_data, ensure_ascii=False) |
| 64 | 65 | ||
| 65 | - self.log_info("正在生成首次搜索查询") | 66 | + logger.info("正在生成首次搜索查询") |
| 66 | 67 | ||
| 67 | # 调用LLM | 68 | # 调用LLM |
| 68 | response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message) | 69 | response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message) |
| @@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode): | @@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode): | ||
| 70 | # 处理响应 | 71 | # 处理响应 |
| 71 | processed_response = self.process_output(response) | 72 | processed_response = self.process_output(response) |
| 72 | 73 | ||
| 73 | - self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}") | 74 | + logger.info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}") |
| 74 | return processed_response | 75 | return processed_response |
| 75 | 76 | ||
| 76 | except Exception as e: | 77 | except Exception as e: |
| 77 | - self.log_error(f"生成首次搜索查询失败: {str(e)}") | 78 | + logger.exception(f"生成首次搜索查询失败: {str(e)}") |
| 78 | raise e | 79 | raise e |
| 79 | 80 | ||
| 80 | def process_output(self, output: str) -> Dict[str, str]: | 81 | def process_output(self, output: str) -> Dict[str, str]: |
| @@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode): | @@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode): | ||
| 93 | cleaned_output = clean_json_tags(cleaned_output) | 94 | cleaned_output = clean_json_tags(cleaned_output) |
| 94 | 95 | ||
| 95 | # 记录清理后的输出用于调试 | 96 | # 记录清理后的输出用于调试 |
| 96 | - self.log_info(f"清理后的输出: {cleaned_output}") | 97 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 97 | 98 | ||
| 98 | # 解析JSON | 99 | # 解析JSON |
| 99 | try: | 100 | try: |
| 100 | result = json.loads(cleaned_output) | 101 | result = json.loads(cleaned_output) |
| 101 | - self.log_info("JSON解析成功") | 102 | + logger.info("JSON解析成功") |
| 102 | except JSONDecodeError as e: | 103 | except JSONDecodeError as e: |
| 103 | - self.log_info(f"JSON解析失败: {str(e)}") | 104 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 104 | # 使用更强大的提取方法 | 105 | # 使用更强大的提取方法 |
| 105 | result = extract_clean_response(cleaned_output) | 106 | result = extract_clean_response(cleaned_output) |
| 106 | if "error" in result: | 107 | if "error" in result: |
| 107 | - self.log_error("JSON解析失败,尝试修复...") | 108 | + logger.error("JSON解析失败,尝试修复...") |
| 108 | # 尝试修复JSON | 109 | # 尝试修复JSON |
| 109 | fixed_json = fix_incomplete_json(cleaned_output) | 110 | fixed_json = fix_incomplete_json(cleaned_output) |
| 110 | if fixed_json: | 111 | if fixed_json: |
| 111 | try: | 112 | try: |
| 112 | result = json.loads(fixed_json) | 113 | result = json.loads(fixed_json) |
| 113 | - self.log_info("JSON修复成功") | 114 | + logger.info("JSON修复成功") |
| 114 | except JSONDecodeError: | 115 | except JSONDecodeError: |
| 115 | - self.log_error("JSON修复失败") | 116 | + logger.error("JSON修复失败") |
| 116 | # 返回默认查询 | 117 | # 返回默认查询 |
| 117 | return self._get_default_search_query() | 118 | return self._get_default_search_query() |
| 118 | else: | 119 | else: |
| 119 | - self.log_error("无法修复JSON,使用默认查询") | 120 | + logger.error("无法修复JSON,使用默认查询") |
| 120 | return self._get_default_search_query() | 121 | return self._get_default_search_query() |
| 121 | 122 | ||
| 122 | # 验证和清理结果 | 123 | # 验证和清理结果 |
| @@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode): | @@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode): | ||
| 124 | reasoning = result.get("reasoning", "") | 125 | reasoning = result.get("reasoning", "") |
| 125 | 126 | ||
| 126 | if not search_query: | 127 | if not search_query: |
| 127 | - self.log_warning("未找到搜索查询,使用默认查询") | 128 | + logger.warning("未找到搜索查询,使用默认查询") |
| 128 | return self._get_default_search_query() | 129 | return self._get_default_search_query() |
| 129 | 130 | ||
| 130 | return { | 131 | return { |
| @@ -197,7 +198,7 @@ class ReflectionNode(BaseNode): | @@ -197,7 +198,7 @@ class ReflectionNode(BaseNode): | ||
| 197 | else: | 198 | else: |
| 198 | message = json.dumps(input_data, ensure_ascii=False) | 199 | message = json.dumps(input_data, ensure_ascii=False) |
| 199 | 200 | ||
| 200 | - self.log_info("正在进行反思并生成新搜索查询") | 201 | + logger.info("正在进行反思并生成新搜索查询") |
| 201 | 202 | ||
| 202 | # 调用LLM | 203 | # 调用LLM |
| 203 | response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message) | 204 | response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message) |
| @@ -205,11 +206,11 @@ class ReflectionNode(BaseNode): | @@ -205,11 +206,11 @@ class ReflectionNode(BaseNode): | ||
| 205 | # 处理响应 | 206 | # 处理响应 |
| 206 | processed_response = self.process_output(response) | 207 | processed_response = self.process_output(response) |
| 207 | 208 | ||
| 208 | - self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}") | 209 | + logger.info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}") |
| 209 | return processed_response | 210 | return processed_response |
| 210 | 211 | ||
| 211 | except Exception as e: | 212 | except Exception as e: |
| 212 | - self.log_error(f"反思生成搜索查询失败: {str(e)}") | 213 | + logger.exception(f"反思生成搜索查询失败: {str(e)}") |
| 213 | raise e | 214 | raise e |
| 214 | 215 | ||
| 215 | def process_output(self, output: str) -> Dict[str, str]: | 216 | def process_output(self, output: str) -> Dict[str, str]: |
| @@ -228,30 +229,30 @@ class ReflectionNode(BaseNode): | @@ -228,30 +229,30 @@ class ReflectionNode(BaseNode): | ||
| 228 | cleaned_output = clean_json_tags(cleaned_output) | 229 | cleaned_output = clean_json_tags(cleaned_output) |
| 229 | 230 | ||
| 230 | # 记录清理后的输出用于调试 | 231 | # 记录清理后的输出用于调试 |
| 231 | - self.log_info(f"清理后的输出: {cleaned_output}") | 232 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 232 | 233 | ||
| 233 | # 解析JSON | 234 | # 解析JSON |
| 234 | try: | 235 | try: |
| 235 | result = json.loads(cleaned_output) | 236 | result = json.loads(cleaned_output) |
| 236 | - self.log_info("JSON解析成功") | 237 | + logger.info("JSON解析成功") |
| 237 | except JSONDecodeError as e: | 238 | except JSONDecodeError as e: |
| 238 | - self.log_info(f"JSON解析失败: {str(e)}") | 239 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 239 | # 使用更强大的提取方法 | 240 | # 使用更强大的提取方法 |
| 240 | result = extract_clean_response(cleaned_output) | 241 | result = extract_clean_response(cleaned_output) |
| 241 | if "error" in result: | 242 | if "error" in result: |
| 242 | - self.log_error("JSON解析失败,尝试修复...") | 243 | + logger.error("JSON解析失败,尝试修复...") |
| 243 | # 尝试修复JSON | 244 | # 尝试修复JSON |
| 244 | fixed_json = fix_incomplete_json(cleaned_output) | 245 | fixed_json = fix_incomplete_json(cleaned_output) |
| 245 | if fixed_json: | 246 | if fixed_json: |
| 246 | try: | 247 | try: |
| 247 | result = json.loads(fixed_json) | 248 | result = json.loads(fixed_json) |
| 248 | - self.log_info("JSON修复成功") | 249 | + logger.info("JSON修复成功") |
| 249 | except JSONDecodeError: | 250 | except JSONDecodeError: |
| 250 | - self.log_error("JSON修复失败") | 251 | + logger.error("JSON修复失败") |
| 251 | # 返回默认查询 | 252 | # 返回默认查询 |
| 252 | return self._get_default_reflection_query() | 253 | return self._get_default_reflection_query() |
| 253 | else: | 254 | else: |
| 254 | - self.log_error("无法修复JSON,使用默认查询") | 255 | + logger.error("无法修复JSON,使用默认查询") |
| 255 | return self._get_default_reflection_query() | 256 | return self._get_default_reflection_query() |
| 256 | 257 | ||
| 257 | # 验证和清理结果 | 258 | # 验证和清理结果 |
| @@ -259,7 +260,7 @@ class ReflectionNode(BaseNode): | @@ -259,7 +260,7 @@ class ReflectionNode(BaseNode): | ||
| 259 | reasoning = result.get("reasoning", "") | 260 | reasoning = result.get("reasoning", "") |
| 260 | 261 | ||
| 261 | if not search_query: | 262 | if not search_query: |
| 262 | - self.log_warning("未找到搜索查询,使用默认查询") | 263 | + logger.warning("未找到搜索查询,使用默认查询") |
| 263 | return self._get_default_reflection_query() | 264 | return self._get_default_reflection_query() |
| 264 | 265 | ||
| 265 | return { | 266 | return { |
| @@ -268,7 +269,7 @@ class ReflectionNode(BaseNode): | @@ -268,7 +269,7 @@ class ReflectionNode(BaseNode): | ||
| 268 | } | 269 | } |
| 269 | 270 | ||
| 270 | except Exception as e: | 271 | except Exception as e: |
| 271 | - self.log_error(f"处理输出失败: {str(e)}") | 272 | + logger.exception(f"处理输出失败: {str(e)}") |
| 272 | # 返回默认查询 | 273 | # 返回默认查询 |
| 273 | return self._get_default_reflection_query() | 274 | return self._get_default_reflection_query() |
| 274 | 275 |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | import json | 6 | import json |
| 7 | from typing import Dict, Any, List | 7 | from typing import Dict, Any, List |
| 8 | from json.decoder import JSONDecodeError | 8 | from json.decoder import JSONDecodeError |
| 9 | +from loguru import logger | ||
| 9 | 10 | ||
| 10 | from .base_node import StateMutationNode | 11 | from .base_node import StateMutationNode |
| 11 | from ..state.state import State | 12 | from ..state.state import State |
| @@ -27,7 +28,7 @@ try: | @@ -27,7 +28,7 @@ try: | ||
| 27 | FORUM_READER_AVAILABLE = True | 28 | FORUM_READER_AVAILABLE = True |
| 28 | except ImportError: | 29 | except ImportError: |
| 29 | FORUM_READER_AVAILABLE = False | 30 | FORUM_READER_AVAILABLE = False |
| 30 | - print("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能") | 31 | + logger.warning("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能") |
| 31 | 32 | ||
| 32 | 33 | ||
| 33 | class FirstSummaryNode(StateMutationNode): | 34 | class FirstSummaryNode(StateMutationNode): |
| @@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode): | @@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode): | ||
| 84 | if host_speech: | 85 | if host_speech: |
| 85 | # 将HOST发言添加到输入数据中 | 86 | # 将HOST发言添加到输入数据中 |
| 86 | data['host_speech'] = host_speech | 87 | data['host_speech'] = host_speech |
| 87 | - self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符") | 88 | + logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符") |
| 88 | except Exception as e: | 89 | except Exception as e: |
| 89 | - self.log_info(f"读取HOST发言失败: {str(e)}") | 90 | + logger.exception(f"读取HOST发言失败: {str(e)}") |
| 90 | 91 | ||
| 91 | # 转换为JSON字符串 | 92 | # 转换为JSON字符串 |
| 92 | message = json.dumps(data, ensure_ascii=False) | 93 | message = json.dumps(data, ensure_ascii=False) |
| @@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode): | @@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode): | ||
| 96 | formatted_host = format_host_speech_for_prompt(data['host_speech']) | 97 | formatted_host = format_host_speech_for_prompt(data['host_speech']) |
| 97 | message = formatted_host + "\n" + message | 98 | message = formatted_host + "\n" + message |
| 98 | 99 | ||
| 99 | - self.log_info("正在生成首次段落总结") | 100 | + logger.info("正在生成首次段落总结") |
| 100 | 101 | ||
| 101 | # 调用LLM生成总结 | 102 | # 调用LLM生成总结 |
| 102 | response = self.llm_client.invoke( | 103 | response = self.llm_client.invoke( |
| @@ -107,11 +108,11 @@ class FirstSummaryNode(StateMutationNode): | @@ -107,11 +108,11 @@ class FirstSummaryNode(StateMutationNode): | ||
| 107 | # 处理响应 | 108 | # 处理响应 |
| 108 | processed_response = self.process_output(response) | 109 | processed_response = self.process_output(response) |
| 109 | 110 | ||
| 110 | - self.log_info("成功生成首次段落总结") | 111 | + logger.info("成功生成首次段落总结") |
| 111 | return processed_response | 112 | return processed_response |
| 112 | 113 | ||
| 113 | except Exception as e: | 114 | except Exception as e: |
| 114 | - self.log_error(f"生成首次总结失败: {str(e)}") | 115 | + logger.exception(f"生成首次总结失败: {str(e)}") |
| 115 | raise e | 116 | raise e |
| 116 | 117 | ||
| 117 | def process_output(self, output: str) -> str: | 118 | def process_output(self, output: str) -> str: |
| @@ -130,26 +131,26 @@ class FirstSummaryNode(StateMutationNode): | @@ -130,26 +131,26 @@ class FirstSummaryNode(StateMutationNode): | ||
| 130 | cleaned_output = clean_json_tags(cleaned_output) | 131 | cleaned_output = clean_json_tags(cleaned_output) |
| 131 | 132 | ||
| 132 | # 记录清理后的输出用于调试 | 133 | # 记录清理后的输出用于调试 |
| 133 | - self.log_info(f"清理后的输出: {cleaned_output}") | 134 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 134 | 135 | ||
| 135 | # 解析JSON | 136 | # 解析JSON |
| 136 | try: | 137 | try: |
| 137 | result = json.loads(cleaned_output) | 138 | result = json.loads(cleaned_output) |
| 138 | - self.log_info("JSON解析成功") | 139 | + logger.info("JSON解析成功") |
| 139 | except JSONDecodeError as e: | 140 | except JSONDecodeError as e: |
| 140 | - self.log_info(f"JSON解析失败: {str(e)}") | 141 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 141 | # 尝试修复JSON | 142 | # 尝试修复JSON |
| 142 | fixed_json = fix_incomplete_json(cleaned_output) | 143 | fixed_json = fix_incomplete_json(cleaned_output) |
| 143 | if fixed_json: | 144 | if fixed_json: |
| 144 | try: | 145 | try: |
| 145 | result = json.loads(fixed_json) | 146 | result = json.loads(fixed_json) |
| 146 | - self.log_info("JSON修复成功") | 147 | + logger.info("JSON修复成功") |
| 147 | except JSONDecodeError: | 148 | except JSONDecodeError: |
| 148 | - self.log_info("JSON修复失败,直接使用清理后的文本") | 149 | + logger.exception("JSON修复失败,直接使用清理后的文本") |
| 149 | # 如果不是JSON格式,直接返回清理后的文本 | 150 | # 如果不是JSON格式,直接返回清理后的文本 |
| 150 | return cleaned_output | 151 | return cleaned_output |
| 151 | else: | 152 | else: |
| 152 | - self.log_info("无法修复JSON,直接使用清理后的文本") | 153 | + logger.exception("无法修复JSON,直接使用清理后的文本") |
| 153 | # 如果不是JSON格式,直接返回清理后的文本 | 154 | # 如果不是JSON格式,直接返回清理后的文本 |
| 154 | return cleaned_output | 155 | return cleaned_output |
| 155 | 156 | ||
| @@ -163,7 +164,7 @@ class FirstSummaryNode(StateMutationNode): | @@ -163,7 +164,7 @@ class FirstSummaryNode(StateMutationNode): | ||
| 163 | return cleaned_output | 164 | return cleaned_output |
| 164 | 165 | ||
| 165 | except Exception as e: | 166 | except Exception as e: |
| 166 | - self.log_error(f"处理输出失败: {str(e)}") | 167 | + logger.exception(f"处理输出失败: {str(e)}") |
| 167 | return "段落总结生成失败" | 168 | return "段落总结生成失败" |
| 168 | 169 | ||
| 169 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: | 170 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: |
| @@ -186,7 +187,7 @@ class FirstSummaryNode(StateMutationNode): | @@ -186,7 +187,7 @@ class FirstSummaryNode(StateMutationNode): | ||
| 186 | # 更新状态 | 187 | # 更新状态 |
| 187 | if 0 <= paragraph_index < len(state.paragraphs): | 188 | if 0 <= paragraph_index < len(state.paragraphs): |
| 188 | state.paragraphs[paragraph_index].research.latest_summary = summary | 189 | state.paragraphs[paragraph_index].research.latest_summary = summary |
| 189 | - self.log_info(f"已更新段落 {paragraph_index} 的首次总结") | 190 | + logger.info(f"已更新段落 {paragraph_index} 的首次总结") |
| 190 | else: | 191 | else: |
| 191 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") | 192 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") |
| 192 | 193 | ||
| @@ -194,7 +195,7 @@ class FirstSummaryNode(StateMutationNode): | @@ -194,7 +195,7 @@ class FirstSummaryNode(StateMutationNode): | ||
| 194 | return state | 195 | return state |
| 195 | 196 | ||
| 196 | except Exception as e: | 197 | except Exception as e: |
| 197 | - self.log_error(f"状态更新失败: {str(e)}") | 198 | + logger.exception(f"状态更新失败: {str(e)}") |
| 198 | raise e | 199 | raise e |
| 199 | 200 | ||
| 200 | 201 | ||
| @@ -252,9 +253,9 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -252,9 +253,9 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 252 | if host_speech: | 253 | if host_speech: |
| 253 | # 将HOST发言添加到输入数据中 | 254 | # 将HOST发言添加到输入数据中 |
| 254 | data['host_speech'] = host_speech | 255 | data['host_speech'] = host_speech |
| 255 | - self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符") | 256 | + logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符") |
| 256 | except Exception as e: | 257 | except Exception as e: |
| 257 | - self.log_info(f"读取HOST发言失败: {str(e)}") | 258 | + logger.exception(f"读取HOST发言失败: {str(e)}") |
| 258 | 259 | ||
| 259 | # 转换为JSON字符串 | 260 | # 转换为JSON字符串 |
| 260 | message = json.dumps(data, ensure_ascii=False) | 261 | message = json.dumps(data, ensure_ascii=False) |
| @@ -264,7 +265,7 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -264,7 +265,7 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 264 | formatted_host = format_host_speech_for_prompt(data['host_speech']) | 265 | formatted_host = format_host_speech_for_prompt(data['host_speech']) |
| 265 | message = formatted_host + "\n" + message | 266 | message = formatted_host + "\n" + message |
| 266 | 267 | ||
| 267 | - self.log_info("正在生成反思总结") | 268 | + logger.info("正在生成反思总结") |
| 268 | 269 | ||
| 269 | # 调用LLM生成总结 | 270 | # 调用LLM生成总结 |
| 270 | response = self.llm_client.invoke( | 271 | response = self.llm_client.invoke( |
| @@ -275,11 +276,11 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -275,11 +276,11 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 275 | # 处理响应 | 276 | # 处理响应 |
| 276 | processed_response = self.process_output(response) | 277 | processed_response = self.process_output(response) |
| 277 | 278 | ||
| 278 | - self.log_info("成功生成反思总结") | 279 | + logger.info("成功生成反思总结") |
| 279 | return processed_response | 280 | return processed_response |
| 280 | 281 | ||
| 281 | except Exception as e: | 282 | except Exception as e: |
| 282 | - self.log_error(f"生成反思总结失败: {str(e)}") | 283 | + logger.exception(f"生成反思总结失败: {str(e)}") |
| 283 | raise e | 284 | raise e |
| 284 | 285 | ||
| 285 | def process_output(self, output: str) -> str: | 286 | def process_output(self, output: str) -> str: |
| @@ -298,26 +299,26 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -298,26 +299,26 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 298 | cleaned_output = clean_json_tags(cleaned_output) | 299 | cleaned_output = clean_json_tags(cleaned_output) |
| 299 | 300 | ||
| 300 | # 记录清理后的输出用于调试 | 301 | # 记录清理后的输出用于调试 |
| 301 | - self.log_info(f"清理后的输出: {cleaned_output}") | 302 | + logger.info(f"清理后的输出: {cleaned_output}") |
| 302 | 303 | ||
| 303 | # 解析JSON | 304 | # 解析JSON |
| 304 | try: | 305 | try: |
| 305 | result = json.loads(cleaned_output) | 306 | result = json.loads(cleaned_output) |
| 306 | - self.log_info("JSON解析成功") | 307 | + logger.info("JSON解析成功") |
| 307 | except JSONDecodeError as e: | 308 | except JSONDecodeError as e: |
| 308 | - self.log_info(f"JSON解析失败: {str(e)}") | 309 | + logger.exception(f"JSON解析失败: {str(e)}") |
| 309 | # 尝试修复JSON | 310 | # 尝试修复JSON |
| 310 | fixed_json = fix_incomplete_json(cleaned_output) | 311 | fixed_json = fix_incomplete_json(cleaned_output) |
| 311 | if fixed_json: | 312 | if fixed_json: |
| 312 | try: | 313 | try: |
| 313 | result = json.loads(fixed_json) | 314 | result = json.loads(fixed_json) |
| 314 | - self.log_info("JSON修复成功") | 315 | + logger.info("JSON修复成功") |
| 315 | except JSONDecodeError: | 316 | except JSONDecodeError: |
| 316 | - self.log_info("JSON修复失败,直接使用清理后的文本") | 317 | + logger.exception("JSON修复失败,直接使用清理后的文本") |
| 317 | # 如果不是JSON格式,直接返回清理后的文本 | 318 | # 如果不是JSON格式,直接返回清理后的文本 |
| 318 | return cleaned_output | 319 | return cleaned_output |
| 319 | else: | 320 | else: |
| 320 | - self.log_info("无法修复JSON,直接使用清理后的文本") | 321 | + logger.exception("无法修复JSON,直接使用清理后的文本") |
| 321 | # 如果不是JSON格式,直接返回清理后的文本 | 322 | # 如果不是JSON格式,直接返回清理后的文本 |
| 322 | return cleaned_output | 323 | return cleaned_output |
| 323 | 324 | ||
| @@ -331,7 +332,7 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -331,7 +332,7 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 331 | return cleaned_output | 332 | return cleaned_output |
| 332 | 333 | ||
| 333 | except Exception as e: | 334 | except Exception as e: |
| 334 | - self.log_error(f"处理输出失败: {str(e)}") | 335 | + logger.exception(f"处理输出失败: {str(e)}") |
| 335 | return "反思总结生成失败" | 336 | return "反思总结生成失败" |
| 336 | 337 | ||
| 337 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: | 338 | def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State: |
| @@ -355,7 +356,7 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -355,7 +356,7 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 355 | if 0 <= paragraph_index < len(state.paragraphs): | 356 | if 0 <= paragraph_index < len(state.paragraphs): |
| 356 | state.paragraphs[paragraph_index].research.latest_summary = updated_summary | 357 | state.paragraphs[paragraph_index].research.latest_summary = updated_summary |
| 357 | state.paragraphs[paragraph_index].research.increment_reflection() | 358 | state.paragraphs[paragraph_index].research.increment_reflection() |
| 358 | - self.log_info(f"已更新段落 {paragraph_index} 的反思总结") | 359 | + logger.info(f"已更新段落 {paragraph_index} 的反思总结") |
| 359 | else: | 360 | else: |
| 360 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") | 361 | raise ValueError(f"段落索引 {paragraph_index} 超出范围") |
| 361 | 362 | ||
| @@ -363,5 +364,5 @@ class ReflectionSummaryNode(StateMutationNode): | @@ -363,5 +364,5 @@ class ReflectionSummaryNode(StateMutationNode): | ||
| 363 | return state | 364 | return state |
| 364 | 365 | ||
| 365 | except Exception as e: | 366 | except Exception as e: |
| 366 | - self.log_error(f"状态更新失败: {str(e)}") | 367 | + logger.exception(f"状态更新失败: {str(e)}") |
| 367 | raise e | 368 | raise e |
| @@ -12,7 +12,7 @@ from .text_processing import ( | @@ -12,7 +12,7 @@ from .text_processing import ( | ||
| 12 | format_search_results_for_prompt | 12 | format_search_results_for_prompt |
| 13 | ) | 13 | ) |
| 14 | 14 | ||
| 15 | -from .config import Config, load_config | 15 | +from .config import Settings |
| 16 | 16 | ||
| 17 | __all__ = [ | 17 | __all__ = [ |
| 18 | "clean_json_tags", | 18 | "clean_json_tags", |
| @@ -21,6 +21,5 @@ __all__ = [ | @@ -21,6 +21,5 @@ __all__ = [ | ||
| 21 | "extract_clean_response", | 21 | "extract_clean_response", |
| 22 | "update_state_with_search_results", | 22 | "update_state_with_search_results", |
| 23 | "format_search_results_for_prompt", | 23 | "format_search_results_for_prompt", |
| 24 | - "Config", | ||
| 25 | - "load_config" | 24 | + "Settings", |
| 26 | ] | 25 | ] |
| 1 | """ | 1 | """ |
| 2 | -Configuration management module for the Query Engine. | 2 | +Query Engine 配置管理模块 |
| 3 | + | ||
| 4 | +此模块使用 pydantic-settings 管理 Query Engine 的配置,支持从环境变量和 .env 文件自动加载。 | ||
| 5 | +数据模型定义位置: | ||
| 6 | +- 本文件 - 配置模型定义 | ||
| 3 | """ | 7 | """ |
| 4 | 8 | ||
| 5 | -import os | ||
| 6 | -from dataclasses import dataclass | 9 | +from pathlib import Path |
| 10 | +from pydantic_settings import BaseSettings | ||
| 11 | +from pydantic import Field | ||
| 7 | from typing import Optional | 12 | from typing import Optional |
| 8 | - | ||
| 9 | - | ||
| 10 | -def _get_value(source, key: str, default=None, *fallback_keys: str): | ||
| 11 | - candidates = (key,) + fallback_keys | ||
| 12 | - value = None | ||
| 13 | - for candidate in candidates: | ||
| 14 | - if isinstance(source, dict): | ||
| 15 | - value = source.get(candidate) | ||
| 16 | - else: | ||
| 17 | - value = getattr(source, candidate, None) | ||
| 18 | - if value not in (None, ""): | ||
| 19 | - break | ||
| 20 | - if value in (None, ""): | ||
| 21 | - for candidate in candidates: | ||
| 22 | - env_val = os.getenv(candidate) | ||
| 23 | - if env_val not in (None, ""): | ||
| 24 | - value = env_val | ||
| 25 | - break | ||
| 26 | - return value if value not in (None, "") else default | ||
| 27 | - | ||
| 28 | - | ||
| 29 | -@dataclass | ||
| 30 | -class Config: | ||
| 31 | - """Query Engine configuration.""" | ||
| 32 | - | ||
| 33 | - llm_api_key: Optional[str] = None | ||
| 34 | - llm_base_url: Optional[str] = None | ||
| 35 | - llm_model_name: Optional[str] = None | ||
| 36 | - llm_provider: Optional[str] = None # compatibility | ||
| 37 | - | ||
| 38 | - tavily_api_key: Optional[str] = None | ||
| 39 | - | ||
| 40 | - search_timeout: int = 240 | ||
| 41 | - max_content_length: int = 20000 | ||
| 42 | - max_reflections: int = 2 | ||
| 43 | - max_paragraphs: int = 5 | ||
| 44 | - max_search_results: int = 20 | ||
| 45 | - | ||
| 46 | - output_dir: str = "reports" | ||
| 47 | - save_intermediate_states: bool = True | ||
| 48 | - | ||
| 49 | - def __post_init__(self): | ||
| 50 | - if not self.llm_provider and self.llm_model_name: | ||
| 51 | - self.llm_provider = self.llm_model_name | ||
| 52 | - | ||
| 53 | - def validate(self) -> bool: | ||
| 54 | - if not self.llm_api_key: | ||
| 55 | - print("错误: Query Engine LLM API Key 未设置 (QUERY_ENGINE_API_KEY)。") | ||
| 56 | - return False | ||
| 57 | - if not self.llm_model_name: | ||
| 58 | - print("错误: Query Engine 模型名称未设置 (QUERY_ENGINE_MODEL_NAME)。") | ||
| 59 | - return False | ||
| 60 | - if not self.tavily_api_key: | ||
| 61 | - print("错误: Tavily API Key 未设置 (TAVILY_API_KEY)。") | ||
| 62 | - return False | ||
| 63 | - return True | ||
| 64 | - | ||
| 65 | - @classmethod | ||
| 66 | - def from_file(cls, config_file: str) -> "Config": | ||
| 67 | - if config_file.endswith(".py"): | ||
| 68 | - import importlib.util | ||
| 69 | - | ||
| 70 | - spec = importlib.util.spec_from_file_location("config", config_file) | ||
| 71 | - config_module = importlib.util.module_from_spec(spec) | ||
| 72 | - spec.loader.exec_module(config_module) | ||
| 73 | - | ||
| 74 | - return cls( | ||
| 75 | - llm_api_key=_get_value(config_module, "QUERY_ENGINE_API_KEY"), | ||
| 76 | - llm_base_url=_get_value(config_module, "QUERY_ENGINE_BASE_URL"), | ||
| 77 | - llm_model_name=_get_value(config_module, "QUERY_ENGINE_MODEL_NAME"), | ||
| 78 | - tavily_api_key=_get_value(config_module, "TAVILY_API_KEY"), | ||
| 79 | - search_timeout=int(_get_value(config_module, "SEARCH_TIMEOUT", 240)), | ||
| 80 | - max_content_length=int(_get_value(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000)), | ||
| 81 | - max_reflections=int(_get_value(config_module, "MAX_REFLECTIONS", 2)), | ||
| 82 | - max_paragraphs=int(_get_value(config_module, "MAX_PARAGRAPHS", 5)), | ||
| 83 | - max_search_results=int(_get_value(config_module, "MAX_SEARCH_RESULTS", 20)), | ||
| 84 | - output_dir=_get_value(config_module, "OUTPUT_DIR", "reports"), | ||
| 85 | - save_intermediate_states=str( | ||
| 86 | - _get_value(config_module, "SAVE_INTERMEDIATE_STATES", "true") | ||
| 87 | - ).lower() | ||
| 88 | - in ("true", "1", "yes"), | ||
| 89 | - ) | ||
| 90 | - | ||
| 91 | - config_dict = {} | ||
| 92 | - if os.path.exists(config_file): | ||
| 93 | - with open(config_file, "r", encoding="utf-8") as f: | ||
| 94 | - for line in f: | ||
| 95 | - line = line.strip() | ||
| 96 | - if line and not line.startswith("#") and "=" in line: | ||
| 97 | - key, value = line.split("=", 1) | ||
| 98 | - config_dict[key.strip()] = value.strip() | ||
| 99 | - | ||
| 100 | - return cls( | ||
| 101 | - llm_api_key=_get_value(config_dict, "QUERY_ENGINE_API_KEY"), | ||
| 102 | - llm_base_url=_get_value(config_dict, "QUERY_ENGINE_BASE_URL"), | ||
| 103 | - llm_model_name=_get_value(config_dict, "QUERY_ENGINE_MODEL_NAME"), | ||
| 104 | - tavily_api_key=_get_value(config_dict, "TAVILY_API_KEY"), | ||
| 105 | - search_timeout=int(_get_value(config_dict, "SEARCH_TIMEOUT", 240)), | ||
| 106 | - max_content_length=int(_get_value(config_dict, "SEARCH_CONTENT_MAX_LENGTH", 20000)), | ||
| 107 | - max_reflections=int(_get_value(config_dict, "MAX_REFLECTIONS", 2)), | ||
| 108 | - max_paragraphs=int(_get_value(config_dict, "MAX_PARAGRAPHS", 5)), | ||
| 109 | - max_search_results=int(_get_value(config_dict, "MAX_SEARCH_RESULTS", 20)), | ||
| 110 | - output_dir=_get_value(config_dict, "OUTPUT_DIR", "reports"), | ||
| 111 | - save_intermediate_states=str( | ||
| 112 | - _get_value(config_dict, "SAVE_INTERMEDIATE_STATES", "true") | ||
| 113 | - ).lower() | ||
| 114 | - in ("true", "1", "yes"), | ||
| 115 | - ) | ||
| 116 | - | ||
| 117 | - | ||
| 118 | -def load_config(config_file: Optional[str] = None) -> Config: | ||
| 119 | - if config_file: | ||
| 120 | - if not os.path.exists(config_file): | ||
| 121 | - raise FileNotFoundError(f"配置文件不存在: {config_file}") | ||
| 122 | - file_to_load = config_file | ||
| 123 | - else: | ||
| 124 | - for candidate in ("config.py", "config.env", ".env"): | ||
| 125 | - if os.path.exists(candidate): | ||
| 126 | - file_to_load = candidate | ||
| 127 | - print(f"已找到配置文件: {candidate}") | ||
| 128 | - break | ||
| 129 | - else: | ||
| 130 | - raise FileNotFoundError("未找到配置文件,请创建 config.py。") | ||
| 131 | - | ||
| 132 | - config = Config.from_file(file_to_load) | ||
| 133 | - if not config.validate(): | ||
| 134 | - raise ValueError("配置校验失败,请检查 config.py 中的相关配置。") | ||
| 135 | - return config | ||
| 136 | - | ||
| 137 | - | ||
| 138 | -def print_config(config: Config): | ||
| 139 | - print("\n=== Query Engine 配置 ===") | ||
| 140 | - print(f"LLM 模型: {config.llm_model_name}") | ||
| 141 | - print(f"LLM Base URL: {config.llm_base_url or '(默认)'}") | ||
| 142 | - print(f"Tavily API Key: {'已配置' if config.tavily_api_key else '未配置'}") | ||
| 143 | - print(f"搜索超时: {config.search_timeout} 秒") | ||
| 144 | - print(f"最长内容长度: {config.max_content_length}") | ||
| 145 | - print(f"最大反思次数: {config.max_reflections}") | ||
| 146 | - print(f"最大段落数: {config.max_paragraphs}") | ||
| 147 | - print(f"最大搜索结果数: {config.max_search_results}") | ||
| 148 | - print(f"输出目录: {config.output_dir}") | ||
| 149 | - print(f"保存中间状态: {config.save_intermediate_states}") | ||
| 150 | - print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}") | ||
| 151 | - print("========================\n") | 13 | +from loguru import logger |
| 14 | + | ||
| 15 | + | ||
| 16 | +# 计算 .env 优先级:优先当前工作目录,其次项目根目录 | ||
| 17 | +PROJECT_ROOT: Path = Path(__file__).resolve().parents[2] | ||
| 18 | +CWD_ENV: Path = Path.cwd() / ".env" | ||
| 19 | +ENV_FILE: str = str(CWD_ENV if CWD_ENV.exists() else (PROJECT_ROOT / ".env")) | ||
| 20 | + | ||
| 21 | + | ||
| 22 | +class Settings(BaseSettings): | ||
| 23 | + """ | ||
| 24 | + Query Engine 全局配置;支持 .env 和环境变量自动加载。 | ||
| 25 | + 变量名与原 config.py 大写一致,便于平滑过渡。 | ||
| 26 | + """ | ||
| 27 | + | ||
| 28 | + # ======================= LLM 相关 ======================= | ||
| 29 | + QUERY_ENGINE_API_KEY: str = Field(..., description="Query Engine LLM API密钥,用于主LLM。您可以更改每个部分LLM使用的API,🚩只要兼容OpenAI请求格式都可以,定义好KEY、BASE_URL与MODEL_NAME即可正常使用。") | ||
| 30 | + QUERY_ENGINE_BASE_URL: Optional[str] = Field(None, description="Query Engine LLM接口BaseUrl,可自定义厂商API") | ||
| 31 | + QUERY_ENGINE_MODEL_NAME: str = Field(..., description="Query Engine LLM模型名称") | ||
| 32 | + QUERY_ENGINE_PROVIDER: Optional[str] = Field(None, description="Query Engine LLM提供商(兼容字段)") | ||
| 33 | + | ||
| 34 | + # ================== 网络工具配置 ==================== | ||
| 35 | + TAVILY_API_KEY: str = Field(..., description="Tavily API(申请地址:https://www.tavily.com/)API密钥,用于Tavily网络搜索") | ||
| 36 | + | ||
| 37 | + # ================== 搜索参数配置 ==================== | ||
| 38 | + SEARCH_TIMEOUT: int = Field(240, description="搜索超时(秒)") | ||
| 39 | + SEARCH_CONTENT_MAX_LENGTH: int = Field(20000, description="用于提示的最长内容长度") | ||
| 40 | + MAX_REFLECTIONS: int = Field(2, description="最大反思轮数") | ||
| 41 | + MAX_PARAGRAPHS: int = Field(5, description="最大段落数") | ||
| 42 | + MAX_SEARCH_RESULTS: int = Field(20, description="最大搜索结果数") | ||
| 43 | + | ||
| 44 | + # ================== 输出配置 ==================== | ||
| 45 | + OUTPUT_DIR: str = Field("reports", description="输出目录") | ||
| 46 | + SAVE_INTERMEDIATE_STATES: bool = Field(True, description="是否保存中间状态") | ||
| 47 | + | ||
| 48 | + class Config: | ||
| 49 | + env_file = ENV_FILE | ||
| 50 | + env_prefix = "" | ||
| 51 | + case_sensitive = False | ||
| 52 | + extra = "allow" | ||
| 53 | + | ||
| 54 | + | ||
| 55 | +# 创建全局配置实例 | ||
| 56 | +settings = Settings() | ||
| 57 | + | ||
| 58 | +def print_config(config: Settings): | ||
| 59 | + """ | ||
| 60 | + 打印配置信息 | ||
| 61 | + | ||
| 62 | + Args: | ||
| 63 | + config: Settings配置对象 | ||
| 64 | + """ | ||
| 65 | + message = "" | ||
| 66 | + message += "=== Query Engine 配置 ===\n" | ||
| 67 | + message += f"LLM 模型: {config.QUERY_ENGINE_MODEL_NAME}\n" | ||
| 68 | + message += f"LLM Base URL: {config.QUERY_ENGINE_BASE_URL or '(默认)'}\n" | ||
| 69 | + message += f"Tavily API Key: {'已配置' if config.TAVILY_API_KEY else '未配置'}\n" | ||
| 70 | + message += f"搜索超时: {config.SEARCH_TIMEOUT} 秒\n" | ||
| 71 | + message += f"最长内容长度: {config.SEARCH_CONTENT_MAX_LENGTH}\n" | ||
| 72 | + message += f"最大反思次数: {config.MAX_REFLECTIONS}\n" | ||
| 73 | + message += f"最大段落数: {config.MAX_PARAGRAPHS}\n" | ||
| 74 | + message += f"最大搜索结果数: {config.MAX_SEARCH_RESULTS}\n" | ||
| 75 | + message += f"输出目录: {config.OUTPUT_DIR}\n" | ||
| 76 | + message += f"保存中间状态: {config.SAVE_INTERMEDIATE_STATES}\n" | ||
| 77 | + message += f"LLM API Key: {'已配置' if config.QUERY_ENGINE_API_KEY else '未配置'}\n" | ||
| 78 | + message += "========================\n" | ||
| 79 | + logger.info(message) |
-
Please register or login to post a comment