ASR_server.py 14.1 KB
import asyncio
import websockets
import argparse
import json
import logging
from funasr import AutoModel
import os

# 设置日志级别
logger = logging.getLogger(__name__)
logger.setLevel(logging.CRITICAL)

# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0", help="host ip, localhost, 0.0.0.0")
parser.add_argument("--port", type=int, default=10197, help="grpc server port")
parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu")
parser.add_argument("--gpu_id", type=int, default=0, help="specify which gpu device to use")
args = parser.parse_args()

# 初始化模型
print("model loading")
try:
    asr_model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
                          vad_model="fsmn-vad", vad_model_revision="v2.0.4",
                          punc_model="ct-punc-c", punc_model_revision="v2.0.4",
                          device=f"cuda:{args.gpu_id}" if args.ngpu else "cpu", disable_update=True)
                        #   ,disable_update=True
    print("model loaded")
except Exception as e:
    print(f"模型加载失败: {e}")
    import traceback
    traceback.print_exc()
    exit(1)
websocket_users = {}
task_queue = asyncio.Queue()
# 分块会话管理
chunk_sessions = {}  # {user_id: {filename, chunks, total_chunks, received_chunks, temp_file}}

async def ws_serve(websocket, path):
    global websocket_users, chunk_sessions
    user_id = id(websocket)
    websocket_users[user_id] = websocket
    try:
        async for message in websocket:
            if isinstance(message, str):
                data = json.loads(message)
                
                # 处理分块协议
                if 'type' in data:
                    await handle_chunked_protocol(websocket, data, user_id)
                # 处理传统协议
                elif 'url' in data:
                    # 处理文件URL
                    await task_queue.put((websocket, data['url'], 'url'))
                elif 'audio_data' in data:
                    # 处理音频数据
                    await task_queue.put((websocket, data, 'audio_data'))
    except websockets.exceptions.ConnectionClosed as e:
        logger.info(f"Connection closed: {e.reason}")
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
    finally:
        logger.info(f"Cleaning up connection for user {user_id}")
        if user_id in websocket_users:
            del websocket_users[user_id]
        # 清理分块会话
        if user_id in chunk_sessions:
            await cleanup_chunk_session(user_id)
        await websocket.close()
        logger.info("WebSocket closed")

async def worker():
    while True:
        task_data = await task_queue.get()
        websocket = task_data[0]
        
        if websocket.open:
            if len(task_data) == 3:  # 新格式: (websocket, data, type)
                data, data_type = task_data[1], task_data[2]
                if data_type == 'url':
                    await process_wav_file(websocket, data)
                elif data_type == 'audio_data':
                    await process_audio_data(websocket, data)
                elif data_type == 'chunked_audio':
                    await process_chunked_audio(websocket, data)
            else:  # 兼容旧格式: (websocket, url)
                await process_wav_file(websocket, task_data[1])
        else:
            logger.info("WebSocket connection is already closed when trying to process file")
        task_queue.task_done()

async def process_wav_file(websocket, url):
    # 热词
    param_dict = {"sentence_timestamp": False}
    with open("data/hotword.txt", "r", encoding="utf-8") as f:
        lines = f.readlines()
        lines = [line.strip() for line in lines]
    hotword = " ".join(lines)
    print(f"热词:{hotword}")
    param_dict["hotword"] = hotword
    wav_path = url
    try:
        res = asr_model.generate(input=wav_path, is_final=True, **param_dict)
        if res:
            if 'text' in res[0] and websocket.open:
                await websocket.send(res[0]['text'])
    except Exception as e:
        print(f"Error during model.generate: {e}")
    finally:
        # 注释掉文件删除操作,保留缓存文件用于测试
        # if os.path.exists(wav_path):
        #     os.remove(wav_path)
        print(f"保留音频文件用于测试: {wav_path}")

async def handle_chunked_protocol(websocket, data, user_id):
    """处理分块协议消息"""
    global chunk_sessions
    
    try:
        msg_type = data.get('type')
        filename = data.get('filename', 'unknown.wav')
        
        if msg_type == 'audio_start':
            # 开始新的分块会话
            total_chunks = data.get('total_chunks', 0)
            total_size = data.get('total_size', 0)
            
            print(f"开始接收分块音频: {filename}, 总分块数: {total_chunks}, 总大小: {total_size} bytes")
            
            # 创建临时文件
            import tempfile
            temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
            
            chunk_sessions[user_id] = {
                'filename': filename,
                'total_chunks': total_chunks,
                'total_size': total_size,
                'received_chunks': 0,
                'temp_file': temp_file,
                'temp_path': temp_file.name,
                'chunks_data': {}  # {chunk_index: chunk_data}
            }
            
            await websocket.send(json.dumps({"status": "ready", "message": f"准备接收 {total_chunks} 个分块"}))
            
        elif msg_type == 'audio_chunk':
            # 接收音频分块
            if user_id not in chunk_sessions:
                await websocket.send(json.dumps({"error": "未找到分块会话,请先发送audio_start"}))
                return
            
            session = chunk_sessions[user_id]
            chunk_index = data.get('chunk_index', -1)
            chunk_data = data.get('chunk_data', '')
            is_last = data.get('is_last', False)
            
            if chunk_index >= 0 and chunk_data:
                # 解码并存储分块数据
                import base64
                chunk_bytes = base64.b64decode(chunk_data)
                session['chunks_data'][chunk_index] = chunk_bytes
                session['received_chunks'] += 1
                
                # 进度反馈
                progress = (session['received_chunks'] / session['total_chunks']) * 100
                if session['received_chunks'] % 10 == 0 or is_last:
                    print(f"接收进度: {progress:.1f}% ({session['received_chunks']}/{session['total_chunks']})")
            
        elif msg_type == 'audio_end':
            # 完成分块接收,重组音频
            if user_id not in chunk_sessions:
                await websocket.send(json.dumps({"error": "未找到分块会话"}))
                return
            
            session = chunk_sessions[user_id]
            
            # 检查是否接收完整
            if session['received_chunks'] != session['total_chunks']:
                await websocket.send(json.dumps({
                    "error": f"分块不完整: 期望{session['total_chunks']}, 实际{session['received_chunks']}"
                }))
                await cleanup_chunk_session(user_id)
                return
            
            # 按顺序重组音频数据
            print(f"重组音频文件: {session['filename']}")
            with open(session['temp_path'], 'wb') as f:
                for i in range(session['total_chunks']):
                    if i in session['chunks_data']:
                        f.write(session['chunks_data'][i])
                    else:
                        print(f"警告: 分块 {i} 缺失")
            
            # 提交到处理队列
            reconstructed_data = {
                'audio_file_path': session['temp_path'],
                'filename': session['filename']
            }
            await task_queue.put((websocket, reconstructed_data, 'chunked_audio'))
            
            # 清理会话(保留临时文件给处理函数)
            del chunk_sessions[user_id]
            print(f"分块音频重组完成: {session['filename']}")
            
    except Exception as e:
        print(f"处理分块协议时出错: {e}")
        await websocket.send(json.dumps({"error": f"分块处理错误: {str(e)}"}))
        if user_id in chunk_sessions:
            await cleanup_chunk_session(user_id)

async def cleanup_chunk_session(user_id):
    """清理分块会话"""
    global chunk_sessions
    
    if user_id in chunk_sessions:
        session = chunk_sessions[user_id]
        try:
            # 关闭并删除临时文件
            if 'temp_file' in session:
                session['temp_file'].close()
            if 'temp_path' in session and os.path.exists(session['temp_path']):
                os.remove(session['temp_path'])
                print(f"清理临时文件: {session['temp_path']}")
        except Exception as e:
            print(f"清理分块会话时出错: {e}")
        finally:
            del chunk_sessions[user_id]

async def process_chunked_audio(websocket, data):
    """处理分块重组后的音频文件"""
    try:
        audio_file_path = data.get('audio_file_path')
        filename = data.get('filename', 'chunked_audio.wav')
        
        if not audio_file_path or not os.path.exists(audio_file_path):
            await websocket.send(json.dumps({"error": "重组音频文件不存在"}))
            return
        
        print(f"处理分块重组音频: {filename}, 文件路径: {audio_file_path}")
        
        # 热词配置
        param_dict = {"sentence_timestamp": False}
        try:
            with open("data/hotword.txt", "r", encoding="utf-8") as f:
                lines = f.readlines()
                lines = [line.strip() for line in lines]
            hotword = " ".join(lines)
            print(f"热词:{hotword}")
            param_dict["hotword"] = hotword
        except FileNotFoundError:
            print("热词文件不存在,跳过热词配置")
        
        # 进行语音识别
        res = asr_model.generate(input=audio_file_path, is_final=True, **param_dict)
        if res and websocket.open:
            if 'text' in res[0]:
                result_text = res[0]['text']
                print(f"分块音频识别结果: {result_text}")
                await websocket.send(result_text)
            else:
                await websocket.send("识别失败:无法获取文本结果")
        
    except Exception as e:
        print(f"处理分块音频时出错: {e}")
        if websocket.open:
            await websocket.send(f"分块音频识别错误: {str(e)}")
    finally:
        # 注释掉临时文件删除操作,保留用于测试
        # if 'audio_file_path' in locals() and os.path.exists(audio_file_path):
        #     os.remove(audio_file_path)
        if 'audio_file_path' in locals():
            print(f"保留分块重组音频文件用于测试: {audio_file_path}")

async def process_audio_data(websocket, data):
    """处理音频数据"""
    import base64
    import tempfile
    
    try:
        # 获取音频数据
        audio_data = data.get('audio_data')
        filename = data.get('filename', 'audio.wav')
        
        if not audio_data:
            await websocket.send(json.dumps({"error": "No audio data provided"}))
            return
        
        # 解码Base64音频数据
        audio_bytes = base64.b64decode(audio_data)
        
        # 创建临时文件
        with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
            temp_file.write(audio_bytes)
            temp_path = temp_file.name
        
        print(f"处理音频文件: {filename}, 临时路径: {temp_path}")
        
        # 热词配置
        param_dict = {"sentence_timestamp": False}
        try:
            with open("data/hotword.txt", "r", encoding="utf-8") as f:
                lines = f.readlines()
                lines = [line.strip() for line in lines]
            hotword = " ".join(lines)
            print(f"热词:{hotword}")
            param_dict["hotword"] = hotword
        except FileNotFoundError:
            print("热词文件不存在,跳过热词配置")
        
        # 进行语音识别
        res = asr_model.generate(input=temp_path, is_final=True, **param_dict)
        if res and websocket.open:
            if 'text' in res[0]:
                result_text = res[0]['text']
                print(f"识别结果: {result_text}")
                await websocket.send(result_text)
            else:
                await websocket.send("识别失败:无法获取文本结果")
        
    except Exception as e:
        print(f"处理音频数据时出错: {e}")
        if websocket.open:
            await websocket.send(f"识别错误: {str(e)}")
    finally:
        # 注释掉临时文件删除操作,保留用于测试
        # if 'temp_path' in locals() and os.path.exists(temp_path):
        #     os.remove(temp_path)
        if 'temp_path' in locals():
            print(f"保留临时音频文件用于测试: {temp_path}")

async def main():
    server = await websockets.serve(ws_serve, args.host, args.port, ping_interval=10)
    worker_task = asyncio.create_task(worker())
    
    try:
        # 保持服务器运行,直到被手动中断
        print(f"ASR服务器已启动,监听地址: {args.host}:{args.port}")
        print("注意:此版本已禁用文件自动删除功能,用于测试分析")
        await asyncio.Future()  # 永久等待,直到程序被中断
    except asyncio.CancelledError:
        print("服务器正在关闭...")
    finally:
        # 清理资源
        worker_task.cancel()
        try:
            await worker_task
        except asyncio.CancelledError:
            pass
        server.close()
        await server.wait_closed()

# 使用 asyncio 运行主函数
try:
    asyncio.run(main())
except KeyboardInterrupt:
    logging.info("服务器已关闭")
except Exception as e:
    logging.error(f"服务器启动失败: {e}")
    import traceback
    traceback.print_exc()