马一丁

GraphRAG

... ... @@ -75,4 +75,12 @@ ANSPIRE_API_KEY=
# Bocha AI Search API(用于Bocha多模态搜索,这里密钥名称虽然是Web Search,但其实是要AI Search的,申请地址:https://open.bochaai.com/)
BOCHA_BASE_URL=https://api.bocha.cn/v1/ai-search
BOCHA_WEB_SEARCH_API_KEY=
\ No newline at end of file
BOCHA_WEB_SEARCH_API_KEY=
# ================== GraphRAG 配置 ====================
# GraphRAG 功能开关(true/false),默认关闭
# 开启后会构建知识图谱并在章节生成前进行图谱查询
GRAPHRAG_ENABLED=false
# GraphRAG 查询次数上限(每个章节生成前LLM可查询知识图谱的最大次数)
# 仅在 GRAPHRAG_ENABLED=true 时生效
GRAPHRAG_MAX_QUERIES=3
... ...
... ... @@ -39,6 +39,21 @@ from .renderers import HTMLRenderer
from .state import ReportState
from .utils.config import settings, Settings
# GraphRAG 模块导入
from .graphrag import (
StateParser,
ForumParser,
GraphBuilder,
GraphStorage,
Graph,
QueryEngine,
)
from .nodes import GraphRAGQueryNode
from .graphrag.prompts import (
SYSTEM_PROMPT_CHAPTER_GRAPH_ENHANCEMENT,
format_graph_results_for_prompt
)
class StageOutputFormatError(ValueError):
"""阶段性输出结构不符合预期时抛出的受控异常。"""
... ... @@ -559,6 +574,37 @@ class ReportAgent:
self._persist_planning_artifacts(run_dir, layout_design, word_plan, template_overview)
emit('stage', {'stage': 'storage_ready', 'run_dir': str(run_dir)})
# ==================== GraphRAG 初始化 ====================
graphrag_enabled = getattr(self.config, 'GRAPHRAG_ENABLED', False)
knowledge_graph = None
graphrag_query_node = None
if graphrag_enabled:
logger.info("GraphRAG 已启用,开始构建知识图谱...")
emit('stage', {'stage': 'graphrag_building', 'message': '正在构建知识图谱'})
try:
knowledge_graph = self._build_knowledge_graph(
query, normalized_reports, forum_logs, run_dir
)
if knowledge_graph:
graphrag_query_node = GraphRAGQueryNode(self.llm_client)
graph_stats = knowledge_graph.get_stats()
emit('stage', {
'stage': 'graphrag_built',
'node_count': graph_stats.get('total_nodes', 0),
'edge_count': graph_stats.get('total_edges', 0)
})
logger.info(f"知识图谱构建完成: {graph_stats}")
else:
logger.warning("知识图谱构建失败,将使用原始流程")
graphrag_enabled = False
except Exception as graph_error:
logger.exception(f"GraphRAG 构建异常: {graph_error}")
graphrag_enabled = False
emit('stage', {'stage': 'graphrag_error', 'error': str(graph_error)})
# ==================== GraphRAG 初始化结束 ====================
chapters = []
chapter_max_attempts = max(
self._CONTENT_SPARSE_MIN_ATTEMPTS, self.config.CHAPTER_JSON_MAX_ATTEMPTS
... ... @@ -594,11 +640,47 @@ class ReportAgent:
best_sparse_candidate: Dict[str, Any] | None = None
best_sparse_score = -1
fallback_used = False
# ==================== GraphRAG 查询 ====================
graph_results = None
chapter_context = generation_context.copy()
if graphrag_enabled and knowledge_graph and graphrag_query_node:
try:
max_queries = getattr(self.config, 'GRAPHRAG_MAX_QUERIES', 3)
section_info = {
'title': section.title,
'id': section.chapter_id,
'role': section.description,
'target_words': chapter_targets.get(section.chapter_id, {}).get('targetWords', 500),
'emphasis': chapter_targets.get(section.chapter_id, {}).get('emphasisPoints', '')
}
graph_results = graphrag_query_node.run(
section_info,
{
'query': query,
'template_name': template_result.get('template_name'),
'chapters': word_plan.get('chapters', [])
},
knowledge_graph,
max_queries=max_queries
)
if graph_results and graph_results.get('total_nodes', 0) > 0:
# 将图谱结果注入生成上下文
chapter_context['graph_results'] = graph_results
chapter_context['graph_enhancement_prompt'] = format_graph_results_for_prompt(graph_results)
logger.info(f"章节 {section.title} GraphRAG 查询完成: {graph_results.get('total_nodes', 0)} 节点")
except Exception as graph_query_error:
logger.warning(f"GraphRAG 查询失败 ({section.title}): {graph_query_error}")
# ==================== GraphRAG 查询结束 ====================
while attempt <= chapter_max_attempts:
try:
chapter_payload = self.chapter_generation_node.run(
section,
generation_context,
chapter_context, # 使用包含图谱结果的上下文
run_dir,
stream_callback=chunk_callback
)
... ... @@ -796,6 +878,62 @@ class ReportAgent:
self.state.metadata.template_used = fallback_template['template_name']
return fallback_template
def _build_knowledge_graph(
self,
query: str,
reports: Dict[str, str],
forum_logs: str,
run_dir: Path
) -> Optional[Graph]:
"""
构建知识图谱。
从已加载的 State JSON 和论坛日志中提取结构化数据,
构建知识图谱供后续章节生成时查询。
参数:
query: 用户查询主题。
reports: 归一化后的报告映射。
forum_logs: 论坛日志内容。
run_dir: 运行目录,用于保存图谱。
返回:
Graph: 构建好的知识图谱;失败返回 None。
"""
try:
# 解析 State JSON(如果在 load_input_files 时已加载)
states = {}
state_parser = StateParser()
# 尝试从 reports 目录查找 State JSON
# 注意:这里假设 reports 字典的键对应引擎目录
for engine in ['insight', 'media', 'query']:
# 尝试从全局状态获取(如果之前已加载)
if hasattr(self, '_loaded_states') and engine in self._loaded_states:
states[engine] = self._loaded_states[engine]
# 解析论坛日志
forum_entries = []
if forum_logs:
forum_parser = ForumParser()
forum_entries = forum_parser.parse(forum_logs)
logger.info(f"解析论坛日志: {len(forum_entries)} 条记录")
# 构建图谱
builder = GraphBuilder()
graph = builder.build(query, states, forum_entries)
# 保存图谱
storage = GraphStorage()
graph_path = storage.save(graph, self.state.task_id, run_dir)
logger.info(f"知识图谱已保存: {graph_path}")
return graph
except Exception as e:
logger.exception(f"构建知识图谱失败: {e}")
return None
def _slice_template(self, template_markdown: str) -> List[TemplateSection]:
"""
将模板切成章节列表,若为空则提供fallback。
... ... @@ -1459,20 +1597,23 @@ class ReportAgent:
def load_input_files(self, file_paths: Dict[str, str]) -> Dict[str, Any]:
"""
加载输入文件内容
Args:
file_paths: 文件路径字典
Returns:
加载的内容字典,包含 `reports` 列表与 `forum_logs` 字符串
加载的内容字典,包含 `reports` 列表、`forum_logs` 字符串和 `states` 字典
"""
content = {
'reports': [],
'forum_logs': ''
'forum_logs': '',
'states': {} # 新增:用于 GraphRAG 的 State JSON
}
# 加载报告文件
engines = ['query', 'media', 'insight']
state_parser = StateParser()
for engine in engines:
if engine in file_paths:
try:
... ... @@ -1480,10 +1621,24 @@ class ReportAgent:
report_content = f.read()
content['reports'].append(report_content)
logger.info(f"已加载 {engine} 报告: {len(report_content)} 字符")
# 新增:尝试查找并加载对应的 State JSON(用于 GraphRAG)
if self.config.GRAPHRAG_ENABLED:
state_path = state_parser.find_state_json(file_paths[engine])
if state_path:
parsed_state = state_parser.parse_from_file(engine, state_path)
if parsed_state:
content['states'][engine] = parsed_state
# 同时保存到实例属性,供 _build_knowledge_graph 使用
if not hasattr(self, '_loaded_states'):
self._loaded_states = {}
self._loaded_states[engine] = parsed_state
logger.info(f"已加载 {engine} State JSON: {len(parsed_state.sections)} 个段落")
except Exception as e:
logger.exception(f"加载 {engine} 报告失败: {str(e)}")
content['reports'].append("")
# 加载论坛日志
if 'forum' in file_paths:
try:
... ... @@ -1492,7 +1647,7 @@ class ReportAgent:
logger.info(f"已加载论坛日志: {len(content['forum_logs'])} 字符")
except Exception as e:
logger.exception(f"加载论坛日志失败: {str(e)}")
return content
... ...
"""
GraphRAG 知识图谱模块
提供基于结构化数据的知识图谱构建、存储与查询功能。
"""
from .state_parser import StateParser, ParsedState, ParsedSection, SearchRecord
from .forum_parser import ForumParser, ForumEntry
from .graph_builder import GraphBuilder
from .graph_storage import GraphStorage, Graph, Node, Edge
from .query_engine import QueryEngine, QueryParams, QueryResult
__all__ = [
# 解析器
'StateParser',
'ParsedState',
'ParsedSection',
'SearchRecord',
'ForumParser',
'ForumEntry',
# 图谱核心
'GraphBuilder',
'GraphStorage',
'Graph',
'Node',
'Edge',
# 查询引擎
'QueryEngine',
'QueryParams',
'QueryResult',
]
... ...
"""
Forum 日志解析器
解析 forum.log 文件,提取结构化的讨论记录用于构建知识图谱。
"""
from dataclasses import dataclass
from typing import List, Optional
import re
@dataclass
class ForumEntry:
"""论坛讨论条目"""
timestamp: str
speaker: str
content: str
@property
def is_host(self) -> bool:
"""是否为主持人发言"""
return self.speaker.upper() == 'HOST'
@property
def is_system(self) -> bool:
"""是否为系统消息"""
return self.speaker.upper() == 'SYSTEM'
@property
def engine_name(self) -> Optional[str]:
"""获取对应的引擎名称(小写)"""
speaker_upper = self.speaker.upper()
if speaker_upper in ['INSIGHT', 'MEDIA', 'QUERY', 'HOST']:
return speaker_upper.lower()
return None
class ForumParser:
"""
Forum 日志解析器
解析 forum.log,提取结构化的讨论记录。
日志格式: [HH:MM:SS] [SPEAKER] content
"""
# 匹配日志行的正则表达式
PATTERN = re.compile(r'\[(\d{2}:\d{2}:\d{2})\]\s*\[(\w+)\]\s*(.+)')
# 有效的发言者
VALID_SPEAKERS = {'INSIGHT', 'MEDIA', 'QUERY', 'HOST', 'SYSTEM'}
def parse(self, forum_logs: str) -> List[ForumEntry]:
"""
解析 forum.log 内容
Args:
forum_logs: forum.log 文件内容
Returns:
ForumEntry 列表
"""
if not forum_logs:
return []
entries = []
for line in forum_logs.strip().split('\n'):
if not line.strip():
continue
match = self.PATTERN.match(line)
if match:
timestamp, speaker, content = match.groups()
speaker_upper = speaker.upper()
if speaker_upper in self.VALID_SPEAKERS:
# 处理转义的换行符
content = content.replace('\\n', '\n')
entries.append(ForumEntry(
timestamp=timestamp,
speaker=speaker_upper,
content=content
))
return entries
def get_host_insights(self, entries: List[ForumEntry]) -> List[str]:
"""
提取 Host(主持人)的发言内容
Args:
entries: ForumEntry 列表
Returns:
Host 发言内容列表
"""
return [e.content for e in entries if e.is_host]
def get_engine_entries(self, entries: List[ForumEntry],
engine: str) -> List[ForumEntry]:
"""
获取指定引擎的发言
Args:
entries: ForumEntry 列表
engine: 引擎名称 (insight/media/query/host)
Returns:
该引擎的 ForumEntry 列表
"""
engine_upper = engine.upper()
return [e for e in entries if e.speaker == engine_upper]
def get_summary_by_engine(self, entries: List[ForumEntry]) -> dict:
"""
按引擎分组统计发言
Args:
entries: ForumEntry 列表
Returns:
{engine: [contents]} 字典
"""
result = {
'insight': [],
'media': [],
'query': [],
'host': []
}
for entry in entries:
engine = entry.engine_name
if engine and engine in result:
result[engine].append(entry.content)
return result
def extract_key_points(self, entries: List[ForumEntry],
max_points: int = 10) -> List[str]:
"""
提取关键观点(优先 Host 发言)
Args:
entries: ForumEntry 列表
max_points: 最大提取数量
Returns:
关键观点列表
"""
key_points = []
# 优先提取 Host 的发言
for entry in entries:
if entry.is_host and not entry.is_system:
# 提取前 200 字作为摘要
summary = entry.content[:200]
if len(entry.content) > 200:
summary += '...'
key_points.append(f"[{entry.speaker}] {summary}")
if len(key_points) >= max_points:
break
return key_points
... ...
"""
知识图谱构建器
基于结构化的 State JSON 和 Forum 日志构建知识图谱,无需 LLM 提取实体。
"""
from typing import Dict, List, Optional
import hashlib
from .state_parser import ParsedState, ParsedSection
from .forum_parser import ForumEntry
from .graph_storage import Graph, Node
class GraphBuilder:
"""
知识图谱构建器
基于已有的结构化数据(State JSON、Forum 日志)构建图谱,
无需 LLM 进行实体/关系提取。
节点类型(5种):
- topic: 用户查询主题
- engine: 四个引擎来源 (insight/media/query/host)
- section: 报告段落/章节
- search_query: 搜索关键词
- source: 信息来源 URL
关系类型(4种):
- analyzed_by: 主题由引擎分析 (Topic → Engine)
- contains: 引擎包含段落 (Engine → Section)
- searched: 段落执行搜索 (Section → SearchQuery)
- found: 搜索发现来源 (SearchQuery → Source)
"""
def build(self, topic: str, states: Dict[str, ParsedState],
forum_entries: Optional[List[ForumEntry]] = None) -> Graph:
"""
构建知识图谱
Args:
topic: 用户查询主题
states: 引擎状态字典 {engine_name: ParsedState}
forum_entries: Forum 日志条目列表
Returns:
构建的 Graph 对象
"""
graph = Graph()
# 1. 创建主题节点
topic_node = graph.add_node(
node_type="topic",
name=topic,
node_id=f"T_{self._hash(topic)}"
)
# 2. 处理每个引擎的状态
for engine_name, state in states.items():
self._add_engine_nodes(graph, topic_node, engine_name, state)
# 3. 处理 Forum 日志(添加 Host 节点)
if forum_entries:
self._add_forum_nodes(graph, topic_node, forum_entries)
return graph
def _add_engine_nodes(self, graph: Graph, topic_node: Node,
engine_name: str, state: ParsedState) -> None:
"""添加引擎相关节点"""
# 创建引擎节点
engine_node = graph.add_node(
node_type="engine",
name=engine_name,
node_id=engine_name,
report_title=state.report_title,
original_query=state.query
)
# Topic → Engine 关系
graph.add_edge(topic_node, engine_node, "analyzed_by")
# 处理段落
for section in state.sections:
self._add_section_nodes(graph, engine_node, engine_name, section)
def _add_section_nodes(self, graph: Graph, engine_node: Node,
engine_name: str, section: ParsedSection) -> None:
"""添加段落相关节点"""
# 创建段落节点
section_id = f"{engine_name}_S{section.order}"
section_node = graph.add_node(
node_type="section",
name=section.title,
node_id=section_id,
title=section.title,
order=section.order,
summary=section.summary,
engine=engine_name
)
# Engine → Section 关系
graph.add_edge(engine_node, section_node, "contains")
# 处理搜索历史
seen_queries = set() # 去重
for idx, search in enumerate(section.search_history):
if not search.query:
continue
# 搜索词去重
query_key = search.query.strip().lower()
if query_key in seen_queries:
continue
seen_queries.add(query_key)
# 创建搜索词节点
query_id = f"{section_id}_Q{idx}"
query_node = graph.add_node(
node_type="search_query",
name=search.query[:50], # 截断长查询
node_id=query_id,
query_text=search.query,
section_ref=section_id,
engine=engine_name
)
# Section → SearchQuery 关系
graph.add_edge(section_node, query_node, "searched")
# 处理来源
if search.url:
self._add_source_node(graph, query_node, search)
def _add_source_node(self, graph: Graph, query_node: Node,
search) -> None:
"""添加来源节点"""
# 使用 URL 的哈希作为 ID,避免重复
source_id = f"SRC_{self._hash(search.url)}"
# 检查是否已存在
existing = graph.get_node(source_id)
if existing:
source_node = existing
else:
source_node = graph.add_node(
node_type="source",
name=search.title[:50] if search.title else search.url[:50],
node_id=source_id,
url=search.url,
title=search.title,
preview=search.content[:100] if search.content else '',
score=search.score
)
# SearchQuery → Source 关系
graph.add_edge(query_node, source_node, "found")
def _add_forum_nodes(self, graph: Graph, topic_node: Node,
entries: List[ForumEntry]) -> None:
"""添加 Forum 日志相关节点"""
# 创建 Host 引擎节点(如果不存在)
host_node = graph.get_node('host')
if not host_node:
host_node = graph.add_node(
node_type="engine",
name="host",
node_id="host",
report_title="论坛主持人总结"
)
graph.add_edge(topic_node, host_node, "analyzed_by")
# 提取 Host 的关键发言作为 Section
host_entries = [e for e in entries if e.is_host and not e.is_system]
for idx, entry in enumerate(host_entries[:5]): # 最多取 5 条
section_id = f"host_S{idx}"
section_node = graph.add_node(
node_type="section",
name=f"主持人总结 {idx + 1}",
node_id=section_id,
title=f"[{entry.timestamp}] 主持人总结",
order=idx,
summary=entry.content[:300],
engine="host",
timestamp=entry.timestamp
)
graph.add_edge(host_node, section_node, "contains")
@staticmethod
def _hash(text: str) -> str:
"""生成短哈希"""
return hashlib.md5(text.encode()).hexdigest()[:8]
... ...
"""
知识图谱存储模块
定义图谱的核心数据结构(Node、Edge、Graph)及 JSON 存储功能。
"""
from dataclasses import dataclass, field
from typing import Dict, Any, List, Optional, Set
from datetime import datetime
import json
from pathlib import Path
import hashlib
@dataclass
class Node:
"""图谱节点"""
id: str
type: str # topic, engine, section, search_query, source
name: str = ""
attributes: Dict[str, Any] = field(default_factory=dict)
@property
def label(self) -> str:
"""获取显示标签(兼容前端)"""
return self.name
@property
def properties(self) -> Dict[str, Any]:
"""获取属性(兼容前端)"""
return self.attributes
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
'id': self.id,
'type': self.type,
'name': self.name,
'label': self.name, # 兼容字段
'attributes': self.attributes,
'properties': self.attributes # 兼容字段
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Node':
"""从字典创建"""
return cls(
id=data['id'],
type=data['type'],
name=data.get('name', data.get('label', '')),
attributes=data.get('attributes', data.get('properties', {}))
)
def get(self, key: str, default: Any = None) -> Any:
"""获取属性值"""
if key == 'id':
return self.id
if key == 'type':
return self.type
if key in ('name', 'label'):
return self.name
return self.attributes.get(key, default)
@dataclass
class Edge:
"""图谱边"""
from_id: str
to_id: str
relation: str # analyzed_by, contains, searched, found
weight: float = 1.0
attributes: Dict[str, Any] = field(default_factory=dict)
@property
def source(self) -> str:
"""起始节点ID(兼容前端)"""
return self.from_id
@property
def target(self) -> str:
"""目标节点ID(兼容前端)"""
return self.to_id
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
'from': self.from_id,
'to': self.to_id,
'source': self.from_id, # 兼容字段
'target': self.to_id, # 兼容字段
'relation': self.relation,
'weight': self.weight,
'attributes': self.attributes
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Edge':
"""从字典创建"""
return cls(
from_id=data.get('from', data.get('source', '')),
to_id=data.get('to', data.get('target', '')),
relation=data['relation'],
weight=data.get('weight', 1.0),
attributes=data.get('attributes', {})
)
class Graph:
"""知识图谱"""
def __init__(self):
self._nodes: Dict[str, Node] = {}
self._edges: List[Edge] = []
self._adjacency: Dict[str, Set[str]] = {} # 邻接表
@property
def nodes(self) -> Dict[str, Node]:
"""获取所有节点(字典形式,兼容前端API)"""
return self._nodes
@property
def node_list(self) -> List[Node]:
"""获取所有节点(列表形式)"""
return list(self._nodes.values())
@property
def edges(self) -> List[Edge]:
"""获取所有边"""
return self._edges
@property
def node_count(self) -> int:
"""节点数量"""
return len(self._nodes)
@property
def edge_count(self) -> int:
"""边数量"""
return len(self._edges)
def add_node(self, node_type: str, name: str = "",
node_id: Optional[str] = None, **attributes) -> Node:
"""
添加节点
Args:
node_type: 节点类型
name: 节点名称
node_id: 节点ID,不提供则自动生成
**attributes: 其他属性
Returns:
创建的节点
"""
if node_id is None:
# 基于类型和名称生成ID
hash_input = f"{node_type}_{name}_{len(self._nodes)}"
node_id = f"{node_type[:3].upper()}_{hashlib.md5(hash_input.encode()).hexdigest()[:8]}"
# 如果已存在,返回现有节点
if node_id in self._nodes:
return self._nodes[node_id]
node = Node(
id=node_id,
type=node_type,
name=name,
attributes=attributes
)
self._nodes[node_id] = node
self._adjacency[node_id] = set()
return node
def get_node(self, node_id: str) -> Optional[Node]:
"""获取节点"""
return self._nodes.get(node_id)
def add_edge(self, from_node: Node, to_node: Node,
relation: str, weight: float = 1.0, **attributes) -> Edge:
"""
添加边
Args:
from_node: 起始节点
to_node: 目标节点
relation: 关系类型
weight: 权重
**attributes: 其他属性
Returns:
创建的边
"""
edge = Edge(
from_id=from_node.id,
to_id=to_node.id,
relation=relation,
weight=weight,
attributes=attributes
)
self._edges.append(edge)
# 更新邻接表
if from_node.id in self._adjacency:
self._adjacency[from_node.id].add(to_node.id)
if to_node.id in self._adjacency:
self._adjacency[to_node.id].add(from_node.id)
return edge
def get_neighbors(self, node_id: str) -> List[Node]:
"""获取邻居节点"""
neighbor_ids = self._adjacency.get(node_id, set())
return [self._nodes[nid] for nid in neighbor_ids if nid in self._nodes]
def get_edges_from(self, node_id: str) -> List[Edge]:
"""获取从指定节点出发的边"""
return [e for e in self._edges if e.from_id == node_id]
def get_edges_to(self, node_id: str) -> List[Edge]:
"""获取指向指定节点的边"""
return [e for e in self._edges if e.to_id == node_id]
def get_nodes_by_type(self, node_type: str) -> List[Node]:
"""按类型获取节点"""
return [n for n in self._nodes.values() if n.type == node_type]
def get_stats(self) -> Dict[str, int]:
"""获取图谱统计信息"""
type_counts = {}
for node in self._nodes.values():
type_counts[node.type] = type_counts.get(node.type, 0) + 1
return {
'total_nodes': self.node_count,
'total_edges': self.edge_count,
**type_counts
}
def get_summary(self) -> Dict[str, Any]:
"""获取图谱概览(用于提示词)"""
stats = self.get_stats()
# 获取各类型节点的样例
section_titles = [n.name for n in self.get_nodes_by_type('section')][:10]
search_queries = [n.get('query_text', n.name)
for n in self.get_nodes_by_type('search_query')][:20]
return {
'stats': stats,
'section_titles': section_titles,
'sample_queries': search_queries,
'topic': next((n.name for n in self.get_nodes_by_type('topic')), ''),
'engines': [n.name for n in self.get_nodes_by_type('engine')]
}
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
'nodes': [n.to_dict() for n in self.node_list],
'edges': [e.to_dict() for e in self.edges],
'stats': self.get_stats()
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Graph':
"""从字典创建"""
graph = cls()
# 添加节点
for node_data in data.get('nodes', []):
node = Node.from_dict(node_data)
graph._nodes[node.id] = node
graph._adjacency[node.id] = set()
# 添加边
for edge_data in data.get('edges', []):
edge = Edge.from_dict(edge_data)
graph._edges.append(edge)
# 更新邻接表
if edge.from_id in graph._adjacency:
graph._adjacency[edge.from_id].add(edge.to_id)
if edge.to_id in graph._adjacency:
graph._adjacency[edge.to_id].add(edge.from_id)
return graph
class GraphStorage:
"""图谱存储管理器"""
FILENAME = "graphrag.json"
DEFAULT_CHAPTERS_DIR = Path("chapters")
def save(self, graph: Graph, task_id: str, run_dir: Path) -> Path:
"""
保存图谱到 JSON 文件
Args:
graph: 图谱对象
task_id: 任务ID
run_dir: 运行目录
Returns:
保存的文件路径
"""
run_dir = Path(run_dir)
run_dir.mkdir(parents=True, exist_ok=True)
output = {
'task_id': task_id,
'created_at': datetime.now().isoformat(),
**graph.to_dict()
}
file_path = run_dir / self.FILENAME
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(output, f, ensure_ascii=False, indent=2)
return file_path
def load(self, path: Path) -> Optional[Graph]:
"""
从 JSON 文件加载图谱
Args:
path: 文件路径或运行目录
Returns:
Graph 对象,失败返回 None
"""
path = Path(path)
# 如果是目录,添加文件名
if path.is_dir():
file_path = path / self.FILENAME
else:
file_path = path
if not file_path.exists():
return None
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return Graph.from_dict(data)
except Exception:
return None
def exists(self, run_dir: Path) -> bool:
"""检查图谱文件是否存在"""
return (Path(run_dir) / self.FILENAME).exists()
def find_graph_by_report_id(self, report_id: str) -> Optional[Path]:
"""
根据报告ID查找图谱文件
Args:
report_id: 报告ID
Returns:
图谱文件路径,未找到返回 None
"""
# 在默认目录中搜索
chapters_dir = self.DEFAULT_CHAPTERS_DIR
if not chapters_dir.exists():
return None
# 查找匹配报告ID的目录
for run_dir in chapters_dir.iterdir():
if not run_dir.is_dir():
continue
# 检查目录名是否包含报告ID
if report_id in run_dir.name:
graph_path = run_dir / self.FILENAME
if graph_path.exists():
return graph_path
return None
def find_latest_graph(self) -> Optional[Path]:
"""
查找最新的图谱文件
Returns:
最新图谱文件路径,未找到返回 None
"""
chapters_dir = self.DEFAULT_CHAPTERS_DIR
if not chapters_dir.exists():
return None
latest_path = None
latest_time = None
# 遍历所有运行目录
for run_dir in chapters_dir.iterdir():
if not run_dir.is_dir():
continue
graph_path = run_dir / self.FILENAME
if graph_path.exists():
mtime = graph_path.stat().st_mtime
if latest_time is None or mtime > latest_time:
latest_time = mtime
latest_path = graph_path
return latest_path
def list_all_graphs(self) -> List[Dict[str, Any]]:
"""
列出所有可用的图谱
Returns:
图谱信息列表,包含路径、报告ID、创建时间等
"""
chapters_dir = self.DEFAULT_CHAPTERS_DIR
if not chapters_dir.exists():
return []
graphs = []
for run_dir in chapters_dir.iterdir():
if not run_dir.is_dir():
continue
graph_path = run_dir / self.FILENAME
if graph_path.exists():
try:
with open(graph_path, 'r', encoding='utf-8') as f:
data = json.load(f)
graphs.append({
'path': str(graph_path),
'report_id': data.get('task_id', run_dir.name),
'created_at': data.get('created_at'),
'stats': data.get('stats', {}),
'dir_name': run_dir.name
})
except Exception:
continue
# 按创建时间排序
graphs.sort(key=lambda x: x.get('created_at', ''), reverse=True)
return graphs
... ...
"""
GraphRAG 提示词模块
包含查询决策和章节增强的完整提示词定义。
"""
# ================== 查询决策提示词 ==================
GRAPHRAG_QUERY_DECISION_SYSTEM = """你是一个智能舆情分析助手,负责决定如何查询知识图谱以获取生成报告章节所需的信息。
知识图谱包含以下节点类型:
- Topic: 用户查询的主题
- Engine: 四个分析引擎(Insight/Media/Query/Host)
- Section: 各引擎报告的段落章节
- SearchQuery: 引擎执行过的搜索关键词
- Source: 搜索发现的信息来源(URL、标题、内容摘要)
你的任务是根据当前章节的需求,决定查询参数以获取最相关的信息。"""
GRAPHRAG_QUERY_DECISION_USER = """
=== 当前任务 ===
正在生成报告章节: "{chapter_title}"
章节编号: {chapter_id}
章节在模板中的定位: {chapter_role}
目标字数: {target_words}字
章节要点: {chapter_emphasis}
=== 完整报告规划 ===
报告主题: {report_topic}
模板类型: {template_name}
全书章节概览:
{chapters_overview}
=== 知识图谱概览 ===
图谱统计:
- 主题节点: 1个 ({topic_name})
- 引擎节点: {engine_count}个
- 段落节点: {section_count}个
- 搜索词节点: {query_count}个
- 来源节点: {source_count}个
各引擎段落标题:
{section_titles_by_engine}
搜索关键词样例(前20个):
{sample_search_queries}
=== 查询历史记录(本章节已执行的查询) ===
{query_history_detail}
=== 请决定查询参数 ===
请输出JSON格式的查询参数:
```json
{{
"should_query": true/false,
"keywords": ["关键词1", "关键词2", ...],
"node_types": ["section", "search_query", "source"],
"engine_filter": ["insight", "media", "query", "host"],
"depth": 1-3,
"reasoning": "选择这些参数的原因,以及期望获取什么信息"
}}
```
注意事项:
1. 仔细查看查询历史,**避免重复查询相同或相似的关键词**
2. 关键词应与当前章节主题紧密相关
3. 如果查询历史已经覆盖了章节所需的主要信息,设置 should_query=false
4. depth建议:1=精确匹配,2=包含关联,3=扩展探索(信息量大但可能有噪音)
5. 可以通过 engine_filter 聚焦特定引擎的分析视角
"""
# ================== 章节增强提示词(GraphRAG 开启时使用) ==================
SYSTEM_PROMPT_CHAPTER_GRAPH_ENHANCEMENT = """
=== GraphRAG 知识图谱增强 ===
本次章节生成已通过知识图谱查询获取了跨引擎的关联信息。
在生成内容时,请特别注意:
1. **跨引擎关联**: graphResults 中包含了来自不同引擎的相关信息,
请综合利用这些多视角的分析结果,形成更全面的观点。
2. **信息溯源**: 对于重要观点,可以引用 graphResults.matched_sources
中的来源信息,增强可信度。
3. **搜索词关联**: graphResults.matched_queries 显示了各引擎为本主题
执行的相关搜索,这些搜索词本身就是重要的语义线索。
4. **避免重复**: 不同引擎可能有相似的分析,请整合而非重复。
"""
USER_PROMPT_GRAPH_RESULTS_TEMPLATE = """
=== GraphRAG 知识图谱查询结果 ===
**查询轮次**: {query_rounds}次
**匹配的相关段落** (来自其他引擎的相关分析):
{matched_sections}
**相关搜索关键词** (各引擎执行的相关搜索):
{matched_queries}
**相关信息来源** (搜索发现的相关URL和内容):
{matched_sources}
**跨引擎关联洞察**:
{cross_engine_insights}
请在生成本章节时,充分利用以上知识图谱查询结果,
特别是跨引擎的关联信息,以丰富内容的多维度分析。
===
"""
def format_graph_results_for_prompt(graph_results: dict) -> str:
"""
格式化 GraphRAG 查询结果用于提示词
Args:
graph_results: 查询结果字典
Returns:
格式化的字符串
"""
if not graph_results:
return ""
# 格式化段落
matched_sections = graph_results.get('matched_sections', [])
sections_text = _format_matched_sections(matched_sections)
# 格式化搜索词
matched_queries = graph_results.get('matched_queries', [])
queries_text = _format_matched_queries(matched_queries)
# 格式化来源
matched_sources = graph_results.get('matched_sources', [])
sources_text = _format_matched_sources(matched_sources)
# 跨引擎洞察
insights = graph_results.get('cross_engine_insights', [])
insights_text = _format_cross_engine_insights(insights)
return USER_PROMPT_GRAPH_RESULTS_TEMPLATE.format(
query_rounds=graph_results.get('query_rounds', 0),
matched_sections=sections_text,
matched_queries=queries_text,
matched_sources=sources_text,
cross_engine_insights=insights_text
)
def _format_matched_sections(sections: list) -> str:
"""格式化匹配的段落"""
if not sections:
return "(无匹配段落)"
lines = []
for s in sections[:10]: # 限制数量
engine = s.get('engine', 'unknown')
title = s.get('title', '未知标题')
summary = s.get('summary', '')[:100]
lines.append(f"- [{engine}] {title}: {summary}...")
return "\n".join(lines)
def _format_matched_queries(queries: list) -> str:
"""格式化匹配的搜索词"""
if not queries:
return "(无匹配搜索词)"
by_engine = {}
for q in queries:
engine = q.get('engine', 'unknown')
if engine not in by_engine:
by_engine[engine] = []
query_text = q.get('query_text', q.get('name', ''))
if query_text and query_text not in by_engine[engine]:
by_engine[engine].append(query_text)
lines = []
for engine, query_list in by_engine.items():
lines.append(f"- {engine}: {', '.join(query_list[:5])}")
return "\n".join(lines)
def _format_matched_sources(sources: list) -> str:
"""格式化匹配的来源"""
if not sources:
return "(无匹配来源)"
lines = []
for s in sources[:8]:
title = s.get('title', '未知标题')
url = s.get('url', '#')
preview = s.get('preview', '')
lines.append(f"- [{title}]({url})")
if preview:
lines.append(f" 摘要: {preview[:80]}...")
return "\n".join(lines)
def _format_cross_engine_insights(insights: list) -> str:
"""格式化跨引擎洞察"""
if not insights:
return "(无跨引擎关联发现)"
return "\n".join([f"- {insight}" for insight in insights[:5]])
... ...
"""
图查询引擎
支持基于关键词、节点类型、引擎来源和深度的知识图谱查询。
"""
from dataclasses import dataclass, field
from typing import Dict, Any, List, Optional, Set
from .graph_storage import Graph, Node
@dataclass
class QueryParams:
"""查询参数"""
keywords: List[str] = field(default_factory=list)
node_types: Optional[List[str]] = None # None 表示全部类型
engine_filter: Optional[List[str]] = None # 限定引擎来源
depth: int = 1 # 扩展深度
@dataclass
class QueryResult:
"""查询结果"""
matched_sections: List[Dict[str, Any]] = field(default_factory=list)
matched_queries: List[Dict[str, Any]] = field(default_factory=list)
matched_sources: List[Dict[str, Any]] = field(default_factory=list)
total_nodes: int = 0
query_params: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
'matched_sections': self.matched_sections,
'matched_queries': self.matched_queries,
'matched_sources': self.matched_sources,
'total_nodes': self.total_nodes,
'query_params': self.query_params
}
def get_summary(self, max_length: int = 200) -> str:
"""获取结果摘要"""
parts = []
if self.matched_sections:
section_titles = [s.get('title', '')[:30] for s in self.matched_sections[:3]]
parts.append(f"段落({len(self.matched_sections)}): {', '.join(section_titles)}")
if self.matched_queries:
query_texts = [q.get('query_text', '')[:20] for q in self.matched_queries[:3]]
parts.append(f"搜索词({len(self.matched_queries)}): {', '.join(query_texts)}")
if self.matched_sources:
parts.append(f"来源({len(self.matched_sources)})")
summary = "; ".join(parts) if parts else "无匹配结果"
return summary[:max_length]
class QueryEngine:
"""
图查询引擎
支持以下查询能力:
1. 关键词匹配:在节点名称和属性中搜索
2. 类型筛选:限定节点类型 (section/search_query/source)
3. 引擎筛选:限定来源引擎 (insight/media/query/host)
4. 深度扩展:从匹配节点向外扩展指定深度
"""
def __init__(self, graph: Graph):
"""
初始化查询引擎
Args:
graph: 知识图谱对象
"""
self.graph = graph
def query(self, params: QueryParams) -> QueryResult:
"""
执行图谱查询
Args:
params: 查询参数
Returns:
QueryResult 查询结果
"""
# 1. 关键词匹配获取初始节点
matched_nodes = self._match_keywords(params)
# 2. 深度扩展
if params.depth > 0 and matched_nodes:
expanded_nodes = self._expand_depth(matched_nodes, params.depth)
matched_nodes = matched_nodes.union(expanded_nodes)
# 3. 整理结果
result = self._organize_results(matched_nodes, params)
return result
def _match_keywords(self, params: QueryParams) -> Set[str]:
"""关键词匹配"""
matched_ids = set()
for node in self.graph.nodes:
# 类型筛选
if params.node_types and node.type not in params.node_types:
continue
# 引擎筛选
if params.engine_filter:
node_engine = node.get('engine')
if node_engine and node_engine not in params.engine_filter:
continue
# 关键词匹配
if self._matches_keywords(node, params.keywords):
matched_ids.add(node.id)
return matched_ids
def _matches_keywords(self, node: Node, keywords: List[str]) -> bool:
"""检查节点是否匹配关键词"""
if not keywords:
return True # 无关键词时全部匹配
# 构建搜索文本
search_text = f"{node.name} {node.get('title', '')} {node.get('query_text', '')} {node.get('summary', '')}"
search_text = search_text.lower()
# 任一关键词匹配即可
for keyword in keywords:
if keyword.lower() in search_text:
return True
return False
def _expand_depth(self, node_ids: Set[str], depth: int) -> Set[str]:
"""从匹配节点向外扩展指定深度"""
expanded = set()
current_layer = node_ids.copy()
for _ in range(depth):
next_layer = set()
for node_id in current_layer:
# 获取邻居节点
neighbors = self.graph.get_neighbors(node_id)
for neighbor in neighbors:
if neighbor.id not in node_ids and neighbor.id not in expanded:
next_layer.add(neighbor.id)
expanded.add(neighbor.id)
if not next_layer:
break
current_layer = next_layer
return expanded
def _organize_results(self, node_ids: Set[str],
params: QueryParams) -> QueryResult:
"""整理查询结果"""
matched_sections = []
matched_queries = []
matched_sources = []
for node_id in node_ids:
node = self.graph.get_node(node_id)
if not node:
continue
node_dict = {
'id': node.id,
'name': node.name,
'type': node.type,
**node.attributes
}
if node.type == 'section':
matched_sections.append(node_dict)
elif node.type == 'search_query':
matched_queries.append(node_dict)
elif node.type == 'source':
matched_sources.append(node_dict)
# 排序:段落按 order,其他按名称
matched_sections.sort(key=lambda x: x.get('order', 0))
matched_queries.sort(key=lambda x: x.get('query_text', ''))
matched_sources.sort(key=lambda x: x.get('title', ''))
return QueryResult(
matched_sections=matched_sections,
matched_queries=matched_queries,
matched_sources=matched_sources,
total_nodes=len(node_ids),
query_params={
'keywords': params.keywords,
'node_types': params.node_types,
'engine_filter': params.engine_filter,
'depth': params.depth
}
)
def get_node_summary(self) -> Dict[str, Any]:
"""获取图谱节点概览(用于提示词)"""
return self.graph.get_summary()
def get_section_titles_by_engine(self) -> Dict[str, List[str]]:
"""按引擎获取所有段落标题"""
result = {}
for node in self.graph.get_nodes_by_type('section'):
engine = node.get('engine', 'unknown')
if engine not in result:
result[engine] = []
result[engine].append(node.get('title', node.name))
return result
def get_sample_search_queries(self, limit: int = 20) -> List[str]:
"""获取搜索词样例"""
queries = []
for node in self.graph.get_nodes_by_type('search_query'):
query_text = node.get('query_text', node.name)
if query_text and query_text not in queries:
queries.append(query_text)
if len(queries) >= limit:
break
return queries
... ...
"""
State JSON 解析器
解析 Insight/Media/Query 三引擎的 State JSON 文件,
提取结构化数据用于构建知识图谱。
"""
from dataclasses import dataclass, field
from typing import Dict, Any, List, Optional
import json
from pathlib import Path
@dataclass
class SearchRecord:
"""单条搜索记录"""
query: str = ""
url: str = ""
title: str = ""
content: str = ""
score: Optional[float] = None
timestamp: str = ""
@dataclass
class ParsedSection:
"""解析后的段落/章节"""
title: str = ""
order: int = 0
summary: str = ""
search_history: List[SearchRecord] = field(default_factory=list)
@dataclass
class ParsedState:
"""解析后的引擎状态"""
engine: str = ""
query: str = ""
report_title: str = ""
sections: List[ParsedSection] = field(default_factory=list)
class StateParser:
"""
State JSON 解析器
解析三引擎的 State JSON,提取用于构建知识图谱的结构化数据。
"""
def parse(self, engine_name: str, state_json: Dict[str, Any]) -> ParsedState:
"""
解析单个引擎的 State JSON
Args:
engine_name: 引擎名称 (insight/media/query)
state_json: State JSON 字典
Returns:
ParsedState 对象
"""
return ParsedState(
engine=engine_name,
query=state_json.get('query', ''),
report_title=state_json.get('report_title', ''),
sections=[
self._parse_paragraph(p)
for p in state_json.get('paragraphs', [])
]
)
def _parse_paragraph(self, para: Dict[str, Any]) -> ParsedSection:
"""解析单个段落"""
research = para.get('research', {})
# 提取搜索历史
search_history = []
for search in research.get('search_history', []):
search_history.append(SearchRecord(
query=search.get('query', ''),
url=search.get('url', ''),
title=search.get('title', ''),
content=search.get('content', '')[:200] if search.get('content') else '',
score=search.get('score'),
timestamp=search.get('timestamp', '')
))
# 获取摘要,优先使用 latest_summary
summary = research.get('latest_summary', '')
if not summary:
summary = para.get('content', '')
return ParsedSection(
title=para.get('title', ''),
order=para.get('order', 0),
summary=summary[:300] if summary else '',
search_history=search_history
)
def parse_from_file(self, engine_name: str, file_path: str) -> Optional[ParsedState]:
"""
从文件解析 State JSON
Args:
engine_name: 引擎名称
file_path: JSON 文件路径
Returns:
ParsedState 对象,失败返回 None
"""
try:
path = Path(file_path)
if not path.exists():
return None
with open(path, 'r', encoding='utf-8') as f:
state_json = json.load(f)
return self.parse(engine_name, state_json)
except Exception:
return None
def find_state_json(self, md_path: str) -> Optional[str]:
"""
根据 Markdown 报告路径查找对应的 State JSON 文件
State JSON 通常与 MD 文件在同一目录下,命名格式为 state_*.json
Args:
md_path: Markdown 文件路径
Returns:
State JSON 路径,未找到返回 None
"""
md_file = Path(md_path)
if not md_file.exists():
return None
parent_dir = md_file.parent
# 尝试匹配 state_*.json 文件
state_files = list(parent_dir.glob('state_*.json'))
if not state_files:
return None
# 如果有多个,尝试通过时间戳匹配
md_stem = md_file.stem # e.g., "武汉大学_20250825_180214"
for state_file in state_files:
state_stem = state_file.stem # e.g., "state_武汉大学_20250825_180214"
# 检查是否包含相同的查询词和时间戳
if md_stem in state_stem or state_stem.replace('state_', '') == md_stem:
return str(state_file)
# 否则返回最新的
state_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)
return str(state_files[0])
... ...
... ... @@ -14,6 +14,7 @@ from .chapter_generation_node import (
)
from .document_layout_node import DocumentLayoutNode
from .word_budget_node import WordBudgetNode
from .graphrag_query_node import GraphRAGQueryNode, QueryHistory
__all__ = [
"BaseNode",
... ... @@ -25,4 +26,6 @@ __all__ = [
"ChapterValidationError",
"DocumentLayoutNode",
"WordBudgetNode",
"GraphRAGQueryNode",
"QueryHistory",
]
... ...
... ... @@ -205,11 +205,15 @@ class ChapterGenerationNode(BaseNode):
llm_payload = self._build_payload(section, context)
user_message = build_chapter_user_prompt(llm_payload)
# 检查是否有GraphRAG结果,决定是否使用增强提示词
graph_enhanced = bool(context.get("graph_results"))
raw_text = self._stream_llm(
user_message,
chapter_dir,
stream_callback=stream_callback,
section_meta=chapter_meta,
graph_enhanced=graph_enhanced,
**kwargs,
)
parse_context: List[str] = []
... ... @@ -351,6 +355,22 @@ class ChapterGenerationNode(BaseNode):
"chapterPlan": chapter_plan,
"wordPlan": context.get("word_plan"),
}
# GraphRAG 增强:如果上下文中包含图谱查询结果,添加到payload
graph_results = context.get("graph_results")
if graph_results:
payload["graphResults"] = {
"totalNodes": graph_results.get("total_nodes", 0),
"queryRounds": graph_results.get("query_rounds", 0),
"matchedSections": graph_results.get("matched_sections", []),
"matchedQueries": graph_results.get("matched_queries", []),
"matchedSources": graph_results.get("matched_sources", []),
}
# 同时添加增强提示(如果有)
graph_enhancement = context.get("graph_enhancement_prompt")
if graph_enhancement:
payload["graphEnhancementPrompt"] = graph_enhancement
if chapter_plan:
constraints = payload["constraints"]
if chapter_plan.get("targetWords"):
... ... @@ -438,6 +458,7 @@ class ChapterGenerationNode(BaseNode):
chapter_dir: Path,
stream_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
section_meta: Optional[Dict[str, Any]] = None,
graph_enhanced: bool = False,
**kwargs,
) -> str:
"""
... ... @@ -448,15 +469,23 @@ class ChapterGenerationNode(BaseNode):
chapter_dir: 章节的本地缓存目录,用于存放 stream.raw。
stream_callback: SSE流式推送的回调函数。
section_meta: 附带的章节ID/标题,用于回调payload。
graph_enhanced: 是否启用GraphRAG增强的系统提示词。
**kwargs: 透传温度、top_p等参数。
返回:
str: 将所有delta拼接后的原始文本。
"""
# 根据是否启用GraphRAG选择不同的系统提示词
if graph_enhanced:
from ..graphrag.prompts import SYSTEM_PROMPT_CHAPTER_GRAPH_ENHANCEMENT
system_prompt = SYSTEM_PROMPT_CHAPTER_JSON + "\n\n" + SYSTEM_PROMPT_CHAPTER_GRAPH_ENHANCEMENT
else:
system_prompt = SYSTEM_PROMPT_CHAPTER_JSON
chunks: List[str] = []
with self.storage.capture_stream(chapter_dir) as stream_fp:
stream = self.llm_client.stream_invoke(
SYSTEM_PROMPT_CHAPTER_JSON,
system_prompt,
user_message,
temperature=kwargs.get("temperature", 0.2),
top_p=kwargs.get("top_p", 0.95),
... ...
"""
GraphRAG 查询节点
负责与知识图谱交互,让 LLM 决定查询参数并执行多轮查询。
包含查询历史机制以防止重复查询。
"""
import json
import re
from dataclasses import dataclass, field
from typing import Dict, Any, List, Optional
from loguru import logger
from .base_node import BaseNode
from ..llms.base import LLMClient
from ..graphrag.graph_storage import Graph
from ..graphrag.query_engine import QueryEngine, QueryParams, QueryResult
from ..graphrag.prompts import (
GRAPHRAG_QUERY_DECISION_SYSTEM,
GRAPHRAG_QUERY_DECISION_USER
)
@dataclass
class QueryRound:
"""单轮查询记录"""
round: int
params: Dict[str, Any]
result_count: int
summary: str
class QueryHistory:
"""
查询历史记录器
记录每次查询的参数和结果摘要,用于防止 LLM 重复查询相同内容。
"""
def __init__(self):
self.rounds: List[QueryRound] = []
def add(self, params: Dict[str, Any], result: QueryResult) -> None:
"""
记录一次查询
Args:
params: 查询参数
result: 查询结果
"""
self.rounds.append(QueryRound(
round=len(self.rounds) + 1,
params=params,
result_count=result.total_nodes,
summary=result.get_summary()
))
def to_prompt(self) -> str:
"""
生成供 LLM 参考的历史上下文
Returns:
格式化的历史记录字符串
"""
if not self.rounds:
return "(这是第1次查询,无历史记录)"
lines = ["=== 已完成的查询历史 ==="]
for r in self.rounds:
keywords = r.params.get('keywords', [])
node_types = r.params.get('node_types', ['all'])
engine_filter = r.params.get('engine_filter', ['all'])
lines.append(f"第{r.round}次查询:")
lines.append(f" 关键词: {', '.join(keywords) if keywords else '无'}")
lines.append(f" 节点类型: {', '.join(node_types) if node_types else '全部'}")
lines.append(f" 引擎筛选: {', '.join(engine_filter) if engine_filter else '全部'}")
lines.append(f" 返回节点数: {r.result_count}")
lines.append(f" 结果摘要: {r.summary}")
lines.append("")
lines.append("=== 请避免重复上述查询,探索新的角度 ===")
return "\n".join(lines)
def get_all_keywords(self) -> List[str]:
"""获取所有已查询的关键词"""
keywords = []
for r in self.rounds:
keywords.extend(r.params.get('keywords', []))
return keywords
class GraphRAGQueryNode(BaseNode):
"""
GraphRAG 查询节点
核心职责:
1. 接收完整上下文(报告、章节规划、图谱概览)
2. 维护查询历史记录,防止重复查询
3. 调用 LLM 决定查询参数
4. 执行 GraphRAG 查询
5. 最多允许 max_queries 次查询
6. 将查询结果整合返回
"""
def __init__(self, llm_client: LLMClient):
super().__init__(llm_client, "GraphRAGQueryNode")
def run(self, section: Dict[str, Any], context: Dict[str, Any],
graph: Graph, max_queries: int = 3) -> Dict[str, Any]:
"""
执行 GraphRAG 查询流程
Args:
section: 当前章节信息
context: 生成上下文(报告、规划等)
graph: 知识图谱
max_queries: 最大查询次数
Returns:
合并后的查询结果
"""
self.log_info(f"开始 GraphRAG 查询,章节: {section.get('title', 'unknown')}")
query_engine = QueryEngine(graph)
history = QueryHistory()
all_results: List[QueryResult] = []
for round_idx in range(max_queries):
self.log_info(f"查询轮次 {round_idx + 1}/{max_queries}")
# 1. 构建决策提示词
prompt = self._build_decision_prompt(
section, context, query_engine, history
)
# 2. 调用 LLM 决定查询参数
decision = self._get_query_decision(prompt)
if decision is None:
self.log_error("LLM 返回无效决策,终止查询")
break
# 3. 检查是否停止
if not decision.get('should_query', False):
self.log_info(f"LLM 决定停止查询: {decision.get('reasoning', '无原因')}")
break
# 4. 执行查询
params = QueryParams(
keywords=decision.get('keywords', []),
node_types=decision.get('node_types'),
engine_filter=decision.get('engine_filter'),
depth=decision.get('depth', 1)
)
result = query_engine.query(params)
all_results.append(result)
self.log_info(f"查询返回 {result.total_nodes} 个节点")
# 5. 记录历史
history.add(decision, result)
# 6. 合并所有结果
merged = self._merge_results(all_results)
merged['query_rounds'] = len(all_results)
self.log_info(f"GraphRAG 查询完成,共 {len(all_results)} 轮,"
f"获取 {merged.get('total_nodes', 0)} 个节点")
return merged
def _build_decision_prompt(self, section: Dict[str, Any],
context: Dict[str, Any],
query_engine: QueryEngine,
history: QueryHistory) -> Dict[str, str]:
"""构建查询决策提示词"""
# 获取图谱概览
summary = query_engine.get_node_summary()
stats = summary.get('stats', {})
# 获取段落标题(按引擎分组)
section_titles = query_engine.get_section_titles_by_engine()
section_titles_text = ""
for engine, titles in section_titles.items():
section_titles_text += f"\n{engine}: {', '.join(titles[:5])}"
# 获取搜索词样例
sample_queries = query_engine.get_sample_search_queries(20)
# 获取章节概览
chapters = context.get('chapters', [])
chapters_text = "\n".join([
f"- {c.get('id', '')}: {c.get('title', '')}"
for c in chapters[:10]
])
user_prompt = GRAPHRAG_QUERY_DECISION_USER.format(
chapter_title=section.get('title', ''),
chapter_id=section.get('id', ''),
chapter_role=section.get('role', ''),
target_words=section.get('target_words', 500),
chapter_emphasis=section.get('emphasis', ''),
report_topic=context.get('query', ''),
template_name=context.get('template_name', ''),
chapters_overview=chapters_text,
topic_name=summary.get('topic', ''),
engine_count=len(summary.get('engines', [])),
section_count=stats.get('section', 0),
query_count=stats.get('search_query', 0),
source_count=stats.get('source', 0),
section_titles_by_engine=section_titles_text,
sample_search_queries=', '.join(sample_queries),
query_history_detail=history.to_prompt()
)
return {
'system': GRAPHRAG_QUERY_DECISION_SYSTEM,
'user': user_prompt
}
def _get_query_decision(self, prompt: Dict[str, str]) -> Optional[Dict[str, Any]]:
"""调用 LLM 获取查询决策"""
try:
response = self.llm_client.invoke(
system=prompt['system'],
user=prompt['user']
)
# 解析 JSON 响应
return self._parse_json_response(response)
except Exception as e:
self.log_error(f"LLM 调用失败: {e}")
return None
def _parse_json_response(self, response: str) -> Optional[Dict[str, Any]]:
"""解析 LLM 返回的 JSON"""
try:
# 尝试直接解析
return json.loads(response)
except json.JSONDecodeError:
pass
# 尝试提取 JSON 块
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group(1))
except json.JSONDecodeError:
pass
# 尝试提取花括号内容
brace_match = re.search(r'\{.*\}', response, re.DOTALL)
if brace_match:
try:
return json.loads(brace_match.group())
except json.JSONDecodeError:
pass
self.log_error(f"无法解析 JSON 响应: {response[:200]}")
return None
def _merge_results(self, results: List[QueryResult]) -> Dict[str, Any]:
"""合并多轮查询结果"""
merged = {
'matched_sections': [],
'matched_queries': [],
'matched_sources': [],
'total_nodes': 0,
'cross_engine_insights': []
}
seen_section_ids = set()
seen_query_ids = set()
seen_source_ids = set()
for result in results:
# 合并段落(去重)
for section in result.matched_sections:
sid = section.get('id')
if sid and sid not in seen_section_ids:
seen_section_ids.add(sid)
merged['matched_sections'].append(section)
# 合并搜索词(去重)
for query in result.matched_queries:
qid = query.get('id')
if qid and qid not in seen_query_ids:
seen_query_ids.add(qid)
merged['matched_queries'].append(query)
# 合并来源(去重)
for source in result.matched_sources:
sid = source.get('id')
if sid and sid not in seen_source_ids:
seen_source_ids.add(sid)
merged['matched_sources'].append(source)
merged['total_nodes'] = (
len(merged['matched_sections']) +
len(merged['matched_queries']) +
len(merged['matched_sources'])
)
# 生成跨引擎洞察
merged['cross_engine_insights'] = self._generate_cross_engine_insights(merged)
return merged
def _generate_cross_engine_insights(self, merged: Dict[str, Any]) -> List[str]:
"""生成跨引擎关联洞察"""
insights = []
# 统计各引擎的段落数
engine_sections = {}
for section in merged['matched_sections']:
engine = section.get('engine', 'unknown')
engine_sections[engine] = engine_sections.get(engine, 0) + 1
if len(engine_sections) > 1:
engines = list(engine_sections.keys())
insights.append(f"跨引擎信息来源: {', '.join(engines)}")
# 统计搜索词的引擎分布
engine_queries = {}
for query in merged['matched_queries']:
engine = query.get('engine', 'unknown')
if engine not in engine_queries:
engine_queries[engine] = []
engine_queries[engine].append(query.get('query_text', ''))
if len(engine_queries) > 1:
insights.append(f"多引擎搜索视角: {len(engine_queries)} 个引擎提供了相关搜索")
return insights
... ...
... ... @@ -512,3 +512,129 @@ def build_document_layout_prompt(payload: dict) -> str:
def build_word_budget_prompt(payload: dict) -> str:
"""将篇幅规划输入转为字符串,便于送入LLM并保持字段精确。"""
return json.dumps(payload, ensure_ascii=False, indent=2)
# ==================== GraphRAG 增强提示词 ====================
GRAPHRAG_CHAPTER_ENHANCEMENT_INTRO = """
<知识图谱查询结果>
以下是针对本章节从知识图谱中查询到的相关信息,这些信息来自对Insight/Media/Query三个分析引擎结构化数据的聚合:
{graph_results}
请在生成本章内容时:
1. 充分利用上述图谱查询结果中的具体数据点、关键发现和关联关系
2. 优先引用图谱中标注的来源(搜索关键词、数据来源等)
3. 当图谱结果与三引擎报告有重叠时,以图谱中的结构化数据为准
4. 注意图谱中节点之间的关联关系,体现因果或递进逻辑
5. 如果图谱结果中有明确的数值或时间点,务必准确引用
</知识图谱查询结果>
"""
def build_graphrag_enhanced_user_prompt(payload: dict) -> str:
"""
构造包含GraphRAG查询结果的章节用户提示词。
当GraphRAG启用且有查询结果时,在标准payload基础上
注入图谱查询摘要,指导LLM在章节生成时优先利用这些信息。
Args:
payload: 包含标准章节上下文和可选 graph_enhancement_prompt 的字典
Returns:
序列化后的用户提示词字符串
"""
# 提取图谱增强内容(如果有)
graph_prompt = payload.pop('graph_enhancement_prompt', None)
base_prompt = json.dumps(payload, ensure_ascii=False, indent=2)
if graph_prompt:
return f"{base_prompt}\n\n{graph_prompt}"
return base_prompt
def format_graph_nodes_for_prompt(nodes: list) -> str:
"""
将图谱节点列表格式化为提示词友好的文本。
Args:
nodes: 节点数据列表,每个节点包含 id, type, label, properties
Returns:
格式化的节点描述文本
"""
if not nodes:
return "(无相关节点)"
lines = []
# 按类型分组
by_type = {}
for node in nodes:
node_type = node.get('type', 'unknown')
if node_type not in by_type:
by_type[node_type] = []
by_type[node_type].append(node)
type_labels = {
'topic': '主题',
'engine': '分析引擎',
'section': '报告段落',
'search_query': '搜索关键词',
'source': '数据来源'
}
for node_type, type_nodes in by_type.items():
type_label = type_labels.get(node_type, node_type)
lines.append(f"\n【{type_label}】")
for n in type_nodes[:10]: # 每类最多10个
label = n.get('label', n.get('id', ''))
props = n.get('properties', {})
prop_str = ''
if props:
key_props = {k: v for k, v in props.items() if k in ['summary', 'content', 'headline', 'url', 'query', 'source']}
if key_props:
prop_str = ' | ' + ', '.join(f"{k}:{str(v)[:100]}" for k, v in key_props.items())
lines.append(f" • {label}{prop_str}")
return '\n'.join(lines)
def format_graph_edges_for_prompt(edges: list) -> str:
"""
将图谱边列表格式化为提示词友好的文本。
Args:
edges: 边数据列表,每条边包含 source, target, relation
Returns:
格式化的关系描述文本
"""
if not edges:
return "(无关联关系)"
relation_labels = {
'analyzed_by': '被分析于',
'contains': '包含',
'searched': '搜索了',
'found': '发现于'
}
lines = []
seen = set()
for edge in edges[:20]: # 最多20条关系
source = edge.get('source', '')
target = edge.get('target', '')
relation = edge.get('relation', 'related')
key = f"{source}-{relation}-{target}"
if key in seen:
continue
seen.add(key)
rel_label = relation_labels.get(relation, relation)
lines.append(f" • {source} —[{rel_label}]→ {target}")
return '\n'.join(lines) if lines else "(无关联关系)"
... ...
... ... @@ -66,6 +66,14 @@ class Settings(BaseSettings):
JSON_ERROR_LOG_DIR: str = Field(
"logs/json_repair_failures", description="无法修复的JSON块落盘目录"
)
# GraphRAG 配置
GRAPHRAG_ENABLED: bool = Field(
default=False, description="是否启用GraphRAG知识图谱功能"
)
GRAPHRAG_MAX_QUERIES: int = Field(
default=3, description="GraphRAG每章节查询次数上限"
)
class Config:
"""Pydantic配置:允许从.env读取并兼容大小写"""
... ...
... ... @@ -113,7 +113,9 @@ CONFIG_KEYS = [
'TAVILY_API_KEY',
'SEARCH_TOOL_TYPE',
'BOCHA_WEB_SEARCH_API_KEY',
'ANSPIRE_API_KEY'
'ANSPIRE_API_KEY',
'GRAPHRAG_ENABLED',
'GRAPHRAG_MAX_QUERIES'
]
... ... @@ -1295,6 +1297,247 @@ def shutdown_system():
logger.exception("系统关闭过程中出现异常")
return jsonify({'success': False, 'message': f'系统关闭异常: {exc}'}), 500
# ==================== GraphRAG API 端点 ====================
@app.route('/api/graph/<report_id>')
def get_graph_data(report_id):
"""
获取指定报告的知识图谱数据。
返回格式适合前端 Vis.js 渲染:
- nodes: [{id, label, group, title, properties}]
- edges: [{from, to, label}]
"""
try:
from ReportEngine.graphrag import GraphStorage, Graph
# 从默认存储位置查找图谱文件
storage = GraphStorage()
graph_path = storage.find_graph_by_report_id(report_id)
if not graph_path or not graph_path.exists():
return jsonify({
'success': False,
'message': f'未找到报告 {report_id} 的知识图谱数据'
}), 404
graph = storage.load(graph_path)
# 转换为 Vis.js 格式
vis_nodes = []
vis_edges = []
for node_id, node in graph.nodes.items():
vis_nodes.append({
'id': node_id,
'label': node.label or node_id,
'group': node.type,
'title': _format_node_tooltip(node),
'properties': node.properties
})
for edge in graph.edges:
vis_edges.append({
'from': edge.source,
'to': edge.target,
'label': edge.relation,
'arrows': 'to'
})
return jsonify({
'success': True,
'graph': {
'nodes': vis_nodes,
'edges': vis_edges,
'stats': graph.get_stats()
}
})
except Exception as e:
logger.exception(f"获取图谱数据失败: {e}")
return jsonify({
'success': False,
'message': f'获取图谱数据失败: {str(e)}'
}), 500
@app.route('/api/graph/latest')
def get_latest_graph():
"""获取最近一次生成的知识图谱数据。"""
try:
from ReportEngine.graphrag import GraphStorage
storage = GraphStorage()
latest_path = storage.find_latest_graph()
if not latest_path or not latest_path.exists():
return jsonify({
'success': False,
'message': '暂无可用的知识图谱数据'
}), 404
graph = storage.load(latest_path)
report_id = latest_path.parent.name if latest_path.parent else 'unknown'
# 转换为 Vis.js 格式
vis_nodes = []
vis_edges = []
for node_id, node in graph.nodes.items():
vis_nodes.append({
'id': node_id,
'label': node.label or node_id,
'group': node.type,
'title': _format_node_tooltip(node),
'properties': node.properties
})
for edge in graph.edges:
vis_edges.append({
'from': edge.source,
'to': edge.target,
'label': edge.relation,
'arrows': 'to'
})
return jsonify({
'success': True,
'report_id': report_id,
'graph': {
'nodes': vis_nodes,
'edges': vis_edges,
'stats': graph.get_stats()
}
})
except Exception as e:
logger.exception(f"获取最新图谱失败: {e}")
return jsonify({
'success': False,
'message': f'获取最新图谱失败: {str(e)}'
}), 500
@app.route('/graph-viewer')
@app.route('/graph-viewer/')
@app.route('/graph-viewer/<report_id>')
def graph_viewer(report_id=None):
"""
知识图谱可视化页面。
提供交互式图谱展示,支持:
- 全屏模式
- 缩放、拖拽
- 节点详情查看
- 筛选和搜索
"""
return render_template('graph_viewer.html', report_id=report_id)
@app.route('/api/graph/query', methods=['POST'])
def query_graph():
"""
查询知识图谱。
请求体:
{
"report_id": "xxx", // 可选,默认使用最新图谱
"keywords": ["关键词1", "关键词2"],
"node_types": ["section", "source"],
"depth": 2
}
"""
try:
from ReportEngine.graphrag import GraphStorage, QueryEngine, QueryParams
data = request.get_json() or {}
report_id = data.get('report_id')
storage = GraphStorage()
if report_id:
graph_path = storage.find_graph_by_report_id(report_id)
else:
graph_path = storage.find_latest_graph()
if not graph_path or not graph_path.exists():
return jsonify({
'success': False,
'message': '未找到可用的知识图谱'
}), 404
graph = storage.load(graph_path)
query_engine = QueryEngine(graph)
params = QueryParams(
keywords=data.get('keywords', []),
node_types=data.get('node_types'),
engine_filter=data.get('engine_filter'),
depth=data.get('depth', 1)
)
result = query_engine.query(params)
return jsonify({
'success': True,
'result': {
'matched_nodes': [
{
'id': n.id,
'type': n.type,
'label': n.label,
'properties': n.properties
}
for n in result.matched_nodes
],
'related_edges': [
{
'source': e.source,
'target': e.target,
'relation': e.relation
}
for e in result.related_edges
],
'expanded_nodes': [
{
'id': n.id,
'type': n.type,
'label': n.label,
'properties': n.properties
}
for n in result.expanded_nodes
]
}
})
except Exception as e:
logger.exception(f"图谱查询失败: {e}")
return jsonify({
'success': False,
'message': f'图谱查询失败: {str(e)}'
}), 500
def _format_node_tooltip(node) -> str:
"""格式化节点悬停提示文本。"""
lines = [f"<b>{node.label or node.id}</b>"]
lines.append(f"类型: {node.type}")
props = node.properties or {}
if 'summary' in props:
lines.append(f"摘要: {props['summary'][:100]}...")
if 'content' in props:
lines.append(f"内容: {props['content'][:80]}...")
if 'url' in props:
lines.append(f"链接: {props['url']}")
if 'query' in props:
lines.append(f"查询: {props['query']}")
return "<br>".join(lines)
# ==================== GraphRAG API 端点结束 ====================
@socketio.on('connect')
def handle_connect():
"""客户端连接"""
... ...
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>知识图谱可视化 - BettaFish</title>
<!-- Vis.js -->
<script src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script>
<style>
:root {
--primary-color: #4F46E5;
--primary-light: #818CF8;
--bg-color: #0F172A;
--card-bg: #1E293B;
--text-color: #F1F5F9;
--text-muted: #94A3B8;
--border-color: #334155;
--success-color: #10B981;
--warning-color: #F59E0B;
--error-color: #EF4444;
}
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
background-color: var(--bg-color);
color: var(--text-color);
min-height: 100vh;
}
/* 顶部工具栏 */
.toolbar {
position: fixed;
top: 0;
left: 0;
right: 0;
height: 60px;
background: var(--card-bg);
border-bottom: 1px solid var(--border-color);
display: flex;
align-items: center;
padding: 0 20px;
gap: 16px;
z-index: 1000;
}
.toolbar h1 {
font-size: 1.25rem;
font-weight: 600;
display: flex;
align-items: center;
gap: 8px;
}
.toolbar h1 svg {
width: 24px;
height: 24px;
color: var(--primary-color);
}
.toolbar-divider {
width: 1px;
height: 30px;
background: var(--border-color);
}
.btn {
display: flex;
align-items: center;
gap: 6px;
padding: 8px 16px;
border: 1px solid var(--border-color);
border-radius: 6px;
background: transparent;
color: var(--text-color);
cursor: pointer;
font-size: 0.875rem;
transition: all 0.2s;
}
.btn:hover {
background: var(--primary-color);
border-color: var(--primary-color);
}
.btn-primary {
background: var(--primary-color);
border-color: var(--primary-color);
}
.btn svg {
width: 16px;
height: 16px;
}
.search-box {
flex: 1;
max-width: 400px;
position: relative;
}
.search-box input {
width: 100%;
padding: 8px 16px 8px 40px;
border: 1px solid var(--border-color);
border-radius: 6px;
background: var(--bg-color);
color: var(--text-color);
font-size: 0.875rem;
}
.search-box input:focus {
outline: none;
border-color: var(--primary-color);
}
.search-box svg {
position: absolute;
left: 12px;
top: 50%;
transform: translateY(-50%);
width: 16px;
height: 16px;
color: var(--text-muted);
}
/* 统计信息 */
.stats {
display: flex;
gap: 16px;
margin-left: auto;
}
.stat-item {
display: flex;
align-items: center;
gap: 6px;
font-size: 0.875rem;
}
.stat-item .label {
color: var(--text-muted);
}
.stat-item .value {
font-weight: 600;
color: var(--primary-light);
}
/* 左侧面板 */
.sidebar {
position: fixed;
top: 60px;
left: 0;
width: 300px;
bottom: 0;
background: var(--card-bg);
border-right: 1px solid var(--border-color);
overflow-y: auto;
padding: 16px;
transition: transform 0.3s;
z-index: 100;
}
.sidebar.collapsed {
transform: translateX(-100%);
}
.sidebar h3 {
font-size: 0.875rem;
font-weight: 600;
color: var(--text-muted);
text-transform: uppercase;
letter-spacing: 0.05em;
margin-bottom: 12px;
}
.filter-group {
margin-bottom: 20px;
}
.filter-item {
display: flex;
align-items: center;
gap: 10px;
padding: 8px 0;
cursor: pointer;
}
.filter-item input[type="checkbox"] {
width: 16px;
height: 16px;
accent-color: var(--primary-color);
}
.filter-item .color-dot {
width: 12px;
height: 12px;
border-radius: 50%;
}
.filter-item .count {
margin-left: auto;
font-size: 0.75rem;
color: var(--text-muted);
}
/* 节点详情 */
.node-detail {
margin-top: 20px;
padding-top: 20px;
border-top: 1px solid var(--border-color);
}
.node-detail .detail-title {
font-weight: 600;
margin-bottom: 8px;
color: var(--primary-light);
}
.node-detail .detail-type {
font-size: 0.75rem;
color: var(--text-muted);
margin-bottom: 12px;
}
.node-detail .detail-props {
font-size: 0.875rem;
}
.node-detail .prop-item {
padding: 6px 0;
border-bottom: 1px solid var(--border-color);
}
.node-detail .prop-key {
color: var(--text-muted);
font-size: 0.75rem;
}
.node-detail .prop-value {
margin-top: 2px;
word-break: break-all;
}
/* 图谱容器 */
.graph-container {
position: fixed;
top: 60px;
left: 300px;
right: 0;
bottom: 0;
transition: left 0.3s;
}
.graph-container.fullwidth {
left: 0;
}
#network {
width: 100%;
height: 100%;
background: var(--bg-color);
}
/* 加载状态 */
.loading-overlay {
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
background: var(--bg-color);
z-index: 500;
}
.loading-spinner {
width: 48px;
height: 48px;
border: 4px solid var(--border-color);
border-top-color: var(--primary-color);
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
.loading-text {
margin-top: 16px;
color: var(--text-muted);
}
/* 空状态 */
.empty-state {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
text-align: center;
color: var(--text-muted);
}
.empty-state svg {
width: 64px;
height: 64px;
margin-bottom: 16px;
opacity: 0.5;
}
/* 提示信息 */
.toast {
position: fixed;
bottom: 20px;
right: 20px;
padding: 12px 20px;
background: var(--card-bg);
border: 1px solid var(--border-color);
border-radius: 8px;
display: none;
animation: slideIn 0.3s;
z-index: 2000;
}
@keyframes slideIn {
from {
transform: translateX(100%);
opacity: 0;
}
}
/* 图例 */
.legend {
position: fixed;
bottom: 20px;
left: 320px;
background: var(--card-bg);
border: 1px solid var(--border-color);
border-radius: 8px;
padding: 12px 16px;
display: flex;
gap: 16px;
z-index: 100;
transition: left 0.3s;
}
.legend.fullwidth {
left: 20px;
}
.legend-item {
display: flex;
align-items: center;
gap: 6px;
font-size: 0.75rem;
}
.legend-item .dot {
width: 10px;
height: 10px;
border-radius: 50%;
}
/* 全屏模式 */
.fullscreen-btn {
position: fixed;
bottom: 20px;
right: 20px;
z-index: 100;
}
/* 节点类型颜色 */
.color-topic { background-color: #EF4444; }
.color-engine { background-color: #F59E0B; }
.color-section { background-color: #10B981; }
.color-search_query { background-color: #3B82F6; }
.color-source { background-color: #8B5CF6; }
</style>
</head>
<body>
<!-- 顶部工具栏 -->
<div class="toolbar">
<h1>
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<circle cx="12" cy="5" r="3"/>
<circle cx="5" cy="19" r="3"/>
<circle cx="19" cy="19" r="3"/>
<line x1="12" y1="8" x2="5" y2="16"/>
<line x1="12" y1="8" x2="19" y2="16"/>
</svg>
知识图谱
</h1>
<div class="toolbar-divider"></div>
<button class="btn" id="toggleSidebar" title="切换侧边栏">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2"/>
<line x1="9" y1="3" x2="9" y2="21"/>
</svg>
</button>
<button class="btn" id="fitBtn" title="适应视图">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<path d="M15 3h6v6M9 21H3v-6M21 3l-7 7M3 21l7-7"/>
</svg>
适应
</button>
<button class="btn" id="zoomInBtn" title="放大">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<circle cx="11" cy="11" r="8"/>
<line x1="21" y1="21" x2="16.65" y2="16.65"/>
<line x1="11" y1="8" x2="11" y2="14"/>
<line x1="8" y1="11" x2="14" y2="11"/>
</svg>
</button>
<button class="btn" id="zoomOutBtn" title="缩小">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<circle cx="11" cy="11" r="8"/>
<line x1="21" y1="21" x2="16.65" y2="16.65"/>
<line x1="8" y1="11" x2="14" y2="11"/>
</svg>
</button>
<div class="search-box">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<circle cx="11" cy="11" r="8"/>
<line x1="21" y1="21" x2="16.65" y2="16.65"/>
</svg>
<input type="text" id="searchInput" placeholder="搜索节点...">
</div>
<div class="stats" id="statsContainer">
<div class="stat-item">
<span class="label">节点</span>
<span class="value" id="nodeCount">0</span>
</div>
<div class="stat-item">
<span class="label">关系</span>
<span class="value" id="edgeCount">0</span>
</div>
</div>
</div>
<!-- 左侧面板 -->
<div class="sidebar" id="sidebar">
<div class="filter-group">
<h3>节点类型</h3>
<label class="filter-item">
<input type="checkbox" checked data-type="topic">
<span class="color-dot color-topic"></span>
<span>主题</span>
<span class="count" id="count-topic">0</span>
</label>
<label class="filter-item">
<input type="checkbox" checked data-type="engine">
<span class="color-dot color-engine"></span>
<span>分析引擎</span>
<span class="count" id="count-engine">0</span>
</label>
<label class="filter-item">
<input type="checkbox" checked data-type="section">
<span class="color-dot color-section"></span>
<span>报告段落</span>
<span class="count" id="count-section">0</span>
</label>
<label class="filter-item">
<input type="checkbox" checked data-type="search_query">
<span class="color-dot color-search_query"></span>
<span>搜索关键词</span>
<span class="count" id="count-search_query">0</span>
</label>
<label class="filter-item">
<input type="checkbox" checked data-type="source">
<span class="color-dot color-source"></span>
<span>数据来源</span>
<span class="count" id="count-source">0</span>
</label>
</div>
<div class="node-detail" id="nodeDetail" style="display: none;">
<h3>节点详情</h3>
<div class="detail-title" id="detailTitle"></div>
<div class="detail-type" id="detailType"></div>
<div class="detail-props" id="detailProps"></div>
</div>
</div>
<!-- 图谱容器 -->
<div class="graph-container" id="graphContainer">
<div id="network"></div>
<!-- 加载状态 -->
<div class="loading-overlay" id="loadingOverlay">
<div class="loading-spinner"></div>
<div class="loading-text">正在加载知识图谱...</div>
</div>
<!-- 空状态 -->
<div class="empty-state" id="emptyState" style="display: none;">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1">
<circle cx="12" cy="12" r="10"/>
<path d="M8 15h8"/>
<path d="M9 9h.01"/>
<path d="M15 9h.01"/>
</svg>
<h3>暂无图谱数据</h3>
<p>请先生成报告以创建知识图谱</p>
</div>
</div>
<!-- 图例 -->
<div class="legend" id="legend">
<div class="legend-item">
<span class="dot color-topic"></span>
<span>主题</span>
</div>
<div class="legend-item">
<span class="dot color-engine"></span>
<span>引擎</span>
</div>
<div class="legend-item">
<span class="dot color-section"></span>
<span>段落</span>
</div>
<div class="legend-item">
<span class="dot color-search_query"></span>
<span>搜索词</span>
</div>
<div class="legend-item">
<span class="dot color-source"></span>
<span>来源</span>
</div>
</div>
<!-- 全屏按钮 -->
<button class="btn fullscreen-btn" id="fullscreenBtn" title="全屏">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<path d="M8 3H5a2 2 0 00-2 2v3m18 0V5a2 2 0 00-2-2h-3m0 18h3a2 2 0 002-2v-3M3 16v3a2 2 0 002 2h3"/>
</svg>
</button>
<!-- 提示 -->
<div class="toast" id="toast"></div>
<script>
// 配置
const NODE_COLORS = {
topic: '#EF4444',
engine: '#F59E0B',
section: '#10B981',
search_query: '#3B82F6',
source: '#8B5CF6'
};
const NODE_SHAPES = {
topic: 'star',
engine: 'diamond',
section: 'dot',
search_query: 'triangle',
source: 'square'
};
// 全局变量
let network = null;
let allNodes = [];
let allEdges = [];
let reportId = {{ report_id | tojson if report_id else 'null' }};
// 初始化
document.addEventListener('DOMContentLoaded', () => {
loadGraphData();
setupEventListeners();
});
// 加载图谱数据
async function loadGraphData() {
showLoading(true);
try {
const url = reportId
? `/api/graph/${reportId}`
: '/api/graph/latest';
const response = await fetch(url);
const data = await response.json();
if (data.success && data.graph) {
allNodes = data.graph.nodes;
allEdges = data.graph.edges;
updateStats(data.graph.stats);
renderGraph();
showLoading(false);
} else {
showEmpty(true);
showLoading(false);
}
} catch (error) {
console.error('加载图谱失败:', error);
showToast('加载图谱失败: ' + error.message);
showEmpty(true);
showLoading(false);
}
}
// 渲染图谱
function renderGraph() {
const container = document.getElementById('network');
// 处理节点
const visibleTypes = getVisibleTypes();
const filteredNodes = allNodes.filter(n => visibleTypes.includes(n.group));
const filteredNodeIds = new Set(filteredNodes.map(n => n.id));
const nodes = new vis.DataSet(filteredNodes.map(node => ({
id: node.id,
label: truncateLabel(node.label, 20),
title: node.title,
group: node.group,
color: {
background: NODE_COLORS[node.group] || '#6B7280',
border: NODE_COLORS[node.group] || '#6B7280',
highlight: {
background: lightenColor(NODE_COLORS[node.group] || '#6B7280'),
border: NODE_COLORS[node.group] || '#6B7280'
}
},
shape: NODE_SHAPES[node.group] || 'dot',
size: node.group === 'topic' ? 30 : (node.group === 'engine' ? 25 : 15),
font: {
color: '#F1F5F9',
size: 12
},
// 保存原始数据
_data: node
})));
// 处理边
const edges = new vis.DataSet(allEdges
.filter(e => filteredNodeIds.has(e.from) && filteredNodeIds.has(e.to))
.map(edge => ({
from: edge.from,
to: edge.to,
label: edge.label,
arrows: edge.arrows || 'to',
color: {
color: '#475569',
highlight: '#818CF8'
},
font: {
color: '#94A3B8',
size: 10,
strokeWidth: 0
},
smooth: {
type: 'continuous'
}
}))
);
// 图谱配置
const options = {
nodes: {
borderWidth: 2,
shadow: true
},
edges: {
width: 1,
shadow: true
},
physics: {
enabled: true,
solver: 'forceAtlas2Based',
forceAtlas2Based: {
gravitationalConstant: -100,
centralGravity: 0.01,
springLength: 150,
springConstant: 0.08,
damping: 0.5
},
stabilization: {
enabled: true,
iterations: 200
}
},
interaction: {
hover: true,
tooltipDelay: 100,
zoomView: true,
dragView: true
}
};
// 创建网络
network = new vis.Network(container, { nodes, edges }, options);
// 节点点击事件
network.on('click', (params) => {
if (params.nodes.length > 0) {
const nodeId = params.nodes[0];
const node = allNodes.find(n => n.id === nodeId);
if (node) {
showNodeDetail(node);
}
} else {
hideNodeDetail();
}
});
// 稳定后适应视图
network.once('stabilizationIterationsDone', () => {
network.fit({ animation: true });
});
}
// 显示节点详情
function showNodeDetail(node) {
const detailPanel = document.getElementById('nodeDetail');
const titleEl = document.getElementById('detailTitle');
const typeEl = document.getElementById('detailType');
const propsEl = document.getElementById('detailProps');
titleEl.textContent = node.label;
const typeLabels = {
topic: '主题',
engine: '分析引擎',
section: '报告段落',
search_query: '搜索关键词',
source: '数据来源'
};
typeEl.textContent = typeLabels[node.group] || node.group;
// 显示属性
let propsHtml = '';
const props = node.properties || {};
for (const [key, value] of Object.entries(props)) {
if (value) {
propsHtml += `
<div class="prop-item">
<div class="prop-key">${key}</div>
<div class="prop-value">${truncateText(String(value), 200)}</div>
</div>
`;
}
}
propsEl.innerHTML = propsHtml || '<div class="prop-item">无附加属性</div>';
detailPanel.style.display = 'block';
}
// 隐藏节点详情
function hideNodeDetail() {
document.getElementById('nodeDetail').style.display = 'none';
}
// 更新统计
function updateStats(stats) {
document.getElementById('nodeCount').textContent = stats.total_nodes || 0;
document.getElementById('edgeCount').textContent = stats.total_edges || 0;
// 更新各类型计数
document.getElementById('count-topic').textContent = stats.topic || 0;
document.getElementById('count-engine').textContent = stats.engine || 0;
document.getElementById('count-section').textContent = stats.section || 0;
document.getElementById('count-search_query').textContent = stats.search_query || 0;
document.getElementById('count-source').textContent = stats.source || 0;
}
// 获取可见类型
function getVisibleTypes() {
const types = [];
document.querySelectorAll('.filter-item input[type="checkbox"]').forEach(cb => {
if (cb.checked) {
types.push(cb.dataset.type);
}
});
return types;
}
// 设置事件监听
function setupEventListeners() {
// 侧边栏切换
document.getElementById('toggleSidebar').addEventListener('click', () => {
const sidebar = document.getElementById('sidebar');
const container = document.getElementById('graphContainer');
const legend = document.getElementById('legend');
sidebar.classList.toggle('collapsed');
container.classList.toggle('fullwidth');
legend.classList.toggle('fullwidth');
});
// 适应视图
document.getElementById('fitBtn').addEventListener('click', () => {
if (network) network.fit({ animation: true });
});
// 放大
document.getElementById('zoomInBtn').addEventListener('click', () => {
if (network) {
const scale = network.getScale() * 1.2;
network.moveTo({ scale, animation: true });
}
});
// 缩小
document.getElementById('zoomOutBtn').addEventListener('click', () => {
if (network) {
const scale = network.getScale() / 1.2;
network.moveTo({ scale, animation: true });
}
});
// 全屏
document.getElementById('fullscreenBtn').addEventListener('click', () => {
if (!document.fullscreenElement) {
document.documentElement.requestFullscreen();
} else {
document.exitFullscreen();
}
});
// 搜索
document.getElementById('searchInput').addEventListener('input', (e) => {
const query = e.target.value.toLowerCase();
if (!query) {
if (network) network.selectNodes([]);
return;
}
const matchedIds = allNodes
.filter(n => n.label.toLowerCase().includes(query))
.map(n => n.id);
if (network && matchedIds.length > 0) {
network.selectNodes(matchedIds);
network.focus(matchedIds[0], { animation: true, scale: 1.5 });
}
});
// 筛选
document.querySelectorAll('.filter-item input[type="checkbox"]').forEach(cb => {
cb.addEventListener('change', () => {
renderGraph();
});
});
}
// 辅助函数
function showLoading(show) {
document.getElementById('loadingOverlay').style.display = show ? 'flex' : 'none';
}
function showEmpty(show) {
document.getElementById('emptyState').style.display = show ? 'block' : 'none';
}
function showToast(message) {
const toast = document.getElementById('toast');
toast.textContent = message;
toast.style.display = 'block';
setTimeout(() => {
toast.style.display = 'none';
}, 3000);
}
function truncateLabel(text, maxLen) {
if (!text) return '';
return text.length > maxLen ? text.slice(0, maxLen) + '...' : text;
}
function truncateText(text, maxLen) {
if (!text) return '';
return text.length > maxLen ? text.slice(0, maxLen) + '...' : text;
}
function lightenColor(color) {
// 简单的颜色变亮
const hex = color.replace('#', '');
const r = Math.min(255, parseInt(hex.slice(0, 2), 16) + 40);
const g = Math.min(255, parseInt(hex.slice(2, 4), 16) + 40);
const b = Math.min(255, parseInt(hex.slice(4, 6), 16) + 40);
return `rgb(${r}, ${g}, ${b})`;
}
</script>
</body>
</html>
... ...