audio_utils.py 7.1 KB
# AIfeng/2025-07-11 13:36:00
"""
豆包ASR音频处理工具模块
提供音频格式检测、分片处理、元数据提取等功能
"""

import wave
from io import BytesIO
from typing import Tuple, Generator, Dict, Any


class AudioProcessor:
    """音频处理器"""
    
    @staticmethod
    def read_wav_info(audio_data: bytes) -> Tuple[int, int, int, int, bytes]:
        """
        读取WAV文件信息
        
        Args:
            audio_data: WAV音频数据
            
        Returns:
            Tuple: (声道数, 采样宽度, 采样率, 帧数, 音频字节数据)
        """
        try:
            with BytesIO(audio_data) as audio_io:
                with wave.open(audio_io, 'rb') as wave_fp:
                    nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4]
                    wave_bytes = wave_fp.readframes(nframes)
            return nchannels, sampwidth, framerate, nframes, wave_bytes
        except Exception as e:
            raise ValueError(f"读取WAV文件失败: {e}")
    
    @staticmethod
    def is_wav_format(audio_data: bytes) -> bool:
        """
        检查是否为WAV格式
        
        Args:
            audio_data: 音频数据
            
        Returns:
            bool: 是否为WAV格式
        """
        if len(audio_data) < 44:
            return False
        return audio_data[0:4] == b"RIFF" and audio_data[8:12] == b"WAVE"
    
    @staticmethod
    def detect_audio_format(audio_data: bytes) -> str:
        """
        检测音频格式
        
        Args:
            audio_data: 音频数据
            
        Returns:
            str: 音频格式 ('wav', 'mp3', 'pcm', 'unknown')
        """
        if len(audio_data) < 4:
            return 'unknown'
        
        # 检查WAV格式
        if AudioProcessor.is_wav_format(audio_data):
            return 'wav'
        
        # 检查MP3格式
        if audio_data[0:3] == b"ID3" or audio_data[0:2] == b"\xff\xfb":
            return 'mp3'
        
        # 默认为PCM
        return 'pcm'
    
    @staticmethod
    def slice_audio_data(
        audio_data: bytes, 
        chunk_size: int
    ) -> Generator[Tuple[bytes, bool], None, None]:
        """
        将音频数据分片
        
        Args:
            audio_data: 音频数据
            chunk_size: 分片大小
            
        Yields:
            Tuple[bytes, bool]: (音频片段, 是否为最后一片)
        """
        data_len = len(audio_data)
        offset = 0
        
        while offset + chunk_size < data_len:
            yield audio_data[offset:offset + chunk_size], False
            offset += chunk_size
        
        # 最后一片
        if offset < data_len:
            yield audio_data[offset:data_len], True
    
    @staticmethod
    def calculate_segment_size(
        audio_format: str,
        sample_rate: int = 16000,
        channels: int = 1,
        bits: int = 16,
        segment_duration_ms: int = 200,
        mp3_seg_size: int = 1000
    ) -> int:
        """
        计算音频分片大小
        
        Args:
            audio_format: 音频格式
            sample_rate: 采样率
            channels: 声道数
            bits: 位深度
            segment_duration_ms: 分片时长(毫秒)
            mp3_seg_size: MP3分片大小
            
        Returns:
            int: 分片大小(字节)
        """
        if audio_format == 'mp3':
            return mp3_seg_size
        elif audio_format == 'wav':
            # 计算每秒字节数
            bytes_per_second = channels * (bits // 8) * sample_rate
            return int(bytes_per_second * segment_duration_ms / 1000)
        elif audio_format == 'pcm':
            # PCM格式计算
            return int(sample_rate * (bits // 8) * channels * segment_duration_ms / 1000)
        else:
            raise ValueError(f"不支持的音频格式: {audio_format}")
    
    @staticmethod
    def extract_wav_metadata(audio_data: bytes) -> Dict[str, Any]:
        """
        提取WAV文件元数据
        
        Args:
            audio_data: WAV音频数据
            
        Returns:
            Dict: 音频元数据
        """
        try:
            nchannels, sampwidth, framerate, nframes, _ = AudioProcessor.read_wav_info(audio_data)
            duration = nframes / framerate
            
            return {
                'format': 'wav',
                'channels': nchannels,
                'sample_width': sampwidth,
                'sample_rate': framerate,
                'frames': nframes,
                'duration': duration,
                'size': len(audio_data)
            }
        except Exception as e:
            return {
                'format': 'wav',
                'error': str(e),
                'size': len(audio_data)
            }
    
    @staticmethod
    def validate_audio_params(
        audio_format: str,
        sample_rate: int,
        channels: int,
        bits: int
    ) -> bool:
        """
        验证音频参数
        
        Args:
            audio_format: 音频格式
            sample_rate: 采样率
            channels: 声道数
            bits: 位深度
            
        Returns:
            bool: 参数是否有效
        """
        # 支持的格式
        supported_formats = ['wav', 'mp3', 'pcm']
        if audio_format not in supported_formats:
            return False
        
        # 采样率范围
        if sample_rate < 8000 or sample_rate > 48000:
            return False
        
        # 声道数
        if channels < 1 or channels > 2:
            return False
        
        # 位深度
        if bits not in [8, 16, 24, 32]:
            return False
        
        return True
    
    @staticmethod
    def prepare_audio_for_recognition(
        audio_data: bytes,
        target_format: str = 'wav',
        segment_duration_ms: int = 200
    ) -> Tuple[str, int, Dict[str, Any]]:
        """
        为识别准备音频数据
        
        Args:
            audio_data: 原始音频数据
            target_format: 目标格式
            segment_duration_ms: 分片时长
            
        Returns:
            Tuple: (检测到的格式, 分片大小, 音频元数据)
        """
        # 检测音频格式
        detected_format = AudioProcessor.detect_audio_format(audio_data)
        
        # 提取元数据
        if detected_format == 'wav':
            metadata = AudioProcessor.extract_wav_metadata(audio_data)
            segment_size = AudioProcessor.calculate_segment_size(
                detected_format,
                metadata.get('sample_rate', 16000),
                metadata.get('channels', 1),
                metadata.get('sample_width', 2) * 8,
                segment_duration_ms
            )
        else:
            # 对于非WAV格式,使用默认参数
            metadata = {
                'format': detected_format,
                'size': len(audio_data)
            }
            segment_size = AudioProcessor.calculate_segment_size(
                detected_format,
                segment_duration_ms=segment_duration_ms
            )
        
        return detected_format, segment_size, metadata