base_node.py
2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Report Engine节点基类。
所有高阶推理节点都继承于此,统一日志、输入校验与状态变更接口。
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from ..llms.base import LLMClient
from ..state.state import ReportState
from loguru import logger
class BaseNode(ABC):
"""
节点基类。
统一实现日志工具、输入/输出钩子以及LLM客户端依赖注入,
便于所有节点只专注业务逻辑。
"""
def __init__(self, llm_client: LLMClient, node_name: str = ""):
"""
初始化节点
Args:
llm_client: LLM客户端
node_name: 节点名称
BaseNode 会保存节点名以便统一输出日志前缀。
"""
self.llm_client = llm_client
self.node_name = node_name or self.__class__.__name__
@abstractmethod
def run(self, input_data: Any, **kwargs) -> Any:
"""
执行节点处理逻辑
Args:
input_data: 输入数据
**kwargs: 额外参数
Returns:
处理结果
"""
pass
def validate_input(self, input_data: Any) -> bool:
"""
验证输入数据。
默认直接通过,子类可按需覆写实现字段检查。
Args:
input_data: 输入数据
Returns:
验证是否通过
"""
return True
def process_output(self, output: Any) -> Any:
"""
处理输出数据。
子类可覆写进行结构化或校验。
Args:
output: 原始输出
Returns:
处理后的输出
"""
return output
def log_info(self, message: str):
"""记录信息日志,并自动带上节点名作为前缀。"""
formatted_message = f"[{self.node_name}] {message}"
logger.info(formatted_message)
def log_error(self, message: str):
"""记录错误日志,便于排障。"""
formatted_message = f"[{self.node_name}] {message}"
logger.error(formatted_message)
class StateMutationNode(BaseNode):
"""
带状态修改功能的节点基类。
适用于节点需要直接写入 ReportState 的场景。
"""
@abstractmethod
def mutate_state(self, input_data: Any, state: ReportState, **kwargs) -> ReportState:
"""
修改状态。
子类需返回新的状态对象或在原地修改后回传,供流水线记录。
Args:
input_data: 输入数据
state: 当前状态
**kwargs: 额外参数
Returns:
修改后的状态
"""
pass