冯杨

同步github官方更新截止Commits on Apr 18, 2025

a9c36c76e569107b5a39b3de8afd6e016b24d662
@@ -15,4 +15,8 @@ pretrained @@ -15,4 +15,8 @@ pretrained
15 *.mp4 15 *.mp4
16 .DS_Store 16 .DS_Store
17 workspace/log_ngp.txt 17 workspace/log_ngp.txt
18 -.idea  
  18 +.idea
  19 +
  20 +models/
  21 +*.log
  22 +dist
  1 +Real-time interactive streaming digital human enables synchronous audio and video dialogue. It can basically achieve commercial effects.
  2 +
  3 +[Effect of wav2lip](https://www.bilibili.com/video/BV1scwBeyELA/) | [Effect of ernerf](https://www.bilibili.com/video/BV1G1421z73r/) | [Effect of musetalk](https://www.bilibili.com/video/BV1gm421N7vQ/)
  4 +
  5 +## News
  6 +- December 8, 2024: Improved multi-concurrency, and the video memory does not increase with the number of concurrent connections.
  7 +- December 21, 2024: Added model warm-up for wav2lip and musetalk to solve the problem of stuttering during the first inference. Thanks to [@heimaojinzhangyz](https://github.com/heimaojinzhangyz)
  8 +- December 28, 2024: Added the digital human model Ultralight-Digital-Human. Thanks to [@lijihua2017](https://github.com/lijihua2017)
  9 +- February 7, 2025: Added fish-speech tts
  10 +- February 21, 2025: Added the open-source model wav2lip256. Thanks to @不蠢不蠢
  11 +- March 2, 2025: Added Tencent's speech synthesis service
  12 +- March 16, 2025: Supports mac gpu inference. Thanks to [@GcsSloop](https://github.com/GcsSloop)
  13 +
  14 +## Features
  15 +1. Supports multiple digital human models: ernerf, musetalk, wav2lip, Ultralight-Digital-Human
  16 +2. Supports voice cloning
  17 +3. Supports interrupting the digital human while it is speaking
  18 +4. Supports full-body video stitching
  19 +5. Supports rtmp and webrtc
  20 +6. Supports video arrangement: Play custom videos when not speaking
  21 +7. Supports multi-concurrency
  22 +
  23 +## 1. Installation
  24 +
  25 +Tested on Ubuntu 20.04, Python 3.10, Pytorch 1.12 and CUDA 11.3
  26 +
  27 +### 1.1 Install dependency
  28 +
  29 +```bash
  30 +conda create -n nerfstream python=3.10
  31 +conda activate nerfstream
  32 +# If the cuda version is not 11.3 (confirm the version by running nvidia-smi), install the corresponding version of pytorch according to <https://pytorch.org/get-started/previous-versions/>
  33 +conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
  34 +pip install -r requirements.txt
  35 +# If you need to train the ernerf model, install the following libraries
  36 +# pip install "git+https://github.com/facebookresearch/pytorch3d.git"
  37 +# pip install tensorflow-gpu==2.8.0
  38 +# pip install --upgrade "protobuf<=3.20.1"
  39 +```
  40 +Common installation issues [FAQ](https://livetalking-doc.readthedocs.io/en/latest/faq.html)
  41 +For setting up the linux cuda environment, you can refer to this article https://zhuanlan.zhihu.com/p/674972886
  42 +
  43 +
  44 +## 2. Quick Start
  45 +- Download the models
  46 +Quark Cloud Disk <https://pan.quark.cn/s/83a750323ef0>
  47 +Google Drive <https://drive.google.com/drive/folders/1FOC_MD6wdogyyX_7V1d4NDIO7P9NlSAJ?usp=sharing>
  48 +Copy wav2lip256.pth to the models folder of this project and rename it to wav2lip.pth;
  49 +Extract wav2lip256_avatar1.tar.gz and copy the entire folder to the data/avatars folder of this project.
  50 +- Run
  51 +python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar1
  52 +Open http://serverip:8010/webrtcapi.html in a browser. First click'start' to play the digital human video; then enter any text in the text box and submit it. The digital human will broadcast this text.
  53 +<font color=red>The server side needs to open ports tcp:8010; udp:1-65536</font>
  54 +If you need to purchase a high-definition wav2lip model for commercial use, [Link](https://livetalking-doc.readthedocs.io/zh-cn/latest/service.html#wav2lip).
  55 +
  56 +- Quick experience
  57 +<https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_GitHub_livetalking1.3> Create an instance with this image to run it.
  58 +
  59 +If you can't access huggingface, before running
  60 +```
  61 +export HF_ENDPOINT=https://hf-mirror.com
  62 +```
  63 +
  64 +
  65 +## 3. More Usage
  66 +Usage instructions: <https://livetalking-doc.readthedocs.io/en/latest>
  67 +
  68 +## 4. Docker Run
  69 +No need for the previous installation, just run directly.
  70 +```
  71 +docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/codewithgpu2/lipku-metahuman-stream:2K9qaMBu8v
  72 +```
  73 +The code is in /root/metahuman-stream. First, git pull to get the latest code, and then execute the commands as in steps 2 and 3.
  74 +
  75 +The following images are provided:
  76 +- autodl image: <https://www.codewithgpu.com/i/lipku/metahuman-stream/base>
  77 +[autodl Tutorial](https://livetalking-doc.readthedocs.io/en/latest/autodl/README.html)
  78 +- ucloud image: <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_livetalking1.3>
  79 +Any port can be opened, and there is no need to deploy an srs service additionally.
  80 +[ucloud Tutorial](https://livetalking-doc.readthedocs.io/en/latest/ucloud/ucloud.html)
  81 +
  82 +
  83 +## 5. TODO
  84 +- [x] Added chatgpt to enable digital human dialogue
  85 +- [x] Voice cloning
  86 +- [x] Replace the digital human with a video when it is silent
  87 +- [x] MuseTalk
  88 +- [x] Wav2Lip
  89 +- [x] Ultralight-Digital-Human
  90 +
  91 +---
  92 +If this project is helpful to you, please give it a star. Friends who are interested are also welcome to join in and improve this project together.
  93 +* Knowledge Planet: https://t.zsxq.com/7NMyO, where high-quality common problems, best practice experiences, and problem solutions are accumulated.
  94 +* WeChat Official Account: Digital Human Technology
  95 +![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyfaiaLZGuMGQXnhLWxibpJUS2gfs8Dje6JuMY8zu2tVyU9n8Zx1yaNncvKHBMibX0ocehoITy5qQEZg/640?wxfrom=12&tp=wxpic&usePicPrefetch=1&wx_fmt=jpeg&amp;from=appmsg)
1 -Real time interactive streaming digital human, realize audio video synchronous dialogue. It can basically achieve commercial effects. 1 +[English](./README-EN.md) | 中文版
2 实时交互流式数字人,实现音视频同步对话。基本可以达到商用效果 2 实时交互流式数字人,实现音视频同步对话。基本可以达到商用效果
  3 +[wav2lip效果](https://www.bilibili.com/video/BV1scwBeyELA/) | [ernerf效果](https://www.bilibili.com/video/BV1G1421z73r/) | [musetalk效果](https://www.bilibili.com/video/BV1gm421N7vQ/)
3 4
4 -[ernerf 效果](https://www.bilibili.com/video/BV1PM4m1y7Q2/) [musetalk 效果](https://www.bilibili.com/video/BV1gm421N7vQ/) [wav2lip 效果](https://www.bilibili.com/video/BV1Bw4m1e74P/)  
5 -  
6 -## 为避免与 3d 数字人混淆,原项目 metahuman-stream 改名为 livetalking,原有链接地址继续可用 5 +## 为避免与3d数字人混淆,原项目metahuman-stream改名为livetalking,原有链接地址继续可用
7 6
8 ## News 7 ## News
9 -  
10 - 2024.12.8 完善多并发,显存不随并发数增加 8 - 2024.12.8 完善多并发,显存不随并发数增加
11 -- 2024.12.21 添加 wav2lip、musetalk 模型预热,解决第一次推理卡顿问题。感谢@heimaojinzhangyz  
12 -- 2024.12.28 添加数字人模型 Ultralight-Digital-Human。 感谢@lijihua2017  
13 -- 2025.2.7 添加 fish-speech tts  
14 -- 2025.2.21 添加 wav2lip256 开源模型 感谢@不蠢不蠢 9 +- 2024.12.21 添加wav2lip、musetalk模型预热,解决第一次推理卡顿问题。感谢[@heimaojinzhangyz](https://github.com/heimaojinzhangyz)
  10 +- 2024.12.28 添加数字人模型Ultralight-Digital-Human。 感谢[@lijihua2017](https://github.com/lijihua2017)
  11 +- 2025.2.7 添加fish-speech tts
  12 +- 2025.2.21 添加wav2lip256开源模型 感谢@不蠢不蠢
15 - 2025.3.2 添加腾讯语音合成服务 13 - 2025.3.2 添加腾讯语音合成服务
  14 +- 2025.3.16 支持mac gpu推理,感谢[@GcsSloop](https://github.com/GcsSloop)
16 15
17 ## Features 16 ## Features
18 -  
19 1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human 17 1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human
20 2. 支持声音克隆 18 2. 支持声音克隆
21 3. 支持数字人说话被打断 19 3. 支持数字人说话被打断
22 4. 支持全身视频拼接 20 4. 支持全身视频拼接
23 -5. 支持 rtmp 和 webrtc 21 +5. 支持rtmp和webrtc
24 6. 支持视频编排:不说话时播放自定义视频 22 6. 支持视频编排:不说话时播放自定义视频
25 7. 支持多并发 23 7. 支持多并发
26 24
@@ -33,67 +31,61 @@ Tested on Ubuntu 20.04, Python3.10, Pytorch 1.12 and CUDA 11.3 @@ -33,67 +31,61 @@ Tested on Ubuntu 20.04, Python3.10, Pytorch 1.12 and CUDA 11.3
33 ```bash 31 ```bash
34 conda create -n nerfstream python=3.10 32 conda create -n nerfstream python=3.10
35 conda activate nerfstream 33 conda activate nerfstream
36 -#如果cuda版本不为11.3(运行nvidia-smi确认版本),根据<https://pytorch.org/get-started/previous-versions/>安装对应版本的pytorch 34 +#如果cuda版本不为11.3(运行nvidia-smi确认版本),根据<https://pytorch.org/get-started/previous-versions/>安装对应版本的pytorch
37 conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch 35 conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
38 pip install -r requirements.txt 36 pip install -r requirements.txt
39 #如果需要训练ernerf模型,安装下面的库 37 #如果需要训练ernerf模型,安装下面的库
40 # pip install "git+https://github.com/facebookresearch/pytorch3d.git" 38 # pip install "git+https://github.com/facebookresearch/pytorch3d.git"
41 # pip install tensorflow-gpu==2.8.0 39 # pip install tensorflow-gpu==2.8.0
42 # pip install --upgrade "protobuf<=3.20.1" 40 # pip install --upgrade "protobuf<=3.20.1"
43 -```  
44 - 41 +```
45 安装常见问题[FAQ](https://livetalking-doc.readthedocs.io/en/latest/faq.html) 42 安装常见问题[FAQ](https://livetalking-doc.readthedocs.io/en/latest/faq.html)
46 -linux cuda 环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886 43 +linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886
47 44
48 -## 2. Quick Start  
49 45
  46 +## 2. Quick Start
50 - 下载模型 47 - 下载模型
51 - 百度云盘<https://pan.baidu.com/s/1yOsQ06-RIDTJd3HFCw4wtA> 密码: ltua 48 + 夸克云盘<https://pan.quark.cn/s/83a750323ef0>
52 GoogleDriver <https://drive.google.com/drive/folders/1FOC_MD6wdogyyX_7V1d4NDIO7P9NlSAJ?usp=sharing> 49 GoogleDriver <https://drive.google.com/drive/folders/1FOC_MD6wdogyyX_7V1d4NDIO7P9NlSAJ?usp=sharing>
53 - 将 wav2lip256.pth 拷到本项目的 models 下, 重命名为 wav2lip.pth;  
54 - 将 wav2lip256_avatar1.tar.gz 解压后整个文件夹拷到本项目的 data/avatars 下 50 + 将wav2lip256.pth拷到本项目的models下, 重命名为wav2lip.pth;
  51 + 将wav2lip256_avatar1.tar.gz解压后整个文件夹拷到本项目的data/avatars下
55 - 运行 52 - 运行
56 python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar1 --preload 2 53 python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar1 --preload 2
57 -  
58 使用 GPU 启动模特 3 号:python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar3 --preload 2 54 使用 GPU 启动模特 3 号:python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar3 --preload 2
59 - 用浏览器打开 http://serverip:8010/webrtcapi.html , 先点‘start',播放数字人视频;然后在文本框输入任意文字,提交。数字人播报该段文字 55 +
  56 +用浏览器打开http://serverip:8010/webrtcapi.html , 先点‘start',播放数字人视频;然后在文本框输入任意文字,提交。数字人播报该段文字
60 <font color=red>服务端需要开放端口 tcp:8010; udp:1-65536 </font> 57 <font color=red>服务端需要开放端口 tcp:8010; udp:1-65536 </font>
61 - 如果需要商用高清 wav2lip 模型,可以与我联系购买 58 + 如果需要商用高清wav2lip模型,[链接](https://livetalking-doc.readthedocs.io/zh-cn/latest/service.html#wav2lip)
62 59
63 - 快速体验 60 - 快速体验
64 <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_GitHub_livetalking1.3> 用该镜像创建实例即可运行成功 61 <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_GitHub_livetalking1.3> 用该镜像创建实例即可运行成功
65 62
66 -如果访问不了 huggingface,在运行前  
67 - 63 +如果访问不了huggingface,在运行前
68 ``` 64 ```
69 export HF_ENDPOINT=https://hf-mirror.com 65 export HF_ENDPOINT=https://hf-mirror.com
70 -``` 66 +```
71 67
72 -## 3. More Usage  
73 68
  69 +## 3. More Usage
74 使用说明: <https://livetalking-doc.readthedocs.io/> 70 使用说明: <https://livetalking-doc.readthedocs.io/>
75 71
76 ## 4. Docker Run 72 ## 4. Docker Run
77 -  
78 不需要前面的安装,直接运行。 73 不需要前面的安装,直接运行。
79 -  
80 ``` 74 ```
81 docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/codewithgpu2/lipku-metahuman-stream:2K9qaMBu8v 75 docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/codewithgpu2/lipku-metahuman-stream:2K9qaMBu8v
82 ``` 76 ```
83 -  
84 -代码在/root/metahuman-stream,先 git pull 拉一下最新代码,然后执行命令同第 2、3 步 77 +代码在/root/metahuman-stream,先git pull拉一下最新代码,然后执行命令同第2、3步
85 78
86 提供如下镜像 79 提供如下镜像
  80 +- autodl镜像: <https://www.codewithgpu.com/i/lipku/metahuman-stream/base>
  81 + [autodl教程](https://livetalking-doc.readthedocs.io/en/latest/autodl/README.html)
  82 +- ucloud镜像: <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_livetalking1.3>
  83 + 可以开放任意端口,不需要另外部署srs服务.
  84 + [ucloud教程](https://livetalking-doc.readthedocs.io/en/latest/ucloud/ucloud.html)
87 85
88 -- autodl 镜像: <https://www.codewithgpu.com/i/lipku/metahuman-stream/base>  
89 - [autodl 教程](https://livetalking-doc.readthedocs.io/en/latest/autodl/README.html)  
90 -- ucloud 镜像: <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_livetalking1.3>  
91 - 可以开放任意端口,不需要另外部署 srs 服务.  
92 - [ucloud 教程](https://livetalking-doc.readthedocs.io/en/latest/ucloud/ucloud.html)  
93 86
94 ## 5. TODO 87 ## 5. TODO
95 -  
96 -- [x] 添加 chatgpt 实现数字人对话 88 +- [x] 添加chatgpt实现数字人对话
97 - [x] 声音克隆 89 - [x] 声音克隆
98 - [x] 数字人静音时用一段视频代替 90 - [x] 数字人静音时用一段视频代替
99 - [x] MuseTalk 91 - [x] MuseTalk
@@ -101,9 +93,8 @@ docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/c @@ -101,9 +93,8 @@ docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/c
101 - [x] Ultralight-Digital-Human 93 - [x] Ultralight-Digital-Human
102 94
103 --- 95 ---
  96 +如果本项目对你有帮助,帮忙点个star。也欢迎感兴趣的朋友一起来完善该项目.
  97 +* 知识星球: https://t.zsxq.com/7NMyO 沉淀高质量常见问题、最佳实践经验、问题解答
  98 +* 微信公众号:数字人技术
  99 + ![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyfaiaLZGuMGQXnhLWxibpJUS2gfs8Dje6JuMY8zu2tVyU9n8Zx1yaNncvKHBMibX0ocehoITy5qQEZg/640?wxfrom=12&tp=wxpic&usePicPrefetch=1&wx_fmt=jpeg&amp;from=appmsg)
104 100
105 -如果本项目对你有帮助,帮忙点个 star。也欢迎感兴趣的朋友一起来完善该项目.  
106 -  
107 -- 知识星球: https://t.zsxq.com/7NMyO 沉淀高质量常见问题、最佳实践经验、问题解答  
108 -- 微信公众号:数字人技术  
109 - ![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyfaiaLZGuMGQXnhLWxibpJUS2gfs8Dje6JuMY8zu2tVyU9n8Zx1yaNncvKHBMibX0ocehoITy5qQEZg/640?wxfrom=12&tp=wxpic&usePicPrefetch=1&wx_fmt=jpeg&from=appmsg)  
@@ -201,7 +201,7 @@ async def set_audiotype(request): @@ -201,7 +201,7 @@ async def set_audiotype(request):
201 params = await request.json() 201 params = await request.json()
202 202
203 sessionid = params.get('sessionid',0) 203 sessionid = params.get('sessionid',0)
204 - nerfreals[sessionid].set_curr_state(params['audiotype'],params['reinit']) 204 + nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit'])
205 205
206 return web.Response( 206 return web.Response(
207 content_type="application/json", 207 content_type="application/json",
@@ -495,6 +495,8 @@ if __name__ == '__main__': @@ -495,6 +495,8 @@ if __name__ == '__main__':
495 elif opt.transport=='rtcpush': 495 elif opt.transport=='rtcpush':
496 pagename='rtcpushapi.html' 496 pagename='rtcpushapi.html'
497 logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename) 497 logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
  498 + logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
  499 +
498 def run_server(runner): 500 def run_server(runner):
499 loop = asyncio.new_event_loop() 501 loop = asyncio.new_event_loop()
500 asyncio.set_event_loop(loop) 502 asyncio.set_event_loop(loop)
@@ -35,7 +35,7 @@ import soundfile as sf @@ -35,7 +35,7 @@ import soundfile as sf
35 import av 35 import av
36 from fractions import Fraction 36 from fractions import Fraction
37 37
38 -from ttsreal import EdgeTTS,VoitsTTS,XTTS,CosyVoiceTTS,FishTTS,TencentTTS 38 +from ttsreal import EdgeTTS,SovitsTTS,XTTS,CosyVoiceTTS,FishTTS,TencentTTS
39 from logger import logger 39 from logger import logger
40 40
41 from tqdm import tqdm 41 from tqdm import tqdm
@@ -57,7 +57,7 @@ class BaseReal: @@ -57,7 +57,7 @@ class BaseReal:
57 if opt.tts == "edgetts": 57 if opt.tts == "edgetts":
58 self.tts = EdgeTTS(opt,self) 58 self.tts = EdgeTTS(opt,self)
59 elif opt.tts == "gpt-sovits": 59 elif opt.tts == "gpt-sovits":
60 - self.tts = VoitsTTS(opt,self) 60 + self.tts = SovitsTTS(opt,self)
61 elif opt.tts == "xtts": 61 elif opt.tts == "xtts":
62 self.tts = XTTS(opt,self) 62 self.tts = XTTS(opt,self)
63 elif opt.tts == "cosyvoice": 63 elif opt.tts == "cosyvoice":
@@ -66,7 +66,7 @@ class BaseReal: @@ -66,7 +66,7 @@ class BaseReal:
66 self.tts = FishTTS(opt,self) 66 self.tts = FishTTS(opt,self)
67 elif opt.tts == "tencent": 67 elif opt.tts == "tencent":
68 self.tts = TencentTTS(opt,self) 68 self.tts = TencentTTS(opt,self)
69 - 69 +
70 self.speaking = False 70 self.speaking = False
71 71
72 self.recording = False 72 self.recording = False
@@ -84,11 +84,11 @@ class BaseReal: @@ -84,11 +84,11 @@ class BaseReal:
84 84
85 def put_msg_txt(self,msg,eventpoint=None): 85 def put_msg_txt(self,msg,eventpoint=None):
86 self.tts.put_msg_txt(msg,eventpoint) 86 self.tts.put_msg_txt(msg,eventpoint)
87 - 87 +
88 def put_audio_frame(self,audio_chunk,eventpoint=None): #16khz 20ms pcm 88 def put_audio_frame(self,audio_chunk,eventpoint=None): #16khz 20ms pcm
89 self.asr.put_audio_frame(audio_chunk,eventpoint) 89 self.asr.put_audio_frame(audio_chunk,eventpoint)
90 90
91 - def put_audio_file(self,filebyte): 91 + def put_audio_file(self,filebyte):
92 input_stream = BytesIO(filebyte) 92 input_stream = BytesIO(filebyte)
93 stream = self.__create_bytes_stream(input_stream) 93 stream = self.__create_bytes_stream(input_stream)
94 streamlen = stream.shape[0] 94 streamlen = stream.shape[0]
@@ -97,7 +97,7 @@ class BaseReal: @@ -97,7 +97,7 @@ class BaseReal:
97 self.put_audio_frame(stream[idx:idx+self.chunk]) 97 self.put_audio_frame(stream[idx:idx+self.chunk])
98 streamlen -= self.chunk 98 streamlen -= self.chunk
99 idx += self.chunk 99 idx += self.chunk
100 - 100 +
101 def __create_bytes_stream(self,byte_stream): 101 def __create_bytes_stream(self,byte_stream):
102 #byte_stream=BytesIO(buffer) 102 #byte_stream=BytesIO(buffer)
103 stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 103 stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
@@ -107,7 +107,7 @@ class BaseReal: @@ -107,7 +107,7 @@ class BaseReal:
107 if stream.ndim > 1: 107 if stream.ndim > 1:
108 logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') 108 logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
109 stream = stream[:, 0] 109 stream = stream[:, 0]
110 - 110 +
111 if sample_rate != self.sample_rate and stream.shape[0]>0: 111 if sample_rate != self.sample_rate and stream.shape[0]>0:
112 logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') 112 logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
113 stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) 113 stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
@@ -120,7 +120,7 @@ class BaseReal: @@ -120,7 +120,7 @@ class BaseReal:
120 120
121 def is_speaking(self)->bool: 121 def is_speaking(self)->bool:
122 return self.speaking 122 return self.speaking
123 - 123 +
124 def __loadcustom(self): 124 def __loadcustom(self):
125 for item in self.opt.customopt: 125 for item in self.opt.customopt:
126 logger.info(item) 126 logger.info(item)
@@ -155,9 +155,9 @@ class BaseReal: @@ -155,9 +155,9 @@ class BaseReal:
155 '-s', "{}x{}".format(self.width, self.height), 155 '-s', "{}x{}".format(self.width, self.height),
156 '-r', str(25), 156 '-r', str(25),
157 '-i', '-', 157 '-i', '-',
158 - '-pix_fmt', 'yuv420p', 158 + '-pix_fmt', 'yuv420p',
159 '-vcodec', "h264", 159 '-vcodec', "h264",
160 - #'-f' , 'flv', 160 + #'-f' , 'flv',
161 f'temp{self.opt.sessionid}.mp4'] 161 f'temp{self.opt.sessionid}.mp4']
162 self._record_video_pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE) 162 self._record_video_pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE)
163 163
@@ -169,7 +169,7 @@ class BaseReal: @@ -169,7 +169,7 @@ class BaseReal:
169 '-ar', '16000', 169 '-ar', '16000',
170 '-i', '-', 170 '-i', '-',
171 '-acodec', 'aac', 171 '-acodec', 'aac',
172 - #'-f' , 'wav', 172 + #'-f' , 'wav',
173 f'temp{self.opt.sessionid}.aac'] 173 f'temp{self.opt.sessionid}.aac']
174 self._record_audio_pipe = subprocess.Popen(acommand, shell=False, stdin=subprocess.PIPE) 174 self._record_audio_pipe = subprocess.Popen(acommand, shell=False, stdin=subprocess.PIPE)
175 175
@@ -177,10 +177,10 @@ class BaseReal: @@ -177,10 +177,10 @@ class BaseReal:
177 # self.recordq_video.queue.clear() 177 # self.recordq_video.queue.clear()
178 # self.recordq_audio.queue.clear() 178 # self.recordq_audio.queue.clear()
179 # self.container = av.open(path, mode="w") 179 # self.container = av.open(path, mode="w")
180 - 180 +
181 # process_thread = Thread(target=self.record_frame, args=()) 181 # process_thread = Thread(target=self.record_frame, args=())
182 # process_thread.start() 182 # process_thread.start()
183 - 183 +
184 def record_video_data(self,image): 184 def record_video_data(self,image):
185 if self.width == 0: 185 if self.width == 0:
186 print("image.shape:",image.shape) 186 print("image.shape:",image.shape)
@@ -191,14 +191,14 @@ class BaseReal: @@ -191,14 +191,14 @@ class BaseReal:
191 def record_audio_data(self,frame): 191 def record_audio_data(self,frame):
192 if self.recording: 192 if self.recording:
193 self._record_audio_pipe.stdin.write(frame.tostring()) 193 self._record_audio_pipe.stdin.write(frame.tostring())
194 -  
195 - # def record_frame(self): 194 +
  195 + # def record_frame(self):
196 # videostream = self.container.add_stream("libx264", rate=25) 196 # videostream = self.container.add_stream("libx264", rate=25)
197 # videostream.codec_context.time_base = Fraction(1, 25) 197 # videostream.codec_context.time_base = Fraction(1, 25)
198 # audiostream = self.container.add_stream("aac") 198 # audiostream = self.container.add_stream("aac")
199 # audiostream.codec_context.time_base = Fraction(1, 16000) 199 # audiostream.codec_context.time_base = Fraction(1, 16000)
200 # init = True 200 # init = True
201 - # framenum = 0 201 + # framenum = 0
202 # while self.recording: 202 # while self.recording:
203 # try: 203 # try:
204 # videoframe = self.recordq_video.get(block=True, timeout=1) 204 # videoframe = self.recordq_video.get(block=True, timeout=1)
@@ -231,18 +231,18 @@ class BaseReal: @@ -231,18 +231,18 @@ class BaseReal:
231 # self.recordq_video.queue.clear() 231 # self.recordq_video.queue.clear()
232 # self.recordq_audio.queue.clear() 232 # self.recordq_audio.queue.clear()
233 # print('record thread stop') 233 # print('record thread stop')
234 - 234 +
235 def stop_recording(self): 235 def stop_recording(self):
236 """停止录制视频""" 236 """停止录制视频"""
237 if not self.recording: 237 if not self.recording:
238 return 238 return
239 - self.recording = False  
240 - self._record_video_pipe.stdin.close() #wait() 239 + self.recording = False
  240 + self._record_video_pipe.stdin.close() #wait()
241 self._record_video_pipe.wait() 241 self._record_video_pipe.wait()
242 self._record_audio_pipe.stdin.close() 242 self._record_audio_pipe.stdin.close()
243 self._record_audio_pipe.wait() 243 self._record_audio_pipe.wait()
244 cmd_combine_audio = f"ffmpeg -y -i temp{self.opt.sessionid}.aac -i temp{self.opt.sessionid}.mp4 -c:v copy -c:a copy data/record.mp4" 244 cmd_combine_audio = f"ffmpeg -y -i temp{self.opt.sessionid}.aac -i temp{self.opt.sessionid}.mp4 -c:v copy -c:a copy data/record.mp4"
245 - os.system(cmd_combine_audio) 245 + os.system(cmd_combine_audio)
246 #os.remove(output_path) 246 #os.remove(output_path)
247 247
248 def mirror_index(self,size, index): 248 def mirror_index(self,size, index):
@@ -252,8 +252,8 @@ class BaseReal: @@ -252,8 +252,8 @@ class BaseReal:
252 if turn % 2 == 0: 252 if turn % 2 == 0:
253 return res 253 return res
254 else: 254 else:
255 - return size - res - 1  
256 - 255 + return size - res - 1
  256 +
257 def get_audio_stream(self,audiotype): 257 def get_audio_stream(self,audiotype):
258 idx = self.custom_audio_index[audiotype] 258 idx = self.custom_audio_index[audiotype]
259 stream = self.custom_audio_cycle[audiotype][idx:idx+self.chunk] 259 stream = self.custom_audio_cycle[audiotype][idx:idx+self.chunk]
@@ -261,9 +261,9 @@ class BaseReal: @@ -261,9 +261,9 @@ class BaseReal:
261 if self.custom_audio_index[audiotype]>=self.custom_audio_cycle[audiotype].shape[0]: 261 if self.custom_audio_index[audiotype]>=self.custom_audio_cycle[audiotype].shape[0]:
262 self.curr_state = 1 #当前视频不循环播放,切换到静音状态 262 self.curr_state = 1 #当前视频不循环播放,切换到静音状态
263 return stream 263 return stream
264 -  
265 - def set_curr_state(self,audiotype, reinit):  
266 - print('set_curr_state:',audiotype) 264 +
  265 + def set_custom_state(self,audiotype, reinit=True):
  266 + print('set_custom_state:',audiotype)
267 self.curr_state = audiotype 267 self.curr_state = audiotype
268 if reinit: 268 if reinit:
269 self.custom_audio_index[audiotype] = 0 269 self.custom_audio_index[audiotype] = 0
@@ -179,8 +179,11 @@ print(f'[INFO] fitting light...') @@ -179,8 +179,11 @@ print(f'[INFO] fitting light...')
179 179
180 batch_size = 32 180 batch_size = 32
181 181
182 -device_default = torch.device("cuda:0")  
183 -device_render = torch.device("cuda:0") 182 +device_default = torch.device("cuda:0" if torch.cuda.is_available() else (
  183 + "mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
  184 +device_render = torch.device("cuda:0" if torch.cuda.is_available() else (
  185 + "mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
  186 +
184 renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render) 187 renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render)
185 188
186 sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size] 189 sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size]
@@ -83,7 +83,7 @@ class Render_3DMM(nn.Module): @@ -83,7 +83,7 @@ class Render_3DMM(nn.Module):
83 img_h=500, 83 img_h=500,
84 img_w=500, 84 img_w=500,
85 batch_size=1, 85 batch_size=1,
86 - device=torch.device("cuda:0"), 86 + device=torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")),
87 ): 87 ):
88 super(Render_3DMM, self).__init__() 88 super(Render_3DMM, self).__init__()
89 89
@@ -147,7 +147,7 @@ if __name__ == '__main__': @@ -147,7 +147,7 @@ if __name__ == '__main__':
147 147
148 seed_everything(opt.seed) 148 seed_everything(opt.seed)
149 149
150 - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 150 + device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
151 151
152 model = NeRFNetwork(opt) 152 model = NeRFNetwork(opt)
153 153
@@ -442,7 +442,7 @@ class LPIPSMeter: @@ -442,7 +442,7 @@ class LPIPSMeter:
442 self.N = 0 442 self.N = 0
443 self.net = net 443 self.net = net
444 444
445 - self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') 445 + self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else ('mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu'))
446 self.fn = lpips.LPIPS(net=net).eval().to(self.device) 446 self.fn = lpips.LPIPS(net=net).eval().to(self.device)
447 447
448 def clear(self): 448 def clear(self):
@@ -456,13 +456,13 @@ class LPIPSMeter: @@ -456,13 +456,13 @@ class LPIPSMeter:
456 inp = inp.to(self.device) 456 inp = inp.to(self.device)
457 outputs.append(inp) 457 outputs.append(inp)
458 return outputs 458 return outputs
459 - 459 +
460 def update(self, preds, truths): 460 def update(self, preds, truths):
461 preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1] 461 preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1]
462 v = self.fn(truths, preds, normalize=True).item() # normalize=True: [0, 1] to [-1, 1] 462 v = self.fn(truths, preds, normalize=True).item() # normalize=True: [0, 1] to [-1, 1]
463 self.V += v 463 self.V += v
464 self.N += 1 464 self.N += 1
465 - 465 +
466 def measure(self): 466 def measure(self):
467 return self.V / self.N 467 return self.V / self.N
468 468
@@ -499,7 +499,7 @@ class LMDMeter: @@ -499,7 +499,7 @@ class LMDMeter:
499 499
500 self.V = 0 500 self.V = 0
501 self.N = 0 501 self.N = 0
502 - 502 +
503 def get_landmarks(self, img): 503 def get_landmarks(self, img):
504 504
505 if self.backend == 'dlib': 505 if self.backend == 'dlib':
@@ -515,7 +515,7 @@ class LMDMeter: @@ -515,7 +515,7 @@ class LMDMeter:
515 515
516 else: 516 else:
517 lms = self.predictor.get_landmarks(img)[-1] 517 lms = self.predictor.get_landmarks(img)[-1]
518 - 518 +
519 # self.vis_landmarks(img, lms) 519 # self.vis_landmarks(img, lms)
520 lms = lms.astype(np.float32) 520 lms = lms.astype(np.float32)
521 521
@@ -537,7 +537,7 @@ class LMDMeter: @@ -537,7 +537,7 @@ class LMDMeter:
537 inp = (inp * 255).astype(np.uint8) 537 inp = (inp * 255).astype(np.uint8)
538 outputs.append(inp) 538 outputs.append(inp)
539 return outputs 539 return outputs
540 - 540 +
541 def update(self, preds, truths): 541 def update(self, preds, truths):
542 # assert B == 1 542 # assert B == 1
543 preds, truths = self.prepare_inputs(preds[0], truths[0]) # [H, W, 3] numpy array 543 preds, truths = self.prepare_inputs(preds[0], truths[0]) # [H, W, 3] numpy array
@@ -553,13 +553,13 @@ class LMDMeter: @@ -553,13 +553,13 @@ class LMDMeter:
553 # avarage 553 # avarage
554 lms_pred = lms_pred - lms_pred.mean(0) 554 lms_pred = lms_pred - lms_pred.mean(0)
555 lms_truth = lms_truth - lms_truth.mean(0) 555 lms_truth = lms_truth - lms_truth.mean(0)
556 - 556 +
557 # distance 557 # distance
558 dist = np.sqrt(((lms_pred - lms_truth) ** 2).sum(1)).mean(0) 558 dist = np.sqrt(((lms_pred - lms_truth) ** 2).sum(1)).mean(0)
559 - 559 +
560 self.V += dist 560 self.V += dist
561 self.N += 1 561 self.N += 1
562 - 562 +
563 def measure(self): 563 def measure(self):
564 return self.V / self.N 564 return self.V / self.N
565 565
@@ -567,14 +567,14 @@ class LMDMeter: @@ -567,14 +567,14 @@ class LMDMeter:
567 writer.add_scalar(os.path.join(prefix, f"LMD ({self.backend})"), self.measure(), global_step) 567 writer.add_scalar(os.path.join(prefix, f"LMD ({self.backend})"), self.measure(), global_step)
568 568
569 def report(self): 569 def report(self):
570 - return f'LMD ({self.backend}) = {self.measure():.6f}'  
571 - 570 + return f'LMD ({self.backend}) = {self.measure():.6f}'
  571 +
572 572
573 class Trainer(object): 573 class Trainer(object):
574 - def __init__(self, 574 + def __init__(self,
575 name, # name of this experiment 575 name, # name of this experiment
576 opt, # extra conf 576 opt, # extra conf
577 - model, # network 577 + model, # network
578 criterion=None, # loss function, if None, assume inline implementation in train_step 578 criterion=None, # loss function, if None, assume inline implementation in train_step
579 optimizer=None, # optimizer 579 optimizer=None, # optimizer
580 ema_decay=None, # if use EMA, set the decay 580 ema_decay=None, # if use EMA, set the decay
@@ -596,7 +596,7 @@ class Trainer(object): @@ -596,7 +596,7 @@ class Trainer(object):
596 use_tensorboardX=True, # whether to use tensorboard for logging 596 use_tensorboardX=True, # whether to use tensorboard for logging
597 scheduler_update_every_step=False, # whether to call scheduler.step() after every train step 597 scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
598 ): 598 ):
599 - 599 +
600 self.name = name 600 self.name = name
601 self.opt = opt 601 self.opt = opt
602 self.mute = mute 602 self.mute = mute
@@ -618,7 +618,11 @@ class Trainer(object): @@ -618,7 +618,11 @@ class Trainer(object):
618 self.flip_init_lips = self.opt.init_lips 618 self.flip_init_lips = self.opt.init_lips
619 self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") 619 self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
620 self.scheduler_update_every_step = scheduler_update_every_step 620 self.scheduler_update_every_step = scheduler_update_every_step
621 - self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') 621 + self.device = device if device is not None else torch.device(
  622 + f'cuda:{local_rank}' if torch.cuda.is_available() else (
  623 + 'mps' if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else 'cpu'
  624 + )
  625 + )
622 self.console = Console() 626 self.console = Console()
623 627
624 model.to(self.device) 628 model.to(self.device)
@@ -56,10 +56,8 @@ from ultralight.unet import Model @@ -56,10 +56,8 @@ from ultralight.unet import Model
56 from ultralight.audio2feature import Audio2Feature 56 from ultralight.audio2feature import Audio2Feature
57 from logger import logger 57 from logger import logger
58 58
59 -  
60 -device = 'cuda' if torch.cuda.is_available() else 'cpu'  
61 -logger.info('Using {} for inference.'.format(device))  
62 - 59 +device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
  60 +print('Using {} for inference.'.format(device))
63 61
64 def load_model(opt): 62 def load_model(opt):
65 audio_processor = Audio2Feature() 63 audio_processor = Audio2Feature()
@@ -44,8 +44,8 @@ from basereal import BaseReal @@ -44,8 +44,8 @@ from basereal import BaseReal
44 from tqdm import tqdm 44 from tqdm import tqdm
45 from logger import logger 45 from logger import logger
46 46
47 -device = 'cuda' if torch.cuda.is_available() else 'cpu'  
48 -logger.info('Using {} for inference.'.format(device)) 47 +device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
  48 +print('Using {} for inference.'.format(device))
49 49
50 def _load(checkpoint_path): 50 def _load(checkpoint_path):
51 if device == 'cuda': 51 if device == 'cuda':
@@ -51,7 +51,7 @@ from logger import logger @@ -51,7 +51,7 @@ from logger import logger
51 def load_model(): 51 def load_model():
52 # load model weights 52 # load model weights
53 audio_processor,vae, unet, pe = load_all_model() 53 audio_processor,vae, unet, pe = load_all_model()
54 - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 54 + device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
55 timesteps = torch.tensor([0], device=device) 55 timesteps = torch.tensor([0], device=device)
56 pe = pe.half() 56 pe = pe.half()
57 vae.vae = vae.vae.half() 57 vae.vae = vae.vae.half()
@@ -64,7 +64,7 @@ def load_avatar(avatar_id): @@ -64,7 +64,7 @@ def load_avatar(avatar_id):
64 #self.video_path = '' #video_path 64 #self.video_path = '' #video_path
65 #self.bbox_shift = opt.bbox_shift 65 #self.bbox_shift = opt.bbox_shift
66 avatar_path = f"./data/avatars/{avatar_id}" 66 avatar_path = f"./data/avatars/{avatar_id}"
67 - full_imgs_path = f"{avatar_path}/full_imgs" 67 + full_imgs_path = f"{avatar_path}/full_imgs"
68 coords_path = f"{avatar_path}/coords.pkl" 68 coords_path = f"{avatar_path}/coords.pkl"
69 latents_out_path= f"{avatar_path}/latents.pt" 69 latents_out_path= f"{avatar_path}/latents.pt"
70 video_out_path = f"{avatar_path}/vid_output/" 70 video_out_path = f"{avatar_path}/vid_output/"
@@ -74,7 +74,7 @@ def load_avatar(avatar_id): @@ -74,7 +74,7 @@ def load_avatar(avatar_id):
74 # self.avatar_info = { 74 # self.avatar_info = {
75 # "avatar_id":self.avatar_id, 75 # "avatar_id":self.avatar_id,
76 # "video_path":self.video_path, 76 # "video_path":self.video_path,
77 - # "bbox_shift":self.bbox_shift 77 + # "bbox_shift":self.bbox_shift
78 # } 78 # }
79 79
80 input_latent_list_cycle = torch.load(latents_out_path) #,weights_only=True 80 input_latent_list_cycle = torch.load(latents_out_path) #,weights_only=True
@@ -124,19 +124,19 @@ def __mirror_index(size, index): @@ -124,19 +124,19 @@ def __mirror_index(size, index):
124 if turn % 2 == 0: 124 if turn % 2 == 0:
125 return res 125 return res
126 else: 126 else:
127 - return size - res - 1 127 + return size - res - 1
128 128
129 @torch.no_grad() 129 @torch.no_grad()
130 def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue, 130 def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue,
131 vae, unet, pe,timesteps): #vae, unet, pe,timesteps 131 vae, unet, pe,timesteps): #vae, unet, pe,timesteps
132 - 132 +
133 # vae, unet, pe = load_diffusion_model() 133 # vae, unet, pe = load_diffusion_model()
134 # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 134 # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
135 # timesteps = torch.tensor([0], device=device) 135 # timesteps = torch.tensor([0], device=device)
136 # pe = pe.half() 136 # pe = pe.half()
137 # vae.vae = vae.vae.half() 137 # vae.vae = vae.vae.half()
138 # unet.model = unet.model.half() 138 # unet.model = unet.model.half()
139 - 139 +
140 length = len(input_latent_list_cycle) 140 length = len(input_latent_list_cycle)
141 index = 0 141 index = 0
142 count=0 142 count=0
@@ -169,7 +169,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a @@ -169,7 +169,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a
169 latent = input_latent_list_cycle[idx] 169 latent = input_latent_list_cycle[idx]
170 latent_batch.append(latent) 170 latent_batch.append(latent)
171 latent_batch = torch.cat(latent_batch, dim=0) 171 latent_batch = torch.cat(latent_batch, dim=0)
172 - 172 +
173 # for i, (whisper_batch,latent_batch) in enumerate(gen): 173 # for i, (whisper_batch,latent_batch) in enumerate(gen):
174 audio_feature_batch = torch.from_numpy(whisper_batch) 174 audio_feature_batch = torch.from_numpy(whisper_batch)
175 audio_feature_batch = audio_feature_batch.to(device=unet.device, 175 audio_feature_batch = audio_feature_batch.to(device=unet.device,
@@ -179,8 +179,8 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a @@ -179,8 +179,8 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a
179 # print('prepare time:',time.perf_counter()-t) 179 # print('prepare time:',time.perf_counter()-t)
180 # t=time.perf_counter() 180 # t=time.perf_counter()
181 181
182 - pred_latents = unet.model(latent_batch,  
183 - timesteps, 182 + pred_latents = unet.model(latent_batch,
  183 + timesteps,
184 encoder_hidden_states=audio_feature_batch).sample 184 encoder_hidden_states=audio_feature_batch).sample
185 # print('unet time:',time.perf_counter()-t) 185 # print('unet time:',time.perf_counter()-t)
186 # t=time.perf_counter() 186 # t=time.perf_counter()
@@ -203,7 +203,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a @@ -203,7 +203,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a
203 #self.__pushmedia(res_frame,loop,audio_track,video_track) 203 #self.__pushmedia(res_frame,loop,audio_track,video_track)
204 res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) 204 res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
205 index = index + 1 205 index = index + 1
206 - #print('total batch time:',time.perf_counter()-starttime) 206 + #print('total batch time:',time.perf_counter()-starttime)
207 logger.info('musereal inference processor stop') 207 logger.info('musereal inference processor stop')
208 208
209 class MuseReal(BaseReal): 209 class MuseReal(BaseReal):
@@ -226,12 +226,12 @@ class MuseReal(BaseReal): @@ -226,12 +226,12 @@ class MuseReal(BaseReal):
226 226
227 self.asr = MuseASR(opt,self,self.audio_processor) 227 self.asr = MuseASR(opt,self,self.audio_processor)
228 self.asr.warm_up() 228 self.asr.warm_up()
229 - 229 +
230 self.render_event = mp.Event() 230 self.render_event = mp.Event()
231 231
232 def __del__(self): 232 def __del__(self):
233 logger.info(f'musereal({self.sessionid}) delete') 233 logger.info(f'musereal({self.sessionid}) delete')
234 - 234 +
235 235
236 def __mirror_index(self, index): 236 def __mirror_index(self, index):
237 size = len(self.coord_list_cycle) 237 size = len(self.coord_list_cycle)
@@ -240,9 +240,9 @@ class MuseReal(BaseReal): @@ -240,9 +240,9 @@ class MuseReal(BaseReal):
240 if turn % 2 == 0: 240 if turn % 2 == 0:
241 return res 241 return res
242 else: 242 else:
243 - return size - res - 1 243 + return size - res - 1
244 244
245 - def __warm_up(self): 245 + def __warm_up(self):
246 self.asr.run_step() 246 self.asr.run_step()
247 whisper_chunks = self.asr.get_next_feat() 247 whisper_chunks = self.asr.get_next_feat()
248 whisper_batch = np.stack(whisper_chunks) 248 whisper_batch = np.stack(whisper_chunks)
@@ -260,30 +260,57 @@ class MuseReal(BaseReal): @@ -260,30 +260,57 @@ class MuseReal(BaseReal):
260 audio_feature_batch = self.pe(audio_feature_batch) 260 audio_feature_batch = self.pe(audio_feature_batch)
261 latent_batch = latent_batch.to(dtype=self.unet.model.dtype) 261 latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
262 262
263 - pred_latents = self.unet.model(latent_batch,  
264 - self.timesteps, 263 + pred_latents = self.unet.model(latent_batch,
  264 + self.timesteps,
265 encoder_hidden_states=audio_feature_batch).sample 265 encoder_hidden_states=audio_feature_batch).sample
266 recon = self.vae.decode_latents(pred_latents) 266 recon = self.vae.decode_latents(pred_latents)
267 - 267 +
268 268
269 def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): 269 def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
270 - 270 + enable_transition = True # 设置为False禁用过渡效果,True启用
  271 +
  272 + if enable_transition:
  273 + self.last_speaking = False
  274 + self.transition_start = time.time()
  275 + self.transition_duration = 0.1 # 过渡时间
  276 + self.last_silent_frame = None # 静音帧缓存
  277 + self.last_speaking_frame = None # 说话帧缓存
  278 +
271 while not quit_event.is_set(): 279 while not quit_event.is_set():
272 try: 280 try:
273 res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1) 281 res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1)
274 except queue.Empty: 282 except queue.Empty:
275 continue 283 continue
276 - if audio_frames[0][1]!=0 and audio_frames[1][1]!=0: #全为静音数据,只需要取fullimg 284 +
  285 + if enable_transition:
  286 + # 检测状态变化
  287 + current_speaking = not (audio_frames[0][1]!=0 and audio_frames[1][1]!=0)
  288 + if current_speaking != self.last_speaking:
  289 + logger.info(f"状态切换:{'说话' if self.last_speaking else '静音'} → {'说话' if current_speaking else '静音'}")
  290 + self.transition_start = time.time()
  291 + self.last_speaking = current_speaking
  292 +
  293 + if audio_frames[0][1]!=0 and audio_frames[1][1]!=0:
277 self.speaking = False 294 self.speaking = False
278 audiotype = audio_frames[0][1] 295 audiotype = audio_frames[0][1]
279 - if self.custom_index.get(audiotype) is not None: #有自定义视频 296 + if self.custom_index.get(audiotype) is not None:
280 mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype]) 297 mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype])
281 - combine_frame = self.custom_img_cycle[audiotype][mirindex] 298 + target_frame = self.custom_img_cycle[audiotype][mirindex]
282 self.custom_index[audiotype] += 1 299 self.custom_index[audiotype] += 1
283 - # if not self.custom_opt[audiotype].loop and self.custom_index[audiotype]>=len(self.custom_img_cycle[audiotype]):  
284 - # self.curr_state = 1 #当前视频不循环播放,切换到静音状态  
285 else: 300 else:
286 - combine_frame = self.frame_list_cycle[idx] 301 + target_frame = self.frame_list_cycle[idx]
  302 +
  303 + if enable_transition:
  304 + # 说话→静音过渡
  305 + if time.time() - self.transition_start < self.transition_duration and self.last_speaking_frame is not None:
  306 + alpha = min(1.0, (time.time() - self.transition_start) / self.transition_duration)
  307 + combine_frame = cv2.addWeighted(self.last_speaking_frame, 1-alpha, target_frame, alpha, 0)
  308 + else:
  309 + combine_frame = target_frame
  310 + # 缓存静音帧
  311 + self.last_silent_frame = combine_frame.copy()
  312 + else:
  313 + combine_frame = target_frame
287 else: 314 else:
288 self.speaking = True 315 self.speaking = True
289 bbox = self.coord_list_cycle[idx] 316 bbox = self.coord_list_cycle[idx]
@@ -291,20 +318,29 @@ class MuseReal(BaseReal): @@ -291,20 +318,29 @@ class MuseReal(BaseReal):
291 x1, y1, x2, y2 = bbox 318 x1, y1, x2, y2 = bbox
292 try: 319 try:
293 res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) 320 res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
294 - except: 321 + except Exception as e:
  322 + logger.warning(f"resize error: {e}")
295 continue 323 continue
296 mask = self.mask_list_cycle[idx] 324 mask = self.mask_list_cycle[idx]
297 mask_crop_box = self.mask_coords_list_cycle[idx] 325 mask_crop_box = self.mask_coords_list_cycle[idx]
298 - #combine_frame = get_image(ori_frame,res_frame,bbox)  
299 - #t=time.perf_counter()  
300 - combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)  
301 - #print('blending time:',time.perf_counter()-t)  
302 326
303 - image = combine_frame #(outputs['image'] * 255).astype(np.uint8) 327 + current_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
  328 + if enable_transition:
  329 + # 静音→说话过渡
  330 + if time.time() - self.transition_start < self.transition_duration and self.last_silent_frame is not None:
  331 + alpha = min(1.0, (time.time() - self.transition_start) / self.transition_duration)
  332 + combine_frame = cv2.addWeighted(self.last_silent_frame, 1-alpha, current_frame, alpha, 0)
  333 + else:
  334 + combine_frame = current_frame
  335 + # 缓存说话帧
  336 + self.last_speaking_frame = combine_frame.copy()
  337 + else:
  338 + combine_frame = current_frame
  339 +
  340 + image = combine_frame
304 new_frame = VideoFrame.from_ndarray(image, format="bgr24") 341 new_frame = VideoFrame.from_ndarray(image, format="bgr24")
305 asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) 342 asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop)
306 self.record_video_data(image) 343 self.record_video_data(image)
307 - #self.recordq_video.put(new_frame)  
308 344
309 for audio_frame in audio_frames: 345 for audio_frame in audio_frames:
310 frame,type,eventpoint = audio_frame 346 frame,type,eventpoint = audio_frame
@@ -312,12 +348,8 @@ class MuseReal(BaseReal): @@ -312,12 +348,8 @@ class MuseReal(BaseReal):
312 new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) 348 new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
313 new_frame.planes[0].update(frame.tobytes()) 349 new_frame.planes[0].update(frame.tobytes())
314 new_frame.sample_rate=16000 350 new_frame.sample_rate=16000
315 - # if audio_track._queue.qsize()>10:  
316 - # time.sleep(0.1)  
317 asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) 351 asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
318 self.record_audio_data(frame) 352 self.record_audio_data(frame)
319 - #self.notify(eventpoint)  
320 - #self.recordq_audio.put(new_frame)  
321 logger.info('musereal process_frames thread stop') 353 logger.info('musereal process_frames thread stop')
322 354
323 def render(self,quit_event,loop=None,audio_track=None,video_track=None): 355 def render(self,quit_event,loop=None,audio_track=None,video_track=None):
@@ -36,7 +36,7 @@ class UNet(): @@ -36,7 +36,7 @@ class UNet():
36 unet_config = json.load(f) 36 unet_config = json.load(f)
37 self.model = UNet2DConditionModel(**unet_config) 37 self.model = UNet2DConditionModel(**unet_config)
38 self.pe = PositionalEncoding(d_model=384) 38 self.pe = PositionalEncoding(d_model=384)
39 - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 39 + self.device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
40 weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device) 40 weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
41 self.model.load_state_dict(weights) 41 self.model.load_state_dict(weights)
42 if use_float16: 42 if use_float16:
@@ -23,7 +23,7 @@ class VAE(): @@ -23,7 +23,7 @@ class VAE():
23 self.model_path = model_path 23 self.model_path = model_path
24 self.vae = AutoencoderKL.from_pretrained(self.model_path) 24 self.vae = AutoencoderKL.from_pretrained(self.model_path)
25 25
26 - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 + self.device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
27 self.vae.to(self.device) 27 self.vae.to(self.device)
28 28
29 if use_float16: 29 if use_float16:
@@ -325,7 +325,7 @@ def create_musetalk_human(file, avatar_id): @@ -325,7 +325,7 @@ def create_musetalk_human(file, avatar_id):
325 325
326 326
327 # initialize the mmpose model 327 # initialize the mmpose model
328 -device = "cuda" if torch.cuda.is_available() else "cpu" 328 +device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
329 fa = FaceAlignment(1, flip_input=False, device=device) 329 fa = FaceAlignment(1, flip_input=False, device=device)
330 config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py') 330 config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py')
331 checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth')) 331 checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth'))
@@ -13,14 +13,14 @@ import torch @@ -13,14 +13,14 @@ import torch
13 from tqdm import tqdm 13 from tqdm import tqdm
14 14
15 # initialize the mmpose model 15 # initialize the mmpose model
16 -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 +device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
17 config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py' 17 config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
18 checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth' 18 checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
19 model = init_model(config_file, checkpoint_file, device=device) 19 model = init_model(config_file, checkpoint_file, device=device)
20 20
21 # initialize the face detection model 21 # initialize the face detection model
22 -device = "cuda" if torch.cuda.is_available() else "cpu"  
23 -fa = FaceAlignment(LandmarksType._2D, flip_input=False,device=device) 22 +device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
  23 +fa = FaceAlignment(LandmarksType._2D, flip_input=False, device=device)
24 24
25 # maker if the bbox is not sufficient 25 # maker if the bbox is not sufficient
26 coord_placeholder = (0.0,0.0,0.0,0.0) 26 coord_placeholder = (0.0,0.0,0.0,0.0)
@@ -91,7 +91,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow @@ -91,7 +91,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
91 """ 91 """
92 92
93 if device is None: 93 if device is None:
94 - device = "cuda" if torch.cuda.is_available() else "cpu" 94 + device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
95 if download_root is None: 95 if download_root is None:
96 download_root = os.getenv( 96 download_root = os.getenv(
97 "XDG_CACHE_HOME", 97 "XDG_CACHE_HOME",
@@ -78,17 +78,19 @@ def transcribe( @@ -78,17 +78,19 @@ def transcribe(
78 if dtype == torch.float16: 78 if dtype == torch.float16:
79 warnings.warn("FP16 is not supported on CPU; using FP32 instead") 79 warnings.warn("FP16 is not supported on CPU; using FP32 instead")
80 dtype = torch.float32 80 dtype = torch.float32
  81 + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
  82 + warnings.warn("Performing inference on CPU when MPS is available")
81 83
82 if dtype == torch.float32: 84 if dtype == torch.float32:
83 decode_options["fp16"] = False 85 decode_options["fp16"] = False
84 86
85 mel = log_mel_spectrogram(audio) 87 mel = log_mel_spectrogram(audio)
86 - 88 +
87 all_segments = [] 89 all_segments = []
88 def add_segment( 90 def add_segment(
89 *, start: float, end: float, encoder_embeddings 91 *, start: float, end: float, encoder_embeddings
90 ): 92 ):
91 - 93 +
92 all_segments.append( 94 all_segments.append(
93 { 95 {
94 "start": start, 96 "start": start,
@@ -100,20 +102,20 @@ def transcribe( @@ -100,20 +102,20 @@ def transcribe(
100 num_frames = mel.shape[-1] 102 num_frames = mel.shape[-1]
101 seek = 0 103 seek = 0
102 previous_seek_value = seek 104 previous_seek_value = seek
103 - sample_skip = 3000 # 105 + sample_skip = 3000 #
104 with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: 106 with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
105 while seek < num_frames: 107 while seek < num_frames:
106 # seek是开始的帧数 108 # seek是开始的帧数
107 end_seek = min(seek + sample_skip, num_frames) 109 end_seek = min(seek + sample_skip, num_frames)
108 segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype) 110 segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype)
109 - 111 +
110 single = segment.ndim == 2 112 single = segment.ndim == 2
111 if single: 113 if single:
112 segment = segment.unsqueeze(0) 114 segment = segment.unsqueeze(0)
113 if dtype == torch.float16: 115 if dtype == torch.float16:
114 segment = segment.half() 116 segment = segment.half()
115 audio_features, embeddings = model.encoder(segment, include_embeddings = True) 117 audio_features, embeddings = model.encoder(segment, include_embeddings = True)
116 - 118 +
117 encoder_embeddings = embeddings 119 encoder_embeddings = embeddings
118 #print(f"encoder_embeddings shape {encoder_embeddings.shape}") 120 #print(f"encoder_embeddings shape {encoder_embeddings.shape}")
119 add_segment( 121 add_segment(
@@ -124,7 +126,7 @@ def transcribe( @@ -124,7 +126,7 @@ def transcribe(
124 encoder_embeddings=encoder_embeddings, 126 encoder_embeddings=encoder_embeddings,
125 ) 127 )
126 seek+=sample_skip 128 seek+=sample_skip
127 - 129 +
128 return dict(segments=all_segments) 130 return dict(segments=all_segments)
129 131
130 132
@@ -135,7 +137,7 @@ def cli(): @@ -135,7 +137,7 @@ def cli():
135 parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") 137 parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
136 parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") 138 parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
137 parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") 139 parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
138 - parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") 140 + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "mps", help="device to use for PyTorch inference")
139 parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") 141 parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
140 parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") 142 parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
141 143
@@ -30,7 +30,7 @@ class NerfASR(BaseASR): @@ -30,7 +30,7 @@ class NerfASR(BaseASR):
30 def __init__(self, opt, parent, audio_processor,audio_model): 30 def __init__(self, opt, parent, audio_processor,audio_model):
31 super().__init__(opt,parent) 31 super().__init__(opt,parent)
32 32
33 - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 33 + self.device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
34 if 'esperanto' in self.opt.asr_model: 34 if 'esperanto' in self.opt.asr_model:
35 self.audio_dim = 44 35 self.audio_dim = 44
36 elif 'deepspeech' in self.opt.asr_model: 36 elif 'deepspeech' in self.opt.asr_model:
@@ -77,7 +77,7 @@ def load_model(opt): @@ -77,7 +77,7 @@ def load_model(opt):
77 seed_everything(opt.seed) 77 seed_everything(opt.seed)
78 logger.info(opt) 78 logger.info(opt)
79 79
80 - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 80 + device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else 'cpu'))
81 model = NeRFNetwork(opt) 81 model = NeRFNetwork(opt)
82 82
83 criterion = torch.nn.MSELoss(reduction='none') 83 criterion = torch.nn.MSELoss(reduction='none')
@@ -90,7 +90,7 @@ class BaseTTS: @@ -90,7 +90,7 @@ class BaseTTS:
90 ########################################################################################### 90 ###########################################################################################
91 class EdgeTTS(BaseTTS): 91 class EdgeTTS(BaseTTS):
92 def txt_to_audio(self,msg): 92 def txt_to_audio(self,msg):
93 - voicename = "zh-CN-XiaoxiaoNeural" 93 + voicename = "zh-CN-YunxiaNeural"
94 text,textevent = msg 94 text,textevent = msg
95 t = time.time() 95 t = time.time()
96 asyncio.new_event_loop().run_until_complete(self.__main(voicename,text)) 96 asyncio.new_event_loop().run_until_complete(self.__main(voicename,text))
@@ -98,7 +98,7 @@ class EdgeTTS(BaseTTS): @@ -98,7 +98,7 @@ class EdgeTTS(BaseTTS):
98 if self.input_stream.getbuffer().nbytes<=0: #edgetts err 98 if self.input_stream.getbuffer().nbytes<=0: #edgetts err
99 logger.error('edgetts err!!!!!') 99 logger.error('edgetts err!!!!!')
100 return 100 return
101 - 101 +
102 self.input_stream.seek(0) 102 self.input_stream.seek(0)
103 stream = self.__create_bytes_stream(self.input_stream) 103 stream = self.__create_bytes_stream(self.input_stream)
104 streamlen = stream.shape[0] 104 streamlen = stream.shape[0]
@@ -107,15 +107,15 @@ class EdgeTTS(BaseTTS): @@ -107,15 +107,15 @@ class EdgeTTS(BaseTTS):
107 eventpoint=None 107 eventpoint=None
108 streamlen -= self.chunk 108 streamlen -= self.chunk
109 if idx==0: 109 if idx==0:
110 - eventpoint={'status':'start','text':text,'msgenvent':textevent} 110 + eventpoint={'status':'start','text':text,'msgevent':textevent}
111 elif streamlen<self.chunk: 111 elif streamlen<self.chunk:
112 - eventpoint={'status':'end','text':text,'msgenvent':textevent} 112 + eventpoint={'status':'end','text':text,'msgevent':textevent}
113 self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) 113 self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint)
114 idx += self.chunk 114 idx += self.chunk
115 #if streamlen>0: #skip last frame(not 20ms) 115 #if streamlen>0: #skip last frame(not 20ms)
116 # self.queue.put(stream[idx:]) 116 # self.queue.put(stream[idx:])
117 self.input_stream.seek(0) 117 self.input_stream.seek(0)
118 - self.input_stream.truncate() 118 + self.input_stream.truncate()
119 119
120 def __create_bytes_stream(self,byte_stream): 120 def __create_bytes_stream(self,byte_stream):
121 #byte_stream=BytesIO(buffer) 121 #byte_stream=BytesIO(buffer)
@@ -126,13 +126,13 @@ class EdgeTTS(BaseTTS): @@ -126,13 +126,13 @@ class EdgeTTS(BaseTTS):
126 if stream.ndim > 1: 126 if stream.ndim > 1:
127 logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') 127 logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
128 stream = stream[:, 0] 128 stream = stream[:, 0]
129 - 129 +
130 if sample_rate != self.sample_rate and stream.shape[0]>0: 130 if sample_rate != self.sample_rate and stream.shape[0]>0:
131 logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') 131 logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
132 stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) 132 stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
133 133
134 return stream 134 return stream
135 - 135 +
136 async def __main(self,voicename: str, text: str): 136 async def __main(self,voicename: str, text: str):
137 try: 137 try:
138 communicate = edge_tts.Communicate(text, voicename) 138 communicate = edge_tts.Communicate(text, voicename)
@@ -153,12 +153,12 @@ class EdgeTTS(BaseTTS): @@ -153,12 +153,12 @@ class EdgeTTS(BaseTTS):
153 153
154 ########################################################################################### 154 ###########################################################################################
155 class FishTTS(BaseTTS): 155 class FishTTS(BaseTTS):
156 - def txt_to_audio(self,msg): 156 + def txt_to_audio(self,msg):
157 text,textevent = msg 157 text,textevent = msg
158 self.stream_tts( 158 self.stream_tts(
159 self.fish_speech( 159 self.fish_speech(
160 text, 160 text,
161 - self.opt.REF_FILE, 161 + self.opt.REF_FILE,
162 self.opt.REF_TEXT, 162 self.opt.REF_TEXT,
163 "zh", #en args.language, 163 "zh", #en args.language,
164 self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, 164 self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url,
@@ -190,9 +190,9 @@ class FishTTS(BaseTTS): @@ -190,9 +190,9 @@ class FishTTS(BaseTTS):
190 if res.status_code != 200: 190 if res.status_code != 200:
191 logger.error("Error:%s", res.text) 191 logger.error("Error:%s", res.text)
192 return 192 return
193 - 193 +
194 first = True 194 first = True
195 - 195 +
196 for chunk in res.iter_content(chunk_size=17640): # 1764 44100*20ms*2 196 for chunk in res.iter_content(chunk_size=17640): # 1764 44100*20ms*2
197 #print('chunk len:',len(chunk)) 197 #print('chunk len:',len(chunk))
198 if first: 198 if first:
@@ -209,7 +209,7 @@ class FishTTS(BaseTTS): @@ -209,7 +209,7 @@ class FishTTS(BaseTTS):
209 text,textevent = msg 209 text,textevent = msg
210 first = True 210 first = True
211 for chunk in audio_stream: 211 for chunk in audio_stream:
212 - if chunk is not None and len(chunk)>0: 212 + if chunk is not None and len(chunk)>0:
213 stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 213 stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
214 stream = resampy.resample(x=stream, sr_orig=44100, sr_new=self.sample_rate) 214 stream = resampy.resample(x=stream, sr_orig=44100, sr_new=self.sample_rate)
215 #byte_stream=BytesIO(buffer) 215 #byte_stream=BytesIO(buffer)
@@ -219,22 +219,22 @@ class FishTTS(BaseTTS): @@ -219,22 +219,22 @@ class FishTTS(BaseTTS):
219 while streamlen >= self.chunk: 219 while streamlen >= self.chunk:
220 eventpoint=None 220 eventpoint=None
221 if first: 221 if first:
222 - eventpoint={'status':'start','text':text,'msgenvent':textevent} 222 + eventpoint={'status':'start','text':text,'msgevent':textevent}
223 first = False 223 first = False
224 self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) 224 self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint)
225 streamlen -= self.chunk 225 streamlen -= self.chunk
226 idx += self.chunk 226 idx += self.chunk
227 - eventpoint={'status':'end','text':text,'msgenvent':textevent}  
228 - self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) 227 + eventpoint={'status':'end','text':text,'msgevent':textevent}
  228 + self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint)
229 229
230 ########################################################################################### 230 ###########################################################################################
231 -class VoitsTTS(BaseTTS):  
232 - def txt_to_audio(self,msg): 231 +class SovitsTTS(BaseTTS):
  232 + def txt_to_audio(self,msg):
233 text,textevent = msg 233 text,textevent = msg
234 self.stream_tts( 234 self.stream_tts(
235 self.gpt_sovits( 235 self.gpt_sovits(
236 text, 236 text,
237 - self.opt.REF_FILE, 237 + self.opt.REF_FILE,
238 self.opt.REF_TEXT, 238 self.opt.REF_TEXT,
239 "zh", #en args.language, 239 "zh", #en args.language,
240 self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, 240 self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url,
@@ -271,9 +271,9 @@ class VoitsTTS(BaseTTS): @@ -271,9 +271,9 @@ class VoitsTTS(BaseTTS):
271 if res.status_code != 200: 271 if res.status_code != 200:
272 logger.error("Error:%s", res.text) 272 logger.error("Error:%s", res.text)
273 return 273 return
274 - 274 +
275 first = True 275 first = True
276 - 276 +
277 for chunk in res.iter_content(chunk_size=None): #12800 1280 32K*20ms*2 277 for chunk in res.iter_content(chunk_size=None): #12800 1280 32K*20ms*2
278 logger.info('chunk len:%d',len(chunk)) 278 logger.info('chunk len:%d',len(chunk))
279 if first: 279 if first:
@@ -295,7 +295,7 @@ class VoitsTTS(BaseTTS): @@ -295,7 +295,7 @@ class VoitsTTS(BaseTTS):
295 if stream.ndim > 1: 295 if stream.ndim > 1:
296 logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') 296 logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
297 stream = stream[:, 0] 297 stream = stream[:, 0]
298 - 298 +
299 if sample_rate != self.sample_rate and stream.shape[0]>0: 299 if sample_rate != self.sample_rate and stream.shape[0]>0:
300 logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') 300 logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
301 stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) 301 stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
@@ -306,7 +306,7 @@ class VoitsTTS(BaseTTS): @@ -306,7 +306,7 @@ class VoitsTTS(BaseTTS):
306 text,textevent = msg 306 text,textevent = msg
307 first = True 307 first = True
308 for chunk in audio_stream: 308 for chunk in audio_stream:
309 - if chunk is not None and len(chunk)>0: 309 + if chunk is not None and len(chunk)>0:
310 #stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 310 #stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
311 #stream = resampy.resample(x=stream, sr_orig=32000, sr_new=self.sample_rate) 311 #stream = resampy.resample(x=stream, sr_orig=32000, sr_new=self.sample_rate)
312 byte_stream=BytesIO(chunk) 312 byte_stream=BytesIO(chunk)
@@ -316,22 +316,22 @@ class VoitsTTS(BaseTTS): @@ -316,22 +316,22 @@ class VoitsTTS(BaseTTS):
316 while streamlen >= self.chunk: 316 while streamlen >= self.chunk:
317 eventpoint=None 317 eventpoint=None
318 if first: 318 if first:
319 - eventpoint={'status':'start','text':text,'msgenvent':textevent} 319 + eventpoint={'status':'start','text':text,'msgevent':textevent}
320 first = False 320 first = False
321 self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) 321 self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint)
322 streamlen -= self.chunk 322 streamlen -= self.chunk
323 idx += self.chunk 323 idx += self.chunk
324 - eventpoint={'status':'end','text':text,'msgenvent':textevent} 324 + eventpoint={'status':'end','text':text,'msgevent':textevent}
325 self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) 325 self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint)
326 326
327 ########################################################################################### 327 ###########################################################################################
328 class CosyVoiceTTS(BaseTTS): 328 class CosyVoiceTTS(BaseTTS):
329 def txt_to_audio(self,msg): 329 def txt_to_audio(self,msg):
330 - text,textevent = msg 330 + text,textevent = msg
331 self.stream_tts( 331 self.stream_tts(
332 self.cosy_voice( 332 self.cosy_voice(
333 text, 333 text,
334 - self.opt.REF_FILE, 334 + self.opt.REF_FILE,
335 self.opt.REF_TEXT, 335 self.opt.REF_TEXT,
336 "zh", #en args.language, 336 "zh", #en args.language,
337 self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, 337 self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url,
@@ -348,16 +348,16 @@ class CosyVoiceTTS(BaseTTS): @@ -348,16 +348,16 @@ class CosyVoiceTTS(BaseTTS):
348 try: 348 try:
349 files = [('prompt_wav', ('prompt_wav', open(reffile, 'rb'), 'application/octet-stream'))] 349 files = [('prompt_wav', ('prompt_wav', open(reffile, 'rb'), 'application/octet-stream'))]
350 res = requests.request("GET", f"{server_url}/inference_zero_shot", data=payload, files=files, stream=True) 350 res = requests.request("GET", f"{server_url}/inference_zero_shot", data=payload, files=files, stream=True)
351 - 351 +
352 end = time.perf_counter() 352 end = time.perf_counter()
353 logger.info(f"cosy_voice Time to make POST: {end-start}s") 353 logger.info(f"cosy_voice Time to make POST: {end-start}s")
354 354
355 if res.status_code != 200: 355 if res.status_code != 200:
356 logger.error("Error:%s", res.text) 356 logger.error("Error:%s", res.text)
357 return 357 return
358 - 358 +
359 first = True 359 first = True
360 - 360 +
361 for chunk in res.iter_content(chunk_size=9600): # 960 24K*20ms*2 361 for chunk in res.iter_content(chunk_size=9600): # 960 24K*20ms*2
362 if first: 362 if first:
363 end = time.perf_counter() 363 end = time.perf_counter()
@@ -372,7 +372,7 @@ class CosyVoiceTTS(BaseTTS): @@ -372,7 +372,7 @@ class CosyVoiceTTS(BaseTTS):
372 text,textevent = msg 372 text,textevent = msg
373 first = True 373 first = True
374 for chunk in audio_stream: 374 for chunk in audio_stream:
375 - if chunk is not None and len(chunk)>0: 375 + if chunk is not None and len(chunk)>0:
376 stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 376 stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
377 stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) 377 stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
378 #byte_stream=BytesIO(buffer) 378 #byte_stream=BytesIO(buffer)
@@ -382,13 +382,13 @@ class CosyVoiceTTS(BaseTTS): @@ -382,13 +382,13 @@ class CosyVoiceTTS(BaseTTS):
382 while streamlen >= self.chunk: 382 while streamlen >= self.chunk:
383 eventpoint=None 383 eventpoint=None
384 if first: 384 if first:
385 - eventpoint={'status':'start','text':text,'msgenvent':textevent} 385 + eventpoint={'status':'start','text':text,'msgevent':textevent}
386 first = False 386 first = False
387 self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) 387 self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint)
388 streamlen -= self.chunk 388 streamlen -= self.chunk
389 idx += self.chunk 389 idx += self.chunk
390 - eventpoint={'status':'end','text':text,'msgenvent':textevent}  
391 - self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) 390 + eventpoint={'status':'end','text':text,'msgevent':textevent}
  391 + self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint)
392 392
393 ########################################################################################### 393 ###########################################################################################
394 _PROTOCOL = "https://" 394 _PROTOCOL = "https://"
@@ -407,7 +407,7 @@ class TencentTTS(BaseTTS): @@ -407,7 +407,7 @@ class TencentTTS(BaseTTS):
407 self.sample_rate = 16000 407 self.sample_rate = 16000
408 self.volume = 0 408 self.volume = 0
409 self.speed = 0 409 self.speed = 0
410 - 410 +
411 def __gen_signature(self, params): 411 def __gen_signature(self, params):
412 sort_dict = sorted(params.keys()) 412 sort_dict = sorted(params.keys())
413 sign_str = "POST" + _HOST + _PATH + "?" 413 sign_str = "POST" + _HOST + _PATH + "?"
@@ -440,11 +440,11 @@ class TencentTTS(BaseTTS): @@ -440,11 +440,11 @@ class TencentTTS(BaseTTS):
440 return params 440 return params
441 441
442 def txt_to_audio(self,msg): 442 def txt_to_audio(self,msg):
443 - text,textevent = msg 443 + text,textevent = msg
444 self.stream_tts( 444 self.stream_tts(
445 self.tencent_voice( 445 self.tencent_voice(
446 text, 446 text,
447 - self.opt.REF_FILE, 447 + self.opt.REF_FILE,
448 self.opt.REF_TEXT, 448 self.opt.REF_TEXT,
449 "zh", #en args.language, 449 "zh", #en args.language,
450 self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, 450 self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url,
@@ -465,12 +465,12 @@ class TencentTTS(BaseTTS): @@ -465,12 +465,12 @@ class TencentTTS(BaseTTS):
465 try: 465 try:
466 res = requests.post(url, headers=headers, 466 res = requests.post(url, headers=headers,
467 data=json.dumps(params), stream=True) 467 data=json.dumps(params), stream=True)
468 - 468 +
469 end = time.perf_counter() 469 end = time.perf_counter()
470 logger.info(f"tencent Time to make POST: {end-start}s") 470 logger.info(f"tencent Time to make POST: {end-start}s")
471 - 471 +
472 first = True 472 first = True
473 - 473 +
474 for chunk in res.iter_content(chunk_size=6400): # 640 16K*20ms*2 474 for chunk in res.iter_content(chunk_size=6400): # 640 16K*20ms*2
475 #logger.info('chunk len:%d',len(chunk)) 475 #logger.info('chunk len:%d',len(chunk))
476 if first: 476 if first:
@@ -483,7 +483,7 @@ class TencentTTS(BaseTTS): @@ -483,7 +483,7 @@ class TencentTTS(BaseTTS):
483 except: 483 except:
484 end = time.perf_counter() 484 end = time.perf_counter()
485 logger.info(f"tencent Time to first chunk: {end-start}s") 485 logger.info(f"tencent Time to first chunk: {end-start}s")
486 - first = False 486 + first = False
487 if chunk and self.state==State.RUNNING: 487 if chunk and self.state==State.RUNNING:
488 yield chunk 488 yield chunk
489 except Exception as e: 489 except Exception as e:
@@ -494,7 +494,7 @@ class TencentTTS(BaseTTS): @@ -494,7 +494,7 @@ class TencentTTS(BaseTTS):
494 first = True 494 first = True
495 last_stream = np.array([],dtype=np.float32) 495 last_stream = np.array([],dtype=np.float32)
496 for chunk in audio_stream: 496 for chunk in audio_stream:
497 - if chunk is not None and len(chunk)>0: 497 + if chunk is not None and len(chunk)>0:
498 stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 498 stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
499 stream = np.concatenate((last_stream,stream)) 499 stream = np.concatenate((last_stream,stream))
500 #stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) 500 #stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
@@ -505,14 +505,14 @@ class TencentTTS(BaseTTS): @@ -505,14 +505,14 @@ class TencentTTS(BaseTTS):
505 while streamlen >= self.chunk: 505 while streamlen >= self.chunk:
506 eventpoint=None 506 eventpoint=None
507 if first: 507 if first:
508 - eventpoint={'status':'start','text':text,'msgenvent':textevent} 508 + eventpoint={'status':'start','text':text,'msgevent':textevent}
509 first = False 509 first = False
510 self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) 510 self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint)
511 streamlen -= self.chunk 511 streamlen -= self.chunk
512 idx += self.chunk 512 idx += self.chunk
513 last_stream = stream[idx:] #get the remain stream 513 last_stream = stream[idx:] #get the remain stream
514 - eventpoint={'status':'end','text':text,'msgenvent':textevent}  
515 - self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) 514 + eventpoint={'status':'end','text':text,'msgevent':textevent}
  515 + self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint)
516 516
517 ########################################################################################### 517 ###########################################################################################
518 518
@@ -522,7 +522,7 @@ class XTTS(BaseTTS): @@ -522,7 +522,7 @@ class XTTS(BaseTTS):
522 self.speaker = self.get_speaker(opt.REF_FILE, opt.TTS_SERVER) 522 self.speaker = self.get_speaker(opt.REF_FILE, opt.TTS_SERVER)
523 523
524 def txt_to_audio(self,msg): 524 def txt_to_audio(self,msg):
525 - text,textevent = msg 525 + text,textevent = msg
526 self.stream_tts( 526 self.stream_tts(
527 self.xtts( 527 self.xtts(
528 text, 528 text,
@@ -558,7 +558,7 @@ class XTTS(BaseTTS): @@ -558,7 +558,7 @@ class XTTS(BaseTTS):
558 return 558 return
559 559
560 first = True 560 first = True
561 - 561 +
562 for chunk in res.iter_content(chunk_size=9600): #24K*20ms*2 562 for chunk in res.iter_content(chunk_size=9600): #24K*20ms*2
563 if first: 563 if first:
564 end = time.perf_counter() 564 end = time.perf_counter()
@@ -568,12 +568,12 @@ class XTTS(BaseTTS): @@ -568,12 +568,12 @@ class XTTS(BaseTTS):
568 yield chunk 568 yield chunk
569 except Exception as e: 569 except Exception as e:
570 print(e) 570 print(e)
571 - 571 +
572 def stream_tts(self,audio_stream,msg): 572 def stream_tts(self,audio_stream,msg):
573 text,textevent = msg 573 text,textevent = msg
574 first = True 574 first = True
575 for chunk in audio_stream: 575 for chunk in audio_stream:
576 - if chunk is not None and len(chunk)>0: 576 + if chunk is not None and len(chunk)>0:
577 stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 577 stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
578 stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) 578 stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
579 #byte_stream=BytesIO(buffer) 579 #byte_stream=BytesIO(buffer)
@@ -583,10 +583,10 @@ class XTTS(BaseTTS): @@ -583,10 +583,10 @@ class XTTS(BaseTTS):
583 while streamlen >= self.chunk: 583 while streamlen >= self.chunk:
584 eventpoint=None 584 eventpoint=None
585 if first: 585 if first:
586 - eventpoint={'status':'start','text':text,'msgenvent':textevent} 586 + eventpoint={'status':'start','text':text,'msgevent':textevent}
587 first = False 587 first = False
588 self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) 588 self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint)
589 streamlen -= self.chunk 589 streamlen -= self.chunk
590 idx += self.chunk 590 idx += self.chunk
591 - eventpoint={'status':'end','text':text,'msgenvent':textevent} 591 + eventpoint={'status':'end','text':text,'msgevent':textevent}
592 self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) 592 self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint)
@@ -236,7 +236,7 @@ if __name__ == '__main__': @@ -236,7 +236,7 @@ if __name__ == '__main__':
236 if hasattr(module, 'reparameterize'): 236 if hasattr(module, 'reparameterize'):
237 module.reparameterize() 237 module.reparameterize()
238 return model 238 return model
239 - device = torch.device("cuda") 239 + device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
240 def check_onnx(torch_out, torch_in, audio): 240 def check_onnx(torch_out, torch_in, audio):
241 onnx_model = onnx.load(onnx_path) 241 onnx_model = onnx.load(onnx_path)
242 onnx.checker.check_model(onnx_model) 242 onnx.checker.check_model(onnx_model)
  1 +<!DOCTYPE html>
  2 +<html lang="zh-CN">
  3 +<head>
  4 + <meta charset="UTF-8">
  5 + <meta name="viewport" content="width=device-width, initial-scale=1.0">
  6 + <title>livetalking数字人交互平台</title>
  7 + <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
  8 + <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap-icons@1.10.0/font/bootstrap-icons.css">
  9 + <style>
  10 + :root {
  11 + --primary-color: #4361ee;
  12 + --secondary-color: #3f37c9;
  13 + --accent-color: #4895ef;
  14 + --background-color: #f8f9fa;
  15 + --card-bg: #ffffff;
  16 + --text-color: #212529;
  17 + --border-radius: 10px;
  18 + --box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
  19 + }
  20 +
  21 + body {
  22 + font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
  23 + background-color: var(--background-color);
  24 + color: var(--text-color);
  25 + min-height: 100vh;
  26 + padding-top: 20px;
  27 + }
  28 +
  29 + .dashboard-container {
  30 + max-width: 1400px;
  31 + margin: 0 auto;
  32 + padding: 20px;
  33 + }
  34 +
  35 + .card {
  36 + background-color: var(--card-bg);
  37 + border-radius: var(--border-radius);
  38 + box-shadow: var(--box-shadow);
  39 + border: none;
  40 + margin-bottom: 20px;
  41 + overflow: hidden;
  42 + }
  43 +
  44 + .card-header {
  45 + background-color: var(--primary-color);
  46 + color: white;
  47 + font-weight: 600;
  48 + padding: 15px 20px;
  49 + border-bottom: none;
  50 + }
  51 +
  52 + .video-container {
  53 + position: relative;
  54 + width: 100%;
  55 + background-color: #000;
  56 + border-radius: var(--border-radius);
  57 + overflow: hidden;
  58 + display: flex;
  59 + justify-content: center;
  60 + align-items: center;
  61 + }
  62 +
  63 + video {
  64 + max-width: 100%;
  65 + max-height: 100%;
  66 + display: block;
  67 + border-radius: var(--border-radius);
  68 + }
  69 +
  70 + .controls-container {
  71 + padding: 20px;
  72 + }
  73 +
  74 + .btn-primary {
  75 + background-color: var(--primary-color);
  76 + border-color: var(--primary-color);
  77 + }
  78 +
  79 + .btn-primary:hover {
  80 + background-color: var(--secondary-color);
  81 + border-color: var(--secondary-color);
  82 + }
  83 +
  84 + .btn-outline-primary {
  85 + color: var(--primary-color);
  86 + border-color: var(--primary-color);
  87 + }
  88 +
  89 + .btn-outline-primary:hover {
  90 + background-color: var(--primary-color);
  91 + color: white;
  92 + }
  93 +
  94 + .form-control {
  95 + border-radius: var(--border-radius);
  96 + padding: 10px 15px;
  97 + border: 1px solid #ced4da;
  98 + }
  99 +
  100 + .form-control:focus {
  101 + border-color: var(--accent-color);
  102 + box-shadow: 0 0 0 0.25rem rgba(67, 97, 238, 0.25);
  103 + }
  104 +
  105 + .status-indicator {
  106 + width: 10px;
  107 + height: 10px;
  108 + border-radius: 50%;
  109 + display: inline-block;
  110 + margin-right: 5px;
  111 + }
  112 +
  113 + .status-connected {
  114 + background-color: #28a745;
  115 + }
  116 +
  117 + .status-disconnected {
  118 + background-color: #dc3545;
  119 + }
  120 +
  121 + .status-connecting {
  122 + background-color: #ffc107;
  123 + }
  124 +
  125 + .asr-container {
  126 + height: 300px;
  127 + overflow-y: auto;
  128 + padding: 15px;
  129 + background-color: #f8f9fa;
  130 + border-radius: var(--border-radius);
  131 + border: 1px solid #ced4da;
  132 + }
  133 +
  134 + .asr-text {
  135 + margin-bottom: 10px;
  136 + padding: 10px;
  137 + background-color: white;
  138 + border-radius: var(--border-radius);
  139 + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
  140 + }
  141 +
  142 + .user-message {
  143 + background-color: #e3f2fd;
  144 + border-left: 4px solid var(--primary-color);
  145 + }
  146 +
  147 + .system-message {
  148 + background-color: #f1f8e9;
  149 + border-left: 4px solid #8bc34a;
  150 + }
  151 +
  152 + .recording-indicator {
  153 + position: absolute;
  154 + top: 15px;
  155 + right: 15px;
  156 + background-color: rgba(220, 53, 69, 0.8);
  157 + color: white;
  158 + padding: 5px 10px;
  159 + border-radius: 20px;
  160 + font-size: 0.8rem;
  161 + display: none;
  162 + }
  163 +
  164 + .recording-indicator.active {
  165 + display: flex;
  166 + align-items: center;
  167 + }
  168 +
  169 + .recording-indicator .blink {
  170 + width: 10px;
  171 + height: 10px;
  172 + background-color: #fff;
  173 + border-radius: 50%;
  174 + margin-right: 5px;
  175 + animation: blink 1s infinite;
  176 + }
  177 +
  178 + @keyframes blink {
  179 + 0% { opacity: 1; }
  180 + 50% { opacity: 0.3; }
  181 + 100% { opacity: 1; }
  182 + }
  183 +
  184 + .mode-switch {
  185 + margin-bottom: 20px;
  186 + }
  187 +
  188 + .nav-tabs .nav-link {
  189 + color: var(--text-color);
  190 + border: none;
  191 + padding: 10px 20px;
  192 + border-radius: var(--border-radius) var(--border-radius) 0 0;
  193 + }
  194 +
  195 + .nav-tabs .nav-link.active {
  196 + color: var(--primary-color);
  197 + background-color: var(--card-bg);
  198 + border-bottom: 3px solid var(--primary-color);
  199 + font-weight: 600;
  200 + }
  201 +
  202 + .tab-content {
  203 + padding: 20px;
  204 + background-color: var(--card-bg);
  205 + border-radius: 0 0 var(--border-radius) var(--border-radius);
  206 + }
  207 +
  208 + .settings-panel {
  209 + padding: 15px;
  210 + background-color: #f8f9fa;
  211 + border-radius: var(--border-radius);
  212 + margin-top: 15px;
  213 + }
  214 +
  215 + .footer {
  216 + text-align: center;
  217 + margin-top: 30px;
  218 + padding: 20px 0;
  219 + color: #6c757d;
  220 + font-size: 0.9rem;
  221 + }
  222 +
  223 + .voice-record-btn {
  224 + width: 60px;
  225 + height: 60px;
  226 + border-radius: 50%;
  227 + background-color: var(--primary-color);
  228 + color: white;
  229 + display: flex;
  230 + justify-content: center;
  231 + align-items: center;
  232 + cursor: pointer;
  233 + transition: all 0.2s ease;
  234 + box-shadow: 0 2px 5px rgba(0,0,0,0.2);
  235 + margin: 0 auto;
  236 + }
  237 +
  238 + .voice-record-btn:hover {
  239 + background-color: var(--secondary-color);
  240 + transform: scale(1.05);
  241 + }
  242 +
  243 + .voice-record-btn:active {
  244 + background-color: #dc3545;
  245 + transform: scale(0.95);
  246 + }
  247 +
  248 + .voice-record-btn i {
  249 + font-size: 24px;
  250 + }
  251 +
  252 + .voice-record-label {
  253 + text-align: center;
  254 + margin-top: 10px;
  255 + font-size: 14px;
  256 + color: #6c757d;
  257 + }
  258 +
  259 + .video-size-control {
  260 + margin-top: 15px;
  261 + }
  262 +
  263 + .recording-pulse {
  264 + animation: pulse 1.5s infinite;
  265 + }
  266 +
  267 + @keyframes pulse {
  268 + 0% {
  269 + box-shadow: 0 0 0 0 rgba(220, 53, 69, 0.7);
  270 + }
  271 + 70% {
  272 + box-shadow: 0 0 0 15px rgba(220, 53, 69, 0);
  273 + }
  274 + 100% {
  275 + box-shadow: 0 0 0 0 rgba(220, 53, 69, 0);
  276 + }
  277 + }
  278 + </style>
  279 +</head>
  280 +<body>
  281 + <div class="dashboard-container">
  282 + <div class="row">
  283 + <div class="col-12">
  284 + <h1 class="text-center mb-4">livetalking数字人交互平台</h1>
  285 + </div>
  286 + </div>
  287 +
  288 + <div class="row">
  289 + <!-- 视频区域 -->
  290 + <div class="col-lg-8">
  291 + <div class="card">
  292 + <div class="card-header d-flex justify-content-between align-items-center">
  293 + <div>
  294 + <span class="status-indicator status-disconnected" id="connection-status"></span>
  295 + <span id="status-text">未连接</span>
  296 + </div>
  297 + </div>
  298 + <div class="card-body p-0">
  299 + <div class="video-container">
  300 + <video id="video" autoplay playsinline></video>
  301 + <div class="recording-indicator" id="recording-indicator">
  302 + <div class="blink"></div>
  303 + <span>录制中</span>
  304 + </div>
  305 + </div>
  306 +
  307 + <div class="controls-container">
  308 + <div class="row">
  309 + <div class="col-md-6 mb-3">
  310 + <button class="btn btn-primary w-100" id="start">
  311 + <i class="bi bi-play-fill"></i> 开始连接
  312 + </button>
  313 + <button class="btn btn-danger w-100" id="stop" style="display: none;">
  314 + <i class="bi bi-stop-fill"></i> 停止连接
  315 + </button>
  316 + </div>
  317 + <div class="col-md-6 mb-3">
  318 + <div class="d-flex">
  319 + <button class="btn btn-outline-primary flex-grow-1 me-2" id="btn_start_record">
  320 + <i class="bi bi-record-fill"></i> 开始录制
  321 + </button>
  322 + <button class="btn btn-outline-danger flex-grow-1" id="btn_stop_record" disabled>
  323 + <i class="bi bi-stop-fill"></i> 停止录制
  324 + </button>
  325 + </div>
  326 + </div>
  327 + </div>
  328 +
  329 + <div class="row">
  330 + <div class="col-12">
  331 + <div class="video-size-control">
  332 + <label for="video-size-slider" class="form-label">视频大小调节: <span id="video-size-value">100%</span></label>
  333 + <input type="range" class="form-range" id="video-size-slider" min="50" max="150" value="100">
  334 + </div>
  335 + </div>
  336 + </div>
  337 +
  338 + <div class="settings-panel mt-3">
  339 + <div class="row">
  340 + <div class="col-md-12">
  341 + <div class="form-check form-switch mb-3">
  342 + <input class="form-check-input" type="checkbox" id="use-stun">
  343 + <label class="form-check-label" for="use-stun">使用STUN服务器</label>
  344 + </div>
  345 + </div>
  346 + </div>
  347 + </div>
  348 + </div>
  349 + </div>
  350 + </div>
  351 + </div>
  352 +
  353 + <!-- 右侧交互 -->
  354 + <div class="col-lg-4">
  355 + <div class="card">
  356 + <div class="card-header">
  357 + <ul class="nav nav-tabs card-header-tabs" id="interaction-tabs" role="tablist">
  358 + <li class="nav-item" role="presentation">
  359 + <button class="nav-link active" id="chat-tab" data-bs-toggle="tab" data-bs-target="#chat" type="button" role="tab" aria-controls="chat" aria-selected="true">对话模式</button>
  360 + </li>
  361 + <li class="nav-item" role="presentation">
  362 + <button class="nav-link" id="tts-tab" data-bs-toggle="tab" data-bs-target="#tts" type="button" role="tab" aria-controls="tts" aria-selected="false">朗读模式</button>
  363 + </li>
  364 + </ul>
  365 + </div>
  366 + <div class="card-body">
  367 + <div class="tab-content" id="interaction-tabs-content">
  368 + <!-- 对话模式 -->
  369 + <div class="tab-pane fade show active" id="chat" role="tabpanel" aria-labelledby="chat-tab">
  370 + <div class="asr-container mb-3" id="chat-messages">
  371 + <div class="asr-text system-message">
  372 + 系统: 欢迎使用livetalking,请点击"开始连接"按钮开始对话。
  373 + </div>
  374 + </div>
  375 +
  376 + <form id="chat-form">
  377 + <div class="input-group mb-3">
  378 + <textarea class="form-control" id="chat-message" rows="3" placeholder="输入您想对数字人说的话..."></textarea>
  379 + <button class="btn btn-primary" type="submit">
  380 + <i class="bi bi-send"></i> 发送
  381 + </button>
  382 + </div>
  383 + </form>
  384 +
  385 + <!-- 按住说话按钮 -->
  386 + <div class="voice-record-btn" id="voice-record-btn">
  387 + <i class="bi bi-mic-fill"></i>
  388 + </div>
  389 + <div class="voice-record-label">按住说话,松开发送</div>
  390 + </div>
  391 +
  392 + <!-- 朗读模式 -->
  393 + <div class="tab-pane fade" id="tts" role="tabpanel" aria-labelledby="tts-tab">
  394 + <form id="echo-form">
  395 + <div class="mb-3">
  396 + <label for="message" class="form-label">输入要朗读的文本</label>
  397 + <textarea class="form-control" id="message" rows="6" placeholder="输入您想让数字人朗读的文字..."></textarea>
  398 + </div>
  399 + <button type="submit" class="btn btn-primary w-100">
  400 + <i class="bi bi-volume-up"></i> 朗读文本
  401 + </button>
  402 + </form>
  403 + </div>
  404 + </div>
  405 + </div>
  406 + </div>
  407 + </div>
  408 + </div>
  409 +
  410 + <div class="footer">
  411 + <p>Made with ❤️ by Marstaos | Frontend & Performance Optimization</p>
  412 + </div>
  413 + </div>
  414 +
  415 + <!-- 隐藏的会话ID -->
  416 + <input type="hidden" id="sessionid" value="0">
  417 +
  418 +
  419 + <script src="client.js"></script>
  420 + <script src="srs.sdk.js"></script>
  421 + <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
  422 + <script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
  423 + <script>
  424 + $(document).ready(function() {
  425 + $('#video-size-slider').on('input', function() {
  426 + const value = $(this).val();
  427 + $('#video-size-value').text(value + '%');
  428 + $('#video').css('width', value + '%');
  429 + });
  430 + function updateConnectionStatus(status) {
  431 + const statusIndicator = $('#connection-status');
  432 + const statusText = $('#status-text');
  433 +
  434 + statusIndicator.removeClass('status-connected status-disconnected status-connecting');
  435 +
  436 + switch(status) {
  437 + case 'connected':
  438 + statusIndicator.addClass('status-connected');
  439 + statusText.text('已连接');
  440 + break;
  441 + case 'connecting':
  442 + statusIndicator.addClass('status-connecting');
  443 + statusText.text('连接中...');
  444 + break;
  445 + case 'disconnected':
  446 + default:
  447 + statusIndicator.addClass('status-disconnected');
  448 + statusText.text('未连接');
  449 + break;
  450 + }
  451 + }
  452 +
  453 + // 添加聊天消息
  454 + function addChatMessage(message, type = 'user') {
  455 + const messagesContainer = $('#chat-messages');
  456 + const messageClass = type === 'user' ? 'user-message' : 'system-message';
  457 + const sender = type === 'user' ? '您' : '数字人';
  458 +
  459 + const messageElement = $(`
  460 + <div class="asr-text ${messageClass}">
  461 + ${sender}: ${message}
  462 + </div>
  463 + `);
  464 +
  465 + messagesContainer.append(messageElement);
  466 + messagesContainer.scrollTop(messagesContainer[0].scrollHeight);
  467 + }
  468 +
  469 + // 开始/停止按钮
  470 + $('#start').click(function() {
  471 + updateConnectionStatus('connecting');
  472 + start();
  473 + $(this).hide();
  474 + $('#stop').show();
  475 +
  476 + // 添加定时器检查视频流是否已加载
  477 + let connectionCheckTimer = setInterval(function() {
  478 + const video = document.getElementById('video');
  479 + // 检查视频是否有数据
  480 + if (video.readyState >= 3 && video.videoWidth > 0) {
  481 + updateConnectionStatus('connected');
  482 + clearInterval(connectionCheckTimer);
  483 + }
  484 + }, 2000); // 每2秒检查一次
  485 +
  486 + // 60秒后如果还是连接中状态,就停止检查
  487 + setTimeout(function() {
  488 + if (connectionCheckTimer) {
  489 + clearInterval(connectionCheckTimer);
  490 + }
  491 + }, 60000);
  492 + });
  493 +
  494 + $('#stop').click(function() {
  495 + stop();
  496 + $(this).hide();
  497 + $('#start').show();
  498 + updateConnectionStatus('disconnected');
  499 + });
  500 +
  501 + // 录制功能
  502 + $('#btn_start_record').click(function() {
  503 + console.log('Starting recording...');
  504 + fetch('/record', {
  505 + body: JSON.stringify({
  506 + type: 'start_record',
  507 + sessionid: parseInt(document.getElementById('sessionid').value),
  508 + }),
  509 + headers: {
  510 + 'Content-Type': 'application/json'
  511 + },
  512 + method: 'POST'
  513 + }).then(function(response) {
  514 + if (response.ok) {
  515 + console.log('Recording started.');
  516 + $('#btn_start_record').prop('disabled', true);
  517 + $('#btn_stop_record').prop('disabled', false);
  518 + $('#recording-indicator').addClass('active');
  519 + } else {
  520 + console.error('Failed to start recording.');
  521 + }
  522 + }).catch(function(error) {
  523 + console.error('Error:', error);
  524 + });
  525 + });
  526 +
  527 + $('#btn_stop_record').click(function() {
  528 + console.log('Stopping recording...');
  529 + fetch('/record', {
  530 + body: JSON.stringify({
  531 + type: 'end_record',
  532 + sessionid: parseInt(document.getElementById('sessionid').value),
  533 + }),
  534 + headers: {
  535 + 'Content-Type': 'application/json'
  536 + },
  537 + method: 'POST'
  538 + }).then(function(response) {
  539 + if (response.ok) {
  540 + console.log('Recording stopped.');
  541 + $('#btn_start_record').prop('disabled', false);
  542 + $('#btn_stop_record').prop('disabled', true);
  543 + $('#recording-indicator').removeClass('active');
  544 + } else {
  545 + console.error('Failed to stop recording.');
  546 + }
  547 + }).catch(function(error) {
  548 + console.error('Error:', error);
  549 + });
  550 + });
  551 +
  552 + $('#echo-form').on('submit', function(e) {
  553 + e.preventDefault();
  554 + var message = $('#message').val();
  555 + if (!message.trim()) return;
  556 +
  557 + console.log('Sending echo message:', message);
  558 +
  559 + fetch('/human', {
  560 + body: JSON.stringify({
  561 + text: message,
  562 + type: 'echo',
  563 + interrupt: true,
  564 + sessionid: parseInt(document.getElementById('sessionid').value),
  565 + }),
  566 + headers: {
  567 + 'Content-Type': 'application/json'
  568 + },
  569 + method: 'POST'
  570 + });
  571 +
  572 + $('#message').val('');
  573 + addChatMessage(`已发送朗读请求: "${message}"`, 'system');
  574 + });
  575 +
  576 + // 聊天模式表单提交
  577 + $('#chat-form').on('submit', function(e) {
  578 + e.preventDefault();
  579 + var message = $('#chat-message').val();
  580 + if (!message.trim()) return;
  581 +
  582 + console.log('Sending chat message:', message);
  583 +
  584 + fetch('/human', {
  585 + body: JSON.stringify({
  586 + text: message,
  587 + type: 'chat',
  588 + interrupt: true,
  589 + sessionid: parseInt(document.getElementById('sessionid').value),
  590 + }),
  591 + headers: {
  592 + 'Content-Type': 'application/json'
  593 + },
  594 + method: 'POST'
  595 + });
  596 +
  597 + addChatMessage(message, 'user');
  598 + $('#chat-message').val('');
  599 + });
  600 +
  601 + // 按住说话功能
  602 + let mediaRecorder;
  603 + let audioChunks = [];
  604 + let isRecording = false;
  605 + let recognition;
  606 +
  607 + // 检查浏览器是否支持语音识别
  608 + const isSpeechRecognitionSupported = 'webkitSpeechRecognition' in window || 'SpeechRecognition' in window;
  609 +
  610 + if (isSpeechRecognitionSupported) {
  611 + recognition = new (window.SpeechRecognition || window.webkitSpeechRecognition)();
  612 + recognition.continuous = true;
  613 + recognition.interimResults = true;
  614 + recognition.lang = 'zh-CN';
  615 +
  616 + recognition.onresult = function(event) {
  617 + let interimTranscript = '';
  618 + let finalTranscript = '';
  619 +
  620 + for (let i = event.resultIndex; i < event.results.length; ++i) {
  621 + if (event.results[i].isFinal) {
  622 + finalTranscript += event.results[i][0].transcript;
  623 + } else {
  624 + interimTranscript += event.results[i][0].transcript;
  625 + $('#chat-message').val(interimTranscript);
  626 + }
  627 + }
  628 +
  629 + if (finalTranscript) {
  630 + $('#chat-message').val(finalTranscript);
  631 + }
  632 + };
  633 +
  634 + recognition.onerror = function(event) {
  635 + console.error('语音识别错误:', event.error);
  636 + };
  637 + }
  638 +
  639 + // 按住说话按钮事件
  640 + $('#voice-record-btn').on('mousedown touchstart', function(e) {
  641 + e.preventDefault();
  642 + startRecording();
  643 + }).on('mouseup mouseleave touchend', function() {
  644 + if (isRecording) {
  645 + stopRecording();
  646 + }
  647 + });
  648 +
  649 + // 开始录音
  650 + function startRecording() {
  651 + if (isRecording) return;
  652 +
  653 + navigator.mediaDevices.getUserMedia({ audio: true })
  654 + .then(function(stream) {
  655 + audioChunks = [];
  656 + mediaRecorder = new MediaRecorder(stream);
  657 +
  658 + mediaRecorder.ondataavailable = function(e) {
  659 + if (e.data.size > 0) {
  660 + audioChunks.push(e.data);
  661 + }
  662 + };
  663 +
  664 + mediaRecorder.start();
  665 + isRecording = true;
  666 +
  667 + $('#voice-record-btn').addClass('recording-pulse');
  668 + $('#voice-record-btn').css('background-color', '#dc3545');
  669 +
  670 + if (recognition) {
  671 + recognition.start();
  672 + }
  673 + })
  674 + .catch(function(error) {
  675 + console.error('无法访问麦克风:', error);
  676 + alert('无法访问麦克风,请检查浏览器权限设置。');
  677 + });
  678 + }
  679 +
  680 + function stopRecording() {
  681 + if (!isRecording) return;
  682 +
  683 + mediaRecorder.stop();
  684 + isRecording = false;
  685 +
  686 + // 停止所有音轨
  687 + mediaRecorder.stream.getTracks().forEach(track => track.stop());
  688 +
  689 + // 视觉反馈恢复
  690 + $('#voice-record-btn').removeClass('recording-pulse');
  691 + $('#voice-record-btn').css('background-color', '');
  692 +
  693 + // 停止语音识别
  694 + if (recognition) {
  695 + recognition.stop();
  696 + }
  697 +
  698 + // 获取识别的文本并发送
  699 + setTimeout(function() {
  700 + const recognizedText = $('#chat-message').val().trim();
  701 + if (recognizedText) {
  702 + // 发送识别的文本
  703 + fetch('/human', {
  704 + body: JSON.stringify({
  705 + text: recognizedText,
  706 + type: 'chat',
  707 + interrupt: true,
  708 + sessionid: parseInt(document.getElementById('sessionid').value),
  709 + }),
  710 + headers: {
  711 + 'Content-Type': 'application/json'
  712 + },
  713 + method: 'POST'
  714 + });
  715 +
  716 + addChatMessage(recognizedText, 'user');
  717 + $('#chat-message').val('');
  718 + }
  719 + }, 500);
  720 + }
  721 +
  722 + // WebRTC 相关功能
  723 + if (typeof window.onWebRTCConnected === 'function') {
  724 + const originalOnConnected = window.onWebRTCConnected;
  725 + window.onWebRTCConnected = function() {
  726 + updateConnectionStatus('connected');
  727 + if (originalOnConnected) originalOnConnected();
  728 + };
  729 + } else {
  730 + window.onWebRTCConnected = function() {
  731 + updateConnectionStatus('connected');
  732 + };
  733 + }
  734 +
  735 + // 当连接断开时更新状态
  736 + if (typeof window.onWebRTCDisconnected === 'function') {
  737 + const originalOnDisconnected = window.onWebRTCDisconnected;
  738 + window.onWebRTCDisconnected = function() {
  739 + updateConnectionStatus('disconnected');
  740 + if (originalOnDisconnected) originalOnDisconnected();
  741 + };
  742 + } else {
  743 + window.onWebRTCDisconnected = function() {
  744 + updateConnectionStatus('disconnected');
  745 + };
  746 + }
  747 +
  748 + // SRS WebRTC播放功能
  749 + var sdk = null; // 全局处理器,用于在重新发布时进行清理
  750 +
  751 + function startPlay() {
  752 + // 关闭之前的连接
  753 + if (sdk) {
  754 + sdk.close();
  755 + }
  756 +
  757 + sdk = new SrsRtcWhipWhepAsync();
  758 + $('#video').prop('srcObject', sdk.stream);
  759 +
  760 + var host = window.location.hostname;
  761 + var url = "http://" + host + ":1985/rtc/v1/whep/?app=live&stream=livestream";
  762 +
  763 + sdk.play(url).then(function(session) {
  764 + console.log('WebRTC播放已启动,会话ID:', session.sessionid);
  765 + }).catch(function(reason) {
  766 + sdk.close();
  767 + console.error('WebRTC播放失败:', reason);
  768 + });
  769 + }
  770 + });
  771 + </script>
  772 +</body>
  773 +</html>