同步github官方更新截止Commits on Apr 18, 2025
a9c36c76e569107b5a39b3de8afd6e016b24d662
Showing
23 changed files
with
1100 additions
and
196 deletions
README-EN.md
0 → 100644
| 1 | +Real-time interactive streaming digital human enables synchronous audio and video dialogue. It can basically achieve commercial effects. | ||
| 2 | + | ||
| 3 | +[Effect of wav2lip](https://www.bilibili.com/video/BV1scwBeyELA/) | [Effect of ernerf](https://www.bilibili.com/video/BV1G1421z73r/) | [Effect of musetalk](https://www.bilibili.com/video/BV1gm421N7vQ/) | ||
| 4 | + | ||
| 5 | +## News | ||
| 6 | +- December 8, 2024: Improved multi-concurrency, and the video memory does not increase with the number of concurrent connections. | ||
| 7 | +- December 21, 2024: Added model warm-up for wav2lip and musetalk to solve the problem of stuttering during the first inference. Thanks to [@heimaojinzhangyz](https://github.com/heimaojinzhangyz) | ||
| 8 | +- December 28, 2024: Added the digital human model Ultralight-Digital-Human. Thanks to [@lijihua2017](https://github.com/lijihua2017) | ||
| 9 | +- February 7, 2025: Added fish-speech tts | ||
| 10 | +- February 21, 2025: Added the open-source model wav2lip256. Thanks to @不蠢不蠢 | ||
| 11 | +- March 2, 2025: Added Tencent's speech synthesis service | ||
| 12 | +- March 16, 2025: Supports mac gpu inference. Thanks to [@GcsSloop](https://github.com/GcsSloop) | ||
| 13 | + | ||
| 14 | +## Features | ||
| 15 | +1. Supports multiple digital human models: ernerf, musetalk, wav2lip, Ultralight-Digital-Human | ||
| 16 | +2. Supports voice cloning | ||
| 17 | +3. Supports interrupting the digital human while it is speaking | ||
| 18 | +4. Supports full-body video stitching | ||
| 19 | +5. Supports rtmp and webrtc | ||
| 20 | +6. Supports video arrangement: Play custom videos when not speaking | ||
| 21 | +7. Supports multi-concurrency | ||
| 22 | + | ||
| 23 | +## 1. Installation | ||
| 24 | + | ||
| 25 | +Tested on Ubuntu 20.04, Python 3.10, Pytorch 1.12 and CUDA 11.3 | ||
| 26 | + | ||
| 27 | +### 1.1 Install dependency | ||
| 28 | + | ||
| 29 | +```bash | ||
| 30 | +conda create -n nerfstream python=3.10 | ||
| 31 | +conda activate nerfstream | ||
| 32 | +# If the cuda version is not 11.3 (confirm the version by running nvidia-smi), install the corresponding version of pytorch according to <https://pytorch.org/get-started/previous-versions/> | ||
| 33 | +conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch | ||
| 34 | +pip install -r requirements.txt | ||
| 35 | +# If you need to train the ernerf model, install the following libraries | ||
| 36 | +# pip install "git+https://github.com/facebookresearch/pytorch3d.git" | ||
| 37 | +# pip install tensorflow-gpu==2.8.0 | ||
| 38 | +# pip install --upgrade "protobuf<=3.20.1" | ||
| 39 | +``` | ||
| 40 | +Common installation issues [FAQ](https://livetalking-doc.readthedocs.io/en/latest/faq.html) | ||
| 41 | +For setting up the linux cuda environment, you can refer to this article https://zhuanlan.zhihu.com/p/674972886 | ||
| 42 | + | ||
| 43 | + | ||
| 44 | +## 2. Quick Start | ||
| 45 | +- Download the models | ||
| 46 | +Quark Cloud Disk <https://pan.quark.cn/s/83a750323ef0> | ||
| 47 | +Google Drive <https://drive.google.com/drive/folders/1FOC_MD6wdogyyX_7V1d4NDIO7P9NlSAJ?usp=sharing> | ||
| 48 | +Copy wav2lip256.pth to the models folder of this project and rename it to wav2lip.pth; | ||
| 49 | +Extract wav2lip256_avatar1.tar.gz and copy the entire folder to the data/avatars folder of this project. | ||
| 50 | +- Run | ||
| 51 | +python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar1 | ||
| 52 | +Open http://serverip:8010/webrtcapi.html in a browser. First click'start' to play the digital human video; then enter any text in the text box and submit it. The digital human will broadcast this text. | ||
| 53 | +<font color=red>The server side needs to open ports tcp:8010; udp:1-65536</font> | ||
| 54 | +If you need to purchase a high-definition wav2lip model for commercial use, [Link](https://livetalking-doc.readthedocs.io/zh-cn/latest/service.html#wav2lip). | ||
| 55 | + | ||
| 56 | +- Quick experience | ||
| 57 | +<https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_GitHub_livetalking1.3> Create an instance with this image to run it. | ||
| 58 | + | ||
| 59 | +If you can't access huggingface, before running | ||
| 60 | +``` | ||
| 61 | +export HF_ENDPOINT=https://hf-mirror.com | ||
| 62 | +``` | ||
| 63 | + | ||
| 64 | + | ||
| 65 | +## 3. More Usage | ||
| 66 | +Usage instructions: <https://livetalking-doc.readthedocs.io/en/latest> | ||
| 67 | + | ||
| 68 | +## 4. Docker Run | ||
| 69 | +No need for the previous installation, just run directly. | ||
| 70 | +``` | ||
| 71 | +docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/codewithgpu2/lipku-metahuman-stream:2K9qaMBu8v | ||
| 72 | +``` | ||
| 73 | +The code is in /root/metahuman-stream. First, git pull to get the latest code, and then execute the commands as in steps 2 and 3. | ||
| 74 | + | ||
| 75 | +The following images are provided: | ||
| 76 | +- autodl image: <https://www.codewithgpu.com/i/lipku/metahuman-stream/base> | ||
| 77 | +[autodl Tutorial](https://livetalking-doc.readthedocs.io/en/latest/autodl/README.html) | ||
| 78 | +- ucloud image: <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_livetalking1.3> | ||
| 79 | +Any port can be opened, and there is no need to deploy an srs service additionally. | ||
| 80 | +[ucloud Tutorial](https://livetalking-doc.readthedocs.io/en/latest/ucloud/ucloud.html) | ||
| 81 | + | ||
| 82 | + | ||
| 83 | +## 5. TODO | ||
| 84 | +- [x] Added chatgpt to enable digital human dialogue | ||
| 85 | +- [x] Voice cloning | ||
| 86 | +- [x] Replace the digital human with a video when it is silent | ||
| 87 | +- [x] MuseTalk | ||
| 88 | +- [x] Wav2Lip | ||
| 89 | +- [x] Ultralight-Digital-Human | ||
| 90 | + | ||
| 91 | +--- | ||
| 92 | +If this project is helpful to you, please give it a star. Friends who are interested are also welcome to join in and improve this project together. | ||
| 93 | +* Knowledge Planet: https://t.zsxq.com/7NMyO, where high-quality common problems, best practice experiences, and problem solutions are accumulated. | ||
| 94 | +* WeChat Official Account: Digital Human Technology | ||
| 95 | + |
| 1 | -Real time interactive streaming digital human, realize audio video synchronous dialogue. It can basically achieve commercial effects. | 1 | +[English](./README-EN.md) | 中文版 |
| 2 | 实时交互流式数字人,实现音视频同步对话。基本可以达到商用效果 | 2 | 实时交互流式数字人,实现音视频同步对话。基本可以达到商用效果 |
| 3 | +[wav2lip效果](https://www.bilibili.com/video/BV1scwBeyELA/) | [ernerf效果](https://www.bilibili.com/video/BV1G1421z73r/) | [musetalk效果](https://www.bilibili.com/video/BV1gm421N7vQ/) | ||
| 3 | 4 | ||
| 4 | -[ernerf 效果](https://www.bilibili.com/video/BV1PM4m1y7Q2/) [musetalk 效果](https://www.bilibili.com/video/BV1gm421N7vQ/) [wav2lip 效果](https://www.bilibili.com/video/BV1Bw4m1e74P/) | ||
| 5 | - | ||
| 6 | -## 为避免与 3d 数字人混淆,原项目 metahuman-stream 改名为 livetalking,原有链接地址继续可用 | 5 | +## 为避免与3d数字人混淆,原项目metahuman-stream改名为livetalking,原有链接地址继续可用 |
| 7 | 6 | ||
| 8 | ## News | 7 | ## News |
| 9 | - | ||
| 10 | - 2024.12.8 完善多并发,显存不随并发数增加 | 8 | - 2024.12.8 完善多并发,显存不随并发数增加 |
| 11 | -- 2024.12.21 添加 wav2lip、musetalk 模型预热,解决第一次推理卡顿问题。感谢@heimaojinzhangyz | ||
| 12 | -- 2024.12.28 添加数字人模型 Ultralight-Digital-Human。 感谢@lijihua2017 | ||
| 13 | -- 2025.2.7 添加 fish-speech tts | ||
| 14 | -- 2025.2.21 添加 wav2lip256 开源模型 感谢@不蠢不蠢 | 9 | +- 2024.12.21 添加wav2lip、musetalk模型预热,解决第一次推理卡顿问题。感谢[@heimaojinzhangyz](https://github.com/heimaojinzhangyz) |
| 10 | +- 2024.12.28 添加数字人模型Ultralight-Digital-Human。 感谢[@lijihua2017](https://github.com/lijihua2017) | ||
| 11 | +- 2025.2.7 添加fish-speech tts | ||
| 12 | +- 2025.2.21 添加wav2lip256开源模型 感谢@不蠢不蠢 | ||
| 15 | - 2025.3.2 添加腾讯语音合成服务 | 13 | - 2025.3.2 添加腾讯语音合成服务 |
| 14 | +- 2025.3.16 支持mac gpu推理,感谢[@GcsSloop](https://github.com/GcsSloop) | ||
| 16 | 15 | ||
| 17 | ## Features | 16 | ## Features |
| 18 | - | ||
| 19 | 1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human | 17 | 1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human |
| 20 | 2. 支持声音克隆 | 18 | 2. 支持声音克隆 |
| 21 | 3. 支持数字人说话被打断 | 19 | 3. 支持数字人说话被打断 |
| 22 | 4. 支持全身视频拼接 | 20 | 4. 支持全身视频拼接 |
| 23 | -5. 支持 rtmp 和 webrtc | 21 | +5. 支持rtmp和webrtc |
| 24 | 6. 支持视频编排:不说话时播放自定义视频 | 22 | 6. 支持视频编排:不说话时播放自定义视频 |
| 25 | 7. 支持多并发 | 23 | 7. 支持多并发 |
| 26 | 24 | ||
| @@ -33,67 +31,61 @@ Tested on Ubuntu 20.04, Python3.10, Pytorch 1.12 and CUDA 11.3 | @@ -33,67 +31,61 @@ Tested on Ubuntu 20.04, Python3.10, Pytorch 1.12 and CUDA 11.3 | ||
| 33 | ```bash | 31 | ```bash |
| 34 | conda create -n nerfstream python=3.10 | 32 | conda create -n nerfstream python=3.10 |
| 35 | conda activate nerfstream | 33 | conda activate nerfstream |
| 36 | -#如果cuda版本不为11.3(运行nvidia-smi确认版本),根据<https://pytorch.org/get-started/previous-versions/>安装对应版本的pytorch | 34 | +#如果cuda版本不为11.3(运行nvidia-smi确认版本),根据<https://pytorch.org/get-started/previous-versions/>安装对应版本的pytorch |
| 37 | conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch | 35 | conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch |
| 38 | pip install -r requirements.txt | 36 | pip install -r requirements.txt |
| 39 | #如果需要训练ernerf模型,安装下面的库 | 37 | #如果需要训练ernerf模型,安装下面的库 |
| 40 | # pip install "git+https://github.com/facebookresearch/pytorch3d.git" | 38 | # pip install "git+https://github.com/facebookresearch/pytorch3d.git" |
| 41 | # pip install tensorflow-gpu==2.8.0 | 39 | # pip install tensorflow-gpu==2.8.0 |
| 42 | # pip install --upgrade "protobuf<=3.20.1" | 40 | # pip install --upgrade "protobuf<=3.20.1" |
| 43 | -``` | ||
| 44 | - | 41 | +``` |
| 45 | 安装常见问题[FAQ](https://livetalking-doc.readthedocs.io/en/latest/faq.html) | 42 | 安装常见问题[FAQ](https://livetalking-doc.readthedocs.io/en/latest/faq.html) |
| 46 | -linux cuda 环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886 | 43 | +linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886 |
| 47 | 44 | ||
| 48 | -## 2. Quick Start | ||
| 49 | 45 | ||
| 46 | +## 2. Quick Start | ||
| 50 | - 下载模型 | 47 | - 下载模型 |
| 51 | - 百度云盘<https://pan.baidu.com/s/1yOsQ06-RIDTJd3HFCw4wtA> 密码: ltua | 48 | + 夸克云盘<https://pan.quark.cn/s/83a750323ef0> |
| 52 | GoogleDriver <https://drive.google.com/drive/folders/1FOC_MD6wdogyyX_7V1d4NDIO7P9NlSAJ?usp=sharing> | 49 | GoogleDriver <https://drive.google.com/drive/folders/1FOC_MD6wdogyyX_7V1d4NDIO7P9NlSAJ?usp=sharing> |
| 53 | - 将 wav2lip256.pth 拷到本项目的 models 下, 重命名为 wav2lip.pth; | ||
| 54 | - 将 wav2lip256_avatar1.tar.gz 解压后整个文件夹拷到本项目的 data/avatars 下 | 50 | + 将wav2lip256.pth拷到本项目的models下, 重命名为wav2lip.pth; |
| 51 | + 将wav2lip256_avatar1.tar.gz解压后整个文件夹拷到本项目的data/avatars下 | ||
| 55 | - 运行 | 52 | - 运行 |
| 56 | python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar1 --preload 2 | 53 | python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar1 --preload 2 |
| 57 | - | ||
| 58 | 使用 GPU 启动模特 3 号:python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar3 --preload 2 | 54 | 使用 GPU 启动模特 3 号:python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar3 --preload 2 |
| 59 | - 用浏览器打开 http://serverip:8010/webrtcapi.html , 先点‘start',播放数字人视频;然后在文本框输入任意文字,提交。数字人播报该段文字 | 55 | + |
| 56 | +用浏览器打开http://serverip:8010/webrtcapi.html , 先点‘start',播放数字人视频;然后在文本框输入任意文字,提交。数字人播报该段文字 | ||
| 60 | <font color=red>服务端需要开放端口 tcp:8010; udp:1-65536 </font> | 57 | <font color=red>服务端需要开放端口 tcp:8010; udp:1-65536 </font> |
| 61 | - 如果需要商用高清 wav2lip 模型,可以与我联系购买 | 58 | + 如果需要商用高清wav2lip模型,[链接](https://livetalking-doc.readthedocs.io/zh-cn/latest/service.html#wav2lip) |
| 62 | 59 | ||
| 63 | - 快速体验 | 60 | - 快速体验 |
| 64 | <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_GitHub_livetalking1.3> 用该镜像创建实例即可运行成功 | 61 | <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_GitHub_livetalking1.3> 用该镜像创建实例即可运行成功 |
| 65 | 62 | ||
| 66 | -如果访问不了 huggingface,在运行前 | ||
| 67 | - | 63 | +如果访问不了huggingface,在运行前 |
| 68 | ``` | 64 | ``` |
| 69 | export HF_ENDPOINT=https://hf-mirror.com | 65 | export HF_ENDPOINT=https://hf-mirror.com |
| 70 | -``` | 66 | +``` |
| 71 | 67 | ||
| 72 | -## 3. More Usage | ||
| 73 | 68 | ||
| 69 | +## 3. More Usage | ||
| 74 | 使用说明: <https://livetalking-doc.readthedocs.io/> | 70 | 使用说明: <https://livetalking-doc.readthedocs.io/> |
| 75 | 71 | ||
| 76 | ## 4. Docker Run | 72 | ## 4. Docker Run |
| 77 | - | ||
| 78 | 不需要前面的安装,直接运行。 | 73 | 不需要前面的安装,直接运行。 |
| 79 | - | ||
| 80 | ``` | 74 | ``` |
| 81 | docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/codewithgpu2/lipku-metahuman-stream:2K9qaMBu8v | 75 | docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/codewithgpu2/lipku-metahuman-stream:2K9qaMBu8v |
| 82 | ``` | 76 | ``` |
| 83 | - | ||
| 84 | -代码在/root/metahuman-stream,先 git pull 拉一下最新代码,然后执行命令同第 2、3 步 | 77 | +代码在/root/metahuman-stream,先git pull拉一下最新代码,然后执行命令同第2、3步 |
| 85 | 78 | ||
| 86 | 提供如下镜像 | 79 | 提供如下镜像 |
| 80 | +- autodl镜像: <https://www.codewithgpu.com/i/lipku/metahuman-stream/base> | ||
| 81 | + [autodl教程](https://livetalking-doc.readthedocs.io/en/latest/autodl/README.html) | ||
| 82 | +- ucloud镜像: <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_livetalking1.3> | ||
| 83 | + 可以开放任意端口,不需要另外部署srs服务. | ||
| 84 | + [ucloud教程](https://livetalking-doc.readthedocs.io/en/latest/ucloud/ucloud.html) | ||
| 87 | 85 | ||
| 88 | -- autodl 镜像: <https://www.codewithgpu.com/i/lipku/metahuman-stream/base> | ||
| 89 | - [autodl 教程](https://livetalking-doc.readthedocs.io/en/latest/autodl/README.html) | ||
| 90 | -- ucloud 镜像: <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_livetalking1.3> | ||
| 91 | - 可以开放任意端口,不需要另外部署 srs 服务. | ||
| 92 | - [ucloud 教程](https://livetalking-doc.readthedocs.io/en/latest/ucloud/ucloud.html) | ||
| 93 | 86 | ||
| 94 | ## 5. TODO | 87 | ## 5. TODO |
| 95 | - | ||
| 96 | -- [x] 添加 chatgpt 实现数字人对话 | 88 | +- [x] 添加chatgpt实现数字人对话 |
| 97 | - [x] 声音克隆 | 89 | - [x] 声音克隆 |
| 98 | - [x] 数字人静音时用一段视频代替 | 90 | - [x] 数字人静音时用一段视频代替 |
| 99 | - [x] MuseTalk | 91 | - [x] MuseTalk |
| @@ -101,9 +93,8 @@ docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/c | @@ -101,9 +93,8 @@ docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/c | ||
| 101 | - [x] Ultralight-Digital-Human | 93 | - [x] Ultralight-Digital-Human |
| 102 | 94 | ||
| 103 | --- | 95 | --- |
| 96 | +如果本项目对你有帮助,帮忙点个star。也欢迎感兴趣的朋友一起来完善该项目. | ||
| 97 | +* 知识星球: https://t.zsxq.com/7NMyO 沉淀高质量常见问题、最佳实践经验、问题解答 | ||
| 98 | +* 微信公众号:数字人技术 | ||
| 99 | +  | ||
| 104 | 100 | ||
| 105 | -如果本项目对你有帮助,帮忙点个 star。也欢迎感兴趣的朋友一起来完善该项目. | ||
| 106 | - | ||
| 107 | -- 知识星球: https://t.zsxq.com/7NMyO 沉淀高质量常见问题、最佳实践经验、问题解答 | ||
| 108 | -- 微信公众号:数字人技术 | ||
| 109 | -  |
| @@ -201,7 +201,7 @@ async def set_audiotype(request): | @@ -201,7 +201,7 @@ async def set_audiotype(request): | ||
| 201 | params = await request.json() | 201 | params = await request.json() |
| 202 | 202 | ||
| 203 | sessionid = params.get('sessionid',0) | 203 | sessionid = params.get('sessionid',0) |
| 204 | - nerfreals[sessionid].set_curr_state(params['audiotype'],params['reinit']) | 204 | + nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit']) |
| 205 | 205 | ||
| 206 | return web.Response( | 206 | return web.Response( |
| 207 | content_type="application/json", | 207 | content_type="application/json", |
| @@ -495,6 +495,8 @@ if __name__ == '__main__': | @@ -495,6 +495,8 @@ if __name__ == '__main__': | ||
| 495 | elif opt.transport=='rtcpush': | 495 | elif opt.transport=='rtcpush': |
| 496 | pagename='rtcpushapi.html' | 496 | pagename='rtcpushapi.html' |
| 497 | logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename) | 497 | logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename) |
| 498 | + logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html') | ||
| 499 | + | ||
| 498 | def run_server(runner): | 500 | def run_server(runner): |
| 499 | loop = asyncio.new_event_loop() | 501 | loop = asyncio.new_event_loop() |
| 500 | asyncio.set_event_loop(loop) | 502 | asyncio.set_event_loop(loop) |
| @@ -35,7 +35,7 @@ import soundfile as sf | @@ -35,7 +35,7 @@ import soundfile as sf | ||
| 35 | import av | 35 | import av |
| 36 | from fractions import Fraction | 36 | from fractions import Fraction |
| 37 | 37 | ||
| 38 | -from ttsreal import EdgeTTS,VoitsTTS,XTTS,CosyVoiceTTS,FishTTS,TencentTTS | 38 | +from ttsreal import EdgeTTS,SovitsTTS,XTTS,CosyVoiceTTS,FishTTS,TencentTTS |
| 39 | from logger import logger | 39 | from logger import logger |
| 40 | 40 | ||
| 41 | from tqdm import tqdm | 41 | from tqdm import tqdm |
| @@ -57,7 +57,7 @@ class BaseReal: | @@ -57,7 +57,7 @@ class BaseReal: | ||
| 57 | if opt.tts == "edgetts": | 57 | if opt.tts == "edgetts": |
| 58 | self.tts = EdgeTTS(opt,self) | 58 | self.tts = EdgeTTS(opt,self) |
| 59 | elif opt.tts == "gpt-sovits": | 59 | elif opt.tts == "gpt-sovits": |
| 60 | - self.tts = VoitsTTS(opt,self) | 60 | + self.tts = SovitsTTS(opt,self) |
| 61 | elif opt.tts == "xtts": | 61 | elif opt.tts == "xtts": |
| 62 | self.tts = XTTS(opt,self) | 62 | self.tts = XTTS(opt,self) |
| 63 | elif opt.tts == "cosyvoice": | 63 | elif opt.tts == "cosyvoice": |
| @@ -66,7 +66,7 @@ class BaseReal: | @@ -66,7 +66,7 @@ class BaseReal: | ||
| 66 | self.tts = FishTTS(opt,self) | 66 | self.tts = FishTTS(opt,self) |
| 67 | elif opt.tts == "tencent": | 67 | elif opt.tts == "tencent": |
| 68 | self.tts = TencentTTS(opt,self) | 68 | self.tts = TencentTTS(opt,self) |
| 69 | - | 69 | + |
| 70 | self.speaking = False | 70 | self.speaking = False |
| 71 | 71 | ||
| 72 | self.recording = False | 72 | self.recording = False |
| @@ -84,11 +84,11 @@ class BaseReal: | @@ -84,11 +84,11 @@ class BaseReal: | ||
| 84 | 84 | ||
| 85 | def put_msg_txt(self,msg,eventpoint=None): | 85 | def put_msg_txt(self,msg,eventpoint=None): |
| 86 | self.tts.put_msg_txt(msg,eventpoint) | 86 | self.tts.put_msg_txt(msg,eventpoint) |
| 87 | - | 87 | + |
| 88 | def put_audio_frame(self,audio_chunk,eventpoint=None): #16khz 20ms pcm | 88 | def put_audio_frame(self,audio_chunk,eventpoint=None): #16khz 20ms pcm |
| 89 | self.asr.put_audio_frame(audio_chunk,eventpoint) | 89 | self.asr.put_audio_frame(audio_chunk,eventpoint) |
| 90 | 90 | ||
| 91 | - def put_audio_file(self,filebyte): | 91 | + def put_audio_file(self,filebyte): |
| 92 | input_stream = BytesIO(filebyte) | 92 | input_stream = BytesIO(filebyte) |
| 93 | stream = self.__create_bytes_stream(input_stream) | 93 | stream = self.__create_bytes_stream(input_stream) |
| 94 | streamlen = stream.shape[0] | 94 | streamlen = stream.shape[0] |
| @@ -97,7 +97,7 @@ class BaseReal: | @@ -97,7 +97,7 @@ class BaseReal: | ||
| 97 | self.put_audio_frame(stream[idx:idx+self.chunk]) | 97 | self.put_audio_frame(stream[idx:idx+self.chunk]) |
| 98 | streamlen -= self.chunk | 98 | streamlen -= self.chunk |
| 99 | idx += self.chunk | 99 | idx += self.chunk |
| 100 | - | 100 | + |
| 101 | def __create_bytes_stream(self,byte_stream): | 101 | def __create_bytes_stream(self,byte_stream): |
| 102 | #byte_stream=BytesIO(buffer) | 102 | #byte_stream=BytesIO(buffer) |
| 103 | stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 | 103 | stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 |
| @@ -107,7 +107,7 @@ class BaseReal: | @@ -107,7 +107,7 @@ class BaseReal: | ||
| 107 | if stream.ndim > 1: | 107 | if stream.ndim > 1: |
| 108 | logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') | 108 | logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') |
| 109 | stream = stream[:, 0] | 109 | stream = stream[:, 0] |
| 110 | - | 110 | + |
| 111 | if sample_rate != self.sample_rate and stream.shape[0]>0: | 111 | if sample_rate != self.sample_rate and stream.shape[0]>0: |
| 112 | logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') | 112 | logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') |
| 113 | stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) | 113 | stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) |
| @@ -120,7 +120,7 @@ class BaseReal: | @@ -120,7 +120,7 @@ class BaseReal: | ||
| 120 | 120 | ||
| 121 | def is_speaking(self)->bool: | 121 | def is_speaking(self)->bool: |
| 122 | return self.speaking | 122 | return self.speaking |
| 123 | - | 123 | + |
| 124 | def __loadcustom(self): | 124 | def __loadcustom(self): |
| 125 | for item in self.opt.customopt: | 125 | for item in self.opt.customopt: |
| 126 | logger.info(item) | 126 | logger.info(item) |
| @@ -155,9 +155,9 @@ class BaseReal: | @@ -155,9 +155,9 @@ class BaseReal: | ||
| 155 | '-s', "{}x{}".format(self.width, self.height), | 155 | '-s', "{}x{}".format(self.width, self.height), |
| 156 | '-r', str(25), | 156 | '-r', str(25), |
| 157 | '-i', '-', | 157 | '-i', '-', |
| 158 | - '-pix_fmt', 'yuv420p', | 158 | + '-pix_fmt', 'yuv420p', |
| 159 | '-vcodec', "h264", | 159 | '-vcodec', "h264", |
| 160 | - #'-f' , 'flv', | 160 | + #'-f' , 'flv', |
| 161 | f'temp{self.opt.sessionid}.mp4'] | 161 | f'temp{self.opt.sessionid}.mp4'] |
| 162 | self._record_video_pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE) | 162 | self._record_video_pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE) |
| 163 | 163 | ||
| @@ -169,7 +169,7 @@ class BaseReal: | @@ -169,7 +169,7 @@ class BaseReal: | ||
| 169 | '-ar', '16000', | 169 | '-ar', '16000', |
| 170 | '-i', '-', | 170 | '-i', '-', |
| 171 | '-acodec', 'aac', | 171 | '-acodec', 'aac', |
| 172 | - #'-f' , 'wav', | 172 | + #'-f' , 'wav', |
| 173 | f'temp{self.opt.sessionid}.aac'] | 173 | f'temp{self.opt.sessionid}.aac'] |
| 174 | self._record_audio_pipe = subprocess.Popen(acommand, shell=False, stdin=subprocess.PIPE) | 174 | self._record_audio_pipe = subprocess.Popen(acommand, shell=False, stdin=subprocess.PIPE) |
| 175 | 175 | ||
| @@ -177,10 +177,10 @@ class BaseReal: | @@ -177,10 +177,10 @@ class BaseReal: | ||
| 177 | # self.recordq_video.queue.clear() | 177 | # self.recordq_video.queue.clear() |
| 178 | # self.recordq_audio.queue.clear() | 178 | # self.recordq_audio.queue.clear() |
| 179 | # self.container = av.open(path, mode="w") | 179 | # self.container = av.open(path, mode="w") |
| 180 | - | 180 | + |
| 181 | # process_thread = Thread(target=self.record_frame, args=()) | 181 | # process_thread = Thread(target=self.record_frame, args=()) |
| 182 | # process_thread.start() | 182 | # process_thread.start() |
| 183 | - | 183 | + |
| 184 | def record_video_data(self,image): | 184 | def record_video_data(self,image): |
| 185 | if self.width == 0: | 185 | if self.width == 0: |
| 186 | print("image.shape:",image.shape) | 186 | print("image.shape:",image.shape) |
| @@ -191,14 +191,14 @@ class BaseReal: | @@ -191,14 +191,14 @@ class BaseReal: | ||
| 191 | def record_audio_data(self,frame): | 191 | def record_audio_data(self,frame): |
| 192 | if self.recording: | 192 | if self.recording: |
| 193 | self._record_audio_pipe.stdin.write(frame.tostring()) | 193 | self._record_audio_pipe.stdin.write(frame.tostring()) |
| 194 | - | ||
| 195 | - # def record_frame(self): | 194 | + |
| 195 | + # def record_frame(self): | ||
| 196 | # videostream = self.container.add_stream("libx264", rate=25) | 196 | # videostream = self.container.add_stream("libx264", rate=25) |
| 197 | # videostream.codec_context.time_base = Fraction(1, 25) | 197 | # videostream.codec_context.time_base = Fraction(1, 25) |
| 198 | # audiostream = self.container.add_stream("aac") | 198 | # audiostream = self.container.add_stream("aac") |
| 199 | # audiostream.codec_context.time_base = Fraction(1, 16000) | 199 | # audiostream.codec_context.time_base = Fraction(1, 16000) |
| 200 | # init = True | 200 | # init = True |
| 201 | - # framenum = 0 | 201 | + # framenum = 0 |
| 202 | # while self.recording: | 202 | # while self.recording: |
| 203 | # try: | 203 | # try: |
| 204 | # videoframe = self.recordq_video.get(block=True, timeout=1) | 204 | # videoframe = self.recordq_video.get(block=True, timeout=1) |
| @@ -231,18 +231,18 @@ class BaseReal: | @@ -231,18 +231,18 @@ class BaseReal: | ||
| 231 | # self.recordq_video.queue.clear() | 231 | # self.recordq_video.queue.clear() |
| 232 | # self.recordq_audio.queue.clear() | 232 | # self.recordq_audio.queue.clear() |
| 233 | # print('record thread stop') | 233 | # print('record thread stop') |
| 234 | - | 234 | + |
| 235 | def stop_recording(self): | 235 | def stop_recording(self): |
| 236 | """停止录制视频""" | 236 | """停止录制视频""" |
| 237 | if not self.recording: | 237 | if not self.recording: |
| 238 | return | 238 | return |
| 239 | - self.recording = False | ||
| 240 | - self._record_video_pipe.stdin.close() #wait() | 239 | + self.recording = False |
| 240 | + self._record_video_pipe.stdin.close() #wait() | ||
| 241 | self._record_video_pipe.wait() | 241 | self._record_video_pipe.wait() |
| 242 | self._record_audio_pipe.stdin.close() | 242 | self._record_audio_pipe.stdin.close() |
| 243 | self._record_audio_pipe.wait() | 243 | self._record_audio_pipe.wait() |
| 244 | cmd_combine_audio = f"ffmpeg -y -i temp{self.opt.sessionid}.aac -i temp{self.opt.sessionid}.mp4 -c:v copy -c:a copy data/record.mp4" | 244 | cmd_combine_audio = f"ffmpeg -y -i temp{self.opt.sessionid}.aac -i temp{self.opt.sessionid}.mp4 -c:v copy -c:a copy data/record.mp4" |
| 245 | - os.system(cmd_combine_audio) | 245 | + os.system(cmd_combine_audio) |
| 246 | #os.remove(output_path) | 246 | #os.remove(output_path) |
| 247 | 247 | ||
| 248 | def mirror_index(self,size, index): | 248 | def mirror_index(self,size, index): |
| @@ -252,8 +252,8 @@ class BaseReal: | @@ -252,8 +252,8 @@ class BaseReal: | ||
| 252 | if turn % 2 == 0: | 252 | if turn % 2 == 0: |
| 253 | return res | 253 | return res |
| 254 | else: | 254 | else: |
| 255 | - return size - res - 1 | ||
| 256 | - | 255 | + return size - res - 1 |
| 256 | + | ||
| 257 | def get_audio_stream(self,audiotype): | 257 | def get_audio_stream(self,audiotype): |
| 258 | idx = self.custom_audio_index[audiotype] | 258 | idx = self.custom_audio_index[audiotype] |
| 259 | stream = self.custom_audio_cycle[audiotype][idx:idx+self.chunk] | 259 | stream = self.custom_audio_cycle[audiotype][idx:idx+self.chunk] |
| @@ -261,9 +261,9 @@ class BaseReal: | @@ -261,9 +261,9 @@ class BaseReal: | ||
| 261 | if self.custom_audio_index[audiotype]>=self.custom_audio_cycle[audiotype].shape[0]: | 261 | if self.custom_audio_index[audiotype]>=self.custom_audio_cycle[audiotype].shape[0]: |
| 262 | self.curr_state = 1 #当前视频不循环播放,切换到静音状态 | 262 | self.curr_state = 1 #当前视频不循环播放,切换到静音状态 |
| 263 | return stream | 263 | return stream |
| 264 | - | ||
| 265 | - def set_curr_state(self,audiotype, reinit): | ||
| 266 | - print('set_curr_state:',audiotype) | 264 | + |
| 265 | + def set_custom_state(self,audiotype, reinit=True): | ||
| 266 | + print('set_custom_state:',audiotype) | ||
| 267 | self.curr_state = audiotype | 267 | self.curr_state = audiotype |
| 268 | if reinit: | 268 | if reinit: |
| 269 | self.custom_audio_index[audiotype] = 0 | 269 | self.custom_audio_index[audiotype] = 0 |
| @@ -179,8 +179,11 @@ print(f'[INFO] fitting light...') | @@ -179,8 +179,11 @@ print(f'[INFO] fitting light...') | ||
| 179 | 179 | ||
| 180 | batch_size = 32 | 180 | batch_size = 32 |
| 181 | 181 | ||
| 182 | -device_default = torch.device("cuda:0") | ||
| 183 | -device_render = torch.device("cuda:0") | 182 | +device_default = torch.device("cuda:0" if torch.cuda.is_available() else ( |
| 183 | + "mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) | ||
| 184 | +device_render = torch.device("cuda:0" if torch.cuda.is_available() else ( | ||
| 185 | + "mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) | ||
| 186 | + | ||
| 184 | renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render) | 187 | renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render) |
| 185 | 188 | ||
| 186 | sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size] | 189 | sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size] |
| @@ -83,7 +83,7 @@ class Render_3DMM(nn.Module): | @@ -83,7 +83,7 @@ class Render_3DMM(nn.Module): | ||
| 83 | img_h=500, | 83 | img_h=500, |
| 84 | img_w=500, | 84 | img_w=500, |
| 85 | batch_size=1, | 85 | batch_size=1, |
| 86 | - device=torch.device("cuda:0"), | 86 | + device=torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")), |
| 87 | ): | 87 | ): |
| 88 | super(Render_3DMM, self).__init__() | 88 | super(Render_3DMM, self).__init__() |
| 89 | 89 |
| @@ -147,7 +147,7 @@ if __name__ == '__main__': | @@ -147,7 +147,7 @@ if __name__ == '__main__': | ||
| 147 | 147 | ||
| 148 | seed_everything(opt.seed) | 148 | seed_everything(opt.seed) |
| 149 | 149 | ||
| 150 | - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 150 | + device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) |
| 151 | 151 | ||
| 152 | model = NeRFNetwork(opt) | 152 | model = NeRFNetwork(opt) |
| 153 | 153 |
| @@ -442,7 +442,7 @@ class LPIPSMeter: | @@ -442,7 +442,7 @@ class LPIPSMeter: | ||
| 442 | self.N = 0 | 442 | self.N = 0 |
| 443 | self.net = net | 443 | self.net = net |
| 444 | 444 | ||
| 445 | - self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 445 | + self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else ('mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu')) |
| 446 | self.fn = lpips.LPIPS(net=net).eval().to(self.device) | 446 | self.fn = lpips.LPIPS(net=net).eval().to(self.device) |
| 447 | 447 | ||
| 448 | def clear(self): | 448 | def clear(self): |
| @@ -456,13 +456,13 @@ class LPIPSMeter: | @@ -456,13 +456,13 @@ class LPIPSMeter: | ||
| 456 | inp = inp.to(self.device) | 456 | inp = inp.to(self.device) |
| 457 | outputs.append(inp) | 457 | outputs.append(inp) |
| 458 | return outputs | 458 | return outputs |
| 459 | - | 459 | + |
| 460 | def update(self, preds, truths): | 460 | def update(self, preds, truths): |
| 461 | preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1] | 461 | preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1] |
| 462 | v = self.fn(truths, preds, normalize=True).item() # normalize=True: [0, 1] to [-1, 1] | 462 | v = self.fn(truths, preds, normalize=True).item() # normalize=True: [0, 1] to [-1, 1] |
| 463 | self.V += v | 463 | self.V += v |
| 464 | self.N += 1 | 464 | self.N += 1 |
| 465 | - | 465 | + |
| 466 | def measure(self): | 466 | def measure(self): |
| 467 | return self.V / self.N | 467 | return self.V / self.N |
| 468 | 468 | ||
| @@ -499,7 +499,7 @@ class LMDMeter: | @@ -499,7 +499,7 @@ class LMDMeter: | ||
| 499 | 499 | ||
| 500 | self.V = 0 | 500 | self.V = 0 |
| 501 | self.N = 0 | 501 | self.N = 0 |
| 502 | - | 502 | + |
| 503 | def get_landmarks(self, img): | 503 | def get_landmarks(self, img): |
| 504 | 504 | ||
| 505 | if self.backend == 'dlib': | 505 | if self.backend == 'dlib': |
| @@ -515,7 +515,7 @@ class LMDMeter: | @@ -515,7 +515,7 @@ class LMDMeter: | ||
| 515 | 515 | ||
| 516 | else: | 516 | else: |
| 517 | lms = self.predictor.get_landmarks(img)[-1] | 517 | lms = self.predictor.get_landmarks(img)[-1] |
| 518 | - | 518 | + |
| 519 | # self.vis_landmarks(img, lms) | 519 | # self.vis_landmarks(img, lms) |
| 520 | lms = lms.astype(np.float32) | 520 | lms = lms.astype(np.float32) |
| 521 | 521 | ||
| @@ -537,7 +537,7 @@ class LMDMeter: | @@ -537,7 +537,7 @@ class LMDMeter: | ||
| 537 | inp = (inp * 255).astype(np.uint8) | 537 | inp = (inp * 255).astype(np.uint8) |
| 538 | outputs.append(inp) | 538 | outputs.append(inp) |
| 539 | return outputs | 539 | return outputs |
| 540 | - | 540 | + |
| 541 | def update(self, preds, truths): | 541 | def update(self, preds, truths): |
| 542 | # assert B == 1 | 542 | # assert B == 1 |
| 543 | preds, truths = self.prepare_inputs(preds[0], truths[0]) # [H, W, 3] numpy array | 543 | preds, truths = self.prepare_inputs(preds[0], truths[0]) # [H, W, 3] numpy array |
| @@ -553,13 +553,13 @@ class LMDMeter: | @@ -553,13 +553,13 @@ class LMDMeter: | ||
| 553 | # avarage | 553 | # avarage |
| 554 | lms_pred = lms_pred - lms_pred.mean(0) | 554 | lms_pred = lms_pred - lms_pred.mean(0) |
| 555 | lms_truth = lms_truth - lms_truth.mean(0) | 555 | lms_truth = lms_truth - lms_truth.mean(0) |
| 556 | - | 556 | + |
| 557 | # distance | 557 | # distance |
| 558 | dist = np.sqrt(((lms_pred - lms_truth) ** 2).sum(1)).mean(0) | 558 | dist = np.sqrt(((lms_pred - lms_truth) ** 2).sum(1)).mean(0) |
| 559 | - | 559 | + |
| 560 | self.V += dist | 560 | self.V += dist |
| 561 | self.N += 1 | 561 | self.N += 1 |
| 562 | - | 562 | + |
| 563 | def measure(self): | 563 | def measure(self): |
| 564 | return self.V / self.N | 564 | return self.V / self.N |
| 565 | 565 | ||
| @@ -567,14 +567,14 @@ class LMDMeter: | @@ -567,14 +567,14 @@ class LMDMeter: | ||
| 567 | writer.add_scalar(os.path.join(prefix, f"LMD ({self.backend})"), self.measure(), global_step) | 567 | writer.add_scalar(os.path.join(prefix, f"LMD ({self.backend})"), self.measure(), global_step) |
| 568 | 568 | ||
| 569 | def report(self): | 569 | def report(self): |
| 570 | - return f'LMD ({self.backend}) = {self.measure():.6f}' | ||
| 571 | - | 570 | + return f'LMD ({self.backend}) = {self.measure():.6f}' |
| 571 | + | ||
| 572 | 572 | ||
| 573 | class Trainer(object): | 573 | class Trainer(object): |
| 574 | - def __init__(self, | 574 | + def __init__(self, |
| 575 | name, # name of this experiment | 575 | name, # name of this experiment |
| 576 | opt, # extra conf | 576 | opt, # extra conf |
| 577 | - model, # network | 577 | + model, # network |
| 578 | criterion=None, # loss function, if None, assume inline implementation in train_step | 578 | criterion=None, # loss function, if None, assume inline implementation in train_step |
| 579 | optimizer=None, # optimizer | 579 | optimizer=None, # optimizer |
| 580 | ema_decay=None, # if use EMA, set the decay | 580 | ema_decay=None, # if use EMA, set the decay |
| @@ -596,7 +596,7 @@ class Trainer(object): | @@ -596,7 +596,7 @@ class Trainer(object): | ||
| 596 | use_tensorboardX=True, # whether to use tensorboard for logging | 596 | use_tensorboardX=True, # whether to use tensorboard for logging |
| 597 | scheduler_update_every_step=False, # whether to call scheduler.step() after every train step | 597 | scheduler_update_every_step=False, # whether to call scheduler.step() after every train step |
| 598 | ): | 598 | ): |
| 599 | - | 599 | + |
| 600 | self.name = name | 600 | self.name = name |
| 601 | self.opt = opt | 601 | self.opt = opt |
| 602 | self.mute = mute | 602 | self.mute = mute |
| @@ -618,7 +618,11 @@ class Trainer(object): | @@ -618,7 +618,11 @@ class Trainer(object): | ||
| 618 | self.flip_init_lips = self.opt.init_lips | 618 | self.flip_init_lips = self.opt.init_lips |
| 619 | self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") | 619 | self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") |
| 620 | self.scheduler_update_every_step = scheduler_update_every_step | 620 | self.scheduler_update_every_step = scheduler_update_every_step |
| 621 | - self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') | 621 | + self.device = device if device is not None else torch.device( |
| 622 | + f'cuda:{local_rank}' if torch.cuda.is_available() else ( | ||
| 623 | + 'mps' if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else 'cpu' | ||
| 624 | + ) | ||
| 625 | + ) | ||
| 622 | self.console = Console() | 626 | self.console = Console() |
| 623 | 627 | ||
| 624 | model.to(self.device) | 628 | model.to(self.device) |
| @@ -56,10 +56,8 @@ from ultralight.unet import Model | @@ -56,10 +56,8 @@ from ultralight.unet import Model | ||
| 56 | from ultralight.audio2feature import Audio2Feature | 56 | from ultralight.audio2feature import Audio2Feature |
| 57 | from logger import logger | 57 | from logger import logger |
| 58 | 58 | ||
| 59 | - | ||
| 60 | -device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
| 61 | -logger.info('Using {} for inference.'.format(device)) | ||
| 62 | - | 59 | +device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") |
| 60 | +print('Using {} for inference.'.format(device)) | ||
| 63 | 61 | ||
| 64 | def load_model(opt): | 62 | def load_model(opt): |
| 65 | audio_processor = Audio2Feature() | 63 | audio_processor = Audio2Feature() |
| @@ -44,8 +44,8 @@ from basereal import BaseReal | @@ -44,8 +44,8 @@ from basereal import BaseReal | ||
| 44 | from tqdm import tqdm | 44 | from tqdm import tqdm |
| 45 | from logger import logger | 45 | from logger import logger |
| 46 | 46 | ||
| 47 | -device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
| 48 | -logger.info('Using {} for inference.'.format(device)) | 47 | +device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") |
| 48 | +print('Using {} for inference.'.format(device)) | ||
| 49 | 49 | ||
| 50 | def _load(checkpoint_path): | 50 | def _load(checkpoint_path): |
| 51 | if device == 'cuda': | 51 | if device == 'cuda': |
| @@ -51,7 +51,7 @@ from logger import logger | @@ -51,7 +51,7 @@ from logger import logger | ||
| 51 | def load_model(): | 51 | def load_model(): |
| 52 | # load model weights | 52 | # load model weights |
| 53 | audio_processor,vae, unet, pe = load_all_model() | 53 | audio_processor,vae, unet, pe = load_all_model() |
| 54 | - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 54 | + device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) |
| 55 | timesteps = torch.tensor([0], device=device) | 55 | timesteps = torch.tensor([0], device=device) |
| 56 | pe = pe.half() | 56 | pe = pe.half() |
| 57 | vae.vae = vae.vae.half() | 57 | vae.vae = vae.vae.half() |
| @@ -64,7 +64,7 @@ def load_avatar(avatar_id): | @@ -64,7 +64,7 @@ def load_avatar(avatar_id): | ||
| 64 | #self.video_path = '' #video_path | 64 | #self.video_path = '' #video_path |
| 65 | #self.bbox_shift = opt.bbox_shift | 65 | #self.bbox_shift = opt.bbox_shift |
| 66 | avatar_path = f"./data/avatars/{avatar_id}" | 66 | avatar_path = f"./data/avatars/{avatar_id}" |
| 67 | - full_imgs_path = f"{avatar_path}/full_imgs" | 67 | + full_imgs_path = f"{avatar_path}/full_imgs" |
| 68 | coords_path = f"{avatar_path}/coords.pkl" | 68 | coords_path = f"{avatar_path}/coords.pkl" |
| 69 | latents_out_path= f"{avatar_path}/latents.pt" | 69 | latents_out_path= f"{avatar_path}/latents.pt" |
| 70 | video_out_path = f"{avatar_path}/vid_output/" | 70 | video_out_path = f"{avatar_path}/vid_output/" |
| @@ -74,7 +74,7 @@ def load_avatar(avatar_id): | @@ -74,7 +74,7 @@ def load_avatar(avatar_id): | ||
| 74 | # self.avatar_info = { | 74 | # self.avatar_info = { |
| 75 | # "avatar_id":self.avatar_id, | 75 | # "avatar_id":self.avatar_id, |
| 76 | # "video_path":self.video_path, | 76 | # "video_path":self.video_path, |
| 77 | - # "bbox_shift":self.bbox_shift | 77 | + # "bbox_shift":self.bbox_shift |
| 78 | # } | 78 | # } |
| 79 | 79 | ||
| 80 | input_latent_list_cycle = torch.load(latents_out_path) #,weights_only=True | 80 | input_latent_list_cycle = torch.load(latents_out_path) #,weights_only=True |
| @@ -124,19 +124,19 @@ def __mirror_index(size, index): | @@ -124,19 +124,19 @@ def __mirror_index(size, index): | ||
| 124 | if turn % 2 == 0: | 124 | if turn % 2 == 0: |
| 125 | return res | 125 | return res |
| 126 | else: | 126 | else: |
| 127 | - return size - res - 1 | 127 | + return size - res - 1 |
| 128 | 128 | ||
| 129 | @torch.no_grad() | 129 | @torch.no_grad() |
| 130 | def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue, | 130 | def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue, |
| 131 | vae, unet, pe,timesteps): #vae, unet, pe,timesteps | 131 | vae, unet, pe,timesteps): #vae, unet, pe,timesteps |
| 132 | - | 132 | + |
| 133 | # vae, unet, pe = load_diffusion_model() | 133 | # vae, unet, pe = load_diffusion_model() |
| 134 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 134 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 135 | # timesteps = torch.tensor([0], device=device) | 135 | # timesteps = torch.tensor([0], device=device) |
| 136 | # pe = pe.half() | 136 | # pe = pe.half() |
| 137 | # vae.vae = vae.vae.half() | 137 | # vae.vae = vae.vae.half() |
| 138 | # unet.model = unet.model.half() | 138 | # unet.model = unet.model.half() |
| 139 | - | 139 | + |
| 140 | length = len(input_latent_list_cycle) | 140 | length = len(input_latent_list_cycle) |
| 141 | index = 0 | 141 | index = 0 |
| 142 | count=0 | 142 | count=0 |
| @@ -169,7 +169,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a | @@ -169,7 +169,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a | ||
| 169 | latent = input_latent_list_cycle[idx] | 169 | latent = input_latent_list_cycle[idx] |
| 170 | latent_batch.append(latent) | 170 | latent_batch.append(latent) |
| 171 | latent_batch = torch.cat(latent_batch, dim=0) | 171 | latent_batch = torch.cat(latent_batch, dim=0) |
| 172 | - | 172 | + |
| 173 | # for i, (whisper_batch,latent_batch) in enumerate(gen): | 173 | # for i, (whisper_batch,latent_batch) in enumerate(gen): |
| 174 | audio_feature_batch = torch.from_numpy(whisper_batch) | 174 | audio_feature_batch = torch.from_numpy(whisper_batch) |
| 175 | audio_feature_batch = audio_feature_batch.to(device=unet.device, | 175 | audio_feature_batch = audio_feature_batch.to(device=unet.device, |
| @@ -179,8 +179,8 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a | @@ -179,8 +179,8 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a | ||
| 179 | # print('prepare time:',time.perf_counter()-t) | 179 | # print('prepare time:',time.perf_counter()-t) |
| 180 | # t=time.perf_counter() | 180 | # t=time.perf_counter() |
| 181 | 181 | ||
| 182 | - pred_latents = unet.model(latent_batch, | ||
| 183 | - timesteps, | 182 | + pred_latents = unet.model(latent_batch, |
| 183 | + timesteps, | ||
| 184 | encoder_hidden_states=audio_feature_batch).sample | 184 | encoder_hidden_states=audio_feature_batch).sample |
| 185 | # print('unet time:',time.perf_counter()-t) | 185 | # print('unet time:',time.perf_counter()-t) |
| 186 | # t=time.perf_counter() | 186 | # t=time.perf_counter() |
| @@ -203,7 +203,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a | @@ -203,7 +203,7 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a | ||
| 203 | #self.__pushmedia(res_frame,loop,audio_track,video_track) | 203 | #self.__pushmedia(res_frame,loop,audio_track,video_track) |
| 204 | res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) | 204 | res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) |
| 205 | index = index + 1 | 205 | index = index + 1 |
| 206 | - #print('total batch time:',time.perf_counter()-starttime) | 206 | + #print('total batch time:',time.perf_counter()-starttime) |
| 207 | logger.info('musereal inference processor stop') | 207 | logger.info('musereal inference processor stop') |
| 208 | 208 | ||
| 209 | class MuseReal(BaseReal): | 209 | class MuseReal(BaseReal): |
| @@ -226,12 +226,12 @@ class MuseReal(BaseReal): | @@ -226,12 +226,12 @@ class MuseReal(BaseReal): | ||
| 226 | 226 | ||
| 227 | self.asr = MuseASR(opt,self,self.audio_processor) | 227 | self.asr = MuseASR(opt,self,self.audio_processor) |
| 228 | self.asr.warm_up() | 228 | self.asr.warm_up() |
| 229 | - | 229 | + |
| 230 | self.render_event = mp.Event() | 230 | self.render_event = mp.Event() |
| 231 | 231 | ||
| 232 | def __del__(self): | 232 | def __del__(self): |
| 233 | logger.info(f'musereal({self.sessionid}) delete') | 233 | logger.info(f'musereal({self.sessionid}) delete') |
| 234 | - | 234 | + |
| 235 | 235 | ||
| 236 | def __mirror_index(self, index): | 236 | def __mirror_index(self, index): |
| 237 | size = len(self.coord_list_cycle) | 237 | size = len(self.coord_list_cycle) |
| @@ -240,9 +240,9 @@ class MuseReal(BaseReal): | @@ -240,9 +240,9 @@ class MuseReal(BaseReal): | ||
| 240 | if turn % 2 == 0: | 240 | if turn % 2 == 0: |
| 241 | return res | 241 | return res |
| 242 | else: | 242 | else: |
| 243 | - return size - res - 1 | 243 | + return size - res - 1 |
| 244 | 244 | ||
| 245 | - def __warm_up(self): | 245 | + def __warm_up(self): |
| 246 | self.asr.run_step() | 246 | self.asr.run_step() |
| 247 | whisper_chunks = self.asr.get_next_feat() | 247 | whisper_chunks = self.asr.get_next_feat() |
| 248 | whisper_batch = np.stack(whisper_chunks) | 248 | whisper_batch = np.stack(whisper_chunks) |
| @@ -260,30 +260,57 @@ class MuseReal(BaseReal): | @@ -260,30 +260,57 @@ class MuseReal(BaseReal): | ||
| 260 | audio_feature_batch = self.pe(audio_feature_batch) | 260 | audio_feature_batch = self.pe(audio_feature_batch) |
| 261 | latent_batch = latent_batch.to(dtype=self.unet.model.dtype) | 261 | latent_batch = latent_batch.to(dtype=self.unet.model.dtype) |
| 262 | 262 | ||
| 263 | - pred_latents = self.unet.model(latent_batch, | ||
| 264 | - self.timesteps, | 263 | + pred_latents = self.unet.model(latent_batch, |
| 264 | + self.timesteps, | ||
| 265 | encoder_hidden_states=audio_feature_batch).sample | 265 | encoder_hidden_states=audio_feature_batch).sample |
| 266 | recon = self.vae.decode_latents(pred_latents) | 266 | recon = self.vae.decode_latents(pred_latents) |
| 267 | - | 267 | + |
| 268 | 268 | ||
| 269 | def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): | 269 | def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): |
| 270 | - | 270 | + enable_transition = True # 设置为False禁用过渡效果,True启用 |
| 271 | + | ||
| 272 | + if enable_transition: | ||
| 273 | + self.last_speaking = False | ||
| 274 | + self.transition_start = time.time() | ||
| 275 | + self.transition_duration = 0.1 # 过渡时间 | ||
| 276 | + self.last_silent_frame = None # 静音帧缓存 | ||
| 277 | + self.last_speaking_frame = None # 说话帧缓存 | ||
| 278 | + | ||
| 271 | while not quit_event.is_set(): | 279 | while not quit_event.is_set(): |
| 272 | try: | 280 | try: |
| 273 | res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1) | 281 | res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1) |
| 274 | except queue.Empty: | 282 | except queue.Empty: |
| 275 | continue | 283 | continue |
| 276 | - if audio_frames[0][1]!=0 and audio_frames[1][1]!=0: #全为静音数据,只需要取fullimg | 284 | + |
| 285 | + if enable_transition: | ||
| 286 | + # 检测状态变化 | ||
| 287 | + current_speaking = not (audio_frames[0][1]!=0 and audio_frames[1][1]!=0) | ||
| 288 | + if current_speaking != self.last_speaking: | ||
| 289 | + logger.info(f"状态切换:{'说话' if self.last_speaking else '静音'} → {'说话' if current_speaking else '静音'}") | ||
| 290 | + self.transition_start = time.time() | ||
| 291 | + self.last_speaking = current_speaking | ||
| 292 | + | ||
| 293 | + if audio_frames[0][1]!=0 and audio_frames[1][1]!=0: | ||
| 277 | self.speaking = False | 294 | self.speaking = False |
| 278 | audiotype = audio_frames[0][1] | 295 | audiotype = audio_frames[0][1] |
| 279 | - if self.custom_index.get(audiotype) is not None: #有自定义视频 | 296 | + if self.custom_index.get(audiotype) is not None: |
| 280 | mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype]) | 297 | mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype]) |
| 281 | - combine_frame = self.custom_img_cycle[audiotype][mirindex] | 298 | + target_frame = self.custom_img_cycle[audiotype][mirindex] |
| 282 | self.custom_index[audiotype] += 1 | 299 | self.custom_index[audiotype] += 1 |
| 283 | - # if not self.custom_opt[audiotype].loop and self.custom_index[audiotype]>=len(self.custom_img_cycle[audiotype]): | ||
| 284 | - # self.curr_state = 1 #当前视频不循环播放,切换到静音状态 | ||
| 285 | else: | 300 | else: |
| 286 | - combine_frame = self.frame_list_cycle[idx] | 301 | + target_frame = self.frame_list_cycle[idx] |
| 302 | + | ||
| 303 | + if enable_transition: | ||
| 304 | + # 说话→静音过渡 | ||
| 305 | + if time.time() - self.transition_start < self.transition_duration and self.last_speaking_frame is not None: | ||
| 306 | + alpha = min(1.0, (time.time() - self.transition_start) / self.transition_duration) | ||
| 307 | + combine_frame = cv2.addWeighted(self.last_speaking_frame, 1-alpha, target_frame, alpha, 0) | ||
| 308 | + else: | ||
| 309 | + combine_frame = target_frame | ||
| 310 | + # 缓存静音帧 | ||
| 311 | + self.last_silent_frame = combine_frame.copy() | ||
| 312 | + else: | ||
| 313 | + combine_frame = target_frame | ||
| 287 | else: | 314 | else: |
| 288 | self.speaking = True | 315 | self.speaking = True |
| 289 | bbox = self.coord_list_cycle[idx] | 316 | bbox = self.coord_list_cycle[idx] |
| @@ -291,20 +318,29 @@ class MuseReal(BaseReal): | @@ -291,20 +318,29 @@ class MuseReal(BaseReal): | ||
| 291 | x1, y1, x2, y2 = bbox | 318 | x1, y1, x2, y2 = bbox |
| 292 | try: | 319 | try: |
| 293 | res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) | 320 | res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) |
| 294 | - except: | 321 | + except Exception as e: |
| 322 | + logger.warning(f"resize error: {e}") | ||
| 295 | continue | 323 | continue |
| 296 | mask = self.mask_list_cycle[idx] | 324 | mask = self.mask_list_cycle[idx] |
| 297 | mask_crop_box = self.mask_coords_list_cycle[idx] | 325 | mask_crop_box = self.mask_coords_list_cycle[idx] |
| 298 | - #combine_frame = get_image(ori_frame,res_frame,bbox) | ||
| 299 | - #t=time.perf_counter() | ||
| 300 | - combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) | ||
| 301 | - #print('blending time:',time.perf_counter()-t) | ||
| 302 | 326 | ||
| 303 | - image = combine_frame #(outputs['image'] * 255).astype(np.uint8) | 327 | + current_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) |
| 328 | + if enable_transition: | ||
| 329 | + # 静音→说话过渡 | ||
| 330 | + if time.time() - self.transition_start < self.transition_duration and self.last_silent_frame is not None: | ||
| 331 | + alpha = min(1.0, (time.time() - self.transition_start) / self.transition_duration) | ||
| 332 | + combine_frame = cv2.addWeighted(self.last_silent_frame, 1-alpha, current_frame, alpha, 0) | ||
| 333 | + else: | ||
| 334 | + combine_frame = current_frame | ||
| 335 | + # 缓存说话帧 | ||
| 336 | + self.last_speaking_frame = combine_frame.copy() | ||
| 337 | + else: | ||
| 338 | + combine_frame = current_frame | ||
| 339 | + | ||
| 340 | + image = combine_frame | ||
| 304 | new_frame = VideoFrame.from_ndarray(image, format="bgr24") | 341 | new_frame = VideoFrame.from_ndarray(image, format="bgr24") |
| 305 | asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) | 342 | asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) |
| 306 | self.record_video_data(image) | 343 | self.record_video_data(image) |
| 307 | - #self.recordq_video.put(new_frame) | ||
| 308 | 344 | ||
| 309 | for audio_frame in audio_frames: | 345 | for audio_frame in audio_frames: |
| 310 | frame,type,eventpoint = audio_frame | 346 | frame,type,eventpoint = audio_frame |
| @@ -312,12 +348,8 @@ class MuseReal(BaseReal): | @@ -312,12 +348,8 @@ class MuseReal(BaseReal): | ||
| 312 | new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) | 348 | new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) |
| 313 | new_frame.planes[0].update(frame.tobytes()) | 349 | new_frame.planes[0].update(frame.tobytes()) |
| 314 | new_frame.sample_rate=16000 | 350 | new_frame.sample_rate=16000 |
| 315 | - # if audio_track._queue.qsize()>10: | ||
| 316 | - # time.sleep(0.1) | ||
| 317 | asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) | 351 | asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) |
| 318 | self.record_audio_data(frame) | 352 | self.record_audio_data(frame) |
| 319 | - #self.notify(eventpoint) | ||
| 320 | - #self.recordq_audio.put(new_frame) | ||
| 321 | logger.info('musereal process_frames thread stop') | 353 | logger.info('musereal process_frames thread stop') |
| 322 | 354 | ||
| 323 | def render(self,quit_event,loop=None,audio_track=None,video_track=None): | 355 | def render(self,quit_event,loop=None,audio_track=None,video_track=None): |
| @@ -36,7 +36,7 @@ class UNet(): | @@ -36,7 +36,7 @@ class UNet(): | ||
| 36 | unet_config = json.load(f) | 36 | unet_config = json.load(f) |
| 37 | self.model = UNet2DConditionModel(**unet_config) | 37 | self.model = UNet2DConditionModel(**unet_config) |
| 38 | self.pe = PositionalEncoding(d_model=384) | 38 | self.pe = PositionalEncoding(d_model=384) |
| 39 | - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 39 | + self.device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) |
| 40 | weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device) | 40 | weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device) |
| 41 | self.model.load_state_dict(weights) | 41 | self.model.load_state_dict(weights) |
| 42 | if use_float16: | 42 | if use_float16: |
| @@ -23,7 +23,7 @@ class VAE(): | @@ -23,7 +23,7 @@ class VAE(): | ||
| 23 | self.model_path = model_path | 23 | self.model_path = model_path |
| 24 | self.vae = AutoencoderKL.from_pretrained(self.model_path) | 24 | self.vae = AutoencoderKL.from_pretrained(self.model_path) |
| 25 | 25 | ||
| 26 | - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 26 | + self.device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) |
| 27 | self.vae.to(self.device) | 27 | self.vae.to(self.device) |
| 28 | 28 | ||
| 29 | if use_float16: | 29 | if use_float16: |
| @@ -325,7 +325,7 @@ def create_musetalk_human(file, avatar_id): | @@ -325,7 +325,7 @@ def create_musetalk_human(file, avatar_id): | ||
| 325 | 325 | ||
| 326 | 326 | ||
| 327 | # initialize the mmpose model | 327 | # initialize the mmpose model |
| 328 | -device = "cuda" if torch.cuda.is_available() else "cpu" | 328 | +device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") |
| 329 | fa = FaceAlignment(1, flip_input=False, device=device) | 329 | fa = FaceAlignment(1, flip_input=False, device=device) |
| 330 | config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py') | 330 | config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py') |
| 331 | checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth')) | 331 | checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth')) |
| @@ -13,14 +13,14 @@ import torch | @@ -13,14 +13,14 @@ import torch | ||
| 13 | from tqdm import tqdm | 13 | from tqdm import tqdm |
| 14 | 14 | ||
| 15 | # initialize the mmpose model | 15 | # initialize the mmpose model |
| 16 | -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 16 | +device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) |
| 17 | config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py' | 17 | config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py' |
| 18 | checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth' | 18 | checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth' |
| 19 | model = init_model(config_file, checkpoint_file, device=device) | 19 | model = init_model(config_file, checkpoint_file, device=device) |
| 20 | 20 | ||
| 21 | # initialize the face detection model | 21 | # initialize the face detection model |
| 22 | -device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| 23 | -fa = FaceAlignment(LandmarksType._2D, flip_input=False,device=device) | 22 | +device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") |
| 23 | +fa = FaceAlignment(LandmarksType._2D, flip_input=False, device=device) | ||
| 24 | 24 | ||
| 25 | # maker if the bbox is not sufficient | 25 | # maker if the bbox is not sufficient |
| 26 | coord_placeholder = (0.0,0.0,0.0,0.0) | 26 | coord_placeholder = (0.0,0.0,0.0,0.0) |
| @@ -91,7 +91,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow | @@ -91,7 +91,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow | ||
| 91 | """ | 91 | """ |
| 92 | 92 | ||
| 93 | if device is None: | 93 | if device is None: |
| 94 | - device = "cuda" if torch.cuda.is_available() else "cpu" | 94 | + device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") |
| 95 | if download_root is None: | 95 | if download_root is None: |
| 96 | download_root = os.getenv( | 96 | download_root = os.getenv( |
| 97 | "XDG_CACHE_HOME", | 97 | "XDG_CACHE_HOME", |
| @@ -78,17 +78,19 @@ def transcribe( | @@ -78,17 +78,19 @@ def transcribe( | ||
| 78 | if dtype == torch.float16: | 78 | if dtype == torch.float16: |
| 79 | warnings.warn("FP16 is not supported on CPU; using FP32 instead") | 79 | warnings.warn("FP16 is not supported on CPU; using FP32 instead") |
| 80 | dtype = torch.float32 | 80 | dtype = torch.float32 |
| 81 | + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | ||
| 82 | + warnings.warn("Performing inference on CPU when MPS is available") | ||
| 81 | 83 | ||
| 82 | if dtype == torch.float32: | 84 | if dtype == torch.float32: |
| 83 | decode_options["fp16"] = False | 85 | decode_options["fp16"] = False |
| 84 | 86 | ||
| 85 | mel = log_mel_spectrogram(audio) | 87 | mel = log_mel_spectrogram(audio) |
| 86 | - | 88 | + |
| 87 | all_segments = [] | 89 | all_segments = [] |
| 88 | def add_segment( | 90 | def add_segment( |
| 89 | *, start: float, end: float, encoder_embeddings | 91 | *, start: float, end: float, encoder_embeddings |
| 90 | ): | 92 | ): |
| 91 | - | 93 | + |
| 92 | all_segments.append( | 94 | all_segments.append( |
| 93 | { | 95 | { |
| 94 | "start": start, | 96 | "start": start, |
| @@ -100,20 +102,20 @@ def transcribe( | @@ -100,20 +102,20 @@ def transcribe( | ||
| 100 | num_frames = mel.shape[-1] | 102 | num_frames = mel.shape[-1] |
| 101 | seek = 0 | 103 | seek = 0 |
| 102 | previous_seek_value = seek | 104 | previous_seek_value = seek |
| 103 | - sample_skip = 3000 # | 105 | + sample_skip = 3000 # |
| 104 | with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: | 106 | with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: |
| 105 | while seek < num_frames: | 107 | while seek < num_frames: |
| 106 | # seek是开始的帧数 | 108 | # seek是开始的帧数 |
| 107 | end_seek = min(seek + sample_skip, num_frames) | 109 | end_seek = min(seek + sample_skip, num_frames) |
| 108 | segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype) | 110 | segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype) |
| 109 | - | 111 | + |
| 110 | single = segment.ndim == 2 | 112 | single = segment.ndim == 2 |
| 111 | if single: | 113 | if single: |
| 112 | segment = segment.unsqueeze(0) | 114 | segment = segment.unsqueeze(0) |
| 113 | if dtype == torch.float16: | 115 | if dtype == torch.float16: |
| 114 | segment = segment.half() | 116 | segment = segment.half() |
| 115 | audio_features, embeddings = model.encoder(segment, include_embeddings = True) | 117 | audio_features, embeddings = model.encoder(segment, include_embeddings = True) |
| 116 | - | 118 | + |
| 117 | encoder_embeddings = embeddings | 119 | encoder_embeddings = embeddings |
| 118 | #print(f"encoder_embeddings shape {encoder_embeddings.shape}") | 120 | #print(f"encoder_embeddings shape {encoder_embeddings.shape}") |
| 119 | add_segment( | 121 | add_segment( |
| @@ -124,7 +126,7 @@ def transcribe( | @@ -124,7 +126,7 @@ def transcribe( | ||
| 124 | encoder_embeddings=encoder_embeddings, | 126 | encoder_embeddings=encoder_embeddings, |
| 125 | ) | 127 | ) |
| 126 | seek+=sample_skip | 128 | seek+=sample_skip |
| 127 | - | 129 | + |
| 128 | return dict(segments=all_segments) | 130 | return dict(segments=all_segments) |
| 129 | 131 | ||
| 130 | 132 | ||
| @@ -135,7 +137,7 @@ def cli(): | @@ -135,7 +137,7 @@ def cli(): | ||
| 135 | parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") | 137 | parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") |
| 136 | parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") | 138 | parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") |
| 137 | parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") | 139 | parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") |
| 138 | - parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") | 140 | + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "mps", help="device to use for PyTorch inference") |
| 139 | parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") | 141 | parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") |
| 140 | parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") | 142 | parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") |
| 141 | 143 |
| @@ -30,7 +30,7 @@ class NerfASR(BaseASR): | @@ -30,7 +30,7 @@ class NerfASR(BaseASR): | ||
| 30 | def __init__(self, opt, parent, audio_processor,audio_model): | 30 | def __init__(self, opt, parent, audio_processor,audio_model): |
| 31 | super().__init__(opt,parent) | 31 | super().__init__(opt,parent) |
| 32 | 32 | ||
| 33 | - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | 33 | + self.device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") |
| 34 | if 'esperanto' in self.opt.asr_model: | 34 | if 'esperanto' in self.opt.asr_model: |
| 35 | self.audio_dim = 44 | 35 | self.audio_dim = 44 |
| 36 | elif 'deepspeech' in self.opt.asr_model: | 36 | elif 'deepspeech' in self.opt.asr_model: |
| @@ -77,7 +77,7 @@ def load_model(opt): | @@ -77,7 +77,7 @@ def load_model(opt): | ||
| 77 | seed_everything(opt.seed) | 77 | seed_everything(opt.seed) |
| 78 | logger.info(opt) | 78 | logger.info(opt) |
| 79 | 79 | ||
| 80 | - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 80 | + device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else 'cpu')) |
| 81 | model = NeRFNetwork(opt) | 81 | model = NeRFNetwork(opt) |
| 82 | 82 | ||
| 83 | criterion = torch.nn.MSELoss(reduction='none') | 83 | criterion = torch.nn.MSELoss(reduction='none') |
| @@ -90,7 +90,7 @@ class BaseTTS: | @@ -90,7 +90,7 @@ class BaseTTS: | ||
| 90 | ########################################################################################### | 90 | ########################################################################################### |
| 91 | class EdgeTTS(BaseTTS): | 91 | class EdgeTTS(BaseTTS): |
| 92 | def txt_to_audio(self,msg): | 92 | def txt_to_audio(self,msg): |
| 93 | - voicename = "zh-CN-XiaoxiaoNeural" | 93 | + voicename = "zh-CN-YunxiaNeural" |
| 94 | text,textevent = msg | 94 | text,textevent = msg |
| 95 | t = time.time() | 95 | t = time.time() |
| 96 | asyncio.new_event_loop().run_until_complete(self.__main(voicename,text)) | 96 | asyncio.new_event_loop().run_until_complete(self.__main(voicename,text)) |
| @@ -98,7 +98,7 @@ class EdgeTTS(BaseTTS): | @@ -98,7 +98,7 @@ class EdgeTTS(BaseTTS): | ||
| 98 | if self.input_stream.getbuffer().nbytes<=0: #edgetts err | 98 | if self.input_stream.getbuffer().nbytes<=0: #edgetts err |
| 99 | logger.error('edgetts err!!!!!') | 99 | logger.error('edgetts err!!!!!') |
| 100 | return | 100 | return |
| 101 | - | 101 | + |
| 102 | self.input_stream.seek(0) | 102 | self.input_stream.seek(0) |
| 103 | stream = self.__create_bytes_stream(self.input_stream) | 103 | stream = self.__create_bytes_stream(self.input_stream) |
| 104 | streamlen = stream.shape[0] | 104 | streamlen = stream.shape[0] |
| @@ -107,15 +107,15 @@ class EdgeTTS(BaseTTS): | @@ -107,15 +107,15 @@ class EdgeTTS(BaseTTS): | ||
| 107 | eventpoint=None | 107 | eventpoint=None |
| 108 | streamlen -= self.chunk | 108 | streamlen -= self.chunk |
| 109 | if idx==0: | 109 | if idx==0: |
| 110 | - eventpoint={'status':'start','text':text,'msgenvent':textevent} | 110 | + eventpoint={'status':'start','text':text,'msgevent':textevent} |
| 111 | elif streamlen<self.chunk: | 111 | elif streamlen<self.chunk: |
| 112 | - eventpoint={'status':'end','text':text,'msgenvent':textevent} | 112 | + eventpoint={'status':'end','text':text,'msgevent':textevent} |
| 113 | self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) | 113 | self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) |
| 114 | idx += self.chunk | 114 | idx += self.chunk |
| 115 | #if streamlen>0: #skip last frame(not 20ms) | 115 | #if streamlen>0: #skip last frame(not 20ms) |
| 116 | # self.queue.put(stream[idx:]) | 116 | # self.queue.put(stream[idx:]) |
| 117 | self.input_stream.seek(0) | 117 | self.input_stream.seek(0) |
| 118 | - self.input_stream.truncate() | 118 | + self.input_stream.truncate() |
| 119 | 119 | ||
| 120 | def __create_bytes_stream(self,byte_stream): | 120 | def __create_bytes_stream(self,byte_stream): |
| 121 | #byte_stream=BytesIO(buffer) | 121 | #byte_stream=BytesIO(buffer) |
| @@ -126,13 +126,13 @@ class EdgeTTS(BaseTTS): | @@ -126,13 +126,13 @@ class EdgeTTS(BaseTTS): | ||
| 126 | if stream.ndim > 1: | 126 | if stream.ndim > 1: |
| 127 | logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') | 127 | logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') |
| 128 | stream = stream[:, 0] | 128 | stream = stream[:, 0] |
| 129 | - | 129 | + |
| 130 | if sample_rate != self.sample_rate and stream.shape[0]>0: | 130 | if sample_rate != self.sample_rate and stream.shape[0]>0: |
| 131 | logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') | 131 | logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') |
| 132 | stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) | 132 | stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) |
| 133 | 133 | ||
| 134 | return stream | 134 | return stream |
| 135 | - | 135 | + |
| 136 | async def __main(self,voicename: str, text: str): | 136 | async def __main(self,voicename: str, text: str): |
| 137 | try: | 137 | try: |
| 138 | communicate = edge_tts.Communicate(text, voicename) | 138 | communicate = edge_tts.Communicate(text, voicename) |
| @@ -153,12 +153,12 @@ class EdgeTTS(BaseTTS): | @@ -153,12 +153,12 @@ class EdgeTTS(BaseTTS): | ||
| 153 | 153 | ||
| 154 | ########################################################################################### | 154 | ########################################################################################### |
| 155 | class FishTTS(BaseTTS): | 155 | class FishTTS(BaseTTS): |
| 156 | - def txt_to_audio(self,msg): | 156 | + def txt_to_audio(self,msg): |
| 157 | text,textevent = msg | 157 | text,textevent = msg |
| 158 | self.stream_tts( | 158 | self.stream_tts( |
| 159 | self.fish_speech( | 159 | self.fish_speech( |
| 160 | text, | 160 | text, |
| 161 | - self.opt.REF_FILE, | 161 | + self.opt.REF_FILE, |
| 162 | self.opt.REF_TEXT, | 162 | self.opt.REF_TEXT, |
| 163 | "zh", #en args.language, | 163 | "zh", #en args.language, |
| 164 | self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, | 164 | self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, |
| @@ -190,9 +190,9 @@ class FishTTS(BaseTTS): | @@ -190,9 +190,9 @@ class FishTTS(BaseTTS): | ||
| 190 | if res.status_code != 200: | 190 | if res.status_code != 200: |
| 191 | logger.error("Error:%s", res.text) | 191 | logger.error("Error:%s", res.text) |
| 192 | return | 192 | return |
| 193 | - | 193 | + |
| 194 | first = True | 194 | first = True |
| 195 | - | 195 | + |
| 196 | for chunk in res.iter_content(chunk_size=17640): # 1764 44100*20ms*2 | 196 | for chunk in res.iter_content(chunk_size=17640): # 1764 44100*20ms*2 |
| 197 | #print('chunk len:',len(chunk)) | 197 | #print('chunk len:',len(chunk)) |
| 198 | if first: | 198 | if first: |
| @@ -209,7 +209,7 @@ class FishTTS(BaseTTS): | @@ -209,7 +209,7 @@ class FishTTS(BaseTTS): | ||
| 209 | text,textevent = msg | 209 | text,textevent = msg |
| 210 | first = True | 210 | first = True |
| 211 | for chunk in audio_stream: | 211 | for chunk in audio_stream: |
| 212 | - if chunk is not None and len(chunk)>0: | 212 | + if chunk is not None and len(chunk)>0: |
| 213 | stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 | 213 | stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 |
| 214 | stream = resampy.resample(x=stream, sr_orig=44100, sr_new=self.sample_rate) | 214 | stream = resampy.resample(x=stream, sr_orig=44100, sr_new=self.sample_rate) |
| 215 | #byte_stream=BytesIO(buffer) | 215 | #byte_stream=BytesIO(buffer) |
| @@ -219,22 +219,22 @@ class FishTTS(BaseTTS): | @@ -219,22 +219,22 @@ class FishTTS(BaseTTS): | ||
| 219 | while streamlen >= self.chunk: | 219 | while streamlen >= self.chunk: |
| 220 | eventpoint=None | 220 | eventpoint=None |
| 221 | if first: | 221 | if first: |
| 222 | - eventpoint={'status':'start','text':text,'msgenvent':textevent} | 222 | + eventpoint={'status':'start','text':text,'msgevent':textevent} |
| 223 | first = False | 223 | first = False |
| 224 | self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) | 224 | self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) |
| 225 | streamlen -= self.chunk | 225 | streamlen -= self.chunk |
| 226 | idx += self.chunk | 226 | idx += self.chunk |
| 227 | - eventpoint={'status':'end','text':text,'msgenvent':textevent} | ||
| 228 | - self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) | 227 | + eventpoint={'status':'end','text':text,'msgevent':textevent} |
| 228 | + self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) | ||
| 229 | 229 | ||
| 230 | ########################################################################################### | 230 | ########################################################################################### |
| 231 | -class VoitsTTS(BaseTTS): | ||
| 232 | - def txt_to_audio(self,msg): | 231 | +class SovitsTTS(BaseTTS): |
| 232 | + def txt_to_audio(self,msg): | ||
| 233 | text,textevent = msg | 233 | text,textevent = msg |
| 234 | self.stream_tts( | 234 | self.stream_tts( |
| 235 | self.gpt_sovits( | 235 | self.gpt_sovits( |
| 236 | text, | 236 | text, |
| 237 | - self.opt.REF_FILE, | 237 | + self.opt.REF_FILE, |
| 238 | self.opt.REF_TEXT, | 238 | self.opt.REF_TEXT, |
| 239 | "zh", #en args.language, | 239 | "zh", #en args.language, |
| 240 | self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, | 240 | self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, |
| @@ -271,9 +271,9 @@ class VoitsTTS(BaseTTS): | @@ -271,9 +271,9 @@ class VoitsTTS(BaseTTS): | ||
| 271 | if res.status_code != 200: | 271 | if res.status_code != 200: |
| 272 | logger.error("Error:%s", res.text) | 272 | logger.error("Error:%s", res.text) |
| 273 | return | 273 | return |
| 274 | - | 274 | + |
| 275 | first = True | 275 | first = True |
| 276 | - | 276 | + |
| 277 | for chunk in res.iter_content(chunk_size=None): #12800 1280 32K*20ms*2 | 277 | for chunk in res.iter_content(chunk_size=None): #12800 1280 32K*20ms*2 |
| 278 | logger.info('chunk len:%d',len(chunk)) | 278 | logger.info('chunk len:%d',len(chunk)) |
| 279 | if first: | 279 | if first: |
| @@ -295,7 +295,7 @@ class VoitsTTS(BaseTTS): | @@ -295,7 +295,7 @@ class VoitsTTS(BaseTTS): | ||
| 295 | if stream.ndim > 1: | 295 | if stream.ndim > 1: |
| 296 | logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') | 296 | logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') |
| 297 | stream = stream[:, 0] | 297 | stream = stream[:, 0] |
| 298 | - | 298 | + |
| 299 | if sample_rate != self.sample_rate and stream.shape[0]>0: | 299 | if sample_rate != self.sample_rate and stream.shape[0]>0: |
| 300 | logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') | 300 | logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') |
| 301 | stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) | 301 | stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) |
| @@ -306,7 +306,7 @@ class VoitsTTS(BaseTTS): | @@ -306,7 +306,7 @@ class VoitsTTS(BaseTTS): | ||
| 306 | text,textevent = msg | 306 | text,textevent = msg |
| 307 | first = True | 307 | first = True |
| 308 | for chunk in audio_stream: | 308 | for chunk in audio_stream: |
| 309 | - if chunk is not None and len(chunk)>0: | 309 | + if chunk is not None and len(chunk)>0: |
| 310 | #stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 | 310 | #stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 |
| 311 | #stream = resampy.resample(x=stream, sr_orig=32000, sr_new=self.sample_rate) | 311 | #stream = resampy.resample(x=stream, sr_orig=32000, sr_new=self.sample_rate) |
| 312 | byte_stream=BytesIO(chunk) | 312 | byte_stream=BytesIO(chunk) |
| @@ -316,22 +316,22 @@ class VoitsTTS(BaseTTS): | @@ -316,22 +316,22 @@ class VoitsTTS(BaseTTS): | ||
| 316 | while streamlen >= self.chunk: | 316 | while streamlen >= self.chunk: |
| 317 | eventpoint=None | 317 | eventpoint=None |
| 318 | if first: | 318 | if first: |
| 319 | - eventpoint={'status':'start','text':text,'msgenvent':textevent} | 319 | + eventpoint={'status':'start','text':text,'msgevent':textevent} |
| 320 | first = False | 320 | first = False |
| 321 | self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) | 321 | self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) |
| 322 | streamlen -= self.chunk | 322 | streamlen -= self.chunk |
| 323 | idx += self.chunk | 323 | idx += self.chunk |
| 324 | - eventpoint={'status':'end','text':text,'msgenvent':textevent} | 324 | + eventpoint={'status':'end','text':text,'msgevent':textevent} |
| 325 | self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) | 325 | self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) |
| 326 | 326 | ||
| 327 | ########################################################################################### | 327 | ########################################################################################### |
| 328 | class CosyVoiceTTS(BaseTTS): | 328 | class CosyVoiceTTS(BaseTTS): |
| 329 | def txt_to_audio(self,msg): | 329 | def txt_to_audio(self,msg): |
| 330 | - text,textevent = msg | 330 | + text,textevent = msg |
| 331 | self.stream_tts( | 331 | self.stream_tts( |
| 332 | self.cosy_voice( | 332 | self.cosy_voice( |
| 333 | text, | 333 | text, |
| 334 | - self.opt.REF_FILE, | 334 | + self.opt.REF_FILE, |
| 335 | self.opt.REF_TEXT, | 335 | self.opt.REF_TEXT, |
| 336 | "zh", #en args.language, | 336 | "zh", #en args.language, |
| 337 | self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, | 337 | self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, |
| @@ -348,16 +348,16 @@ class CosyVoiceTTS(BaseTTS): | @@ -348,16 +348,16 @@ class CosyVoiceTTS(BaseTTS): | ||
| 348 | try: | 348 | try: |
| 349 | files = [('prompt_wav', ('prompt_wav', open(reffile, 'rb'), 'application/octet-stream'))] | 349 | files = [('prompt_wav', ('prompt_wav', open(reffile, 'rb'), 'application/octet-stream'))] |
| 350 | res = requests.request("GET", f"{server_url}/inference_zero_shot", data=payload, files=files, stream=True) | 350 | res = requests.request("GET", f"{server_url}/inference_zero_shot", data=payload, files=files, stream=True) |
| 351 | - | 351 | + |
| 352 | end = time.perf_counter() | 352 | end = time.perf_counter() |
| 353 | logger.info(f"cosy_voice Time to make POST: {end-start}s") | 353 | logger.info(f"cosy_voice Time to make POST: {end-start}s") |
| 354 | 354 | ||
| 355 | if res.status_code != 200: | 355 | if res.status_code != 200: |
| 356 | logger.error("Error:%s", res.text) | 356 | logger.error("Error:%s", res.text) |
| 357 | return | 357 | return |
| 358 | - | 358 | + |
| 359 | first = True | 359 | first = True |
| 360 | - | 360 | + |
| 361 | for chunk in res.iter_content(chunk_size=9600): # 960 24K*20ms*2 | 361 | for chunk in res.iter_content(chunk_size=9600): # 960 24K*20ms*2 |
| 362 | if first: | 362 | if first: |
| 363 | end = time.perf_counter() | 363 | end = time.perf_counter() |
| @@ -372,7 +372,7 @@ class CosyVoiceTTS(BaseTTS): | @@ -372,7 +372,7 @@ class CosyVoiceTTS(BaseTTS): | ||
| 372 | text,textevent = msg | 372 | text,textevent = msg |
| 373 | first = True | 373 | first = True |
| 374 | for chunk in audio_stream: | 374 | for chunk in audio_stream: |
| 375 | - if chunk is not None and len(chunk)>0: | 375 | + if chunk is not None and len(chunk)>0: |
| 376 | stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 | 376 | stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 |
| 377 | stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) | 377 | stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) |
| 378 | #byte_stream=BytesIO(buffer) | 378 | #byte_stream=BytesIO(buffer) |
| @@ -382,13 +382,13 @@ class CosyVoiceTTS(BaseTTS): | @@ -382,13 +382,13 @@ class CosyVoiceTTS(BaseTTS): | ||
| 382 | while streamlen >= self.chunk: | 382 | while streamlen >= self.chunk: |
| 383 | eventpoint=None | 383 | eventpoint=None |
| 384 | if first: | 384 | if first: |
| 385 | - eventpoint={'status':'start','text':text,'msgenvent':textevent} | 385 | + eventpoint={'status':'start','text':text,'msgevent':textevent} |
| 386 | first = False | 386 | first = False |
| 387 | self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) | 387 | self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) |
| 388 | streamlen -= self.chunk | 388 | streamlen -= self.chunk |
| 389 | idx += self.chunk | 389 | idx += self.chunk |
| 390 | - eventpoint={'status':'end','text':text,'msgenvent':textevent} | ||
| 391 | - self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) | 390 | + eventpoint={'status':'end','text':text,'msgevent':textevent} |
| 391 | + self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) | ||
| 392 | 392 | ||
| 393 | ########################################################################################### | 393 | ########################################################################################### |
| 394 | _PROTOCOL = "https://" | 394 | _PROTOCOL = "https://" |
| @@ -407,7 +407,7 @@ class TencentTTS(BaseTTS): | @@ -407,7 +407,7 @@ class TencentTTS(BaseTTS): | ||
| 407 | self.sample_rate = 16000 | 407 | self.sample_rate = 16000 |
| 408 | self.volume = 0 | 408 | self.volume = 0 |
| 409 | self.speed = 0 | 409 | self.speed = 0 |
| 410 | - | 410 | + |
| 411 | def __gen_signature(self, params): | 411 | def __gen_signature(self, params): |
| 412 | sort_dict = sorted(params.keys()) | 412 | sort_dict = sorted(params.keys()) |
| 413 | sign_str = "POST" + _HOST + _PATH + "?" | 413 | sign_str = "POST" + _HOST + _PATH + "?" |
| @@ -440,11 +440,11 @@ class TencentTTS(BaseTTS): | @@ -440,11 +440,11 @@ class TencentTTS(BaseTTS): | ||
| 440 | return params | 440 | return params |
| 441 | 441 | ||
| 442 | def txt_to_audio(self,msg): | 442 | def txt_to_audio(self,msg): |
| 443 | - text,textevent = msg | 443 | + text,textevent = msg |
| 444 | self.stream_tts( | 444 | self.stream_tts( |
| 445 | self.tencent_voice( | 445 | self.tencent_voice( |
| 446 | text, | 446 | text, |
| 447 | - self.opt.REF_FILE, | 447 | + self.opt.REF_FILE, |
| 448 | self.opt.REF_TEXT, | 448 | self.opt.REF_TEXT, |
| 449 | "zh", #en args.language, | 449 | "zh", #en args.language, |
| 450 | self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, | 450 | self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, |
| @@ -465,12 +465,12 @@ class TencentTTS(BaseTTS): | @@ -465,12 +465,12 @@ class TencentTTS(BaseTTS): | ||
| 465 | try: | 465 | try: |
| 466 | res = requests.post(url, headers=headers, | 466 | res = requests.post(url, headers=headers, |
| 467 | data=json.dumps(params), stream=True) | 467 | data=json.dumps(params), stream=True) |
| 468 | - | 468 | + |
| 469 | end = time.perf_counter() | 469 | end = time.perf_counter() |
| 470 | logger.info(f"tencent Time to make POST: {end-start}s") | 470 | logger.info(f"tencent Time to make POST: {end-start}s") |
| 471 | - | 471 | + |
| 472 | first = True | 472 | first = True |
| 473 | - | 473 | + |
| 474 | for chunk in res.iter_content(chunk_size=6400): # 640 16K*20ms*2 | 474 | for chunk in res.iter_content(chunk_size=6400): # 640 16K*20ms*2 |
| 475 | #logger.info('chunk len:%d',len(chunk)) | 475 | #logger.info('chunk len:%d',len(chunk)) |
| 476 | if first: | 476 | if first: |
| @@ -483,7 +483,7 @@ class TencentTTS(BaseTTS): | @@ -483,7 +483,7 @@ class TencentTTS(BaseTTS): | ||
| 483 | except: | 483 | except: |
| 484 | end = time.perf_counter() | 484 | end = time.perf_counter() |
| 485 | logger.info(f"tencent Time to first chunk: {end-start}s") | 485 | logger.info(f"tencent Time to first chunk: {end-start}s") |
| 486 | - first = False | 486 | + first = False |
| 487 | if chunk and self.state==State.RUNNING: | 487 | if chunk and self.state==State.RUNNING: |
| 488 | yield chunk | 488 | yield chunk |
| 489 | except Exception as e: | 489 | except Exception as e: |
| @@ -494,7 +494,7 @@ class TencentTTS(BaseTTS): | @@ -494,7 +494,7 @@ class TencentTTS(BaseTTS): | ||
| 494 | first = True | 494 | first = True |
| 495 | last_stream = np.array([],dtype=np.float32) | 495 | last_stream = np.array([],dtype=np.float32) |
| 496 | for chunk in audio_stream: | 496 | for chunk in audio_stream: |
| 497 | - if chunk is not None and len(chunk)>0: | 497 | + if chunk is not None and len(chunk)>0: |
| 498 | stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 | 498 | stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 |
| 499 | stream = np.concatenate((last_stream,stream)) | 499 | stream = np.concatenate((last_stream,stream)) |
| 500 | #stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) | 500 | #stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) |
| @@ -505,14 +505,14 @@ class TencentTTS(BaseTTS): | @@ -505,14 +505,14 @@ class TencentTTS(BaseTTS): | ||
| 505 | while streamlen >= self.chunk: | 505 | while streamlen >= self.chunk: |
| 506 | eventpoint=None | 506 | eventpoint=None |
| 507 | if first: | 507 | if first: |
| 508 | - eventpoint={'status':'start','text':text,'msgenvent':textevent} | 508 | + eventpoint={'status':'start','text':text,'msgevent':textevent} |
| 509 | first = False | 509 | first = False |
| 510 | self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) | 510 | self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) |
| 511 | streamlen -= self.chunk | 511 | streamlen -= self.chunk |
| 512 | idx += self.chunk | 512 | idx += self.chunk |
| 513 | last_stream = stream[idx:] #get the remain stream | 513 | last_stream = stream[idx:] #get the remain stream |
| 514 | - eventpoint={'status':'end','text':text,'msgenvent':textevent} | ||
| 515 | - self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) | 514 | + eventpoint={'status':'end','text':text,'msgevent':textevent} |
| 515 | + self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) | ||
| 516 | 516 | ||
| 517 | ########################################################################################### | 517 | ########################################################################################### |
| 518 | 518 | ||
| @@ -522,7 +522,7 @@ class XTTS(BaseTTS): | @@ -522,7 +522,7 @@ class XTTS(BaseTTS): | ||
| 522 | self.speaker = self.get_speaker(opt.REF_FILE, opt.TTS_SERVER) | 522 | self.speaker = self.get_speaker(opt.REF_FILE, opt.TTS_SERVER) |
| 523 | 523 | ||
| 524 | def txt_to_audio(self,msg): | 524 | def txt_to_audio(self,msg): |
| 525 | - text,textevent = msg | 525 | + text,textevent = msg |
| 526 | self.stream_tts( | 526 | self.stream_tts( |
| 527 | self.xtts( | 527 | self.xtts( |
| 528 | text, | 528 | text, |
| @@ -558,7 +558,7 @@ class XTTS(BaseTTS): | @@ -558,7 +558,7 @@ class XTTS(BaseTTS): | ||
| 558 | return | 558 | return |
| 559 | 559 | ||
| 560 | first = True | 560 | first = True |
| 561 | - | 561 | + |
| 562 | for chunk in res.iter_content(chunk_size=9600): #24K*20ms*2 | 562 | for chunk in res.iter_content(chunk_size=9600): #24K*20ms*2 |
| 563 | if first: | 563 | if first: |
| 564 | end = time.perf_counter() | 564 | end = time.perf_counter() |
| @@ -568,12 +568,12 @@ class XTTS(BaseTTS): | @@ -568,12 +568,12 @@ class XTTS(BaseTTS): | ||
| 568 | yield chunk | 568 | yield chunk |
| 569 | except Exception as e: | 569 | except Exception as e: |
| 570 | print(e) | 570 | print(e) |
| 571 | - | 571 | + |
| 572 | def stream_tts(self,audio_stream,msg): | 572 | def stream_tts(self,audio_stream,msg): |
| 573 | text,textevent = msg | 573 | text,textevent = msg |
| 574 | first = True | 574 | first = True |
| 575 | for chunk in audio_stream: | 575 | for chunk in audio_stream: |
| 576 | - if chunk is not None and len(chunk)>0: | 576 | + if chunk is not None and len(chunk)>0: |
| 577 | stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 | 577 | stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 |
| 578 | stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) | 578 | stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) |
| 579 | #byte_stream=BytesIO(buffer) | 579 | #byte_stream=BytesIO(buffer) |
| @@ -583,10 +583,10 @@ class XTTS(BaseTTS): | @@ -583,10 +583,10 @@ class XTTS(BaseTTS): | ||
| 583 | while streamlen >= self.chunk: | 583 | while streamlen >= self.chunk: |
| 584 | eventpoint=None | 584 | eventpoint=None |
| 585 | if first: | 585 | if first: |
| 586 | - eventpoint={'status':'start','text':text,'msgenvent':textevent} | 586 | + eventpoint={'status':'start','text':text,'msgevent':textevent} |
| 587 | first = False | 587 | first = False |
| 588 | self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) | 588 | self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) |
| 589 | streamlen -= self.chunk | 589 | streamlen -= self.chunk |
| 590 | idx += self.chunk | 590 | idx += self.chunk |
| 591 | - eventpoint={'status':'end','text':text,'msgenvent':textevent} | 591 | + eventpoint={'status':'end','text':text,'msgevent':textevent} |
| 592 | self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) | 592 | self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) |
| @@ -236,7 +236,7 @@ if __name__ == '__main__': | @@ -236,7 +236,7 @@ if __name__ == '__main__': | ||
| 236 | if hasattr(module, 'reparameterize'): | 236 | if hasattr(module, 'reparameterize'): |
| 237 | module.reparameterize() | 237 | module.reparameterize() |
| 238 | return model | 238 | return model |
| 239 | - device = torch.device("cuda") | 239 | + device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) |
| 240 | def check_onnx(torch_out, torch_in, audio): | 240 | def check_onnx(torch_out, torch_in, audio): |
| 241 | onnx_model = onnx.load(onnx_path) | 241 | onnx_model = onnx.load(onnx_path) |
| 242 | onnx.checker.check_model(onnx_model) | 242 | onnx.checker.check_model(onnx_model) |
web/dashboard.html
0 → 100644
| 1 | +<!DOCTYPE html> | ||
| 2 | +<html lang="zh-CN"> | ||
| 3 | +<head> | ||
| 4 | + <meta charset="UTF-8"> | ||
| 5 | + <meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||
| 6 | + <title>livetalking数字人交互平台</title> | ||
| 7 | + <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet"> | ||
| 8 | + <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap-icons@1.10.0/font/bootstrap-icons.css"> | ||
| 9 | + <style> | ||
| 10 | + :root { | ||
| 11 | + --primary-color: #4361ee; | ||
| 12 | + --secondary-color: #3f37c9; | ||
| 13 | + --accent-color: #4895ef; | ||
| 14 | + --background-color: #f8f9fa; | ||
| 15 | + --card-bg: #ffffff; | ||
| 16 | + --text-color: #212529; | ||
| 17 | + --border-radius: 10px; | ||
| 18 | + --box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | ||
| 19 | + } | ||
| 20 | + | ||
| 21 | + body { | ||
| 22 | + font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | ||
| 23 | + background-color: var(--background-color); | ||
| 24 | + color: var(--text-color); | ||
| 25 | + min-height: 100vh; | ||
| 26 | + padding-top: 20px; | ||
| 27 | + } | ||
| 28 | + | ||
| 29 | + .dashboard-container { | ||
| 30 | + max-width: 1400px; | ||
| 31 | + margin: 0 auto; | ||
| 32 | + padding: 20px; | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + .card { | ||
| 36 | + background-color: var(--card-bg); | ||
| 37 | + border-radius: var(--border-radius); | ||
| 38 | + box-shadow: var(--box-shadow); | ||
| 39 | + border: none; | ||
| 40 | + margin-bottom: 20px; | ||
| 41 | + overflow: hidden; | ||
| 42 | + } | ||
| 43 | + | ||
| 44 | + .card-header { | ||
| 45 | + background-color: var(--primary-color); | ||
| 46 | + color: white; | ||
| 47 | + font-weight: 600; | ||
| 48 | + padding: 15px 20px; | ||
| 49 | + border-bottom: none; | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + .video-container { | ||
| 53 | + position: relative; | ||
| 54 | + width: 100%; | ||
| 55 | + background-color: #000; | ||
| 56 | + border-radius: var(--border-radius); | ||
| 57 | + overflow: hidden; | ||
| 58 | + display: flex; | ||
| 59 | + justify-content: center; | ||
| 60 | + align-items: center; | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + video { | ||
| 64 | + max-width: 100%; | ||
| 65 | + max-height: 100%; | ||
| 66 | + display: block; | ||
| 67 | + border-radius: var(--border-radius); | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + .controls-container { | ||
| 71 | + padding: 20px; | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + .btn-primary { | ||
| 75 | + background-color: var(--primary-color); | ||
| 76 | + border-color: var(--primary-color); | ||
| 77 | + } | ||
| 78 | + | ||
| 79 | + .btn-primary:hover { | ||
| 80 | + background-color: var(--secondary-color); | ||
| 81 | + border-color: var(--secondary-color); | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + .btn-outline-primary { | ||
| 85 | + color: var(--primary-color); | ||
| 86 | + border-color: var(--primary-color); | ||
| 87 | + } | ||
| 88 | + | ||
| 89 | + .btn-outline-primary:hover { | ||
| 90 | + background-color: var(--primary-color); | ||
| 91 | + color: white; | ||
| 92 | + } | ||
| 93 | + | ||
| 94 | + .form-control { | ||
| 95 | + border-radius: var(--border-radius); | ||
| 96 | + padding: 10px 15px; | ||
| 97 | + border: 1px solid #ced4da; | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + .form-control:focus { | ||
| 101 | + border-color: var(--accent-color); | ||
| 102 | + box-shadow: 0 0 0 0.25rem rgba(67, 97, 238, 0.25); | ||
| 103 | + } | ||
| 104 | + | ||
| 105 | + .status-indicator { | ||
| 106 | + width: 10px; | ||
| 107 | + height: 10px; | ||
| 108 | + border-radius: 50%; | ||
| 109 | + display: inline-block; | ||
| 110 | + margin-right: 5px; | ||
| 111 | + } | ||
| 112 | + | ||
| 113 | + .status-connected { | ||
| 114 | + background-color: #28a745; | ||
| 115 | + } | ||
| 116 | + | ||
| 117 | + .status-disconnected { | ||
| 118 | + background-color: #dc3545; | ||
| 119 | + } | ||
| 120 | + | ||
| 121 | + .status-connecting { | ||
| 122 | + background-color: #ffc107; | ||
| 123 | + } | ||
| 124 | + | ||
| 125 | + .asr-container { | ||
| 126 | + height: 300px; | ||
| 127 | + overflow-y: auto; | ||
| 128 | + padding: 15px; | ||
| 129 | + background-color: #f8f9fa; | ||
| 130 | + border-radius: var(--border-radius); | ||
| 131 | + border: 1px solid #ced4da; | ||
| 132 | + } | ||
| 133 | + | ||
| 134 | + .asr-text { | ||
| 135 | + margin-bottom: 10px; | ||
| 136 | + padding: 10px; | ||
| 137 | + background-color: white; | ||
| 138 | + border-radius: var(--border-radius); | ||
| 139 | + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); | ||
| 140 | + } | ||
| 141 | + | ||
| 142 | + .user-message { | ||
| 143 | + background-color: #e3f2fd; | ||
| 144 | + border-left: 4px solid var(--primary-color); | ||
| 145 | + } | ||
| 146 | + | ||
| 147 | + .system-message { | ||
| 148 | + background-color: #f1f8e9; | ||
| 149 | + border-left: 4px solid #8bc34a; | ||
| 150 | + } | ||
| 151 | + | ||
| 152 | + .recording-indicator { | ||
| 153 | + position: absolute; | ||
| 154 | + top: 15px; | ||
| 155 | + right: 15px; | ||
| 156 | + background-color: rgba(220, 53, 69, 0.8); | ||
| 157 | + color: white; | ||
| 158 | + padding: 5px 10px; | ||
| 159 | + border-radius: 20px; | ||
| 160 | + font-size: 0.8rem; | ||
| 161 | + display: none; | ||
| 162 | + } | ||
| 163 | + | ||
| 164 | + .recording-indicator.active { | ||
| 165 | + display: flex; | ||
| 166 | + align-items: center; | ||
| 167 | + } | ||
| 168 | + | ||
| 169 | + .recording-indicator .blink { | ||
| 170 | + width: 10px; | ||
| 171 | + height: 10px; | ||
| 172 | + background-color: #fff; | ||
| 173 | + border-radius: 50%; | ||
| 174 | + margin-right: 5px; | ||
| 175 | + animation: blink 1s infinite; | ||
| 176 | + } | ||
| 177 | + | ||
| 178 | + @keyframes blink { | ||
| 179 | + 0% { opacity: 1; } | ||
| 180 | + 50% { opacity: 0.3; } | ||
| 181 | + 100% { opacity: 1; } | ||
| 182 | + } | ||
| 183 | + | ||
| 184 | + .mode-switch { | ||
| 185 | + margin-bottom: 20px; | ||
| 186 | + } | ||
| 187 | + | ||
| 188 | + .nav-tabs .nav-link { | ||
| 189 | + color: var(--text-color); | ||
| 190 | + border: none; | ||
| 191 | + padding: 10px 20px; | ||
| 192 | + border-radius: var(--border-radius) var(--border-radius) 0 0; | ||
| 193 | + } | ||
| 194 | + | ||
| 195 | + .nav-tabs .nav-link.active { | ||
| 196 | + color: var(--primary-color); | ||
| 197 | + background-color: var(--card-bg); | ||
| 198 | + border-bottom: 3px solid var(--primary-color); | ||
| 199 | + font-weight: 600; | ||
| 200 | + } | ||
| 201 | + | ||
| 202 | + .tab-content { | ||
| 203 | + padding: 20px; | ||
| 204 | + background-color: var(--card-bg); | ||
| 205 | + border-radius: 0 0 var(--border-radius) var(--border-radius); | ||
| 206 | + } | ||
| 207 | + | ||
| 208 | + .settings-panel { | ||
| 209 | + padding: 15px; | ||
| 210 | + background-color: #f8f9fa; | ||
| 211 | + border-radius: var(--border-radius); | ||
| 212 | + margin-top: 15px; | ||
| 213 | + } | ||
| 214 | + | ||
| 215 | + .footer { | ||
| 216 | + text-align: center; | ||
| 217 | + margin-top: 30px; | ||
| 218 | + padding: 20px 0; | ||
| 219 | + color: #6c757d; | ||
| 220 | + font-size: 0.9rem; | ||
| 221 | + } | ||
| 222 | + | ||
| 223 | + .voice-record-btn { | ||
| 224 | + width: 60px; | ||
| 225 | + height: 60px; | ||
| 226 | + border-radius: 50%; | ||
| 227 | + background-color: var(--primary-color); | ||
| 228 | + color: white; | ||
| 229 | + display: flex; | ||
| 230 | + justify-content: center; | ||
| 231 | + align-items: center; | ||
| 232 | + cursor: pointer; | ||
| 233 | + transition: all 0.2s ease; | ||
| 234 | + box-shadow: 0 2px 5px rgba(0,0,0,0.2); | ||
| 235 | + margin: 0 auto; | ||
| 236 | + } | ||
| 237 | + | ||
| 238 | + .voice-record-btn:hover { | ||
| 239 | + background-color: var(--secondary-color); | ||
| 240 | + transform: scale(1.05); | ||
| 241 | + } | ||
| 242 | + | ||
| 243 | + .voice-record-btn:active { | ||
| 244 | + background-color: #dc3545; | ||
| 245 | + transform: scale(0.95); | ||
| 246 | + } | ||
| 247 | + | ||
| 248 | + .voice-record-btn i { | ||
| 249 | + font-size: 24px; | ||
| 250 | + } | ||
| 251 | + | ||
| 252 | + .voice-record-label { | ||
| 253 | + text-align: center; | ||
| 254 | + margin-top: 10px; | ||
| 255 | + font-size: 14px; | ||
| 256 | + color: #6c757d; | ||
| 257 | + } | ||
| 258 | + | ||
| 259 | + .video-size-control { | ||
| 260 | + margin-top: 15px; | ||
| 261 | + } | ||
| 262 | + | ||
| 263 | + .recording-pulse { | ||
| 264 | + animation: pulse 1.5s infinite; | ||
| 265 | + } | ||
| 266 | + | ||
| 267 | + @keyframes pulse { | ||
| 268 | + 0% { | ||
| 269 | + box-shadow: 0 0 0 0 rgba(220, 53, 69, 0.7); | ||
| 270 | + } | ||
| 271 | + 70% { | ||
| 272 | + box-shadow: 0 0 0 15px rgba(220, 53, 69, 0); | ||
| 273 | + } | ||
| 274 | + 100% { | ||
| 275 | + box-shadow: 0 0 0 0 rgba(220, 53, 69, 0); | ||
| 276 | + } | ||
| 277 | + } | ||
| 278 | + </style> | ||
| 279 | +</head> | ||
| 280 | +<body> | ||
| 281 | + <div class="dashboard-container"> | ||
| 282 | + <div class="row"> | ||
| 283 | + <div class="col-12"> | ||
| 284 | + <h1 class="text-center mb-4">livetalking数字人交互平台</h1> | ||
| 285 | + </div> | ||
| 286 | + </div> | ||
| 287 | + | ||
| 288 | + <div class="row"> | ||
| 289 | + <!-- 视频区域 --> | ||
| 290 | + <div class="col-lg-8"> | ||
| 291 | + <div class="card"> | ||
| 292 | + <div class="card-header d-flex justify-content-between align-items-center"> | ||
| 293 | + <div> | ||
| 294 | + <span class="status-indicator status-disconnected" id="connection-status"></span> | ||
| 295 | + <span id="status-text">未连接</span> | ||
| 296 | + </div> | ||
| 297 | + </div> | ||
| 298 | + <div class="card-body p-0"> | ||
| 299 | + <div class="video-container"> | ||
| 300 | + <video id="video" autoplay playsinline></video> | ||
| 301 | + <div class="recording-indicator" id="recording-indicator"> | ||
| 302 | + <div class="blink"></div> | ||
| 303 | + <span>录制中</span> | ||
| 304 | + </div> | ||
| 305 | + </div> | ||
| 306 | + | ||
| 307 | + <div class="controls-container"> | ||
| 308 | + <div class="row"> | ||
| 309 | + <div class="col-md-6 mb-3"> | ||
| 310 | + <button class="btn btn-primary w-100" id="start"> | ||
| 311 | + <i class="bi bi-play-fill"></i> 开始连接 | ||
| 312 | + </button> | ||
| 313 | + <button class="btn btn-danger w-100" id="stop" style="display: none;"> | ||
| 314 | + <i class="bi bi-stop-fill"></i> 停止连接 | ||
| 315 | + </button> | ||
| 316 | + </div> | ||
| 317 | + <div class="col-md-6 mb-3"> | ||
| 318 | + <div class="d-flex"> | ||
| 319 | + <button class="btn btn-outline-primary flex-grow-1 me-2" id="btn_start_record"> | ||
| 320 | + <i class="bi bi-record-fill"></i> 开始录制 | ||
| 321 | + </button> | ||
| 322 | + <button class="btn btn-outline-danger flex-grow-1" id="btn_stop_record" disabled> | ||
| 323 | + <i class="bi bi-stop-fill"></i> 停止录制 | ||
| 324 | + </button> | ||
| 325 | + </div> | ||
| 326 | + </div> | ||
| 327 | + </div> | ||
| 328 | + | ||
| 329 | + <div class="row"> | ||
| 330 | + <div class="col-12"> | ||
| 331 | + <div class="video-size-control"> | ||
| 332 | + <label for="video-size-slider" class="form-label">视频大小调节: <span id="video-size-value">100%</span></label> | ||
| 333 | + <input type="range" class="form-range" id="video-size-slider" min="50" max="150" value="100"> | ||
| 334 | + </div> | ||
| 335 | + </div> | ||
| 336 | + </div> | ||
| 337 | + | ||
| 338 | + <div class="settings-panel mt-3"> | ||
| 339 | + <div class="row"> | ||
| 340 | + <div class="col-md-12"> | ||
| 341 | + <div class="form-check form-switch mb-3"> | ||
| 342 | + <input class="form-check-input" type="checkbox" id="use-stun"> | ||
| 343 | + <label class="form-check-label" for="use-stun">使用STUN服务器</label> | ||
| 344 | + </div> | ||
| 345 | + </div> | ||
| 346 | + </div> | ||
| 347 | + </div> | ||
| 348 | + </div> | ||
| 349 | + </div> | ||
| 350 | + </div> | ||
| 351 | + </div> | ||
| 352 | + | ||
| 353 | + <!-- 右侧交互 --> | ||
| 354 | + <div class="col-lg-4"> | ||
| 355 | + <div class="card"> | ||
| 356 | + <div class="card-header"> | ||
| 357 | + <ul class="nav nav-tabs card-header-tabs" id="interaction-tabs" role="tablist"> | ||
| 358 | + <li class="nav-item" role="presentation"> | ||
| 359 | + <button class="nav-link active" id="chat-tab" data-bs-toggle="tab" data-bs-target="#chat" type="button" role="tab" aria-controls="chat" aria-selected="true">对话模式</button> | ||
| 360 | + </li> | ||
| 361 | + <li class="nav-item" role="presentation"> | ||
| 362 | + <button class="nav-link" id="tts-tab" data-bs-toggle="tab" data-bs-target="#tts" type="button" role="tab" aria-controls="tts" aria-selected="false">朗读模式</button> | ||
| 363 | + </li> | ||
| 364 | + </ul> | ||
| 365 | + </div> | ||
| 366 | + <div class="card-body"> | ||
| 367 | + <div class="tab-content" id="interaction-tabs-content"> | ||
| 368 | + <!-- 对话模式 --> | ||
| 369 | + <div class="tab-pane fade show active" id="chat" role="tabpanel" aria-labelledby="chat-tab"> | ||
| 370 | + <div class="asr-container mb-3" id="chat-messages"> | ||
| 371 | + <div class="asr-text system-message"> | ||
| 372 | + 系统: 欢迎使用livetalking,请点击"开始连接"按钮开始对话。 | ||
| 373 | + </div> | ||
| 374 | + </div> | ||
| 375 | + | ||
| 376 | + <form id="chat-form"> | ||
| 377 | + <div class="input-group mb-3"> | ||
| 378 | + <textarea class="form-control" id="chat-message" rows="3" placeholder="输入您想对数字人说的话..."></textarea> | ||
| 379 | + <button class="btn btn-primary" type="submit"> | ||
| 380 | + <i class="bi bi-send"></i> 发送 | ||
| 381 | + </button> | ||
| 382 | + </div> | ||
| 383 | + </form> | ||
| 384 | + | ||
| 385 | + <!-- 按住说话按钮 --> | ||
| 386 | + <div class="voice-record-btn" id="voice-record-btn"> | ||
| 387 | + <i class="bi bi-mic-fill"></i> | ||
| 388 | + </div> | ||
| 389 | + <div class="voice-record-label">按住说话,松开发送</div> | ||
| 390 | + </div> | ||
| 391 | + | ||
| 392 | + <!-- 朗读模式 --> | ||
| 393 | + <div class="tab-pane fade" id="tts" role="tabpanel" aria-labelledby="tts-tab"> | ||
| 394 | + <form id="echo-form"> | ||
| 395 | + <div class="mb-3"> | ||
| 396 | + <label for="message" class="form-label">输入要朗读的文本</label> | ||
| 397 | + <textarea class="form-control" id="message" rows="6" placeholder="输入您想让数字人朗读的文字..."></textarea> | ||
| 398 | + </div> | ||
| 399 | + <button type="submit" class="btn btn-primary w-100"> | ||
| 400 | + <i class="bi bi-volume-up"></i> 朗读文本 | ||
| 401 | + </button> | ||
| 402 | + </form> | ||
| 403 | + </div> | ||
| 404 | + </div> | ||
| 405 | + </div> | ||
| 406 | + </div> | ||
| 407 | + </div> | ||
| 408 | + </div> | ||
| 409 | + | ||
| 410 | + <div class="footer"> | ||
| 411 | + <p>Made with ❤️ by Marstaos | Frontend & Performance Optimization</p> | ||
| 412 | + </div> | ||
| 413 | + </div> | ||
| 414 | + | ||
| 415 | + <!-- 隐藏的会话ID --> | ||
| 416 | + <input type="hidden" id="sessionid" value="0"> | ||
| 417 | + | ||
| 418 | + | ||
| 419 | + <script src="client.js"></script> | ||
| 420 | + <script src="srs.sdk.js"></script> | ||
| 421 | + <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script> | ||
| 422 | + <script src="https://code.jquery.com/jquery-3.6.0.min.js"></script> | ||
| 423 | + <script> | ||
| 424 | + $(document).ready(function() { | ||
| 425 | + $('#video-size-slider').on('input', function() { | ||
| 426 | + const value = $(this).val(); | ||
| 427 | + $('#video-size-value').text(value + '%'); | ||
| 428 | + $('#video').css('width', value + '%'); | ||
| 429 | + }); | ||
| 430 | + function updateConnectionStatus(status) { | ||
| 431 | + const statusIndicator = $('#connection-status'); | ||
| 432 | + const statusText = $('#status-text'); | ||
| 433 | + | ||
| 434 | + statusIndicator.removeClass('status-connected status-disconnected status-connecting'); | ||
| 435 | + | ||
| 436 | + switch(status) { | ||
| 437 | + case 'connected': | ||
| 438 | + statusIndicator.addClass('status-connected'); | ||
| 439 | + statusText.text('已连接'); | ||
| 440 | + break; | ||
| 441 | + case 'connecting': | ||
| 442 | + statusIndicator.addClass('status-connecting'); | ||
| 443 | + statusText.text('连接中...'); | ||
| 444 | + break; | ||
| 445 | + case 'disconnected': | ||
| 446 | + default: | ||
| 447 | + statusIndicator.addClass('status-disconnected'); | ||
| 448 | + statusText.text('未连接'); | ||
| 449 | + break; | ||
| 450 | + } | ||
| 451 | + } | ||
| 452 | + | ||
| 453 | + // 添加聊天消息 | ||
| 454 | + function addChatMessage(message, type = 'user') { | ||
| 455 | + const messagesContainer = $('#chat-messages'); | ||
| 456 | + const messageClass = type === 'user' ? 'user-message' : 'system-message'; | ||
| 457 | + const sender = type === 'user' ? '您' : '数字人'; | ||
| 458 | + | ||
| 459 | + const messageElement = $(` | ||
| 460 | + <div class="asr-text ${messageClass}"> | ||
| 461 | + ${sender}: ${message} | ||
| 462 | + </div> | ||
| 463 | + `); | ||
| 464 | + | ||
| 465 | + messagesContainer.append(messageElement); | ||
| 466 | + messagesContainer.scrollTop(messagesContainer[0].scrollHeight); | ||
| 467 | + } | ||
| 468 | + | ||
| 469 | + // 开始/停止按钮 | ||
| 470 | + $('#start').click(function() { | ||
| 471 | + updateConnectionStatus('connecting'); | ||
| 472 | + start(); | ||
| 473 | + $(this).hide(); | ||
| 474 | + $('#stop').show(); | ||
| 475 | + | ||
| 476 | + // 添加定时器检查视频流是否已加载 | ||
| 477 | + let connectionCheckTimer = setInterval(function() { | ||
| 478 | + const video = document.getElementById('video'); | ||
| 479 | + // 检查视频是否有数据 | ||
| 480 | + if (video.readyState >= 3 && video.videoWidth > 0) { | ||
| 481 | + updateConnectionStatus('connected'); | ||
| 482 | + clearInterval(connectionCheckTimer); | ||
| 483 | + } | ||
| 484 | + }, 2000); // 每2秒检查一次 | ||
| 485 | + | ||
| 486 | + // 60秒后如果还是连接中状态,就停止检查 | ||
| 487 | + setTimeout(function() { | ||
| 488 | + if (connectionCheckTimer) { | ||
| 489 | + clearInterval(connectionCheckTimer); | ||
| 490 | + } | ||
| 491 | + }, 60000); | ||
| 492 | + }); | ||
| 493 | + | ||
| 494 | + $('#stop').click(function() { | ||
| 495 | + stop(); | ||
| 496 | + $(this).hide(); | ||
| 497 | + $('#start').show(); | ||
| 498 | + updateConnectionStatus('disconnected'); | ||
| 499 | + }); | ||
| 500 | + | ||
| 501 | + // 录制功能 | ||
| 502 | + $('#btn_start_record').click(function() { | ||
| 503 | + console.log('Starting recording...'); | ||
| 504 | + fetch('/record', { | ||
| 505 | + body: JSON.stringify({ | ||
| 506 | + type: 'start_record', | ||
| 507 | + sessionid: parseInt(document.getElementById('sessionid').value), | ||
| 508 | + }), | ||
| 509 | + headers: { | ||
| 510 | + 'Content-Type': 'application/json' | ||
| 511 | + }, | ||
| 512 | + method: 'POST' | ||
| 513 | + }).then(function(response) { | ||
| 514 | + if (response.ok) { | ||
| 515 | + console.log('Recording started.'); | ||
| 516 | + $('#btn_start_record').prop('disabled', true); | ||
| 517 | + $('#btn_stop_record').prop('disabled', false); | ||
| 518 | + $('#recording-indicator').addClass('active'); | ||
| 519 | + } else { | ||
| 520 | + console.error('Failed to start recording.'); | ||
| 521 | + } | ||
| 522 | + }).catch(function(error) { | ||
| 523 | + console.error('Error:', error); | ||
| 524 | + }); | ||
| 525 | + }); | ||
| 526 | + | ||
| 527 | + $('#btn_stop_record').click(function() { | ||
| 528 | + console.log('Stopping recording...'); | ||
| 529 | + fetch('/record', { | ||
| 530 | + body: JSON.stringify({ | ||
| 531 | + type: 'end_record', | ||
| 532 | + sessionid: parseInt(document.getElementById('sessionid').value), | ||
| 533 | + }), | ||
| 534 | + headers: { | ||
| 535 | + 'Content-Type': 'application/json' | ||
| 536 | + }, | ||
| 537 | + method: 'POST' | ||
| 538 | + }).then(function(response) { | ||
| 539 | + if (response.ok) { | ||
| 540 | + console.log('Recording stopped.'); | ||
| 541 | + $('#btn_start_record').prop('disabled', false); | ||
| 542 | + $('#btn_stop_record').prop('disabled', true); | ||
| 543 | + $('#recording-indicator').removeClass('active'); | ||
| 544 | + } else { | ||
| 545 | + console.error('Failed to stop recording.'); | ||
| 546 | + } | ||
| 547 | + }).catch(function(error) { | ||
| 548 | + console.error('Error:', error); | ||
| 549 | + }); | ||
| 550 | + }); | ||
| 551 | + | ||
| 552 | + $('#echo-form').on('submit', function(e) { | ||
| 553 | + e.preventDefault(); | ||
| 554 | + var message = $('#message').val(); | ||
| 555 | + if (!message.trim()) return; | ||
| 556 | + | ||
| 557 | + console.log('Sending echo message:', message); | ||
| 558 | + | ||
| 559 | + fetch('/human', { | ||
| 560 | + body: JSON.stringify({ | ||
| 561 | + text: message, | ||
| 562 | + type: 'echo', | ||
| 563 | + interrupt: true, | ||
| 564 | + sessionid: parseInt(document.getElementById('sessionid').value), | ||
| 565 | + }), | ||
| 566 | + headers: { | ||
| 567 | + 'Content-Type': 'application/json' | ||
| 568 | + }, | ||
| 569 | + method: 'POST' | ||
| 570 | + }); | ||
| 571 | + | ||
| 572 | + $('#message').val(''); | ||
| 573 | + addChatMessage(`已发送朗读请求: "${message}"`, 'system'); | ||
| 574 | + }); | ||
| 575 | + | ||
| 576 | + // 聊天模式表单提交 | ||
| 577 | + $('#chat-form').on('submit', function(e) { | ||
| 578 | + e.preventDefault(); | ||
| 579 | + var message = $('#chat-message').val(); | ||
| 580 | + if (!message.trim()) return; | ||
| 581 | + | ||
| 582 | + console.log('Sending chat message:', message); | ||
| 583 | + | ||
| 584 | + fetch('/human', { | ||
| 585 | + body: JSON.stringify({ | ||
| 586 | + text: message, | ||
| 587 | + type: 'chat', | ||
| 588 | + interrupt: true, | ||
| 589 | + sessionid: parseInt(document.getElementById('sessionid').value), | ||
| 590 | + }), | ||
| 591 | + headers: { | ||
| 592 | + 'Content-Type': 'application/json' | ||
| 593 | + }, | ||
| 594 | + method: 'POST' | ||
| 595 | + }); | ||
| 596 | + | ||
| 597 | + addChatMessage(message, 'user'); | ||
| 598 | + $('#chat-message').val(''); | ||
| 599 | + }); | ||
| 600 | + | ||
| 601 | + // 按住说话功能 | ||
| 602 | + let mediaRecorder; | ||
| 603 | + let audioChunks = []; | ||
| 604 | + let isRecording = false; | ||
| 605 | + let recognition; | ||
| 606 | + | ||
| 607 | + // 检查浏览器是否支持语音识别 | ||
| 608 | + const isSpeechRecognitionSupported = 'webkitSpeechRecognition' in window || 'SpeechRecognition' in window; | ||
| 609 | + | ||
| 610 | + if (isSpeechRecognitionSupported) { | ||
| 611 | + recognition = new (window.SpeechRecognition || window.webkitSpeechRecognition)(); | ||
| 612 | + recognition.continuous = true; | ||
| 613 | + recognition.interimResults = true; | ||
| 614 | + recognition.lang = 'zh-CN'; | ||
| 615 | + | ||
| 616 | + recognition.onresult = function(event) { | ||
| 617 | + let interimTranscript = ''; | ||
| 618 | + let finalTranscript = ''; | ||
| 619 | + | ||
| 620 | + for (let i = event.resultIndex; i < event.results.length; ++i) { | ||
| 621 | + if (event.results[i].isFinal) { | ||
| 622 | + finalTranscript += event.results[i][0].transcript; | ||
| 623 | + } else { | ||
| 624 | + interimTranscript += event.results[i][0].transcript; | ||
| 625 | + $('#chat-message').val(interimTranscript); | ||
| 626 | + } | ||
| 627 | + } | ||
| 628 | + | ||
| 629 | + if (finalTranscript) { | ||
| 630 | + $('#chat-message').val(finalTranscript); | ||
| 631 | + } | ||
| 632 | + }; | ||
| 633 | + | ||
| 634 | + recognition.onerror = function(event) { | ||
| 635 | + console.error('语音识别错误:', event.error); | ||
| 636 | + }; | ||
| 637 | + } | ||
| 638 | + | ||
| 639 | + // 按住说话按钮事件 | ||
| 640 | + $('#voice-record-btn').on('mousedown touchstart', function(e) { | ||
| 641 | + e.preventDefault(); | ||
| 642 | + startRecording(); | ||
| 643 | + }).on('mouseup mouseleave touchend', function() { | ||
| 644 | + if (isRecording) { | ||
| 645 | + stopRecording(); | ||
| 646 | + } | ||
| 647 | + }); | ||
| 648 | + | ||
| 649 | + // 开始录音 | ||
| 650 | + function startRecording() { | ||
| 651 | + if (isRecording) return; | ||
| 652 | + | ||
| 653 | + navigator.mediaDevices.getUserMedia({ audio: true }) | ||
| 654 | + .then(function(stream) { | ||
| 655 | + audioChunks = []; | ||
| 656 | + mediaRecorder = new MediaRecorder(stream); | ||
| 657 | + | ||
| 658 | + mediaRecorder.ondataavailable = function(e) { | ||
| 659 | + if (e.data.size > 0) { | ||
| 660 | + audioChunks.push(e.data); | ||
| 661 | + } | ||
| 662 | + }; | ||
| 663 | + | ||
| 664 | + mediaRecorder.start(); | ||
| 665 | + isRecording = true; | ||
| 666 | + | ||
| 667 | + $('#voice-record-btn').addClass('recording-pulse'); | ||
| 668 | + $('#voice-record-btn').css('background-color', '#dc3545'); | ||
| 669 | + | ||
| 670 | + if (recognition) { | ||
| 671 | + recognition.start(); | ||
| 672 | + } | ||
| 673 | + }) | ||
| 674 | + .catch(function(error) { | ||
| 675 | + console.error('无法访问麦克风:', error); | ||
| 676 | + alert('无法访问麦克风,请检查浏览器权限设置。'); | ||
| 677 | + }); | ||
| 678 | + } | ||
| 679 | + | ||
| 680 | + function stopRecording() { | ||
| 681 | + if (!isRecording) return; | ||
| 682 | + | ||
| 683 | + mediaRecorder.stop(); | ||
| 684 | + isRecording = false; | ||
| 685 | + | ||
| 686 | + // 停止所有音轨 | ||
| 687 | + mediaRecorder.stream.getTracks().forEach(track => track.stop()); | ||
| 688 | + | ||
| 689 | + // 视觉反馈恢复 | ||
| 690 | + $('#voice-record-btn').removeClass('recording-pulse'); | ||
| 691 | + $('#voice-record-btn').css('background-color', ''); | ||
| 692 | + | ||
| 693 | + // 停止语音识别 | ||
| 694 | + if (recognition) { | ||
| 695 | + recognition.stop(); | ||
| 696 | + } | ||
| 697 | + | ||
| 698 | + // 获取识别的文本并发送 | ||
| 699 | + setTimeout(function() { | ||
| 700 | + const recognizedText = $('#chat-message').val().trim(); | ||
| 701 | + if (recognizedText) { | ||
| 702 | + // 发送识别的文本 | ||
| 703 | + fetch('/human', { | ||
| 704 | + body: JSON.stringify({ | ||
| 705 | + text: recognizedText, | ||
| 706 | + type: 'chat', | ||
| 707 | + interrupt: true, | ||
| 708 | + sessionid: parseInt(document.getElementById('sessionid').value), | ||
| 709 | + }), | ||
| 710 | + headers: { | ||
| 711 | + 'Content-Type': 'application/json' | ||
| 712 | + }, | ||
| 713 | + method: 'POST' | ||
| 714 | + }); | ||
| 715 | + | ||
| 716 | + addChatMessage(recognizedText, 'user'); | ||
| 717 | + $('#chat-message').val(''); | ||
| 718 | + } | ||
| 719 | + }, 500); | ||
| 720 | + } | ||
| 721 | + | ||
| 722 | + // WebRTC 相关功能 | ||
| 723 | + if (typeof window.onWebRTCConnected === 'function') { | ||
| 724 | + const originalOnConnected = window.onWebRTCConnected; | ||
| 725 | + window.onWebRTCConnected = function() { | ||
| 726 | + updateConnectionStatus('connected'); | ||
| 727 | + if (originalOnConnected) originalOnConnected(); | ||
| 728 | + }; | ||
| 729 | + } else { | ||
| 730 | + window.onWebRTCConnected = function() { | ||
| 731 | + updateConnectionStatus('connected'); | ||
| 732 | + }; | ||
| 733 | + } | ||
| 734 | + | ||
| 735 | + // 当连接断开时更新状态 | ||
| 736 | + if (typeof window.onWebRTCDisconnected === 'function') { | ||
| 737 | + const originalOnDisconnected = window.onWebRTCDisconnected; | ||
| 738 | + window.onWebRTCDisconnected = function() { | ||
| 739 | + updateConnectionStatus('disconnected'); | ||
| 740 | + if (originalOnDisconnected) originalOnDisconnected(); | ||
| 741 | + }; | ||
| 742 | + } else { | ||
| 743 | + window.onWebRTCDisconnected = function() { | ||
| 744 | + updateConnectionStatus('disconnected'); | ||
| 745 | + }; | ||
| 746 | + } | ||
| 747 | + | ||
| 748 | + // SRS WebRTC播放功能 | ||
| 749 | + var sdk = null; // 全局处理器,用于在重新发布时进行清理 | ||
| 750 | + | ||
| 751 | + function startPlay() { | ||
| 752 | + // 关闭之前的连接 | ||
| 753 | + if (sdk) { | ||
| 754 | + sdk.close(); | ||
| 755 | + } | ||
| 756 | + | ||
| 757 | + sdk = new SrsRtcWhipWhepAsync(); | ||
| 758 | + $('#video').prop('srcObject', sdk.stream); | ||
| 759 | + | ||
| 760 | + var host = window.location.hostname; | ||
| 761 | + var url = "http://" + host + ":1985/rtc/v1/whep/?app=live&stream=livestream"; | ||
| 762 | + | ||
| 763 | + sdk.play(url).then(function(session) { | ||
| 764 | + console.log('WebRTC播放已启动,会话ID:', session.sessionid); | ||
| 765 | + }).catch(function(reason) { | ||
| 766 | + sdk.close(); | ||
| 767 | + console.error('WebRTC播放失败:', reason); | ||
| 768 | + }); | ||
| 769 | + } | ||
| 770 | + }); | ||
| 771 | + </script> | ||
| 772 | +</body> | ||
| 773 | +</html> |
-
Please register or login to post a comment