Showing
46 changed files
with
4734 additions
and
0 deletions
Too many changes to show.
To preserve performance only 46 of 46+ files are displayed.
.github/FUNDING.yml
0 → 100644
| 1 | +github: [lipku] |
.gitignore
0 → 100644
.vscode/launch.json
0 → 100644
| 1 | +{ | ||
| 2 | + "version": "0.2.0", | ||
| 3 | + "configurations": [ | ||
| 4 | + { | ||
| 5 | + "name": "Python: 带参数调试", | ||
| 6 | + "type": "python", | ||
| 7 | + "request": "launch", | ||
| 8 | + "program": "${workspaceFolder}/app.py", | ||
| 9 | + "args": [ | ||
| 10 | + "--transport", "webrtc", | ||
| 11 | + "--model", "wav2lip", | ||
| 12 | + "--avatar_id", "wav2lip256_avatar1" | ||
| 13 | + ], | ||
| 14 | + "console": "integratedTerminal", | ||
| 15 | + "justMyCode": true | ||
| 16 | + } | ||
| 17 | + ] | ||
| 18 | +} |
Dockerfile
0 → 100644
| 1 | +# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. | ||
| 2 | +# | ||
| 3 | +# NVIDIA CORPORATION and its licensors retain all intellectual property | ||
| 4 | +# and proprietary rights in and to this software, related documentation | ||
| 5 | +# and any modifications thereto. Any use, reproduction, disclosure or | ||
| 6 | +# distribution of this software and related documentation without an express | ||
| 7 | +# license agreement from NVIDIA CORPORATION is strictly prohibited. | ||
| 8 | + | ||
| 9 | +ARG BASE_IMAGE=nvcr.io/nvidia/cuda:11.6.1-cudnn8-devel-ubuntu20.04 | ||
| 10 | +FROM $BASE_IMAGE | ||
| 11 | + | ||
| 12 | +RUN apt-get update -yq --fix-missing \ | ||
| 13 | + && DEBIAN_FRONTEND=noninteractive apt-get install -yq --no-install-recommends \ | ||
| 14 | + pkg-config \ | ||
| 15 | + wget \ | ||
| 16 | + cmake \ | ||
| 17 | + curl \ | ||
| 18 | + git \ | ||
| 19 | + vim | ||
| 20 | + | ||
| 21 | +#ENV PYTHONDONTWRITEBYTECODE=1 | ||
| 22 | +#ENV PYTHONUNBUFFERED=1 | ||
| 23 | + | ||
| 24 | +# nvidia-container-runtime | ||
| 25 | +#ENV NVIDIA_VISIBLE_DEVICES all | ||
| 26 | +#ENV NVIDIA_DRIVER_CAPABILITIES compute,utility,graphics | ||
| 27 | + | ||
| 28 | +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh | ||
| 29 | +RUN sh Miniconda3-latest-Linux-x86_64.sh -b -u -p ~/miniconda3 | ||
| 30 | +RUN ~/miniconda3/bin/conda init | ||
| 31 | +RUN source ~/.bashrc | ||
| 32 | +RUN conda create -n nerfstream python=3.10 | ||
| 33 | +RUN conda activate nerfstream | ||
| 34 | + | ||
| 35 | +RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ | ||
| 36 | +# install depend | ||
| 37 | +RUN conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch | ||
| 38 | +Copy requirements.txt ./ | ||
| 39 | +RUN pip install -r requirements.txt | ||
| 40 | + | ||
| 41 | +# additional libraries | ||
| 42 | +RUN pip install "git+https://github.com/facebookresearch/pytorch3d.git" | ||
| 43 | +RUN pip install tensorflow-gpu==2.8.0 | ||
| 44 | + | ||
| 45 | +RUN pip uninstall protobuf | ||
| 46 | +RUN pip install protobuf==3.20.1 | ||
| 47 | + | ||
| 48 | +RUN conda install ffmpeg | ||
| 49 | +Copy ../python_rtmpstream /python_rtmpstream | ||
| 50 | +WORKDIR /python_rtmpstream/python | ||
| 51 | +RUN pip install . | ||
| 52 | + | ||
| 53 | +Copy ../nerfstream /nerfstream | ||
| 54 | +WORKDIR /nerfstream | ||
| 55 | +CMD ["python3", "app.py"] |
LICENSE
0 → 100644
| 1 | + Apache License | ||
| 2 | + Version 2.0, January 2004 | ||
| 3 | + http://www.apache.org/licenses/ | ||
| 4 | + | ||
| 5 | + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||
| 6 | + | ||
| 7 | + 1. Definitions. | ||
| 8 | + | ||
| 9 | + "License" shall mean the terms and conditions for use, reproduction, | ||
| 10 | + and distribution as defined by Sections 1 through 9 of this document. | ||
| 11 | + | ||
| 12 | + "Licensor" shall mean the copyright owner or entity authorized by | ||
| 13 | + the copyright owner that is granting the License. | ||
| 14 | + | ||
| 15 | + "Legal Entity" shall mean the union of the acting entity and all | ||
| 16 | + other entities that control, are controlled by, or are under common | ||
| 17 | + control with that entity. For the purposes of this definition, | ||
| 18 | + "control" means (i) the power, direct or indirect, to cause the | ||
| 19 | + direction or management of such entity, whether by contract or | ||
| 20 | + otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||
| 21 | + outstanding shares, or (iii) beneficial ownership of such entity. | ||
| 22 | + | ||
| 23 | + "You" (or "Your") shall mean an individual or Legal Entity | ||
| 24 | + exercising permissions granted by this License. | ||
| 25 | + | ||
| 26 | + "Source" form shall mean the preferred form for making modifications, | ||
| 27 | + including but not limited to software source code, documentation | ||
| 28 | + source, and configuration files. | ||
| 29 | + | ||
| 30 | + "Object" form shall mean any form resulting from mechanical | ||
| 31 | + transformation or translation of a Source form, including but | ||
| 32 | + not limited to compiled object code, generated documentation, | ||
| 33 | + and conversions to other media types. | ||
| 34 | + | ||
| 35 | + "Work" shall mean the work of authorship, whether in Source or | ||
| 36 | + Object form, made available under the License, as indicated by a | ||
| 37 | + copyright notice that is included in or attached to the work | ||
| 38 | + (an example is provided in the Appendix below). | ||
| 39 | + | ||
| 40 | + "Derivative Works" shall mean any work, whether in Source or Object | ||
| 41 | + form, that is based on (or derived from) the Work and for which the | ||
| 42 | + editorial revisions, annotations, elaborations, or other modifications | ||
| 43 | + represent, as a whole, an original work of authorship. For the purposes | ||
| 44 | + of this License, Derivative Works shall not include works that remain | ||
| 45 | + separable from, or merely link (or bind by name) to the interfaces of, | ||
| 46 | + the Work and Derivative Works thereof. | ||
| 47 | + | ||
| 48 | + "Contribution" shall mean any work of authorship, including | ||
| 49 | + the original version of the Work and any modifications or additions | ||
| 50 | + to that Work or Derivative Works thereof, that is intentionally | ||
| 51 | + submitted to Licensor for inclusion in the Work by the copyright owner | ||
| 52 | + or by an individual or Legal Entity authorized to submit on behalf of | ||
| 53 | + the copyright owner. For the purposes of this definition, "submitted" | ||
| 54 | + means any form of electronic, verbal, or written communication sent | ||
| 55 | + to the Licensor or its representatives, including but not limited to | ||
| 56 | + communication on electronic mailing lists, source code control systems, | ||
| 57 | + and issue tracking systems that are managed by, or on behalf of, the | ||
| 58 | + Licensor for the purpose of discussing and improving the Work, but | ||
| 59 | + excluding communication that is conspicuously marked or otherwise | ||
| 60 | + designated in writing by the copyright owner as "Not a Contribution." | ||
| 61 | + | ||
| 62 | + "Contributor" shall mean Licensor and any individual or Legal Entity | ||
| 63 | + on behalf of whom a Contribution has been received by Licensor and | ||
| 64 | + subsequently incorporated within the Work. | ||
| 65 | + | ||
| 66 | + 2. Grant of Copyright License. Subject to the terms and conditions of | ||
| 67 | + this License, each Contributor hereby grants to You a perpetual, | ||
| 68 | + worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||
| 69 | + copyright license to reproduce, prepare Derivative Works of, | ||
| 70 | + publicly display, publicly perform, sublicense, and distribute the | ||
| 71 | + Work and such Derivative Works in Source or Object form. | ||
| 72 | + | ||
| 73 | + 3. Grant of Patent License. Subject to the terms and conditions of | ||
| 74 | + this License, each Contributor hereby grants to You a perpetual, | ||
| 75 | + worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||
| 76 | + (except as stated in this section) patent license to make, have made, | ||
| 77 | + use, offer to sell, sell, import, and otherwise transfer the Work, | ||
| 78 | + where such license applies only to those patent claims licensable | ||
| 79 | + by such Contributor that are necessarily infringed by their | ||
| 80 | + Contribution(s) alone or by combination of their Contribution(s) | ||
| 81 | + with the Work to which such Contribution(s) was submitted. If You | ||
| 82 | + institute patent litigation against any entity (including a | ||
| 83 | + cross-claim or counterclaim in a lawsuit) alleging that the Work | ||
| 84 | + or a Contribution incorporated within the Work constitutes direct | ||
| 85 | + or contributory patent infringement, then any patent licenses | ||
| 86 | + granted to You under this License for that Work shall terminate | ||
| 87 | + as of the date such litigation is filed. | ||
| 88 | + | ||
| 89 | + 4. Redistribution. You may reproduce and distribute copies of the | ||
| 90 | + Work or Derivative Works thereof in any medium, with or without | ||
| 91 | + modifications, and in Source or Object form, provided that You | ||
| 92 | + meet the following conditions: | ||
| 93 | + | ||
| 94 | + (a) You must give any other recipients of the Work or | ||
| 95 | + Derivative Works a copy of this License; and | ||
| 96 | + | ||
| 97 | + (b) You must cause any modified files to carry prominent notices | ||
| 98 | + stating that You changed the files; and | ||
| 99 | + | ||
| 100 | + (c) You must retain, in the Source form of any Derivative Works | ||
| 101 | + that You distribute, all copyright, patent, trademark, and | ||
| 102 | + attribution notices from the Source form of the Work, | ||
| 103 | + excluding those notices that do not pertain to any part of | ||
| 104 | + the Derivative Works; and | ||
| 105 | + | ||
| 106 | + (d) If the Work includes a "NOTICE" text file as part of its | ||
| 107 | + distribution, then any Derivative Works that You distribute must | ||
| 108 | + include a readable copy of the attribution notices contained | ||
| 109 | + within such NOTICE file, excluding those notices that do not | ||
| 110 | + pertain to any part of the Derivative Works, in at least one | ||
| 111 | + of the following places: within a NOTICE text file distributed | ||
| 112 | + as part of the Derivative Works; within the Source form or | ||
| 113 | + documentation, if provided along with the Derivative Works; or, | ||
| 114 | + within a display generated by the Derivative Works, if and | ||
| 115 | + wherever such third-party notices normally appear. The contents | ||
| 116 | + of the NOTICE file are for informational purposes only and | ||
| 117 | + do not modify the License. You may add Your own attribution | ||
| 118 | + notices within Derivative Works that You distribute, alongside | ||
| 119 | + or as an addendum to the NOTICE text from the Work, provided | ||
| 120 | + that such additional attribution notices cannot be construed | ||
| 121 | + as modifying the License. | ||
| 122 | + | ||
| 123 | + You may add Your own copyright statement to Your modifications and | ||
| 124 | + may provide additional or different license terms and conditions | ||
| 125 | + for use, reproduction, or distribution of Your modifications, or | ||
| 126 | + for any such Derivative Works as a whole, provided Your use, | ||
| 127 | + reproduction, and distribution of the Work otherwise complies with | ||
| 128 | + the conditions stated in this License. | ||
| 129 | + | ||
| 130 | + 5. Submission of Contributions. Unless You explicitly state otherwise, | ||
| 131 | + any Contribution intentionally submitted for inclusion in the Work | ||
| 132 | + by You to the Licensor shall be under the terms and conditions of | ||
| 133 | + this License, without any additional terms or conditions. | ||
| 134 | + Notwithstanding the above, nothing herein shall supersede or modify | ||
| 135 | + the terms of any separate license agreement you may have executed | ||
| 136 | + with Licensor regarding such Contributions. | ||
| 137 | + | ||
| 138 | + 6. Trademarks. This License does not grant permission to use the trade | ||
| 139 | + names, trademarks, service marks, or product names of the Licensor, | ||
| 140 | + except as required for reasonable and customary use in describing the | ||
| 141 | + origin of the Work and reproducing the content of the NOTICE file. | ||
| 142 | + | ||
| 143 | + 7. Disclaimer of Warranty. Unless required by applicable law or | ||
| 144 | + agreed to in writing, Licensor provides the Work (and each | ||
| 145 | + Contributor provides its Contributions) on an "AS IS" BASIS, | ||
| 146 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||
| 147 | + implied, including, without limitation, any warranties or conditions | ||
| 148 | + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||
| 149 | + PARTICULAR PURPOSE. You are solely responsible for determining the | ||
| 150 | + appropriateness of using or redistributing the Work and assume any | ||
| 151 | + risks associated with Your exercise of permissions under this License. | ||
| 152 | + | ||
| 153 | + 8. Limitation of Liability. In no event and under no legal theory, | ||
| 154 | + whether in tort (including negligence), contract, or otherwise, | ||
| 155 | + unless required by applicable law (such as deliberate and grossly | ||
| 156 | + negligent acts) or agreed to in writing, shall any Contributor be | ||
| 157 | + liable to You for damages, including any direct, indirect, special, | ||
| 158 | + incidental, or consequential damages of any character arising as a | ||
| 159 | + result of this License or out of the use or inability to use the | ||
| 160 | + Work (including but not limited to damages for loss of goodwill, | ||
| 161 | + work stoppage, computer failure or malfunction, or any and all | ||
| 162 | + other commercial damages or losses), even if such Contributor | ||
| 163 | + has been advised of the possibility of such damages. | ||
| 164 | + | ||
| 165 | + 9. Accepting Warranty or Additional Liability. While redistributing | ||
| 166 | + the Work or Derivative Works thereof, You may choose to offer, | ||
| 167 | + and charge a fee for, acceptance of support, warranty, indemnity, | ||
| 168 | + or other liability obligations and/or rights consistent with this | ||
| 169 | + License. However, in accepting such obligations, You may act only | ||
| 170 | + on Your own behalf and on Your sole responsibility, not on behalf | ||
| 171 | + of any other Contributor, and only if You agree to indemnify, | ||
| 172 | + defend, and hold each Contributor harmless for any liability | ||
| 173 | + incurred by, or claims asserted against, such Contributor by reason | ||
| 174 | + of your accepting any such warranty or additional liability. | ||
| 175 | + | ||
| 176 | + END OF TERMS AND CONDITIONS | ||
| 177 | + | ||
| 178 | + APPENDIX: How to apply the Apache License to your work. | ||
| 179 | + | ||
| 180 | + To apply the Apache License to your work, attach the following | ||
| 181 | + boilerplate notice, with the fields enclosed by brackets "[]" | ||
| 182 | + replaced with your own identifying information. (Don't include | ||
| 183 | + the brackets!) The text should be enclosed in the appropriate | ||
| 184 | + comment syntax for the file format. We also recommend that a | ||
| 185 | + file or class name and description of purpose be included on the | ||
| 186 | + same "printed page" as the copyright notice for easier | ||
| 187 | + identification within third-party archives. | ||
| 188 | + | ||
| 189 | + Copyright [livetalking@lipku] | ||
| 190 | + | ||
| 191 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 192 | + you may not use this file except in compliance with the License. | ||
| 193 | + You may obtain a copy of the License at | ||
| 194 | + | ||
| 195 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
| 196 | + | ||
| 197 | + Unless required by applicable law or agreed to in writing, software | ||
| 198 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
| 199 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 200 | + See the License for the specific language governing permissions and | ||
| 201 | + limitations under the License. |
README.md
0 → 100644
| 1 | +Real time interactive streaming digital human, realize audio video synchronous dialogue. It can basically achieve commercial effects. | ||
| 2 | +实时交互流式数字人,实现音视频同步对话。基本可以达到商用效果 | ||
| 3 | + | ||
| 4 | +[ernerf 效果](https://www.bilibili.com/video/BV1PM4m1y7Q2/) [musetalk 效果](https://www.bilibili.com/video/BV1gm421N7vQ/) [wav2lip 效果](https://www.bilibili.com/video/BV1Bw4m1e74P/) | ||
| 5 | + | ||
| 6 | +## 为避免与 3d 数字人混淆,原项目 metahuman-stream 改名为 livetalking,原有链接地址继续可用 | ||
| 7 | + | ||
| 8 | +## News | ||
| 9 | + | ||
| 10 | +- 2024.12.8 完善多并发,显存不随并发数增加 | ||
| 11 | +- 2024.12.21 添加 wav2lip、musetalk 模型预热,解决第一次推理卡顿问题。感谢@heimaojinzhangyz | ||
| 12 | +- 2024.12.28 添加数字人模型 Ultralight-Digital-Human。 感谢@lijihua2017 | ||
| 13 | +- 2025.2.7 添加 fish-speech tts | ||
| 14 | +- 2025.2.21 添加 wav2lip256 开源模型 感谢@不蠢不蠢 | ||
| 15 | +- 2025.3.2 添加腾讯语音合成服务 | ||
| 16 | + | ||
| 17 | +## Features | ||
| 18 | + | ||
| 19 | +1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human | ||
| 20 | +2. 支持声音克隆 | ||
| 21 | +3. 支持数字人说话被打断 | ||
| 22 | +4. 支持全身视频拼接 | ||
| 23 | +5. 支持 rtmp 和 webrtc | ||
| 24 | +6. 支持视频编排:不说话时播放自定义视频 | ||
| 25 | +7. 支持多并发 | ||
| 26 | + | ||
| 27 | +## 1. Installation | ||
| 28 | + | ||
| 29 | +Tested on Ubuntu 20.04, Python3.10, Pytorch 1.12 and CUDA 11.3 | ||
| 30 | + | ||
| 31 | +### 1.1 Install dependency | ||
| 32 | + | ||
| 33 | +```bash | ||
| 34 | +conda create -n nerfstream python=3.10 | ||
| 35 | +conda activate nerfstream | ||
| 36 | +#如果cuda版本不为11.3(运行nvidia-smi确认版本),根据<https://pytorch.org/get-started/previous-versions/>安装对应版本的pytorch | ||
| 37 | +conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch | ||
| 38 | +pip install -r requirements.txt | ||
| 39 | +#如果需要训练ernerf模型,安装下面的库 | ||
| 40 | +# pip install "git+https://github.com/facebookresearch/pytorch3d.git" | ||
| 41 | +# pip install tensorflow-gpu==2.8.0 | ||
| 42 | +# pip install --upgrade "protobuf<=3.20.1" | ||
| 43 | +``` | ||
| 44 | + | ||
| 45 | +安装常见问题[FAQ](https://livetalking-doc.readthedocs.io/en/latest/faq.html) | ||
| 46 | +linux cuda 环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886 | ||
| 47 | + | ||
| 48 | +## 2. Quick Start | ||
| 49 | + | ||
| 50 | +- 下载模型 | ||
| 51 | + 百度云盘<https://pan.baidu.com/s/1yOsQ06-RIDTJd3HFCw4wtA> 密码: ltua | ||
| 52 | + GoogleDriver <https://drive.google.com/drive/folders/1FOC_MD6wdogyyX_7V1d4NDIO7P9NlSAJ?usp=sharing> | ||
| 53 | + 将 wav2lip256.pth 拷到本项目的 models 下, 重命名为 wav2lip.pth; | ||
| 54 | + 将 wav2lip256_avatar1.tar.gz 解压后整个文件夹拷到本项目的 data/avatars 下 | ||
| 55 | +- 运行 | ||
| 56 | + python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar1 --preload 2 | ||
| 57 | + | ||
| 58 | + 使用 GPU 启动模特 3 号:python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar3 --preload 2 | ||
| 59 | + 用浏览器打开 http://serverip:8010/webrtcapi.html , 先点‘start',播放数字人视频;然后在文本框输入任意文字,提交。数字人播报该段文字 | ||
| 60 | + <font color=red>服务端需要开放端口 tcp:8010; udp:1-65536 </font> | ||
| 61 | + 如果需要商用高清 wav2lip 模型,可以与我联系购买 | ||
| 62 | + | ||
| 63 | +- 快速体验 | ||
| 64 | + <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_GitHub_livetalking1.3> 用该镜像创建实例即可运行成功 | ||
| 65 | + | ||
| 66 | +如果访问不了 huggingface,在运行前 | ||
| 67 | + | ||
| 68 | +``` | ||
| 69 | +export HF_ENDPOINT=https://hf-mirror.com | ||
| 70 | +``` | ||
| 71 | + | ||
| 72 | +## 3. More Usage | ||
| 73 | + | ||
| 74 | +使用说明: <https://livetalking-doc.readthedocs.io/> | ||
| 75 | + | ||
| 76 | +## 4. Docker Run | ||
| 77 | + | ||
| 78 | +不需要前面的安装,直接运行。 | ||
| 79 | + | ||
| 80 | +``` | ||
| 81 | +docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/codewithgpu2/lipku-metahuman-stream:2K9qaMBu8v | ||
| 82 | +``` | ||
| 83 | + | ||
| 84 | +代码在/root/metahuman-stream,先 git pull 拉一下最新代码,然后执行命令同第 2、3 步 | ||
| 85 | + | ||
| 86 | +提供如下镜像 | ||
| 87 | + | ||
| 88 | +- autodl 镜像: <https://www.codewithgpu.com/i/lipku/metahuman-stream/base> | ||
| 89 | + [autodl 教程](https://livetalking-doc.readthedocs.io/en/latest/autodl/README.html) | ||
| 90 | +- ucloud 镜像: <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_livetalking1.3> | ||
| 91 | + 可以开放任意端口,不需要另外部署 srs 服务. | ||
| 92 | + [ucloud 教程](https://livetalking-doc.readthedocs.io/en/latest/ucloud/ucloud.html) | ||
| 93 | + | ||
| 94 | +## 5. TODO | ||
| 95 | + | ||
| 96 | +- [x] 添加 chatgpt 实现数字人对话 | ||
| 97 | +- [x] 声音克隆 | ||
| 98 | +- [x] 数字人静音时用一段视频代替 | ||
| 99 | +- [x] MuseTalk | ||
| 100 | +- [x] Wav2Lip | ||
| 101 | +- [x] Ultralight-Digital-Human | ||
| 102 | + | ||
| 103 | +--- | ||
| 104 | + | ||
| 105 | +如果本项目对你有帮助,帮忙点个 star。也欢迎感兴趣的朋友一起来完善该项目. | ||
| 106 | + | ||
| 107 | +- 知识星球: https://t.zsxq.com/7NMyO 沉淀高质量常见问题、最佳实践经验、问题解答 | ||
| 108 | +- 微信公众号:数字人技术 | ||
| 109 | +  |
app.py
0 → 100644
| 1 | +############################################################################### | ||
| 2 | +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking | ||
| 3 | +# email: lipku@foxmail.com | ||
| 4 | +# | ||
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 6 | +# you may not use this file except in compliance with the License. | ||
| 7 | +# You may obtain a copy of the License at | ||
| 8 | +# | ||
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
| 10 | +# | ||
| 11 | +# Unless required by applicable law or agreed to in writing, software | ||
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 14 | +# See the License for the specific language governing permissions and | ||
| 15 | +# limitations under the License. | ||
| 16 | +############################################################################### | ||
| 17 | + | ||
| 18 | +# server.py | ||
| 19 | +from flask import Flask, render_template,send_from_directory,request, jsonify | ||
| 20 | +from flask_sockets import Sockets | ||
| 21 | +import base64 | ||
| 22 | +import json | ||
| 23 | +#import gevent | ||
| 24 | +#from gevent import pywsgi | ||
| 25 | +#from geventwebsocket.handler import WebSocketHandler | ||
| 26 | +import re | ||
| 27 | +import numpy as np | ||
| 28 | +from threading import Thread,Event | ||
| 29 | +#import multiprocessing | ||
| 30 | +import torch.multiprocessing as mp | ||
| 31 | + | ||
| 32 | +from aiohttp import web | ||
| 33 | +import aiohttp | ||
| 34 | +import aiohttp_cors | ||
| 35 | +from aiortc import RTCPeerConnection, RTCSessionDescription | ||
| 36 | +from aiortc.rtcrtpsender import RTCRtpSender | ||
| 37 | +from webrtc import HumanPlayer | ||
| 38 | +from basereal import BaseReal | ||
| 39 | +from llm import llm_response | ||
| 40 | + | ||
| 41 | +import argparse | ||
| 42 | +import random | ||
| 43 | +import shutil | ||
| 44 | +import asyncio | ||
| 45 | +import torch | ||
| 46 | +from typing import Dict | ||
| 47 | +from logger import logger | ||
| 48 | + | ||
| 49 | + | ||
| 50 | +app = Flask(__name__) | ||
| 51 | +#sockets = Sockets(app) | ||
| 52 | +nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal | ||
| 53 | +opt = None | ||
| 54 | +model = None | ||
| 55 | +avatar = None | ||
| 56 | + | ||
| 57 | + | ||
| 58 | +#####webrtc############################### | ||
| 59 | +pcs = set() | ||
| 60 | + | ||
| 61 | +def randN(N)->int: | ||
| 62 | + '''生成长度为 N的随机数 ''' | ||
| 63 | + min = pow(10, N - 1) | ||
| 64 | + max = pow(10, N) | ||
| 65 | + return random.randint(min, max - 1) | ||
| 66 | + | ||
| 67 | +def build_nerfreal(sessionid:int)->BaseReal: | ||
| 68 | + opt.sessionid=sessionid | ||
| 69 | + if opt.model == 'wav2lip': | ||
| 70 | + from lipreal import LipReal | ||
| 71 | + nerfreal = LipReal(opt,model,avatar) | ||
| 72 | + elif opt.model == 'musetalk': | ||
| 73 | + from musereal import MuseReal | ||
| 74 | + nerfreal = MuseReal(opt,model,avatar) | ||
| 75 | + elif opt.model == 'ernerf': | ||
| 76 | + from nerfreal import NeRFReal | ||
| 77 | + nerfreal = NeRFReal(opt,model,avatar) | ||
| 78 | + elif opt.model == 'ultralight': | ||
| 79 | + from lightreal import LightReal | ||
| 80 | + nerfreal = LightReal(opt,model,avatar) | ||
| 81 | + return nerfreal | ||
| 82 | + | ||
| 83 | +#@app.route('/offer', methods=['POST']) | ||
| 84 | +async def offer(request): | ||
| 85 | + params = await request.json() | ||
| 86 | + offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) | ||
| 87 | + | ||
| 88 | + if len(nerfreals) >= opt.max_session: | ||
| 89 | + logger.info('reach max session') | ||
| 90 | + return -1 | ||
| 91 | + sessionid = randN(6) #len(nerfreals) | ||
| 92 | + logger.info('sessionid=%d',sessionid) | ||
| 93 | + nerfreals[sessionid] = None | ||
| 94 | + nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) | ||
| 95 | + nerfreals[sessionid] = nerfreal | ||
| 96 | + | ||
| 97 | + pc = RTCPeerConnection() | ||
| 98 | + pcs.add(pc) | ||
| 99 | + | ||
| 100 | + @pc.on("connectionstatechange") | ||
| 101 | + async def on_connectionstatechange(): | ||
| 102 | + logger.info("Connection state is %s" % pc.connectionState) | ||
| 103 | + if pc.connectionState == "failed": | ||
| 104 | + await pc.close() | ||
| 105 | + pcs.discard(pc) | ||
| 106 | + del nerfreals[sessionid] | ||
| 107 | + if pc.connectionState == "closed": | ||
| 108 | + pcs.discard(pc) | ||
| 109 | + del nerfreals[sessionid] | ||
| 110 | + | ||
| 111 | + player = HumanPlayer(nerfreals[sessionid]) | ||
| 112 | + audio_sender = pc.addTrack(player.audio) | ||
| 113 | + video_sender = pc.addTrack(player.video) | ||
| 114 | + capabilities = RTCRtpSender.getCapabilities("video") | ||
| 115 | + preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs)) | ||
| 116 | + preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs)) | ||
| 117 | + preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs)) | ||
| 118 | + transceiver = pc.getTransceivers()[1] | ||
| 119 | + transceiver.setCodecPreferences(preferences) | ||
| 120 | + | ||
| 121 | + await pc.setRemoteDescription(offer) | ||
| 122 | + | ||
| 123 | + answer = await pc.createAnswer() | ||
| 124 | + await pc.setLocalDescription(answer) | ||
| 125 | + | ||
| 126 | + #return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) | ||
| 127 | + | ||
| 128 | + return web.Response( | ||
| 129 | + content_type="application/json", | ||
| 130 | + text=json.dumps( | ||
| 131 | + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid":sessionid} | ||
| 132 | + ), | ||
| 133 | + ) | ||
| 134 | + | ||
| 135 | +async def human(request): | ||
| 136 | + params = await request.json() | ||
| 137 | + | ||
| 138 | + sessionid = params.get('sessionid',0) | ||
| 139 | + if params.get('interrupt'): | ||
| 140 | + nerfreals[sessionid].flush_talk() | ||
| 141 | + | ||
| 142 | + if params['type']=='echo': | ||
| 143 | + nerfreals[sessionid].put_msg_txt(params['text']) | ||
| 144 | + elif params['type']=='chat': | ||
| 145 | + res=await asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'],nerfreals[sessionid]) | ||
| 146 | + #nerfreals[sessionid].put_msg_txt(res) | ||
| 147 | + | ||
| 148 | + return web.Response( | ||
| 149 | + content_type="application/json", | ||
| 150 | + text=json.dumps( | ||
| 151 | + {"code": 0, "data":"ok"} | ||
| 152 | + ), | ||
| 153 | + ) | ||
| 154 | + | ||
| 155 | +async def humanaudio(request): | ||
| 156 | + try: | ||
| 157 | + params = await request.json() | ||
| 158 | + sessionid = int(params.get('sessionid', 0)) | ||
| 159 | + fileobj = params['file_url'] | ||
| 160 | + if fileobj.startswith("http"): | ||
| 161 | + async with aiohttp.ClientSession() as session: | ||
| 162 | + async with session.get(fileobj) as response: | ||
| 163 | + if response.status == 200: | ||
| 164 | + filebytes = await response.read() | ||
| 165 | + else: | ||
| 166 | + return web.Response( | ||
| 167 | + content_type="application/json", | ||
| 168 | + text=json.dumps({"code": -1, "msg": "Error downloading file"}) | ||
| 169 | + ) | ||
| 170 | + else: | ||
| 171 | + filename = fileobj.filename | ||
| 172 | + filebytes = fileobj.file.read() | ||
| 173 | + | ||
| 174 | + nerfreals[sessionid].put_audio_file(filebytes) | ||
| 175 | + | ||
| 176 | + return web.Response( | ||
| 177 | + content_type="application/json", | ||
| 178 | + text=json.dumps({"code": 0, "msg": "ok"}) | ||
| 179 | + ) | ||
| 180 | + | ||
| 181 | + except Exception as e: | ||
| 182 | + return web.Response( | ||
| 183 | + content_type="application/json", | ||
| 184 | + text=json.dumps({"code": -1, "msg": "err", "data": str(e)}) | ||
| 185 | + ) | ||
| 186 | + | ||
| 187 | +async def set_audiotype(request): | ||
| 188 | + params = await request.json() | ||
| 189 | + | ||
| 190 | + sessionid = params.get('sessionid',0) | ||
| 191 | + nerfreals[sessionid].set_curr_state(params['audiotype'],params['reinit']) | ||
| 192 | + | ||
| 193 | + return web.Response( | ||
| 194 | + content_type="application/json", | ||
| 195 | + text=json.dumps( | ||
| 196 | + {"code": 0, "data":"ok"} | ||
| 197 | + ), | ||
| 198 | + ) | ||
| 199 | + | ||
| 200 | +async def record(request): | ||
| 201 | + params = await request.json() | ||
| 202 | + | ||
| 203 | + sessionid = params.get('sessionid',0) | ||
| 204 | + if params['type']=='start_record': | ||
| 205 | + # nerfreals[sessionid].put_msg_txt(params['text']) | ||
| 206 | + nerfreals[sessionid].start_recording() | ||
| 207 | + elif params['type']=='end_record': | ||
| 208 | + nerfreals[sessionid].stop_recording() | ||
| 209 | + return web.Response( | ||
| 210 | + content_type="application/json", | ||
| 211 | + text=json.dumps( | ||
| 212 | + {"code": 0, "data":"ok"} | ||
| 213 | + ), | ||
| 214 | + ) | ||
| 215 | + | ||
| 216 | +async def is_speaking(request): | ||
| 217 | + params = await request.json() | ||
| 218 | + | ||
| 219 | + sessionid = params.get('sessionid',0) | ||
| 220 | + return web.Response( | ||
| 221 | + content_type="application/json", | ||
| 222 | + text=json.dumps( | ||
| 223 | + {"code": 0, "data": nerfreals[sessionid].is_speaking()} | ||
| 224 | + ), | ||
| 225 | + ) | ||
| 226 | + | ||
| 227 | + | ||
| 228 | +async def on_shutdown(app): | ||
| 229 | + # close peer connections | ||
| 230 | + coros = [pc.close() for pc in pcs] | ||
| 231 | + await asyncio.gather(*coros) | ||
| 232 | + pcs.clear() | ||
| 233 | + | ||
| 234 | +async def post(url,data): | ||
| 235 | + try: | ||
| 236 | + async with aiohttp.ClientSession() as session: | ||
| 237 | + async with session.post(url,data=data) as response: | ||
| 238 | + return await response.text() | ||
| 239 | + except aiohttp.ClientError as e: | ||
| 240 | + logger.info(f'Error: {e}') | ||
| 241 | + | ||
| 242 | +async def run(push_url,sessionid): | ||
| 243 | + nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) | ||
| 244 | + nerfreals[sessionid] = nerfreal | ||
| 245 | + | ||
| 246 | + pc = RTCPeerConnection() | ||
| 247 | + pcs.add(pc) | ||
| 248 | + | ||
| 249 | + @pc.on("connectionstatechange") | ||
| 250 | + async def on_connectionstatechange(): | ||
| 251 | + logger.info("Connection state is %s" % pc.connectionState) | ||
| 252 | + if pc.connectionState == "failed": | ||
| 253 | + await pc.close() | ||
| 254 | + pcs.discard(pc) | ||
| 255 | + | ||
| 256 | + player = HumanPlayer(nerfreals[sessionid]) | ||
| 257 | + audio_sender = pc.addTrack(player.audio) | ||
| 258 | + video_sender = pc.addTrack(player.video) | ||
| 259 | + | ||
| 260 | + await pc.setLocalDescription(await pc.createOffer()) | ||
| 261 | + answer = await post(push_url,pc.localDescription.sdp) | ||
| 262 | + await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer')) | ||
| 263 | +########################################## | ||
| 264 | +# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' | ||
| 265 | +# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver' | ||
| 266 | +if __name__ == '__main__': | ||
| 267 | + mp.set_start_method('spawn') | ||
| 268 | + parser = argparse.ArgumentParser() | ||
| 269 | + parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source") | ||
| 270 | + parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area") | ||
| 271 | + parser.add_argument('--torso_imgs', type=str, default="", help="torso images path") | ||
| 272 | + | ||
| 273 | + parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye") | ||
| 274 | + | ||
| 275 | + parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use") | ||
| 276 | + parser.add_argument('--workspace', type=str, default='data/video') | ||
| 277 | + parser.add_argument('--seed', type=int, default=0) | ||
| 278 | + | ||
| 279 | + ### training options | ||
| 280 | + parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth') | ||
| 281 | + | ||
| 282 | + parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step") | ||
| 283 | + parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") | ||
| 284 | + parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)") | ||
| 285 | + parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") | ||
| 286 | + parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") | ||
| 287 | + parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") | ||
| 288 | + parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") | ||
| 289 | + | ||
| 290 | + ### loss set | ||
| 291 | + parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps") | ||
| 292 | + parser.add_argument('--amb_aud_loss', type=int, default=1, help="use ambient aud loss") | ||
| 293 | + parser.add_argument('--amb_eye_loss', type=int, default=1, help="use ambient eye loss") | ||
| 294 | + parser.add_argument('--unc_loss', type=int, default=1, help="use uncertainty loss") | ||
| 295 | + parser.add_argument('--lambda_amb', type=float, default=1e-4, help="lambda for ambient loss") | ||
| 296 | + | ||
| 297 | + ### network backbone options | ||
| 298 | + parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") | ||
| 299 | + | ||
| 300 | + parser.add_argument('--bg_img', type=str, default='white', help="background image") | ||
| 301 | + parser.add_argument('--fbg', action='store_true', help="frame-wise bg") | ||
| 302 | + parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") | ||
| 303 | + parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") | ||
| 304 | + parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence") | ||
| 305 | + | ||
| 306 | + parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform") | ||
| 307 | + | ||
| 308 | + ### dataset options | ||
| 309 | + parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") | ||
| 310 | + parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") | ||
| 311 | + # (the default value is for the fox dataset) | ||
| 312 | + parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") | ||
| 313 | + parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3") | ||
| 314 | + parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") | ||
| 315 | + parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") | ||
| 316 | + parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") | ||
| 317 | + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)") | ||
| 318 | + parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)") | ||
| 319 | + parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") | ||
| 320 | + | ||
| 321 | + parser.add_argument('--init_lips', action='store_true', help="init lips region") | ||
| 322 | + parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region") | ||
| 323 | + parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...") | ||
| 324 | + | ||
| 325 | + parser.add_argument('--torso', action='store_true', help="fix head and train torso") | ||
| 326 | + parser.add_argument('--head_ckpt', type=str, default='', help="head model") | ||
| 327 | + | ||
| 328 | + ### GUI options | ||
| 329 | + parser.add_argument('--gui', action='store_true', help="start a GUI") | ||
| 330 | + parser.add_argument('--W', type=int, default=450, help="GUI width") | ||
| 331 | + parser.add_argument('--H', type=int, default=450, help="GUI height") | ||
| 332 | + parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center") | ||
| 333 | + parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy") | ||
| 334 | + parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") | ||
| 335 | + | ||
| 336 | + ### else | ||
| 337 | + parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") | ||
| 338 | + parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)") | ||
| 339 | + parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits") | ||
| 340 | + | ||
| 341 | + parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off") | ||
| 342 | + parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size") | ||
| 343 | + | ||
| 344 | + parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off") | ||
| 345 | + | ||
| 346 | + parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension") | ||
| 347 | + parser.add_argument('--part', action='store_true', help="use partial training data (1/10)") | ||
| 348 | + parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)") | ||
| 349 | + | ||
| 350 | + parser.add_argument('--train_camera', action='store_true', help="optimize camera pose") | ||
| 351 | + parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size") | ||
| 352 | + parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size") | ||
| 353 | + | ||
| 354 | + # asr | ||
| 355 | + parser.add_argument('--asr', action='store_true', help="load asr for real-time app") | ||
| 356 | + parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input") | ||
| 357 | + parser.add_argument('--asr_play', action='store_true', help="play out the audio") | ||
| 358 | + | ||
| 359 | + #parser.add_argument('--asr_model', type=str, default='deepspeech') | ||
| 360 | + parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # | ||
| 361 | + # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') | ||
| 362 | + # parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft') | ||
| 363 | + | ||
| 364 | + parser.add_argument('--asr_save_feats', action='store_true') | ||
| 365 | + # audio FPS | ||
| 366 | + parser.add_argument('--fps', type=int, default=50) | ||
| 367 | + # sliding window left-middle-right length (unit: 20ms) | ||
| 368 | + parser.add_argument('-l', type=int, default=10) | ||
| 369 | + parser.add_argument('-m', type=int, default=8) | ||
| 370 | + parser.add_argument('-r', type=int, default=10) | ||
| 371 | + | ||
| 372 | + parser.add_argument('--fullbody', action='store_true', help="fullbody human") | ||
| 373 | + parser.add_argument('--fullbody_img', type=str, default='data/fullbody/img') | ||
| 374 | + parser.add_argument('--fullbody_width', type=int, default=580) | ||
| 375 | + parser.add_argument('--fullbody_height', type=int, default=1080) | ||
| 376 | + parser.add_argument('--fullbody_offset_x', type=int, default=0) | ||
| 377 | + parser.add_argument('--fullbody_offset_y', type=int, default=0) | ||
| 378 | + | ||
| 379 | + #musetalk opt | ||
| 380 | + parser.add_argument('--avatar_id', type=str, default='avator_1') | ||
| 381 | + parser.add_argument('--bbox_shift', type=int, default=5) | ||
| 382 | + parser.add_argument('--batch_size', type=int, default=16) | ||
| 383 | + | ||
| 384 | + # parser.add_argument('--customvideo', action='store_true', help="custom video") | ||
| 385 | + # parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img') | ||
| 386 | + # parser.add_argument('--customvideo_imgnum', type=int, default=1) | ||
| 387 | + | ||
| 388 | + parser.add_argument('--customvideo_config', type=str, default='') | ||
| 389 | + | ||
| 390 | + parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits cosyvoice | ||
| 391 | + parser.add_argument('--REF_FILE', type=str, default=None) | ||
| 392 | + parser.add_argument('--REF_TEXT', type=str, default=None) | ||
| 393 | + parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 | ||
| 394 | + # parser.add_argument('--CHARACTER', type=str, default='test') | ||
| 395 | + # parser.add_argument('--EMOTION', type=str, default='default') | ||
| 396 | + | ||
| 397 | + parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip | ||
| 398 | + | ||
| 399 | + parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush | ||
| 400 | + parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream | ||
| 401 | + | ||
| 402 | + parser.add_argument('--max_session', type=int, default=1) #multi session count | ||
| 403 | + parser.add_argument('--listenport', type=int, default=8010) | ||
| 404 | + | ||
| 405 | + opt = parser.parse_args() | ||
| 406 | + #app.config.from_object(opt) | ||
| 407 | + #print(app.config) | ||
| 408 | + opt.customopt = [] | ||
| 409 | + if opt.customvideo_config!='': | ||
| 410 | + with open(opt.customvideo_config,'r') as file: | ||
| 411 | + opt.customopt = json.load(file) | ||
| 412 | + | ||
| 413 | + if opt.model == 'ernerf': | ||
| 414 | + from nerfreal import NeRFReal,load_model,load_avatar | ||
| 415 | + model = load_model(opt) | ||
| 416 | + avatar = load_avatar(opt) | ||
| 417 | + | ||
| 418 | + # we still need test_loader to provide audio features for testing. | ||
| 419 | + # for k in range(opt.max_session): | ||
| 420 | + # opt.sessionid=k | ||
| 421 | + # nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model) | ||
| 422 | + # nerfreals.append(nerfreal) | ||
| 423 | + elif opt.model == 'musetalk': | ||
| 424 | + from musereal import MuseReal,load_model,load_avatar,warm_up | ||
| 425 | + logger.info(opt) | ||
| 426 | + model = load_model() | ||
| 427 | + avatar = load_avatar(opt.avatar_id) | ||
| 428 | + warm_up(opt.batch_size,model) | ||
| 429 | + # for k in range(opt.max_session): | ||
| 430 | + # opt.sessionid=k | ||
| 431 | + # nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps) | ||
| 432 | + # nerfreals.append(nerfreal) | ||
| 433 | + elif opt.model == 'wav2lip': | ||
| 434 | + from lipreal import LipReal,load_model,load_avatar,warm_up | ||
| 435 | + logger.info(opt) | ||
| 436 | + model = load_model("./models/wav2lip.pth") | ||
| 437 | + avatar = load_avatar(opt.avatar_id) | ||
| 438 | + warm_up(opt.batch_size,model,256) | ||
| 439 | + # for k in range(opt.max_session): | ||
| 440 | + # opt.sessionid=k | ||
| 441 | + # nerfreal = LipReal(opt,model) | ||
| 442 | + # nerfreals.append(nerfreal) | ||
| 443 | + elif opt.model == 'ultralight': | ||
| 444 | + from lightreal import LightReal,load_model,load_avatar,warm_up | ||
| 445 | + logger.info(opt) | ||
| 446 | + model = load_model(opt) | ||
| 447 | + avatar = load_avatar(opt.avatar_id) | ||
| 448 | + warm_up(opt.batch_size,avatar,160) | ||
| 449 | + | ||
| 450 | + if opt.transport=='rtmp': | ||
| 451 | + thread_quit = Event() | ||
| 452 | + nerfreals[0] = build_nerfreal(0) | ||
| 453 | + rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,)) | ||
| 454 | + rendthrd.start() | ||
| 455 | + | ||
| 456 | + ############################################################################# | ||
| 457 | + appasync = web.Application() | ||
| 458 | + appasync.on_shutdown.append(on_shutdown) | ||
| 459 | + appasync.router.add_post("/offer", offer) | ||
| 460 | + appasync.router.add_post("/human", human) | ||
| 461 | + appasync.router.add_post("/humanaudio", humanaudio) | ||
| 462 | + appasync.router.add_post("/set_audiotype", set_audiotype) | ||
| 463 | + appasync.router.add_post("/record", record) | ||
| 464 | + appasync.router.add_post("/is_speaking", is_speaking) | ||
| 465 | + appasync.router.add_static('/',path='web') | ||
| 466 | + | ||
| 467 | + # Configure default CORS settings. | ||
| 468 | + cors = aiohttp_cors.setup(appasync, defaults={ | ||
| 469 | + "*": aiohttp_cors.ResourceOptions( | ||
| 470 | + allow_credentials=True, | ||
| 471 | + expose_headers="*", | ||
| 472 | + allow_headers="*", | ||
| 473 | + ) | ||
| 474 | + }) | ||
| 475 | + # Configure CORS on all routes. | ||
| 476 | + for route in list(appasync.router.routes()): | ||
| 477 | + cors.add(route) | ||
| 478 | + | ||
| 479 | + pagename='webrtcapi.html' | ||
| 480 | + if opt.transport=='rtmp': | ||
| 481 | + pagename='echoapi.html' | ||
| 482 | + elif opt.transport=='rtcpush': | ||
| 483 | + pagename='rtcpushapi.html' | ||
| 484 | + logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename) | ||
| 485 | + def run_server(runner): | ||
| 486 | + loop = asyncio.new_event_loop() | ||
| 487 | + asyncio.set_event_loop(loop) | ||
| 488 | + loop.run_until_complete(runner.setup()) | ||
| 489 | + site = web.TCPSite(runner, '0.0.0.0', opt.listenport) | ||
| 490 | + loop.run_until_complete(site.start()) | ||
| 491 | + if opt.transport=='rtcpush': | ||
| 492 | + for k in range(opt.max_session): | ||
| 493 | + push_url = opt.push_url | ||
| 494 | + if k!=0: | ||
| 495 | + push_url = opt.push_url+str(k) | ||
| 496 | + loop.run_until_complete(run(push_url,k)) | ||
| 497 | + loop.run_forever() | ||
| 498 | + #Thread(target=run_server, args=(web.AppRunner(appasync),)).start() | ||
| 499 | + run_server(web.AppRunner(appasync)) | ||
| 500 | + | ||
| 501 | + #app.on_shutdown.append(on_shutdown) | ||
| 502 | + #app.router.add_post("/offer", offer) | ||
| 503 | + | ||
| 504 | + # print('start websocket server') | ||
| 505 | + # server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) | ||
| 506 | + # server.serve_forever() | ||
| 507 | + | ||
| 508 | + |
assets/dataflow.png
0 → 100644
14.4 KB
assets/faq.md
0 → 100644
| 1 | +1. pytorch3d安装不成功\ | ||
| 2 | + 下载源码编译 | ||
| 3 | + | ||
| 4 | +```bash | ||
| 5 | +git clone https://github.com/facebookresearch/pytorch3d.git | ||
| 6 | +python setup.py install | ||
| 7 | +``` | ||
| 8 | + | ||
| 9 | +2. websocket连接报错\ | ||
| 10 | + 修改python/site-packages/flask\_sockets.py | ||
| 11 | + | ||
| 12 | +```python | ||
| 13 | +self.url_map.add(Rule(rule, endpoint=f)) 改成 | ||
| 14 | +self.url_map.add(Rule(rule, endpoint=f, websocket=True)) | ||
| 15 | +``` | ||
| 16 | + | ||
| 17 | +3. protobuf版本过高 | ||
| 18 | + | ||
| 19 | +```bash | ||
| 20 | +pip uninstall protobuf | ||
| 21 | +pip install protobuf==3.20.1 | ||
| 22 | +``` | ||
| 23 | + | ||
| 24 | +4. 数字人不眨眼\ | ||
| 25 | +训练模型时添加如下步骤 | ||
| 26 | + | ||
| 27 | +> Obtain AU45 for eyes blinking.\ | ||
| 28 | +> Run FeatureExtraction in OpenFace, rename and move the output CSV file to data/\<ID>/au.csv. | ||
| 29 | + | ||
| 30 | +将au.csv拷到本项目的data目录下 | ||
| 31 | + | ||
| 32 | +5. 数字人添加背景图片 | ||
| 33 | + | ||
| 34 | +```bash | ||
| 35 | +python app.py --bg_img bc.jpg | ||
| 36 | +``` | ||
| 37 | + | ||
| 38 | +6. 用自己训练的模型报错维度不匹配\ | ||
| 39 | +训练模型时用wav2vec提取音频特征 | ||
| 40 | + | ||
| 41 | +```bash | ||
| 42 | +python main.py data/ --workspace workspace/ -O --iters 100000 --asr_model cpierse/wav2vec2-large-xlsr-53-esperanto | ||
| 43 | +``` | ||
| 44 | + | ||
| 45 | +7. rtmp推流时ffmpeg版本不对 | ||
| 46 | +网上版友反馈是需要4.2.2版本。我也不确定具体哪些版本不行。原则是运行一下ffmpeg,打印的信息里需要有libx264,如果没有肯定不行 | ||
| 47 | +``` | ||
| 48 | +--enable-libx264 | ||
| 49 | +``` | ||
| 50 | +8. 替换自己训练的模型 | ||
| 51 | +```python | ||
| 52 | +. | ||
| 53 | +├── data | ||
| 54 | +│ ├── data_kf.json (对应训练数据中的transforms_train.json) | ||
| 55 | +│ ├── au.csv | ||
| 56 | +│ ├── pretrained | ||
| 57 | +│ └── └── ngp_kf.pth (对应训练后的模型ngp_ep00xx.pth) | ||
| 58 | + | ||
| 59 | +``` | ||
| 60 | + | ||
| 61 | + | ||
| 62 | +其他参考 | ||
| 63 | +https://github.com/lipku/metahuman-stream/issues/43#issuecomment-2008930101 | ||
| 64 | + | ||
| 65 | + |
assets/main.png
0 → 100644
182 KB
assets/qrcode_wechat.jpg
0 → 100644
27.1 KB
baseasr.py
0 → 100644
| 1 | +############################################################################### | ||
| 2 | +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking | ||
| 3 | +# email: lipku@foxmail.com | ||
| 4 | +# | ||
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 6 | +# you may not use this file except in compliance with the License. | ||
| 7 | +# You may obtain a copy of the License at | ||
| 8 | +# | ||
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
| 10 | +# | ||
| 11 | +# Unless required by applicable law or agreed to in writing, software | ||
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 14 | +# See the License for the specific language governing permissions and | ||
| 15 | +# limitations under the License. | ||
| 16 | +############################################################################### | ||
| 17 | + | ||
| 18 | +import time | ||
| 19 | +import numpy as np | ||
| 20 | + | ||
| 21 | +import queue | ||
| 22 | +from queue import Queue | ||
| 23 | +import torch.multiprocessing as mp | ||
| 24 | + | ||
| 25 | +from basereal import BaseReal | ||
| 26 | + | ||
| 27 | + | ||
| 28 | +class BaseASR: | ||
| 29 | + def __init__(self, opt, parent:BaseReal = None): | ||
| 30 | + self.opt = opt | ||
| 31 | + self.parent = parent | ||
| 32 | + | ||
| 33 | + self.fps = opt.fps # 20 ms per frame | ||
| 34 | + self.sample_rate = 16000 | ||
| 35 | + self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) | ||
| 36 | + self.queue = Queue() | ||
| 37 | + self.output_queue = mp.Queue() | ||
| 38 | + | ||
| 39 | + self.batch_size = opt.batch_size | ||
| 40 | + | ||
| 41 | + self.frames = [] | ||
| 42 | + self.stride_left_size = opt.l | ||
| 43 | + self.stride_right_size = opt.r | ||
| 44 | + #self.context_size = 10 | ||
| 45 | + self.feat_queue = mp.Queue(2) | ||
| 46 | + | ||
| 47 | + #self.warm_up() | ||
| 48 | + | ||
| 49 | + def flush_talk(self): | ||
| 50 | + self.queue.queue.clear() | ||
| 51 | + | ||
| 52 | + def put_audio_frame(self,audio_chunk,eventpoint=None): #16khz 20ms pcm | ||
| 53 | + self.queue.put((audio_chunk,eventpoint)) | ||
| 54 | + | ||
| 55 | + #return frame:audio pcm; type: 0-normal speak, 1-silence; eventpoint:custom event sync with audio | ||
| 56 | + def get_audio_frame(self): | ||
| 57 | + try: | ||
| 58 | + frame,eventpoint = self.queue.get(block=True,timeout=0.01) | ||
| 59 | + type = 0 | ||
| 60 | + #print(f'[INFO] get frame {frame.shape}') | ||
| 61 | + except queue.Empty: | ||
| 62 | + if self.parent and self.parent.curr_state>1: #播放自定义音频 | ||
| 63 | + frame = self.parent.get_audio_stream(self.parent.curr_state) | ||
| 64 | + type = self.parent.curr_state | ||
| 65 | + else: | ||
| 66 | + frame = np.zeros(self.chunk, dtype=np.float32) | ||
| 67 | + type = 1 | ||
| 68 | + eventpoint = None | ||
| 69 | + | ||
| 70 | + return frame,type,eventpoint | ||
| 71 | + | ||
| 72 | + #return frame:audio pcm; type: 0-normal speak, 1-silence; eventpoint:custom event sync with audio | ||
| 73 | + def get_audio_out(self): | ||
| 74 | + return self.output_queue.get() | ||
| 75 | + | ||
| 76 | + def warm_up(self): | ||
| 77 | + for _ in range(self.stride_left_size + self.stride_right_size): | ||
| 78 | + audio_frame,type,eventpoint=self.get_audio_frame() | ||
| 79 | + self.frames.append(audio_frame) | ||
| 80 | + self.output_queue.put((audio_frame,type,eventpoint)) | ||
| 81 | + for _ in range(self.stride_left_size): | ||
| 82 | + self.output_queue.get() | ||
| 83 | + | ||
| 84 | + def run_step(self): | ||
| 85 | + pass | ||
| 86 | + | ||
| 87 | + def get_next_feat(self,block,timeout): | ||
| 88 | + return self.feat_queue.get(block,timeout) |
basereal.py
0 → 100644
| 1 | +############################################################################### | ||
| 2 | +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking | ||
| 3 | +# email: lipku@foxmail.com | ||
| 4 | +# | ||
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 6 | +# you may not use this file except in compliance with the License. | ||
| 7 | +# You may obtain a copy of the License at | ||
| 8 | +# | ||
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
| 10 | +# | ||
| 11 | +# Unless required by applicable law or agreed to in writing, software | ||
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 14 | +# See the License for the specific language governing permissions and | ||
| 15 | +# limitations under the License. | ||
| 16 | +############################################################################### | ||
| 17 | + | ||
| 18 | +import math | ||
| 19 | +import torch | ||
| 20 | +import numpy as np | ||
| 21 | + | ||
| 22 | +import subprocess | ||
| 23 | +import os | ||
| 24 | +import time | ||
| 25 | +import cv2 | ||
| 26 | +import glob | ||
| 27 | +import resampy | ||
| 28 | + | ||
| 29 | +import queue | ||
| 30 | +from queue import Queue | ||
| 31 | +from threading import Thread, Event | ||
| 32 | +from io import BytesIO | ||
| 33 | +import soundfile as sf | ||
| 34 | + | ||
| 35 | +import av | ||
| 36 | +from fractions import Fraction | ||
| 37 | + | ||
| 38 | +from ttsreal import EdgeTTS,VoitsTTS,XTTS,CosyVoiceTTS,FishTTS,TencentTTS | ||
| 39 | +from logger import logger | ||
| 40 | + | ||
| 41 | +from tqdm import tqdm | ||
| 42 | +def read_imgs(img_list): | ||
| 43 | + frames = [] | ||
| 44 | + logger.info('reading images...') | ||
| 45 | + for img_path in tqdm(img_list): | ||
| 46 | + frame = cv2.imread(img_path) | ||
| 47 | + frames.append(frame) | ||
| 48 | + return frames | ||
| 49 | + | ||
| 50 | +class BaseReal: | ||
| 51 | + def __init__(self, opt): | ||
| 52 | + self.opt = opt | ||
| 53 | + self.sample_rate = 16000 | ||
| 54 | + self.chunk = self.sample_rate // opt.fps # 320 samples per chunk (20ms * 16000 / 1000) | ||
| 55 | + self.sessionid = self.opt.sessionid | ||
| 56 | + | ||
| 57 | + if opt.tts == "edgetts": | ||
| 58 | + self.tts = EdgeTTS(opt,self) | ||
| 59 | + elif opt.tts == "gpt-sovits": | ||
| 60 | + self.tts = VoitsTTS(opt,self) | ||
| 61 | + elif opt.tts == "xtts": | ||
| 62 | + self.tts = XTTS(opt,self) | ||
| 63 | + elif opt.tts == "cosyvoice": | ||
| 64 | + self.tts = CosyVoiceTTS(opt,self) | ||
| 65 | + elif opt.tts == "fishtts": | ||
| 66 | + self.tts = FishTTS(opt,self) | ||
| 67 | + elif opt.tts == "tencent": | ||
| 68 | + self.tts = TencentTTS(opt,self) | ||
| 69 | + | ||
| 70 | + self.speaking = False | ||
| 71 | + | ||
| 72 | + self.recording = False | ||
| 73 | + self._record_video_pipe = None | ||
| 74 | + self._record_audio_pipe = None | ||
| 75 | + self.width = self.height = 0 | ||
| 76 | + | ||
| 77 | + self.curr_state=0 | ||
| 78 | + self.custom_img_cycle = {} | ||
| 79 | + self.custom_audio_cycle = {} | ||
| 80 | + self.custom_audio_index = {} | ||
| 81 | + self.custom_index = {} | ||
| 82 | + self.custom_opt = {} | ||
| 83 | + self.__loadcustom() | ||
| 84 | + | ||
| 85 | + def put_msg_txt(self,msg,eventpoint=None): | ||
| 86 | + self.tts.put_msg_txt(msg,eventpoint) | ||
| 87 | + | ||
| 88 | + def put_audio_frame(self,audio_chunk,eventpoint=None): #16khz 20ms pcm | ||
| 89 | + self.asr.put_audio_frame(audio_chunk,eventpoint) | ||
| 90 | + | ||
| 91 | + def put_audio_file(self,filebyte): | ||
| 92 | + input_stream = BytesIO(filebyte) | ||
| 93 | + stream = self.__create_bytes_stream(input_stream) | ||
| 94 | + streamlen = stream.shape[0] | ||
| 95 | + idx=0 | ||
| 96 | + while streamlen >= self.chunk: #and self.state==State.RUNNING | ||
| 97 | + self.put_audio_frame(stream[idx:idx+self.chunk]) | ||
| 98 | + streamlen -= self.chunk | ||
| 99 | + idx += self.chunk | ||
| 100 | + | ||
| 101 | + def __create_bytes_stream(self,byte_stream): | ||
| 102 | + #byte_stream=BytesIO(buffer) | ||
| 103 | + stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 | ||
| 104 | + logger.info(f'[INFO]put audio stream {sample_rate}: {stream.shape}') | ||
| 105 | + stream = stream.astype(np.float32) | ||
| 106 | + | ||
| 107 | + if stream.ndim > 1: | ||
| 108 | + logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') | ||
| 109 | + stream = stream[:, 0] | ||
| 110 | + | ||
| 111 | + if sample_rate != self.sample_rate and stream.shape[0]>0: | ||
| 112 | + logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') | ||
| 113 | + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) | ||
| 114 | + | ||
| 115 | + return stream | ||
| 116 | + | ||
| 117 | + def flush_talk(self): | ||
| 118 | + self.tts.flush_talk() | ||
| 119 | + self.asr.flush_talk() | ||
| 120 | + | ||
| 121 | + def is_speaking(self)->bool: | ||
| 122 | + return self.speaking | ||
| 123 | + | ||
| 124 | + def __loadcustom(self): | ||
| 125 | + for item in self.opt.customopt: | ||
| 126 | + logger.info(item) | ||
| 127 | + input_img_list = glob.glob(os.path.join(item['imgpath'], '*.[jpJP][pnPN]*[gG]')) | ||
| 128 | + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) | ||
| 129 | + self.custom_img_cycle[item['audiotype']] = read_imgs(input_img_list) | ||
| 130 | + self.custom_audio_cycle[item['audiotype']], sample_rate = sf.read(item['audiopath'], dtype='float32') | ||
| 131 | + self.custom_audio_index[item['audiotype']] = 0 | ||
| 132 | + self.custom_index[item['audiotype']] = 0 | ||
| 133 | + self.custom_opt[item['audiotype']] = item | ||
| 134 | + | ||
| 135 | + def init_customindex(self): | ||
| 136 | + self.curr_state=0 | ||
| 137 | + for key in self.custom_audio_index: | ||
| 138 | + self.custom_audio_index[key]=0 | ||
| 139 | + for key in self.custom_index: | ||
| 140 | + self.custom_index[key]=0 | ||
| 141 | + | ||
| 142 | + def notify(self,eventpoint): | ||
| 143 | + logger.info("notify:%s",eventpoint) | ||
| 144 | + | ||
| 145 | + def start_recording(self): | ||
| 146 | + """开始录制视频""" | ||
| 147 | + if self.recording: | ||
| 148 | + return | ||
| 149 | + | ||
| 150 | + command = ['ffmpeg', | ||
| 151 | + '-y', '-an', | ||
| 152 | + '-f', 'rawvideo', | ||
| 153 | + '-vcodec','rawvideo', | ||
| 154 | + '-pix_fmt', 'bgr24', #像素格式 | ||
| 155 | + '-s', "{}x{}".format(self.width, self.height), | ||
| 156 | + '-r', str(25), | ||
| 157 | + '-i', '-', | ||
| 158 | + '-pix_fmt', 'yuv420p', | ||
| 159 | + '-vcodec', "h264", | ||
| 160 | + #'-f' , 'flv', | ||
| 161 | + f'temp{self.opt.sessionid}.mp4'] | ||
| 162 | + self._record_video_pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE) | ||
| 163 | + | ||
| 164 | + acommand = ['ffmpeg', | ||
| 165 | + '-y', '-vn', | ||
| 166 | + '-f', 's16le', | ||
| 167 | + #'-acodec','pcm_s16le', | ||
| 168 | + '-ac', '1', | ||
| 169 | + '-ar', '16000', | ||
| 170 | + '-i', '-', | ||
| 171 | + '-acodec', 'aac', | ||
| 172 | + #'-f' , 'wav', | ||
| 173 | + f'temp{self.opt.sessionid}.aac'] | ||
| 174 | + self._record_audio_pipe = subprocess.Popen(acommand, shell=False, stdin=subprocess.PIPE) | ||
| 175 | + | ||
| 176 | + self.recording = True | ||
| 177 | + # self.recordq_video.queue.clear() | ||
| 178 | + # self.recordq_audio.queue.clear() | ||
| 179 | + # self.container = av.open(path, mode="w") | ||
| 180 | + | ||
| 181 | + # process_thread = Thread(target=self.record_frame, args=()) | ||
| 182 | + # process_thread.start() | ||
| 183 | + | ||
| 184 | + def record_video_data(self,image): | ||
| 185 | + if self.width == 0: | ||
| 186 | + print("image.shape:",image.shape) | ||
| 187 | + self.height,self.width,_ = image.shape | ||
| 188 | + if self.recording: | ||
| 189 | + self._record_video_pipe.stdin.write(image.tostring()) | ||
| 190 | + | ||
| 191 | + def record_audio_data(self,frame): | ||
| 192 | + if self.recording: | ||
| 193 | + self._record_audio_pipe.stdin.write(frame.tostring()) | ||
| 194 | + | ||
| 195 | + # def record_frame(self): | ||
| 196 | + # videostream = self.container.add_stream("libx264", rate=25) | ||
| 197 | + # videostream.codec_context.time_base = Fraction(1, 25) | ||
| 198 | + # audiostream = self.container.add_stream("aac") | ||
| 199 | + # audiostream.codec_context.time_base = Fraction(1, 16000) | ||
| 200 | + # init = True | ||
| 201 | + # framenum = 0 | ||
| 202 | + # while self.recording: | ||
| 203 | + # try: | ||
| 204 | + # videoframe = self.recordq_video.get(block=True, timeout=1) | ||
| 205 | + # videoframe.pts = framenum #int(round(framenum*0.04 / videostream.codec_context.time_base)) | ||
| 206 | + # videoframe.dts = videoframe.pts | ||
| 207 | + # if init: | ||
| 208 | + # videostream.width = videoframe.width | ||
| 209 | + # videostream.height = videoframe.height | ||
| 210 | + # init = False | ||
| 211 | + # for packet in videostream.encode(videoframe): | ||
| 212 | + # self.container.mux(packet) | ||
| 213 | + # for k in range(2): | ||
| 214 | + # audioframe = self.recordq_audio.get(block=True, timeout=1) | ||
| 215 | + # audioframe.pts = int(round((framenum*2+k)*0.02 / audiostream.codec_context.time_base)) | ||
| 216 | + # audioframe.dts = audioframe.pts | ||
| 217 | + # for packet in audiostream.encode(audioframe): | ||
| 218 | + # self.container.mux(packet) | ||
| 219 | + # framenum += 1 | ||
| 220 | + # except queue.Empty: | ||
| 221 | + # print('record queue empty,') | ||
| 222 | + # continue | ||
| 223 | + # except Exception as e: | ||
| 224 | + # print(e) | ||
| 225 | + # #break | ||
| 226 | + # for packet in videostream.encode(None): | ||
| 227 | + # self.container.mux(packet) | ||
| 228 | + # for packet in audiostream.encode(None): | ||
| 229 | + # self.container.mux(packet) | ||
| 230 | + # self.container.close() | ||
| 231 | + # self.recordq_video.queue.clear() | ||
| 232 | + # self.recordq_audio.queue.clear() | ||
| 233 | + # print('record thread stop') | ||
| 234 | + | ||
| 235 | + def stop_recording(self): | ||
| 236 | + """停止录制视频""" | ||
| 237 | + if not self.recording: | ||
| 238 | + return | ||
| 239 | + self.recording = False | ||
| 240 | + self._record_video_pipe.stdin.close() #wait() | ||
| 241 | + self._record_video_pipe.wait() | ||
| 242 | + self._record_audio_pipe.stdin.close() | ||
| 243 | + self._record_audio_pipe.wait() | ||
| 244 | + cmd_combine_audio = f"ffmpeg -y -i temp{self.opt.sessionid}.aac -i temp{self.opt.sessionid}.mp4 -c:v copy -c:a copy data/record.mp4" | ||
| 245 | + os.system(cmd_combine_audio) | ||
| 246 | + #os.remove(output_path) | ||
| 247 | + | ||
| 248 | + def mirror_index(self,size, index): | ||
| 249 | + #size = len(self.coord_list_cycle) | ||
| 250 | + turn = index // size | ||
| 251 | + res = index % size | ||
| 252 | + if turn % 2 == 0: | ||
| 253 | + return res | ||
| 254 | + else: | ||
| 255 | + return size - res - 1 | ||
| 256 | + | ||
| 257 | + def get_audio_stream(self,audiotype): | ||
| 258 | + idx = self.custom_audio_index[audiotype] | ||
| 259 | + stream = self.custom_audio_cycle[audiotype][idx:idx+self.chunk] | ||
| 260 | + self.custom_audio_index[audiotype] += self.chunk | ||
| 261 | + if self.custom_audio_index[audiotype]>=self.custom_audio_cycle[audiotype].shape[0]: | ||
| 262 | + self.curr_state = 1 #当前视频不循环播放,切换到静音状态 | ||
| 263 | + return stream | ||
| 264 | + | ||
| 265 | + def set_curr_state(self,audiotype, reinit): | ||
| 266 | + print('set_curr_state:',audiotype) | ||
| 267 | + self.curr_state = audiotype | ||
| 268 | + if reinit: | ||
| 269 | + self.custom_audio_index[audiotype] = 0 | ||
| 270 | + self.custom_index[audiotype] = 0 | ||
| 271 | + | ||
| 272 | + # def process_custom(self,audiotype:int,idx:int): | ||
| 273 | + # if self.curr_state!=audiotype: #从推理切到口播 | ||
| 274 | + # if idx in self.switch_pos: #在卡点位置可以切换 | ||
| 275 | + # self.curr_state=audiotype | ||
| 276 | + # self.custom_index=0 | ||
| 277 | + # else: | ||
| 278 | + # self.custom_index+=1 |
| 1 | +# Routines for DeepSpeech features processing | ||
| 2 | +Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model. | ||
| 3 | + | ||
| 4 | +## Installation | ||
| 5 | + | ||
| 6 | +``` | ||
| 7 | +pip3 install -r requirements.txt | ||
| 8 | +``` | ||
| 9 | + | ||
| 10 | +## Usage | ||
| 11 | + | ||
| 12 | +Generate wav files: | ||
| 13 | +``` | ||
| 14 | +python3 extract_wav.py --in-video=<you_data_dir> | ||
| 15 | +``` | ||
| 16 | + | ||
| 17 | +Generate files with DeepSpeech features: | ||
| 18 | +``` | ||
| 19 | +python3 extract_ds_features.py --input=<you_data_dir> | ||
| 20 | +``` |
| 1 | +""" | ||
| 2 | + DeepSpeech features processing routines. | ||
| 3 | + NB: Based on VOCA code. See the corresponding license restrictions. | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +__all__ = ['conv_audios_to_deepspeech'] | ||
| 7 | + | ||
| 8 | +import numpy as np | ||
| 9 | +import warnings | ||
| 10 | +import resampy | ||
| 11 | +from scipy.io import wavfile | ||
| 12 | +from python_speech_features import mfcc | ||
| 13 | +import tensorflow.compat.v1 as tf | ||
| 14 | +tf.disable_v2_behavior() | ||
| 15 | + | ||
| 16 | +def conv_audios_to_deepspeech(audios, | ||
| 17 | + out_files, | ||
| 18 | + num_frames_info, | ||
| 19 | + deepspeech_pb_path, | ||
| 20 | + audio_window_size=1, | ||
| 21 | + audio_window_stride=1): | ||
| 22 | + """ | ||
| 23 | + Convert list of audio files into files with DeepSpeech features. | ||
| 24 | + | ||
| 25 | + Parameters | ||
| 26 | + ---------- | ||
| 27 | + audios : list of str or list of None | ||
| 28 | + Paths to input audio files. | ||
| 29 | + out_files : list of str | ||
| 30 | + Paths to output files with DeepSpeech features. | ||
| 31 | + num_frames_info : list of int | ||
| 32 | + List of numbers of frames. | ||
| 33 | + deepspeech_pb_path : str | ||
| 34 | + Path to DeepSpeech 0.1.0 frozen model. | ||
| 35 | + audio_window_size : int, default 16 | ||
| 36 | + Audio window size. | ||
| 37 | + audio_window_stride : int, default 1 | ||
| 38 | + Audio window stride. | ||
| 39 | + """ | ||
| 40 | + # deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" | ||
| 41 | + graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net( | ||
| 42 | + deepspeech_pb_path) | ||
| 43 | + | ||
| 44 | + with tf.compat.v1.Session(graph=graph) as sess: | ||
| 45 | + for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info): | ||
| 46 | + print(audio_file_path) | ||
| 47 | + print(out_file_path) | ||
| 48 | + audio_sample_rate, audio = wavfile.read(audio_file_path) | ||
| 49 | + if audio.ndim != 1: | ||
| 50 | + warnings.warn( | ||
| 51 | + "Audio has multiple channels, the first channel is used") | ||
| 52 | + audio = audio[:, 0] | ||
| 53 | + ds_features = pure_conv_audio_to_deepspeech( | ||
| 54 | + audio=audio, | ||
| 55 | + audio_sample_rate=audio_sample_rate, | ||
| 56 | + audio_window_size=audio_window_size, | ||
| 57 | + audio_window_stride=audio_window_stride, | ||
| 58 | + num_frames=num_frames, | ||
| 59 | + net_fn=lambda x: sess.run( | ||
| 60 | + logits_ph, | ||
| 61 | + feed_dict={ | ||
| 62 | + input_node_ph: x[np.newaxis, ...], | ||
| 63 | + input_lengths_ph: [x.shape[0]]})) | ||
| 64 | + | ||
| 65 | + net_output = ds_features.reshape(-1, 29) | ||
| 66 | + win_size = 16 | ||
| 67 | + zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) | ||
| 68 | + net_output = np.concatenate( | ||
| 69 | + (zero_pad, net_output, zero_pad), axis=0) | ||
| 70 | + windows = [] | ||
| 71 | + for window_index in range(0, net_output.shape[0] - win_size, 2): | ||
| 72 | + windows.append( | ||
| 73 | + net_output[window_index:window_index + win_size]) | ||
| 74 | + print(np.array(windows).shape) | ||
| 75 | + np.save(out_file_path, np.array(windows)) | ||
| 76 | + | ||
| 77 | + | ||
| 78 | +def prepare_deepspeech_net(deepspeech_pb_path): | ||
| 79 | + """ | ||
| 80 | + Load and prepare DeepSpeech network. | ||
| 81 | + | ||
| 82 | + Parameters | ||
| 83 | + ---------- | ||
| 84 | + deepspeech_pb_path : str | ||
| 85 | + Path to DeepSpeech 0.1.0 frozen model. | ||
| 86 | + | ||
| 87 | + Returns | ||
| 88 | + ------- | ||
| 89 | + graph : obj | ||
| 90 | + ThensorFlow graph. | ||
| 91 | + logits_ph : obj | ||
| 92 | + ThensorFlow placeholder for `logits`. | ||
| 93 | + input_node_ph : obj | ||
| 94 | + ThensorFlow placeholder for `input_node`. | ||
| 95 | + input_lengths_ph : obj | ||
| 96 | + ThensorFlow placeholder for `input_lengths`. | ||
| 97 | + """ | ||
| 98 | + # Load graph and place_holders: | ||
| 99 | + with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f: | ||
| 100 | + graph_def = tf.compat.v1.GraphDef() | ||
| 101 | + graph_def.ParseFromString(f.read()) | ||
| 102 | + | ||
| 103 | + graph = tf.compat.v1.get_default_graph() | ||
| 104 | + tf.import_graph_def(graph_def, name="deepspeech") | ||
| 105 | + logits_ph = graph.get_tensor_by_name("deepspeech/logits:0") | ||
| 106 | + input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0") | ||
| 107 | + input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0") | ||
| 108 | + | ||
| 109 | + return graph, logits_ph, input_node_ph, input_lengths_ph | ||
| 110 | + | ||
| 111 | + | ||
| 112 | +def pure_conv_audio_to_deepspeech(audio, | ||
| 113 | + audio_sample_rate, | ||
| 114 | + audio_window_size, | ||
| 115 | + audio_window_stride, | ||
| 116 | + num_frames, | ||
| 117 | + net_fn): | ||
| 118 | + """ | ||
| 119 | + Core routine for converting audion into DeepSpeech features. | ||
| 120 | + | ||
| 121 | + Parameters | ||
| 122 | + ---------- | ||
| 123 | + audio : np.array | ||
| 124 | + Audio data. | ||
| 125 | + audio_sample_rate : int | ||
| 126 | + Audio sample rate. | ||
| 127 | + audio_window_size : int | ||
| 128 | + Audio window size. | ||
| 129 | + audio_window_stride : int | ||
| 130 | + Audio window stride. | ||
| 131 | + num_frames : int or None | ||
| 132 | + Numbers of frames. | ||
| 133 | + net_fn : func | ||
| 134 | + Function for DeepSpeech model call. | ||
| 135 | + | ||
| 136 | + Returns | ||
| 137 | + ------- | ||
| 138 | + np.array | ||
| 139 | + DeepSpeech features. | ||
| 140 | + """ | ||
| 141 | + target_sample_rate = 16000 | ||
| 142 | + if audio_sample_rate != target_sample_rate: | ||
| 143 | + resampled_audio = resampy.resample( | ||
| 144 | + x=audio.astype(np.float), | ||
| 145 | + sr_orig=audio_sample_rate, | ||
| 146 | + sr_new=target_sample_rate) | ||
| 147 | + else: | ||
| 148 | + resampled_audio = audio.astype(np.float) | ||
| 149 | + input_vector = conv_audio_to_deepspeech_input_vector( | ||
| 150 | + audio=resampled_audio.astype(np.int16), | ||
| 151 | + sample_rate=target_sample_rate, | ||
| 152 | + num_cepstrum=26, | ||
| 153 | + num_context=9) | ||
| 154 | + | ||
| 155 | + network_output = net_fn(input_vector) | ||
| 156 | + # print(network_output.shape) | ||
| 157 | + | ||
| 158 | + deepspeech_fps = 50 | ||
| 159 | + video_fps = 50 # Change this option if video fps is different | ||
| 160 | + audio_len_s = float(audio.shape[0]) / audio_sample_rate | ||
| 161 | + if num_frames is None: | ||
| 162 | + num_frames = int(round(audio_len_s * video_fps)) | ||
| 163 | + else: | ||
| 164 | + video_fps = num_frames / audio_len_s | ||
| 165 | + network_output = interpolate_features( | ||
| 166 | + features=network_output[:, 0], | ||
| 167 | + input_rate=deepspeech_fps, | ||
| 168 | + output_rate=video_fps, | ||
| 169 | + output_len=num_frames) | ||
| 170 | + | ||
| 171 | + # Make windows: | ||
| 172 | + zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1])) | ||
| 173 | + network_output = np.concatenate( | ||
| 174 | + (zero_pad, network_output, zero_pad), axis=0) | ||
| 175 | + windows = [] | ||
| 176 | + for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride): | ||
| 177 | + windows.append( | ||
| 178 | + network_output[window_index:window_index + audio_window_size]) | ||
| 179 | + | ||
| 180 | + return np.array(windows) | ||
| 181 | + | ||
| 182 | + | ||
| 183 | +def conv_audio_to_deepspeech_input_vector(audio, | ||
| 184 | + sample_rate, | ||
| 185 | + num_cepstrum, | ||
| 186 | + num_context): | ||
| 187 | + """ | ||
| 188 | + Convert audio raw data into DeepSpeech input vector. | ||
| 189 | + | ||
| 190 | + Parameters | ||
| 191 | + ---------- | ||
| 192 | + audio : np.array | ||
| 193 | + Audio data. | ||
| 194 | + audio_sample_rate : int | ||
| 195 | + Audio sample rate. | ||
| 196 | + num_cepstrum : int | ||
| 197 | + Number of cepstrum. | ||
| 198 | + num_context : int | ||
| 199 | + Number of context. | ||
| 200 | + | ||
| 201 | + Returns | ||
| 202 | + ------- | ||
| 203 | + np.array | ||
| 204 | + DeepSpeech input vector. | ||
| 205 | + """ | ||
| 206 | + # Get mfcc coefficients: | ||
| 207 | + features = mfcc( | ||
| 208 | + signal=audio, | ||
| 209 | + samplerate=sample_rate, | ||
| 210 | + numcep=num_cepstrum) | ||
| 211 | + | ||
| 212 | + # We only keep every second feature (BiRNN stride = 2): | ||
| 213 | + features = features[::2] | ||
| 214 | + | ||
| 215 | + # One stride per time step in the input: | ||
| 216 | + num_strides = len(features) | ||
| 217 | + | ||
| 218 | + # Add empty initial and final contexts: | ||
| 219 | + empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype) | ||
| 220 | + features = np.concatenate((empty_context, features, empty_context)) | ||
| 221 | + | ||
| 222 | + # Create a view into the array with overlapping strides of size | ||
| 223 | + # numcontext (past) + 1 (present) + numcontext (future): | ||
| 224 | + window_size = 2 * num_context + 1 | ||
| 225 | + train_inputs = np.lib.stride_tricks.as_strided( | ||
| 226 | + features, | ||
| 227 | + shape=(num_strides, window_size, num_cepstrum), | ||
| 228 | + strides=(features.strides[0], | ||
| 229 | + features.strides[0], features.strides[1]), | ||
| 230 | + writeable=False) | ||
| 231 | + | ||
| 232 | + # Flatten the second and third dimensions: | ||
| 233 | + train_inputs = np.reshape(train_inputs, [num_strides, -1]) | ||
| 234 | + | ||
| 235 | + train_inputs = np.copy(train_inputs) | ||
| 236 | + train_inputs = (train_inputs - np.mean(train_inputs)) / \ | ||
| 237 | + np.std(train_inputs) | ||
| 238 | + | ||
| 239 | + return train_inputs | ||
| 240 | + | ||
| 241 | + | ||
| 242 | +def interpolate_features(features, | ||
| 243 | + input_rate, | ||
| 244 | + output_rate, | ||
| 245 | + output_len): | ||
| 246 | + """ | ||
| 247 | + Interpolate DeepSpeech features. | ||
| 248 | + | ||
| 249 | + Parameters | ||
| 250 | + ---------- | ||
| 251 | + features : np.array | ||
| 252 | + DeepSpeech features. | ||
| 253 | + input_rate : int | ||
| 254 | + input rate (FPS). | ||
| 255 | + output_rate : int | ||
| 256 | + Output rate (FPS). | ||
| 257 | + output_len : int | ||
| 258 | + Output data length. | ||
| 259 | + | ||
| 260 | + Returns | ||
| 261 | + ------- | ||
| 262 | + np.array | ||
| 263 | + Interpolated data. | ||
| 264 | + """ | ||
| 265 | + input_len = features.shape[0] | ||
| 266 | + num_features = features.shape[1] | ||
| 267 | + input_timestamps = np.arange(input_len) / float(input_rate) | ||
| 268 | + output_timestamps = np.arange(output_len) / float(output_rate) | ||
| 269 | + output_features = np.zeros((output_len, num_features)) | ||
| 270 | + for feature_idx in range(num_features): | ||
| 271 | + output_features[:, feature_idx] = np.interp( | ||
| 272 | + x=output_timestamps, | ||
| 273 | + xp=input_timestamps, | ||
| 274 | + fp=features[:, feature_idx]) | ||
| 275 | + return output_features |
| 1 | +""" | ||
| 2 | + Routines for loading DeepSpeech model. | ||
| 3 | +""" | ||
| 4 | + | ||
| 5 | +__all__ = ['get_deepspeech_model_file'] | ||
| 6 | + | ||
| 7 | +import os | ||
| 8 | +import zipfile | ||
| 9 | +import logging | ||
| 10 | +import hashlib | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features' | ||
| 14 | + | ||
| 15 | + | ||
| 16 | +def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")): | ||
| 17 | + """ | ||
| 18 | + Return location for the pretrained on local file system. This function will download from online model zoo when | ||
| 19 | + model cannot be found or has mismatch. The root directory will be created if it doesn't exist. | ||
| 20 | + | ||
| 21 | + Parameters | ||
| 22 | + ---------- | ||
| 23 | + local_model_store_dir_path : str, default $TENSORFLOW_HOME/models | ||
| 24 | + Location for keeping the model parameters. | ||
| 25 | + | ||
| 26 | + Returns | ||
| 27 | + ------- | ||
| 28 | + file_path | ||
| 29 | + Path to the requested pretrained model file. | ||
| 30 | + """ | ||
| 31 | + sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e" | ||
| 32 | + file_name = "deepspeech-0_1_0-b90017e8.pb" | ||
| 33 | + local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path) | ||
| 34 | + file_path = os.path.join(local_model_store_dir_path, file_name) | ||
| 35 | + if os.path.exists(file_path): | ||
| 36 | + if _check_sha1(file_path, sha1_hash): | ||
| 37 | + return file_path | ||
| 38 | + else: | ||
| 39 | + logging.warning("Mismatch in the content of model file detected. Downloading again.") | ||
| 40 | + else: | ||
| 41 | + logging.info("Model file not found. Downloading to {}.".format(file_path)) | ||
| 42 | + | ||
| 43 | + if not os.path.exists(local_model_store_dir_path): | ||
| 44 | + os.makedirs(local_model_store_dir_path) | ||
| 45 | + | ||
| 46 | + zip_file_path = file_path + ".zip" | ||
| 47 | + _download( | ||
| 48 | + url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format( | ||
| 49 | + repo_url=deepspeech_features_repo_url, | ||
| 50 | + repo_release_tag="v0.0.1", | ||
| 51 | + file_name=file_name), | ||
| 52 | + path=zip_file_path, | ||
| 53 | + overwrite=True) | ||
| 54 | + with zipfile.ZipFile(zip_file_path) as zf: | ||
| 55 | + zf.extractall(local_model_store_dir_path) | ||
| 56 | + os.remove(zip_file_path) | ||
| 57 | + | ||
| 58 | + if _check_sha1(file_path, sha1_hash): | ||
| 59 | + return file_path | ||
| 60 | + else: | ||
| 61 | + raise ValueError("Downloaded file has different hash. Please try again.") | ||
| 62 | + | ||
| 63 | + | ||
| 64 | +def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): | ||
| 65 | + """ | ||
| 66 | + Download an given URL | ||
| 67 | + | ||
| 68 | + Parameters | ||
| 69 | + ---------- | ||
| 70 | + url : str | ||
| 71 | + URL to download | ||
| 72 | + path : str, optional | ||
| 73 | + Destination path to store downloaded file. By default stores to the | ||
| 74 | + current directory with same name as in url. | ||
| 75 | + overwrite : bool, optional | ||
| 76 | + Whether to overwrite destination file if already exists. | ||
| 77 | + sha1_hash : str, optional | ||
| 78 | + Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified | ||
| 79 | + but doesn't match. | ||
| 80 | + retries : integer, default 5 | ||
| 81 | + The number of times to attempt the download in case of failure or non 200 return codes | ||
| 82 | + verify_ssl : bool, default True | ||
| 83 | + Verify SSL certificates. | ||
| 84 | + | ||
| 85 | + Returns | ||
| 86 | + ------- | ||
| 87 | + str | ||
| 88 | + The file path of the downloaded file. | ||
| 89 | + """ | ||
| 90 | + import warnings | ||
| 91 | + try: | ||
| 92 | + import requests | ||
| 93 | + except ImportError: | ||
| 94 | + class requests_failed_to_import(object): | ||
| 95 | + pass | ||
| 96 | + requests = requests_failed_to_import | ||
| 97 | + | ||
| 98 | + if path is None: | ||
| 99 | + fname = url.split("/")[-1] | ||
| 100 | + # Empty filenames are invalid | ||
| 101 | + assert fname, "Can't construct file-name from this URL. Please set the `path` option manually." | ||
| 102 | + else: | ||
| 103 | + path = os.path.expanduser(path) | ||
| 104 | + if os.path.isdir(path): | ||
| 105 | + fname = os.path.join(path, url.split("/")[-1]) | ||
| 106 | + else: | ||
| 107 | + fname = path | ||
| 108 | + assert retries >= 0, "Number of retries should be at least 0" | ||
| 109 | + | ||
| 110 | + if not verify_ssl: | ||
| 111 | + warnings.warn( | ||
| 112 | + "Unverified HTTPS request is being made (verify_ssl=False). " | ||
| 113 | + "Adding certificate verification is strongly advised.") | ||
| 114 | + | ||
| 115 | + if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)): | ||
| 116 | + dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) | ||
| 117 | + if not os.path.exists(dirname): | ||
| 118 | + os.makedirs(dirname) | ||
| 119 | + while retries + 1 > 0: | ||
| 120 | + # Disable pyling too broad Exception | ||
| 121 | + # pylint: disable=W0703 | ||
| 122 | + try: | ||
| 123 | + print("Downloading {} from {}...".format(fname, url)) | ||
| 124 | + r = requests.get(url, stream=True, verify=verify_ssl) | ||
| 125 | + if r.status_code != 200: | ||
| 126 | + raise RuntimeError("Failed downloading url {}".format(url)) | ||
| 127 | + with open(fname, "wb") as f: | ||
| 128 | + for chunk in r.iter_content(chunk_size=1024): | ||
| 129 | + if chunk: # filter out keep-alive new chunks | ||
| 130 | + f.write(chunk) | ||
| 131 | + if sha1_hash and not _check_sha1(fname, sha1_hash): | ||
| 132 | + raise UserWarning("File {} is downloaded but the content hash does not match." | ||
| 133 | + " The repo may be outdated or download may be incomplete. " | ||
| 134 | + "If the `repo_url` is overridden, consider switching to " | ||
| 135 | + "the default repo.".format(fname)) | ||
| 136 | + break | ||
| 137 | + except Exception as e: | ||
| 138 | + retries -= 1 | ||
| 139 | + if retries <= 0: | ||
| 140 | + raise e | ||
| 141 | + else: | ||
| 142 | + print("download failed, retrying, {} attempt{} left" | ||
| 143 | + .format(retries, "s" if retries > 1 else "")) | ||
| 144 | + | ||
| 145 | + return fname | ||
| 146 | + | ||
| 147 | + | ||
| 148 | +def _check_sha1(filename, sha1_hash): | ||
| 149 | + """ | ||
| 150 | + Check whether the sha1 hash of the file content matches the expected hash. | ||
| 151 | + | ||
| 152 | + Parameters | ||
| 153 | + ---------- | ||
| 154 | + filename : str | ||
| 155 | + Path to the file. | ||
| 156 | + sha1_hash : str | ||
| 157 | + Expected sha1 hash in hexadecimal digits. | ||
| 158 | + | ||
| 159 | + Returns | ||
| 160 | + ------- | ||
| 161 | + bool | ||
| 162 | + Whether the file content matches the expected hash. | ||
| 163 | + """ | ||
| 164 | + sha1 = hashlib.sha1() | ||
| 165 | + with open(filename, "rb") as f: | ||
| 166 | + while True: | ||
| 167 | + data = f.read(1048576) | ||
| 168 | + if not data: | ||
| 169 | + break | ||
| 170 | + sha1.update(data) | ||
| 171 | + | ||
| 172 | + return sha1.hexdigest() == sha1_hash |
| 1 | +""" | ||
| 2 | + Script for extracting DeepSpeech features from audio file. | ||
| 3 | +""" | ||
| 4 | + | ||
| 5 | +import os | ||
| 6 | +import argparse | ||
| 7 | +import numpy as np | ||
| 8 | +import pandas as pd | ||
| 9 | +from deepspeech_store import get_deepspeech_model_file | ||
| 10 | +from deepspeech_features import conv_audios_to_deepspeech | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +def parse_args(): | ||
| 14 | + """ | ||
| 15 | + Create python script parameters. | ||
| 16 | + Returns | ||
| 17 | + ------- | ||
| 18 | + ArgumentParser | ||
| 19 | + Resulted args. | ||
| 20 | + """ | ||
| 21 | + parser = argparse.ArgumentParser( | ||
| 22 | + description="Extract DeepSpeech features from audio file", | ||
| 23 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
| 24 | + parser.add_argument( | ||
| 25 | + "--input", | ||
| 26 | + type=str, | ||
| 27 | + required=True, | ||
| 28 | + help="path to input audio file or directory") | ||
| 29 | + parser.add_argument( | ||
| 30 | + "--output", | ||
| 31 | + type=str, | ||
| 32 | + help="path to output file with DeepSpeech features") | ||
| 33 | + parser.add_argument( | ||
| 34 | + "--deepspeech", | ||
| 35 | + type=str, | ||
| 36 | + help="path to DeepSpeech 0.1.0 frozen model") | ||
| 37 | + parser.add_argument( | ||
| 38 | + "--metainfo", | ||
| 39 | + type=str, | ||
| 40 | + help="path to file with meta-information") | ||
| 41 | + | ||
| 42 | + args = parser.parse_args() | ||
| 43 | + return args | ||
| 44 | + | ||
| 45 | + | ||
| 46 | +def extract_features(in_audios, | ||
| 47 | + out_files, | ||
| 48 | + deepspeech_pb_path, | ||
| 49 | + metainfo_file_path=None): | ||
| 50 | + """ | ||
| 51 | + Real extract audio from video file. | ||
| 52 | + Parameters | ||
| 53 | + ---------- | ||
| 54 | + in_audios : list of str | ||
| 55 | + Paths to input audio files. | ||
| 56 | + out_files : list of str | ||
| 57 | + Paths to output files with DeepSpeech features. | ||
| 58 | + deepspeech_pb_path : str | ||
| 59 | + Path to DeepSpeech 0.1.0 frozen model. | ||
| 60 | + metainfo_file_path : str, default None | ||
| 61 | + Path to file with meta-information. | ||
| 62 | + """ | ||
| 63 | + #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" | ||
| 64 | + if metainfo_file_path is None: | ||
| 65 | + num_frames_info = [None] * len(in_audios) | ||
| 66 | + else: | ||
| 67 | + train_df = pd.read_csv( | ||
| 68 | + metainfo_file_path, | ||
| 69 | + sep="\t", | ||
| 70 | + index_col=False, | ||
| 71 | + dtype={"Id": np.int, "File": np.unicode, "Count": np.int}) | ||
| 72 | + num_frames_info = train_df["Count"].values | ||
| 73 | + assert (len(num_frames_info) == len(in_audios)) | ||
| 74 | + | ||
| 75 | + for i, in_audio in enumerate(in_audios): | ||
| 76 | + if not out_files[i]: | ||
| 77 | + file_stem, _ = os.path.splitext(in_audio) | ||
| 78 | + out_files[i] = file_stem + ".npy" | ||
| 79 | + #print(out_files[i]) | ||
| 80 | + conv_audios_to_deepspeech( | ||
| 81 | + audios=in_audios, | ||
| 82 | + out_files=out_files, | ||
| 83 | + num_frames_info=num_frames_info, | ||
| 84 | + deepspeech_pb_path=deepspeech_pb_path) | ||
| 85 | + | ||
| 86 | + | ||
| 87 | +def main(): | ||
| 88 | + """ | ||
| 89 | + Main body of script. | ||
| 90 | + """ | ||
| 91 | + args = parse_args() | ||
| 92 | + in_audio = os.path.expanduser(args.input) | ||
| 93 | + if not os.path.exists(in_audio): | ||
| 94 | + raise Exception("Input file/directory doesn't exist: {}".format(in_audio)) | ||
| 95 | + deepspeech_pb_path = args.deepspeech | ||
| 96 | + #add | ||
| 97 | + deepspeech_pb_path = True | ||
| 98 | + args.deepspeech = '~/.tensorflow/models/deepspeech-0_1_0-b90017e8.pb' | ||
| 99 | + #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" | ||
| 100 | + if deepspeech_pb_path is None: | ||
| 101 | + deepspeech_pb_path = "" | ||
| 102 | + if deepspeech_pb_path: | ||
| 103 | + deepspeech_pb_path = os.path.expanduser(args.deepspeech) | ||
| 104 | + if not os.path.exists(deepspeech_pb_path): | ||
| 105 | + deepspeech_pb_path = get_deepspeech_model_file() | ||
| 106 | + if os.path.isfile(in_audio): | ||
| 107 | + extract_features( | ||
| 108 | + in_audios=[in_audio], | ||
| 109 | + out_files=[args.output], | ||
| 110 | + deepspeech_pb_path=deepspeech_pb_path, | ||
| 111 | + metainfo_file_path=args.metainfo) | ||
| 112 | + else: | ||
| 113 | + audio_file_paths = [] | ||
| 114 | + for file_name in os.listdir(in_audio): | ||
| 115 | + if not os.path.isfile(os.path.join(in_audio, file_name)): | ||
| 116 | + continue | ||
| 117 | + _, file_ext = os.path.splitext(file_name) | ||
| 118 | + if file_ext.lower() == ".wav": | ||
| 119 | + audio_file_path = os.path.join(in_audio, file_name) | ||
| 120 | + audio_file_paths.append(audio_file_path) | ||
| 121 | + audio_file_paths = sorted(audio_file_paths) | ||
| 122 | + out_file_paths = [""] * len(audio_file_paths) | ||
| 123 | + extract_features( | ||
| 124 | + in_audios=audio_file_paths, | ||
| 125 | + out_files=out_file_paths, | ||
| 126 | + deepspeech_pb_path=deepspeech_pb_path, | ||
| 127 | + metainfo_file_path=args.metainfo) | ||
| 128 | + | ||
| 129 | + | ||
| 130 | +if __name__ == "__main__": | ||
| 131 | + main() | ||
| 132 | + |
| 1 | +""" | ||
| 2 | + Script for extracting audio (16-bit, mono, 22000 Hz) from video file. | ||
| 3 | +""" | ||
| 4 | + | ||
| 5 | +import os | ||
| 6 | +import argparse | ||
| 7 | +import subprocess | ||
| 8 | + | ||
| 9 | + | ||
| 10 | +def parse_args(): | ||
| 11 | + """ | ||
| 12 | + Create python script parameters. | ||
| 13 | + | ||
| 14 | + Returns | ||
| 15 | + ------- | ||
| 16 | + ArgumentParser | ||
| 17 | + Resulted args. | ||
| 18 | + """ | ||
| 19 | + parser = argparse.ArgumentParser( | ||
| 20 | + description="Extract audio from video file", | ||
| 21 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
| 22 | + parser.add_argument( | ||
| 23 | + "--in-video", | ||
| 24 | + type=str, | ||
| 25 | + required=True, | ||
| 26 | + help="path to input video file or directory") | ||
| 27 | + parser.add_argument( | ||
| 28 | + "--out-audio", | ||
| 29 | + type=str, | ||
| 30 | + help="path to output audio file") | ||
| 31 | + | ||
| 32 | + args = parser.parse_args() | ||
| 33 | + return args | ||
| 34 | + | ||
| 35 | + | ||
| 36 | +def extract_audio(in_video, | ||
| 37 | + out_audio): | ||
| 38 | + """ | ||
| 39 | + Real extract audio from video file. | ||
| 40 | + | ||
| 41 | + Parameters | ||
| 42 | + ---------- | ||
| 43 | + in_video : str | ||
| 44 | + Path to input video file. | ||
| 45 | + out_audio : str | ||
| 46 | + Path to output audio file. | ||
| 47 | + """ | ||
| 48 | + if not out_audio: | ||
| 49 | + file_stem, _ = os.path.splitext(in_video) | ||
| 50 | + out_audio = file_stem + ".wav" | ||
| 51 | + # command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}" | ||
| 52 | + # command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" | ||
| 53 | + # command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" | ||
| 54 | + command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}" | ||
| 55 | + subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True) | ||
| 56 | + | ||
| 57 | + | ||
| 58 | +def main(): | ||
| 59 | + """ | ||
| 60 | + Main body of script. | ||
| 61 | + """ | ||
| 62 | + args = parse_args() | ||
| 63 | + in_video = os.path.expanduser(args.in_video) | ||
| 64 | + if not os.path.exists(in_video): | ||
| 65 | + raise Exception("Input file/directory doesn't exist: {}".format(in_video)) | ||
| 66 | + if os.path.isfile(in_video): | ||
| 67 | + extract_audio( | ||
| 68 | + in_video=in_video, | ||
| 69 | + out_audio=args.out_audio) | ||
| 70 | + else: | ||
| 71 | + video_file_paths = [] | ||
| 72 | + for file_name in os.listdir(in_video): | ||
| 73 | + if not os.path.isfile(os.path.join(in_video, file_name)): | ||
| 74 | + continue | ||
| 75 | + _, file_ext = os.path.splitext(file_name) | ||
| 76 | + if file_ext.lower() in (".mp4", ".mkv", ".avi"): | ||
| 77 | + video_file_path = os.path.join(in_video, file_name) | ||
| 78 | + video_file_paths.append(video_file_path) | ||
| 79 | + video_file_paths = sorted(video_file_paths) | ||
| 80 | + for video_file_path in video_file_paths: | ||
| 81 | + extract_audio( | ||
| 82 | + in_video=video_file_path, | ||
| 83 | + out_audio="") | ||
| 84 | + | ||
| 85 | + | ||
| 86 | +if __name__ == "__main__": | ||
| 87 | + main() |
| 1 | +import numpy as np | ||
| 2 | + | ||
| 3 | +net_output = np.load('french.ds.npy').reshape(-1, 29) | ||
| 4 | +win_size = 16 | ||
| 5 | +zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) | ||
| 6 | +net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0) | ||
| 7 | +windows = [] | ||
| 8 | +for window_index in range(0, net_output.shape[0] - win_size, 2): | ||
| 9 | + windows.append(net_output[window_index:window_index + win_size]) | ||
| 10 | +print(np.array(windows).shape) | ||
| 11 | +np.save('aud_french.npy', np.array(windows)) |
ernerf/data_utils/face_parsing/logger.py
0 → 100644
| 1 | +#!/usr/bin/python | ||
| 2 | +# -*- encoding: utf-8 -*- | ||
| 3 | + | ||
| 4 | + | ||
| 5 | +import os.path as osp | ||
| 6 | +import time | ||
| 7 | +import sys | ||
| 8 | +import logging | ||
| 9 | + | ||
| 10 | +import torch.distributed as dist | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +def setup_logger(logpth): | ||
| 14 | + logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) | ||
| 15 | + logfile = osp.join(logpth, logfile) | ||
| 16 | + FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' | ||
| 17 | + log_level = logging.INFO | ||
| 18 | + if dist.is_initialized() and not dist.get_rank()==0: | ||
| 19 | + log_level = logging.ERROR | ||
| 20 | + logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) | ||
| 21 | + logging.root.addHandler(logging.StreamHandler()) | ||
| 22 | + | ||
| 23 | + |
ernerf/data_utils/face_parsing/model.py
0 → 100644
| 1 | +#!/usr/bin/python | ||
| 2 | +# -*- encoding: utf-8 -*- | ||
| 3 | + | ||
| 4 | + | ||
| 5 | +import torch | ||
| 6 | +import torch.nn as nn | ||
| 7 | +import torch.nn.functional as F | ||
| 8 | +import torchvision | ||
| 9 | + | ||
| 10 | +from resnet import Resnet18 | ||
| 11 | +# from modules.bn import InPlaceABNSync as BatchNorm2d | ||
| 12 | + | ||
| 13 | + | ||
| 14 | +class ConvBNReLU(nn.Module): | ||
| 15 | + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): | ||
| 16 | + super(ConvBNReLU, self).__init__() | ||
| 17 | + self.conv = nn.Conv2d(in_chan, | ||
| 18 | + out_chan, | ||
| 19 | + kernel_size = ks, | ||
| 20 | + stride = stride, | ||
| 21 | + padding = padding, | ||
| 22 | + bias = False) | ||
| 23 | + self.bn = nn.BatchNorm2d(out_chan) | ||
| 24 | + self.init_weight() | ||
| 25 | + | ||
| 26 | + def forward(self, x): | ||
| 27 | + x = self.conv(x) | ||
| 28 | + x = F.relu(self.bn(x)) | ||
| 29 | + return x | ||
| 30 | + | ||
| 31 | + def init_weight(self): | ||
| 32 | + for ly in self.children(): | ||
| 33 | + if isinstance(ly, nn.Conv2d): | ||
| 34 | + nn.init.kaiming_normal_(ly.weight, a=1) | ||
| 35 | + if not ly.bias is None: nn.init.constant_(ly.bias, 0) | ||
| 36 | + | ||
| 37 | +class BiSeNetOutput(nn.Module): | ||
| 38 | + def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): | ||
| 39 | + super(BiSeNetOutput, self).__init__() | ||
| 40 | + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) | ||
| 41 | + self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) | ||
| 42 | + self.init_weight() | ||
| 43 | + | ||
| 44 | + def forward(self, x): | ||
| 45 | + x = self.conv(x) | ||
| 46 | + x = self.conv_out(x) | ||
| 47 | + return x | ||
| 48 | + | ||
| 49 | + def init_weight(self): | ||
| 50 | + for ly in self.children(): | ||
| 51 | + if isinstance(ly, nn.Conv2d): | ||
| 52 | + nn.init.kaiming_normal_(ly.weight, a=1) | ||
| 53 | + if not ly.bias is None: nn.init.constant_(ly.bias, 0) | ||
| 54 | + | ||
| 55 | + def get_params(self): | ||
| 56 | + wd_params, nowd_params = [], [] | ||
| 57 | + for name, module in self.named_modules(): | ||
| 58 | + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): | ||
| 59 | + wd_params.append(module.weight) | ||
| 60 | + if not module.bias is None: | ||
| 61 | + nowd_params.append(module.bias) | ||
| 62 | + elif isinstance(module, nn.BatchNorm2d): | ||
| 63 | + nowd_params += list(module.parameters()) | ||
| 64 | + return wd_params, nowd_params | ||
| 65 | + | ||
| 66 | + | ||
| 67 | +class AttentionRefinementModule(nn.Module): | ||
| 68 | + def __init__(self, in_chan, out_chan, *args, **kwargs): | ||
| 69 | + super(AttentionRefinementModule, self).__init__() | ||
| 70 | + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) | ||
| 71 | + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) | ||
| 72 | + self.bn_atten = nn.BatchNorm2d(out_chan) | ||
| 73 | + self.sigmoid_atten = nn.Sigmoid() | ||
| 74 | + self.init_weight() | ||
| 75 | + | ||
| 76 | + def forward(self, x): | ||
| 77 | + feat = self.conv(x) | ||
| 78 | + atten = F.avg_pool2d(feat, feat.size()[2:]) | ||
| 79 | + atten = self.conv_atten(atten) | ||
| 80 | + atten = self.bn_atten(atten) | ||
| 81 | + atten = self.sigmoid_atten(atten) | ||
| 82 | + out = torch.mul(feat, atten) | ||
| 83 | + return out | ||
| 84 | + | ||
| 85 | + def init_weight(self): | ||
| 86 | + for ly in self.children(): | ||
| 87 | + if isinstance(ly, nn.Conv2d): | ||
| 88 | + nn.init.kaiming_normal_(ly.weight, a=1) | ||
| 89 | + if not ly.bias is None: nn.init.constant_(ly.bias, 0) | ||
| 90 | + | ||
| 91 | + | ||
| 92 | +class ContextPath(nn.Module): | ||
| 93 | + def __init__(self, *args, **kwargs): | ||
| 94 | + super(ContextPath, self).__init__() | ||
| 95 | + self.resnet = Resnet18() | ||
| 96 | + self.arm16 = AttentionRefinementModule(256, 128) | ||
| 97 | + self.arm32 = AttentionRefinementModule(512, 128) | ||
| 98 | + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) | ||
| 99 | + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) | ||
| 100 | + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) | ||
| 101 | + | ||
| 102 | + self.init_weight() | ||
| 103 | + | ||
| 104 | + def forward(self, x): | ||
| 105 | + H0, W0 = x.size()[2:] | ||
| 106 | + feat8, feat16, feat32 = self.resnet(x) | ||
| 107 | + H8, W8 = feat8.size()[2:] | ||
| 108 | + H16, W16 = feat16.size()[2:] | ||
| 109 | + H32, W32 = feat32.size()[2:] | ||
| 110 | + | ||
| 111 | + avg = F.avg_pool2d(feat32, feat32.size()[2:]) | ||
| 112 | + avg = self.conv_avg(avg) | ||
| 113 | + avg_up = F.interpolate(avg, (H32, W32), mode='nearest') | ||
| 114 | + | ||
| 115 | + feat32_arm = self.arm32(feat32) | ||
| 116 | + feat32_sum = feat32_arm + avg_up | ||
| 117 | + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') | ||
| 118 | + feat32_up = self.conv_head32(feat32_up) | ||
| 119 | + | ||
| 120 | + feat16_arm = self.arm16(feat16) | ||
| 121 | + feat16_sum = feat16_arm + feat32_up | ||
| 122 | + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') | ||
| 123 | + feat16_up = self.conv_head16(feat16_up) | ||
| 124 | + | ||
| 125 | + return feat8, feat16_up, feat32_up # x8, x8, x16 | ||
| 126 | + | ||
| 127 | + def init_weight(self): | ||
| 128 | + for ly in self.children(): | ||
| 129 | + if isinstance(ly, nn.Conv2d): | ||
| 130 | + nn.init.kaiming_normal_(ly.weight, a=1) | ||
| 131 | + if not ly.bias is None: nn.init.constant_(ly.bias, 0) | ||
| 132 | + | ||
| 133 | + def get_params(self): | ||
| 134 | + wd_params, nowd_params = [], [] | ||
| 135 | + for name, module in self.named_modules(): | ||
| 136 | + if isinstance(module, (nn.Linear, nn.Conv2d)): | ||
| 137 | + wd_params.append(module.weight) | ||
| 138 | + if not module.bias is None: | ||
| 139 | + nowd_params.append(module.bias) | ||
| 140 | + elif isinstance(module, nn.BatchNorm2d): | ||
| 141 | + nowd_params += list(module.parameters()) | ||
| 142 | + return wd_params, nowd_params | ||
| 143 | + | ||
| 144 | + | ||
| 145 | +### This is not used, since I replace this with the resnet feature with the same size | ||
| 146 | +class SpatialPath(nn.Module): | ||
| 147 | + def __init__(self, *args, **kwargs): | ||
| 148 | + super(SpatialPath, self).__init__() | ||
| 149 | + self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) | ||
| 150 | + self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) | ||
| 151 | + self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) | ||
| 152 | + self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) | ||
| 153 | + self.init_weight() | ||
| 154 | + | ||
| 155 | + def forward(self, x): | ||
| 156 | + feat = self.conv1(x) | ||
| 157 | + feat = self.conv2(feat) | ||
| 158 | + feat = self.conv3(feat) | ||
| 159 | + feat = self.conv_out(feat) | ||
| 160 | + return feat | ||
| 161 | + | ||
| 162 | + def init_weight(self): | ||
| 163 | + for ly in self.children(): | ||
| 164 | + if isinstance(ly, nn.Conv2d): | ||
| 165 | + nn.init.kaiming_normal_(ly.weight, a=1) | ||
| 166 | + if not ly.bias is None: nn.init.constant_(ly.bias, 0) | ||
| 167 | + | ||
| 168 | + def get_params(self): | ||
| 169 | + wd_params, nowd_params = [], [] | ||
| 170 | + for name, module in self.named_modules(): | ||
| 171 | + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): | ||
| 172 | + wd_params.append(module.weight) | ||
| 173 | + if not module.bias is None: | ||
| 174 | + nowd_params.append(module.bias) | ||
| 175 | + elif isinstance(module, nn.BatchNorm2d): | ||
| 176 | + nowd_params += list(module.parameters()) | ||
| 177 | + return wd_params, nowd_params | ||
| 178 | + | ||
| 179 | + | ||
| 180 | +class FeatureFusionModule(nn.Module): | ||
| 181 | + def __init__(self, in_chan, out_chan, *args, **kwargs): | ||
| 182 | + super(FeatureFusionModule, self).__init__() | ||
| 183 | + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) | ||
| 184 | + self.conv1 = nn.Conv2d(out_chan, | ||
| 185 | + out_chan//4, | ||
| 186 | + kernel_size = 1, | ||
| 187 | + stride = 1, | ||
| 188 | + padding = 0, | ||
| 189 | + bias = False) | ||
| 190 | + self.conv2 = nn.Conv2d(out_chan//4, | ||
| 191 | + out_chan, | ||
| 192 | + kernel_size = 1, | ||
| 193 | + stride = 1, | ||
| 194 | + padding = 0, | ||
| 195 | + bias = False) | ||
| 196 | + self.relu = nn.ReLU(inplace=True) | ||
| 197 | + self.sigmoid = nn.Sigmoid() | ||
| 198 | + self.init_weight() | ||
| 199 | + | ||
| 200 | + def forward(self, fsp, fcp): | ||
| 201 | + fcat = torch.cat([fsp, fcp], dim=1) | ||
| 202 | + feat = self.convblk(fcat) | ||
| 203 | + atten = F.avg_pool2d(feat, feat.size()[2:]) | ||
| 204 | + atten = self.conv1(atten) | ||
| 205 | + atten = self.relu(atten) | ||
| 206 | + atten = self.conv2(atten) | ||
| 207 | + atten = self.sigmoid(atten) | ||
| 208 | + feat_atten = torch.mul(feat, atten) | ||
| 209 | + feat_out = feat_atten + feat | ||
| 210 | + return feat_out | ||
| 211 | + | ||
| 212 | + def init_weight(self): | ||
| 213 | + for ly in self.children(): | ||
| 214 | + if isinstance(ly, nn.Conv2d): | ||
| 215 | + nn.init.kaiming_normal_(ly.weight, a=1) | ||
| 216 | + if not ly.bias is None: nn.init.constant_(ly.bias, 0) | ||
| 217 | + | ||
| 218 | + def get_params(self): | ||
| 219 | + wd_params, nowd_params = [], [] | ||
| 220 | + for name, module in self.named_modules(): | ||
| 221 | + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): | ||
| 222 | + wd_params.append(module.weight) | ||
| 223 | + if not module.bias is None: | ||
| 224 | + nowd_params.append(module.bias) | ||
| 225 | + elif isinstance(module, nn.BatchNorm2d): | ||
| 226 | + nowd_params += list(module.parameters()) | ||
| 227 | + return wd_params, nowd_params | ||
| 228 | + | ||
| 229 | + | ||
| 230 | +class BiSeNet(nn.Module): | ||
| 231 | + def __init__(self, n_classes, *args, **kwargs): | ||
| 232 | + super(BiSeNet, self).__init__() | ||
| 233 | + self.cp = ContextPath() | ||
| 234 | + ## here self.sp is deleted | ||
| 235 | + self.ffm = FeatureFusionModule(256, 256) | ||
| 236 | + self.conv_out = BiSeNetOutput(256, 256, n_classes) | ||
| 237 | + self.conv_out16 = BiSeNetOutput(128, 64, n_classes) | ||
| 238 | + self.conv_out32 = BiSeNetOutput(128, 64, n_classes) | ||
| 239 | + self.init_weight() | ||
| 240 | + | ||
| 241 | + def forward(self, x): | ||
| 242 | + H, W = x.size()[2:] | ||
| 243 | + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature | ||
| 244 | + feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature | ||
| 245 | + feat_fuse = self.ffm(feat_sp, feat_cp8) | ||
| 246 | + | ||
| 247 | + feat_out = self.conv_out(feat_fuse) | ||
| 248 | + feat_out16 = self.conv_out16(feat_cp8) | ||
| 249 | + feat_out32 = self.conv_out32(feat_cp16) | ||
| 250 | + | ||
| 251 | + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) | ||
| 252 | + feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) | ||
| 253 | + feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) | ||
| 254 | + | ||
| 255 | + # return feat_out, feat_out16, feat_out32 | ||
| 256 | + return feat_out | ||
| 257 | + | ||
| 258 | + def init_weight(self): | ||
| 259 | + for ly in self.children(): | ||
| 260 | + if isinstance(ly, nn.Conv2d): | ||
| 261 | + nn.init.kaiming_normal_(ly.weight, a=1) | ||
| 262 | + if not ly.bias is None: nn.init.constant_(ly.bias, 0) | ||
| 263 | + | ||
| 264 | + def get_params(self): | ||
| 265 | + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] | ||
| 266 | + for name, child in self.named_children(): | ||
| 267 | + child_wd_params, child_nowd_params = child.get_params() | ||
| 268 | + if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): | ||
| 269 | + lr_mul_wd_params += child_wd_params | ||
| 270 | + lr_mul_nowd_params += child_nowd_params | ||
| 271 | + else: | ||
| 272 | + wd_params += child_wd_params | ||
| 273 | + nowd_params += child_nowd_params | ||
| 274 | + return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params | ||
| 275 | + | ||
| 276 | + | ||
| 277 | +if __name__ == "__main__": | ||
| 278 | + net = BiSeNet(19) | ||
| 279 | + net.cuda() | ||
| 280 | + net.eval() | ||
| 281 | + in_ten = torch.randn(16, 3, 640, 480).cuda() | ||
| 282 | + out, out16, out32 = net(in_ten) | ||
| 283 | + print(out.shape) | ||
| 284 | + | ||
| 285 | + net.get_params() |
ernerf/data_utils/face_parsing/resnet.py
0 → 100644
| 1 | +#!/usr/bin/python | ||
| 2 | +# -*- encoding: utf-8 -*- | ||
| 3 | + | ||
| 4 | +import torch | ||
| 5 | +import torch.nn as nn | ||
| 6 | +import torch.nn.functional as F | ||
| 7 | +import torch.utils.model_zoo as modelzoo | ||
| 8 | + | ||
| 9 | +# from modules.bn import InPlaceABNSync as BatchNorm2d | ||
| 10 | + | ||
| 11 | +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' | ||
| 12 | + | ||
| 13 | + | ||
| 14 | +def conv3x3(in_planes, out_planes, stride=1): | ||
| 15 | + """3x3 convolution with padding""" | ||
| 16 | + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||
| 17 | + padding=1, bias=False) | ||
| 18 | + | ||
| 19 | + | ||
| 20 | +class BasicBlock(nn.Module): | ||
| 21 | + def __init__(self, in_chan, out_chan, stride=1): | ||
| 22 | + super(BasicBlock, self).__init__() | ||
| 23 | + self.conv1 = conv3x3(in_chan, out_chan, stride) | ||
| 24 | + self.bn1 = nn.BatchNorm2d(out_chan) | ||
| 25 | + self.conv2 = conv3x3(out_chan, out_chan) | ||
| 26 | + self.bn2 = nn.BatchNorm2d(out_chan) | ||
| 27 | + self.relu = nn.ReLU(inplace=True) | ||
| 28 | + self.downsample = None | ||
| 29 | + if in_chan != out_chan or stride != 1: | ||
| 30 | + self.downsample = nn.Sequential( | ||
| 31 | + nn.Conv2d(in_chan, out_chan, | ||
| 32 | + kernel_size=1, stride=stride, bias=False), | ||
| 33 | + nn.BatchNorm2d(out_chan), | ||
| 34 | + ) | ||
| 35 | + | ||
| 36 | + def forward(self, x): | ||
| 37 | + residual = self.conv1(x) | ||
| 38 | + residual = F.relu(self.bn1(residual)) | ||
| 39 | + residual = self.conv2(residual) | ||
| 40 | + residual = self.bn2(residual) | ||
| 41 | + | ||
| 42 | + shortcut = x | ||
| 43 | + if self.downsample is not None: | ||
| 44 | + shortcut = self.downsample(x) | ||
| 45 | + | ||
| 46 | + out = shortcut + residual | ||
| 47 | + out = self.relu(out) | ||
| 48 | + return out | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +def create_layer_basic(in_chan, out_chan, bnum, stride=1): | ||
| 52 | + layers = [BasicBlock(in_chan, out_chan, stride=stride)] | ||
| 53 | + for i in range(bnum-1): | ||
| 54 | + layers.append(BasicBlock(out_chan, out_chan, stride=1)) | ||
| 55 | + return nn.Sequential(*layers) | ||
| 56 | + | ||
| 57 | + | ||
| 58 | +class Resnet18(nn.Module): | ||
| 59 | + def __init__(self): | ||
| 60 | + super(Resnet18, self).__init__() | ||
| 61 | + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, | ||
| 62 | + bias=False) | ||
| 63 | + self.bn1 = nn.BatchNorm2d(64) | ||
| 64 | + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
| 65 | + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) | ||
| 66 | + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) | ||
| 67 | + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) | ||
| 68 | + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) | ||
| 69 | + self.init_weight() | ||
| 70 | + | ||
| 71 | + def forward(self, x): | ||
| 72 | + x = self.conv1(x) | ||
| 73 | + x = F.relu(self.bn1(x)) | ||
| 74 | + x = self.maxpool(x) | ||
| 75 | + | ||
| 76 | + x = self.layer1(x) | ||
| 77 | + feat8 = self.layer2(x) # 1/8 | ||
| 78 | + feat16 = self.layer3(feat8) # 1/16 | ||
| 79 | + feat32 = self.layer4(feat16) # 1/32 | ||
| 80 | + return feat8, feat16, feat32 | ||
| 81 | + | ||
| 82 | + def init_weight(self): | ||
| 83 | + state_dict = modelzoo.load_url(resnet18_url) | ||
| 84 | + self_state_dict = self.state_dict() | ||
| 85 | + for k, v in state_dict.items(): | ||
| 86 | + if 'fc' in k: continue | ||
| 87 | + self_state_dict.update({k: v}) | ||
| 88 | + self.load_state_dict(self_state_dict) | ||
| 89 | + | ||
| 90 | + def get_params(self): | ||
| 91 | + wd_params, nowd_params = [], [] | ||
| 92 | + for name, module in self.named_modules(): | ||
| 93 | + if isinstance(module, (nn.Linear, nn.Conv2d)): | ||
| 94 | + wd_params.append(module.weight) | ||
| 95 | + if not module.bias is None: | ||
| 96 | + nowd_params.append(module.bias) | ||
| 97 | + elif isinstance(module, nn.BatchNorm2d): | ||
| 98 | + nowd_params += list(module.parameters()) | ||
| 99 | + return wd_params, nowd_params | ||
| 100 | + | ||
| 101 | + | ||
| 102 | +if __name__ == "__main__": | ||
| 103 | + net = Resnet18() | ||
| 104 | + x = torch.randn(16, 3, 224, 224) | ||
| 105 | + out = net(x) | ||
| 106 | + print(out[0].size()) | ||
| 107 | + print(out[1].size()) | ||
| 108 | + print(out[2].size()) | ||
| 109 | + net.get_params() |
ernerf/data_utils/face_parsing/test.py
0 → 100644
| 1 | +#!/usr/bin/python | ||
| 2 | +# -*- encoding: utf-8 -*- | ||
| 3 | +import numpy as np | ||
| 4 | +from model import BiSeNet | ||
| 5 | + | ||
| 6 | +import torch | ||
| 7 | + | ||
| 8 | +import os | ||
| 9 | +import os.path as osp | ||
| 10 | + | ||
| 11 | +from PIL import Image | ||
| 12 | +import torchvision.transforms as transforms | ||
| 13 | +import cv2 | ||
| 14 | +from pathlib import Path | ||
| 15 | +import configargparse | ||
| 16 | +import tqdm | ||
| 17 | + | ||
| 18 | +# import ttach as tta | ||
| 19 | + | ||
| 20 | +def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg', | ||
| 21 | + img_size=(512, 512)): | ||
| 22 | + im = np.array(im) | ||
| 23 | + vis_im = im.copy().astype(np.uint8) | ||
| 24 | + vis_parsing_anno = parsing_anno.copy().astype(np.uint8) | ||
| 25 | + vis_parsing_anno = cv2.resize( | ||
| 26 | + vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) | ||
| 27 | + vis_parsing_anno_color = np.zeros( | ||
| 28 | + (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255 | ||
| 29 | + | ||
| 30 | + num_of_class = np.max(vis_parsing_anno) | ||
| 31 | + # print(num_of_class) | ||
| 32 | + for pi in range(1, 14): | ||
| 33 | + index = np.where(vis_parsing_anno == pi) | ||
| 34 | + vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) | ||
| 35 | + | ||
| 36 | + for pi in range(14, 16): | ||
| 37 | + index = np.where(vis_parsing_anno == pi) | ||
| 38 | + vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0]) | ||
| 39 | + for pi in range(16, 17): | ||
| 40 | + index = np.where(vis_parsing_anno == pi) | ||
| 41 | + vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255]) | ||
| 42 | + for pi in range(17, num_of_class+1): | ||
| 43 | + index = np.where(vis_parsing_anno == pi) | ||
| 44 | + vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) | ||
| 45 | + | ||
| 46 | + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) | ||
| 47 | + index = np.where(vis_parsing_anno == num_of_class-1) | ||
| 48 | + vis_im = cv2.resize(vis_parsing_anno_color, img_size, | ||
| 49 | + interpolation=cv2.INTER_NEAREST) | ||
| 50 | + if save_im: | ||
| 51 | + cv2.imwrite(save_path, vis_im) | ||
| 52 | + | ||
| 53 | + | ||
| 54 | +def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): | ||
| 55 | + | ||
| 56 | + Path(respth).mkdir(parents=True, exist_ok=True) | ||
| 57 | + | ||
| 58 | + print(f'[INFO] loading model...') | ||
| 59 | + n_classes = 19 | ||
| 60 | + net = BiSeNet(n_classes=n_classes) | ||
| 61 | + net.cuda() | ||
| 62 | + net.load_state_dict(torch.load(cp)) | ||
| 63 | + net.eval() | ||
| 64 | + | ||
| 65 | + to_tensor = transforms.Compose([ | ||
| 66 | + transforms.ToTensor(), | ||
| 67 | + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | ||
| 68 | + ]) | ||
| 69 | + | ||
| 70 | + image_paths = os.listdir(dspth) | ||
| 71 | + | ||
| 72 | + with torch.no_grad(): | ||
| 73 | + for image_path in tqdm.tqdm(image_paths): | ||
| 74 | + if image_path.endswith('.jpg') or image_path.endswith('.png'): | ||
| 75 | + img = Image.open(osp.join(dspth, image_path)) | ||
| 76 | + ori_size = img.size | ||
| 77 | + image = img.resize((512, 512), Image.BILINEAR) | ||
| 78 | + image = image.convert("RGB") | ||
| 79 | + img = to_tensor(image) | ||
| 80 | + | ||
| 81 | + # test-time augmentation. | ||
| 82 | + inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512] | ||
| 83 | + outputs = net(inputs.cuda()) | ||
| 84 | + parsing = outputs.mean(0).cpu().numpy().argmax(0) | ||
| 85 | + | ||
| 86 | + image_path = int(image_path[:-4]) | ||
| 87 | + image_path = str(image_path) + '.png' | ||
| 88 | + | ||
| 89 | + vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size) | ||
| 90 | + | ||
| 91 | + | ||
| 92 | +if __name__ == "__main__": | ||
| 93 | + parser = configargparse.ArgumentParser() | ||
| 94 | + parser.add_argument('--respath', type=str, default='./result/', help='result path for label') | ||
| 95 | + parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images') | ||
| 96 | + parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth') | ||
| 97 | + args = parser.parse_args() | ||
| 98 | + evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath) |
ernerf/data_utils/face_tracking/__init__.py
0 → 100644
| 1 | +import numpy as np | ||
| 2 | +from scipy.io import loadmat | ||
| 3 | + | ||
| 4 | +original_BFM = loadmat("3DMM/01_MorphableModel.mat") | ||
| 5 | +sub_inds = np.load("3DMM/topology_info.npy", allow_pickle=True).item()["sub_inds"] | ||
| 6 | + | ||
| 7 | +shapePC = original_BFM["shapePC"] | ||
| 8 | +shapeEV = original_BFM["shapeEV"] | ||
| 9 | +shapeMU = original_BFM["shapeMU"] | ||
| 10 | +texPC = original_BFM["texPC"] | ||
| 11 | +texEV = original_BFM["texEV"] | ||
| 12 | +texMU = original_BFM["texMU"] | ||
| 13 | + | ||
| 14 | +b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) | ||
| 15 | +mu_shape = shapeMU.reshape(-1, 3) | ||
| 16 | + | ||
| 17 | +b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) | ||
| 18 | +mu_tex = texMU.reshape(-1, 3) | ||
| 19 | + | ||
| 20 | +b_shape = b_shape[:, sub_inds, :].reshape(199, -1) | ||
| 21 | +mu_shape = mu_shape[sub_inds, :].reshape(-1) | ||
| 22 | +b_tex = b_tex[:, sub_inds, :].reshape(199, -1) | ||
| 23 | +mu_tex = mu_tex[sub_inds, :].reshape(-1) | ||
| 24 | + | ||
| 25 | +exp_info = np.load("3DMM/exp_info.npy", allow_pickle=True).item() | ||
| 26 | +np.save( | ||
| 27 | + "3DMM/3DMM_info.npy", | ||
| 28 | + { | ||
| 29 | + "mu_shape": mu_shape, | ||
| 30 | + "b_shape": b_shape, | ||
| 31 | + "sig_shape": shapeEV.reshape(-1), | ||
| 32 | + "mu_exp": exp_info["mu_exp"], | ||
| 33 | + "b_exp": exp_info["base_exp"], | ||
| 34 | + "sig_exp": exp_info["sig_exp"], | ||
| 35 | + "mu_tex": mu_tex, | ||
| 36 | + "b_tex": b_tex, | ||
| 37 | + "sig_tex": texEV.reshape(-1), | ||
| 38 | + }, | ||
| 39 | +) |
| 1 | +import os | ||
| 2 | +import torch | ||
| 3 | +import numpy as np | ||
| 4 | + | ||
| 5 | + | ||
| 6 | +def load_dir(path, start, end): | ||
| 7 | + lmss = [] | ||
| 8 | + imgs_paths = [] | ||
| 9 | + for i in range(start, end): | ||
| 10 | + if os.path.isfile(os.path.join(path, str(i) + ".lms")): | ||
| 11 | + lms = np.loadtxt(os.path.join(path, str(i) + ".lms"), dtype=np.float32) | ||
| 12 | + lmss.append(lms) | ||
| 13 | + imgs_paths.append(os.path.join(path, str(i) + ".jpg")) | ||
| 14 | + lmss = np.stack(lmss) | ||
| 15 | + lmss = torch.as_tensor(lmss).cuda() | ||
| 16 | + return lmss, imgs_paths |
| 1 | +import os | ||
| 2 | +import sys | ||
| 3 | +import cv2 | ||
| 4 | +import argparse | ||
| 5 | +from pathlib import Path | ||
| 6 | +import torch | ||
| 7 | +import numpy as np | ||
| 8 | +from data_loader import load_dir | ||
| 9 | +from facemodel import Face_3DMM | ||
| 10 | +from util import * | ||
| 11 | +from render_3dmm import Render_3DMM | ||
| 12 | + | ||
| 13 | + | ||
| 14 | +# torch.autograd.set_detect_anomaly(True) | ||
| 15 | + | ||
| 16 | +dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +def set_requires_grad(tensor_list): | ||
| 20 | + for tensor in tensor_list: | ||
| 21 | + tensor.requires_grad = True | ||
| 22 | + | ||
| 23 | + | ||
| 24 | +parser = argparse.ArgumentParser() | ||
| 25 | +parser.add_argument( | ||
| 26 | + "--path", type=str, default="obama/ori_imgs", help="idname of target person" | ||
| 27 | +) | ||
| 28 | +parser.add_argument("--img_h", type=int, default=512, help="image height") | ||
| 29 | +parser.add_argument("--img_w", type=int, default=512, help="image width") | ||
| 30 | +parser.add_argument("--frame_num", type=int, default=11000, help="image number") | ||
| 31 | +args = parser.parse_args() | ||
| 32 | + | ||
| 33 | +start_id = 0 | ||
| 34 | +end_id = args.frame_num | ||
| 35 | + | ||
| 36 | +lms, img_paths = load_dir(args.path, start_id, end_id) | ||
| 37 | +num_frames = lms.shape[0] | ||
| 38 | +h, w = args.img_h, args.img_w | ||
| 39 | +cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda() | ||
| 40 | +id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650 | ||
| 41 | +model_3dmm = Face_3DMM( | ||
| 42 | + os.path.join(dir_path, "3DMM"), id_dim, exp_dim, tex_dim, point_num | ||
| 43 | +) | ||
| 44 | + | ||
| 45 | +# only use one image per 40 to do fit the focal length | ||
| 46 | +sel_ids = np.arange(0, num_frames, 40) | ||
| 47 | +sel_num = sel_ids.shape[0] | ||
| 48 | +arg_focal = 1600 | ||
| 49 | +arg_landis = 1e5 | ||
| 50 | + | ||
| 51 | +print(f'[INFO] fitting focal length...') | ||
| 52 | + | ||
| 53 | +# fit the focal length | ||
| 54 | +for focal in range(600, 1500, 100): | ||
| 55 | + id_para = lms.new_zeros((1, id_dim), requires_grad=True) | ||
| 56 | + exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True) | ||
| 57 | + euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True) | ||
| 58 | + trans = lms.new_zeros((sel_num, 3), requires_grad=True) | ||
| 59 | + trans.data[:, 2] -= 7 | ||
| 60 | + focal_length = lms.new_zeros(1, requires_grad=False) | ||
| 61 | + focal_length.data += focal | ||
| 62 | + set_requires_grad([id_para, exp_para, euler_angle, trans]) | ||
| 63 | + | ||
| 64 | + optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1) | ||
| 65 | + optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=0.1) | ||
| 66 | + | ||
| 67 | + for iter in range(2000): | ||
| 68 | + id_para_batch = id_para.expand(sel_num, -1) | ||
| 69 | + geometry = model_3dmm.get_3dlandmarks( | ||
| 70 | + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy | ||
| 71 | + ) | ||
| 72 | + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) | ||
| 73 | + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach()) | ||
| 74 | + loss = loss_lan | ||
| 75 | + optimizer_frame.zero_grad() | ||
| 76 | + loss.backward() | ||
| 77 | + optimizer_frame.step() | ||
| 78 | + # if iter % 100 == 0: | ||
| 79 | + # print(focal, 'pose', iter, loss.item()) | ||
| 80 | + | ||
| 81 | + for iter in range(2500): | ||
| 82 | + id_para_batch = id_para.expand(sel_num, -1) | ||
| 83 | + geometry = model_3dmm.get_3dlandmarks( | ||
| 84 | + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy | ||
| 85 | + ) | ||
| 86 | + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) | ||
| 87 | + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach()) | ||
| 88 | + loss_regid = torch.mean(id_para * id_para) | ||
| 89 | + loss_regexp = torch.mean(exp_para * exp_para) | ||
| 90 | + loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4 | ||
| 91 | + optimizer_idexp.zero_grad() | ||
| 92 | + optimizer_frame.zero_grad() | ||
| 93 | + loss.backward() | ||
| 94 | + optimizer_idexp.step() | ||
| 95 | + optimizer_frame.step() | ||
| 96 | + # if iter % 100 == 0: | ||
| 97 | + # print(focal, 'poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item()) | ||
| 98 | + | ||
| 99 | + if iter % 1500 == 0 and iter >= 1500: | ||
| 100 | + for param_group in optimizer_idexp.param_groups: | ||
| 101 | + param_group["lr"] *= 0.2 | ||
| 102 | + for param_group in optimizer_frame.param_groups: | ||
| 103 | + param_group["lr"] *= 0.2 | ||
| 104 | + | ||
| 105 | + print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item()) | ||
| 106 | + | ||
| 107 | + if loss_lan.item() < arg_landis: | ||
| 108 | + arg_landis = loss_lan.item() | ||
| 109 | + arg_focal = focal | ||
| 110 | + | ||
| 111 | +print("[INFO] find best focal:", arg_focal) | ||
| 112 | + | ||
| 113 | +print(f'[INFO] coarse fitting...') | ||
| 114 | + | ||
| 115 | +# for all frames, do a coarse fitting ??? | ||
| 116 | +id_para = lms.new_zeros((1, id_dim), requires_grad=True) | ||
| 117 | +exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True) | ||
| 118 | +tex_para = lms.new_zeros( | ||
| 119 | + (1, tex_dim), requires_grad=True | ||
| 120 | +) # not optimized in this block ??? | ||
| 121 | +euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True) | ||
| 122 | +trans = lms.new_zeros((num_frames, 3), requires_grad=True) | ||
| 123 | +light_para = lms.new_zeros((num_frames, 27), requires_grad=True) | ||
| 124 | +trans.data[:, 2] -= 7 # ??? | ||
| 125 | +focal_length = lms.new_zeros(1, requires_grad=True) | ||
| 126 | +focal_length.data += arg_focal | ||
| 127 | + | ||
| 128 | +set_requires_grad([id_para, exp_para, tex_para, euler_angle, trans, light_para]) | ||
| 129 | + | ||
| 130 | +optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1) | ||
| 131 | +optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1) | ||
| 132 | + | ||
| 133 | +for iter in range(1500): | ||
| 134 | + id_para_batch = id_para.expand(num_frames, -1) | ||
| 135 | + geometry = model_3dmm.get_3dlandmarks( | ||
| 136 | + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy | ||
| 137 | + ) | ||
| 138 | + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) | ||
| 139 | + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach()) | ||
| 140 | + loss = loss_lan | ||
| 141 | + optimizer_frame.zero_grad() | ||
| 142 | + loss.backward() | ||
| 143 | + optimizer_frame.step() | ||
| 144 | + if iter == 1000: | ||
| 145 | + for param_group in optimizer_frame.param_groups: | ||
| 146 | + param_group["lr"] = 0.1 | ||
| 147 | + # if iter % 100 == 0: | ||
| 148 | + # print('pose', iter, loss.item()) | ||
| 149 | + | ||
| 150 | +for param_group in optimizer_frame.param_groups: | ||
| 151 | + param_group["lr"] = 0.1 | ||
| 152 | + | ||
| 153 | +for iter in range(2000): | ||
| 154 | + id_para_batch = id_para.expand(num_frames, -1) | ||
| 155 | + geometry = model_3dmm.get_3dlandmarks( | ||
| 156 | + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy | ||
| 157 | + ) | ||
| 158 | + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) | ||
| 159 | + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach()) | ||
| 160 | + loss_regid = torch.mean(id_para * id_para) | ||
| 161 | + loss_regexp = torch.mean(exp_para * exp_para) | ||
| 162 | + loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4 | ||
| 163 | + optimizer_idexp.zero_grad() | ||
| 164 | + optimizer_frame.zero_grad() | ||
| 165 | + loss.backward() | ||
| 166 | + optimizer_idexp.step() | ||
| 167 | + optimizer_frame.step() | ||
| 168 | + # if iter % 100 == 0: | ||
| 169 | + # print('poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item()) | ||
| 170 | + if iter % 1000 == 0 and iter >= 1000: | ||
| 171 | + for param_group in optimizer_idexp.param_groups: | ||
| 172 | + param_group["lr"] *= 0.2 | ||
| 173 | + for param_group in optimizer_frame.param_groups: | ||
| 174 | + param_group["lr"] *= 0.2 | ||
| 175 | + | ||
| 176 | +print(loss_lan.item(), torch.mean(trans[:, 2]).item()) | ||
| 177 | + | ||
| 178 | +print(f'[INFO] fitting light...') | ||
| 179 | + | ||
| 180 | +batch_size = 32 | ||
| 181 | + | ||
| 182 | +device_default = torch.device("cuda:0") | ||
| 183 | +device_render = torch.device("cuda:0") | ||
| 184 | +renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render) | ||
| 185 | + | ||
| 186 | +sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size] | ||
| 187 | +imgs = [] | ||
| 188 | +for sel_id in sel_ids: | ||
| 189 | + imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1]) | ||
| 190 | +imgs = np.stack(imgs) | ||
| 191 | +sel_imgs = torch.as_tensor(imgs).cuda() | ||
| 192 | +sel_lms = lms[sel_ids] | ||
| 193 | +sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True) | ||
| 194 | +set_requires_grad([sel_light]) | ||
| 195 | + | ||
| 196 | +optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=0.1) | ||
| 197 | +optimizer_id_frame = torch.optim.Adam([euler_angle, trans, exp_para, id_para], lr=0.01) | ||
| 198 | + | ||
| 199 | +for iter in range(71): | ||
| 200 | + sel_exp_para, sel_euler, sel_trans = ( | ||
| 201 | + exp_para[sel_ids], | ||
| 202 | + euler_angle[sel_ids], | ||
| 203 | + trans[sel_ids], | ||
| 204 | + ) | ||
| 205 | + sel_id_para = id_para.expand(batch_size, -1) | ||
| 206 | + geometry = model_3dmm.get_3dlandmarks( | ||
| 207 | + sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy | ||
| 208 | + ) | ||
| 209 | + proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy) | ||
| 210 | + | ||
| 211 | + loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach()) | ||
| 212 | + loss_regid = torch.mean(id_para * id_para) | ||
| 213 | + loss_regexp = torch.mean(sel_exp_para * sel_exp_para) | ||
| 214 | + | ||
| 215 | + sel_tex_para = tex_para.expand(batch_size, -1) | ||
| 216 | + sel_texture = model_3dmm.forward_tex(sel_tex_para) | ||
| 217 | + geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) | ||
| 218 | + rott_geo = forward_rott(geometry, sel_euler, sel_trans) | ||
| 219 | + render_imgs = renderer( | ||
| 220 | + rott_geo.to(device_render), | ||
| 221 | + sel_texture.to(device_render), | ||
| 222 | + sel_light.to(device_render), | ||
| 223 | + ) | ||
| 224 | + render_imgs = render_imgs.to(device_default) | ||
| 225 | + | ||
| 226 | + mask = (render_imgs[:, :, :, 3]).detach() > 0.0 | ||
| 227 | + render_proj = sel_imgs.clone() | ||
| 228 | + render_proj[mask] = render_imgs[mask][..., :3].byte() | ||
| 229 | + loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask) | ||
| 230 | + | ||
| 231 | + if iter > 50: | ||
| 232 | + loss = loss_col + loss_lan * 0.05 + loss_regid * 1.0 + loss_regexp * 0.8 | ||
| 233 | + else: | ||
| 234 | + loss = loss_col + loss_lan * 3 + loss_regid * 2.0 + loss_regexp * 1.0 | ||
| 235 | + | ||
| 236 | + optimizer_tl.zero_grad() | ||
| 237 | + optimizer_id_frame.zero_grad() | ||
| 238 | + loss.backward() | ||
| 239 | + | ||
| 240 | + optimizer_tl.step() | ||
| 241 | + optimizer_id_frame.step() | ||
| 242 | + | ||
| 243 | + if iter % 50 == 0 and iter > 0: | ||
| 244 | + for param_group in optimizer_id_frame.param_groups: | ||
| 245 | + param_group["lr"] *= 0.2 | ||
| 246 | + for param_group in optimizer_tl.param_groups: | ||
| 247 | + param_group["lr"] *= 0.2 | ||
| 248 | + # print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item()) | ||
| 249 | + | ||
| 250 | + | ||
| 251 | +light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1) | ||
| 252 | +light_para.data = light_mean | ||
| 253 | + | ||
| 254 | +exp_para = exp_para.detach() | ||
| 255 | +euler_angle = euler_angle.detach() | ||
| 256 | +trans = trans.detach() | ||
| 257 | +light_para = light_para.detach() | ||
| 258 | + | ||
| 259 | +print(f'[INFO] fine frame-wise fitting...') | ||
| 260 | + | ||
| 261 | +for i in range(int((num_frames - 1) / batch_size + 1)): | ||
| 262 | + | ||
| 263 | + if (i + 1) * batch_size > num_frames: | ||
| 264 | + start_n = num_frames - batch_size | ||
| 265 | + sel_ids = np.arange(num_frames - batch_size, num_frames) | ||
| 266 | + else: | ||
| 267 | + start_n = i * batch_size | ||
| 268 | + sel_ids = np.arange(i * batch_size, i * batch_size + batch_size) | ||
| 269 | + | ||
| 270 | + imgs = [] | ||
| 271 | + for sel_id in sel_ids: | ||
| 272 | + imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1]) | ||
| 273 | + imgs = np.stack(imgs) | ||
| 274 | + sel_imgs = torch.as_tensor(imgs).cuda() | ||
| 275 | + sel_lms = lms[sel_ids] | ||
| 276 | + | ||
| 277 | + sel_exp_para = exp_para.new_zeros((batch_size, exp_dim), requires_grad=True) | ||
| 278 | + sel_exp_para.data = exp_para[sel_ids].clone() | ||
| 279 | + sel_euler = euler_angle.new_zeros((batch_size, 3), requires_grad=True) | ||
| 280 | + sel_euler.data = euler_angle[sel_ids].clone() | ||
| 281 | + sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True) | ||
| 282 | + sel_trans.data = trans[sel_ids].clone() | ||
| 283 | + sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True) | ||
| 284 | + sel_light.data = light_para[sel_ids].clone() | ||
| 285 | + | ||
| 286 | + set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light]) | ||
| 287 | + | ||
| 288 | + optimizer_cur_batch = torch.optim.Adam( | ||
| 289 | + [sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005 | ||
| 290 | + ) | ||
| 291 | + | ||
| 292 | + sel_id_para = id_para.expand(batch_size, -1).detach() | ||
| 293 | + sel_tex_para = tex_para.expand(batch_size, -1).detach() | ||
| 294 | + | ||
| 295 | + pre_num = 5 | ||
| 296 | + | ||
| 297 | + if i > 0: | ||
| 298 | + pre_ids = np.arange(start_n - pre_num, start_n) | ||
| 299 | + | ||
| 300 | + for iter in range(50): | ||
| 301 | + | ||
| 302 | + geometry = model_3dmm.get_3dlandmarks( | ||
| 303 | + sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy | ||
| 304 | + ) | ||
| 305 | + proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy) | ||
| 306 | + loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach()) | ||
| 307 | + loss_regexp = torch.mean(sel_exp_para * sel_exp_para) | ||
| 308 | + | ||
| 309 | + sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) | ||
| 310 | + sel_texture = model_3dmm.forward_tex(sel_tex_para) | ||
| 311 | + geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) | ||
| 312 | + rott_geo = forward_rott(geometry, sel_euler, sel_trans) | ||
| 313 | + render_imgs = renderer( | ||
| 314 | + rott_geo.to(device_render), | ||
| 315 | + sel_texture.to(device_render), | ||
| 316 | + sel_light.to(device_render), | ||
| 317 | + ) | ||
| 318 | + render_imgs = render_imgs.to(device_default) | ||
| 319 | + | ||
| 320 | + mask = (render_imgs[:, :, :, 3]).detach() > 0.0 | ||
| 321 | + | ||
| 322 | + loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask) | ||
| 323 | + | ||
| 324 | + if i > 0: | ||
| 325 | + geometry_lap = model_3dmm.forward_geo_sub( | ||
| 326 | + id_para.expand(batch_size + pre_num, -1).detach(), | ||
| 327 | + torch.cat((exp_para[pre_ids].detach(), sel_exp_para)), | ||
| 328 | + model_3dmm.rigid_ids, | ||
| 329 | + ) | ||
| 330 | + rott_geo_lap = forward_rott( | ||
| 331 | + geometry_lap, | ||
| 332 | + torch.cat((euler_angle[pre_ids].detach(), sel_euler)), | ||
| 333 | + torch.cat((trans[pre_ids].detach(), sel_trans)), | ||
| 334 | + ) | ||
| 335 | + loss_lap = cal_lap_loss( | ||
| 336 | + [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0] | ||
| 337 | + ) | ||
| 338 | + else: | ||
| 339 | + geometry_lap = model_3dmm.forward_geo_sub( | ||
| 340 | + id_para.expand(batch_size, -1).detach(), | ||
| 341 | + sel_exp_para, | ||
| 342 | + model_3dmm.rigid_ids, | ||
| 343 | + ) | ||
| 344 | + rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans) | ||
| 345 | + loss_lap = cal_lap_loss( | ||
| 346 | + [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0] | ||
| 347 | + ) | ||
| 348 | + | ||
| 349 | + | ||
| 350 | + if iter > 30: | ||
| 351 | + loss = loss_col * 0.5 + loss_lan * 1.5 + loss_lap * 100000 + loss_regexp * 1.0 | ||
| 352 | + else: | ||
| 353 | + loss = loss_col * 0.5 + loss_lan * 8 + loss_lap * 100000 + loss_regexp * 1.0 | ||
| 354 | + | ||
| 355 | + optimizer_cur_batch.zero_grad() | ||
| 356 | + loss.backward() | ||
| 357 | + optimizer_cur_batch.step() | ||
| 358 | + | ||
| 359 | + # if iter % 10 == 0: | ||
| 360 | + # print( | ||
| 361 | + # i, | ||
| 362 | + # iter, | ||
| 363 | + # loss_col.item(), | ||
| 364 | + # loss_lan.item(), | ||
| 365 | + # loss_lap.item(), | ||
| 366 | + # loss_regexp.item(), | ||
| 367 | + # ) | ||
| 368 | + | ||
| 369 | + print(str(i) + " of " + str(int((num_frames - 1) / batch_size + 1)) + " done") | ||
| 370 | + | ||
| 371 | + render_proj = sel_imgs.clone() | ||
| 372 | + render_proj[mask] = render_imgs[mask][..., :3].byte() | ||
| 373 | + | ||
| 374 | + exp_para[sel_ids] = sel_exp_para.clone() | ||
| 375 | + euler_angle[sel_ids] = sel_euler.clone() | ||
| 376 | + trans[sel_ids] = sel_trans.clone() | ||
| 377 | + light_para[sel_ids] = sel_light.clone() | ||
| 378 | + | ||
| 379 | +torch.save( | ||
| 380 | + { | ||
| 381 | + "id": id_para.detach().cpu(), | ||
| 382 | + "exp": exp_para.detach().cpu(), | ||
| 383 | + "euler": euler_angle.detach().cpu(), | ||
| 384 | + "trans": trans.detach().cpu(), | ||
| 385 | + "focal": focal_length.detach().cpu(), | ||
| 386 | + }, | ||
| 387 | + os.path.join(os.path.dirname(args.path), "track_params.pt"), | ||
| 388 | +) | ||
| 389 | + | ||
| 390 | +print("params saved") |
ernerf/data_utils/face_tracking/facemodel.py
0 → 100644
| 1 | +import torch | ||
| 2 | +import torch.nn as nn | ||
| 3 | +import numpy as np | ||
| 4 | +import os | ||
| 5 | +from util import * | ||
| 6 | + | ||
| 7 | + | ||
| 8 | +class Face_3DMM(nn.Module): | ||
| 9 | + def __init__(self, modelpath, id_dim, exp_dim, tex_dim, point_num): | ||
| 10 | + super(Face_3DMM, self).__init__() | ||
| 11 | + # id_dim = 100 | ||
| 12 | + # exp_dim = 79 | ||
| 13 | + # tex_dim = 100 | ||
| 14 | + self.point_num = point_num | ||
| 15 | + DMM_info = np.load( | ||
| 16 | + os.path.join(modelpath, "3DMM_info.npy"), allow_pickle=True | ||
| 17 | + ).item() | ||
| 18 | + base_id = DMM_info["b_shape"][:id_dim, :] | ||
| 19 | + mu_id = DMM_info["mu_shape"] | ||
| 20 | + base_exp = DMM_info["b_exp"][:exp_dim, :] | ||
| 21 | + mu_exp = DMM_info["mu_exp"] | ||
| 22 | + mu = mu_id + mu_exp | ||
| 23 | + mu = mu.reshape(-1, 3) | ||
| 24 | + for i in range(3): | ||
| 25 | + mu[:, i] -= np.mean(mu[:, i]) | ||
| 26 | + mu = mu.reshape(-1) | ||
| 27 | + self.base_id = torch.as_tensor(base_id).cuda() / 100000.0 | ||
| 28 | + self.base_exp = torch.as_tensor(base_exp).cuda() / 100000.0 | ||
| 29 | + self.mu = torch.as_tensor(mu).cuda() / 100000.0 | ||
| 30 | + base_tex = DMM_info["b_tex"][:tex_dim, :] | ||
| 31 | + mu_tex = DMM_info["mu_tex"] | ||
| 32 | + self.base_tex = torch.as_tensor(base_tex).cuda() | ||
| 33 | + self.mu_tex = torch.as_tensor(mu_tex).cuda() | ||
| 34 | + sig_id = DMM_info["sig_shape"][:id_dim] | ||
| 35 | + sig_tex = DMM_info["sig_tex"][:tex_dim] | ||
| 36 | + sig_exp = DMM_info["sig_exp"][:exp_dim] | ||
| 37 | + self.sig_id = torch.as_tensor(sig_id).cuda() | ||
| 38 | + self.sig_tex = torch.as_tensor(sig_tex).cuda() | ||
| 39 | + self.sig_exp = torch.as_tensor(sig_exp).cuda() | ||
| 40 | + | ||
| 41 | + keys_info = np.load( | ||
| 42 | + os.path.join(modelpath, "keys_info.npy"), allow_pickle=True | ||
| 43 | + ).item() | ||
| 44 | + self.keyinds = torch.as_tensor(keys_info["keyinds"]).cuda() | ||
| 45 | + self.left_contours = torch.as_tensor(keys_info["left_contour"]).cuda() | ||
| 46 | + self.right_contours = torch.as_tensor(keys_info["right_contour"]).cuda() | ||
| 47 | + self.rigid_ids = torch.as_tensor(keys_info["rigid_ids"]).cuda() | ||
| 48 | + | ||
| 49 | + def get_3dlandmarks(self, id_para, exp_para, euler_angle, trans, focal_length, cxy): | ||
| 50 | + id_para = id_para * self.sig_id | ||
| 51 | + exp_para = exp_para * self.sig_exp | ||
| 52 | + batch_size = id_para.shape[0] | ||
| 53 | + num_per_contour = self.left_contours.shape[1] | ||
| 54 | + left_contours_flat = self.left_contours.reshape(-1) | ||
| 55 | + right_contours_flat = self.right_contours.reshape(-1) | ||
| 56 | + sel_index = torch.cat( | ||
| 57 | + ( | ||
| 58 | + 3 * left_contours_flat.unsqueeze(1), | ||
| 59 | + 3 * left_contours_flat.unsqueeze(1) + 1, | ||
| 60 | + 3 * left_contours_flat.unsqueeze(1) + 2, | ||
| 61 | + ), | ||
| 62 | + dim=1, | ||
| 63 | + ).reshape(-1) | ||
| 64 | + left_geometry = ( | ||
| 65 | + torch.mm(id_para, self.base_id[:, sel_index]) | ||
| 66 | + + torch.mm(exp_para, self.base_exp[:, sel_index]) | ||
| 67 | + + self.mu[sel_index] | ||
| 68 | + ) | ||
| 69 | + left_geometry = left_geometry.view(batch_size, -1, 3) | ||
| 70 | + proj_x = forward_transform( | ||
| 71 | + left_geometry, euler_angle, trans, focal_length, cxy | ||
| 72 | + )[:, :, 0] | ||
| 73 | + proj_x = proj_x.reshape(batch_size, 8, num_per_contour) | ||
| 74 | + arg_min = proj_x.argmin(dim=2) | ||
| 75 | + left_geometry = left_geometry.view(batch_size * 8, num_per_contour, 3) | ||
| 76 | + left_3dlands = left_geometry[ | ||
| 77 | + torch.arange(batch_size * 8), arg_min.view(-1), : | ||
| 78 | + ].view(batch_size, 8, 3) | ||
| 79 | + | ||
| 80 | + sel_index = torch.cat( | ||
| 81 | + ( | ||
| 82 | + 3 * right_contours_flat.unsqueeze(1), | ||
| 83 | + 3 * right_contours_flat.unsqueeze(1) + 1, | ||
| 84 | + 3 * right_contours_flat.unsqueeze(1) + 2, | ||
| 85 | + ), | ||
| 86 | + dim=1, | ||
| 87 | + ).reshape(-1) | ||
| 88 | + right_geometry = ( | ||
| 89 | + torch.mm(id_para, self.base_id[:, sel_index]) | ||
| 90 | + + torch.mm(exp_para, self.base_exp[:, sel_index]) | ||
| 91 | + + self.mu[sel_index] | ||
| 92 | + ) | ||
| 93 | + right_geometry = right_geometry.view(batch_size, -1, 3) | ||
| 94 | + proj_x = forward_transform( | ||
| 95 | + right_geometry, euler_angle, trans, focal_length, cxy | ||
| 96 | + )[:, :, 0] | ||
| 97 | + proj_x = proj_x.reshape(batch_size, 8, num_per_contour) | ||
| 98 | + arg_max = proj_x.argmax(dim=2) | ||
| 99 | + right_geometry = right_geometry.view(batch_size * 8, num_per_contour, 3) | ||
| 100 | + right_3dlands = right_geometry[ | ||
| 101 | + torch.arange(batch_size * 8), arg_max.view(-1), : | ||
| 102 | + ].view(batch_size, 8, 3) | ||
| 103 | + | ||
| 104 | + sel_index = torch.cat( | ||
| 105 | + ( | ||
| 106 | + 3 * self.keyinds.unsqueeze(1), | ||
| 107 | + 3 * self.keyinds.unsqueeze(1) + 1, | ||
| 108 | + 3 * self.keyinds.unsqueeze(1) + 2, | ||
| 109 | + ), | ||
| 110 | + dim=1, | ||
| 111 | + ).reshape(-1) | ||
| 112 | + geometry = ( | ||
| 113 | + torch.mm(id_para, self.base_id[:, sel_index]) | ||
| 114 | + + torch.mm(exp_para, self.base_exp[:, sel_index]) | ||
| 115 | + + self.mu[sel_index] | ||
| 116 | + ) | ||
| 117 | + lands_3d = geometry.view(-1, self.keyinds.shape[0], 3) | ||
| 118 | + lands_3d[:, :8, :] = left_3dlands | ||
| 119 | + lands_3d[:, 9:17, :] = right_3dlands | ||
| 120 | + return lands_3d | ||
| 121 | + | ||
| 122 | + def forward_geo_sub(self, id_para, exp_para, sub_index): | ||
| 123 | + id_para = id_para * self.sig_id | ||
| 124 | + exp_para = exp_para * self.sig_exp | ||
| 125 | + sel_index = torch.cat( | ||
| 126 | + ( | ||
| 127 | + 3 * sub_index.unsqueeze(1), | ||
| 128 | + 3 * sub_index.unsqueeze(1) + 1, | ||
| 129 | + 3 * sub_index.unsqueeze(1) + 2, | ||
| 130 | + ), | ||
| 131 | + dim=1, | ||
| 132 | + ).reshape(-1) | ||
| 133 | + geometry = ( | ||
| 134 | + torch.mm(id_para, self.base_id[:, sel_index]) | ||
| 135 | + + torch.mm(exp_para, self.base_exp[:, sel_index]) | ||
| 136 | + + self.mu[sel_index] | ||
| 137 | + ) | ||
| 138 | + return geometry.reshape(-1, sub_index.shape[0], 3) | ||
| 139 | + | ||
| 140 | + def forward_geo(self, id_para, exp_para): | ||
| 141 | + id_para = id_para * self.sig_id | ||
| 142 | + exp_para = exp_para * self.sig_exp | ||
| 143 | + geometry = ( | ||
| 144 | + torch.mm(id_para, self.base_id) | ||
| 145 | + + torch.mm(exp_para, self.base_exp) | ||
| 146 | + + self.mu | ||
| 147 | + ) | ||
| 148 | + return geometry.reshape(-1, self.point_num, 3) | ||
| 149 | + | ||
| 150 | + def forward_tex(self, tex_para): | ||
| 151 | + tex_para = tex_para * self.sig_tex | ||
| 152 | + texture = torch.mm(tex_para, self.base_tex) + self.mu_tex | ||
| 153 | + return texture.reshape(-1, self.point_num, 3) |
| 1 | +"""This module contains functions for geometry transform and camera projection""" | ||
| 2 | +import torch | ||
| 3 | +import torch.nn as nn | ||
| 4 | +import numpy as np | ||
| 5 | + | ||
| 6 | + | ||
| 7 | +def euler2rot(euler_angle): | ||
| 8 | + batch_size = euler_angle.shape[0] | ||
| 9 | + theta = euler_angle[:, 0].reshape(-1, 1, 1) | ||
| 10 | + phi = euler_angle[:, 1].reshape(-1, 1, 1) | ||
| 11 | + psi = euler_angle[:, 2].reshape(-1, 1, 1) | ||
| 12 | + one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) | ||
| 13 | + zero = torch.zeros( | ||
| 14 | + (batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device | ||
| 15 | + ) | ||
| 16 | + rot_x = torch.cat( | ||
| 17 | + ( | ||
| 18 | + torch.cat((one, zero, zero), 1), | ||
| 19 | + torch.cat((zero, theta.cos(), theta.sin()), 1), | ||
| 20 | + torch.cat((zero, -theta.sin(), theta.cos()), 1), | ||
| 21 | + ), | ||
| 22 | + 2, | ||
| 23 | + ) | ||
| 24 | + rot_y = torch.cat( | ||
| 25 | + ( | ||
| 26 | + torch.cat((phi.cos(), zero, -phi.sin()), 1), | ||
| 27 | + torch.cat((zero, one, zero), 1), | ||
| 28 | + torch.cat((phi.sin(), zero, phi.cos()), 1), | ||
| 29 | + ), | ||
| 30 | + 2, | ||
| 31 | + ) | ||
| 32 | + rot_z = torch.cat( | ||
| 33 | + ( | ||
| 34 | + torch.cat((psi.cos(), -psi.sin(), zero), 1), | ||
| 35 | + torch.cat((psi.sin(), psi.cos(), zero), 1), | ||
| 36 | + torch.cat((zero, zero, one), 1), | ||
| 37 | + ), | ||
| 38 | + 2, | ||
| 39 | + ) | ||
| 40 | + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) | ||
| 41 | + | ||
| 42 | + | ||
| 43 | +def rot_trans_geo(geometry, rot, trans): | ||
| 44 | + rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1) | ||
| 45 | + return rott_geo.permute(0, 2, 1) | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +def euler_trans_geo(geometry, euler, trans): | ||
| 49 | + rot = euler2rot(euler) | ||
| 50 | + return rot_trans_geo(geometry, rot, trans) | ||
| 51 | + | ||
| 52 | + | ||
| 53 | +def proj_geo(rott_geo, camera_para): | ||
| 54 | + fx = camera_para[:, 0] | ||
| 55 | + fy = camera_para[:, 0] | ||
| 56 | + cx = camera_para[:, 1] | ||
| 57 | + cy = camera_para[:, 2] | ||
| 58 | + | ||
| 59 | + X = rott_geo[:, :, 0] | ||
| 60 | + Y = rott_geo[:, :, 1] | ||
| 61 | + Z = rott_geo[:, :, 2] | ||
| 62 | + | ||
| 63 | + fxX = fx[:, None] * X | ||
| 64 | + fyY = fy[:, None] * Y | ||
| 65 | + | ||
| 66 | + proj_x = -fxX / Z + cx[:, None] | ||
| 67 | + proj_y = fyY / Z + cy[:, None] | ||
| 68 | + | ||
| 69 | + return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) |
| 1 | +import torch | ||
| 2 | +import torch.nn as nn | ||
| 3 | +import numpy as np | ||
| 4 | +import os | ||
| 5 | +from pytorch3d.structures import Meshes | ||
| 6 | +from pytorch3d.renderer import ( | ||
| 7 | + look_at_view_transform, | ||
| 8 | + PerspectiveCameras, | ||
| 9 | + FoVPerspectiveCameras, | ||
| 10 | + PointLights, | ||
| 11 | + DirectionalLights, | ||
| 12 | + Materials, | ||
| 13 | + RasterizationSettings, | ||
| 14 | + MeshRenderer, | ||
| 15 | + MeshRasterizer, | ||
| 16 | + SoftPhongShader, | ||
| 17 | + TexturesUV, | ||
| 18 | + TexturesVertex, | ||
| 19 | + blending, | ||
| 20 | +) | ||
| 21 | + | ||
| 22 | +from pytorch3d.ops import interpolate_face_attributes | ||
| 23 | + | ||
| 24 | +from pytorch3d.renderer.blending import ( | ||
| 25 | + BlendParams, | ||
| 26 | + hard_rgb_blend, | ||
| 27 | + sigmoid_alpha_blend, | ||
| 28 | + softmax_rgb_blend, | ||
| 29 | +) | ||
| 30 | + | ||
| 31 | + | ||
| 32 | +class SoftSimpleShader(nn.Module): | ||
| 33 | + """ | ||
| 34 | + Per pixel lighting - the lighting model is applied using the interpolated | ||
| 35 | + coordinates and normals for each pixel. The blending function returns the | ||
| 36 | + soft aggregated color using all the faces per pixel. | ||
| 37 | + | ||
| 38 | + To use the default values, simply initialize the shader with the desired | ||
| 39 | + device e.g. | ||
| 40 | + | ||
| 41 | + """ | ||
| 42 | + | ||
| 43 | + def __init__( | ||
| 44 | + self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None | ||
| 45 | + ): | ||
| 46 | + super().__init__() | ||
| 47 | + self.lights = lights if lights is not None else PointLights(device=device) | ||
| 48 | + self.materials = ( | ||
| 49 | + materials if materials is not None else Materials(device=device) | ||
| 50 | + ) | ||
| 51 | + self.cameras = cameras | ||
| 52 | + self.blend_params = blend_params if blend_params is not None else BlendParams() | ||
| 53 | + | ||
| 54 | + def to(self, device): | ||
| 55 | + # Manually move to device modules which are not subclasses of nn.Module | ||
| 56 | + self.cameras = self.cameras.to(device) | ||
| 57 | + self.materials = self.materials.to(device) | ||
| 58 | + self.lights = self.lights.to(device) | ||
| 59 | + return self | ||
| 60 | + | ||
| 61 | + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: | ||
| 62 | + | ||
| 63 | + texels = meshes.sample_textures(fragments) | ||
| 64 | + blend_params = kwargs.get("blend_params", self.blend_params) | ||
| 65 | + | ||
| 66 | + cameras = kwargs.get("cameras", self.cameras) | ||
| 67 | + if cameras is None: | ||
| 68 | + msg = "Cameras must be specified either at initialization \ | ||
| 69 | + or in the forward pass of SoftPhongShader" | ||
| 70 | + raise ValueError(msg) | ||
| 71 | + znear = kwargs.get("znear", getattr(cameras, "znear", 1.0)) | ||
| 72 | + zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) | ||
| 73 | + images = softmax_rgb_blend( | ||
| 74 | + texels, fragments, blend_params, znear=znear, zfar=zfar | ||
| 75 | + ) | ||
| 76 | + return images | ||
| 77 | + | ||
| 78 | + | ||
| 79 | +class Render_3DMM(nn.Module): | ||
| 80 | + def __init__( | ||
| 81 | + self, | ||
| 82 | + focal=1015, | ||
| 83 | + img_h=500, | ||
| 84 | + img_w=500, | ||
| 85 | + batch_size=1, | ||
| 86 | + device=torch.device("cuda:0"), | ||
| 87 | + ): | ||
| 88 | + super(Render_3DMM, self).__init__() | ||
| 89 | + | ||
| 90 | + self.focal = focal | ||
| 91 | + self.img_h = img_h | ||
| 92 | + self.img_w = img_w | ||
| 93 | + self.device = device | ||
| 94 | + self.renderer = self.get_render(batch_size) | ||
| 95 | + | ||
| 96 | + dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
| 97 | + topo_info = np.load( | ||
| 98 | + os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True | ||
| 99 | + ).item() | ||
| 100 | + self.tris = torch.as_tensor(topo_info["tris"]).to(self.device) | ||
| 101 | + self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device) | ||
| 102 | + | ||
| 103 | + def compute_normal(self, geometry): | ||
| 104 | + vert_1 = torch.index_select(geometry, 1, self.tris[:, 0]) | ||
| 105 | + vert_2 = torch.index_select(geometry, 1, self.tris[:, 1]) | ||
| 106 | + vert_3 = torch.index_select(geometry, 1, self.tris[:, 2]) | ||
| 107 | + nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) | ||
| 108 | + tri_normal = nn.functional.normalize(nnorm, dim=2) | ||
| 109 | + v_norm = tri_normal[:, self.vert_tris, :].sum(2) | ||
| 110 | + vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2) | ||
| 111 | + return vert_normal | ||
| 112 | + | ||
| 113 | + def get_render(self, batch_size=1): | ||
| 114 | + half_s = self.img_w * 0.5 | ||
| 115 | + R, T = look_at_view_transform(10, 0, 0) | ||
| 116 | + R = R.repeat(batch_size, 1, 1) | ||
| 117 | + T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device) | ||
| 118 | + | ||
| 119 | + cameras = FoVPerspectiveCameras( | ||
| 120 | + device=self.device, | ||
| 121 | + R=R, | ||
| 122 | + T=T, | ||
| 123 | + znear=0.01, | ||
| 124 | + zfar=20, | ||
| 125 | + fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi, | ||
| 126 | + ) | ||
| 127 | + lights = PointLights( | ||
| 128 | + device=self.device, | ||
| 129 | + location=[[0.0, 0.0, 1e5]], | ||
| 130 | + ambient_color=[[1, 1, 1]], | ||
| 131 | + specular_color=[[0.0, 0.0, 0.0]], | ||
| 132 | + diffuse_color=[[0.0, 0.0, 0.0]], | ||
| 133 | + ) | ||
| 134 | + sigma = 1e-4 | ||
| 135 | + raster_settings = RasterizationSettings( | ||
| 136 | + image_size=(self.img_h, self.img_w), | ||
| 137 | + blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0, | ||
| 138 | + faces_per_pixel=2, | ||
| 139 | + perspective_correct=False, | ||
| 140 | + ) | ||
| 141 | + blend_params = blending.BlendParams(background_color=[0, 0, 0]) | ||
| 142 | + renderer = MeshRenderer( | ||
| 143 | + rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras), | ||
| 144 | + shader=SoftSimpleShader( | ||
| 145 | + lights=lights, blend_params=blend_params, cameras=cameras | ||
| 146 | + ), | ||
| 147 | + ) | ||
| 148 | + return renderer.to(self.device) | ||
| 149 | + | ||
| 150 | + @staticmethod | ||
| 151 | + def Illumination_layer(face_texture, norm, gamma): | ||
| 152 | + | ||
| 153 | + n_b, num_vertex, _ = face_texture.size() | ||
| 154 | + n_v_full = n_b * num_vertex | ||
| 155 | + gamma = gamma.view(-1, 3, 9).clone() | ||
| 156 | + gamma[:, :, 0] += 0.8 | ||
| 157 | + | ||
| 158 | + gamma = gamma.permute(0, 2, 1) | ||
| 159 | + | ||
| 160 | + a0 = np.pi | ||
| 161 | + a1 = 2 * np.pi / np.sqrt(3.0) | ||
| 162 | + a2 = 2 * np.pi / np.sqrt(8.0) | ||
| 163 | + c0 = 1 / np.sqrt(4 * np.pi) | ||
| 164 | + c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi) | ||
| 165 | + c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi) | ||
| 166 | + d0 = 0.5 / np.sqrt(3.0) | ||
| 167 | + | ||
| 168 | + Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0 | ||
| 169 | + norm = norm.view(-1, 3) | ||
| 170 | + nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2] | ||
| 171 | + arrH = [] | ||
| 172 | + | ||
| 173 | + arrH.append(Y0) | ||
| 174 | + arrH.append(-a1 * c1 * ny) | ||
| 175 | + arrH.append(a1 * c1 * nz) | ||
| 176 | + arrH.append(-a1 * c1 * nx) | ||
| 177 | + arrH.append(a2 * c2 * nx * ny) | ||
| 178 | + arrH.append(-a2 * c2 * ny * nz) | ||
| 179 | + arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1)) | ||
| 180 | + arrH.append(-a2 * c2 * nx * nz) | ||
| 181 | + arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2))) | ||
| 182 | + | ||
| 183 | + H = torch.stack(arrH, 1) | ||
| 184 | + Y = H.view(n_b, num_vertex, 9) | ||
| 185 | + lighting = Y.bmm(gamma) | ||
| 186 | + | ||
| 187 | + face_color = face_texture * lighting | ||
| 188 | + return face_color | ||
| 189 | + | ||
| 190 | + def forward(self, rott_geometry, texture, diffuse_sh): | ||
| 191 | + face_normal = self.compute_normal(rott_geometry) | ||
| 192 | + face_color = self.Illumination_layer(texture, face_normal, diffuse_sh) | ||
| 193 | + face_color = TexturesVertex(face_color) | ||
| 194 | + mesh = Meshes( | ||
| 195 | + rott_geometry, | ||
| 196 | + self.tris.float().repeat(rott_geometry.shape[0], 1, 1), | ||
| 197 | + face_color, | ||
| 198 | + ) | ||
| 199 | + rendered_img = self.renderer(mesh) | ||
| 200 | + rendered_img = torch.clamp(rendered_img, 0, 255) | ||
| 201 | + | ||
| 202 | + return rendered_img |
| 1 | +import torch | ||
| 2 | +import torch.nn as nn | ||
| 3 | +import render_util | ||
| 4 | +import geo_transform | ||
| 5 | +import numpy as np | ||
| 6 | + | ||
| 7 | + | ||
| 8 | +def compute_tri_normal(geometry, tris): | ||
| 9 | + geometry = geometry.permute(0, 2, 1) | ||
| 10 | + tri_1 = tris[:, 0] | ||
| 11 | + tri_2 = tris[:, 1] | ||
| 12 | + tri_3 = tris[:, 2] | ||
| 13 | + | ||
| 14 | + vert_1 = torch.index_select(geometry, 2, tri_1) | ||
| 15 | + vert_2 = torch.index_select(geometry, 2, tri_2) | ||
| 16 | + vert_3 = torch.index_select(geometry, 2, tri_3) | ||
| 17 | + | ||
| 18 | + nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1) | ||
| 19 | + normal = nn.functional.normalize(nnorm).permute(0, 2, 1) | ||
| 20 | + return normal | ||
| 21 | + | ||
| 22 | + | ||
| 23 | +class Compute_normal_base(torch.autograd.Function): | ||
| 24 | + @staticmethod | ||
| 25 | + def forward(ctx, normal): | ||
| 26 | + (normal_b,) = render_util.normal_base_forward(normal) | ||
| 27 | + ctx.save_for_backward(normal) | ||
| 28 | + return normal_b | ||
| 29 | + | ||
| 30 | + @staticmethod | ||
| 31 | + def backward(ctx, grad_normal_b): | ||
| 32 | + (normal,) = ctx.saved_tensors | ||
| 33 | + (grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal) | ||
| 34 | + return grad_normal | ||
| 35 | + | ||
| 36 | + | ||
| 37 | +class Normal_Base(torch.nn.Module): | ||
| 38 | + def __init__(self): | ||
| 39 | + super(Normal_Base, self).__init__() | ||
| 40 | + | ||
| 41 | + def forward(self, normal): | ||
| 42 | + return Compute_normal_base.apply(normal) | ||
| 43 | + | ||
| 44 | + | ||
| 45 | +def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img): | ||
| 46 | + point_num = geometry.shape[1] | ||
| 47 | + rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans) | ||
| 48 | + proj_geo = geo_transform.proj_geo(rott_geo, cam) | ||
| 49 | + rot_tri_normal = compute_tri_normal(rott_geo, tris) | ||
| 50 | + rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris) | ||
| 51 | + is_visible = -torch.bmm( | ||
| 52 | + rot_vert_normal.reshape(-1, 1, 3), | ||
| 53 | + nn.functional.normalize(rott_geo.reshape(-1, 3, 1)), | ||
| 54 | + ).reshape(-1, point_num) | ||
| 55 | + is_visible[is_visible < 0.01] = -1 | ||
| 56 | + pixel_valid = torch.zeros( | ||
| 57 | + (ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]), | ||
| 58 | + dtype=torch.float32, | ||
| 59 | + device=ori_img.device, | ||
| 60 | + ) | ||
| 61 | + return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid | ||
| 62 | + | ||
| 63 | + | ||
| 64 | +class Render_Face(torch.autograd.Function): | ||
| 65 | + @staticmethod | ||
| 66 | + def forward( | ||
| 67 | + ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid | ||
| 68 | + ): | ||
| 69 | + batch_size, h, w, _ = ori_img.shape | ||
| 70 | + ori_img = ori_img.view(batch_size, -1, 3) | ||
| 71 | + ori_size = torch.cat( | ||
| 72 | + ( | ||
| 73 | + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) | ||
| 74 | + * h, | ||
| 75 | + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) | ||
| 76 | + * w, | ||
| 77 | + ), | ||
| 78 | + dim=1, | ||
| 79 | + ).view(-1) | ||
| 80 | + tri_index, tri_coord, render, real = render_util.render_face_forward( | ||
| 81 | + proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid | ||
| 82 | + ) | ||
| 83 | + ctx.save_for_backward( | ||
| 84 | + ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord | ||
| 85 | + ) | ||
| 86 | + return render, real | ||
| 87 | + | ||
| 88 | + @staticmethod | ||
| 89 | + def backward(ctx, grad_render, grad_real): | ||
| 90 | + ( | ||
| 91 | + ori_img, | ||
| 92 | + ori_size, | ||
| 93 | + proj_geo, | ||
| 94 | + texture, | ||
| 95 | + nbl, | ||
| 96 | + tri_inds, | ||
| 97 | + tri_index, | ||
| 98 | + tri_coord, | ||
| 99 | + ) = ctx.saved_tensors | ||
| 100 | + grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward( | ||
| 101 | + grad_render, | ||
| 102 | + grad_real, | ||
| 103 | + ori_img, | ||
| 104 | + ori_size, | ||
| 105 | + proj_geo, | ||
| 106 | + texture, | ||
| 107 | + nbl, | ||
| 108 | + tri_inds, | ||
| 109 | + tri_index, | ||
| 110 | + tri_coord, | ||
| 111 | + ) | ||
| 112 | + return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None | ||
| 113 | + | ||
| 114 | + | ||
| 115 | +class Render_RGB(nn.Module): | ||
| 116 | + def __init__(self): | ||
| 117 | + super(Render_RGB, self).__init__() | ||
| 118 | + | ||
| 119 | + def forward( | ||
| 120 | + self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid | ||
| 121 | + ): | ||
| 122 | + return Render_Face.apply( | ||
| 123 | + proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid | ||
| 124 | + ) | ||
| 125 | + | ||
| 126 | + | ||
| 127 | +def cal_land(proj_geo, is_visible, lands_info, land_num): | ||
| 128 | + (land_index,) = render_util.update_contour(lands_info, is_visible, land_num) | ||
| 129 | + proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[ | ||
| 130 | + :, :2 | ||
| 131 | + ].reshape(-1, land_num, 2) | ||
| 132 | + return proj_land | ||
| 133 | + | ||
| 134 | + | ||
| 135 | +class Render_Land(nn.Module): | ||
| 136 | + def __init__(self): | ||
| 137 | + super(Render_Land, self).__init__() | ||
| 138 | + lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32) | ||
| 139 | + self.lands_info = torch.as_tensor(lands_info).cuda() | ||
| 140 | + tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64) | ||
| 141 | + self.tris = torch.as_tensor(tris).cuda() - 1 | ||
| 142 | + vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64) | ||
| 143 | + self.vert_tris = torch.as_tensor(vert_tris).cuda() | ||
| 144 | + self.normal_baser = Normal_Base().cuda() | ||
| 145 | + self.renderer = Render_RGB().cuda() | ||
| 146 | + | ||
| 147 | + def render_mesh(self, geometry, euler, trans, cam, ori_img, light): | ||
| 148 | + batch_size, h, w, _ = ori_img.shape | ||
| 149 | + ori_img = ori_img.view(batch_size, -1, 3) | ||
| 150 | + ori_size = torch.cat( | ||
| 151 | + ( | ||
| 152 | + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) | ||
| 153 | + * h, | ||
| 154 | + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) | ||
| 155 | + * w, | ||
| 156 | + ), | ||
| 157 | + dim=1, | ||
| 158 | + ).view(-1) | ||
| 159 | + rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render( | ||
| 160 | + geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img | ||
| 161 | + ) | ||
| 162 | + tri_nb = self.normal_baser(rot_tri_normal.contiguous()) | ||
| 163 | + nbl = torch.bmm( | ||
| 164 | + tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3) | ||
| 165 | + ) | ||
| 166 | + texture = torch.ones_like(geometry) * 200 | ||
| 167 | + (render,) = render_util.render_mesh( | ||
| 168 | + proj_geo, ori_img, ori_size, texture, nbl, self.tris | ||
| 169 | + ) | ||
| 170 | + return render.view(batch_size, h, w, 3).byte() | ||
| 171 | + | ||
| 172 | + def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands): | ||
| 173 | + rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render( | ||
| 174 | + geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img | ||
| 175 | + ) | ||
| 176 | + tri_nb = self.normal_baser(rot_tri_normal.contiguous()) | ||
| 177 | + nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3)) | ||
| 178 | + render, real = self.renderer( | ||
| 179 | + proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid | ||
| 180 | + ) | ||
| 181 | + proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1]) | ||
| 182 | + col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape( | ||
| 183 | + ori_img.shape[0], -1 | ||
| 184 | + ) | ||
| 185 | + col_dis = torch.mean(col_minus * pixel_valid) / ( | ||
| 186 | + torch.mean(pixel_valid) + 0.00001 | ||
| 187 | + ) | ||
| 188 | + land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape( | ||
| 189 | + ori_img.shape[0], -1 | ||
| 190 | + ) | ||
| 191 | + lan_dis = torch.mean(land_dists) | ||
| 192 | + return col_dis, lan_dis |
ernerf/data_utils/face_tracking/util.py
0 → 100644
| 1 | +import torch | ||
| 2 | +import torch.nn as nn | ||
| 3 | +import torch.nn.functional as F | ||
| 4 | + | ||
| 5 | + | ||
| 6 | +def compute_tri_normal(geometry, tris): | ||
| 7 | + tri_1 = tris[:, 0] | ||
| 8 | + tri_2 = tris[:, 1] | ||
| 9 | + tri_3 = tris[:, 2] | ||
| 10 | + vert_1 = torch.index_select(geometry, 1, tri_1) | ||
| 11 | + vert_2 = torch.index_select(geometry, 1, tri_2) | ||
| 12 | + vert_3 = torch.index_select(geometry, 1, tri_3) | ||
| 13 | + nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) | ||
| 14 | + normal = nn.functional.normalize(nnorm) | ||
| 15 | + return normal | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +def euler2rot(euler_angle): | ||
| 19 | + batch_size = euler_angle.shape[0] | ||
| 20 | + theta = euler_angle[:, 0].reshape(-1, 1, 1) | ||
| 21 | + phi = euler_angle[:, 1].reshape(-1, 1, 1) | ||
| 22 | + psi = euler_angle[:, 2].reshape(-1, 1, 1) | ||
| 23 | + one = torch.ones(batch_size, 1, 1).to(euler_angle.device) | ||
| 24 | + zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device) | ||
| 25 | + rot_x = torch.cat( | ||
| 26 | + ( | ||
| 27 | + torch.cat((one, zero, zero), 1), | ||
| 28 | + torch.cat((zero, theta.cos(), theta.sin()), 1), | ||
| 29 | + torch.cat((zero, -theta.sin(), theta.cos()), 1), | ||
| 30 | + ), | ||
| 31 | + 2, | ||
| 32 | + ) | ||
| 33 | + rot_y = torch.cat( | ||
| 34 | + ( | ||
| 35 | + torch.cat((phi.cos(), zero, -phi.sin()), 1), | ||
| 36 | + torch.cat((zero, one, zero), 1), | ||
| 37 | + torch.cat((phi.sin(), zero, phi.cos()), 1), | ||
| 38 | + ), | ||
| 39 | + 2, | ||
| 40 | + ) | ||
| 41 | + rot_z = torch.cat( | ||
| 42 | + ( | ||
| 43 | + torch.cat((psi.cos(), -psi.sin(), zero), 1), | ||
| 44 | + torch.cat((psi.sin(), psi.cos(), zero), 1), | ||
| 45 | + torch.cat((zero, zero, one), 1), | ||
| 46 | + ), | ||
| 47 | + 2, | ||
| 48 | + ) | ||
| 49 | + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) | ||
| 50 | + | ||
| 51 | + | ||
| 52 | +def rot_trans_pts(geometry, rot, trans): | ||
| 53 | + rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None] | ||
| 54 | + return rott_geo.permute(0, 2, 1) | ||
| 55 | + | ||
| 56 | + | ||
| 57 | +def cal_lap_loss(tensor_list, weight_list): | ||
| 58 | + lap_kernel = ( | ||
| 59 | + torch.Tensor((-0.5, 1.0, -0.5)) | ||
| 60 | + .unsqueeze(0) | ||
| 61 | + .unsqueeze(0) | ||
| 62 | + .float() | ||
| 63 | + .to(tensor_list[0].device) | ||
| 64 | + ) | ||
| 65 | + loss_lap = 0 | ||
| 66 | + for i in range(len(tensor_list)): | ||
| 67 | + in_tensor = tensor_list[i] | ||
| 68 | + in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1]) | ||
| 69 | + out_tensor = F.conv1d(in_tensor, lap_kernel) | ||
| 70 | + loss_lap += torch.mean(out_tensor ** 2) * weight_list[i] | ||
| 71 | + return loss_lap | ||
| 72 | + | ||
| 73 | + | ||
| 74 | +def proj_pts(rott_geo, focal_length, cxy): | ||
| 75 | + cx, cy = cxy[0], cxy[1] | ||
| 76 | + X = rott_geo[:, :, 0] | ||
| 77 | + Y = rott_geo[:, :, 1] | ||
| 78 | + Z = rott_geo[:, :, 2] | ||
| 79 | + fxX = focal_length * X | ||
| 80 | + fyY = focal_length * Y | ||
| 81 | + proj_x = -fxX / Z + cx | ||
| 82 | + proj_y = fyY / Z + cy | ||
| 83 | + return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) | ||
| 84 | + | ||
| 85 | + | ||
| 86 | +def forward_rott(geometry, euler_angle, trans): | ||
| 87 | + rot = euler2rot(euler_angle) | ||
| 88 | + rott_geo = rot_trans_pts(geometry, rot, trans) | ||
| 89 | + return rott_geo | ||
| 90 | + | ||
| 91 | + | ||
| 92 | +def forward_transform(geometry, euler_angle, trans, focal_length, cxy): | ||
| 93 | + rot = euler2rot(euler_angle) | ||
| 94 | + rott_geo = rot_trans_pts(geometry, rot, trans) | ||
| 95 | + proj_geo = proj_pts(rott_geo, focal_length, cxy) | ||
| 96 | + return proj_geo | ||
| 97 | + | ||
| 98 | + | ||
| 99 | +def cal_lan_loss(proj_lan, gt_lan): | ||
| 100 | + return torch.mean((proj_lan - gt_lan) ** 2) | ||
| 101 | + | ||
| 102 | + | ||
| 103 | +def cal_col_loss(pred_img, gt_img, img_mask): | ||
| 104 | + pred_img = pred_img.float() | ||
| 105 | + # loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255 | ||
| 106 | + loss = (torch.sum(torch.square(pred_img - gt_img), 3)) * img_mask / 255 | ||
| 107 | + loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2)) | ||
| 108 | + loss = torch.mean(loss) | ||
| 109 | + return loss |
ernerf/data_utils/process.py
0 → 100644
| 1 | +import os | ||
| 2 | +import glob | ||
| 3 | +import tqdm | ||
| 4 | +import json | ||
| 5 | +import argparse | ||
| 6 | +import cv2 | ||
| 7 | +import numpy as np | ||
| 8 | + | ||
| 9 | +def extract_audio(path, out_path, sample_rate=16000): | ||
| 10 | + | ||
| 11 | + print(f'[INFO] ===== extract audio from {path} to {out_path} =====') | ||
| 12 | + cmd = f'ffmpeg -i {path} -f wav -ar {sample_rate} {out_path}' | ||
| 13 | + os.system(cmd) | ||
| 14 | + print(f'[INFO] ===== extracted audio =====') | ||
| 15 | + | ||
| 16 | + | ||
| 17 | +def extract_audio_features(path, mode='wav2vec'): | ||
| 18 | + | ||
| 19 | + print(f'[INFO] ===== extract audio labels for {path} =====') | ||
| 20 | + if mode == 'wav2vec': | ||
| 21 | + cmd = f'python nerf/asr.py --wav {path} --save_feats' | ||
| 22 | + else: # deepspeech | ||
| 23 | + cmd = f'python data_utils/deepspeech_features/extract_ds_features.py --input {path}' | ||
| 24 | + os.system(cmd) | ||
| 25 | + print(f'[INFO] ===== extracted audio labels =====') | ||
| 26 | + | ||
| 27 | + | ||
| 28 | + | ||
| 29 | +def extract_images(path, out_path, fps=25): | ||
| 30 | + | ||
| 31 | + print(f'[INFO] ===== extract images from {path} to {out_path} =====') | ||
| 32 | + cmd = f'ffmpeg -i {path} -vf fps={fps} -qmin 1 -q:v 1 -start_number 0 {os.path.join(out_path, "%d.jpg")}' | ||
| 33 | + os.system(cmd) | ||
| 34 | + print(f'[INFO] ===== extracted images =====') | ||
| 35 | + | ||
| 36 | + | ||
| 37 | +def extract_semantics(ori_imgs_dir, parsing_dir): | ||
| 38 | + | ||
| 39 | + print(f'[INFO] ===== extract semantics from {ori_imgs_dir} to {parsing_dir} =====') | ||
| 40 | + cmd = f'python data_utils/face_parsing/test.py --respath={parsing_dir} --imgpath={ori_imgs_dir}' | ||
| 41 | + os.system(cmd) | ||
| 42 | + print(f'[INFO] ===== extracted semantics =====') | ||
| 43 | + | ||
| 44 | + | ||
| 45 | +def extract_landmarks(ori_imgs_dir): | ||
| 46 | + | ||
| 47 | + print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====') | ||
| 48 | + | ||
| 49 | + import face_alignment | ||
| 50 | + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) | ||
| 51 | + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) | ||
| 52 | + for image_path in tqdm.tqdm(image_paths): | ||
| 53 | + input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] | ||
| 54 | + input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB) | ||
| 55 | + preds = fa.get_landmarks(input) | ||
| 56 | + if len(preds) > 0: | ||
| 57 | + lands = preds[0].reshape(-1, 2)[:,:2] | ||
| 58 | + np.savetxt(image_path.replace('jpg', 'lms'), lands, '%f') | ||
| 59 | + del fa | ||
| 60 | + print(f'[INFO] ===== extracted face landmarks =====') | ||
| 61 | + | ||
| 62 | + | ||
| 63 | +def extract_background(base_dir, ori_imgs_dir): | ||
| 64 | + | ||
| 65 | + print(f'[INFO] ===== extract background image from {ori_imgs_dir} =====') | ||
| 66 | + | ||
| 67 | + from sklearn.neighbors import NearestNeighbors | ||
| 68 | + | ||
| 69 | + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) | ||
| 70 | + # only use 1/20 image_paths | ||
| 71 | + image_paths = image_paths[::20] | ||
| 72 | + # read one image to get H/W | ||
| 73 | + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] | ||
| 74 | + h, w = tmp_image.shape[:2] | ||
| 75 | + | ||
| 76 | + # nearest neighbors | ||
| 77 | + all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() | ||
| 78 | + distss = [] | ||
| 79 | + for image_path in tqdm.tqdm(image_paths): | ||
| 80 | + parse_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png')) | ||
| 81 | + bg = (parse_img[..., 0] == 255) & (parse_img[..., 1] == 255) & (parse_img[..., 2] == 255) | ||
| 82 | + fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) | ||
| 83 | + nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) | ||
| 84 | + dists, _ = nbrs.kneighbors(all_xys) | ||
| 85 | + distss.append(dists) | ||
| 86 | + | ||
| 87 | + distss = np.stack(distss) | ||
| 88 | + max_dist = np.max(distss, 0) | ||
| 89 | + max_id = np.argmax(distss, 0) | ||
| 90 | + | ||
| 91 | + bc_pixs = max_dist > 5 | ||
| 92 | + bc_pixs_id = np.nonzero(bc_pixs) | ||
| 93 | + bc_ids = max_id[bc_pixs] | ||
| 94 | + | ||
| 95 | + imgs = [] | ||
| 96 | + num_pixs = distss.shape[1] | ||
| 97 | + for image_path in image_paths: | ||
| 98 | + img = cv2.imread(image_path) | ||
| 99 | + imgs.append(img) | ||
| 100 | + imgs = np.stack(imgs).reshape(-1, num_pixs, 3) | ||
| 101 | + | ||
| 102 | + bc_img = np.zeros((h*w, 3), dtype=np.uint8) | ||
| 103 | + bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] | ||
| 104 | + bc_img = bc_img.reshape(h, w, 3) | ||
| 105 | + | ||
| 106 | + max_dist = max_dist.reshape(h, w) | ||
| 107 | + bc_pixs = max_dist > 5 | ||
| 108 | + bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose() | ||
| 109 | + fg_xys = np.stack(np.nonzero(bc_pixs)).transpose() | ||
| 110 | + nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) | ||
| 111 | + distances, indices = nbrs.kneighbors(bg_xys) | ||
| 112 | + bg_fg_xys = fg_xys[indices[:, 0]] | ||
| 113 | + bc_img[bg_xys[:, 0], bg_xys[:, 1], :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :] | ||
| 114 | + | ||
| 115 | + cv2.imwrite(os.path.join(base_dir, 'bc.jpg'), bc_img) | ||
| 116 | + | ||
| 117 | + print(f'[INFO] ===== extracted background image =====') | ||
| 118 | + | ||
| 119 | + | ||
| 120 | +def extract_torso_and_gt(base_dir, ori_imgs_dir): | ||
| 121 | + | ||
| 122 | + print(f'[INFO] ===== extract torso and gt images for {base_dir} =====') | ||
| 123 | + | ||
| 124 | + from scipy.ndimage import binary_erosion, binary_dilation | ||
| 125 | + | ||
| 126 | + # load bg | ||
| 127 | + bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED) | ||
| 128 | + | ||
| 129 | + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) | ||
| 130 | + | ||
| 131 | + for image_path in tqdm.tqdm(image_paths): | ||
| 132 | + # read ori image | ||
| 133 | + ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] | ||
| 134 | + | ||
| 135 | + # read semantics | ||
| 136 | + seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png')) | ||
| 137 | + head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0) | ||
| 138 | + neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0) | ||
| 139 | + torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255) | ||
| 140 | + bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255) | ||
| 141 | + | ||
| 142 | + # get gt image | ||
| 143 | + gt_image = ori_image.copy() | ||
| 144 | + gt_image[bg_part] = bg_image[bg_part] | ||
| 145 | + cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image) | ||
| 146 | + | ||
| 147 | + # get torso image | ||
| 148 | + torso_image = gt_image.copy() # rgb | ||
| 149 | + torso_image[head_part] = bg_image[head_part] | ||
| 150 | + torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha | ||
| 151 | + | ||
| 152 | + # torso part "vertical" in-painting... | ||
| 153 | + L = 8 + 1 | ||
| 154 | + torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2] | ||
| 155 | + # lexsort: sort 2D coords first by y then by x, | ||
| 156 | + # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes | ||
| 157 | + inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1])) | ||
| 158 | + torso_coords = torso_coords[inds] | ||
| 159 | + # choose the top pixel for each column | ||
| 160 | + u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True) | ||
| 161 | + top_torso_coords = torso_coords[uid] # [m, 2] | ||
| 162 | + # only keep top-is-head pixels | ||
| 163 | + top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) | ||
| 164 | + mask = head_part[tuple(top_torso_coords_up.T)] | ||
| 165 | + if mask.any(): | ||
| 166 | + top_torso_coords = top_torso_coords[mask] | ||
| 167 | + # get the color | ||
| 168 | + top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3] | ||
| 169 | + # construct inpaint coords (vertically up, or minus in x) | ||
| 170 | + inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2] | ||
| 171 | + inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2] | ||
| 172 | + inpaint_torso_coords += inpaint_offsets | ||
| 173 | + inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2] | ||
| 174 | + inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3] | ||
| 175 | + darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1] | ||
| 176 | + inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3] | ||
| 177 | + # set color | ||
| 178 | + torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors | ||
| 179 | + | ||
| 180 | + inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool) | ||
| 181 | + inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True | ||
| 182 | + else: | ||
| 183 | + inpaint_torso_mask = None | ||
| 184 | + | ||
| 185 | + | ||
| 186 | + # neck part "vertical" in-painting... | ||
| 187 | + push_down = 4 | ||
| 188 | + L = 48 + push_down + 1 | ||
| 189 | + | ||
| 190 | + neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3) | ||
| 191 | + | ||
| 192 | + neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2] | ||
| 193 | + # lexsort: sort 2D coords first by y then by x, | ||
| 194 | + # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes | ||
| 195 | + inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1])) | ||
| 196 | + neck_coords = neck_coords[inds] | ||
| 197 | + # choose the top pixel for each column | ||
| 198 | + u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True) | ||
| 199 | + top_neck_coords = neck_coords[uid] # [m, 2] | ||
| 200 | + # only keep top-is-head pixels | ||
| 201 | + top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0]) | ||
| 202 | + mask = head_part[tuple(top_neck_coords_up.T)] | ||
| 203 | + | ||
| 204 | + top_neck_coords = top_neck_coords[mask] | ||
| 205 | + # push these top down for 4 pixels to make the neck inpainting more natural... | ||
| 206 | + offset_down = np.minimum(ucnt[mask] - 1, push_down) | ||
| 207 | + top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1) | ||
| 208 | + # get the color | ||
| 209 | + top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3] | ||
| 210 | + # construct inpaint coords (vertically up, or minus in x) | ||
| 211 | + inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2] | ||
| 212 | + inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2] | ||
| 213 | + inpaint_neck_coords += inpaint_offsets | ||
| 214 | + inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2] | ||
| 215 | + inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3] | ||
| 216 | + darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1] | ||
| 217 | + inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3] | ||
| 218 | + # set color | ||
| 219 | + torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors | ||
| 220 | + | ||
| 221 | + # apply blurring to the inpaint area to avoid vertical-line artifects... | ||
| 222 | + inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool) | ||
| 223 | + inpaint_mask[tuple(inpaint_neck_coords.T)] = True | ||
| 224 | + | ||
| 225 | + blur_img = torso_image.copy() | ||
| 226 | + blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT) | ||
| 227 | + | ||
| 228 | + torso_image[inpaint_mask] = blur_img[inpaint_mask] | ||
| 229 | + | ||
| 230 | + # set mask | ||
| 231 | + mask = (neck_part | torso_part | inpaint_mask) | ||
| 232 | + if inpaint_torso_mask is not None: | ||
| 233 | + mask = mask | inpaint_torso_mask | ||
| 234 | + torso_image[~mask] = 0 | ||
| 235 | + torso_alpha[~mask] = 0 | ||
| 236 | + | ||
| 237 | + cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1)) | ||
| 238 | + | ||
| 239 | + print(f'[INFO] ===== extracted torso and gt images =====') | ||
| 240 | + | ||
| 241 | + | ||
| 242 | +def face_tracking(ori_imgs_dir): | ||
| 243 | + | ||
| 244 | + print(f'[INFO] ===== perform face tracking =====') | ||
| 245 | + | ||
| 246 | + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) | ||
| 247 | + | ||
| 248 | + # read one image to get H/W | ||
| 249 | + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] | ||
| 250 | + h, w = tmp_image.shape[:2] | ||
| 251 | + | ||
| 252 | + cmd = f'python data_utils/face_tracking/face_tracker.py --path={ori_imgs_dir} --img_h={h} --img_w={w} --frame_num={len(image_paths)}' | ||
| 253 | + | ||
| 254 | + os.system(cmd) | ||
| 255 | + | ||
| 256 | + print(f'[INFO] ===== finished face tracking =====') | ||
| 257 | + | ||
| 258 | + | ||
| 259 | +def save_transforms(base_dir, ori_imgs_dir): | ||
| 260 | + print(f'[INFO] ===== save transforms =====') | ||
| 261 | + | ||
| 262 | + import torch | ||
| 263 | + | ||
| 264 | + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) | ||
| 265 | + | ||
| 266 | + # read one image to get H/W | ||
| 267 | + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] | ||
| 268 | + h, w = tmp_image.shape[:2] | ||
| 269 | + | ||
| 270 | + params_dict = torch.load(os.path.join(base_dir, 'track_params.pt')) | ||
| 271 | + focal_len = params_dict['focal'] | ||
| 272 | + euler_angle = params_dict['euler'] | ||
| 273 | + trans = params_dict['trans'] / 10.0 | ||
| 274 | + valid_num = euler_angle.shape[0] | ||
| 275 | + | ||
| 276 | + def euler2rot(euler_angle): | ||
| 277 | + batch_size = euler_angle.shape[0] | ||
| 278 | + theta = euler_angle[:, 0].reshape(-1, 1, 1) | ||
| 279 | + phi = euler_angle[:, 1].reshape(-1, 1, 1) | ||
| 280 | + psi = euler_angle[:, 2].reshape(-1, 1, 1) | ||
| 281 | + one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) | ||
| 282 | + zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) | ||
| 283 | + rot_x = torch.cat(( | ||
| 284 | + torch.cat((one, zero, zero), 1), | ||
| 285 | + torch.cat((zero, theta.cos(), theta.sin()), 1), | ||
| 286 | + torch.cat((zero, -theta.sin(), theta.cos()), 1), | ||
| 287 | + ), 2) | ||
| 288 | + rot_y = torch.cat(( | ||
| 289 | + torch.cat((phi.cos(), zero, -phi.sin()), 1), | ||
| 290 | + torch.cat((zero, one, zero), 1), | ||
| 291 | + torch.cat((phi.sin(), zero, phi.cos()), 1), | ||
| 292 | + ), 2) | ||
| 293 | + rot_z = torch.cat(( | ||
| 294 | + torch.cat((psi.cos(), -psi.sin(), zero), 1), | ||
| 295 | + torch.cat((psi.sin(), psi.cos(), zero), 1), | ||
| 296 | + torch.cat((zero, zero, one), 1) | ||
| 297 | + ), 2) | ||
| 298 | + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) | ||
| 299 | + | ||
| 300 | + | ||
| 301 | + # train_val_split = int(valid_num*0.5) | ||
| 302 | + # train_val_split = valid_num - 25 * 20 # take the last 20s as valid set. | ||
| 303 | + train_val_split = int(valid_num * 10 / 11) | ||
| 304 | + | ||
| 305 | + train_ids = torch.arange(0, train_val_split) | ||
| 306 | + val_ids = torch.arange(train_val_split, valid_num) | ||
| 307 | + | ||
| 308 | + rot = euler2rot(euler_angle) | ||
| 309 | + rot_inv = rot.permute(0, 2, 1) | ||
| 310 | + trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2)) | ||
| 311 | + | ||
| 312 | + pose = torch.eye(4, dtype=torch.float32) | ||
| 313 | + save_ids = ['train', 'val'] | ||
| 314 | + train_val_ids = [train_ids, val_ids] | ||
| 315 | + mean_z = -float(torch.mean(trans[:, 2]).item()) | ||
| 316 | + | ||
| 317 | + for split in range(2): | ||
| 318 | + transform_dict = dict() | ||
| 319 | + transform_dict['focal_len'] = float(focal_len[0]) | ||
| 320 | + transform_dict['cx'] = float(w/2.0) | ||
| 321 | + transform_dict['cy'] = float(h/2.0) | ||
| 322 | + transform_dict['frames'] = [] | ||
| 323 | + ids = train_val_ids[split] | ||
| 324 | + save_id = save_ids[split] | ||
| 325 | + | ||
| 326 | + for i in ids: | ||
| 327 | + i = i.item() | ||
| 328 | + frame_dict = dict() | ||
| 329 | + frame_dict['img_id'] = i | ||
| 330 | + frame_dict['aud_id'] = i | ||
| 331 | + | ||
| 332 | + pose[:3, :3] = rot_inv[i] | ||
| 333 | + pose[:3, 3] = trans_inv[i, :, 0] | ||
| 334 | + | ||
| 335 | + frame_dict['transform_matrix'] = pose.numpy().tolist() | ||
| 336 | + | ||
| 337 | + transform_dict['frames'].append(frame_dict) | ||
| 338 | + | ||
| 339 | + with open(os.path.join(base_dir, 'transforms_' + save_id + '.json'), 'w') as fp: | ||
| 340 | + json.dump(transform_dict, fp, indent=2, separators=(',', ': ')) | ||
| 341 | + | ||
| 342 | + print(f'[INFO] ===== finished saving transforms =====') | ||
| 343 | + | ||
| 344 | + | ||
| 345 | +if __name__ == '__main__': | ||
| 346 | + parser = argparse.ArgumentParser() | ||
| 347 | + parser.add_argument('path', type=str, help="path to video file") | ||
| 348 | + parser.add_argument('--task', type=int, default=-1, help="-1 means all") | ||
| 349 | + parser.add_argument('--asr', type=str, default='wav2vec', help="wav2vec or deepspeech") | ||
| 350 | + | ||
| 351 | + opt = parser.parse_args() | ||
| 352 | + | ||
| 353 | + base_dir = os.path.dirname(opt.path) | ||
| 354 | + | ||
| 355 | + wav_path = os.path.join(base_dir, 'aud.wav') | ||
| 356 | + ori_imgs_dir = os.path.join(base_dir, 'ori_imgs') | ||
| 357 | + parsing_dir = os.path.join(base_dir, 'parsing') | ||
| 358 | + gt_imgs_dir = os.path.join(base_dir, 'gt_imgs') | ||
| 359 | + torso_imgs_dir = os.path.join(base_dir, 'torso_imgs') | ||
| 360 | + | ||
| 361 | + os.makedirs(ori_imgs_dir, exist_ok=True) | ||
| 362 | + os.makedirs(parsing_dir, exist_ok=True) | ||
| 363 | + os.makedirs(gt_imgs_dir, exist_ok=True) | ||
| 364 | + os.makedirs(torso_imgs_dir, exist_ok=True) | ||
| 365 | + | ||
| 366 | + | ||
| 367 | + # extract audio | ||
| 368 | + if opt.task == -1 or opt.task == 1: | ||
| 369 | + extract_audio(opt.path, wav_path) | ||
| 370 | + | ||
| 371 | + # extract audio features | ||
| 372 | + if opt.task == -1 or opt.task == 2: | ||
| 373 | + extract_audio_features(wav_path, mode=opt.asr) | ||
| 374 | + | ||
| 375 | + # extract images | ||
| 376 | + if opt.task == -1 or opt.task == 3: | ||
| 377 | + extract_images(opt.path, ori_imgs_dir) | ||
| 378 | + | ||
| 379 | + # face parsing | ||
| 380 | + if opt.task == -1 or opt.task == 4: | ||
| 381 | + extract_semantics(ori_imgs_dir, parsing_dir) | ||
| 382 | + | ||
| 383 | + # extract bg | ||
| 384 | + if opt.task == -1 or opt.task == 5: | ||
| 385 | + extract_background(base_dir, ori_imgs_dir) | ||
| 386 | + | ||
| 387 | + # extract torso images and gt_images | ||
| 388 | + if opt.task == -1 or opt.task == 6: | ||
| 389 | + extract_torso_and_gt(base_dir, ori_imgs_dir) | ||
| 390 | + | ||
| 391 | + # extract face landmarks | ||
| 392 | + if opt.task == -1 or opt.task == 7: | ||
| 393 | + extract_landmarks(ori_imgs_dir) | ||
| 394 | + | ||
| 395 | + # face tracking | ||
| 396 | + if opt.task == -1 or opt.task == 8: | ||
| 397 | + face_tracking(ori_imgs_dir) | ||
| 398 | + | ||
| 399 | + # save transforms.json | ||
| 400 | + if opt.task == -1 or opt.task == 9: | ||
| 401 | + save_transforms(base_dir, ori_imgs_dir) | ||
| 402 | + |
ernerf/encoding.py
0 → 100644
| 1 | +import torch | ||
| 2 | +import torch.nn as nn | ||
| 3 | +import torch.nn.functional as F | ||
| 4 | + | ||
| 5 | + | ||
| 6 | +def get_encoder(encoding, input_dim=3, | ||
| 7 | + multires=6, | ||
| 8 | + degree=4, | ||
| 9 | + num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, | ||
| 10 | + **kwargs): | ||
| 11 | + | ||
| 12 | + if encoding == 'None': | ||
| 13 | + return lambda x, **kwargs: x, input_dim | ||
| 14 | + | ||
| 15 | + elif encoding == 'frequency': | ||
| 16 | + from .freqencoder import FreqEncoder | ||
| 17 | + encoder = FreqEncoder(input_dim=input_dim, degree=multires) | ||
| 18 | + | ||
| 19 | + elif encoding == 'spherical_harmonics': | ||
| 20 | + from .shencoder import SHEncoder | ||
| 21 | + encoder = SHEncoder(input_dim=input_dim, degree=degree) | ||
| 22 | + | ||
| 23 | + elif encoding == 'hashgrid': | ||
| 24 | + from .gridencoder import GridEncoder | ||
| 25 | + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) | ||
| 26 | + | ||
| 27 | + elif encoding == 'tiledgrid': | ||
| 28 | + from .gridencoder import GridEncoder | ||
| 29 | + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) | ||
| 30 | + | ||
| 31 | + elif encoding == 'ash': | ||
| 32 | + from .ashencoder import AshEncoder | ||
| 33 | + encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution) | ||
| 34 | + | ||
| 35 | + else: | ||
| 36 | + raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, spherical_harmonics, hashgrid, tiledgrid]') | ||
| 37 | + | ||
| 38 | + return encoder, encoder.output_dim |
ernerf/freqencoder/__init__.py
0 → 100644
| 1 | +from .freq import FreqEncoder |
ernerf/freqencoder/backend.py
0 → 100644
| 1 | +import os | ||
| 2 | +from torch.utils.cpp_extension import load | ||
| 3 | + | ||
| 4 | +_src_path = os.path.dirname(os.path.abspath(__file__)) | ||
| 5 | + | ||
| 6 | +nvcc_flags = [ | ||
| 7 | + '-O3', '-std=c++17', | ||
| 8 | + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-allow-unsupported-compiler', | ||
| 9 | + '-use_fast_math' | ||
| 10 | +] | ||
| 11 | + | ||
| 12 | +if os.name == "posix": | ||
| 13 | + c_flags = ['-O3', '-std=c++17'] | ||
| 14 | +elif os.name == "nt": | ||
| 15 | + c_flags = ['/O2', '/std:c++17'] | ||
| 16 | + | ||
| 17 | + # find cl.exe | ||
| 18 | + def find_cl_path(): | ||
| 19 | + import glob | ||
| 20 | + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: | ||
| 21 | + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) | ||
| 22 | + if paths: | ||
| 23 | + return paths[0] | ||
| 24 | + | ||
| 25 | + # If cl.exe is not on path, try to find it. | ||
| 26 | + if os.system("where cl.exe >nul 2>nul") != 0: | ||
| 27 | + cl_path = find_cl_path() | ||
| 28 | + if cl_path is None: | ||
| 29 | + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") | ||
| 30 | + os.environ["PATH"] += ";" + cl_path | ||
| 31 | + | ||
| 32 | +_backend = load(name='_freqencoder', | ||
| 33 | + extra_cflags=c_flags, | ||
| 34 | + extra_cuda_cflags=nvcc_flags, | ||
| 35 | + sources=[os.path.join(_src_path, 'src', f) for f in [ | ||
| 36 | + 'freqencoder.cu', | ||
| 37 | + 'bindings.cpp', | ||
| 38 | + ]], | ||
| 39 | + ) | ||
| 40 | + | ||
| 41 | +__all__ = ['_backend'] |
ernerf/freqencoder/freq.py
0 → 100644
| 1 | +import numpy as np | ||
| 2 | + | ||
| 3 | +import torch | ||
| 4 | +import torch.nn as nn | ||
| 5 | +from torch.autograd import Function | ||
| 6 | +from torch.autograd.function import once_differentiable | ||
| 7 | +from torch.cuda.amp import custom_bwd, custom_fwd | ||
| 8 | + | ||
| 9 | +try: | ||
| 10 | + import _freqencoder as _backend | ||
| 11 | +except ImportError: | ||
| 12 | + from .backend import _backend | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +class _freq_encoder(Function): | ||
| 16 | + @staticmethod | ||
| 17 | + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision | ||
| 18 | + def forward(ctx, inputs, degree, output_dim): | ||
| 19 | + # inputs: [B, input_dim], float | ||
| 20 | + # RETURN: [B, F], float | ||
| 21 | + | ||
| 22 | + if not inputs.is_cuda: inputs = inputs.cuda() | ||
| 23 | + inputs = inputs.contiguous() | ||
| 24 | + | ||
| 25 | + B, input_dim = inputs.shape # batch size, coord dim | ||
| 26 | + | ||
| 27 | + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) | ||
| 28 | + | ||
| 29 | + _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) | ||
| 30 | + | ||
| 31 | + ctx.save_for_backward(inputs, outputs) | ||
| 32 | + ctx.dims = [B, input_dim, degree, output_dim] | ||
| 33 | + | ||
| 34 | + return outputs | ||
| 35 | + | ||
| 36 | + @staticmethod | ||
| 37 | + #@once_differentiable | ||
| 38 | + @custom_bwd | ||
| 39 | + def backward(ctx, grad): | ||
| 40 | + # grad: [B, C * C] | ||
| 41 | + | ||
| 42 | + grad = grad.contiguous() | ||
| 43 | + inputs, outputs = ctx.saved_tensors | ||
| 44 | + B, input_dim, degree, output_dim = ctx.dims | ||
| 45 | + | ||
| 46 | + grad_inputs = torch.zeros_like(inputs) | ||
| 47 | + _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) | ||
| 48 | + | ||
| 49 | + return grad_inputs, None, None | ||
| 50 | + | ||
| 51 | + | ||
| 52 | +freq_encode = _freq_encoder.apply | ||
| 53 | + | ||
| 54 | + | ||
| 55 | +class FreqEncoder(nn.Module): | ||
| 56 | + def __init__(self, input_dim=3, degree=4): | ||
| 57 | + super().__init__() | ||
| 58 | + | ||
| 59 | + self.input_dim = input_dim | ||
| 60 | + self.degree = degree | ||
| 61 | + self.output_dim = input_dim + input_dim * 2 * degree | ||
| 62 | + | ||
| 63 | + def __repr__(self): | ||
| 64 | + return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" | ||
| 65 | + | ||
| 66 | + def forward(self, inputs, **kwargs): | ||
| 67 | + # inputs: [..., input_dim] | ||
| 68 | + # return: [..., ] | ||
| 69 | + | ||
| 70 | + prefix_shape = list(inputs.shape[:-1]) | ||
| 71 | + inputs = inputs.reshape(-1, self.input_dim) | ||
| 72 | + | ||
| 73 | + outputs = freq_encode(inputs, self.degree, self.output_dim) | ||
| 74 | + | ||
| 75 | + outputs = outputs.reshape(prefix_shape + [self.output_dim]) | ||
| 76 | + | ||
| 77 | + return outputs |
ernerf/freqencoder/setup.py
0 → 100644
| 1 | +import os | ||
| 2 | +from setuptools import setup | ||
| 3 | +from torch.utils.cpp_extension import BuildExtension, CUDAExtension | ||
| 4 | + | ||
| 5 | +_src_path = os.path.dirname(os.path.abspath(__file__)) | ||
| 6 | + | ||
| 7 | +nvcc_flags = [ | ||
| 8 | + '-O3', '-std=c++17', | ||
| 9 | + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-allow-unsupported-compiler', | ||
| 10 | + '-use_fast_math' | ||
| 11 | +] | ||
| 12 | + | ||
| 13 | +if os.name == "posix": | ||
| 14 | + c_flags = ['-O3', '-std=c++17'] | ||
| 15 | +elif os.name == "nt": | ||
| 16 | + c_flags = ['/O2', '/std:c++17'] | ||
| 17 | + | ||
| 18 | + # find cl.exe | ||
| 19 | + def find_cl_path(): | ||
| 20 | + import glob | ||
| 21 | + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: | ||
| 22 | + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) | ||
| 23 | + if paths: | ||
| 24 | + return paths[0] | ||
| 25 | + | ||
| 26 | + # If cl.exe is not on path, try to find it. | ||
| 27 | + if os.system("where cl.exe >nul 2>nul") != 0: | ||
| 28 | + cl_path = find_cl_path() | ||
| 29 | + if cl_path is None: | ||
| 30 | + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") | ||
| 31 | + os.environ["PATH"] += ";" + cl_path | ||
| 32 | + | ||
| 33 | +setup( | ||
| 34 | + name='freqencoder', # package name, import this to use python API | ||
| 35 | + ext_modules=[ | ||
| 36 | + CUDAExtension( | ||
| 37 | + name='_freqencoder', # extension name, import this to use CUDA API | ||
| 38 | + sources=[os.path.join(_src_path, 'src', f) for f in [ | ||
| 39 | + 'freqencoder.cu', | ||
| 40 | + 'bindings.cpp', | ||
| 41 | + ]], | ||
| 42 | + extra_compile_args={ | ||
| 43 | + 'cxx': c_flags, | ||
| 44 | + 'nvcc': nvcc_flags, | ||
| 45 | + } | ||
| 46 | + ), | ||
| 47 | + ], | ||
| 48 | + cmdclass={ | ||
| 49 | + 'build_ext': BuildExtension, | ||
| 50 | + } | ||
| 51 | +) |
ernerf/freqencoder/src/bindings.cpp
0 → 100644
ernerf/freqencoder/src/freqencoder.cu
0 → 100644
| 1 | +#include <stdint.h> | ||
| 2 | + | ||
| 3 | +#include <cuda.h> | ||
| 4 | +#include <cuda_fp16.h> | ||
| 5 | +#include <cuda_runtime.h> | ||
| 6 | + | ||
| 7 | +#include <ATen/cuda/CUDAContext.h> | ||
| 8 | +#include <torch/torch.h> | ||
| 9 | + | ||
| 10 | +#include <algorithm> | ||
| 11 | +#include <stdexcept> | ||
| 12 | + | ||
| 13 | +#include <cstdio> | ||
| 14 | + | ||
| 15 | + | ||
| 16 | +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") | ||
| 17 | +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") | ||
| 18 | +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") | ||
| 19 | +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") | ||
| 20 | + | ||
| 21 | +inline constexpr __device__ float PI() { return 3.141592653589793f; } | ||
| 22 | + | ||
| 23 | +template <typename T> | ||
| 24 | +__host__ __device__ T div_round_up(T val, T divisor) { | ||
| 25 | + return (val + divisor - 1) / divisor; | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +// inputs: [B, D] | ||
| 29 | +// outputs: [B, C], C = D + D * deg * 2 | ||
| 30 | +__global__ void kernel_freq( | ||
| 31 | + const float * __restrict__ inputs, | ||
| 32 | + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, | ||
| 33 | + float * outputs | ||
| 34 | +) { | ||
| 35 | + // parallel on per-element | ||
| 36 | + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; | ||
| 37 | + if (t >= B * C) return; | ||
| 38 | + | ||
| 39 | + // get index | ||
| 40 | + const uint32_t b = t / C; | ||
| 41 | + const uint32_t c = t - b * C; // t % C; | ||
| 42 | + | ||
| 43 | + // locate | ||
| 44 | + inputs += b * D; | ||
| 45 | + outputs += t; | ||
| 46 | + | ||
| 47 | + // write self | ||
| 48 | + if (c < D) { | ||
| 49 | + outputs[0] = inputs[c]; | ||
| 50 | + // write freq | ||
| 51 | + } else { | ||
| 52 | + const uint32_t col = c / D - 1; | ||
| 53 | + const uint32_t d = c % D; | ||
| 54 | + const uint32_t freq = col / 2; | ||
| 55 | + const float phase_shift = (col % 2) * (PI() / 2); | ||
| 56 | + outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); | ||
| 57 | + } | ||
| 58 | +} | ||
| 59 | + | ||
| 60 | +// grad: [B, C], C = D + D * deg * 2 | ||
| 61 | +// outputs: [B, C] | ||
| 62 | +// grad_inputs: [B, D] | ||
| 63 | +__global__ void kernel_freq_backward( | ||
| 64 | + const float * __restrict__ grad, | ||
| 65 | + const float * __restrict__ outputs, | ||
| 66 | + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, | ||
| 67 | + float * grad_inputs | ||
| 68 | +) { | ||
| 69 | + // parallel on per-element | ||
| 70 | + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; | ||
| 71 | + if (t >= B * D) return; | ||
| 72 | + | ||
| 73 | + const uint32_t b = t / D; | ||
| 74 | + const uint32_t d = t - b * D; // t % D; | ||
| 75 | + | ||
| 76 | + // locate | ||
| 77 | + grad += b * C; | ||
| 78 | + outputs += b * C; | ||
| 79 | + grad_inputs += t; | ||
| 80 | + | ||
| 81 | + // register | ||
| 82 | + float result = grad[d]; | ||
| 83 | + grad += D; | ||
| 84 | + outputs += D; | ||
| 85 | + | ||
| 86 | + for (uint32_t f = 0; f < deg; f++) { | ||
| 87 | + result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); | ||
| 88 | + grad += 2 * D; | ||
| 89 | + outputs += 2 * D; | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + // write | ||
| 93 | + grad_inputs[0] = result; | ||
| 94 | +} | ||
| 95 | + | ||
| 96 | + | ||
| 97 | +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { | ||
| 98 | + CHECK_CUDA(inputs); | ||
| 99 | + CHECK_CUDA(outputs); | ||
| 100 | + | ||
| 101 | + CHECK_CONTIGUOUS(inputs); | ||
| 102 | + CHECK_CONTIGUOUS(outputs); | ||
| 103 | + | ||
| 104 | + CHECK_IS_FLOATING(inputs); | ||
| 105 | + CHECK_IS_FLOATING(outputs); | ||
| 106 | + | ||
| 107 | + static constexpr uint32_t N_THREADS = 128; | ||
| 108 | + | ||
| 109 | + kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>()); | ||
| 110 | +} | ||
| 111 | + | ||
| 112 | + | ||
| 113 | +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { | ||
| 114 | + CHECK_CUDA(grad); | ||
| 115 | + CHECK_CUDA(outputs); | ||
| 116 | + CHECK_CUDA(grad_inputs); | ||
| 117 | + | ||
| 118 | + CHECK_CONTIGUOUS(grad); | ||
| 119 | + CHECK_CONTIGUOUS(outputs); | ||
| 120 | + CHECK_CONTIGUOUS(grad_inputs); | ||
| 121 | + | ||
| 122 | + CHECK_IS_FLOATING(grad); | ||
| 123 | + CHECK_IS_FLOATING(outputs); | ||
| 124 | + CHECK_IS_FLOATING(grad_inputs); | ||
| 125 | + | ||
| 126 | + static constexpr uint32_t N_THREADS = 128; | ||
| 127 | + | ||
| 128 | + kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>()); | ||
| 129 | +} |
ernerf/freqencoder/src/freqencoder.h
0 → 100644
| 1 | +# pragma once | ||
| 2 | + | ||
| 3 | +#include <stdint.h> | ||
| 4 | +#include <torch/torch.h> | ||
| 5 | + | ||
| 6 | +// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) | ||
| 7 | +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); | ||
| 8 | + | ||
| 9 | +// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) | ||
| 10 | +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); |
ernerf/gridencoder/__init__.py
0 → 100644
| 1 | +from .grid import GridEncoder |
ernerf/gridencoder/backend.py
0 → 100644
| 1 | +import os | ||
| 2 | +from torch.utils.cpp_extension import load | ||
| 3 | + | ||
| 4 | +_src_path = os.path.dirname(os.path.abspath(__file__)) | ||
| 5 | + | ||
| 6 | +nvcc_flags = [ | ||
| 7 | + '-O3', '-std=c++17', | ||
| 8 | + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', | ||
| 9 | +] | ||
| 10 | + | ||
| 11 | +if os.name == "posix": | ||
| 12 | + c_flags = ['-O3', '-std=c++17', '-finput-charset=UTF-8'] | ||
| 13 | +elif os.name == "nt": | ||
| 14 | + c_flags = ['/O2', '/std:c++17', '/finput-charset=UTF-8'] | ||
| 15 | + | ||
| 16 | + # find cl.exe | ||
| 17 | + def find_cl_path(): | ||
| 18 | + import glob | ||
| 19 | + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: | ||
| 20 | + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) | ||
| 21 | + if paths: | ||
| 22 | + return paths[0] | ||
| 23 | + | ||
| 24 | + # If cl.exe is not on path, try to find it. | ||
| 25 | + if os.system("where cl.exe >nul 2>nul") != 0: | ||
| 26 | + cl_path = find_cl_path() | ||
| 27 | + if cl_path is None: | ||
| 28 | + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") | ||
| 29 | + os.environ["PATH"] += ";" + cl_path | ||
| 30 | + | ||
| 31 | +_backend = load(name='_grid_encoder', | ||
| 32 | + extra_cflags=c_flags, | ||
| 33 | + extra_cuda_cflags=nvcc_flags, | ||
| 34 | + sources=[os.path.join(_src_path, 'src', f) for f in [ | ||
| 35 | + 'gridencoder.cu', | ||
| 36 | + 'bindings.cpp', | ||
| 37 | + ]], | ||
| 38 | + ) | ||
| 39 | + | ||
| 40 | +__all__ = ['_backend'] |
ernerf/gridencoder/grid.py
0 → 100644
| 1 | +import numpy as np | ||
| 2 | + | ||
| 3 | +import torch | ||
| 4 | +import torch.nn as nn | ||
| 5 | +from torch.autograd import Function | ||
| 6 | +from torch.autograd.function import once_differentiable | ||
| 7 | +from torch.cuda.amp import custom_bwd, custom_fwd | ||
| 8 | + | ||
| 9 | +try: | ||
| 10 | + import _gridencoder as _backend | ||
| 11 | +except ImportError: | ||
| 12 | + from .backend import _backend | ||
| 13 | + | ||
| 14 | +_gridtype_to_id = { | ||
| 15 | + 'hash': 0, | ||
| 16 | + 'tiled': 1, | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +class _grid_encode(Function): | ||
| 20 | + @staticmethod | ||
| 21 | + @custom_fwd | ||
| 22 | + def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False): | ||
| 23 | + # inputs: [B, D], float in [0, 1] | ||
| 24 | + # embeddings: [sO, C], float | ||
| 25 | + # offsets: [L + 1], int | ||
| 26 | + # RETURN: [B, F], float | ||
| 27 | + | ||
| 28 | + inputs = inputs.float().contiguous() | ||
| 29 | + | ||
| 30 | + B, D = inputs.shape # batch size, coord dim | ||
| 31 | + L = offsets.shape[0] - 1 # level | ||
| 32 | + C = embeddings.shape[1] # embedding dim for each level | ||
| 33 | + S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f | ||
| 34 | + H = base_resolution # base resolution | ||
| 35 | + | ||
| 36 | + # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) | ||
| 37 | + # if C % 2 != 0, force float, since half for atomicAdd is very slow. | ||
| 38 | + if torch.is_autocast_enabled() and C % 2 == 0: | ||
| 39 | + embeddings = embeddings.to(torch.half) | ||
| 40 | + | ||
| 41 | + # L first, optimize cache for cuda kernel, but needs an extra permute later | ||
| 42 | + outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) | ||
| 43 | + | ||
| 44 | + if calc_grad_inputs: | ||
| 45 | + dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) | ||
| 46 | + else: | ||
| 47 | + dy_dx = None | ||
| 48 | + | ||
| 49 | + _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners) | ||
| 50 | + | ||
| 51 | + # permute back to [B, L * C] | ||
| 52 | + outputs = outputs.permute(1, 0, 2).reshape(B, L * C) | ||
| 53 | + | ||
| 54 | + ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) | ||
| 55 | + ctx.dims = [B, D, C, L, S, H, gridtype] | ||
| 56 | + ctx.align_corners = align_corners | ||
| 57 | + | ||
| 58 | + return outputs | ||
| 59 | + | ||
| 60 | + @staticmethod | ||
| 61 | + #@once_differentiable | ||
| 62 | + @custom_bwd | ||
| 63 | + def backward(ctx, grad): | ||
| 64 | + | ||
| 65 | + inputs, embeddings, offsets, dy_dx = ctx.saved_tensors | ||
| 66 | + B, D, C, L, S, H, gridtype = ctx.dims | ||
| 67 | + align_corners = ctx.align_corners | ||
| 68 | + | ||
| 69 | + # grad: [B, L * C] --> [L, B, C] | ||
| 70 | + grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() | ||
| 71 | + | ||
| 72 | + grad_embeddings = torch.zeros_like(embeddings) | ||
| 73 | + | ||
| 74 | + if dy_dx is not None: | ||
| 75 | + grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) | ||
| 76 | + else: | ||
| 77 | + grad_inputs = None | ||
| 78 | + | ||
| 79 | + _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners) | ||
| 80 | + | ||
| 81 | + if dy_dx is not None: | ||
| 82 | + grad_inputs = grad_inputs.to(inputs.dtype) | ||
| 83 | + | ||
| 84 | + return grad_inputs, grad_embeddings, None, None, None, None, None, None | ||
| 85 | + | ||
| 86 | + | ||
| 87 | + | ||
| 88 | +grid_encode = _grid_encode.apply | ||
| 89 | + | ||
| 90 | + | ||
| 91 | +class GridEncoder(nn.Module): | ||
| 92 | + def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False): | ||
| 93 | + super().__init__() | ||
| 94 | + | ||
| 95 | + # the finest resolution desired at the last level, if provided, overridee per_level_scale | ||
| 96 | + if desired_resolution is not None: | ||
| 97 | + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) | ||
| 98 | + | ||
| 99 | + self.input_dim = input_dim # coord dims, 2 or 3 | ||
| 100 | + self.num_levels = num_levels # num levels, each level multiply resolution by 2 | ||
| 101 | + self.level_dim = level_dim # encode channels per level | ||
| 102 | + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. | ||
| 103 | + self.log2_hashmap_size = log2_hashmap_size | ||
| 104 | + self.base_resolution = base_resolution | ||
| 105 | + self.output_dim = num_levels * level_dim | ||
| 106 | + self.gridtype = gridtype | ||
| 107 | + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" | ||
| 108 | + self.align_corners = align_corners | ||
| 109 | + | ||
| 110 | + # allocate parameters | ||
| 111 | + offsets = [] | ||
| 112 | + offset = 0 | ||
| 113 | + self.max_params = 2 ** log2_hashmap_size | ||
| 114 | + for i in range(num_levels): | ||
| 115 | + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) | ||
| 116 | + params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number | ||
| 117 | + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible | ||
| 118 | + offsets.append(offset) | ||
| 119 | + offset += params_in_level | ||
| 120 | + # print(resolution, params_in_level) | ||
| 121 | + offsets.append(offset) | ||
| 122 | + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) | ||
| 123 | + self.register_buffer('offsets', offsets) | ||
| 124 | + | ||
| 125 | + self.n_params = offsets[-1] * level_dim | ||
| 126 | + | ||
| 127 | + # parameters | ||
| 128 | + self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) | ||
| 129 | + | ||
| 130 | + self.reset_parameters() | ||
| 131 | + | ||
| 132 | + def reset_parameters(self): | ||
| 133 | + std = 1e-4 | ||
| 134 | + self.embeddings.data.uniform_(-std, std) | ||
| 135 | + | ||
| 136 | + def __repr__(self): | ||
| 137 | + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}" | ||
| 138 | + | ||
| 139 | + def forward(self, inputs, bound=1): | ||
| 140 | + # inputs: [..., input_dim], normalized real world positions in [-bound, bound] | ||
| 141 | + # return: [..., num_levels * level_dim] | ||
| 142 | + | ||
| 143 | + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] | ||
| 144 | + | ||
| 145 | + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) | ||
| 146 | + | ||
| 147 | + prefix_shape = list(inputs.shape[:-1]) | ||
| 148 | + inputs = inputs.view(-1, self.input_dim) | ||
| 149 | + | ||
| 150 | + outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners) | ||
| 151 | + outputs = outputs.view(prefix_shape + [self.output_dim]) | ||
| 152 | + | ||
| 153 | + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) | ||
| 154 | + | ||
| 155 | + return outputs |
ernerf/gridencoder/setup.py
0 → 100644
| 1 | +import os | ||
| 2 | +from setuptools import setup | ||
| 3 | +from torch.utils.cpp_extension import BuildExtension, CUDAExtension | ||
| 4 | + | ||
| 5 | +_src_path = os.path.dirname(os.path.abspath(__file__)) | ||
| 6 | + | ||
| 7 | +nvcc_flags = [ | ||
| 8 | + '-O3', '-std=c++17', | ||
| 9 | + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler', | ||
| 10 | +] | ||
| 11 | + | ||
| 12 | +if os.name == "posix": | ||
| 13 | + c_flags = ['-O3', '-std=c++17'] | ||
| 14 | +elif os.name == "nt": | ||
| 15 | + c_flags = ['/O2', '/std:c++17'] | ||
| 16 | + | ||
| 17 | + # find cl.exe | ||
| 18 | + def find_cl_path(): | ||
| 19 | + import glob | ||
| 20 | + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: | ||
| 21 | + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) | ||
| 22 | + if paths: | ||
| 23 | + return paths[0] | ||
| 24 | + | ||
| 25 | + # If cl.exe is not on path, try to find it. | ||
| 26 | + if os.system("where cl.exe >nul 2>nul") != 0: | ||
| 27 | + cl_path = find_cl_path() | ||
| 28 | + if cl_path is None: | ||
| 29 | + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") | ||
| 30 | + os.environ["PATH"] += ";" + cl_path | ||
| 31 | + | ||
| 32 | +setup( | ||
| 33 | + name='gridencoder', # package name, import this to use python API | ||
| 34 | + ext_modules=[ | ||
| 35 | + CUDAExtension( | ||
| 36 | + name='_gridencoder', # extension name, import this to use CUDA API | ||
| 37 | + sources=[os.path.join(_src_path, 'src', f) for f in [ | ||
| 38 | + 'gridencoder.cu', | ||
| 39 | + 'bindings.cpp', | ||
| 40 | + ]], | ||
| 41 | + extra_compile_args={ | ||
| 42 | + 'cxx': c_flags, | ||
| 43 | + 'nvcc': nvcc_flags, | ||
| 44 | + } | ||
| 45 | + ), | ||
| 46 | + ], | ||
| 47 | + cmdclass={ | ||
| 48 | + 'build_ext': BuildExtension, | ||
| 49 | + } | ||
| 50 | +) |
ernerf/gridencoder/src/bindings.cpp
0 → 100644
-
Please register or login to post a comment