冯杨

基础提交

Showing 46 changed files with 4734 additions and 0 deletions

Too many changes to show.

To preserve performance only 46 of 46+ files are displayed.

  1 +github: [lipku]
  1 +__pycache__/
  2 +build/
  3 +*.egg-info/
  4 +*.so
  5 +*.mp4
  6 +
  7 +tmp*
  8 +trial*/
  9 +
  10 +data
  11 +data_utils/face_tracking/3DMM/*
  12 +data_utils/face_parsing/79999_iter.pth
  13 +
  14 +pretrained
  15 +*.mp4
  16 +.DS_Store
  17 +workspace/log_ngp.txt
  18 +.idea
  1 +{
  2 + "version": "0.2.0",
  3 + "configurations": [
  4 + {
  5 + "name": "Python: 带参数调试",
  6 + "type": "python",
  7 + "request": "launch",
  8 + "program": "${workspaceFolder}/app.py",
  9 + "args": [
  10 + "--transport", "webrtc",
  11 + "--model", "wav2lip",
  12 + "--avatar_id", "wav2lip256_avatar1"
  13 + ],
  14 + "console": "integratedTerminal",
  15 + "justMyCode": true
  16 + }
  17 + ]
  18 +}
  1 +# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
  2 +#
  3 +# NVIDIA CORPORATION and its licensors retain all intellectual property
  4 +# and proprietary rights in and to this software, related documentation
  5 +# and any modifications thereto. Any use, reproduction, disclosure or
  6 +# distribution of this software and related documentation without an express
  7 +# license agreement from NVIDIA CORPORATION is strictly prohibited.
  8 +
  9 +ARG BASE_IMAGE=nvcr.io/nvidia/cuda:11.6.1-cudnn8-devel-ubuntu20.04
  10 +FROM $BASE_IMAGE
  11 +
  12 +RUN apt-get update -yq --fix-missing \
  13 + && DEBIAN_FRONTEND=noninteractive apt-get install -yq --no-install-recommends \
  14 + pkg-config \
  15 + wget \
  16 + cmake \
  17 + curl \
  18 + git \
  19 + vim
  20 +
  21 +#ENV PYTHONDONTWRITEBYTECODE=1
  22 +#ENV PYTHONUNBUFFERED=1
  23 +
  24 +# nvidia-container-runtime
  25 +#ENV NVIDIA_VISIBLE_DEVICES all
  26 +#ENV NVIDIA_DRIVER_CAPABILITIES compute,utility,graphics
  27 +
  28 +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
  29 +RUN sh Miniconda3-latest-Linux-x86_64.sh -b -u -p ~/miniconda3
  30 +RUN ~/miniconda3/bin/conda init
  31 +RUN source ~/.bashrc
  32 +RUN conda create -n nerfstream python=3.10
  33 +RUN conda activate nerfstream
  34 +
  35 +RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
  36 +# install depend
  37 +RUN conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
  38 +Copy requirements.txt ./
  39 +RUN pip install -r requirements.txt
  40 +
  41 +# additional libraries
  42 +RUN pip install "git+https://github.com/facebookresearch/pytorch3d.git"
  43 +RUN pip install tensorflow-gpu==2.8.0
  44 +
  45 +RUN pip uninstall protobuf
  46 +RUN pip install protobuf==3.20.1
  47 +
  48 +RUN conda install ffmpeg
  49 +Copy ../python_rtmpstream /python_rtmpstream
  50 +WORKDIR /python_rtmpstream/python
  51 +RUN pip install .
  52 +
  53 +Copy ../nerfstream /nerfstream
  54 +WORKDIR /nerfstream
  55 +CMD ["python3", "app.py"]
  1 + Apache License
  2 + Version 2.0, January 2004
  3 + http://www.apache.org/licenses/
  4 +
  5 + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
  6 +
  7 + 1. Definitions.
  8 +
  9 + "License" shall mean the terms and conditions for use, reproduction,
  10 + and distribution as defined by Sections 1 through 9 of this document.
  11 +
  12 + "Licensor" shall mean the copyright owner or entity authorized by
  13 + the copyright owner that is granting the License.
  14 +
  15 + "Legal Entity" shall mean the union of the acting entity and all
  16 + other entities that control, are controlled by, or are under common
  17 + control with that entity. For the purposes of this definition,
  18 + "control" means (i) the power, direct or indirect, to cause the
  19 + direction or management of such entity, whether by contract or
  20 + otherwise, or (ii) ownership of fifty percent (50%) or more of the
  21 + outstanding shares, or (iii) beneficial ownership of such entity.
  22 +
  23 + "You" (or "Your") shall mean an individual or Legal Entity
  24 + exercising permissions granted by this License.
  25 +
  26 + "Source" form shall mean the preferred form for making modifications,
  27 + including but not limited to software source code, documentation
  28 + source, and configuration files.
  29 +
  30 + "Object" form shall mean any form resulting from mechanical
  31 + transformation or translation of a Source form, including but
  32 + not limited to compiled object code, generated documentation,
  33 + and conversions to other media types.
  34 +
  35 + "Work" shall mean the work of authorship, whether in Source or
  36 + Object form, made available under the License, as indicated by a
  37 + copyright notice that is included in or attached to the work
  38 + (an example is provided in the Appendix below).
  39 +
  40 + "Derivative Works" shall mean any work, whether in Source or Object
  41 + form, that is based on (or derived from) the Work and for which the
  42 + editorial revisions, annotations, elaborations, or other modifications
  43 + represent, as a whole, an original work of authorship. For the purposes
  44 + of this License, Derivative Works shall not include works that remain
  45 + separable from, or merely link (or bind by name) to the interfaces of,
  46 + the Work and Derivative Works thereof.
  47 +
  48 + "Contribution" shall mean any work of authorship, including
  49 + the original version of the Work and any modifications or additions
  50 + to that Work or Derivative Works thereof, that is intentionally
  51 + submitted to Licensor for inclusion in the Work by the copyright owner
  52 + or by an individual or Legal Entity authorized to submit on behalf of
  53 + the copyright owner. For the purposes of this definition, "submitted"
  54 + means any form of electronic, verbal, or written communication sent
  55 + to the Licensor or its representatives, including but not limited to
  56 + communication on electronic mailing lists, source code control systems,
  57 + and issue tracking systems that are managed by, or on behalf of, the
  58 + Licensor for the purpose of discussing and improving the Work, but
  59 + excluding communication that is conspicuously marked or otherwise
  60 + designated in writing by the copyright owner as "Not a Contribution."
  61 +
  62 + "Contributor" shall mean Licensor and any individual or Legal Entity
  63 + on behalf of whom a Contribution has been received by Licensor and
  64 + subsequently incorporated within the Work.
  65 +
  66 + 2. Grant of Copyright License. Subject to the terms and conditions of
  67 + this License, each Contributor hereby grants to You a perpetual,
  68 + worldwide, non-exclusive, no-charge, royalty-free, irrevocable
  69 + copyright license to reproduce, prepare Derivative Works of,
  70 + publicly display, publicly perform, sublicense, and distribute the
  71 + Work and such Derivative Works in Source or Object form.
  72 +
  73 + 3. Grant of Patent License. Subject to the terms and conditions of
  74 + this License, each Contributor hereby grants to You a perpetual,
  75 + worldwide, non-exclusive, no-charge, royalty-free, irrevocable
  76 + (except as stated in this section) patent license to make, have made,
  77 + use, offer to sell, sell, import, and otherwise transfer the Work,
  78 + where such license applies only to those patent claims licensable
  79 + by such Contributor that are necessarily infringed by their
  80 + Contribution(s) alone or by combination of their Contribution(s)
  81 + with the Work to which such Contribution(s) was submitted. If You
  82 + institute patent litigation against any entity (including a
  83 + cross-claim or counterclaim in a lawsuit) alleging that the Work
  84 + or a Contribution incorporated within the Work constitutes direct
  85 + or contributory patent infringement, then any patent licenses
  86 + granted to You under this License for that Work shall terminate
  87 + as of the date such litigation is filed.
  88 +
  89 + 4. Redistribution. You may reproduce and distribute copies of the
  90 + Work or Derivative Works thereof in any medium, with or without
  91 + modifications, and in Source or Object form, provided that You
  92 + meet the following conditions:
  93 +
  94 + (a) You must give any other recipients of the Work or
  95 + Derivative Works a copy of this License; and
  96 +
  97 + (b) You must cause any modified files to carry prominent notices
  98 + stating that You changed the files; and
  99 +
  100 + (c) You must retain, in the Source form of any Derivative Works
  101 + that You distribute, all copyright, patent, trademark, and
  102 + attribution notices from the Source form of the Work,
  103 + excluding those notices that do not pertain to any part of
  104 + the Derivative Works; and
  105 +
  106 + (d) If the Work includes a "NOTICE" text file as part of its
  107 + distribution, then any Derivative Works that You distribute must
  108 + include a readable copy of the attribution notices contained
  109 + within such NOTICE file, excluding those notices that do not
  110 + pertain to any part of the Derivative Works, in at least one
  111 + of the following places: within a NOTICE text file distributed
  112 + as part of the Derivative Works; within the Source form or
  113 + documentation, if provided along with the Derivative Works; or,
  114 + within a display generated by the Derivative Works, if and
  115 + wherever such third-party notices normally appear. The contents
  116 + of the NOTICE file are for informational purposes only and
  117 + do not modify the License. You may add Your own attribution
  118 + notices within Derivative Works that You distribute, alongside
  119 + or as an addendum to the NOTICE text from the Work, provided
  120 + that such additional attribution notices cannot be construed
  121 + as modifying the License.
  122 +
  123 + You may add Your own copyright statement to Your modifications and
  124 + may provide additional or different license terms and conditions
  125 + for use, reproduction, or distribution of Your modifications, or
  126 + for any such Derivative Works as a whole, provided Your use,
  127 + reproduction, and distribution of the Work otherwise complies with
  128 + the conditions stated in this License.
  129 +
  130 + 5. Submission of Contributions. Unless You explicitly state otherwise,
  131 + any Contribution intentionally submitted for inclusion in the Work
  132 + by You to the Licensor shall be under the terms and conditions of
  133 + this License, without any additional terms or conditions.
  134 + Notwithstanding the above, nothing herein shall supersede or modify
  135 + the terms of any separate license agreement you may have executed
  136 + with Licensor regarding such Contributions.
  137 +
  138 + 6. Trademarks. This License does not grant permission to use the trade
  139 + names, trademarks, service marks, or product names of the Licensor,
  140 + except as required for reasonable and customary use in describing the
  141 + origin of the Work and reproducing the content of the NOTICE file.
  142 +
  143 + 7. Disclaimer of Warranty. Unless required by applicable law or
  144 + agreed to in writing, Licensor provides the Work (and each
  145 + Contributor provides its Contributions) on an "AS IS" BASIS,
  146 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
  147 + implied, including, without limitation, any warranties or conditions
  148 + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
  149 + PARTICULAR PURPOSE. You are solely responsible for determining the
  150 + appropriateness of using or redistributing the Work and assume any
  151 + risks associated with Your exercise of permissions under this License.
  152 +
  153 + 8. Limitation of Liability. In no event and under no legal theory,
  154 + whether in tort (including negligence), contract, or otherwise,
  155 + unless required by applicable law (such as deliberate and grossly
  156 + negligent acts) or agreed to in writing, shall any Contributor be
  157 + liable to You for damages, including any direct, indirect, special,
  158 + incidental, or consequential damages of any character arising as a
  159 + result of this License or out of the use or inability to use the
  160 + Work (including but not limited to damages for loss of goodwill,
  161 + work stoppage, computer failure or malfunction, or any and all
  162 + other commercial damages or losses), even if such Contributor
  163 + has been advised of the possibility of such damages.
  164 +
  165 + 9. Accepting Warranty or Additional Liability. While redistributing
  166 + the Work or Derivative Works thereof, You may choose to offer,
  167 + and charge a fee for, acceptance of support, warranty, indemnity,
  168 + or other liability obligations and/or rights consistent with this
  169 + License. However, in accepting such obligations, You may act only
  170 + on Your own behalf and on Your sole responsibility, not on behalf
  171 + of any other Contributor, and only if You agree to indemnify,
  172 + defend, and hold each Contributor harmless for any liability
  173 + incurred by, or claims asserted against, such Contributor by reason
  174 + of your accepting any such warranty or additional liability.
  175 +
  176 + END OF TERMS AND CONDITIONS
  177 +
  178 + APPENDIX: How to apply the Apache License to your work.
  179 +
  180 + To apply the Apache License to your work, attach the following
  181 + boilerplate notice, with the fields enclosed by brackets "[]"
  182 + replaced with your own identifying information. (Don't include
  183 + the brackets!) The text should be enclosed in the appropriate
  184 + comment syntax for the file format. We also recommend that a
  185 + file or class name and description of purpose be included on the
  186 + same "printed page" as the copyright notice for easier
  187 + identification within third-party archives.
  188 +
  189 + Copyright [livetalking@lipku]
  190 +
  191 + Licensed under the Apache License, Version 2.0 (the "License");
  192 + you may not use this file except in compliance with the License.
  193 + You may obtain a copy of the License at
  194 +
  195 + http://www.apache.org/licenses/LICENSE-2.0
  196 +
  197 + Unless required by applicable law or agreed to in writing, software
  198 + distributed under the License is distributed on an "AS IS" BASIS,
  199 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  200 + See the License for the specific language governing permissions and
  201 + limitations under the License.
  1 +Real time interactive streaming digital human, realize audio video synchronous dialogue. It can basically achieve commercial effects.
  2 +实时交互流式数字人,实现音视频同步对话。基本可以达到商用效果
  3 +
  4 +[ernerf 效果](https://www.bilibili.com/video/BV1PM4m1y7Q2/) [musetalk 效果](https://www.bilibili.com/video/BV1gm421N7vQ/) [wav2lip 效果](https://www.bilibili.com/video/BV1Bw4m1e74P/)
  5 +
  6 +## 为避免与 3d 数字人混淆,原项目 metahuman-stream 改名为 livetalking,原有链接地址继续可用
  7 +
  8 +## News
  9 +
  10 +- 2024.12.8 完善多并发,显存不随并发数增加
  11 +- 2024.12.21 添加 wav2lip、musetalk 模型预热,解决第一次推理卡顿问题。感谢@heimaojinzhangyz
  12 +- 2024.12.28 添加数字人模型 Ultralight-Digital-Human。 感谢@lijihua2017
  13 +- 2025.2.7 添加 fish-speech tts
  14 +- 2025.2.21 添加 wav2lip256 开源模型 感谢@不蠢不蠢
  15 +- 2025.3.2 添加腾讯语音合成服务
  16 +
  17 +## Features
  18 +
  19 +1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human
  20 +2. 支持声音克隆
  21 +3. 支持数字人说话被打断
  22 +4. 支持全身视频拼接
  23 +5. 支持 rtmp 和 webrtc
  24 +6. 支持视频编排:不说话时播放自定义视频
  25 +7. 支持多并发
  26 +
  27 +## 1. Installation
  28 +
  29 +Tested on Ubuntu 20.04, Python3.10, Pytorch 1.12 and CUDA 11.3
  30 +
  31 +### 1.1 Install dependency
  32 +
  33 +```bash
  34 +conda create -n nerfstream python=3.10
  35 +conda activate nerfstream
  36 +#如果cuda版本不为11.3(运行nvidia-smi确认版本),根据<https://pytorch.org/get-started/previous-versions/>安装对应版本的pytorch
  37 +conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
  38 +pip install -r requirements.txt
  39 +#如果需要训练ernerf模型,安装下面的库
  40 +# pip install "git+https://github.com/facebookresearch/pytorch3d.git"
  41 +# pip install tensorflow-gpu==2.8.0
  42 +# pip install --upgrade "protobuf<=3.20.1"
  43 +```
  44 +
  45 +安装常见问题[FAQ](https://livetalking-doc.readthedocs.io/en/latest/faq.html)
  46 +linux cuda 环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886
  47 +
  48 +## 2. Quick Start
  49 +
  50 +- 下载模型
  51 + 百度云盘<https://pan.baidu.com/s/1yOsQ06-RIDTJd3HFCw4wtA> 密码: ltua
  52 + GoogleDriver <https://drive.google.com/drive/folders/1FOC_MD6wdogyyX_7V1d4NDIO7P9NlSAJ?usp=sharing>
  53 + 将 wav2lip256.pth 拷到本项目的 models 下, 重命名为 wav2lip.pth;
  54 + 将 wav2lip256_avatar1.tar.gz 解压后整个文件夹拷到本项目的 data/avatars 下
  55 +- 运行
  56 + python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar1 --preload 2
  57 +
  58 + 使用 GPU 启动模特 3 号:python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar3 --preload 2
  59 + 用浏览器打开 http://serverip:8010/webrtcapi.html , 先点‘start',播放数字人视频;然后在文本框输入任意文字,提交。数字人播报该段文字
  60 + <font color=red>服务端需要开放端口 tcp:8010; udp:1-65536 </font>
  61 + 如果需要商用高清 wav2lip 模型,可以与我联系购买
  62 +
  63 +- 快速体验
  64 + <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_GitHub_livetalking1.3> 用该镜像创建实例即可运行成功
  65 +
  66 +如果访问不了 huggingface,在运行前
  67 +
  68 +```
  69 +export HF_ENDPOINT=https://hf-mirror.com
  70 +```
  71 +
  72 +## 3. More Usage
  73 +
  74 +使用说明: <https://livetalking-doc.readthedocs.io/>
  75 +
  76 +## 4. Docker Run
  77 +
  78 +不需要前面的安装,直接运行。
  79 +
  80 +```
  81 +docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/codewithgpu2/lipku-metahuman-stream:2K9qaMBu8v
  82 +```
  83 +
  84 +代码在/root/metahuman-stream,先 git pull 拉一下最新代码,然后执行命令同第 2、3 步
  85 +
  86 +提供如下镜像
  87 +
  88 +- autodl 镜像: <https://www.codewithgpu.com/i/lipku/metahuman-stream/base>
  89 + [autodl 教程](https://livetalking-doc.readthedocs.io/en/latest/autodl/README.html)
  90 +- ucloud 镜像: <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_livetalking1.3>
  91 + 可以开放任意端口,不需要另外部署 srs 服务.
  92 + [ucloud 教程](https://livetalking-doc.readthedocs.io/en/latest/ucloud/ucloud.html)
  93 +
  94 +## 5. TODO
  95 +
  96 +- [x] 添加 chatgpt 实现数字人对话
  97 +- [x] 声音克隆
  98 +- [x] 数字人静音时用一段视频代替
  99 +- [x] MuseTalk
  100 +- [x] Wav2Lip
  101 +- [x] Ultralight-Digital-Human
  102 +
  103 +---
  104 +
  105 +如果本项目对你有帮助,帮忙点个 star。也欢迎感兴趣的朋友一起来完善该项目.
  106 +
  107 +- 知识星球: https://t.zsxq.com/7NMyO 沉淀高质量常见问题、最佳实践经验、问题解答
  108 +- 微信公众号:数字人技术
  109 + ![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyfaiaLZGuMGQXnhLWxibpJUS2gfs8Dje6JuMY8zu2tVyU9n8Zx1yaNncvKHBMibX0ocehoITy5qQEZg/640?wxfrom=12&tp=wxpic&usePicPrefetch=1&wx_fmt=jpeg&from=appmsg)
  1 +###############################################################################
  2 +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
  3 +# email: lipku@foxmail.com
  4 +#
  5 +# Licensed under the Apache License, Version 2.0 (the "License");
  6 +# you may not use this file except in compliance with the License.
  7 +# You may obtain a copy of the License at
  8 +#
  9 +# http://www.apache.org/licenses/LICENSE-2.0
  10 +#
  11 +# Unless required by applicable law or agreed to in writing, software
  12 +# distributed under the License is distributed on an "AS IS" BASIS,
  13 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14 +# See the License for the specific language governing permissions and
  15 +# limitations under the License.
  16 +###############################################################################
  17 +
  18 +# server.py
  19 +from flask import Flask, render_template,send_from_directory,request, jsonify
  20 +from flask_sockets import Sockets
  21 +import base64
  22 +import json
  23 +#import gevent
  24 +#from gevent import pywsgi
  25 +#from geventwebsocket.handler import WebSocketHandler
  26 +import re
  27 +import numpy as np
  28 +from threading import Thread,Event
  29 +#import multiprocessing
  30 +import torch.multiprocessing as mp
  31 +
  32 +from aiohttp import web
  33 +import aiohttp
  34 +import aiohttp_cors
  35 +from aiortc import RTCPeerConnection, RTCSessionDescription
  36 +from aiortc.rtcrtpsender import RTCRtpSender
  37 +from webrtc import HumanPlayer
  38 +from basereal import BaseReal
  39 +from llm import llm_response
  40 +
  41 +import argparse
  42 +import random
  43 +import shutil
  44 +import asyncio
  45 +import torch
  46 +from typing import Dict
  47 +from logger import logger
  48 +
  49 +
  50 +app = Flask(__name__)
  51 +#sockets = Sockets(app)
  52 +nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
  53 +opt = None
  54 +model = None
  55 +avatar = None
  56 +
  57 +
  58 +#####webrtc###############################
  59 +pcs = set()
  60 +
  61 +def randN(N)->int:
  62 + '''生成长度为 N的随机数 '''
  63 + min = pow(10, N - 1)
  64 + max = pow(10, N)
  65 + return random.randint(min, max - 1)
  66 +
  67 +def build_nerfreal(sessionid:int)->BaseReal:
  68 + opt.sessionid=sessionid
  69 + if opt.model == 'wav2lip':
  70 + from lipreal import LipReal
  71 + nerfreal = LipReal(opt,model,avatar)
  72 + elif opt.model == 'musetalk':
  73 + from musereal import MuseReal
  74 + nerfreal = MuseReal(opt,model,avatar)
  75 + elif opt.model == 'ernerf':
  76 + from nerfreal import NeRFReal
  77 + nerfreal = NeRFReal(opt,model,avatar)
  78 + elif opt.model == 'ultralight':
  79 + from lightreal import LightReal
  80 + nerfreal = LightReal(opt,model,avatar)
  81 + return nerfreal
  82 +
  83 +#@app.route('/offer', methods=['POST'])
  84 +async def offer(request):
  85 + params = await request.json()
  86 + offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
  87 +
  88 + if len(nerfreals) >= opt.max_session:
  89 + logger.info('reach max session')
  90 + return -1
  91 + sessionid = randN(6) #len(nerfreals)
  92 + logger.info('sessionid=%d',sessionid)
  93 + nerfreals[sessionid] = None
  94 + nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
  95 + nerfreals[sessionid] = nerfreal
  96 +
  97 + pc = RTCPeerConnection()
  98 + pcs.add(pc)
  99 +
  100 + @pc.on("connectionstatechange")
  101 + async def on_connectionstatechange():
  102 + logger.info("Connection state is %s" % pc.connectionState)
  103 + if pc.connectionState == "failed":
  104 + await pc.close()
  105 + pcs.discard(pc)
  106 + del nerfreals[sessionid]
  107 + if pc.connectionState == "closed":
  108 + pcs.discard(pc)
  109 + del nerfreals[sessionid]
  110 +
  111 + player = HumanPlayer(nerfreals[sessionid])
  112 + audio_sender = pc.addTrack(player.audio)
  113 + video_sender = pc.addTrack(player.video)
  114 + capabilities = RTCRtpSender.getCapabilities("video")
  115 + preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
  116 + preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
  117 + preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
  118 + transceiver = pc.getTransceivers()[1]
  119 + transceiver.setCodecPreferences(preferences)
  120 +
  121 + await pc.setRemoteDescription(offer)
  122 +
  123 + answer = await pc.createAnswer()
  124 + await pc.setLocalDescription(answer)
  125 +
  126 + #return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
  127 +
  128 + return web.Response(
  129 + content_type="application/json",
  130 + text=json.dumps(
  131 + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid":sessionid}
  132 + ),
  133 + )
  134 +
  135 +async def human(request):
  136 + params = await request.json()
  137 +
  138 + sessionid = params.get('sessionid',0)
  139 + if params.get('interrupt'):
  140 + nerfreals[sessionid].flush_talk()
  141 +
  142 + if params['type']=='echo':
  143 + nerfreals[sessionid].put_msg_txt(params['text'])
  144 + elif params['type']=='chat':
  145 + res=await asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'],nerfreals[sessionid])
  146 + #nerfreals[sessionid].put_msg_txt(res)
  147 +
  148 + return web.Response(
  149 + content_type="application/json",
  150 + text=json.dumps(
  151 + {"code": 0, "data":"ok"}
  152 + ),
  153 + )
  154 +
  155 +async def humanaudio(request):
  156 + try:
  157 + params = await request.json()
  158 + sessionid = int(params.get('sessionid', 0))
  159 + fileobj = params['file_url']
  160 + if fileobj.startswith("http"):
  161 + async with aiohttp.ClientSession() as session:
  162 + async with session.get(fileobj) as response:
  163 + if response.status == 200:
  164 + filebytes = await response.read()
  165 + else:
  166 + return web.Response(
  167 + content_type="application/json",
  168 + text=json.dumps({"code": -1, "msg": "Error downloading file"})
  169 + )
  170 + else:
  171 + filename = fileobj.filename
  172 + filebytes = fileobj.file.read()
  173 +
  174 + nerfreals[sessionid].put_audio_file(filebytes)
  175 +
  176 + return web.Response(
  177 + content_type="application/json",
  178 + text=json.dumps({"code": 0, "msg": "ok"})
  179 + )
  180 +
  181 + except Exception as e:
  182 + return web.Response(
  183 + content_type="application/json",
  184 + text=json.dumps({"code": -1, "msg": "err", "data": str(e)})
  185 + )
  186 +
  187 +async def set_audiotype(request):
  188 + params = await request.json()
  189 +
  190 + sessionid = params.get('sessionid',0)
  191 + nerfreals[sessionid].set_curr_state(params['audiotype'],params['reinit'])
  192 +
  193 + return web.Response(
  194 + content_type="application/json",
  195 + text=json.dumps(
  196 + {"code": 0, "data":"ok"}
  197 + ),
  198 + )
  199 +
  200 +async def record(request):
  201 + params = await request.json()
  202 +
  203 + sessionid = params.get('sessionid',0)
  204 + if params['type']=='start_record':
  205 + # nerfreals[sessionid].put_msg_txt(params['text'])
  206 + nerfreals[sessionid].start_recording()
  207 + elif params['type']=='end_record':
  208 + nerfreals[sessionid].stop_recording()
  209 + return web.Response(
  210 + content_type="application/json",
  211 + text=json.dumps(
  212 + {"code": 0, "data":"ok"}
  213 + ),
  214 + )
  215 +
  216 +async def is_speaking(request):
  217 + params = await request.json()
  218 +
  219 + sessionid = params.get('sessionid',0)
  220 + return web.Response(
  221 + content_type="application/json",
  222 + text=json.dumps(
  223 + {"code": 0, "data": nerfreals[sessionid].is_speaking()}
  224 + ),
  225 + )
  226 +
  227 +
  228 +async def on_shutdown(app):
  229 + # close peer connections
  230 + coros = [pc.close() for pc in pcs]
  231 + await asyncio.gather(*coros)
  232 + pcs.clear()
  233 +
  234 +async def post(url,data):
  235 + try:
  236 + async with aiohttp.ClientSession() as session:
  237 + async with session.post(url,data=data) as response:
  238 + return await response.text()
  239 + except aiohttp.ClientError as e:
  240 + logger.info(f'Error: {e}')
  241 +
  242 +async def run(push_url,sessionid):
  243 + nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
  244 + nerfreals[sessionid] = nerfreal
  245 +
  246 + pc = RTCPeerConnection()
  247 + pcs.add(pc)
  248 +
  249 + @pc.on("connectionstatechange")
  250 + async def on_connectionstatechange():
  251 + logger.info("Connection state is %s" % pc.connectionState)
  252 + if pc.connectionState == "failed":
  253 + await pc.close()
  254 + pcs.discard(pc)
  255 +
  256 + player = HumanPlayer(nerfreals[sessionid])
  257 + audio_sender = pc.addTrack(player.audio)
  258 + video_sender = pc.addTrack(player.video)
  259 +
  260 + await pc.setLocalDescription(await pc.createOffer())
  261 + answer = await post(push_url,pc.localDescription.sdp)
  262 + await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
  263 +##########################################
  264 +# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
  265 +# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
  266 +if __name__ == '__main__':
  267 + mp.set_start_method('spawn')
  268 + parser = argparse.ArgumentParser()
  269 + parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
  270 + parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area")
  271 + parser.add_argument('--torso_imgs', type=str, default="", help="torso images path")
  272 +
  273 + parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye")
  274 +
  275 + parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use")
  276 + parser.add_argument('--workspace', type=str, default='data/video')
  277 + parser.add_argument('--seed', type=int, default=0)
  278 +
  279 + ### training options
  280 + parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth')
  281 +
  282 + parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step")
  283 + parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
  284 + parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)")
  285 + parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
  286 + parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
  287 + parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
  288 + parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
  289 +
  290 + ### loss set
  291 + parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
  292 + parser.add_argument('--amb_aud_loss', type=int, default=1, help="use ambient aud loss")
  293 + parser.add_argument('--amb_eye_loss', type=int, default=1, help="use ambient eye loss")
  294 + parser.add_argument('--unc_loss', type=int, default=1, help="use uncertainty loss")
  295 + parser.add_argument('--lambda_amb', type=float, default=1e-4, help="lambda for ambient loss")
  296 +
  297 + ### network backbone options
  298 + parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
  299 +
  300 + parser.add_argument('--bg_img', type=str, default='white', help="background image")
  301 + parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
  302 + parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
  303 + parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
  304 + parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
  305 +
  306 + parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform")
  307 +
  308 + ### dataset options
  309 + parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
  310 + parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
  311 + # (the default value is for the fox dataset)
  312 + parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
  313 + parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
  314 + parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
  315 + parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
  316 + parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
  317 + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)")
  318 + parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)")
  319 + parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
  320 +
  321 + parser.add_argument('--init_lips', action='store_true', help="init lips region")
  322 + parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
  323 + parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...")
  324 +
  325 + parser.add_argument('--torso', action='store_true', help="fix head and train torso")
  326 + parser.add_argument('--head_ckpt', type=str, default='', help="head model")
  327 +
  328 + ### GUI options
  329 + parser.add_argument('--gui', action='store_true', help="start a GUI")
  330 + parser.add_argument('--W', type=int, default=450, help="GUI width")
  331 + parser.add_argument('--H', type=int, default=450, help="GUI height")
  332 + parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center")
  333 + parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy")
  334 + parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
  335 +
  336 + ### else
  337 + parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
  338 + parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)")
  339 + parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
  340 +
  341 + parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
  342 + parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size")
  343 +
  344 + parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
  345 +
  346 + parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension")
  347 + parser.add_argument('--part', action='store_true', help="use partial training data (1/10)")
  348 + parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
  349 +
  350 + parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
  351 + parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size")
  352 + parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
  353 +
  354 + # asr
  355 + parser.add_argument('--asr', action='store_true', help="load asr for real-time app")
  356 + parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
  357 + parser.add_argument('--asr_play', action='store_true', help="play out the audio")
  358 +
  359 + #parser.add_argument('--asr_model', type=str, default='deepspeech')
  360 + parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') #
  361 + # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
  362 + # parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
  363 +
  364 + parser.add_argument('--asr_save_feats', action='store_true')
  365 + # audio FPS
  366 + parser.add_argument('--fps', type=int, default=50)
  367 + # sliding window left-middle-right length (unit: 20ms)
  368 + parser.add_argument('-l', type=int, default=10)
  369 + parser.add_argument('-m', type=int, default=8)
  370 + parser.add_argument('-r', type=int, default=10)
  371 +
  372 + parser.add_argument('--fullbody', action='store_true', help="fullbody human")
  373 + parser.add_argument('--fullbody_img', type=str, default='data/fullbody/img')
  374 + parser.add_argument('--fullbody_width', type=int, default=580)
  375 + parser.add_argument('--fullbody_height', type=int, default=1080)
  376 + parser.add_argument('--fullbody_offset_x', type=int, default=0)
  377 + parser.add_argument('--fullbody_offset_y', type=int, default=0)
  378 +
  379 + #musetalk opt
  380 + parser.add_argument('--avatar_id', type=str, default='avator_1')
  381 + parser.add_argument('--bbox_shift', type=int, default=5)
  382 + parser.add_argument('--batch_size', type=int, default=16)
  383 +
  384 + # parser.add_argument('--customvideo', action='store_true', help="custom video")
  385 + # parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
  386 + # parser.add_argument('--customvideo_imgnum', type=int, default=1)
  387 +
  388 + parser.add_argument('--customvideo_config', type=str, default='')
  389 +
  390 + parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits cosyvoice
  391 + parser.add_argument('--REF_FILE', type=str, default=None)
  392 + parser.add_argument('--REF_TEXT', type=str, default=None)
  393 + parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
  394 + # parser.add_argument('--CHARACTER', type=str, default='test')
  395 + # parser.add_argument('--EMOTION', type=str, default='default')
  396 +
  397 + parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip
  398 +
  399 + parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
  400 + parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
  401 +
  402 + parser.add_argument('--max_session', type=int, default=1) #multi session count
  403 + parser.add_argument('--listenport', type=int, default=8010)
  404 +
  405 + opt = parser.parse_args()
  406 + #app.config.from_object(opt)
  407 + #print(app.config)
  408 + opt.customopt = []
  409 + if opt.customvideo_config!='':
  410 + with open(opt.customvideo_config,'r') as file:
  411 + opt.customopt = json.load(file)
  412 +
  413 + if opt.model == 'ernerf':
  414 + from nerfreal import NeRFReal,load_model,load_avatar
  415 + model = load_model(opt)
  416 + avatar = load_avatar(opt)
  417 +
  418 + # we still need test_loader to provide audio features for testing.
  419 + # for k in range(opt.max_session):
  420 + # opt.sessionid=k
  421 + # nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model)
  422 + # nerfreals.append(nerfreal)
  423 + elif opt.model == 'musetalk':
  424 + from musereal import MuseReal,load_model,load_avatar,warm_up
  425 + logger.info(opt)
  426 + model = load_model()
  427 + avatar = load_avatar(opt.avatar_id)
  428 + warm_up(opt.batch_size,model)
  429 + # for k in range(opt.max_session):
  430 + # opt.sessionid=k
  431 + # nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps)
  432 + # nerfreals.append(nerfreal)
  433 + elif opt.model == 'wav2lip':
  434 + from lipreal import LipReal,load_model,load_avatar,warm_up
  435 + logger.info(opt)
  436 + model = load_model("./models/wav2lip.pth")
  437 + avatar = load_avatar(opt.avatar_id)
  438 + warm_up(opt.batch_size,model,256)
  439 + # for k in range(opt.max_session):
  440 + # opt.sessionid=k
  441 + # nerfreal = LipReal(opt,model)
  442 + # nerfreals.append(nerfreal)
  443 + elif opt.model == 'ultralight':
  444 + from lightreal import LightReal,load_model,load_avatar,warm_up
  445 + logger.info(opt)
  446 + model = load_model(opt)
  447 + avatar = load_avatar(opt.avatar_id)
  448 + warm_up(opt.batch_size,avatar,160)
  449 +
  450 + if opt.transport=='rtmp':
  451 + thread_quit = Event()
  452 + nerfreals[0] = build_nerfreal(0)
  453 + rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
  454 + rendthrd.start()
  455 +
  456 + #############################################################################
  457 + appasync = web.Application()
  458 + appasync.on_shutdown.append(on_shutdown)
  459 + appasync.router.add_post("/offer", offer)
  460 + appasync.router.add_post("/human", human)
  461 + appasync.router.add_post("/humanaudio", humanaudio)
  462 + appasync.router.add_post("/set_audiotype", set_audiotype)
  463 + appasync.router.add_post("/record", record)
  464 + appasync.router.add_post("/is_speaking", is_speaking)
  465 + appasync.router.add_static('/',path='web')
  466 +
  467 + # Configure default CORS settings.
  468 + cors = aiohttp_cors.setup(appasync, defaults={
  469 + "*": aiohttp_cors.ResourceOptions(
  470 + allow_credentials=True,
  471 + expose_headers="*",
  472 + allow_headers="*",
  473 + )
  474 + })
  475 + # Configure CORS on all routes.
  476 + for route in list(appasync.router.routes()):
  477 + cors.add(route)
  478 +
  479 + pagename='webrtcapi.html'
  480 + if opt.transport=='rtmp':
  481 + pagename='echoapi.html'
  482 + elif opt.transport=='rtcpush':
  483 + pagename='rtcpushapi.html'
  484 + logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
  485 + def run_server(runner):
  486 + loop = asyncio.new_event_loop()
  487 + asyncio.set_event_loop(loop)
  488 + loop.run_until_complete(runner.setup())
  489 + site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
  490 + loop.run_until_complete(site.start())
  491 + if opt.transport=='rtcpush':
  492 + for k in range(opt.max_session):
  493 + push_url = opt.push_url
  494 + if k!=0:
  495 + push_url = opt.push_url+str(k)
  496 + loop.run_until_complete(run(push_url,k))
  497 + loop.run_forever()
  498 + #Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
  499 + run_server(web.AppRunner(appasync))
  500 +
  501 + #app.on_shutdown.append(on_shutdown)
  502 + #app.router.add_post("/offer", offer)
  503 +
  504 + # print('start websocket server')
  505 + # server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
  506 + # server.serve_forever()
  507 +
  508 +
  1 +1. pytorch3d安装不成功\
  2 + 下载源码编译
  3 +
  4 +```bash
  5 +git clone https://github.com/facebookresearch/pytorch3d.git
  6 +python setup.py install
  7 +```
  8 +
  9 +2. websocket连接报错\
  10 + 修改python/site-packages/flask\_sockets.py
  11 +
  12 +```python
  13 +self.url_map.add(Rule(rule, endpoint=f)) 改成
  14 +self.url_map.add(Rule(rule, endpoint=f, websocket=True))
  15 +```
  16 +
  17 +3. protobuf版本过高
  18 +
  19 +```bash
  20 +pip uninstall protobuf
  21 +pip install protobuf==3.20.1
  22 +```
  23 +
  24 +4. 数字人不眨眼\
  25 +训练模型时添加如下步骤
  26 +
  27 +> Obtain AU45 for eyes blinking.\
  28 +> Run FeatureExtraction in OpenFace, rename and move the output CSV file to data/\<ID>/au.csv.
  29 +
  30 +将au.csv拷到本项目的data目录下
  31 +
  32 +5. 数字人添加背景图片
  33 +
  34 +```bash
  35 +python app.py --bg_img bc.jpg
  36 +```
  37 +
  38 +6. 用自己训练的模型报错维度不匹配\
  39 +训练模型时用wav2vec提取音频特征
  40 +
  41 +```bash
  42 +python main.py data/ --workspace workspace/ -O --iters 100000 --asr_model cpierse/wav2vec2-large-xlsr-53-esperanto
  43 +```
  44 +
  45 +7. rtmp推流时ffmpeg版本不对
  46 +网上版友反馈是需要4.2.2版本。我也不确定具体哪些版本不行。原则是运行一下ffmpeg,打印的信息里需要有libx264,如果没有肯定不行
  47 +```
  48 +--enable-libx264
  49 +```
  50 +8. 替换自己训练的模型
  51 +```python
  52 +.
  53 +├── data
  54 + ├── data_kf.json (对应训练数据中的transforms_train.json
  55 + ├── au.csv
  56 + ├── pretrained
  57 + └── └── ngp_kf.pth (对应训练后的模型ngp_ep00xx.pth
  58 +
  59 +```
  60 +
  61 +
  62 +其他参考
  63 +https://github.com/lipku/metahuman-stream/issues/43#issuecomment-2008930101
  64 +
  65 +
  1 +###############################################################################
  2 +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
  3 +# email: lipku@foxmail.com
  4 +#
  5 +# Licensed under the Apache License, Version 2.0 (the "License");
  6 +# you may not use this file except in compliance with the License.
  7 +# You may obtain a copy of the License at
  8 +#
  9 +# http://www.apache.org/licenses/LICENSE-2.0
  10 +#
  11 +# Unless required by applicable law or agreed to in writing, software
  12 +# distributed under the License is distributed on an "AS IS" BASIS,
  13 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14 +# See the License for the specific language governing permissions and
  15 +# limitations under the License.
  16 +###############################################################################
  17 +
  18 +import time
  19 +import numpy as np
  20 +
  21 +import queue
  22 +from queue import Queue
  23 +import torch.multiprocessing as mp
  24 +
  25 +from basereal import BaseReal
  26 +
  27 +
  28 +class BaseASR:
  29 + def __init__(self, opt, parent:BaseReal = None):
  30 + self.opt = opt
  31 + self.parent = parent
  32 +
  33 + self.fps = opt.fps # 20 ms per frame
  34 + self.sample_rate = 16000
  35 + self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
  36 + self.queue = Queue()
  37 + self.output_queue = mp.Queue()
  38 +
  39 + self.batch_size = opt.batch_size
  40 +
  41 + self.frames = []
  42 + self.stride_left_size = opt.l
  43 + self.stride_right_size = opt.r
  44 + #self.context_size = 10
  45 + self.feat_queue = mp.Queue(2)
  46 +
  47 + #self.warm_up()
  48 +
  49 + def flush_talk(self):
  50 + self.queue.queue.clear()
  51 +
  52 + def put_audio_frame(self,audio_chunk,eventpoint=None): #16khz 20ms pcm
  53 + self.queue.put((audio_chunk,eventpoint))
  54 +
  55 + #return frame:audio pcm; type: 0-normal speak, 1-silence; eventpoint:custom event sync with audio
  56 + def get_audio_frame(self):
  57 + try:
  58 + frame,eventpoint = self.queue.get(block=True,timeout=0.01)
  59 + type = 0
  60 + #print(f'[INFO] get frame {frame.shape}')
  61 + except queue.Empty:
  62 + if self.parent and self.parent.curr_state>1: #播放自定义音频
  63 + frame = self.parent.get_audio_stream(self.parent.curr_state)
  64 + type = self.parent.curr_state
  65 + else:
  66 + frame = np.zeros(self.chunk, dtype=np.float32)
  67 + type = 1
  68 + eventpoint = None
  69 +
  70 + return frame,type,eventpoint
  71 +
  72 + #return frame:audio pcm; type: 0-normal speak, 1-silence; eventpoint:custom event sync with audio
  73 + def get_audio_out(self):
  74 + return self.output_queue.get()
  75 +
  76 + def warm_up(self):
  77 + for _ in range(self.stride_left_size + self.stride_right_size):
  78 + audio_frame,type,eventpoint=self.get_audio_frame()
  79 + self.frames.append(audio_frame)
  80 + self.output_queue.put((audio_frame,type,eventpoint))
  81 + for _ in range(self.stride_left_size):
  82 + self.output_queue.get()
  83 +
  84 + def run_step(self):
  85 + pass
  86 +
  87 + def get_next_feat(self,block,timeout):
  88 + return self.feat_queue.get(block,timeout)
  1 +###############################################################################
  2 +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
  3 +# email: lipku@foxmail.com
  4 +#
  5 +# Licensed under the Apache License, Version 2.0 (the "License");
  6 +# you may not use this file except in compliance with the License.
  7 +# You may obtain a copy of the License at
  8 +#
  9 +# http://www.apache.org/licenses/LICENSE-2.0
  10 +#
  11 +# Unless required by applicable law or agreed to in writing, software
  12 +# distributed under the License is distributed on an "AS IS" BASIS,
  13 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14 +# See the License for the specific language governing permissions and
  15 +# limitations under the License.
  16 +###############################################################################
  17 +
  18 +import math
  19 +import torch
  20 +import numpy as np
  21 +
  22 +import subprocess
  23 +import os
  24 +import time
  25 +import cv2
  26 +import glob
  27 +import resampy
  28 +
  29 +import queue
  30 +from queue import Queue
  31 +from threading import Thread, Event
  32 +from io import BytesIO
  33 +import soundfile as sf
  34 +
  35 +import av
  36 +from fractions import Fraction
  37 +
  38 +from ttsreal import EdgeTTS,VoitsTTS,XTTS,CosyVoiceTTS,FishTTS,TencentTTS
  39 +from logger import logger
  40 +
  41 +from tqdm import tqdm
  42 +def read_imgs(img_list):
  43 + frames = []
  44 + logger.info('reading images...')
  45 + for img_path in tqdm(img_list):
  46 + frame = cv2.imread(img_path)
  47 + frames.append(frame)
  48 + return frames
  49 +
  50 +class BaseReal:
  51 + def __init__(self, opt):
  52 + self.opt = opt
  53 + self.sample_rate = 16000
  54 + self.chunk = self.sample_rate // opt.fps # 320 samples per chunk (20ms * 16000 / 1000)
  55 + self.sessionid = self.opt.sessionid
  56 +
  57 + if opt.tts == "edgetts":
  58 + self.tts = EdgeTTS(opt,self)
  59 + elif opt.tts == "gpt-sovits":
  60 + self.tts = VoitsTTS(opt,self)
  61 + elif opt.tts == "xtts":
  62 + self.tts = XTTS(opt,self)
  63 + elif opt.tts == "cosyvoice":
  64 + self.tts = CosyVoiceTTS(opt,self)
  65 + elif opt.tts == "fishtts":
  66 + self.tts = FishTTS(opt,self)
  67 + elif opt.tts == "tencent":
  68 + self.tts = TencentTTS(opt,self)
  69 +
  70 + self.speaking = False
  71 +
  72 + self.recording = False
  73 + self._record_video_pipe = None
  74 + self._record_audio_pipe = None
  75 + self.width = self.height = 0
  76 +
  77 + self.curr_state=0
  78 + self.custom_img_cycle = {}
  79 + self.custom_audio_cycle = {}
  80 + self.custom_audio_index = {}
  81 + self.custom_index = {}
  82 + self.custom_opt = {}
  83 + self.__loadcustom()
  84 +
  85 + def put_msg_txt(self,msg,eventpoint=None):
  86 + self.tts.put_msg_txt(msg,eventpoint)
  87 +
  88 + def put_audio_frame(self,audio_chunk,eventpoint=None): #16khz 20ms pcm
  89 + self.asr.put_audio_frame(audio_chunk,eventpoint)
  90 +
  91 + def put_audio_file(self,filebyte):
  92 + input_stream = BytesIO(filebyte)
  93 + stream = self.__create_bytes_stream(input_stream)
  94 + streamlen = stream.shape[0]
  95 + idx=0
  96 + while streamlen >= self.chunk: #and self.state==State.RUNNING
  97 + self.put_audio_frame(stream[idx:idx+self.chunk])
  98 + streamlen -= self.chunk
  99 + idx += self.chunk
  100 +
  101 + def __create_bytes_stream(self,byte_stream):
  102 + #byte_stream=BytesIO(buffer)
  103 + stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
  104 + logger.info(f'[INFO]put audio stream {sample_rate}: {stream.shape}')
  105 + stream = stream.astype(np.float32)
  106 +
  107 + if stream.ndim > 1:
  108 + logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
  109 + stream = stream[:, 0]
  110 +
  111 + if sample_rate != self.sample_rate and stream.shape[0]>0:
  112 + logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
  113 + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
  114 +
  115 + return stream
  116 +
  117 + def flush_talk(self):
  118 + self.tts.flush_talk()
  119 + self.asr.flush_talk()
  120 +
  121 + def is_speaking(self)->bool:
  122 + return self.speaking
  123 +
  124 + def __loadcustom(self):
  125 + for item in self.opt.customopt:
  126 + logger.info(item)
  127 + input_img_list = glob.glob(os.path.join(item['imgpath'], '*.[jpJP][pnPN]*[gG]'))
  128 + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
  129 + self.custom_img_cycle[item['audiotype']] = read_imgs(input_img_list)
  130 + self.custom_audio_cycle[item['audiotype']], sample_rate = sf.read(item['audiopath'], dtype='float32')
  131 + self.custom_audio_index[item['audiotype']] = 0
  132 + self.custom_index[item['audiotype']] = 0
  133 + self.custom_opt[item['audiotype']] = item
  134 +
  135 + def init_customindex(self):
  136 + self.curr_state=0
  137 + for key in self.custom_audio_index:
  138 + self.custom_audio_index[key]=0
  139 + for key in self.custom_index:
  140 + self.custom_index[key]=0
  141 +
  142 + def notify(self,eventpoint):
  143 + logger.info("notify:%s",eventpoint)
  144 +
  145 + def start_recording(self):
  146 + """开始录制视频"""
  147 + if self.recording:
  148 + return
  149 +
  150 + command = ['ffmpeg',
  151 + '-y', '-an',
  152 + '-f', 'rawvideo',
  153 + '-vcodec','rawvideo',
  154 + '-pix_fmt', 'bgr24', #像素格式
  155 + '-s', "{}x{}".format(self.width, self.height),
  156 + '-r', str(25),
  157 + '-i', '-',
  158 + '-pix_fmt', 'yuv420p',
  159 + '-vcodec', "h264",
  160 + #'-f' , 'flv',
  161 + f'temp{self.opt.sessionid}.mp4']
  162 + self._record_video_pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE)
  163 +
  164 + acommand = ['ffmpeg',
  165 + '-y', '-vn',
  166 + '-f', 's16le',
  167 + #'-acodec','pcm_s16le',
  168 + '-ac', '1',
  169 + '-ar', '16000',
  170 + '-i', '-',
  171 + '-acodec', 'aac',
  172 + #'-f' , 'wav',
  173 + f'temp{self.opt.sessionid}.aac']
  174 + self._record_audio_pipe = subprocess.Popen(acommand, shell=False, stdin=subprocess.PIPE)
  175 +
  176 + self.recording = True
  177 + # self.recordq_video.queue.clear()
  178 + # self.recordq_audio.queue.clear()
  179 + # self.container = av.open(path, mode="w")
  180 +
  181 + # process_thread = Thread(target=self.record_frame, args=())
  182 + # process_thread.start()
  183 +
  184 + def record_video_data(self,image):
  185 + if self.width == 0:
  186 + print("image.shape:",image.shape)
  187 + self.height,self.width,_ = image.shape
  188 + if self.recording:
  189 + self._record_video_pipe.stdin.write(image.tostring())
  190 +
  191 + def record_audio_data(self,frame):
  192 + if self.recording:
  193 + self._record_audio_pipe.stdin.write(frame.tostring())
  194 +
  195 + # def record_frame(self):
  196 + # videostream = self.container.add_stream("libx264", rate=25)
  197 + # videostream.codec_context.time_base = Fraction(1, 25)
  198 + # audiostream = self.container.add_stream("aac")
  199 + # audiostream.codec_context.time_base = Fraction(1, 16000)
  200 + # init = True
  201 + # framenum = 0
  202 + # while self.recording:
  203 + # try:
  204 + # videoframe = self.recordq_video.get(block=True, timeout=1)
  205 + # videoframe.pts = framenum #int(round(framenum*0.04 / videostream.codec_context.time_base))
  206 + # videoframe.dts = videoframe.pts
  207 + # if init:
  208 + # videostream.width = videoframe.width
  209 + # videostream.height = videoframe.height
  210 + # init = False
  211 + # for packet in videostream.encode(videoframe):
  212 + # self.container.mux(packet)
  213 + # for k in range(2):
  214 + # audioframe = self.recordq_audio.get(block=True, timeout=1)
  215 + # audioframe.pts = int(round((framenum*2+k)*0.02 / audiostream.codec_context.time_base))
  216 + # audioframe.dts = audioframe.pts
  217 + # for packet in audiostream.encode(audioframe):
  218 + # self.container.mux(packet)
  219 + # framenum += 1
  220 + # except queue.Empty:
  221 + # print('record queue empty,')
  222 + # continue
  223 + # except Exception as e:
  224 + # print(e)
  225 + # #break
  226 + # for packet in videostream.encode(None):
  227 + # self.container.mux(packet)
  228 + # for packet in audiostream.encode(None):
  229 + # self.container.mux(packet)
  230 + # self.container.close()
  231 + # self.recordq_video.queue.clear()
  232 + # self.recordq_audio.queue.clear()
  233 + # print('record thread stop')
  234 +
  235 + def stop_recording(self):
  236 + """停止录制视频"""
  237 + if not self.recording:
  238 + return
  239 + self.recording = False
  240 + self._record_video_pipe.stdin.close() #wait()
  241 + self._record_video_pipe.wait()
  242 + self._record_audio_pipe.stdin.close()
  243 + self._record_audio_pipe.wait()
  244 + cmd_combine_audio = f"ffmpeg -y -i temp{self.opt.sessionid}.aac -i temp{self.opt.sessionid}.mp4 -c:v copy -c:a copy data/record.mp4"
  245 + os.system(cmd_combine_audio)
  246 + #os.remove(output_path)
  247 +
  248 + def mirror_index(self,size, index):
  249 + #size = len(self.coord_list_cycle)
  250 + turn = index // size
  251 + res = index % size
  252 + if turn % 2 == 0:
  253 + return res
  254 + else:
  255 + return size - res - 1
  256 +
  257 + def get_audio_stream(self,audiotype):
  258 + idx = self.custom_audio_index[audiotype]
  259 + stream = self.custom_audio_cycle[audiotype][idx:idx+self.chunk]
  260 + self.custom_audio_index[audiotype] += self.chunk
  261 + if self.custom_audio_index[audiotype]>=self.custom_audio_cycle[audiotype].shape[0]:
  262 + self.curr_state = 1 #当前视频不循环播放,切换到静音状态
  263 + return stream
  264 +
  265 + def set_curr_state(self,audiotype, reinit):
  266 + print('set_curr_state:',audiotype)
  267 + self.curr_state = audiotype
  268 + if reinit:
  269 + self.custom_audio_index[audiotype] = 0
  270 + self.custom_index[audiotype] = 0
  271 +
  272 + # def process_custom(self,audiotype:int,idx:int):
  273 + # if self.curr_state!=audiotype: #从推理切到口播
  274 + # if idx in self.switch_pos: #在卡点位置可以切换
  275 + # self.curr_state=audiotype
  276 + # self.custom_index=0
  277 + # else:
  278 + # self.custom_index+=1
  1 +# Routines for DeepSpeech features processing
  2 +Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model.
  3 +
  4 +## Installation
  5 +
  6 +```
  7 +pip3 install -r requirements.txt
  8 +```
  9 +
  10 +## Usage
  11 +
  12 +Generate wav files:
  13 +```
  14 +python3 extract_wav.py --in-video=<you_data_dir>
  15 +```
  16 +
  17 +Generate files with DeepSpeech features:
  18 +```
  19 +python3 extract_ds_features.py --input=<you_data_dir>
  20 +```
  1 +"""
  2 + DeepSpeech features processing routines.
  3 + NB: Based on VOCA code. See the corresponding license restrictions.
  4 +"""
  5 +
  6 +__all__ = ['conv_audios_to_deepspeech']
  7 +
  8 +import numpy as np
  9 +import warnings
  10 +import resampy
  11 +from scipy.io import wavfile
  12 +from python_speech_features import mfcc
  13 +import tensorflow.compat.v1 as tf
  14 +tf.disable_v2_behavior()
  15 +
  16 +def conv_audios_to_deepspeech(audios,
  17 + out_files,
  18 + num_frames_info,
  19 + deepspeech_pb_path,
  20 + audio_window_size=1,
  21 + audio_window_stride=1):
  22 + """
  23 + Convert list of audio files into files with DeepSpeech features.
  24 +
  25 + Parameters
  26 + ----------
  27 + audios : list of str or list of None
  28 + Paths to input audio files.
  29 + out_files : list of str
  30 + Paths to output files with DeepSpeech features.
  31 + num_frames_info : list of int
  32 + List of numbers of frames.
  33 + deepspeech_pb_path : str
  34 + Path to DeepSpeech 0.1.0 frozen model.
  35 + audio_window_size : int, default 16
  36 + Audio window size.
  37 + audio_window_stride : int, default 1
  38 + Audio window stride.
  39 + """
  40 + # deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
  41 + graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(
  42 + deepspeech_pb_path)
  43 +
  44 + with tf.compat.v1.Session(graph=graph) as sess:
  45 + for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):
  46 + print(audio_file_path)
  47 + print(out_file_path)
  48 + audio_sample_rate, audio = wavfile.read(audio_file_path)
  49 + if audio.ndim != 1:
  50 + warnings.warn(
  51 + "Audio has multiple channels, the first channel is used")
  52 + audio = audio[:, 0]
  53 + ds_features = pure_conv_audio_to_deepspeech(
  54 + audio=audio,
  55 + audio_sample_rate=audio_sample_rate,
  56 + audio_window_size=audio_window_size,
  57 + audio_window_stride=audio_window_stride,
  58 + num_frames=num_frames,
  59 + net_fn=lambda x: sess.run(
  60 + logits_ph,
  61 + feed_dict={
  62 + input_node_ph: x[np.newaxis, ...],
  63 + input_lengths_ph: [x.shape[0]]}))
  64 +
  65 + net_output = ds_features.reshape(-1, 29)
  66 + win_size = 16
  67 + zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
  68 + net_output = np.concatenate(
  69 + (zero_pad, net_output, zero_pad), axis=0)
  70 + windows = []
  71 + for window_index in range(0, net_output.shape[0] - win_size, 2):
  72 + windows.append(
  73 + net_output[window_index:window_index + win_size])
  74 + print(np.array(windows).shape)
  75 + np.save(out_file_path, np.array(windows))
  76 +
  77 +
  78 +def prepare_deepspeech_net(deepspeech_pb_path):
  79 + """
  80 + Load and prepare DeepSpeech network.
  81 +
  82 + Parameters
  83 + ----------
  84 + deepspeech_pb_path : str
  85 + Path to DeepSpeech 0.1.0 frozen model.
  86 +
  87 + Returns
  88 + -------
  89 + graph : obj
  90 + ThensorFlow graph.
  91 + logits_ph : obj
  92 + ThensorFlow placeholder for `logits`.
  93 + input_node_ph : obj
  94 + ThensorFlow placeholder for `input_node`.
  95 + input_lengths_ph : obj
  96 + ThensorFlow placeholder for `input_lengths`.
  97 + """
  98 + # Load graph and place_holders:
  99 + with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f:
  100 + graph_def = tf.compat.v1.GraphDef()
  101 + graph_def.ParseFromString(f.read())
  102 +
  103 + graph = tf.compat.v1.get_default_graph()
  104 + tf.import_graph_def(graph_def, name="deepspeech")
  105 + logits_ph = graph.get_tensor_by_name("deepspeech/logits:0")
  106 + input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0")
  107 + input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0")
  108 +
  109 + return graph, logits_ph, input_node_ph, input_lengths_ph
  110 +
  111 +
  112 +def pure_conv_audio_to_deepspeech(audio,
  113 + audio_sample_rate,
  114 + audio_window_size,
  115 + audio_window_stride,
  116 + num_frames,
  117 + net_fn):
  118 + """
  119 + Core routine for converting audion into DeepSpeech features.
  120 +
  121 + Parameters
  122 + ----------
  123 + audio : np.array
  124 + Audio data.
  125 + audio_sample_rate : int
  126 + Audio sample rate.
  127 + audio_window_size : int
  128 + Audio window size.
  129 + audio_window_stride : int
  130 + Audio window stride.
  131 + num_frames : int or None
  132 + Numbers of frames.
  133 + net_fn : func
  134 + Function for DeepSpeech model call.
  135 +
  136 + Returns
  137 + -------
  138 + np.array
  139 + DeepSpeech features.
  140 + """
  141 + target_sample_rate = 16000
  142 + if audio_sample_rate != target_sample_rate:
  143 + resampled_audio = resampy.resample(
  144 + x=audio.astype(np.float),
  145 + sr_orig=audio_sample_rate,
  146 + sr_new=target_sample_rate)
  147 + else:
  148 + resampled_audio = audio.astype(np.float)
  149 + input_vector = conv_audio_to_deepspeech_input_vector(
  150 + audio=resampled_audio.astype(np.int16),
  151 + sample_rate=target_sample_rate,
  152 + num_cepstrum=26,
  153 + num_context=9)
  154 +
  155 + network_output = net_fn(input_vector)
  156 + # print(network_output.shape)
  157 +
  158 + deepspeech_fps = 50
  159 + video_fps = 50 # Change this option if video fps is different
  160 + audio_len_s = float(audio.shape[0]) / audio_sample_rate
  161 + if num_frames is None:
  162 + num_frames = int(round(audio_len_s * video_fps))
  163 + else:
  164 + video_fps = num_frames / audio_len_s
  165 + network_output = interpolate_features(
  166 + features=network_output[:, 0],
  167 + input_rate=deepspeech_fps,
  168 + output_rate=video_fps,
  169 + output_len=num_frames)
  170 +
  171 + # Make windows:
  172 + zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))
  173 + network_output = np.concatenate(
  174 + (zero_pad, network_output, zero_pad), axis=0)
  175 + windows = []
  176 + for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):
  177 + windows.append(
  178 + network_output[window_index:window_index + audio_window_size])
  179 +
  180 + return np.array(windows)
  181 +
  182 +
  183 +def conv_audio_to_deepspeech_input_vector(audio,
  184 + sample_rate,
  185 + num_cepstrum,
  186 + num_context):
  187 + """
  188 + Convert audio raw data into DeepSpeech input vector.
  189 +
  190 + Parameters
  191 + ----------
  192 + audio : np.array
  193 + Audio data.
  194 + audio_sample_rate : int
  195 + Audio sample rate.
  196 + num_cepstrum : int
  197 + Number of cepstrum.
  198 + num_context : int
  199 + Number of context.
  200 +
  201 + Returns
  202 + -------
  203 + np.array
  204 + DeepSpeech input vector.
  205 + """
  206 + # Get mfcc coefficients:
  207 + features = mfcc(
  208 + signal=audio,
  209 + samplerate=sample_rate,
  210 + numcep=num_cepstrum)
  211 +
  212 + # We only keep every second feature (BiRNN stride = 2):
  213 + features = features[::2]
  214 +
  215 + # One stride per time step in the input:
  216 + num_strides = len(features)
  217 +
  218 + # Add empty initial and final contexts:
  219 + empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)
  220 + features = np.concatenate((empty_context, features, empty_context))
  221 +
  222 + # Create a view into the array with overlapping strides of size
  223 + # numcontext (past) + 1 (present) + numcontext (future):
  224 + window_size = 2 * num_context + 1
  225 + train_inputs = np.lib.stride_tricks.as_strided(
  226 + features,
  227 + shape=(num_strides, window_size, num_cepstrum),
  228 + strides=(features.strides[0],
  229 + features.strides[0], features.strides[1]),
  230 + writeable=False)
  231 +
  232 + # Flatten the second and third dimensions:
  233 + train_inputs = np.reshape(train_inputs, [num_strides, -1])
  234 +
  235 + train_inputs = np.copy(train_inputs)
  236 + train_inputs = (train_inputs - np.mean(train_inputs)) / \
  237 + np.std(train_inputs)
  238 +
  239 + return train_inputs
  240 +
  241 +
  242 +def interpolate_features(features,
  243 + input_rate,
  244 + output_rate,
  245 + output_len):
  246 + """
  247 + Interpolate DeepSpeech features.
  248 +
  249 + Parameters
  250 + ----------
  251 + features : np.array
  252 + DeepSpeech features.
  253 + input_rate : int
  254 + input rate (FPS).
  255 + output_rate : int
  256 + Output rate (FPS).
  257 + output_len : int
  258 + Output data length.
  259 +
  260 + Returns
  261 + -------
  262 + np.array
  263 + Interpolated data.
  264 + """
  265 + input_len = features.shape[0]
  266 + num_features = features.shape[1]
  267 + input_timestamps = np.arange(input_len) / float(input_rate)
  268 + output_timestamps = np.arange(output_len) / float(output_rate)
  269 + output_features = np.zeros((output_len, num_features))
  270 + for feature_idx in range(num_features):
  271 + output_features[:, feature_idx] = np.interp(
  272 + x=output_timestamps,
  273 + xp=input_timestamps,
  274 + fp=features[:, feature_idx])
  275 + return output_features
  1 +"""
  2 + Routines for loading DeepSpeech model.
  3 +"""
  4 +
  5 +__all__ = ['get_deepspeech_model_file']
  6 +
  7 +import os
  8 +import zipfile
  9 +import logging
  10 +import hashlib
  11 +
  12 +
  13 +deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'
  14 +
  15 +
  16 +def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")):
  17 + """
  18 + Return location for the pretrained on local file system. This function will download from online model zoo when
  19 + model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
  20 +
  21 + Parameters
  22 + ----------
  23 + local_model_store_dir_path : str, default $TENSORFLOW_HOME/models
  24 + Location for keeping the model parameters.
  25 +
  26 + Returns
  27 + -------
  28 + file_path
  29 + Path to the requested pretrained model file.
  30 + """
  31 + sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e"
  32 + file_name = "deepspeech-0_1_0-b90017e8.pb"
  33 + local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)
  34 + file_path = os.path.join(local_model_store_dir_path, file_name)
  35 + if os.path.exists(file_path):
  36 + if _check_sha1(file_path, sha1_hash):
  37 + return file_path
  38 + else:
  39 + logging.warning("Mismatch in the content of model file detected. Downloading again.")
  40 + else:
  41 + logging.info("Model file not found. Downloading to {}.".format(file_path))
  42 +
  43 + if not os.path.exists(local_model_store_dir_path):
  44 + os.makedirs(local_model_store_dir_path)
  45 +
  46 + zip_file_path = file_path + ".zip"
  47 + _download(
  48 + url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format(
  49 + repo_url=deepspeech_features_repo_url,
  50 + repo_release_tag="v0.0.1",
  51 + file_name=file_name),
  52 + path=zip_file_path,
  53 + overwrite=True)
  54 + with zipfile.ZipFile(zip_file_path) as zf:
  55 + zf.extractall(local_model_store_dir_path)
  56 + os.remove(zip_file_path)
  57 +
  58 + if _check_sha1(file_path, sha1_hash):
  59 + return file_path
  60 + else:
  61 + raise ValueError("Downloaded file has different hash. Please try again.")
  62 +
  63 +
  64 +def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
  65 + """
  66 + Download an given URL
  67 +
  68 + Parameters
  69 + ----------
  70 + url : str
  71 + URL to download
  72 + path : str, optional
  73 + Destination path to store downloaded file. By default stores to the
  74 + current directory with same name as in url.
  75 + overwrite : bool, optional
  76 + Whether to overwrite destination file if already exists.
  77 + sha1_hash : str, optional
  78 + Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
  79 + but doesn't match.
  80 + retries : integer, default 5
  81 + The number of times to attempt the download in case of failure or non 200 return codes
  82 + verify_ssl : bool, default True
  83 + Verify SSL certificates.
  84 +
  85 + Returns
  86 + -------
  87 + str
  88 + The file path of the downloaded file.
  89 + """
  90 + import warnings
  91 + try:
  92 + import requests
  93 + except ImportError:
  94 + class requests_failed_to_import(object):
  95 + pass
  96 + requests = requests_failed_to_import
  97 +
  98 + if path is None:
  99 + fname = url.split("/")[-1]
  100 + # Empty filenames are invalid
  101 + assert fname, "Can't construct file-name from this URL. Please set the `path` option manually."
  102 + else:
  103 + path = os.path.expanduser(path)
  104 + if os.path.isdir(path):
  105 + fname = os.path.join(path, url.split("/")[-1])
  106 + else:
  107 + fname = path
  108 + assert retries >= 0, "Number of retries should be at least 0"
  109 +
  110 + if not verify_ssl:
  111 + warnings.warn(
  112 + "Unverified HTTPS request is being made (verify_ssl=False). "
  113 + "Adding certificate verification is strongly advised.")
  114 +
  115 + if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):
  116 + dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
  117 + if not os.path.exists(dirname):
  118 + os.makedirs(dirname)
  119 + while retries + 1 > 0:
  120 + # Disable pyling too broad Exception
  121 + # pylint: disable=W0703
  122 + try:
  123 + print("Downloading {} from {}...".format(fname, url))
  124 + r = requests.get(url, stream=True, verify=verify_ssl)
  125 + if r.status_code != 200:
  126 + raise RuntimeError("Failed downloading url {}".format(url))
  127 + with open(fname, "wb") as f:
  128 + for chunk in r.iter_content(chunk_size=1024):
  129 + if chunk: # filter out keep-alive new chunks
  130 + f.write(chunk)
  131 + if sha1_hash and not _check_sha1(fname, sha1_hash):
  132 + raise UserWarning("File {} is downloaded but the content hash does not match."
  133 + " The repo may be outdated or download may be incomplete. "
  134 + "If the `repo_url` is overridden, consider switching to "
  135 + "the default repo.".format(fname))
  136 + break
  137 + except Exception as e:
  138 + retries -= 1
  139 + if retries <= 0:
  140 + raise e
  141 + else:
  142 + print("download failed, retrying, {} attempt{} left"
  143 + .format(retries, "s" if retries > 1 else ""))
  144 +
  145 + return fname
  146 +
  147 +
  148 +def _check_sha1(filename, sha1_hash):
  149 + """
  150 + Check whether the sha1 hash of the file content matches the expected hash.
  151 +
  152 + Parameters
  153 + ----------
  154 + filename : str
  155 + Path to the file.
  156 + sha1_hash : str
  157 + Expected sha1 hash in hexadecimal digits.
  158 +
  159 + Returns
  160 + -------
  161 + bool
  162 + Whether the file content matches the expected hash.
  163 + """
  164 + sha1 = hashlib.sha1()
  165 + with open(filename, "rb") as f:
  166 + while True:
  167 + data = f.read(1048576)
  168 + if not data:
  169 + break
  170 + sha1.update(data)
  171 +
  172 + return sha1.hexdigest() == sha1_hash
  1 +"""
  2 + Script for extracting DeepSpeech features from audio file.
  3 +"""
  4 +
  5 +import os
  6 +import argparse
  7 +import numpy as np
  8 +import pandas as pd
  9 +from deepspeech_store import get_deepspeech_model_file
  10 +from deepspeech_features import conv_audios_to_deepspeech
  11 +
  12 +
  13 +def parse_args():
  14 + """
  15 + Create python script parameters.
  16 + Returns
  17 + -------
  18 + ArgumentParser
  19 + Resulted args.
  20 + """
  21 + parser = argparse.ArgumentParser(
  22 + description="Extract DeepSpeech features from audio file",
  23 + formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  24 + parser.add_argument(
  25 + "--input",
  26 + type=str,
  27 + required=True,
  28 + help="path to input audio file or directory")
  29 + parser.add_argument(
  30 + "--output",
  31 + type=str,
  32 + help="path to output file with DeepSpeech features")
  33 + parser.add_argument(
  34 + "--deepspeech",
  35 + type=str,
  36 + help="path to DeepSpeech 0.1.0 frozen model")
  37 + parser.add_argument(
  38 + "--metainfo",
  39 + type=str,
  40 + help="path to file with meta-information")
  41 +
  42 + args = parser.parse_args()
  43 + return args
  44 +
  45 +
  46 +def extract_features(in_audios,
  47 + out_files,
  48 + deepspeech_pb_path,
  49 + metainfo_file_path=None):
  50 + """
  51 + Real extract audio from video file.
  52 + Parameters
  53 + ----------
  54 + in_audios : list of str
  55 + Paths to input audio files.
  56 + out_files : list of str
  57 + Paths to output files with DeepSpeech features.
  58 + deepspeech_pb_path : str
  59 + Path to DeepSpeech 0.1.0 frozen model.
  60 + metainfo_file_path : str, default None
  61 + Path to file with meta-information.
  62 + """
  63 + #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
  64 + if metainfo_file_path is None:
  65 + num_frames_info = [None] * len(in_audios)
  66 + else:
  67 + train_df = pd.read_csv(
  68 + metainfo_file_path,
  69 + sep="\t",
  70 + index_col=False,
  71 + dtype={"Id": np.int, "File": np.unicode, "Count": np.int})
  72 + num_frames_info = train_df["Count"].values
  73 + assert (len(num_frames_info) == len(in_audios))
  74 +
  75 + for i, in_audio in enumerate(in_audios):
  76 + if not out_files[i]:
  77 + file_stem, _ = os.path.splitext(in_audio)
  78 + out_files[i] = file_stem + ".npy"
  79 + #print(out_files[i])
  80 + conv_audios_to_deepspeech(
  81 + audios=in_audios,
  82 + out_files=out_files,
  83 + num_frames_info=num_frames_info,
  84 + deepspeech_pb_path=deepspeech_pb_path)
  85 +
  86 +
  87 +def main():
  88 + """
  89 + Main body of script.
  90 + """
  91 + args = parse_args()
  92 + in_audio = os.path.expanduser(args.input)
  93 + if not os.path.exists(in_audio):
  94 + raise Exception("Input file/directory doesn't exist: {}".format(in_audio))
  95 + deepspeech_pb_path = args.deepspeech
  96 + #add
  97 + deepspeech_pb_path = True
  98 + args.deepspeech = '~/.tensorflow/models/deepspeech-0_1_0-b90017e8.pb'
  99 + #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
  100 + if deepspeech_pb_path is None:
  101 + deepspeech_pb_path = ""
  102 + if deepspeech_pb_path:
  103 + deepspeech_pb_path = os.path.expanduser(args.deepspeech)
  104 + if not os.path.exists(deepspeech_pb_path):
  105 + deepspeech_pb_path = get_deepspeech_model_file()
  106 + if os.path.isfile(in_audio):
  107 + extract_features(
  108 + in_audios=[in_audio],
  109 + out_files=[args.output],
  110 + deepspeech_pb_path=deepspeech_pb_path,
  111 + metainfo_file_path=args.metainfo)
  112 + else:
  113 + audio_file_paths = []
  114 + for file_name in os.listdir(in_audio):
  115 + if not os.path.isfile(os.path.join(in_audio, file_name)):
  116 + continue
  117 + _, file_ext = os.path.splitext(file_name)
  118 + if file_ext.lower() == ".wav":
  119 + audio_file_path = os.path.join(in_audio, file_name)
  120 + audio_file_paths.append(audio_file_path)
  121 + audio_file_paths = sorted(audio_file_paths)
  122 + out_file_paths = [""] * len(audio_file_paths)
  123 + extract_features(
  124 + in_audios=audio_file_paths,
  125 + out_files=out_file_paths,
  126 + deepspeech_pb_path=deepspeech_pb_path,
  127 + metainfo_file_path=args.metainfo)
  128 +
  129 +
  130 +if __name__ == "__main__":
  131 + main()
  132 +
  1 +"""
  2 + Script for extracting audio (16-bit, mono, 22000 Hz) from video file.
  3 +"""
  4 +
  5 +import os
  6 +import argparse
  7 +import subprocess
  8 +
  9 +
  10 +def parse_args():
  11 + """
  12 + Create python script parameters.
  13 +
  14 + Returns
  15 + -------
  16 + ArgumentParser
  17 + Resulted args.
  18 + """
  19 + parser = argparse.ArgumentParser(
  20 + description="Extract audio from video file",
  21 + formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  22 + parser.add_argument(
  23 + "--in-video",
  24 + type=str,
  25 + required=True,
  26 + help="path to input video file or directory")
  27 + parser.add_argument(
  28 + "--out-audio",
  29 + type=str,
  30 + help="path to output audio file")
  31 +
  32 + args = parser.parse_args()
  33 + return args
  34 +
  35 +
  36 +def extract_audio(in_video,
  37 + out_audio):
  38 + """
  39 + Real extract audio from video file.
  40 +
  41 + Parameters
  42 + ----------
  43 + in_video : str
  44 + Path to input video file.
  45 + out_audio : str
  46 + Path to output audio file.
  47 + """
  48 + if not out_audio:
  49 + file_stem, _ = os.path.splitext(in_video)
  50 + out_audio = file_stem + ".wav"
  51 + # command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}"
  52 + # command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
  53 + # command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
  54 + command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}"
  55 + subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True)
  56 +
  57 +
  58 +def main():
  59 + """
  60 + Main body of script.
  61 + """
  62 + args = parse_args()
  63 + in_video = os.path.expanduser(args.in_video)
  64 + if not os.path.exists(in_video):
  65 + raise Exception("Input file/directory doesn't exist: {}".format(in_video))
  66 + if os.path.isfile(in_video):
  67 + extract_audio(
  68 + in_video=in_video,
  69 + out_audio=args.out_audio)
  70 + else:
  71 + video_file_paths = []
  72 + for file_name in os.listdir(in_video):
  73 + if not os.path.isfile(os.path.join(in_video, file_name)):
  74 + continue
  75 + _, file_ext = os.path.splitext(file_name)
  76 + if file_ext.lower() in (".mp4", ".mkv", ".avi"):
  77 + video_file_path = os.path.join(in_video, file_name)
  78 + video_file_paths.append(video_file_path)
  79 + video_file_paths = sorted(video_file_paths)
  80 + for video_file_path in video_file_paths:
  81 + extract_audio(
  82 + in_video=video_file_path,
  83 + out_audio="")
  84 +
  85 +
  86 +if __name__ == "__main__":
  87 + main()
  1 +import numpy as np
  2 +
  3 +net_output = np.load('french.ds.npy').reshape(-1, 29)
  4 +win_size = 16
  5 +zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
  6 +net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0)
  7 +windows = []
  8 +for window_index in range(0, net_output.shape[0] - win_size, 2):
  9 + windows.append(net_output[window_index:window_index + win_size])
  10 +print(np.array(windows).shape)
  11 +np.save('aud_french.npy', np.array(windows))
  1 +#!/usr/bin/python
  2 +# -*- encoding: utf-8 -*-
  3 +
  4 +
  5 +import os.path as osp
  6 +import time
  7 +import sys
  8 +import logging
  9 +
  10 +import torch.distributed as dist
  11 +
  12 +
  13 +def setup_logger(logpth):
  14 + logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
  15 + logfile = osp.join(logpth, logfile)
  16 + FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
  17 + log_level = logging.INFO
  18 + if dist.is_initialized() and not dist.get_rank()==0:
  19 + log_level = logging.ERROR
  20 + logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
  21 + logging.root.addHandler(logging.StreamHandler())
  22 +
  23 +
  1 +#!/usr/bin/python
  2 +# -*- encoding: utf-8 -*-
  3 +
  4 +
  5 +import torch
  6 +import torch.nn as nn
  7 +import torch.nn.functional as F
  8 +import torchvision
  9 +
  10 +from resnet import Resnet18
  11 +# from modules.bn import InPlaceABNSync as BatchNorm2d
  12 +
  13 +
  14 +class ConvBNReLU(nn.Module):
  15 + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
  16 + super(ConvBNReLU, self).__init__()
  17 + self.conv = nn.Conv2d(in_chan,
  18 + out_chan,
  19 + kernel_size = ks,
  20 + stride = stride,
  21 + padding = padding,
  22 + bias = False)
  23 + self.bn = nn.BatchNorm2d(out_chan)
  24 + self.init_weight()
  25 +
  26 + def forward(self, x):
  27 + x = self.conv(x)
  28 + x = F.relu(self.bn(x))
  29 + return x
  30 +
  31 + def init_weight(self):
  32 + for ly in self.children():
  33 + if isinstance(ly, nn.Conv2d):
  34 + nn.init.kaiming_normal_(ly.weight, a=1)
  35 + if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  36 +
  37 +class BiSeNetOutput(nn.Module):
  38 + def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
  39 + super(BiSeNetOutput, self).__init__()
  40 + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
  41 + self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
  42 + self.init_weight()
  43 +
  44 + def forward(self, x):
  45 + x = self.conv(x)
  46 + x = self.conv_out(x)
  47 + return x
  48 +
  49 + def init_weight(self):
  50 + for ly in self.children():
  51 + if isinstance(ly, nn.Conv2d):
  52 + nn.init.kaiming_normal_(ly.weight, a=1)
  53 + if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  54 +
  55 + def get_params(self):
  56 + wd_params, nowd_params = [], []
  57 + for name, module in self.named_modules():
  58 + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
  59 + wd_params.append(module.weight)
  60 + if not module.bias is None:
  61 + nowd_params.append(module.bias)
  62 + elif isinstance(module, nn.BatchNorm2d):
  63 + nowd_params += list(module.parameters())
  64 + return wd_params, nowd_params
  65 +
  66 +
  67 +class AttentionRefinementModule(nn.Module):
  68 + def __init__(self, in_chan, out_chan, *args, **kwargs):
  69 + super(AttentionRefinementModule, self).__init__()
  70 + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
  71 + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
  72 + self.bn_atten = nn.BatchNorm2d(out_chan)
  73 + self.sigmoid_atten = nn.Sigmoid()
  74 + self.init_weight()
  75 +
  76 + def forward(self, x):
  77 + feat = self.conv(x)
  78 + atten = F.avg_pool2d(feat, feat.size()[2:])
  79 + atten = self.conv_atten(atten)
  80 + atten = self.bn_atten(atten)
  81 + atten = self.sigmoid_atten(atten)
  82 + out = torch.mul(feat, atten)
  83 + return out
  84 +
  85 + def init_weight(self):
  86 + for ly in self.children():
  87 + if isinstance(ly, nn.Conv2d):
  88 + nn.init.kaiming_normal_(ly.weight, a=1)
  89 + if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  90 +
  91 +
  92 +class ContextPath(nn.Module):
  93 + def __init__(self, *args, **kwargs):
  94 + super(ContextPath, self).__init__()
  95 + self.resnet = Resnet18()
  96 + self.arm16 = AttentionRefinementModule(256, 128)
  97 + self.arm32 = AttentionRefinementModule(512, 128)
  98 + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
  99 + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
  100 + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
  101 +
  102 + self.init_weight()
  103 +
  104 + def forward(self, x):
  105 + H0, W0 = x.size()[2:]
  106 + feat8, feat16, feat32 = self.resnet(x)
  107 + H8, W8 = feat8.size()[2:]
  108 + H16, W16 = feat16.size()[2:]
  109 + H32, W32 = feat32.size()[2:]
  110 +
  111 + avg = F.avg_pool2d(feat32, feat32.size()[2:])
  112 + avg = self.conv_avg(avg)
  113 + avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
  114 +
  115 + feat32_arm = self.arm32(feat32)
  116 + feat32_sum = feat32_arm + avg_up
  117 + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
  118 + feat32_up = self.conv_head32(feat32_up)
  119 +
  120 + feat16_arm = self.arm16(feat16)
  121 + feat16_sum = feat16_arm + feat32_up
  122 + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
  123 + feat16_up = self.conv_head16(feat16_up)
  124 +
  125 + return feat8, feat16_up, feat32_up # x8, x8, x16
  126 +
  127 + def init_weight(self):
  128 + for ly in self.children():
  129 + if isinstance(ly, nn.Conv2d):
  130 + nn.init.kaiming_normal_(ly.weight, a=1)
  131 + if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  132 +
  133 + def get_params(self):
  134 + wd_params, nowd_params = [], []
  135 + for name, module in self.named_modules():
  136 + if isinstance(module, (nn.Linear, nn.Conv2d)):
  137 + wd_params.append(module.weight)
  138 + if not module.bias is None:
  139 + nowd_params.append(module.bias)
  140 + elif isinstance(module, nn.BatchNorm2d):
  141 + nowd_params += list(module.parameters())
  142 + return wd_params, nowd_params
  143 +
  144 +
  145 +### This is not used, since I replace this with the resnet feature with the same size
  146 +class SpatialPath(nn.Module):
  147 + def __init__(self, *args, **kwargs):
  148 + super(SpatialPath, self).__init__()
  149 + self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
  150 + self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
  151 + self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
  152 + self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
  153 + self.init_weight()
  154 +
  155 + def forward(self, x):
  156 + feat = self.conv1(x)
  157 + feat = self.conv2(feat)
  158 + feat = self.conv3(feat)
  159 + feat = self.conv_out(feat)
  160 + return feat
  161 +
  162 + def init_weight(self):
  163 + for ly in self.children():
  164 + if isinstance(ly, nn.Conv2d):
  165 + nn.init.kaiming_normal_(ly.weight, a=1)
  166 + if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  167 +
  168 + def get_params(self):
  169 + wd_params, nowd_params = [], []
  170 + for name, module in self.named_modules():
  171 + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
  172 + wd_params.append(module.weight)
  173 + if not module.bias is None:
  174 + nowd_params.append(module.bias)
  175 + elif isinstance(module, nn.BatchNorm2d):
  176 + nowd_params += list(module.parameters())
  177 + return wd_params, nowd_params
  178 +
  179 +
  180 +class FeatureFusionModule(nn.Module):
  181 + def __init__(self, in_chan, out_chan, *args, **kwargs):
  182 + super(FeatureFusionModule, self).__init__()
  183 + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
  184 + self.conv1 = nn.Conv2d(out_chan,
  185 + out_chan//4,
  186 + kernel_size = 1,
  187 + stride = 1,
  188 + padding = 0,
  189 + bias = False)
  190 + self.conv2 = nn.Conv2d(out_chan//4,
  191 + out_chan,
  192 + kernel_size = 1,
  193 + stride = 1,
  194 + padding = 0,
  195 + bias = False)
  196 + self.relu = nn.ReLU(inplace=True)
  197 + self.sigmoid = nn.Sigmoid()
  198 + self.init_weight()
  199 +
  200 + def forward(self, fsp, fcp):
  201 + fcat = torch.cat([fsp, fcp], dim=1)
  202 + feat = self.convblk(fcat)
  203 + atten = F.avg_pool2d(feat, feat.size()[2:])
  204 + atten = self.conv1(atten)
  205 + atten = self.relu(atten)
  206 + atten = self.conv2(atten)
  207 + atten = self.sigmoid(atten)
  208 + feat_atten = torch.mul(feat, atten)
  209 + feat_out = feat_atten + feat
  210 + return feat_out
  211 +
  212 + def init_weight(self):
  213 + for ly in self.children():
  214 + if isinstance(ly, nn.Conv2d):
  215 + nn.init.kaiming_normal_(ly.weight, a=1)
  216 + if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  217 +
  218 + def get_params(self):
  219 + wd_params, nowd_params = [], []
  220 + for name, module in self.named_modules():
  221 + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
  222 + wd_params.append(module.weight)
  223 + if not module.bias is None:
  224 + nowd_params.append(module.bias)
  225 + elif isinstance(module, nn.BatchNorm2d):
  226 + nowd_params += list(module.parameters())
  227 + return wd_params, nowd_params
  228 +
  229 +
  230 +class BiSeNet(nn.Module):
  231 + def __init__(self, n_classes, *args, **kwargs):
  232 + super(BiSeNet, self).__init__()
  233 + self.cp = ContextPath()
  234 + ## here self.sp is deleted
  235 + self.ffm = FeatureFusionModule(256, 256)
  236 + self.conv_out = BiSeNetOutput(256, 256, n_classes)
  237 + self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
  238 + self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
  239 + self.init_weight()
  240 +
  241 + def forward(self, x):
  242 + H, W = x.size()[2:]
  243 + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
  244 + feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
  245 + feat_fuse = self.ffm(feat_sp, feat_cp8)
  246 +
  247 + feat_out = self.conv_out(feat_fuse)
  248 + feat_out16 = self.conv_out16(feat_cp8)
  249 + feat_out32 = self.conv_out32(feat_cp16)
  250 +
  251 + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
  252 + feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
  253 + feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
  254 +
  255 + # return feat_out, feat_out16, feat_out32
  256 + return feat_out
  257 +
  258 + def init_weight(self):
  259 + for ly in self.children():
  260 + if isinstance(ly, nn.Conv2d):
  261 + nn.init.kaiming_normal_(ly.weight, a=1)
  262 + if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  263 +
  264 + def get_params(self):
  265 + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
  266 + for name, child in self.named_children():
  267 + child_wd_params, child_nowd_params = child.get_params()
  268 + if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
  269 + lr_mul_wd_params += child_wd_params
  270 + lr_mul_nowd_params += child_nowd_params
  271 + else:
  272 + wd_params += child_wd_params
  273 + nowd_params += child_nowd_params
  274 + return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
  275 +
  276 +
  277 +if __name__ == "__main__":
  278 + net = BiSeNet(19)
  279 + net.cuda()
  280 + net.eval()
  281 + in_ten = torch.randn(16, 3, 640, 480).cuda()
  282 + out, out16, out32 = net(in_ten)
  283 + print(out.shape)
  284 +
  285 + net.get_params()
  1 +#!/usr/bin/python
  2 +# -*- encoding: utf-8 -*-
  3 +
  4 +import torch
  5 +import torch.nn as nn
  6 +import torch.nn.functional as F
  7 +import torch.utils.model_zoo as modelzoo
  8 +
  9 +# from modules.bn import InPlaceABNSync as BatchNorm2d
  10 +
  11 +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
  12 +
  13 +
  14 +def conv3x3(in_planes, out_planes, stride=1):
  15 + """3x3 convolution with padding"""
  16 + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  17 + padding=1, bias=False)
  18 +
  19 +
  20 +class BasicBlock(nn.Module):
  21 + def __init__(self, in_chan, out_chan, stride=1):
  22 + super(BasicBlock, self).__init__()
  23 + self.conv1 = conv3x3(in_chan, out_chan, stride)
  24 + self.bn1 = nn.BatchNorm2d(out_chan)
  25 + self.conv2 = conv3x3(out_chan, out_chan)
  26 + self.bn2 = nn.BatchNorm2d(out_chan)
  27 + self.relu = nn.ReLU(inplace=True)
  28 + self.downsample = None
  29 + if in_chan != out_chan or stride != 1:
  30 + self.downsample = nn.Sequential(
  31 + nn.Conv2d(in_chan, out_chan,
  32 + kernel_size=1, stride=stride, bias=False),
  33 + nn.BatchNorm2d(out_chan),
  34 + )
  35 +
  36 + def forward(self, x):
  37 + residual = self.conv1(x)
  38 + residual = F.relu(self.bn1(residual))
  39 + residual = self.conv2(residual)
  40 + residual = self.bn2(residual)
  41 +
  42 + shortcut = x
  43 + if self.downsample is not None:
  44 + shortcut = self.downsample(x)
  45 +
  46 + out = shortcut + residual
  47 + out = self.relu(out)
  48 + return out
  49 +
  50 +
  51 +def create_layer_basic(in_chan, out_chan, bnum, stride=1):
  52 + layers = [BasicBlock(in_chan, out_chan, stride=stride)]
  53 + for i in range(bnum-1):
  54 + layers.append(BasicBlock(out_chan, out_chan, stride=1))
  55 + return nn.Sequential(*layers)
  56 +
  57 +
  58 +class Resnet18(nn.Module):
  59 + def __init__(self):
  60 + super(Resnet18, self).__init__()
  61 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
  62 + bias=False)
  63 + self.bn1 = nn.BatchNorm2d(64)
  64 + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  65 + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
  66 + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
  67 + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
  68 + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
  69 + self.init_weight()
  70 +
  71 + def forward(self, x):
  72 + x = self.conv1(x)
  73 + x = F.relu(self.bn1(x))
  74 + x = self.maxpool(x)
  75 +
  76 + x = self.layer1(x)
  77 + feat8 = self.layer2(x) # 1/8
  78 + feat16 = self.layer3(feat8) # 1/16
  79 + feat32 = self.layer4(feat16) # 1/32
  80 + return feat8, feat16, feat32
  81 +
  82 + def init_weight(self):
  83 + state_dict = modelzoo.load_url(resnet18_url)
  84 + self_state_dict = self.state_dict()
  85 + for k, v in state_dict.items():
  86 + if 'fc' in k: continue
  87 + self_state_dict.update({k: v})
  88 + self.load_state_dict(self_state_dict)
  89 +
  90 + def get_params(self):
  91 + wd_params, nowd_params = [], []
  92 + for name, module in self.named_modules():
  93 + if isinstance(module, (nn.Linear, nn.Conv2d)):
  94 + wd_params.append(module.weight)
  95 + if not module.bias is None:
  96 + nowd_params.append(module.bias)
  97 + elif isinstance(module, nn.BatchNorm2d):
  98 + nowd_params += list(module.parameters())
  99 + return wd_params, nowd_params
  100 +
  101 +
  102 +if __name__ == "__main__":
  103 + net = Resnet18()
  104 + x = torch.randn(16, 3, 224, 224)
  105 + out = net(x)
  106 + print(out[0].size())
  107 + print(out[1].size())
  108 + print(out[2].size())
  109 + net.get_params()
  1 +#!/usr/bin/python
  2 +# -*- encoding: utf-8 -*-
  3 +import numpy as np
  4 +from model import BiSeNet
  5 +
  6 +import torch
  7 +
  8 +import os
  9 +import os.path as osp
  10 +
  11 +from PIL import Image
  12 +import torchvision.transforms as transforms
  13 +import cv2
  14 +from pathlib import Path
  15 +import configargparse
  16 +import tqdm
  17 +
  18 +# import ttach as tta
  19 +
  20 +def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg',
  21 + img_size=(512, 512)):
  22 + im = np.array(im)
  23 + vis_im = im.copy().astype(np.uint8)
  24 + vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
  25 + vis_parsing_anno = cv2.resize(
  26 + vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
  27 + vis_parsing_anno_color = np.zeros(
  28 + (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255
  29 +
  30 + num_of_class = np.max(vis_parsing_anno)
  31 + # print(num_of_class)
  32 + for pi in range(1, 14):
  33 + index = np.where(vis_parsing_anno == pi)
  34 + vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
  35 +
  36 + for pi in range(14, 16):
  37 + index = np.where(vis_parsing_anno == pi)
  38 + vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0])
  39 + for pi in range(16, 17):
  40 + index = np.where(vis_parsing_anno == pi)
  41 + vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255])
  42 + for pi in range(17, num_of_class+1):
  43 + index = np.where(vis_parsing_anno == pi)
  44 + vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
  45 +
  46 + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
  47 + index = np.where(vis_parsing_anno == num_of_class-1)
  48 + vis_im = cv2.resize(vis_parsing_anno_color, img_size,
  49 + interpolation=cv2.INTER_NEAREST)
  50 + if save_im:
  51 + cv2.imwrite(save_path, vis_im)
  52 +
  53 +
  54 +def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
  55 +
  56 + Path(respth).mkdir(parents=True, exist_ok=True)
  57 +
  58 + print(f'[INFO] loading model...')
  59 + n_classes = 19
  60 + net = BiSeNet(n_classes=n_classes)
  61 + net.cuda()
  62 + net.load_state_dict(torch.load(cp))
  63 + net.eval()
  64 +
  65 + to_tensor = transforms.Compose([
  66 + transforms.ToTensor(),
  67 + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  68 + ])
  69 +
  70 + image_paths = os.listdir(dspth)
  71 +
  72 + with torch.no_grad():
  73 + for image_path in tqdm.tqdm(image_paths):
  74 + if image_path.endswith('.jpg') or image_path.endswith('.png'):
  75 + img = Image.open(osp.join(dspth, image_path))
  76 + ori_size = img.size
  77 + image = img.resize((512, 512), Image.BILINEAR)
  78 + image = image.convert("RGB")
  79 + img = to_tensor(image)
  80 +
  81 + # test-time augmentation.
  82 + inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512]
  83 + outputs = net(inputs.cuda())
  84 + parsing = outputs.mean(0).cpu().numpy().argmax(0)
  85 +
  86 + image_path = int(image_path[:-4])
  87 + image_path = str(image_path) + '.png'
  88 +
  89 + vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size)
  90 +
  91 +
  92 +if __name__ == "__main__":
  93 + parser = configargparse.ArgumentParser()
  94 + parser.add_argument('--respath', type=str, default='./result/', help='result path for label')
  95 + parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images')
  96 + parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth')
  97 + args = parser.parse_args()
  98 + evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath)
  1 +import numpy as np
  2 +from scipy.io import loadmat
  3 +
  4 +original_BFM = loadmat("3DMM/01_MorphableModel.mat")
  5 +sub_inds = np.load("3DMM/topology_info.npy", allow_pickle=True).item()["sub_inds"]
  6 +
  7 +shapePC = original_BFM["shapePC"]
  8 +shapeEV = original_BFM["shapeEV"]
  9 +shapeMU = original_BFM["shapeMU"]
  10 +texPC = original_BFM["texPC"]
  11 +texEV = original_BFM["texEV"]
  12 +texMU = original_BFM["texMU"]
  13 +
  14 +b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3)
  15 +mu_shape = shapeMU.reshape(-1, 3)
  16 +
  17 +b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3)
  18 +mu_tex = texMU.reshape(-1, 3)
  19 +
  20 +b_shape = b_shape[:, sub_inds, :].reshape(199, -1)
  21 +mu_shape = mu_shape[sub_inds, :].reshape(-1)
  22 +b_tex = b_tex[:, sub_inds, :].reshape(199, -1)
  23 +mu_tex = mu_tex[sub_inds, :].reshape(-1)
  24 +
  25 +exp_info = np.load("3DMM/exp_info.npy", allow_pickle=True).item()
  26 +np.save(
  27 + "3DMM/3DMM_info.npy",
  28 + {
  29 + "mu_shape": mu_shape,
  30 + "b_shape": b_shape,
  31 + "sig_shape": shapeEV.reshape(-1),
  32 + "mu_exp": exp_info["mu_exp"],
  33 + "b_exp": exp_info["base_exp"],
  34 + "sig_exp": exp_info["sig_exp"],
  35 + "mu_tex": mu_tex,
  36 + "b_tex": b_tex,
  37 + "sig_tex": texEV.reshape(-1),
  38 + },
  39 +)
  1 +import os
  2 +import torch
  3 +import numpy as np
  4 +
  5 +
  6 +def load_dir(path, start, end):
  7 + lmss = []
  8 + imgs_paths = []
  9 + for i in range(start, end):
  10 + if os.path.isfile(os.path.join(path, str(i) + ".lms")):
  11 + lms = np.loadtxt(os.path.join(path, str(i) + ".lms"), dtype=np.float32)
  12 + lmss.append(lms)
  13 + imgs_paths.append(os.path.join(path, str(i) + ".jpg"))
  14 + lmss = np.stack(lmss)
  15 + lmss = torch.as_tensor(lmss).cuda()
  16 + return lmss, imgs_paths
  1 +import os
  2 +import sys
  3 +import cv2
  4 +import argparse
  5 +from pathlib import Path
  6 +import torch
  7 +import numpy as np
  8 +from data_loader import load_dir
  9 +from facemodel import Face_3DMM
  10 +from util import *
  11 +from render_3dmm import Render_3DMM
  12 +
  13 +
  14 +# torch.autograd.set_detect_anomaly(True)
  15 +
  16 +dir_path = os.path.dirname(os.path.realpath(__file__))
  17 +
  18 +
  19 +def set_requires_grad(tensor_list):
  20 + for tensor in tensor_list:
  21 + tensor.requires_grad = True
  22 +
  23 +
  24 +parser = argparse.ArgumentParser()
  25 +parser.add_argument(
  26 + "--path", type=str, default="obama/ori_imgs", help="idname of target person"
  27 +)
  28 +parser.add_argument("--img_h", type=int, default=512, help="image height")
  29 +parser.add_argument("--img_w", type=int, default=512, help="image width")
  30 +parser.add_argument("--frame_num", type=int, default=11000, help="image number")
  31 +args = parser.parse_args()
  32 +
  33 +start_id = 0
  34 +end_id = args.frame_num
  35 +
  36 +lms, img_paths = load_dir(args.path, start_id, end_id)
  37 +num_frames = lms.shape[0]
  38 +h, w = args.img_h, args.img_w
  39 +cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda()
  40 +id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650
  41 +model_3dmm = Face_3DMM(
  42 + os.path.join(dir_path, "3DMM"), id_dim, exp_dim, tex_dim, point_num
  43 +)
  44 +
  45 +# only use one image per 40 to do fit the focal length
  46 +sel_ids = np.arange(0, num_frames, 40)
  47 +sel_num = sel_ids.shape[0]
  48 +arg_focal = 1600
  49 +arg_landis = 1e5
  50 +
  51 +print(f'[INFO] fitting focal length...')
  52 +
  53 +# fit the focal length
  54 +for focal in range(600, 1500, 100):
  55 + id_para = lms.new_zeros((1, id_dim), requires_grad=True)
  56 + exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True)
  57 + euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True)
  58 + trans = lms.new_zeros((sel_num, 3), requires_grad=True)
  59 + trans.data[:, 2] -= 7
  60 + focal_length = lms.new_zeros(1, requires_grad=False)
  61 + focal_length.data += focal
  62 + set_requires_grad([id_para, exp_para, euler_angle, trans])
  63 +
  64 + optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
  65 + optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=0.1)
  66 +
  67 + for iter in range(2000):
  68 + id_para_batch = id_para.expand(sel_num, -1)
  69 + geometry = model_3dmm.get_3dlandmarks(
  70 + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
  71 + )
  72 + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
  73 + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
  74 + loss = loss_lan
  75 + optimizer_frame.zero_grad()
  76 + loss.backward()
  77 + optimizer_frame.step()
  78 + # if iter % 100 == 0:
  79 + # print(focal, 'pose', iter, loss.item())
  80 +
  81 + for iter in range(2500):
  82 + id_para_batch = id_para.expand(sel_num, -1)
  83 + geometry = model_3dmm.get_3dlandmarks(
  84 + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
  85 + )
  86 + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
  87 + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
  88 + loss_regid = torch.mean(id_para * id_para)
  89 + loss_regexp = torch.mean(exp_para * exp_para)
  90 + loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
  91 + optimizer_idexp.zero_grad()
  92 + optimizer_frame.zero_grad()
  93 + loss.backward()
  94 + optimizer_idexp.step()
  95 + optimizer_frame.step()
  96 + # if iter % 100 == 0:
  97 + # print(focal, 'poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
  98 +
  99 + if iter % 1500 == 0 and iter >= 1500:
  100 + for param_group in optimizer_idexp.param_groups:
  101 + param_group["lr"] *= 0.2
  102 + for param_group in optimizer_frame.param_groups:
  103 + param_group["lr"] *= 0.2
  104 +
  105 + print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item())
  106 +
  107 + if loss_lan.item() < arg_landis:
  108 + arg_landis = loss_lan.item()
  109 + arg_focal = focal
  110 +
  111 +print("[INFO] find best focal:", arg_focal)
  112 +
  113 +print(f'[INFO] coarse fitting...')
  114 +
  115 +# for all frames, do a coarse fitting ???
  116 +id_para = lms.new_zeros((1, id_dim), requires_grad=True)
  117 +exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
  118 +tex_para = lms.new_zeros(
  119 + (1, tex_dim), requires_grad=True
  120 +) # not optimized in this block ???
  121 +euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
  122 +trans = lms.new_zeros((num_frames, 3), requires_grad=True)
  123 +light_para = lms.new_zeros((num_frames, 27), requires_grad=True)
  124 +trans.data[:, 2] -= 7 # ???
  125 +focal_length = lms.new_zeros(1, requires_grad=True)
  126 +focal_length.data += arg_focal
  127 +
  128 +set_requires_grad([id_para, exp_para, tex_para, euler_angle, trans, light_para])
  129 +
  130 +optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
  131 +optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1)
  132 +
  133 +for iter in range(1500):
  134 + id_para_batch = id_para.expand(num_frames, -1)
  135 + geometry = model_3dmm.get_3dlandmarks(
  136 + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
  137 + )
  138 + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
  139 + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
  140 + loss = loss_lan
  141 + optimizer_frame.zero_grad()
  142 + loss.backward()
  143 + optimizer_frame.step()
  144 + if iter == 1000:
  145 + for param_group in optimizer_frame.param_groups:
  146 + param_group["lr"] = 0.1
  147 + # if iter % 100 == 0:
  148 + # print('pose', iter, loss.item())
  149 +
  150 +for param_group in optimizer_frame.param_groups:
  151 + param_group["lr"] = 0.1
  152 +
  153 +for iter in range(2000):
  154 + id_para_batch = id_para.expand(num_frames, -1)
  155 + geometry = model_3dmm.get_3dlandmarks(
  156 + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
  157 + )
  158 + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
  159 + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
  160 + loss_regid = torch.mean(id_para * id_para)
  161 + loss_regexp = torch.mean(exp_para * exp_para)
  162 + loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
  163 + optimizer_idexp.zero_grad()
  164 + optimizer_frame.zero_grad()
  165 + loss.backward()
  166 + optimizer_idexp.step()
  167 + optimizer_frame.step()
  168 + # if iter % 100 == 0:
  169 + # print('poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
  170 + if iter % 1000 == 0 and iter >= 1000:
  171 + for param_group in optimizer_idexp.param_groups:
  172 + param_group["lr"] *= 0.2
  173 + for param_group in optimizer_frame.param_groups:
  174 + param_group["lr"] *= 0.2
  175 +
  176 +print(loss_lan.item(), torch.mean(trans[:, 2]).item())
  177 +
  178 +print(f'[INFO] fitting light...')
  179 +
  180 +batch_size = 32
  181 +
  182 +device_default = torch.device("cuda:0")
  183 +device_render = torch.device("cuda:0")
  184 +renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render)
  185 +
  186 +sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size]
  187 +imgs = []
  188 +for sel_id in sel_ids:
  189 + imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
  190 +imgs = np.stack(imgs)
  191 +sel_imgs = torch.as_tensor(imgs).cuda()
  192 +sel_lms = lms[sel_ids]
  193 +sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
  194 +set_requires_grad([sel_light])
  195 +
  196 +optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=0.1)
  197 +optimizer_id_frame = torch.optim.Adam([euler_angle, trans, exp_para, id_para], lr=0.01)
  198 +
  199 +for iter in range(71):
  200 + sel_exp_para, sel_euler, sel_trans = (
  201 + exp_para[sel_ids],
  202 + euler_angle[sel_ids],
  203 + trans[sel_ids],
  204 + )
  205 + sel_id_para = id_para.expand(batch_size, -1)
  206 + geometry = model_3dmm.get_3dlandmarks(
  207 + sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
  208 + )
  209 + proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
  210 +
  211 + loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
  212 + loss_regid = torch.mean(id_para * id_para)
  213 + loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
  214 +
  215 + sel_tex_para = tex_para.expand(batch_size, -1)
  216 + sel_texture = model_3dmm.forward_tex(sel_tex_para)
  217 + geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
  218 + rott_geo = forward_rott(geometry, sel_euler, sel_trans)
  219 + render_imgs = renderer(
  220 + rott_geo.to(device_render),
  221 + sel_texture.to(device_render),
  222 + sel_light.to(device_render),
  223 + )
  224 + render_imgs = render_imgs.to(device_default)
  225 +
  226 + mask = (render_imgs[:, :, :, 3]).detach() > 0.0
  227 + render_proj = sel_imgs.clone()
  228 + render_proj[mask] = render_imgs[mask][..., :3].byte()
  229 + loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
  230 +
  231 + if iter > 50:
  232 + loss = loss_col + loss_lan * 0.05 + loss_regid * 1.0 + loss_regexp * 0.8
  233 + else:
  234 + loss = loss_col + loss_lan * 3 + loss_regid * 2.0 + loss_regexp * 1.0
  235 +
  236 + optimizer_tl.zero_grad()
  237 + optimizer_id_frame.zero_grad()
  238 + loss.backward()
  239 +
  240 + optimizer_tl.step()
  241 + optimizer_id_frame.step()
  242 +
  243 + if iter % 50 == 0 and iter > 0:
  244 + for param_group in optimizer_id_frame.param_groups:
  245 + param_group["lr"] *= 0.2
  246 + for param_group in optimizer_tl.param_groups:
  247 + param_group["lr"] *= 0.2
  248 + # print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item())
  249 +
  250 +
  251 +light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1)
  252 +light_para.data = light_mean
  253 +
  254 +exp_para = exp_para.detach()
  255 +euler_angle = euler_angle.detach()
  256 +trans = trans.detach()
  257 +light_para = light_para.detach()
  258 +
  259 +print(f'[INFO] fine frame-wise fitting...')
  260 +
  261 +for i in range(int((num_frames - 1) / batch_size + 1)):
  262 +
  263 + if (i + 1) * batch_size > num_frames:
  264 + start_n = num_frames - batch_size
  265 + sel_ids = np.arange(num_frames - batch_size, num_frames)
  266 + else:
  267 + start_n = i * batch_size
  268 + sel_ids = np.arange(i * batch_size, i * batch_size + batch_size)
  269 +
  270 + imgs = []
  271 + for sel_id in sel_ids:
  272 + imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
  273 + imgs = np.stack(imgs)
  274 + sel_imgs = torch.as_tensor(imgs).cuda()
  275 + sel_lms = lms[sel_ids]
  276 +
  277 + sel_exp_para = exp_para.new_zeros((batch_size, exp_dim), requires_grad=True)
  278 + sel_exp_para.data = exp_para[sel_ids].clone()
  279 + sel_euler = euler_angle.new_zeros((batch_size, 3), requires_grad=True)
  280 + sel_euler.data = euler_angle[sel_ids].clone()
  281 + sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
  282 + sel_trans.data = trans[sel_ids].clone()
  283 + sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
  284 + sel_light.data = light_para[sel_ids].clone()
  285 +
  286 + set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light])
  287 +
  288 + optimizer_cur_batch = torch.optim.Adam(
  289 + [sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005
  290 + )
  291 +
  292 + sel_id_para = id_para.expand(batch_size, -1).detach()
  293 + sel_tex_para = tex_para.expand(batch_size, -1).detach()
  294 +
  295 + pre_num = 5
  296 +
  297 + if i > 0:
  298 + pre_ids = np.arange(start_n - pre_num, start_n)
  299 +
  300 + for iter in range(50):
  301 +
  302 + geometry = model_3dmm.get_3dlandmarks(
  303 + sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
  304 + )
  305 + proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
  306 + loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
  307 + loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
  308 +
  309 + sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
  310 + sel_texture = model_3dmm.forward_tex(sel_tex_para)
  311 + geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
  312 + rott_geo = forward_rott(geometry, sel_euler, sel_trans)
  313 + render_imgs = renderer(
  314 + rott_geo.to(device_render),
  315 + sel_texture.to(device_render),
  316 + sel_light.to(device_render),
  317 + )
  318 + render_imgs = render_imgs.to(device_default)
  319 +
  320 + mask = (render_imgs[:, :, :, 3]).detach() > 0.0
  321 +
  322 + loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
  323 +
  324 + if i > 0:
  325 + geometry_lap = model_3dmm.forward_geo_sub(
  326 + id_para.expand(batch_size + pre_num, -1).detach(),
  327 + torch.cat((exp_para[pre_ids].detach(), sel_exp_para)),
  328 + model_3dmm.rigid_ids,
  329 + )
  330 + rott_geo_lap = forward_rott(
  331 + geometry_lap,
  332 + torch.cat((euler_angle[pre_ids].detach(), sel_euler)),
  333 + torch.cat((trans[pre_ids].detach(), sel_trans)),
  334 + )
  335 + loss_lap = cal_lap_loss(
  336 + [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
  337 + )
  338 + else:
  339 + geometry_lap = model_3dmm.forward_geo_sub(
  340 + id_para.expand(batch_size, -1).detach(),
  341 + sel_exp_para,
  342 + model_3dmm.rigid_ids,
  343 + )
  344 + rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans)
  345 + loss_lap = cal_lap_loss(
  346 + [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
  347 + )
  348 +
  349 +
  350 + if iter > 30:
  351 + loss = loss_col * 0.5 + loss_lan * 1.5 + loss_lap * 100000 + loss_regexp * 1.0
  352 + else:
  353 + loss = loss_col * 0.5 + loss_lan * 8 + loss_lap * 100000 + loss_regexp * 1.0
  354 +
  355 + optimizer_cur_batch.zero_grad()
  356 + loss.backward()
  357 + optimizer_cur_batch.step()
  358 +
  359 + # if iter % 10 == 0:
  360 + # print(
  361 + # i,
  362 + # iter,
  363 + # loss_col.item(),
  364 + # loss_lan.item(),
  365 + # loss_lap.item(),
  366 + # loss_regexp.item(),
  367 + # )
  368 +
  369 + print(str(i) + " of " + str(int((num_frames - 1) / batch_size + 1)) + " done")
  370 +
  371 + render_proj = sel_imgs.clone()
  372 + render_proj[mask] = render_imgs[mask][..., :3].byte()
  373 +
  374 + exp_para[sel_ids] = sel_exp_para.clone()
  375 + euler_angle[sel_ids] = sel_euler.clone()
  376 + trans[sel_ids] = sel_trans.clone()
  377 + light_para[sel_ids] = sel_light.clone()
  378 +
  379 +torch.save(
  380 + {
  381 + "id": id_para.detach().cpu(),
  382 + "exp": exp_para.detach().cpu(),
  383 + "euler": euler_angle.detach().cpu(),
  384 + "trans": trans.detach().cpu(),
  385 + "focal": focal_length.detach().cpu(),
  386 + },
  387 + os.path.join(os.path.dirname(args.path), "track_params.pt"),
  388 +)
  389 +
  390 +print("params saved")
  1 +import torch
  2 +import torch.nn as nn
  3 +import numpy as np
  4 +import os
  5 +from util import *
  6 +
  7 +
  8 +class Face_3DMM(nn.Module):
  9 + def __init__(self, modelpath, id_dim, exp_dim, tex_dim, point_num):
  10 + super(Face_3DMM, self).__init__()
  11 + # id_dim = 100
  12 + # exp_dim = 79
  13 + # tex_dim = 100
  14 + self.point_num = point_num
  15 + DMM_info = np.load(
  16 + os.path.join(modelpath, "3DMM_info.npy"), allow_pickle=True
  17 + ).item()
  18 + base_id = DMM_info["b_shape"][:id_dim, :]
  19 + mu_id = DMM_info["mu_shape"]
  20 + base_exp = DMM_info["b_exp"][:exp_dim, :]
  21 + mu_exp = DMM_info["mu_exp"]
  22 + mu = mu_id + mu_exp
  23 + mu = mu.reshape(-1, 3)
  24 + for i in range(3):
  25 + mu[:, i] -= np.mean(mu[:, i])
  26 + mu = mu.reshape(-1)
  27 + self.base_id = torch.as_tensor(base_id).cuda() / 100000.0
  28 + self.base_exp = torch.as_tensor(base_exp).cuda() / 100000.0
  29 + self.mu = torch.as_tensor(mu).cuda() / 100000.0
  30 + base_tex = DMM_info["b_tex"][:tex_dim, :]
  31 + mu_tex = DMM_info["mu_tex"]
  32 + self.base_tex = torch.as_tensor(base_tex).cuda()
  33 + self.mu_tex = torch.as_tensor(mu_tex).cuda()
  34 + sig_id = DMM_info["sig_shape"][:id_dim]
  35 + sig_tex = DMM_info["sig_tex"][:tex_dim]
  36 + sig_exp = DMM_info["sig_exp"][:exp_dim]
  37 + self.sig_id = torch.as_tensor(sig_id).cuda()
  38 + self.sig_tex = torch.as_tensor(sig_tex).cuda()
  39 + self.sig_exp = torch.as_tensor(sig_exp).cuda()
  40 +
  41 + keys_info = np.load(
  42 + os.path.join(modelpath, "keys_info.npy"), allow_pickle=True
  43 + ).item()
  44 + self.keyinds = torch.as_tensor(keys_info["keyinds"]).cuda()
  45 + self.left_contours = torch.as_tensor(keys_info["left_contour"]).cuda()
  46 + self.right_contours = torch.as_tensor(keys_info["right_contour"]).cuda()
  47 + self.rigid_ids = torch.as_tensor(keys_info["rigid_ids"]).cuda()
  48 +
  49 + def get_3dlandmarks(self, id_para, exp_para, euler_angle, trans, focal_length, cxy):
  50 + id_para = id_para * self.sig_id
  51 + exp_para = exp_para * self.sig_exp
  52 + batch_size = id_para.shape[0]
  53 + num_per_contour = self.left_contours.shape[1]
  54 + left_contours_flat = self.left_contours.reshape(-1)
  55 + right_contours_flat = self.right_contours.reshape(-1)
  56 + sel_index = torch.cat(
  57 + (
  58 + 3 * left_contours_flat.unsqueeze(1),
  59 + 3 * left_contours_flat.unsqueeze(1) + 1,
  60 + 3 * left_contours_flat.unsqueeze(1) + 2,
  61 + ),
  62 + dim=1,
  63 + ).reshape(-1)
  64 + left_geometry = (
  65 + torch.mm(id_para, self.base_id[:, sel_index])
  66 + + torch.mm(exp_para, self.base_exp[:, sel_index])
  67 + + self.mu[sel_index]
  68 + )
  69 + left_geometry = left_geometry.view(batch_size, -1, 3)
  70 + proj_x = forward_transform(
  71 + left_geometry, euler_angle, trans, focal_length, cxy
  72 + )[:, :, 0]
  73 + proj_x = proj_x.reshape(batch_size, 8, num_per_contour)
  74 + arg_min = proj_x.argmin(dim=2)
  75 + left_geometry = left_geometry.view(batch_size * 8, num_per_contour, 3)
  76 + left_3dlands = left_geometry[
  77 + torch.arange(batch_size * 8), arg_min.view(-1), :
  78 + ].view(batch_size, 8, 3)
  79 +
  80 + sel_index = torch.cat(
  81 + (
  82 + 3 * right_contours_flat.unsqueeze(1),
  83 + 3 * right_contours_flat.unsqueeze(1) + 1,
  84 + 3 * right_contours_flat.unsqueeze(1) + 2,
  85 + ),
  86 + dim=1,
  87 + ).reshape(-1)
  88 + right_geometry = (
  89 + torch.mm(id_para, self.base_id[:, sel_index])
  90 + + torch.mm(exp_para, self.base_exp[:, sel_index])
  91 + + self.mu[sel_index]
  92 + )
  93 + right_geometry = right_geometry.view(batch_size, -1, 3)
  94 + proj_x = forward_transform(
  95 + right_geometry, euler_angle, trans, focal_length, cxy
  96 + )[:, :, 0]
  97 + proj_x = proj_x.reshape(batch_size, 8, num_per_contour)
  98 + arg_max = proj_x.argmax(dim=2)
  99 + right_geometry = right_geometry.view(batch_size * 8, num_per_contour, 3)
  100 + right_3dlands = right_geometry[
  101 + torch.arange(batch_size * 8), arg_max.view(-1), :
  102 + ].view(batch_size, 8, 3)
  103 +
  104 + sel_index = torch.cat(
  105 + (
  106 + 3 * self.keyinds.unsqueeze(1),
  107 + 3 * self.keyinds.unsqueeze(1) + 1,
  108 + 3 * self.keyinds.unsqueeze(1) + 2,
  109 + ),
  110 + dim=1,
  111 + ).reshape(-1)
  112 + geometry = (
  113 + torch.mm(id_para, self.base_id[:, sel_index])
  114 + + torch.mm(exp_para, self.base_exp[:, sel_index])
  115 + + self.mu[sel_index]
  116 + )
  117 + lands_3d = geometry.view(-1, self.keyinds.shape[0], 3)
  118 + lands_3d[:, :8, :] = left_3dlands
  119 + lands_3d[:, 9:17, :] = right_3dlands
  120 + return lands_3d
  121 +
  122 + def forward_geo_sub(self, id_para, exp_para, sub_index):
  123 + id_para = id_para * self.sig_id
  124 + exp_para = exp_para * self.sig_exp
  125 + sel_index = torch.cat(
  126 + (
  127 + 3 * sub_index.unsqueeze(1),
  128 + 3 * sub_index.unsqueeze(1) + 1,
  129 + 3 * sub_index.unsqueeze(1) + 2,
  130 + ),
  131 + dim=1,
  132 + ).reshape(-1)
  133 + geometry = (
  134 + torch.mm(id_para, self.base_id[:, sel_index])
  135 + + torch.mm(exp_para, self.base_exp[:, sel_index])
  136 + + self.mu[sel_index]
  137 + )
  138 + return geometry.reshape(-1, sub_index.shape[0], 3)
  139 +
  140 + def forward_geo(self, id_para, exp_para):
  141 + id_para = id_para * self.sig_id
  142 + exp_para = exp_para * self.sig_exp
  143 + geometry = (
  144 + torch.mm(id_para, self.base_id)
  145 + + torch.mm(exp_para, self.base_exp)
  146 + + self.mu
  147 + )
  148 + return geometry.reshape(-1, self.point_num, 3)
  149 +
  150 + def forward_tex(self, tex_para):
  151 + tex_para = tex_para * self.sig_tex
  152 + texture = torch.mm(tex_para, self.base_tex) + self.mu_tex
  153 + return texture.reshape(-1, self.point_num, 3)
  1 +"""This module contains functions for geometry transform and camera projection"""
  2 +import torch
  3 +import torch.nn as nn
  4 +import numpy as np
  5 +
  6 +
  7 +def euler2rot(euler_angle):
  8 + batch_size = euler_angle.shape[0]
  9 + theta = euler_angle[:, 0].reshape(-1, 1, 1)
  10 + phi = euler_angle[:, 1].reshape(-1, 1, 1)
  11 + psi = euler_angle[:, 2].reshape(-1, 1, 1)
  12 + one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
  13 + zero = torch.zeros(
  14 + (batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device
  15 + )
  16 + rot_x = torch.cat(
  17 + (
  18 + torch.cat((one, zero, zero), 1),
  19 + torch.cat((zero, theta.cos(), theta.sin()), 1),
  20 + torch.cat((zero, -theta.sin(), theta.cos()), 1),
  21 + ),
  22 + 2,
  23 + )
  24 + rot_y = torch.cat(
  25 + (
  26 + torch.cat((phi.cos(), zero, -phi.sin()), 1),
  27 + torch.cat((zero, one, zero), 1),
  28 + torch.cat((phi.sin(), zero, phi.cos()), 1),
  29 + ),
  30 + 2,
  31 + )
  32 + rot_z = torch.cat(
  33 + (
  34 + torch.cat((psi.cos(), -psi.sin(), zero), 1),
  35 + torch.cat((psi.sin(), psi.cos(), zero), 1),
  36 + torch.cat((zero, zero, one), 1),
  37 + ),
  38 + 2,
  39 + )
  40 + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
  41 +
  42 +
  43 +def rot_trans_geo(geometry, rot, trans):
  44 + rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1)
  45 + return rott_geo.permute(0, 2, 1)
  46 +
  47 +
  48 +def euler_trans_geo(geometry, euler, trans):
  49 + rot = euler2rot(euler)
  50 + return rot_trans_geo(geometry, rot, trans)
  51 +
  52 +
  53 +def proj_geo(rott_geo, camera_para):
  54 + fx = camera_para[:, 0]
  55 + fy = camera_para[:, 0]
  56 + cx = camera_para[:, 1]
  57 + cy = camera_para[:, 2]
  58 +
  59 + X = rott_geo[:, :, 0]
  60 + Y = rott_geo[:, :, 1]
  61 + Z = rott_geo[:, :, 2]
  62 +
  63 + fxX = fx[:, None] * X
  64 + fyY = fy[:, None] * Y
  65 +
  66 + proj_x = -fxX / Z + cx[:, None]
  67 + proj_y = fyY / Z + cy[:, None]
  68 +
  69 + return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)
  1 +import torch
  2 +import torch.nn as nn
  3 +import numpy as np
  4 +import os
  5 +from pytorch3d.structures import Meshes
  6 +from pytorch3d.renderer import (
  7 + look_at_view_transform,
  8 + PerspectiveCameras,
  9 + FoVPerspectiveCameras,
  10 + PointLights,
  11 + DirectionalLights,
  12 + Materials,
  13 + RasterizationSettings,
  14 + MeshRenderer,
  15 + MeshRasterizer,
  16 + SoftPhongShader,
  17 + TexturesUV,
  18 + TexturesVertex,
  19 + blending,
  20 +)
  21 +
  22 +from pytorch3d.ops import interpolate_face_attributes
  23 +
  24 +from pytorch3d.renderer.blending import (
  25 + BlendParams,
  26 + hard_rgb_blend,
  27 + sigmoid_alpha_blend,
  28 + softmax_rgb_blend,
  29 +)
  30 +
  31 +
  32 +class SoftSimpleShader(nn.Module):
  33 + """
  34 + Per pixel lighting - the lighting model is applied using the interpolated
  35 + coordinates and normals for each pixel. The blending function returns the
  36 + soft aggregated color using all the faces per pixel.
  37 +
  38 + To use the default values, simply initialize the shader with the desired
  39 + device e.g.
  40 +
  41 + """
  42 +
  43 + def __init__(
  44 + self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
  45 + ):
  46 + super().__init__()
  47 + self.lights = lights if lights is not None else PointLights(device=device)
  48 + self.materials = (
  49 + materials if materials is not None else Materials(device=device)
  50 + )
  51 + self.cameras = cameras
  52 + self.blend_params = blend_params if blend_params is not None else BlendParams()
  53 +
  54 + def to(self, device):
  55 + # Manually move to device modules which are not subclasses of nn.Module
  56 + self.cameras = self.cameras.to(device)
  57 + self.materials = self.materials.to(device)
  58 + self.lights = self.lights.to(device)
  59 + return self
  60 +
  61 + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
  62 +
  63 + texels = meshes.sample_textures(fragments)
  64 + blend_params = kwargs.get("blend_params", self.blend_params)
  65 +
  66 + cameras = kwargs.get("cameras", self.cameras)
  67 + if cameras is None:
  68 + msg = "Cameras must be specified either at initialization \
  69 + or in the forward pass of SoftPhongShader"
  70 + raise ValueError(msg)
  71 + znear = kwargs.get("znear", getattr(cameras, "znear", 1.0))
  72 + zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
  73 + images = softmax_rgb_blend(
  74 + texels, fragments, blend_params, znear=znear, zfar=zfar
  75 + )
  76 + return images
  77 +
  78 +
  79 +class Render_3DMM(nn.Module):
  80 + def __init__(
  81 + self,
  82 + focal=1015,
  83 + img_h=500,
  84 + img_w=500,
  85 + batch_size=1,
  86 + device=torch.device("cuda:0"),
  87 + ):
  88 + super(Render_3DMM, self).__init__()
  89 +
  90 + self.focal = focal
  91 + self.img_h = img_h
  92 + self.img_w = img_w
  93 + self.device = device
  94 + self.renderer = self.get_render(batch_size)
  95 +
  96 + dir_path = os.path.dirname(os.path.realpath(__file__))
  97 + topo_info = np.load(
  98 + os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True
  99 + ).item()
  100 + self.tris = torch.as_tensor(topo_info["tris"]).to(self.device)
  101 + self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device)
  102 +
  103 + def compute_normal(self, geometry):
  104 + vert_1 = torch.index_select(geometry, 1, self.tris[:, 0])
  105 + vert_2 = torch.index_select(geometry, 1, self.tris[:, 1])
  106 + vert_3 = torch.index_select(geometry, 1, self.tris[:, 2])
  107 + nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
  108 + tri_normal = nn.functional.normalize(nnorm, dim=2)
  109 + v_norm = tri_normal[:, self.vert_tris, :].sum(2)
  110 + vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2)
  111 + return vert_normal
  112 +
  113 + def get_render(self, batch_size=1):
  114 + half_s = self.img_w * 0.5
  115 + R, T = look_at_view_transform(10, 0, 0)
  116 + R = R.repeat(batch_size, 1, 1)
  117 + T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device)
  118 +
  119 + cameras = FoVPerspectiveCameras(
  120 + device=self.device,
  121 + R=R,
  122 + T=T,
  123 + znear=0.01,
  124 + zfar=20,
  125 + fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi,
  126 + )
  127 + lights = PointLights(
  128 + device=self.device,
  129 + location=[[0.0, 0.0, 1e5]],
  130 + ambient_color=[[1, 1, 1]],
  131 + specular_color=[[0.0, 0.0, 0.0]],
  132 + diffuse_color=[[0.0, 0.0, 0.0]],
  133 + )
  134 + sigma = 1e-4
  135 + raster_settings = RasterizationSettings(
  136 + image_size=(self.img_h, self.img_w),
  137 + blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0,
  138 + faces_per_pixel=2,
  139 + perspective_correct=False,
  140 + )
  141 + blend_params = blending.BlendParams(background_color=[0, 0, 0])
  142 + renderer = MeshRenderer(
  143 + rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras),
  144 + shader=SoftSimpleShader(
  145 + lights=lights, blend_params=blend_params, cameras=cameras
  146 + ),
  147 + )
  148 + return renderer.to(self.device)
  149 +
  150 + @staticmethod
  151 + def Illumination_layer(face_texture, norm, gamma):
  152 +
  153 + n_b, num_vertex, _ = face_texture.size()
  154 + n_v_full = n_b * num_vertex
  155 + gamma = gamma.view(-1, 3, 9).clone()
  156 + gamma[:, :, 0] += 0.8
  157 +
  158 + gamma = gamma.permute(0, 2, 1)
  159 +
  160 + a0 = np.pi
  161 + a1 = 2 * np.pi / np.sqrt(3.0)
  162 + a2 = 2 * np.pi / np.sqrt(8.0)
  163 + c0 = 1 / np.sqrt(4 * np.pi)
  164 + c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi)
  165 + c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi)
  166 + d0 = 0.5 / np.sqrt(3.0)
  167 +
  168 + Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0
  169 + norm = norm.view(-1, 3)
  170 + nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2]
  171 + arrH = []
  172 +
  173 + arrH.append(Y0)
  174 + arrH.append(-a1 * c1 * ny)
  175 + arrH.append(a1 * c1 * nz)
  176 + arrH.append(-a1 * c1 * nx)
  177 + arrH.append(a2 * c2 * nx * ny)
  178 + arrH.append(-a2 * c2 * ny * nz)
  179 + arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1))
  180 + arrH.append(-a2 * c2 * nx * nz)
  181 + arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2)))
  182 +
  183 + H = torch.stack(arrH, 1)
  184 + Y = H.view(n_b, num_vertex, 9)
  185 + lighting = Y.bmm(gamma)
  186 +
  187 + face_color = face_texture * lighting
  188 + return face_color
  189 +
  190 + def forward(self, rott_geometry, texture, diffuse_sh):
  191 + face_normal = self.compute_normal(rott_geometry)
  192 + face_color = self.Illumination_layer(texture, face_normal, diffuse_sh)
  193 + face_color = TexturesVertex(face_color)
  194 + mesh = Meshes(
  195 + rott_geometry,
  196 + self.tris.float().repeat(rott_geometry.shape[0], 1, 1),
  197 + face_color,
  198 + )
  199 + rendered_img = self.renderer(mesh)
  200 + rendered_img = torch.clamp(rendered_img, 0, 255)
  201 +
  202 + return rendered_img
  1 +import torch
  2 +import torch.nn as nn
  3 +import render_util
  4 +import geo_transform
  5 +import numpy as np
  6 +
  7 +
  8 +def compute_tri_normal(geometry, tris):
  9 + geometry = geometry.permute(0, 2, 1)
  10 + tri_1 = tris[:, 0]
  11 + tri_2 = tris[:, 1]
  12 + tri_3 = tris[:, 2]
  13 +
  14 + vert_1 = torch.index_select(geometry, 2, tri_1)
  15 + vert_2 = torch.index_select(geometry, 2, tri_2)
  16 + vert_3 = torch.index_select(geometry, 2, tri_3)
  17 +
  18 + nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1)
  19 + normal = nn.functional.normalize(nnorm).permute(0, 2, 1)
  20 + return normal
  21 +
  22 +
  23 +class Compute_normal_base(torch.autograd.Function):
  24 + @staticmethod
  25 + def forward(ctx, normal):
  26 + (normal_b,) = render_util.normal_base_forward(normal)
  27 + ctx.save_for_backward(normal)
  28 + return normal_b
  29 +
  30 + @staticmethod
  31 + def backward(ctx, grad_normal_b):
  32 + (normal,) = ctx.saved_tensors
  33 + (grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal)
  34 + return grad_normal
  35 +
  36 +
  37 +class Normal_Base(torch.nn.Module):
  38 + def __init__(self):
  39 + super(Normal_Base, self).__init__()
  40 +
  41 + def forward(self, normal):
  42 + return Compute_normal_base.apply(normal)
  43 +
  44 +
  45 +def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img):
  46 + point_num = geometry.shape[1]
  47 + rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans)
  48 + proj_geo = geo_transform.proj_geo(rott_geo, cam)
  49 + rot_tri_normal = compute_tri_normal(rott_geo, tris)
  50 + rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris)
  51 + is_visible = -torch.bmm(
  52 + rot_vert_normal.reshape(-1, 1, 3),
  53 + nn.functional.normalize(rott_geo.reshape(-1, 3, 1)),
  54 + ).reshape(-1, point_num)
  55 + is_visible[is_visible < 0.01] = -1
  56 + pixel_valid = torch.zeros(
  57 + (ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]),
  58 + dtype=torch.float32,
  59 + device=ori_img.device,
  60 + )
  61 + return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid
  62 +
  63 +
  64 +class Render_Face(torch.autograd.Function):
  65 + @staticmethod
  66 + def forward(
  67 + ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
  68 + ):
  69 + batch_size, h, w, _ = ori_img.shape
  70 + ori_img = ori_img.view(batch_size, -1, 3)
  71 + ori_size = torch.cat(
  72 + (
  73 + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
  74 + * h,
  75 + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
  76 + * w,
  77 + ),
  78 + dim=1,
  79 + ).view(-1)
  80 + tri_index, tri_coord, render, real = render_util.render_face_forward(
  81 + proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid
  82 + )
  83 + ctx.save_for_backward(
  84 + ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord
  85 + )
  86 + return render, real
  87 +
  88 + @staticmethod
  89 + def backward(ctx, grad_render, grad_real):
  90 + (
  91 + ori_img,
  92 + ori_size,
  93 + proj_geo,
  94 + texture,
  95 + nbl,
  96 + tri_inds,
  97 + tri_index,
  98 + tri_coord,
  99 + ) = ctx.saved_tensors
  100 + grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward(
  101 + grad_render,
  102 + grad_real,
  103 + ori_img,
  104 + ori_size,
  105 + proj_geo,
  106 + texture,
  107 + nbl,
  108 + tri_inds,
  109 + tri_index,
  110 + tri_coord,
  111 + )
  112 + return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None
  113 +
  114 +
  115 +class Render_RGB(nn.Module):
  116 + def __init__(self):
  117 + super(Render_RGB, self).__init__()
  118 +
  119 + def forward(
  120 + self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
  121 + ):
  122 + return Render_Face.apply(
  123 + proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
  124 + )
  125 +
  126 +
  127 +def cal_land(proj_geo, is_visible, lands_info, land_num):
  128 + (land_index,) = render_util.update_contour(lands_info, is_visible, land_num)
  129 + proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[
  130 + :, :2
  131 + ].reshape(-1, land_num, 2)
  132 + return proj_land
  133 +
  134 +
  135 +class Render_Land(nn.Module):
  136 + def __init__(self):
  137 + super(Render_Land, self).__init__()
  138 + lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32)
  139 + self.lands_info = torch.as_tensor(lands_info).cuda()
  140 + tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64)
  141 + self.tris = torch.as_tensor(tris).cuda() - 1
  142 + vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64)
  143 + self.vert_tris = torch.as_tensor(vert_tris).cuda()
  144 + self.normal_baser = Normal_Base().cuda()
  145 + self.renderer = Render_RGB().cuda()
  146 +
  147 + def render_mesh(self, geometry, euler, trans, cam, ori_img, light):
  148 + batch_size, h, w, _ = ori_img.shape
  149 + ori_img = ori_img.view(batch_size, -1, 3)
  150 + ori_size = torch.cat(
  151 + (
  152 + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
  153 + * h,
  154 + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
  155 + * w,
  156 + ),
  157 + dim=1,
  158 + ).view(-1)
  159 + rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render(
  160 + geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
  161 + )
  162 + tri_nb = self.normal_baser(rot_tri_normal.contiguous())
  163 + nbl = torch.bmm(
  164 + tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3)
  165 + )
  166 + texture = torch.ones_like(geometry) * 200
  167 + (render,) = render_util.render_mesh(
  168 + proj_geo, ori_img, ori_size, texture, nbl, self.tris
  169 + )
  170 + return render.view(batch_size, h, w, 3).byte()
  171 +
  172 + def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands):
  173 + rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render(
  174 + geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
  175 + )
  176 + tri_nb = self.normal_baser(rot_tri_normal.contiguous())
  177 + nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3))
  178 + render, real = self.renderer(
  179 + proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid
  180 + )
  181 + proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1])
  182 + col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape(
  183 + ori_img.shape[0], -1
  184 + )
  185 + col_dis = torch.mean(col_minus * pixel_valid) / (
  186 + torch.mean(pixel_valid) + 0.00001
  187 + )
  188 + land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape(
  189 + ori_img.shape[0], -1
  190 + )
  191 + lan_dis = torch.mean(land_dists)
  192 + return col_dis, lan_dis
  1 +import torch
  2 +import torch.nn as nn
  3 +import torch.nn.functional as F
  4 +
  5 +
  6 +def compute_tri_normal(geometry, tris):
  7 + tri_1 = tris[:, 0]
  8 + tri_2 = tris[:, 1]
  9 + tri_3 = tris[:, 2]
  10 + vert_1 = torch.index_select(geometry, 1, tri_1)
  11 + vert_2 = torch.index_select(geometry, 1, tri_2)
  12 + vert_3 = torch.index_select(geometry, 1, tri_3)
  13 + nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
  14 + normal = nn.functional.normalize(nnorm)
  15 + return normal
  16 +
  17 +
  18 +def euler2rot(euler_angle):
  19 + batch_size = euler_angle.shape[0]
  20 + theta = euler_angle[:, 0].reshape(-1, 1, 1)
  21 + phi = euler_angle[:, 1].reshape(-1, 1, 1)
  22 + psi = euler_angle[:, 2].reshape(-1, 1, 1)
  23 + one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
  24 + zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
  25 + rot_x = torch.cat(
  26 + (
  27 + torch.cat((one, zero, zero), 1),
  28 + torch.cat((zero, theta.cos(), theta.sin()), 1),
  29 + torch.cat((zero, -theta.sin(), theta.cos()), 1),
  30 + ),
  31 + 2,
  32 + )
  33 + rot_y = torch.cat(
  34 + (
  35 + torch.cat((phi.cos(), zero, -phi.sin()), 1),
  36 + torch.cat((zero, one, zero), 1),
  37 + torch.cat((phi.sin(), zero, phi.cos()), 1),
  38 + ),
  39 + 2,
  40 + )
  41 + rot_z = torch.cat(
  42 + (
  43 + torch.cat((psi.cos(), -psi.sin(), zero), 1),
  44 + torch.cat((psi.sin(), psi.cos(), zero), 1),
  45 + torch.cat((zero, zero, one), 1),
  46 + ),
  47 + 2,
  48 + )
  49 + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
  50 +
  51 +
  52 +def rot_trans_pts(geometry, rot, trans):
  53 + rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None]
  54 + return rott_geo.permute(0, 2, 1)
  55 +
  56 +
  57 +def cal_lap_loss(tensor_list, weight_list):
  58 + lap_kernel = (
  59 + torch.Tensor((-0.5, 1.0, -0.5))
  60 + .unsqueeze(0)
  61 + .unsqueeze(0)
  62 + .float()
  63 + .to(tensor_list[0].device)
  64 + )
  65 + loss_lap = 0
  66 + for i in range(len(tensor_list)):
  67 + in_tensor = tensor_list[i]
  68 + in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1])
  69 + out_tensor = F.conv1d(in_tensor, lap_kernel)
  70 + loss_lap += torch.mean(out_tensor ** 2) * weight_list[i]
  71 + return loss_lap
  72 +
  73 +
  74 +def proj_pts(rott_geo, focal_length, cxy):
  75 + cx, cy = cxy[0], cxy[1]
  76 + X = rott_geo[:, :, 0]
  77 + Y = rott_geo[:, :, 1]
  78 + Z = rott_geo[:, :, 2]
  79 + fxX = focal_length * X
  80 + fyY = focal_length * Y
  81 + proj_x = -fxX / Z + cx
  82 + proj_y = fyY / Z + cy
  83 + return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)
  84 +
  85 +
  86 +def forward_rott(geometry, euler_angle, trans):
  87 + rot = euler2rot(euler_angle)
  88 + rott_geo = rot_trans_pts(geometry, rot, trans)
  89 + return rott_geo
  90 +
  91 +
  92 +def forward_transform(geometry, euler_angle, trans, focal_length, cxy):
  93 + rot = euler2rot(euler_angle)
  94 + rott_geo = rot_trans_pts(geometry, rot, trans)
  95 + proj_geo = proj_pts(rott_geo, focal_length, cxy)
  96 + return proj_geo
  97 +
  98 +
  99 +def cal_lan_loss(proj_lan, gt_lan):
  100 + return torch.mean((proj_lan - gt_lan) ** 2)
  101 +
  102 +
  103 +def cal_col_loss(pred_img, gt_img, img_mask):
  104 + pred_img = pred_img.float()
  105 + # loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255
  106 + loss = (torch.sum(torch.square(pred_img - gt_img), 3)) * img_mask / 255
  107 + loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2))
  108 + loss = torch.mean(loss)
  109 + return loss
  1 +import os
  2 +import glob
  3 +import tqdm
  4 +import json
  5 +import argparse
  6 +import cv2
  7 +import numpy as np
  8 +
  9 +def extract_audio(path, out_path, sample_rate=16000):
  10 +
  11 + print(f'[INFO] ===== extract audio from {path} to {out_path} =====')
  12 + cmd = f'ffmpeg -i {path} -f wav -ar {sample_rate} {out_path}'
  13 + os.system(cmd)
  14 + print(f'[INFO] ===== extracted audio =====')
  15 +
  16 +
  17 +def extract_audio_features(path, mode='wav2vec'):
  18 +
  19 + print(f'[INFO] ===== extract audio labels for {path} =====')
  20 + if mode == 'wav2vec':
  21 + cmd = f'python nerf/asr.py --wav {path} --save_feats'
  22 + else: # deepspeech
  23 + cmd = f'python data_utils/deepspeech_features/extract_ds_features.py --input {path}'
  24 + os.system(cmd)
  25 + print(f'[INFO] ===== extracted audio labels =====')
  26 +
  27 +
  28 +
  29 +def extract_images(path, out_path, fps=25):
  30 +
  31 + print(f'[INFO] ===== extract images from {path} to {out_path} =====')
  32 + cmd = f'ffmpeg -i {path} -vf fps={fps} -qmin 1 -q:v 1 -start_number 0 {os.path.join(out_path, "%d.jpg")}'
  33 + os.system(cmd)
  34 + print(f'[INFO] ===== extracted images =====')
  35 +
  36 +
  37 +def extract_semantics(ori_imgs_dir, parsing_dir):
  38 +
  39 + print(f'[INFO] ===== extract semantics from {ori_imgs_dir} to {parsing_dir} =====')
  40 + cmd = f'python data_utils/face_parsing/test.py --respath={parsing_dir} --imgpath={ori_imgs_dir}'
  41 + os.system(cmd)
  42 + print(f'[INFO] ===== extracted semantics =====')
  43 +
  44 +
  45 +def extract_landmarks(ori_imgs_dir):
  46 +
  47 + print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
  48 +
  49 + import face_alignment
  50 + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
  51 + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
  52 + for image_path in tqdm.tqdm(image_paths):
  53 + input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
  54 + input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
  55 + preds = fa.get_landmarks(input)
  56 + if len(preds) > 0:
  57 + lands = preds[0].reshape(-1, 2)[:,:2]
  58 + np.savetxt(image_path.replace('jpg', 'lms'), lands, '%f')
  59 + del fa
  60 + print(f'[INFO] ===== extracted face landmarks =====')
  61 +
  62 +
  63 +def extract_background(base_dir, ori_imgs_dir):
  64 +
  65 + print(f'[INFO] ===== extract background image from {ori_imgs_dir} =====')
  66 +
  67 + from sklearn.neighbors import NearestNeighbors
  68 +
  69 + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
  70 + # only use 1/20 image_paths
  71 + image_paths = image_paths[::20]
  72 + # read one image to get H/W
  73 + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
  74 + h, w = tmp_image.shape[:2]
  75 +
  76 + # nearest neighbors
  77 + all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
  78 + distss = []
  79 + for image_path in tqdm.tqdm(image_paths):
  80 + parse_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
  81 + bg = (parse_img[..., 0] == 255) & (parse_img[..., 1] == 255) & (parse_img[..., 2] == 255)
  82 + fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
  83 + nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
  84 + dists, _ = nbrs.kneighbors(all_xys)
  85 + distss.append(dists)
  86 +
  87 + distss = np.stack(distss)
  88 + max_dist = np.max(distss, 0)
  89 + max_id = np.argmax(distss, 0)
  90 +
  91 + bc_pixs = max_dist > 5
  92 + bc_pixs_id = np.nonzero(bc_pixs)
  93 + bc_ids = max_id[bc_pixs]
  94 +
  95 + imgs = []
  96 + num_pixs = distss.shape[1]
  97 + for image_path in image_paths:
  98 + img = cv2.imread(image_path)
  99 + imgs.append(img)
  100 + imgs = np.stack(imgs).reshape(-1, num_pixs, 3)
  101 +
  102 + bc_img = np.zeros((h*w, 3), dtype=np.uint8)
  103 + bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
  104 + bc_img = bc_img.reshape(h, w, 3)
  105 +
  106 + max_dist = max_dist.reshape(h, w)
  107 + bc_pixs = max_dist > 5
  108 + bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
  109 + fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
  110 + nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
  111 + distances, indices = nbrs.kneighbors(bg_xys)
  112 + bg_fg_xys = fg_xys[indices[:, 0]]
  113 + bc_img[bg_xys[:, 0], bg_xys[:, 1], :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
  114 +
  115 + cv2.imwrite(os.path.join(base_dir, 'bc.jpg'), bc_img)
  116 +
  117 + print(f'[INFO] ===== extracted background image =====')
  118 +
  119 +
  120 +def extract_torso_and_gt(base_dir, ori_imgs_dir):
  121 +
  122 + print(f'[INFO] ===== extract torso and gt images for {base_dir} =====')
  123 +
  124 + from scipy.ndimage import binary_erosion, binary_dilation
  125 +
  126 + # load bg
  127 + bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED)
  128 +
  129 + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
  130 +
  131 + for image_path in tqdm.tqdm(image_paths):
  132 + # read ori image
  133 + ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
  134 +
  135 + # read semantics
  136 + seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
  137 + head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0)
  138 + neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0)
  139 + torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255)
  140 + bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255)
  141 +
  142 + # get gt image
  143 + gt_image = ori_image.copy()
  144 + gt_image[bg_part] = bg_image[bg_part]
  145 + cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
  146 +
  147 + # get torso image
  148 + torso_image = gt_image.copy() # rgb
  149 + torso_image[head_part] = bg_image[head_part]
  150 + torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
  151 +
  152 + # torso part "vertical" in-painting...
  153 + L = 8 + 1
  154 + torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
  155 + # lexsort: sort 2D coords first by y then by x,
  156 + # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
  157 + inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
  158 + torso_coords = torso_coords[inds]
  159 + # choose the top pixel for each column
  160 + u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
  161 + top_torso_coords = torso_coords[uid] # [m, 2]
  162 + # only keep top-is-head pixels
  163 + top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0])
  164 + mask = head_part[tuple(top_torso_coords_up.T)]
  165 + if mask.any():
  166 + top_torso_coords = top_torso_coords[mask]
  167 + # get the color
  168 + top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
  169 + # construct inpaint coords (vertically up, or minus in x)
  170 + inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
  171 + inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
  172 + inpaint_torso_coords += inpaint_offsets
  173 + inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
  174 + inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
  175 + darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
  176 + inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
  177 + # set color
  178 + torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
  179 +
  180 + inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
  181 + inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
  182 + else:
  183 + inpaint_torso_mask = None
  184 +
  185 +
  186 + # neck part "vertical" in-painting...
  187 + push_down = 4
  188 + L = 48 + push_down + 1
  189 +
  190 + neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
  191 +
  192 + neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
  193 + # lexsort: sort 2D coords first by y then by x,
  194 + # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
  195 + inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
  196 + neck_coords = neck_coords[inds]
  197 + # choose the top pixel for each column
  198 + u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
  199 + top_neck_coords = neck_coords[uid] # [m, 2]
  200 + # only keep top-is-head pixels
  201 + top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
  202 + mask = head_part[tuple(top_neck_coords_up.T)]
  203 +
  204 + top_neck_coords = top_neck_coords[mask]
  205 + # push these top down for 4 pixels to make the neck inpainting more natural...
  206 + offset_down = np.minimum(ucnt[mask] - 1, push_down)
  207 + top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
  208 + # get the color
  209 + top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
  210 + # construct inpaint coords (vertically up, or minus in x)
  211 + inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
  212 + inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
  213 + inpaint_neck_coords += inpaint_offsets
  214 + inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
  215 + inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
  216 + darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
  217 + inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
  218 + # set color
  219 + torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
  220 +
  221 + # apply blurring to the inpaint area to avoid vertical-line artifects...
  222 + inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
  223 + inpaint_mask[tuple(inpaint_neck_coords.T)] = True
  224 +
  225 + blur_img = torso_image.copy()
  226 + blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
  227 +
  228 + torso_image[inpaint_mask] = blur_img[inpaint_mask]
  229 +
  230 + # set mask
  231 + mask = (neck_part | torso_part | inpaint_mask)
  232 + if inpaint_torso_mask is not None:
  233 + mask = mask | inpaint_torso_mask
  234 + torso_image[~mask] = 0
  235 + torso_alpha[~mask] = 0
  236 +
  237 + cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1))
  238 +
  239 + print(f'[INFO] ===== extracted torso and gt images =====')
  240 +
  241 +
  242 +def face_tracking(ori_imgs_dir):
  243 +
  244 + print(f'[INFO] ===== perform face tracking =====')
  245 +
  246 + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
  247 +
  248 + # read one image to get H/W
  249 + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
  250 + h, w = tmp_image.shape[:2]
  251 +
  252 + cmd = f'python data_utils/face_tracking/face_tracker.py --path={ori_imgs_dir} --img_h={h} --img_w={w} --frame_num={len(image_paths)}'
  253 +
  254 + os.system(cmd)
  255 +
  256 + print(f'[INFO] ===== finished face tracking =====')
  257 +
  258 +
  259 +def save_transforms(base_dir, ori_imgs_dir):
  260 + print(f'[INFO] ===== save transforms =====')
  261 +
  262 + import torch
  263 +
  264 + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
  265 +
  266 + # read one image to get H/W
  267 + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
  268 + h, w = tmp_image.shape[:2]
  269 +
  270 + params_dict = torch.load(os.path.join(base_dir, 'track_params.pt'))
  271 + focal_len = params_dict['focal']
  272 + euler_angle = params_dict['euler']
  273 + trans = params_dict['trans'] / 10.0
  274 + valid_num = euler_angle.shape[0]
  275 +
  276 + def euler2rot(euler_angle):
  277 + batch_size = euler_angle.shape[0]
  278 + theta = euler_angle[:, 0].reshape(-1, 1, 1)
  279 + phi = euler_angle[:, 1].reshape(-1, 1, 1)
  280 + psi = euler_angle[:, 2].reshape(-1, 1, 1)
  281 + one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
  282 + zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
  283 + rot_x = torch.cat((
  284 + torch.cat((one, zero, zero), 1),
  285 + torch.cat((zero, theta.cos(), theta.sin()), 1),
  286 + torch.cat((zero, -theta.sin(), theta.cos()), 1),
  287 + ), 2)
  288 + rot_y = torch.cat((
  289 + torch.cat((phi.cos(), zero, -phi.sin()), 1),
  290 + torch.cat((zero, one, zero), 1),
  291 + torch.cat((phi.sin(), zero, phi.cos()), 1),
  292 + ), 2)
  293 + rot_z = torch.cat((
  294 + torch.cat((psi.cos(), -psi.sin(), zero), 1),
  295 + torch.cat((psi.sin(), psi.cos(), zero), 1),
  296 + torch.cat((zero, zero, one), 1)
  297 + ), 2)
  298 + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
  299 +
  300 +
  301 + # train_val_split = int(valid_num*0.5)
  302 + # train_val_split = valid_num - 25 * 20 # take the last 20s as valid set.
  303 + train_val_split = int(valid_num * 10 / 11)
  304 +
  305 + train_ids = torch.arange(0, train_val_split)
  306 + val_ids = torch.arange(train_val_split, valid_num)
  307 +
  308 + rot = euler2rot(euler_angle)
  309 + rot_inv = rot.permute(0, 2, 1)
  310 + trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2))
  311 +
  312 + pose = torch.eye(4, dtype=torch.float32)
  313 + save_ids = ['train', 'val']
  314 + train_val_ids = [train_ids, val_ids]
  315 + mean_z = -float(torch.mean(trans[:, 2]).item())
  316 +
  317 + for split in range(2):
  318 + transform_dict = dict()
  319 + transform_dict['focal_len'] = float(focal_len[0])
  320 + transform_dict['cx'] = float(w/2.0)
  321 + transform_dict['cy'] = float(h/2.0)
  322 + transform_dict['frames'] = []
  323 + ids = train_val_ids[split]
  324 + save_id = save_ids[split]
  325 +
  326 + for i in ids:
  327 + i = i.item()
  328 + frame_dict = dict()
  329 + frame_dict['img_id'] = i
  330 + frame_dict['aud_id'] = i
  331 +
  332 + pose[:3, :3] = rot_inv[i]
  333 + pose[:3, 3] = trans_inv[i, :, 0]
  334 +
  335 + frame_dict['transform_matrix'] = pose.numpy().tolist()
  336 +
  337 + transform_dict['frames'].append(frame_dict)
  338 +
  339 + with open(os.path.join(base_dir, 'transforms_' + save_id + '.json'), 'w') as fp:
  340 + json.dump(transform_dict, fp, indent=2, separators=(',', ': '))
  341 +
  342 + print(f'[INFO] ===== finished saving transforms =====')
  343 +
  344 +
  345 +if __name__ == '__main__':
  346 + parser = argparse.ArgumentParser()
  347 + parser.add_argument('path', type=str, help="path to video file")
  348 + parser.add_argument('--task', type=int, default=-1, help="-1 means all")
  349 + parser.add_argument('--asr', type=str, default='wav2vec', help="wav2vec or deepspeech")
  350 +
  351 + opt = parser.parse_args()
  352 +
  353 + base_dir = os.path.dirname(opt.path)
  354 +
  355 + wav_path = os.path.join(base_dir, 'aud.wav')
  356 + ori_imgs_dir = os.path.join(base_dir, 'ori_imgs')
  357 + parsing_dir = os.path.join(base_dir, 'parsing')
  358 + gt_imgs_dir = os.path.join(base_dir, 'gt_imgs')
  359 + torso_imgs_dir = os.path.join(base_dir, 'torso_imgs')
  360 +
  361 + os.makedirs(ori_imgs_dir, exist_ok=True)
  362 + os.makedirs(parsing_dir, exist_ok=True)
  363 + os.makedirs(gt_imgs_dir, exist_ok=True)
  364 + os.makedirs(torso_imgs_dir, exist_ok=True)
  365 +
  366 +
  367 + # extract audio
  368 + if opt.task == -1 or opt.task == 1:
  369 + extract_audio(opt.path, wav_path)
  370 +
  371 + # extract audio features
  372 + if opt.task == -1 or opt.task == 2:
  373 + extract_audio_features(wav_path, mode=opt.asr)
  374 +
  375 + # extract images
  376 + if opt.task == -1 or opt.task == 3:
  377 + extract_images(opt.path, ori_imgs_dir)
  378 +
  379 + # face parsing
  380 + if opt.task == -1 or opt.task == 4:
  381 + extract_semantics(ori_imgs_dir, parsing_dir)
  382 +
  383 + # extract bg
  384 + if opt.task == -1 or opt.task == 5:
  385 + extract_background(base_dir, ori_imgs_dir)
  386 +
  387 + # extract torso images and gt_images
  388 + if opt.task == -1 or opt.task == 6:
  389 + extract_torso_and_gt(base_dir, ori_imgs_dir)
  390 +
  391 + # extract face landmarks
  392 + if opt.task == -1 or opt.task == 7:
  393 + extract_landmarks(ori_imgs_dir)
  394 +
  395 + # face tracking
  396 + if opt.task == -1 or opt.task == 8:
  397 + face_tracking(ori_imgs_dir)
  398 +
  399 + # save transforms.json
  400 + if opt.task == -1 or opt.task == 9:
  401 + save_transforms(base_dir, ori_imgs_dir)
  402 +
  1 +import torch
  2 +import torch.nn as nn
  3 +import torch.nn.functional as F
  4 +
  5 +
  6 +def get_encoder(encoding, input_dim=3,
  7 + multires=6,
  8 + degree=4,
  9 + num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
  10 + **kwargs):
  11 +
  12 + if encoding == 'None':
  13 + return lambda x, **kwargs: x, input_dim
  14 +
  15 + elif encoding == 'frequency':
  16 + from .freqencoder import FreqEncoder
  17 + encoder = FreqEncoder(input_dim=input_dim, degree=multires)
  18 +
  19 + elif encoding == 'spherical_harmonics':
  20 + from .shencoder import SHEncoder
  21 + encoder = SHEncoder(input_dim=input_dim, degree=degree)
  22 +
  23 + elif encoding == 'hashgrid':
  24 + from .gridencoder import GridEncoder
  25 + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
  26 +
  27 + elif encoding == 'tiledgrid':
  28 + from .gridencoder import GridEncoder
  29 + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
  30 +
  31 + elif encoding == 'ash':
  32 + from .ashencoder import AshEncoder
  33 + encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution)
  34 +
  35 + else:
  36 + raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, spherical_harmonics, hashgrid, tiledgrid]')
  37 +
  38 + return encoder, encoder.output_dim
  1 +from .freq import FreqEncoder
  1 +import os
  2 +from torch.utils.cpp_extension import load
  3 +
  4 +_src_path = os.path.dirname(os.path.abspath(__file__))
  5 +
  6 +nvcc_flags = [
  7 + '-O3', '-std=c++17',
  8 + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-allow-unsupported-compiler',
  9 + '-use_fast_math'
  10 +]
  11 +
  12 +if os.name == "posix":
  13 + c_flags = ['-O3', '-std=c++17']
  14 +elif os.name == "nt":
  15 + c_flags = ['/O2', '/std:c++17']
  16 +
  17 + # find cl.exe
  18 + def find_cl_path():
  19 + import glob
  20 + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
  21 + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
  22 + if paths:
  23 + return paths[0]
  24 +
  25 + # If cl.exe is not on path, try to find it.
  26 + if os.system("where cl.exe >nul 2>nul") != 0:
  27 + cl_path = find_cl_path()
  28 + if cl_path is None:
  29 + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
  30 + os.environ["PATH"] += ";" + cl_path
  31 +
  32 +_backend = load(name='_freqencoder',
  33 + extra_cflags=c_flags,
  34 + extra_cuda_cflags=nvcc_flags,
  35 + sources=[os.path.join(_src_path, 'src', f) for f in [
  36 + 'freqencoder.cu',
  37 + 'bindings.cpp',
  38 + ]],
  39 + )
  40 +
  41 +__all__ = ['_backend']
  1 +import numpy as np
  2 +
  3 +import torch
  4 +import torch.nn as nn
  5 +from torch.autograd import Function
  6 +from torch.autograd.function import once_differentiable
  7 +from torch.cuda.amp import custom_bwd, custom_fwd
  8 +
  9 +try:
  10 + import _freqencoder as _backend
  11 +except ImportError:
  12 + from .backend import _backend
  13 +
  14 +
  15 +class _freq_encoder(Function):
  16 + @staticmethod
  17 + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
  18 + def forward(ctx, inputs, degree, output_dim):
  19 + # inputs: [B, input_dim], float
  20 + # RETURN: [B, F], float
  21 +
  22 + if not inputs.is_cuda: inputs = inputs.cuda()
  23 + inputs = inputs.contiguous()
  24 +
  25 + B, input_dim = inputs.shape # batch size, coord dim
  26 +
  27 + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
  28 +
  29 + _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
  30 +
  31 + ctx.save_for_backward(inputs, outputs)
  32 + ctx.dims = [B, input_dim, degree, output_dim]
  33 +
  34 + return outputs
  35 +
  36 + @staticmethod
  37 + #@once_differentiable
  38 + @custom_bwd
  39 + def backward(ctx, grad):
  40 + # grad: [B, C * C]
  41 +
  42 + grad = grad.contiguous()
  43 + inputs, outputs = ctx.saved_tensors
  44 + B, input_dim, degree, output_dim = ctx.dims
  45 +
  46 + grad_inputs = torch.zeros_like(inputs)
  47 + _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
  48 +
  49 + return grad_inputs, None, None
  50 +
  51 +
  52 +freq_encode = _freq_encoder.apply
  53 +
  54 +
  55 +class FreqEncoder(nn.Module):
  56 + def __init__(self, input_dim=3, degree=4):
  57 + super().__init__()
  58 +
  59 + self.input_dim = input_dim
  60 + self.degree = degree
  61 + self.output_dim = input_dim + input_dim * 2 * degree
  62 +
  63 + def __repr__(self):
  64 + return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
  65 +
  66 + def forward(self, inputs, **kwargs):
  67 + # inputs: [..., input_dim]
  68 + # return: [..., ]
  69 +
  70 + prefix_shape = list(inputs.shape[:-1])
  71 + inputs = inputs.reshape(-1, self.input_dim)
  72 +
  73 + outputs = freq_encode(inputs, self.degree, self.output_dim)
  74 +
  75 + outputs = outputs.reshape(prefix_shape + [self.output_dim])
  76 +
  77 + return outputs
  1 +import os
  2 +from setuptools import setup
  3 +from torch.utils.cpp_extension import BuildExtension, CUDAExtension
  4 +
  5 +_src_path = os.path.dirname(os.path.abspath(__file__))
  6 +
  7 +nvcc_flags = [
  8 + '-O3', '-std=c++17',
  9 + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-allow-unsupported-compiler',
  10 + '-use_fast_math'
  11 +]
  12 +
  13 +if os.name == "posix":
  14 + c_flags = ['-O3', '-std=c++17']
  15 +elif os.name == "nt":
  16 + c_flags = ['/O2', '/std:c++17']
  17 +
  18 + # find cl.exe
  19 + def find_cl_path():
  20 + import glob
  21 + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
  22 + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
  23 + if paths:
  24 + return paths[0]
  25 +
  26 + # If cl.exe is not on path, try to find it.
  27 + if os.system("where cl.exe >nul 2>nul") != 0:
  28 + cl_path = find_cl_path()
  29 + if cl_path is None:
  30 + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
  31 + os.environ["PATH"] += ";" + cl_path
  32 +
  33 +setup(
  34 + name='freqencoder', # package name, import this to use python API
  35 + ext_modules=[
  36 + CUDAExtension(
  37 + name='_freqencoder', # extension name, import this to use CUDA API
  38 + sources=[os.path.join(_src_path, 'src', f) for f in [
  39 + 'freqencoder.cu',
  40 + 'bindings.cpp',
  41 + ]],
  42 + extra_compile_args={
  43 + 'cxx': c_flags,
  44 + 'nvcc': nvcc_flags,
  45 + }
  46 + ),
  47 + ],
  48 + cmdclass={
  49 + 'build_ext': BuildExtension,
  50 + }
  51 +)
  1 +#include <torch/extension.h>
  2 +
  3 +#include "freqencoder.h"
  4 +
  5 +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  6 + m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
  7 + m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
  8 +}
  1 +#include <stdint.h>
  2 +
  3 +#include <cuda.h>
  4 +#include <cuda_fp16.h>
  5 +#include <cuda_runtime.h>
  6 +
  7 +#include <ATen/cuda/CUDAContext.h>
  8 +#include <torch/torch.h>
  9 +
  10 +#include <algorithm>
  11 +#include <stdexcept>
  12 +
  13 +#include <cstdio>
  14 +
  15 +
  16 +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
  17 +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
  18 +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
  19 +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
  20 +
  21 +inline constexpr __device__ float PI() { return 3.141592653589793f; }
  22 +
  23 +template <typename T>
  24 +__host__ __device__ T div_round_up(T val, T divisor) {
  25 + return (val + divisor - 1) / divisor;
  26 +}
  27 +
  28 +// inputs: [B, D]
  29 +// outputs: [B, C], C = D + D * deg * 2
  30 +__global__ void kernel_freq(
  31 + const float * __restrict__ inputs,
  32 + uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
  33 + float * outputs
  34 +) {
  35 + // parallel on per-element
  36 + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
  37 + if (t >= B * C) return;
  38 +
  39 + // get index
  40 + const uint32_t b = t / C;
  41 + const uint32_t c = t - b * C; // t % C;
  42 +
  43 + // locate
  44 + inputs += b * D;
  45 + outputs += t;
  46 +
  47 + // write self
  48 + if (c < D) {
  49 + outputs[0] = inputs[c];
  50 + // write freq
  51 + } else {
  52 + const uint32_t col = c / D - 1;
  53 + const uint32_t d = c % D;
  54 + const uint32_t freq = col / 2;
  55 + const float phase_shift = (col % 2) * (PI() / 2);
  56 + outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
  57 + }
  58 +}
  59 +
  60 +// grad: [B, C], C = D + D * deg * 2
  61 +// outputs: [B, C]
  62 +// grad_inputs: [B, D]
  63 +__global__ void kernel_freq_backward(
  64 + const float * __restrict__ grad,
  65 + const float * __restrict__ outputs,
  66 + uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
  67 + float * grad_inputs
  68 +) {
  69 + // parallel on per-element
  70 + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
  71 + if (t >= B * D) return;
  72 +
  73 + const uint32_t b = t / D;
  74 + const uint32_t d = t - b * D; // t % D;
  75 +
  76 + // locate
  77 + grad += b * C;
  78 + outputs += b * C;
  79 + grad_inputs += t;
  80 +
  81 + // register
  82 + float result = grad[d];
  83 + grad += D;
  84 + outputs += D;
  85 +
  86 + for (uint32_t f = 0; f < deg; f++) {
  87 + result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
  88 + grad += 2 * D;
  89 + outputs += 2 * D;
  90 + }
  91 +
  92 + // write
  93 + grad_inputs[0] = result;
  94 +}
  95 +
  96 +
  97 +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
  98 + CHECK_CUDA(inputs);
  99 + CHECK_CUDA(outputs);
  100 +
  101 + CHECK_CONTIGUOUS(inputs);
  102 + CHECK_CONTIGUOUS(outputs);
  103 +
  104 + CHECK_IS_FLOATING(inputs);
  105 + CHECK_IS_FLOATING(outputs);
  106 +
  107 + static constexpr uint32_t N_THREADS = 128;
  108 +
  109 + kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
  110 +}
  111 +
  112 +
  113 +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
  114 + CHECK_CUDA(grad);
  115 + CHECK_CUDA(outputs);
  116 + CHECK_CUDA(grad_inputs);
  117 +
  118 + CHECK_CONTIGUOUS(grad);
  119 + CHECK_CONTIGUOUS(outputs);
  120 + CHECK_CONTIGUOUS(grad_inputs);
  121 +
  122 + CHECK_IS_FLOATING(grad);
  123 + CHECK_IS_FLOATING(outputs);
  124 + CHECK_IS_FLOATING(grad_inputs);
  125 +
  126 + static constexpr uint32_t N_THREADS = 128;
  127 +
  128 + kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
  129 +}
  1 +# pragma once
  2 +
  3 +#include <stdint.h>
  4 +#include <torch/torch.h>
  5 +
  6 +// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
  7 +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
  8 +
  9 +// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
  10 +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
  1 +from .grid import GridEncoder
  1 +import os
  2 +from torch.utils.cpp_extension import load
  3 +
  4 +_src_path = os.path.dirname(os.path.abspath(__file__))
  5 +
  6 +nvcc_flags = [
  7 + '-O3', '-std=c++17',
  8 + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
  9 +]
  10 +
  11 +if os.name == "posix":
  12 + c_flags = ['-O3', '-std=c++17', '-finput-charset=UTF-8']
  13 +elif os.name == "nt":
  14 + c_flags = ['/O2', '/std:c++17', '/finput-charset=UTF-8']
  15 +
  16 + # find cl.exe
  17 + def find_cl_path():
  18 + import glob
  19 + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
  20 + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
  21 + if paths:
  22 + return paths[0]
  23 +
  24 + # If cl.exe is not on path, try to find it.
  25 + if os.system("where cl.exe >nul 2>nul") != 0:
  26 + cl_path = find_cl_path()
  27 + if cl_path is None:
  28 + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
  29 + os.environ["PATH"] += ";" + cl_path
  30 +
  31 +_backend = load(name='_grid_encoder',
  32 + extra_cflags=c_flags,
  33 + extra_cuda_cflags=nvcc_flags,
  34 + sources=[os.path.join(_src_path, 'src', f) for f in [
  35 + 'gridencoder.cu',
  36 + 'bindings.cpp',
  37 + ]],
  38 + )
  39 +
  40 +__all__ = ['_backend']
  1 +import numpy as np
  2 +
  3 +import torch
  4 +import torch.nn as nn
  5 +from torch.autograd import Function
  6 +from torch.autograd.function import once_differentiable
  7 +from torch.cuda.amp import custom_bwd, custom_fwd
  8 +
  9 +try:
  10 + import _gridencoder as _backend
  11 +except ImportError:
  12 + from .backend import _backend
  13 +
  14 +_gridtype_to_id = {
  15 + 'hash': 0,
  16 + 'tiled': 1,
  17 +}
  18 +
  19 +class _grid_encode(Function):
  20 + @staticmethod
  21 + @custom_fwd
  22 + def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
  23 + # inputs: [B, D], float in [0, 1]
  24 + # embeddings: [sO, C], float
  25 + # offsets: [L + 1], int
  26 + # RETURN: [B, F], float
  27 +
  28 + inputs = inputs.float().contiguous()
  29 +
  30 + B, D = inputs.shape # batch size, coord dim
  31 + L = offsets.shape[0] - 1 # level
  32 + C = embeddings.shape[1] # embedding dim for each level
  33 + S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
  34 + H = base_resolution # base resolution
  35 +
  36 + # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
  37 + # if C % 2 != 0, force float, since half for atomicAdd is very slow.
  38 + if torch.is_autocast_enabled() and C % 2 == 0:
  39 + embeddings = embeddings.to(torch.half)
  40 +
  41 + # L first, optimize cache for cuda kernel, but needs an extra permute later
  42 + outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
  43 +
  44 + if calc_grad_inputs:
  45 + dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
  46 + else:
  47 + dy_dx = None
  48 +
  49 + _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners)
  50 +
  51 + # permute back to [B, L * C]
  52 + outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
  53 +
  54 + ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
  55 + ctx.dims = [B, D, C, L, S, H, gridtype]
  56 + ctx.align_corners = align_corners
  57 +
  58 + return outputs
  59 +
  60 + @staticmethod
  61 + #@once_differentiable
  62 + @custom_bwd
  63 + def backward(ctx, grad):
  64 +
  65 + inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
  66 + B, D, C, L, S, H, gridtype = ctx.dims
  67 + align_corners = ctx.align_corners
  68 +
  69 + # grad: [B, L * C] --> [L, B, C]
  70 + grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
  71 +
  72 + grad_embeddings = torch.zeros_like(embeddings)
  73 +
  74 + if dy_dx is not None:
  75 + grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
  76 + else:
  77 + grad_inputs = None
  78 +
  79 + _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners)
  80 +
  81 + if dy_dx is not None:
  82 + grad_inputs = grad_inputs.to(inputs.dtype)
  83 +
  84 + return grad_inputs, grad_embeddings, None, None, None, None, None, None
  85 +
  86 +
  87 +
  88 +grid_encode = _grid_encode.apply
  89 +
  90 +
  91 +class GridEncoder(nn.Module):
  92 + def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
  93 + super().__init__()
  94 +
  95 + # the finest resolution desired at the last level, if provided, overridee per_level_scale
  96 + if desired_resolution is not None:
  97 + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
  98 +
  99 + self.input_dim = input_dim # coord dims, 2 or 3
  100 + self.num_levels = num_levels # num levels, each level multiply resolution by 2
  101 + self.level_dim = level_dim # encode channels per level
  102 + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
  103 + self.log2_hashmap_size = log2_hashmap_size
  104 + self.base_resolution = base_resolution
  105 + self.output_dim = num_levels * level_dim
  106 + self.gridtype = gridtype
  107 + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
  108 + self.align_corners = align_corners
  109 +
  110 + # allocate parameters
  111 + offsets = []
  112 + offset = 0
  113 + self.max_params = 2 ** log2_hashmap_size
  114 + for i in range(num_levels):
  115 + resolution = int(np.ceil(base_resolution * per_level_scale ** i))
  116 + params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
  117 + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
  118 + offsets.append(offset)
  119 + offset += params_in_level
  120 + # print(resolution, params_in_level)
  121 + offsets.append(offset)
  122 + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
  123 + self.register_buffer('offsets', offsets)
  124 +
  125 + self.n_params = offsets[-1] * level_dim
  126 +
  127 + # parameters
  128 + self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
  129 +
  130 + self.reset_parameters()
  131 +
  132 + def reset_parameters(self):
  133 + std = 1e-4
  134 + self.embeddings.data.uniform_(-std, std)
  135 +
  136 + def __repr__(self):
  137 + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
  138 +
  139 + def forward(self, inputs, bound=1):
  140 + # inputs: [..., input_dim], normalized real world positions in [-bound, bound]
  141 + # return: [..., num_levels * level_dim]
  142 +
  143 + inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
  144 +
  145 + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
  146 +
  147 + prefix_shape = list(inputs.shape[:-1])
  148 + inputs = inputs.view(-1, self.input_dim)
  149 +
  150 + outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
  151 + outputs = outputs.view(prefix_shape + [self.output_dim])
  152 +
  153 + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
  154 +
  155 + return outputs
  1 +import os
  2 +from setuptools import setup
  3 +from torch.utils.cpp_extension import BuildExtension, CUDAExtension
  4 +
  5 +_src_path = os.path.dirname(os.path.abspath(__file__))
  6 +
  7 +nvcc_flags = [
  8 + '-O3', '-std=c++17',
  9 + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler',
  10 +]
  11 +
  12 +if os.name == "posix":
  13 + c_flags = ['-O3', '-std=c++17']
  14 +elif os.name == "nt":
  15 + c_flags = ['/O2', '/std:c++17']
  16 +
  17 + # find cl.exe
  18 + def find_cl_path():
  19 + import glob
  20 + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
  21 + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
  22 + if paths:
  23 + return paths[0]
  24 +
  25 + # If cl.exe is not on path, try to find it.
  26 + if os.system("where cl.exe >nul 2>nul") != 0:
  27 + cl_path = find_cl_path()
  28 + if cl_path is None:
  29 + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
  30 + os.environ["PATH"] += ";" + cl_path
  31 +
  32 +setup(
  33 + name='gridencoder', # package name, import this to use python API
  34 + ext_modules=[
  35 + CUDAExtension(
  36 + name='_gridencoder', # extension name, import this to use CUDA API
  37 + sources=[os.path.join(_src_path, 'src', f) for f in [
  38 + 'gridencoder.cu',
  39 + 'bindings.cpp',
  40 + ]],
  41 + extra_compile_args={
  42 + 'cxx': c_flags,
  43 + 'nvcc': nvcc_flags,
  44 + }
  45 + ),
  46 + ],
  47 + cmdclass={
  48 + 'build_ext': BuildExtension,
  49 + }
  50 +)
  1 +#include <torch/extension.h>
  2 +
  3 +#include "gridencoder.h"
  4 +
  5 +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  6 + m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
  7 + m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
  8 +}