protocol.py 8.61 KB
# AIfeng/2025-07-11 13:36:00
"""
豆包语音识别WebSocket协议处理模块
实现二进制协议的编解码、消息类型定义和数据包处理
"""

import gzip
import json
from typing import Dict, Any, Tuple, Optional


# 协议版本和头部大小
PROTOCOL_VERSION = 0b0001
DEFAULT_HEADER_SIZE = 0b0001

# 消息类型定义
class MessageType:
    FULL_CLIENT_REQUEST = 0b0001
    AUDIO_ONLY_REQUEST = 0b0010
    FULL_SERVER_RESPONSE = 0b1001
    SERVER_ACK = 0b1011
    SERVER_ERROR_RESPONSE = 0b1111

# 消息类型特定标志
class MessageFlags:
    NO_SEQUENCE = 0b0000
    POS_SEQUENCE = 0b0001
    NEG_SEQUENCE = 0b0010
    NEG_WITH_SEQUENCE = 0b0011

# 序列化方法
class SerializationMethod:
    NO_SERIALIZATION = 0b0000
    JSON = 0b0001

# 压缩方法
class CompressionType:
    NO_COMPRESSION = 0b0000
    GZIP = 0b0001


class DoubaoProtocol:
    """豆包ASR WebSocket协议处理器"""
    
    @staticmethod
    def generate_header(
        message_type: int = MessageType.FULL_CLIENT_REQUEST,
        message_type_specific_flags: int = MessageFlags.NO_SEQUENCE,
        serial_method: int = SerializationMethod.JSON,
        compression_type: int = CompressionType.GZIP,
        reserved_data: int = 0x00
    ) -> bytearray:
        """
        生成协议头部
        
        Args:
            message_type: 消息类型
            message_type_specific_flags: 消息类型特定标志
            serial_method: 序列化方法
            compression_type: 压缩类型
            reserved_data: 保留字段
            
        Returns:
            bytearray: 4字节协议头部
        """
        header = bytearray()
        header_size = 1
        header.append((PROTOCOL_VERSION << 4) | header_size)
        header.append((message_type << 4) | message_type_specific_flags)
        header.append((serial_method << 4) | compression_type)
        header.append(reserved_data)
        return header
    
    @staticmethod
    def generate_sequence_payload(sequence: int) -> bytearray:
        """
        生成序列号载荷
        
        Args:
            sequence: 序列号
            
        Returns:
            bytearray: 4字节序列号数据
        """
        payload = bytearray()
        payload.extend(sequence.to_bytes(4, 'big', signed=True))
        return payload
    
    @staticmethod
    def parse_response(response_data: bytes) -> Dict[str, Any]:
        """
        解析服务器响应数据
        
        Args:
            response_data: 服务器响应的二进制数据
            
        Returns:
            Dict: 解析后的响应数据
        """
        if len(response_data) < 4:
            raise ValueError("响应数据长度不足")
        
        # 解析头部
        protocol_version = response_data[0] >> 4
        header_size = response_data[0] & 0x0f
        message_type = response_data[1] >> 4
        message_type_specific_flags = response_data[1] & 0x0f
        serialization_method = response_data[2] >> 4
        message_compression = response_data[2] & 0x0f
        reserved = response_data[3]
        
        # 解析扩展头部和载荷
        header_extensions = response_data[4:header_size * 4]
        payload = response_data[header_size * 4:]
        
        result = {
            'protocol_version': protocol_version,
            'header_size': header_size,
            'message_type': message_type,
            'message_type_specific_flags': message_type_specific_flags,
            'serialization_method': serialization_method,
            'message_compression': message_compression,
            'is_last_package': False,
            'payload_msg': None,
            'payload_size': 0
        }
        
        # 处理序列号
        if message_type_specific_flags & 0x01:
            if len(payload) >= 4:
                seq = int.from_bytes(payload[:4], "big", signed=True)
                result['payload_sequence'] = seq
                payload = payload[4:]
        
        # 检查是否为最后一包
        if message_type_specific_flags & 0x02:
            result['is_last_package'] = True
        
        # 根据消息类型解析载荷
        payload_msg = None
        payload_size = 0
        
        if message_type == MessageType.FULL_SERVER_RESPONSE:
            if len(payload) >= 4:
                payload_size = int.from_bytes(payload[:4], "big", signed=True)
                payload_msg = payload[4:]
        elif message_type == MessageType.SERVER_ACK:
            if len(payload) >= 4:
                seq = int.from_bytes(payload[:4], "big", signed=True)
                result['seq'] = seq
                if len(payload) >= 8:
                    payload_size = int.from_bytes(payload[4:8], "big", signed=False)
                    payload_msg = payload[8:]
        elif message_type == MessageType.SERVER_ERROR_RESPONSE:
            if len(payload) >= 8:
                code = int.from_bytes(payload[:4], "big", signed=False)
                result['code'] = code
                payload_size = int.from_bytes(payload[4:8], "big", signed=False)
                payload_msg = payload[8:]
        
        # 解压缩和反序列化载荷
        if payload_msg is not None:
            if message_compression == CompressionType.GZIP:
                try:
                    payload_msg = gzip.decompress(payload_msg)
                except Exception as e:
                    result['decompress_error'] = str(e)
                    return result
            
            if serialization_method == SerializationMethod.JSON:
                try:
                    payload_msg = json.loads(payload_msg.decode('utf-8'))
                except Exception as e:
                    result['json_parse_error'] = str(e)
                    return result
            elif serialization_method != SerializationMethod.NO_SERIALIZATION:
                payload_msg = payload_msg.decode('utf-8')
        
        result['payload_msg'] = payload_msg
        result['payload_size'] = payload_size
        return result
    
    @staticmethod
    def build_full_request(
        request_params: Dict[str, Any],
        sequence: int = 1,
        compression: bool = True
    ) -> bytearray:
        """
        构建完整客户端请求
        
        Args:
            request_params: 请求参数字典
            sequence: 序列号
            compression: 是否启用压缩
            
        Returns:
            bytearray: 完整的请求数据包
        """
        # 序列化请求参数
        payload_bytes = json.dumps(request_params).encode('utf-8')
        
        # 压缩载荷
        compression_type = CompressionType.GZIP if compression else CompressionType.NO_COMPRESSION
        if compression:
            payload_bytes = gzip.compress(payload_bytes)
        
        # 生成头部
        header = DoubaoProtocol.generate_header(
            message_type=MessageType.FULL_CLIENT_REQUEST,
            message_type_specific_flags=MessageFlags.POS_SEQUENCE,
            compression_type=compression_type
        )
        
        # 构建完整请求
        request = bytearray(header)
        request.extend(DoubaoProtocol.generate_sequence_payload(sequence))
        request.extend(len(payload_bytes).to_bytes(4, 'big'))
        request.extend(payload_bytes)
        
        return request
    
    @staticmethod
    def build_audio_request(
        audio_data: bytes,
        sequence: int,
        is_last: bool = False,
        compression: bool = True
    ) -> bytearray:
        """
        构建音频数据请求
        
        Args:
            audio_data: 音频数据
            sequence: 序列号
            is_last: 是否为最后一包
            compression: 是否启用压缩
            
        Returns:
            bytearray: 音频请求数据包
        """
        # 压缩音频数据
        compression_type = CompressionType.GZIP if compression else CompressionType.NO_COMPRESSION
        payload_bytes = gzip.compress(audio_data) if compression else audio_data
        
        # 确定消息标志
        if is_last:
            flags = MessageFlags.NEG_WITH_SEQUENCE
            sequence = -abs(sequence)
        else:
            flags = MessageFlags.POS_SEQUENCE
        
        # 生成头部
        header = DoubaoProtocol.generate_header(
            message_type=MessageType.AUDIO_ONLY_REQUEST,
            message_type_specific_flags=flags,
            compression_type=compression_type
        )
        
        # 构建音频请求
        request = bytearray(header)
        request.extend(DoubaoProtocol.generate_sequence_payload(sequence))
        request.extend(len(payload_bytes).to_bytes(4, 'big'))
        request.extend(payload_bytes)
        
        return request