streaming_recognition_manager.py 13.9 KB
# AIfeng/2025-07-07 09:34:55
# 流式识别结果管理器
# 核心功能:解决重复识别问题、管理部分和最终识别结果、智能结果合并

import time
import threading
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from logger import get_logger

logger = get_logger("StreamingRecognitionManager")

@dataclass
class RecognitionResult:
    """识别结果数据结构"""
    session_id: str
    result_type: str  # 'partial' | 'final'
    text: str
    confidence: float
    timestamp: float
    audio_duration: float
    is_processed: bool = False

class StreamingRecognitionManager:
    """流式识别结果管理器
    
    核心功能:
    1. 管理部分识别结果和最终识别结果
    2. 解决重复识别问题(去重、合并)
    3. 提供增量更新机制
    4. 支持置信度评估和错误恢复
    5. 通知UI更新
    """
    
    def __init__(self,
                 confidence_threshold: float = 0.6,
                 max_session_duration: float = 60.0,
                 result_merge_window: float = 1.0,
                 auto_cleanup_interval: float = 300.0):
        """
        初始化流式识别结果管理器
        
        Args:
            confidence_threshold: 置信度阈值
            max_session_duration: 最大会话持续时间(秒)
            result_merge_window: 结果合并时间窗口(秒)
            auto_cleanup_interval: 自动清理间隔(秒)
        """
        self.confidence_threshold = confidence_threshold
        self.max_session_duration = max_session_duration
        self.result_merge_window = result_merge_window
        self.auto_cleanup_interval = auto_cleanup_interval
        
        # 结果存储
        self.active_sessions: Dict[str, Dict] = {}  # 活跃会话
        self.partial_results: Dict[str, List[RecognitionResult]] = {}  # 部分结果
        self.final_results: Dict[str, List[RecognitionResult]] = {}  # 最终结果
        self.merged_results: Dict[str, str] = {}  # 合并后的结果
        
        # 线程安全
        self.lock = threading.RLock()
        
        # 回调函数
        self.on_partial_result: Optional[Callable] = None
        self.on_final_result: Optional[Callable] = None
        self.on_result_updated: Optional[Callable] = None
        self.on_session_complete: Optional[Callable] = None
        
        # 启动自动清理线程
        self.cleanup_thread = threading.Thread(target=self._auto_cleanup_worker, daemon=True)
        self.cleanup_thread.start()
        
        logger.info(f"StreamingRecognitionManager初始化完成 - 置信度阈值:{confidence_threshold}")
    
    def create_session(self, session_id: str, metadata: Dict[str, Any] = None) -> bool:
        """创建新的识别会话"""
        with self.lock:
            if session_id in self.active_sessions:
                logger.warning(f"会话已存在: {session_id}")
                return False
            
            self.active_sessions[session_id] = {
                'start_time': time.time(),
                'last_update': time.time(),
                'metadata': metadata or {},
                'status': 'active'
            }
            
            self.partial_results[session_id] = []
            self.final_results[session_id] = []
            self.merged_results[session_id] = ""
            
            logger.info(f"创建识别会话: {session_id}")
            return True
    
    def add_partial_result(self, session_id: str, text: str, confidence: float = 1.0, 
                          audio_duration: float = 0.0) -> bool:
        """添加部分识别结果"""
        with self.lock:
            if session_id not in self.active_sessions:
                logger.warning(f"会话不存在: {session_id}")
                return False
            
            # 创建部分结果
            result = RecognitionResult(
                session_id=session_id,
                result_type='partial',
                text=text.strip(),
                confidence=confidence,
                timestamp=time.time(),
                audio_duration=audio_duration
            )
            
            # 检查是否为重复结果
            if not self._is_duplicate_result(session_id, result, 'partial'):
                self.partial_results[session_id].append(result)
                self.active_sessions[session_id]['last_update'] = time.time()
                
                # 更新合并结果
                self._update_merged_result(session_id)
                
                logger.debug(f"添加部分结果 [{session_id}]: {text[:50]}...")
                
                # 触发回调
                if self.on_partial_result:
                    self.on_partial_result(session_id, result)
                
                if self.on_result_updated:
                    self.on_result_updated(session_id, self.merged_results[session_id], 'partial')
                
                return True
            else:
                logger.debug(f"跳过重复的部分结果 [{session_id}]: {text[:30]}...")
                return False
    
    def add_final_result(self, session_id: str, text: str, confidence: float = 1.0,
                        audio_duration: float = 0.0) -> bool:
        """添加最终识别结果"""
        with self.lock:
            if session_id not in self.active_sessions:
                logger.warning(f"会话不存在: {session_id}")
                return False
            
            # 创建最终结果
            result = RecognitionResult(
                session_id=session_id,
                result_type='final',
                text=text.strip(),
                confidence=confidence,
                timestamp=time.time(),
                audio_duration=audio_duration
            )
            
            # 检查是否为重复结果
            if not self._is_duplicate_result(session_id, result, 'final'):
                self.final_results[session_id].append(result)
                self.active_sessions[session_id]['last_update'] = time.time()
                
                # 清除相关的部分结果
                self._clear_related_partial_results(session_id, result)
                
                # 更新合并结果
                self._update_merged_result(session_id)
                
                logger.info(f"添加最终结果 [{session_id}]: {text}")
                
                # 触发回调
                if self.on_final_result:
                    self.on_final_result(session_id, result)
                
                if self.on_result_updated:
                    self.on_result_updated(session_id, self.merged_results[session_id], 'final')
                
                return True
            else:
                logger.debug(f"跳过重复的最终结果 [{session_id}]: {text[:30]}...")
                return False
    
    def _is_duplicate_result(self, session_id: str, new_result: RecognitionResult, 
                           result_type: str) -> bool:
        """检查是否为重复结果"""
        results_list = self.partial_results[session_id] if result_type == 'partial' else self.final_results[session_id]
        
        # 检查最近的结果
        for existing_result in reversed(results_list[-5:]):  # 只检查最近5个结果
            # 时间窗口检查
            time_diff = new_result.timestamp - existing_result.timestamp
            if time_diff > self.result_merge_window:
                continue
            
            # 文本相似度检查
            if self._calculate_text_similarity(new_result.text, existing_result.text) > 0.9:
                return True
        
        return False
    
    def _calculate_text_similarity(self, text1: str, text2: str) -> float:
        """计算文本相似度(简单实现)"""
        if not text1 or not text2:
            return 0.0
        
        # 简单的字符级相似度
        if text1 == text2:
            return 1.0
        
        # 检查包含关系
        if text1 in text2 or text2 in text1:
            return 0.95
        
        # 简单的编辑距离相似度
        max_len = max(len(text1), len(text2))
        if max_len == 0:
            return 1.0
        
        # 这里可以实现更复杂的相似度算法
        common_chars = sum(1 for c1, c2 in zip(text1, text2) if c1 == c2)
        return common_chars / max_len
    
    def _clear_related_partial_results(self, session_id: str, final_result: RecognitionResult):
        """清除与最终结果相关的部分结果"""
        partial_list = self.partial_results[session_id]
        
        # 标记相关的部分结果为已处理
        for partial_result in partial_list:
            if not partial_result.is_processed:
                similarity = self._calculate_text_similarity(partial_result.text, final_result.text)
                if similarity > 0.7:  # 相似度阈值
                    partial_result.is_processed = True
                    logger.debug(f"标记部分结果为已处理: {partial_result.text[:30]}...")
    
    def _update_merged_result(self, session_id: str):
        """更新合并后的识别结果"""
        # 获取所有未处理的最终结果
        final_texts = []
        for result in self.final_results[session_id]:
            if not result.is_processed and result.confidence >= self.confidence_threshold:
                final_texts.append(result.text)
        
        # 获取最新的部分结果(如果没有对应的最终结果)
        if self.partial_results[session_id]:
            latest_partial = self.partial_results[session_id][-1]
            if not latest_partial.is_processed and latest_partial.confidence >= self.confidence_threshold:
                # 检查是否已有对应的最终结果
                has_corresponding_final = any(
                    self._calculate_text_similarity(latest_partial.text, final_result.text) > 0.7
                    for final_result in self.final_results[session_id]
                )
                if not has_corresponding_final:
                    final_texts.append(f"[部分] {latest_partial.text}")
        
        # 合并结果
        self.merged_results[session_id] = " ".join(final_texts)
    
    def get_merged_result(self, session_id: str) -> str:
        """获取合并后的识别结果"""
        with self.lock:
            return self.merged_results.get(session_id, "")
    
    def get_session_results(self, session_id: str) -> Dict[str, List[RecognitionResult]]:
        """获取会话的所有结果"""
        with self.lock:
            return {
                'partial': self.partial_results.get(session_id, []),
                'final': self.final_results.get(session_id, [])
            }
    
    def complete_session(self, session_id: str) -> bool:
        """完成识别会话"""
        with self.lock:
            if session_id not in self.active_sessions:
                return False
            
            self.active_sessions[session_id]['status'] = 'completed'
            self.active_sessions[session_id]['end_time'] = time.time()
            
            # 最终更新合并结果
            self._update_merged_result(session_id)
            
            final_result = self.merged_results[session_id]
            logger.info(f"完成识别会话 [{session_id}]: {final_result}")
            
            # 触发回调
            if self.on_session_complete:
                self.on_session_complete(session_id, final_result)
            
            return True
    
    def _auto_cleanup_worker(self):
        """自动清理工作线程"""
        while True:
            try:
                time.sleep(self.auto_cleanup_interval)
                self._cleanup_old_sessions()
            except Exception as e:
                logger.error(f"自动清理线程错误: {e}")
    
    def _cleanup_old_sessions(self):
        """清理过期的会话"""
        current_time = time.time()
        sessions_to_remove = []
        
        with self.lock:
            for session_id, session_info in self.active_sessions.items():
                # 检查会话是否过期
                session_age = current_time - session_info['start_time']
                last_update_age = current_time - session_info['last_update']
                
                if (session_age > self.max_session_duration or 
                    last_update_age > self.auto_cleanup_interval or
                    session_info['status'] == 'completed'):
                    sessions_to_remove.append(session_id)
            
            # 移除过期会话
            for session_id in sessions_to_remove:
                logger.info(f"清理过期会话: {session_id}")
                del self.active_sessions[session_id]
                del self.partial_results[session_id]
                del self.final_results[session_id]
                del self.merged_results[session_id]
    
    def get_active_sessions(self) -> List[str]:
        """获取活跃会话列表"""
        with self.lock:
            return [sid for sid, info in self.active_sessions.items() if info['status'] == 'active']
    
    def get_status(self) -> Dict[str, Any]:
        """获取管理器状态"""
        with self.lock:
            return {
                'active_sessions_count': len([s for s in self.active_sessions.values() if s['status'] == 'active']),
                'total_sessions_count': len(self.active_sessions),
                'confidence_threshold': self.confidence_threshold,
                'max_session_duration': self.max_session_duration,
                'result_merge_window': self.result_merge_window
            }
    
    def reset(self):
        """重置管理器状态"""
        with self.lock:
            self.active_sessions.clear()
            self.partial_results.clear()
            self.final_results.clear()
            self.merged_results.clear()
            logger.info("StreamingRecognitionManager状态已重置")