streaming_recognition_manager.py
13.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
# AIfeng/2025-07-07 09:34:55
# 流式识别结果管理器
# 核心功能:解决重复识别问题、管理部分和最终识别结果、智能结果合并
import time
import threading
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from logger import get_logger
logger = get_logger("StreamingRecognitionManager")
@dataclass
class RecognitionResult:
"""识别结果数据结构"""
session_id: str
result_type: str # 'partial' | 'final'
text: str
confidence: float
timestamp: float
audio_duration: float
is_processed: bool = False
class StreamingRecognitionManager:
"""流式识别结果管理器
核心功能:
1. 管理部分识别结果和最终识别结果
2. 解决重复识别问题(去重、合并)
3. 提供增量更新机制
4. 支持置信度评估和错误恢复
5. 通知UI更新
"""
def __init__(self,
confidence_threshold: float = 0.6,
max_session_duration: float = 60.0,
result_merge_window: float = 1.0,
auto_cleanup_interval: float = 300.0):
"""
初始化流式识别结果管理器
Args:
confidence_threshold: 置信度阈值
max_session_duration: 最大会话持续时间(秒)
result_merge_window: 结果合并时间窗口(秒)
auto_cleanup_interval: 自动清理间隔(秒)
"""
self.confidence_threshold = confidence_threshold
self.max_session_duration = max_session_duration
self.result_merge_window = result_merge_window
self.auto_cleanup_interval = auto_cleanup_interval
# 结果存储
self.active_sessions: Dict[str, Dict] = {} # 活跃会话
self.partial_results: Dict[str, List[RecognitionResult]] = {} # 部分结果
self.final_results: Dict[str, List[RecognitionResult]] = {} # 最终结果
self.merged_results: Dict[str, str] = {} # 合并后的结果
# 线程安全
self.lock = threading.RLock()
# 回调函数
self.on_partial_result: Optional[Callable] = None
self.on_final_result: Optional[Callable] = None
self.on_result_updated: Optional[Callable] = None
self.on_session_complete: Optional[Callable] = None
# 启动自动清理线程
self.cleanup_thread = threading.Thread(target=self._auto_cleanup_worker, daemon=True)
self.cleanup_thread.start()
logger.info(f"StreamingRecognitionManager初始化完成 - 置信度阈值:{confidence_threshold}")
def create_session(self, session_id: str, metadata: Dict[str, Any] = None) -> bool:
"""创建新的识别会话"""
with self.lock:
if session_id in self.active_sessions:
logger.warning(f"会话已存在: {session_id}")
return False
self.active_sessions[session_id] = {
'start_time': time.time(),
'last_update': time.time(),
'metadata': metadata or {},
'status': 'active'
}
self.partial_results[session_id] = []
self.final_results[session_id] = []
self.merged_results[session_id] = ""
logger.info(f"创建识别会话: {session_id}")
return True
def add_partial_result(self, session_id: str, text: str, confidence: float = 1.0,
audio_duration: float = 0.0) -> bool:
"""添加部分识别结果"""
with self.lock:
if session_id not in self.active_sessions:
logger.warning(f"会话不存在: {session_id}")
return False
# 创建部分结果
result = RecognitionResult(
session_id=session_id,
result_type='partial',
text=text.strip(),
confidence=confidence,
timestamp=time.time(),
audio_duration=audio_duration
)
# 检查是否为重复结果
if not self._is_duplicate_result(session_id, result, 'partial'):
self.partial_results[session_id].append(result)
self.active_sessions[session_id]['last_update'] = time.time()
# 更新合并结果
self._update_merged_result(session_id)
logger.debug(f"添加部分结果 [{session_id}]: {text[:50]}...")
# 触发回调
if self.on_partial_result:
self.on_partial_result(session_id, result)
if self.on_result_updated:
self.on_result_updated(session_id, self.merged_results[session_id], 'partial')
return True
else:
logger.debug(f"跳过重复的部分结果 [{session_id}]: {text[:30]}...")
return False
def add_final_result(self, session_id: str, text: str, confidence: float = 1.0,
audio_duration: float = 0.0) -> bool:
"""添加最终识别结果"""
with self.lock:
if session_id not in self.active_sessions:
logger.warning(f"会话不存在: {session_id}")
return False
# 创建最终结果
result = RecognitionResult(
session_id=session_id,
result_type='final',
text=text.strip(),
confidence=confidence,
timestamp=time.time(),
audio_duration=audio_duration
)
# 检查是否为重复结果
if not self._is_duplicate_result(session_id, result, 'final'):
self.final_results[session_id].append(result)
self.active_sessions[session_id]['last_update'] = time.time()
# 清除相关的部分结果
self._clear_related_partial_results(session_id, result)
# 更新合并结果
self._update_merged_result(session_id)
logger.info(f"添加最终结果 [{session_id}]: {text}")
# 触发回调
if self.on_final_result:
self.on_final_result(session_id, result)
if self.on_result_updated:
self.on_result_updated(session_id, self.merged_results[session_id], 'final')
return True
else:
logger.debug(f"跳过重复的最终结果 [{session_id}]: {text[:30]}...")
return False
def _is_duplicate_result(self, session_id: str, new_result: RecognitionResult,
result_type: str) -> bool:
"""检查是否为重复结果"""
results_list = self.partial_results[session_id] if result_type == 'partial' else self.final_results[session_id]
# 检查最近的结果
for existing_result in reversed(results_list[-5:]): # 只检查最近5个结果
# 时间窗口检查
time_diff = new_result.timestamp - existing_result.timestamp
if time_diff > self.result_merge_window:
continue
# 文本相似度检查
if self._calculate_text_similarity(new_result.text, existing_result.text) > 0.9:
return True
return False
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
"""计算文本相似度(简单实现)"""
if not text1 or not text2:
return 0.0
# 简单的字符级相似度
if text1 == text2:
return 1.0
# 检查包含关系
if text1 in text2 or text2 in text1:
return 0.95
# 简单的编辑距离相似度
max_len = max(len(text1), len(text2))
if max_len == 0:
return 1.0
# 这里可以实现更复杂的相似度算法
common_chars = sum(1 for c1, c2 in zip(text1, text2) if c1 == c2)
return common_chars / max_len
def _clear_related_partial_results(self, session_id: str, final_result: RecognitionResult):
"""清除与最终结果相关的部分结果"""
partial_list = self.partial_results[session_id]
# 标记相关的部分结果为已处理
for partial_result in partial_list:
if not partial_result.is_processed:
similarity = self._calculate_text_similarity(partial_result.text, final_result.text)
if similarity > 0.7: # 相似度阈值
partial_result.is_processed = True
logger.debug(f"标记部分结果为已处理: {partial_result.text[:30]}...")
def _update_merged_result(self, session_id: str):
"""更新合并后的识别结果"""
# 获取所有未处理的最终结果
final_texts = []
for result in self.final_results[session_id]:
if not result.is_processed and result.confidence >= self.confidence_threshold:
final_texts.append(result.text)
# 获取最新的部分结果(如果没有对应的最终结果)
if self.partial_results[session_id]:
latest_partial = self.partial_results[session_id][-1]
if not latest_partial.is_processed and latest_partial.confidence >= self.confidence_threshold:
# 检查是否已有对应的最终结果
has_corresponding_final = any(
self._calculate_text_similarity(latest_partial.text, final_result.text) > 0.7
for final_result in self.final_results[session_id]
)
if not has_corresponding_final:
final_texts.append(f"[部分] {latest_partial.text}")
# 合并结果
self.merged_results[session_id] = " ".join(final_texts)
def get_merged_result(self, session_id: str) -> str:
"""获取合并后的识别结果"""
with self.lock:
return self.merged_results.get(session_id, "")
def get_session_results(self, session_id: str) -> Dict[str, List[RecognitionResult]]:
"""获取会话的所有结果"""
with self.lock:
return {
'partial': self.partial_results.get(session_id, []),
'final': self.final_results.get(session_id, [])
}
def complete_session(self, session_id: str) -> bool:
"""完成识别会话"""
with self.lock:
if session_id not in self.active_sessions:
return False
self.active_sessions[session_id]['status'] = 'completed'
self.active_sessions[session_id]['end_time'] = time.time()
# 最终更新合并结果
self._update_merged_result(session_id)
final_result = self.merged_results[session_id]
logger.info(f"完成识别会话 [{session_id}]: {final_result}")
# 触发回调
if self.on_session_complete:
self.on_session_complete(session_id, final_result)
return True
def _auto_cleanup_worker(self):
"""自动清理工作线程"""
while True:
try:
time.sleep(self.auto_cleanup_interval)
self._cleanup_old_sessions()
except Exception as e:
logger.error(f"自动清理线程错误: {e}")
def _cleanup_old_sessions(self):
"""清理过期的会话"""
current_time = time.time()
sessions_to_remove = []
with self.lock:
for session_id, session_info in self.active_sessions.items():
# 检查会话是否过期
session_age = current_time - session_info['start_time']
last_update_age = current_time - session_info['last_update']
if (session_age > self.max_session_duration or
last_update_age > self.auto_cleanup_interval or
session_info['status'] == 'completed'):
sessions_to_remove.append(session_id)
# 移除过期会话
for session_id in sessions_to_remove:
logger.info(f"清理过期会话: {session_id}")
del self.active_sessions[session_id]
del self.partial_results[session_id]
del self.final_results[session_id]
del self.merged_results[session_id]
def get_active_sessions(self) -> List[str]:
"""获取活跃会话列表"""
with self.lock:
return [sid for sid, info in self.active_sessions.items() if info['status'] == 'active']
def get_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
with self.lock:
return {
'active_sessions_count': len([s for s in self.active_sessions.values() if s['status'] == 'active']),
'total_sessions_count': len(self.active_sessions),
'confidence_threshold': self.confidence_threshold,
'max_session_duration': self.max_session_duration,
'result_merge_window': self.result_merge_window
}
def reset(self):
"""重置管理器状态"""
with self.lock:
self.active_sessions.clear()
self.partial_results.clear()
self.final_results.clear()
self.merged_results.clear()
logger.info("StreamingRecognitionManager状态已重置")