streaming_vad.py 12.4 KB
# AIfeng/2025-07-07 09:34:55
# 流式语音活动检测模块
# 核心功能:持续拼接的累积识别、智能语音分段、动态阈值优化

import numpy as np
import time
from typing import List, Optional, Callable
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from logger import get_logger

logger = get_logger("StreamingVAD")

class StreamingVAD:
    """流式语音活动检测器
    
    实现持续拼接的累积识别策略:
    1. 语音片段不立即结束,而是持续累积
    2. 通过静默阈值、最小/最大语音长度控制分段
    3. 支持中间识别结果的发送
    4. 动态阈值优化提升检测准确性
    """
    
    def __init__(self, 
                 sample_rate: int = 16000,
                 chunk_size: int = 1024,
                 volume_threshold: float = 0.03,
                 silence_duration: float = 1.5,
                 min_speech_duration: float = 0.5,
                 max_speech_duration: float = 30.0,
                 pre_buffer_duration: float = 0.5,
                 dynamic_threshold_factor: float = 0.8,
                 partial_result_interval: float = 2.0):
        """
        初始化流式VAD
        
        Args:
            sample_rate: 采样率
            chunk_size: 音频块大小
            volume_threshold: 基础音量阈值
            silence_duration: 静音持续时间阈值(秒)
            min_speech_duration: 最小语音持续时间(秒)
            max_speech_duration: 最大语音持续时间(秒)
            pre_buffer_duration: 预缓冲时长(秒)
            dynamic_threshold_factor: 动态阈值因子
            partial_result_interval: 部分识别结果发送间隔(秒)
        """
        self.sample_rate = sample_rate
        self.chunk_size = chunk_size
        self.volume_threshold = volume_threshold
        self.silence_duration = silence_duration
        self.min_speech_duration = min_speech_duration
        self.max_speech_duration = max_speech_duration
        self.pre_buffer_duration = pre_buffer_duration
        self.dynamic_threshold_factor = dynamic_threshold_factor
        self.partial_result_interval = partial_result_interval
        
        # 计算帧数
        self.silence_frames = int(silence_duration * sample_rate / chunk_size)
        self.min_speech_frames = int(min_speech_duration * sample_rate / chunk_size)
        self.max_speech_frames = int(max_speech_duration * sample_rate / chunk_size)
        self.pre_buffer_frames = int(pre_buffer_duration * sample_rate / chunk_size)
        self.partial_result_frames = int(partial_result_interval * sample_rate / chunk_size)
        
        # 状态变量
        self.is_speaking = False
        self.silence_counter = 0
        self.speech_counter = 0
        self.total_frames_since_start = 0
        self.last_partial_result_frame = 0
        
        # 音频缓冲区
        self.current_speech_buffer = []
        self.pre_buffer = []
        
        # 动态阈值
        self.volume_history = []
        self.history_size = 50
        self.dynamic_threshold = volume_threshold
        
        # 回调函数
        self.on_speech_start: Optional[Callable] = None
        self.on_speech_continue: Optional[Callable] = None
        self.on_speech_end: Optional[Callable] = None
        self.on_partial_result_ready: Optional[Callable] = None
        
        logger.info(f"StreamingVAD初始化完成 - 静音阈值:{silence_duration}s, 最小语音:{min_speech_duration}s, 最大语音:{max_speech_duration}s")
    
    def _calculate_volume(self, data: bytes) -> float:
        """计算音频数据的音量(RMS值)"""
        try:
            # 将字节数据转换为numpy数组
            audio_data = np.frombuffer(data, dtype=np.int16)
            
            # 检查数组是否为空
            if len(audio_data) == 0:
                return 0.0
            
            # 计算RMS值并归一化到0-1范围
            rms = np.sqrt(np.mean(audio_data.astype(np.float64)**2))
            
            # 检查是否为有效值
            if np.isnan(rms) or np.isinf(rms):
                return 0.0
                
            return min(rms / 32768.0, 1.0)  # 确保不超过1.0
            
        except Exception as e:
            logger.warning(f"音量计算失败: {e}")
            return 0.0
    
    def _update_dynamic_threshold(self, volume: float):
        """更新动态阈值"""
        self.volume_history.append(volume)
        if len(self.volume_history) > self.history_size:
            self.volume_history.pop(0)
        
        if len(self.volume_history) >= 10:
            # 使用历史音量的百分位数作为动态阈值
            percentile_75 = np.percentile(self.volume_history, 75)
            self.dynamic_threshold = max(
                self.volume_threshold,
                percentile_75 * self.dynamic_threshold_factor
            )
    
    def process_audio_frame(self, audio_data: bytes) -> dict:
        """处理音频帧,返回VAD状态和动作
        
        Returns:
            dict: {
                'action': 'speech_start' | 'speech_continue' | 'speech_end' | 'partial_result' | 'silence',
                'audio_buffer': List[bytes],  # 当前累积的音频数据
                'is_speaking': bool,
                'volume': float,
                'threshold': float,
                'speech_duration': float,
                'silence_duration': float
            }
        """
        volume = self._calculate_volume(audio_data)
        self._update_dynamic_threshold(volume)
        
        # 维护预缓冲区
        self.pre_buffer.append(audio_data)
        if len(self.pre_buffer) > self.pre_buffer_frames:
            self.pre_buffer.pop(0)
        
        self.total_frames_since_start += 1
        
        result = {
            'action': 'silence',
            'audio_buffer': [],
            'is_speaking': self.is_speaking,
            'volume': volume,
            'threshold': self.dynamic_threshold,
            'speech_duration': self.speech_counter * self.chunk_size / self.sample_rate,
            'silence_duration': self.silence_counter * self.chunk_size / self.sample_rate
        }
        
        if volume > self.dynamic_threshold:
            # 检测到语音
            if not self.is_speaking:
                # 语音开始
                logger.debug(f"检测到语音开始 - 音量:{volume:.4f}, 阈值:{self.dynamic_threshold:.4f}")
                self.is_speaking = True
                self.silence_counter = 0
                self.speech_counter = 1
                self.last_partial_result_frame = self.total_frames_since_start
                
                # 将预缓冲区数据加入当前语音缓冲区
                self.current_speech_buffer = list(self.pre_buffer)
                
                result['action'] = 'speech_start'
                result['audio_buffer'] = list(self.current_speech_buffer)
                
                if self.on_speech_start:
                    self.on_speech_start(result)
            else:
                # 语音继续
                self.speech_counter += 1
                self.silence_counter = 0
                self.current_speech_buffer.append(audio_data)
                
                # 检查是否需要发送部分识别结果
                frames_since_partial = self.total_frames_since_start - self.last_partial_result_frame
                if frames_since_partial >= self.partial_result_frames:
                    result['action'] = 'partial_result'
                    result['audio_buffer'] = list(self.current_speech_buffer)
                    self.last_partial_result_frame = self.total_frames_since_start
                    
                    if self.on_partial_result_ready:
                        self.on_partial_result_ready(result)
                else:
                    result['action'] = 'speech_continue'
                    result['audio_buffer'] = list(self.current_speech_buffer)
                
                # 检查是否达到最大语音长度
                if self.speech_counter >= self.max_speech_frames:
                    logger.debug(f"达到最大语音长度,强制结束 - 语音帧数:{self.speech_counter}")
                    result['action'] = 'speech_end'
                    result['audio_buffer'] = list(self.current_speech_buffer)
                    
                    # 重置状态
                    self._reset_speech_state()
                    
                    if self.on_speech_end:
                        self.on_speech_end(result)
                elif self.on_speech_continue:
                    self.on_speech_continue(result)
        else:
            # 检测到静音
            if self.is_speaking:
                self.silence_counter += 1
                self.current_speech_buffer.append(audio_data)
                
                # 检查是否达到静音持续时间阈值
                if self.silence_counter >= self.silence_frames:
                    # 语音结束
                    logger.debug(f"检测到语音结束 - 语音帧数:{self.speech_counter}, 静音帧数:{self.silence_counter}")
                    
                    # 检查语音长度是否满足最小要求
                    if self.speech_counter >= self.min_speech_frames:
                        # 移除末尾的静音部分
                        speech_buffer = self.current_speech_buffer[:-self.silence_counter]
                        result['action'] = 'speech_end'
                        result['audio_buffer'] = speech_buffer
                        
                        if self.on_speech_end:
                            self.on_speech_end(result)
                    else:
                        logger.debug(f"语音片段太短,跳过: {self.speech_counter} < {self.min_speech_frames}")
                        result['action'] = 'silence'
                    
                    # 重置状态
                    self._reset_speech_state()
                else:
                    # 静音中,但还未达到阈值
                    result['action'] = 'speech_continue'
                    result['audio_buffer'] = list(self.current_speech_buffer)
                    
                    if self.on_speech_continue:
                        self.on_speech_continue(result)
        
        result['is_speaking'] = self.is_speaking
        return result
    
    def _reset_speech_state(self):
        """重置语音状态"""
        self.is_speaking = False
        self.silence_counter = 0
        self.speech_counter = 0
        self.current_speech_buffer = []
    
    def force_end_speech(self) -> Optional[dict]:
        """强制结束当前语音片段"""
        if self.is_speaking and len(self.current_speech_buffer) >= self.min_speech_frames:
            logger.info("强制结束当前语音片段")
            
            result = {
                'action': 'speech_end',
                'audio_buffer': list(self.current_speech_buffer),
                'is_speaking': True,
                'volume': 0.0,
                'threshold': self.dynamic_threshold,
                'speech_duration': self.speech_counter * self.chunk_size / self.sample_rate,
                'silence_duration': self.silence_counter * self.chunk_size / self.sample_rate
            }
            
            self._reset_speech_state()
            
            if self.on_speech_end:
                self.on_speech_end(result)
            
            return result
        return None
    
    def get_status(self) -> dict:
        """获取VAD状态信息"""
        return {
            'is_speaking': self.is_speaking,
            'dynamic_threshold': self.dynamic_threshold,
            'volume_threshold': self.volume_threshold,
            'silence_duration': self.silence_duration,
            'min_speech_duration': self.min_speech_duration,
            'max_speech_duration': self.max_speech_duration,
            'current_speech_frames': self.speech_counter,
            'current_silence_frames': self.silence_counter,
            'speech_buffer_size': len(self.current_speech_buffer),
            'total_frames_processed': self.total_frames_since_start
        }
    
    def reset(self):
        """重置VAD状态"""
        self._reset_speech_state()
        self.total_frames_since_start = 0
        self.last_partial_result_frame = 0
        self.pre_buffer = []
        self.volume_history = []
        self.dynamic_threshold = self.volume_threshold
        logger.info("StreamingVAD状态已重置")