冯杨

基础提交

Showing 46 changed files with 4734 additions and 0 deletions

Too many changes to show.

To preserve performance only 46 of 46+ files are displayed.

github: [lipku]
... ...
__pycache__/
build/
*.egg-info/
*.so
*.mp4
tmp*
trial*/
data
data_utils/face_tracking/3DMM/*
data_utils/face_parsing/79999_iter.pth
pretrained
*.mp4
.DS_Store
workspace/log_ngp.txt
.idea
\ No newline at end of file
... ...
{
"version": "0.2.0",
"configurations": [
{
"name": "Python: 带参数调试",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/app.py",
"args": [
"--transport", "webrtc",
"--model", "wav2lip",
"--avatar_id", "wav2lip256_avatar1"
],
"console": "integratedTerminal",
"justMyCode": true
}
]
}
\ No newline at end of file
... ...
# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
ARG BASE_IMAGE=nvcr.io/nvidia/cuda:11.6.1-cudnn8-devel-ubuntu20.04
FROM $BASE_IMAGE
RUN apt-get update -yq --fix-missing \
&& DEBIAN_FRONTEND=noninteractive apt-get install -yq --no-install-recommends \
pkg-config \
wget \
cmake \
curl \
git \
vim
#ENV PYTHONDONTWRITEBYTECODE=1
#ENV PYTHONUNBUFFERED=1
# nvidia-container-runtime
#ENV NVIDIA_VISIBLE_DEVICES all
#ENV NVIDIA_DRIVER_CAPABILITIES compute,utility,graphics
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
RUN sh Miniconda3-latest-Linux-x86_64.sh -b -u -p ~/miniconda3
RUN ~/miniconda3/bin/conda init
RUN source ~/.bashrc
RUN conda create -n nerfstream python=3.10
RUN conda activate nerfstream
RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
# install depend
RUN conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
Copy requirements.txt ./
RUN pip install -r requirements.txt
# additional libraries
RUN pip install "git+https://github.com/facebookresearch/pytorch3d.git"
RUN pip install tensorflow-gpu==2.8.0
RUN pip uninstall protobuf
RUN pip install protobuf==3.20.1
RUN conda install ffmpeg
Copy ../python_rtmpstream /python_rtmpstream
WORKDIR /python_rtmpstream/python
RUN pip install .
Copy ../nerfstream /nerfstream
WORKDIR /nerfstream
CMD ["python3", "app.py"]
... ...
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [livetalking@lipku]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
... ...
Real time interactive streaming digital human, realize audio video synchronous dialogue. It can basically achieve commercial effects.
实时交互流式数字人,实现音视频同步对话。基本可以达到商用效果
[ernerf 效果](https://www.bilibili.com/video/BV1PM4m1y7Q2/) [musetalk 效果](https://www.bilibili.com/video/BV1gm421N7vQ/) [wav2lip 效果](https://www.bilibili.com/video/BV1Bw4m1e74P/)
## 为避免与 3d 数字人混淆,原项目 metahuman-stream 改名为 livetalking,原有链接地址继续可用
## News
- 2024.12.8 完善多并发,显存不随并发数增加
- 2024.12.21 添加 wav2lip、musetalk 模型预热,解决第一次推理卡顿问题。感谢@heimaojinzhangyz
- 2024.12.28 添加数字人模型 Ultralight-Digital-Human。 感谢@lijihua2017
- 2025.2.7 添加 fish-speech tts
- 2025.2.21 添加 wav2lip256 开源模型 感谢@不蠢不蠢
- 2025.3.2 添加腾讯语音合成服务
## Features
1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human
2. 支持声音克隆
3. 支持数字人说话被打断
4. 支持全身视频拼接
5. 支持 rtmp 和 webrtc
6. 支持视频编排:不说话时播放自定义视频
7. 支持多并发
## 1. Installation
Tested on Ubuntu 20.04, Python3.10, Pytorch 1.12 and CUDA 11.3
### 1.1 Install dependency
```bash
conda create -n nerfstream python=3.10
conda activate nerfstream
#如果cuda版本不为11.3(运行nvidia-smi确认版本),根据<https://pytorch.org/get-started/previous-versions/>安装对应版本的pytorch
conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt
#如果需要训练ernerf模型,安装下面的库
# pip install "git+https://github.com/facebookresearch/pytorch3d.git"
# pip install tensorflow-gpu==2.8.0
# pip install --upgrade "protobuf<=3.20.1"
```
安装常见问题[FAQ](https://livetalking-doc.readthedocs.io/en/latest/faq.html)
linux cuda 环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886
## 2. Quick Start
- 下载模型
百度云盘<https://pan.baidu.com/s/1yOsQ06-RIDTJd3HFCw4wtA> 密码: ltua
GoogleDriver <https://drive.google.com/drive/folders/1FOC_MD6wdogyyX_7V1d4NDIO7P9NlSAJ?usp=sharing>
将 wav2lip256.pth 拷到本项目的 models 下, 重命名为 wav2lip.pth;
将 wav2lip256_avatar1.tar.gz 解压后整个文件夹拷到本项目的 data/avatars 下
- 运行
python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar1 --preload 2
使用 GPU 启动模特 3 号:python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar3 --preload 2
用浏览器打开 http://serverip:8010/webrtcapi.html , 先点‘start',播放数字人视频;然后在文本框输入任意文字,提交。数字人播报该段文字
<font color=red>服务端需要开放端口 tcp:8010; udp:1-65536 </font>
如果需要商用高清 wav2lip 模型,可以与我联系购买
- 快速体验
<https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_GitHub_livetalking1.3> 用该镜像创建实例即可运行成功
如果访问不了 huggingface,在运行前
```
export HF_ENDPOINT=https://hf-mirror.com
```
## 3. More Usage
使用说明: <https://livetalking-doc.readthedocs.io/>
## 4. Docker Run
不需要前面的安装,直接运行。
```
docker run --gpus all -it --network=host --rm registry.cn-beijing.aliyuncs.com/codewithgpu2/lipku-metahuman-stream:2K9qaMBu8v
```
代码在/root/metahuman-stream,先 git pull 拉一下最新代码,然后执行命令同第 2、3 步
提供如下镜像
- autodl 镜像: <https://www.codewithgpu.com/i/lipku/metahuman-stream/base>
[autodl 教程](https://livetalking-doc.readthedocs.io/en/latest/autodl/README.html)
- ucloud 镜像: <https://www.compshare.cn/images-detail?ImageID=compshareImage-18tpjhhxoq3j&referral_code=3XW3852OBmnD089hMMrtuU&ytag=GPU_livetalking1.3>
可以开放任意端口,不需要另外部署 srs 服务.
[ucloud 教程](https://livetalking-doc.readthedocs.io/en/latest/ucloud/ucloud.html)
## 5. TODO
- [x] 添加 chatgpt 实现数字人对话
- [x] 声音克隆
- [x] 数字人静音时用一段视频代替
- [x] MuseTalk
- [x] Wav2Lip
- [x] Ultralight-Digital-Human
---
如果本项目对你有帮助,帮忙点个 star。也欢迎感兴趣的朋友一起来完善该项目.
- 知识星球: https://t.zsxq.com/7NMyO 沉淀高质量常见问题、最佳实践经验、问题解答
- 微信公众号:数字人技术
![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyfaiaLZGuMGQXnhLWxibpJUS2gfs8Dje6JuMY8zu2tVyU9n8Zx1yaNncvKHBMibX0ocehoITy5qQEZg/640?wxfrom=12&tp=wxpic&usePicPrefetch=1&wx_fmt=jpeg&from=appmsg)
... ...
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# server.py
from flask import Flask, render_template,send_from_directory,request, jsonify
from flask_sockets import Sockets
import base64
import json
#import gevent
#from gevent import pywsgi
#from geventwebsocket.handler import WebSocketHandler
import re
import numpy as np
from threading import Thread,Event
#import multiprocessing
import torch.multiprocessing as mp
from aiohttp import web
import aiohttp
import aiohttp_cors
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.rtcrtpsender import RTCRtpSender
from webrtc import HumanPlayer
from basereal import BaseReal
from llm import llm_response
import argparse
import random
import shutil
import asyncio
import torch
from typing import Dict
from logger import logger
app = Flask(__name__)
#sockets = Sockets(app)
nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
opt = None
model = None
avatar = None
#####webrtc###############################
pcs = set()
def randN(N)->int:
'''生成长度为 N的随机数 '''
min = pow(10, N - 1)
max = pow(10, N)
return random.randint(min, max - 1)
def build_nerfreal(sessionid:int)->BaseReal:
opt.sessionid=sessionid
if opt.model == 'wav2lip':
from lipreal import LipReal
nerfreal = LipReal(opt,model,avatar)
elif opt.model == 'musetalk':
from musereal import MuseReal
nerfreal = MuseReal(opt,model,avatar)
elif opt.model == 'ernerf':
from nerfreal import NeRFReal
nerfreal = NeRFReal(opt,model,avatar)
elif opt.model == 'ultralight':
from lightreal import LightReal
nerfreal = LightReal(opt,model,avatar)
return nerfreal
#@app.route('/offer', methods=['POST'])
async def offer(request):
params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
if len(nerfreals) >= opt.max_session:
logger.info('reach max session')
return -1
sessionid = randN(6) #len(nerfreals)
logger.info('sessionid=%d',sessionid)
nerfreals[sessionid] = None
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
nerfreals[sessionid] = nerfreal
pc = RTCPeerConnection()
pcs.add(pc)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
logger.info("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
del nerfreals[sessionid]
if pc.connectionState == "closed":
pcs.discard(pc)
del nerfreals[sessionid]
player = HumanPlayer(nerfreals[sessionid])
audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video)
capabilities = RTCRtpSender.getCapabilities("video")
preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
transceiver = pc.getTransceivers()[1]
transceiver.setCodecPreferences(preferences)
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
#return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid":sessionid}
),
)
async def human(request):
params = await request.json()
sessionid = params.get('sessionid',0)
if params.get('interrupt'):
nerfreals[sessionid].flush_talk()
if params['type']=='echo':
nerfreals[sessionid].put_msg_txt(params['text'])
elif params['type']=='chat':
res=await asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'],nerfreals[sessionid])
#nerfreals[sessionid].put_msg_txt(res)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data":"ok"}
),
)
async def humanaudio(request):
try:
params = await request.json()
sessionid = int(params.get('sessionid', 0))
fileobj = params['file_url']
if fileobj.startswith("http"):
async with aiohttp.ClientSession() as session:
async with session.get(fileobj) as response:
if response.status == 200:
filebytes = await response.read()
else:
return web.Response(
content_type="application/json",
text=json.dumps({"code": -1, "msg": "Error downloading file"})
)
else:
filename = fileobj.filename
filebytes = fileobj.file.read()
nerfreals[sessionid].put_audio_file(filebytes)
return web.Response(
content_type="application/json",
text=json.dumps({"code": 0, "msg": "ok"})
)
except Exception as e:
return web.Response(
content_type="application/json",
text=json.dumps({"code": -1, "msg": "err", "data": str(e)})
)
async def set_audiotype(request):
params = await request.json()
sessionid = params.get('sessionid',0)
nerfreals[sessionid].set_curr_state(params['audiotype'],params['reinit'])
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data":"ok"}
),
)
async def record(request):
params = await request.json()
sessionid = params.get('sessionid',0)
if params['type']=='start_record':
# nerfreals[sessionid].put_msg_txt(params['text'])
nerfreals[sessionid].start_recording()
elif params['type']=='end_record':
nerfreals[sessionid].stop_recording()
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data":"ok"}
),
)
async def is_speaking(request):
params = await request.json()
sessionid = params.get('sessionid',0)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data": nerfreals[sessionid].is_speaking()}
),
)
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
async def post(url,data):
try:
async with aiohttp.ClientSession() as session:
async with session.post(url,data=data) as response:
return await response.text()
except aiohttp.ClientError as e:
logger.info(f'Error: {e}')
async def run(push_url,sessionid):
nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
nerfreals[sessionid] = nerfreal
pc = RTCPeerConnection()
pcs.add(pc)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
logger.info("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
player = HumanPlayer(nerfreals[sessionid])
audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video)
await pc.setLocalDescription(await pc.createOffer())
answer = await post(push_url,pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
##########################################
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
if __name__ == '__main__':
mp.set_start_method('spawn')
parser = argparse.ArgumentParser()
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area")
parser.add_argument('--torso_imgs', type=str, default="", help="torso images path")
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye")
parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use")
parser.add_argument('--workspace', type=str, default='data/video')
parser.add_argument('--seed', type=int, default=0)
### training options
parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth')
parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step")
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
### loss set
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
parser.add_argument('--amb_aud_loss', type=int, default=1, help="use ambient aud loss")
parser.add_argument('--amb_eye_loss', type=int, default=1, help="use ambient eye loss")
parser.add_argument('--unc_loss', type=int, default=1, help="use uncertainty loss")
parser.add_argument('--lambda_amb', type=float, default=1e-4, help="lambda for ambient loss")
### network backbone options
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--bg_img', type=str, default='white', help="background image")
parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform")
### dataset options
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
# (the default value is for the fox dataset)
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)")
parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)")
parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
parser.add_argument('--init_lips', action='store_true', help="init lips region")
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...")
parser.add_argument('--torso', action='store_true', help="fix head and train torso")
parser.add_argument('--head_ckpt', type=str, default='', help="head model")
### GUI options
parser.add_argument('--gui', action='store_true', help="start a GUI")
parser.add_argument('--W', type=int, default=450, help="GUI width")
parser.add_argument('--H', type=int, default=450, help="GUI height")
parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center")
parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy")
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
### else
parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)")
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size")
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension")
parser.add_argument('--part', action='store_true', help="use partial training data (1/10)")
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size")
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
# asr
parser.add_argument('--asr', action='store_true', help="load asr for real-time app")
parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
parser.add_argument('--asr_play', action='store_true', help="play out the audio")
#parser.add_argument('--asr_model', type=str, default='deepspeech')
parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') #
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
# parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
parser.add_argument('--asr_save_feats', action='store_true')
# audio FPS
parser.add_argument('--fps', type=int, default=50)
# sliding window left-middle-right length (unit: 20ms)
parser.add_argument('-l', type=int, default=10)
parser.add_argument('-m', type=int, default=8)
parser.add_argument('-r', type=int, default=10)
parser.add_argument('--fullbody', action='store_true', help="fullbody human")
parser.add_argument('--fullbody_img', type=str, default='data/fullbody/img')
parser.add_argument('--fullbody_width', type=int, default=580)
parser.add_argument('--fullbody_height', type=int, default=1080)
parser.add_argument('--fullbody_offset_x', type=int, default=0)
parser.add_argument('--fullbody_offset_y', type=int, default=0)
#musetalk opt
parser.add_argument('--avatar_id', type=str, default='avator_1')
parser.add_argument('--bbox_shift', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=16)
# parser.add_argument('--customvideo', action='store_true', help="custom video")
# parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
# parser.add_argument('--customvideo_imgnum', type=int, default=1)
parser.add_argument('--customvideo_config', type=str, default='')
parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits cosyvoice
parser.add_argument('--REF_FILE', type=str, default=None)
parser.add_argument('--REF_TEXT', type=str, default=None)
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
# parser.add_argument('--CHARACTER', type=str, default='test')
# parser.add_argument('--EMOTION', type=str, default='default')
parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
parser.add_argument('--max_session', type=int, default=1) #multi session count
parser.add_argument('--listenport', type=int, default=8010)
opt = parser.parse_args()
#app.config.from_object(opt)
#print(app.config)
opt.customopt = []
if opt.customvideo_config!='':
with open(opt.customvideo_config,'r') as file:
opt.customopt = json.load(file)
if opt.model == 'ernerf':
from nerfreal import NeRFReal,load_model,load_avatar
model = load_model(opt)
avatar = load_avatar(opt)
# we still need test_loader to provide audio features for testing.
# for k in range(opt.max_session):
# opt.sessionid=k
# nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model)
# nerfreals.append(nerfreal)
elif opt.model == 'musetalk':
from musereal import MuseReal,load_model,load_avatar,warm_up
logger.info(opt)
model = load_model()
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model)
# for k in range(opt.max_session):
# opt.sessionid=k
# nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps)
# nerfreals.append(nerfreal)
elif opt.model == 'wav2lip':
from lipreal import LipReal,load_model,load_avatar,warm_up
logger.info(opt)
model = load_model("./models/wav2lip.pth")
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model,256)
# for k in range(opt.max_session):
# opt.sessionid=k
# nerfreal = LipReal(opt,model)
# nerfreals.append(nerfreal)
elif opt.model == 'ultralight':
from lightreal import LightReal,load_model,load_avatar,warm_up
logger.info(opt)
model = load_model(opt)
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,avatar,160)
if opt.transport=='rtmp':
thread_quit = Event()
nerfreals[0] = build_nerfreal(0)
rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
rendthrd.start()
#############################################################################
appasync = web.Application()
appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer)
appasync.router.add_post("/human", human)
appasync.router.add_post("/humanaudio", humanaudio)
appasync.router.add_post("/set_audiotype", set_audiotype)
appasync.router.add_post("/record", record)
appasync.router.add_post("/is_speaking", is_speaking)
appasync.router.add_static('/',path='web')
# Configure default CORS settings.
cors = aiohttp_cors.setup(appasync, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*",
)
})
# Configure CORS on all routes.
for route in list(appasync.router.routes()):
cors.add(route)
pagename='webrtcapi.html'
if opt.transport=='rtmp':
pagename='echoapi.html'
elif opt.transport=='rtcpush':
pagename='rtcpushapi.html'
logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
def run_server(runner):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
loop.run_until_complete(site.start())
if opt.transport=='rtcpush':
for k in range(opt.max_session):
push_url = opt.push_url
if k!=0:
push_url = opt.push_url+str(k)
loop.run_until_complete(run(push_url,k))
loop.run_forever()
#Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
run_server(web.AppRunner(appasync))
#app.on_shutdown.append(on_shutdown)
#app.router.add_post("/offer", offer)
# print('start websocket server')
# server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
# server.serve_forever()
... ...
1. pytorch3d安装不成功\
下载源码编译
```bash
git clone https://github.com/facebookresearch/pytorch3d.git
python setup.py install
```
2. websocket连接报错\
修改python/site-packages/flask\_sockets.py
```python
self.url_map.add(Rule(rule, endpoint=f)) 改成
self.url_map.add(Rule(rule, endpoint=f, websocket=True))
```
3. protobuf版本过高
```bash
pip uninstall protobuf
pip install protobuf==3.20.1
```
4. 数字人不眨眼\
训练模型时添加如下步骤
> Obtain AU45 for eyes blinking.\
> Run FeatureExtraction in OpenFace, rename and move the output CSV file to data/\<ID>/au.csv.
将au.csv拷到本项目的data目录下
5. 数字人添加背景图片
```bash
python app.py --bg_img bc.jpg
```
6. 用自己训练的模型报错维度不匹配\
训练模型时用wav2vec提取音频特征
```bash
python main.py data/ --workspace workspace/ -O --iters 100000 --asr_model cpierse/wav2vec2-large-xlsr-53-esperanto
```
7. rtmp推流时ffmpeg版本不对
网上版友反馈是需要4.2.2版本。我也不确定具体哪些版本不行。原则是运行一下ffmpeg,打印的信息里需要有libx264,如果没有肯定不行
```
--enable-libx264
```
8. 替换自己训练的模型
```python
.
├── data
├── data_kf.json (对应训练数据中的transforms_train.json
├── au.csv
├── pretrained
└── └── ngp_kf.pth (对应训练后的模型ngp_ep00xx.pth
```
其他参考
https://github.com/lipku/metahuman-stream/issues/43#issuecomment-2008930101
... ...
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import time
import numpy as np
import queue
from queue import Queue
import torch.multiprocessing as mp
from basereal import BaseReal
class BaseASR:
def __init__(self, opt, parent:BaseReal = None):
self.opt = opt
self.parent = parent
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
self.output_queue = mp.Queue()
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = opt.l
self.stride_right_size = opt.r
#self.context_size = 10
self.feat_queue = mp.Queue(2)
#self.warm_up()
def flush_talk(self):
self.queue.queue.clear()
def put_audio_frame(self,audio_chunk,eventpoint=None): #16khz 20ms pcm
self.queue.put((audio_chunk,eventpoint))
#return frame:audio pcm; type: 0-normal speak, 1-silence; eventpoint:custom event sync with audio
def get_audio_frame(self):
try:
frame,eventpoint = self.queue.get(block=True,timeout=0.01)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
if self.parent and self.parent.curr_state>1: #播放自定义音频
frame = self.parent.get_audio_stream(self.parent.curr_state)
type = self.parent.curr_state
else:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
eventpoint = None
return frame,type,eventpoint
#return frame:audio pcm; type: 0-normal speak, 1-silence; eventpoint:custom event sync with audio
def get_audio_out(self):
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type,eventpoint=self.get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type,eventpoint))
for _ in range(self.stride_left_size):
self.output_queue.get()
def run_step(self):
pass
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)
\ No newline at end of file
... ...
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import math
import torch
import numpy as np
import subprocess
import os
import time
import cv2
import glob
import resampy
import queue
from queue import Queue
from threading import Thread, Event
from io import BytesIO
import soundfile as sf
import av
from fractions import Fraction
from ttsreal import EdgeTTS,VoitsTTS,XTTS,CosyVoiceTTS,FishTTS,TencentTTS
from logger import logger
from tqdm import tqdm
def read_imgs(img_list):
frames = []
logger.info('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
class BaseReal:
def __init__(self, opt):
self.opt = opt
self.sample_rate = 16000
self.chunk = self.sample_rate // opt.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.sessionid = self.opt.sessionid
if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits":
self.tts = VoitsTTS(opt,self)
elif opt.tts == "xtts":
self.tts = XTTS(opt,self)
elif opt.tts == "cosyvoice":
self.tts = CosyVoiceTTS(opt,self)
elif opt.tts == "fishtts":
self.tts = FishTTS(opt,self)
elif opt.tts == "tencent":
self.tts = TencentTTS(opt,self)
self.speaking = False
self.recording = False
self._record_video_pipe = None
self._record_audio_pipe = None
self.width = self.height = 0
self.curr_state=0
self.custom_img_cycle = {}
self.custom_audio_cycle = {}
self.custom_audio_index = {}
self.custom_index = {}
self.custom_opt = {}
self.__loadcustom()
def put_msg_txt(self,msg,eventpoint=None):
self.tts.put_msg_txt(msg,eventpoint)
def put_audio_frame(self,audio_chunk,eventpoint=None): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk,eventpoint)
def put_audio_file(self,filebyte):
input_stream = BytesIO(filebyte)
stream = self.__create_bytes_stream(input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk: #and self.state==State.RUNNING
self.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
logger.info(f'[INFO]put audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
logger.info(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0:
logger.info(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
def flush_talk(self):
self.tts.flush_talk()
self.asr.flush_talk()
def is_speaking(self)->bool:
return self.speaking
def __loadcustom(self):
for item in self.opt.customopt:
logger.info(item)
input_img_list = glob.glob(os.path.join(item['imgpath'], '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.custom_img_cycle[item['audiotype']] = read_imgs(input_img_list)
self.custom_audio_cycle[item['audiotype']], sample_rate = sf.read(item['audiopath'], dtype='float32')
self.custom_audio_index[item['audiotype']] = 0
self.custom_index[item['audiotype']] = 0
self.custom_opt[item['audiotype']] = item
def init_customindex(self):
self.curr_state=0
for key in self.custom_audio_index:
self.custom_audio_index[key]=0
for key in self.custom_index:
self.custom_index[key]=0
def notify(self,eventpoint):
logger.info("notify:%s",eventpoint)
def start_recording(self):
"""开始录制视频"""
if self.recording:
return
command = ['ffmpeg',
'-y', '-an',
'-f', 'rawvideo',
'-vcodec','rawvideo',
'-pix_fmt', 'bgr24', #像素格式
'-s', "{}x{}".format(self.width, self.height),
'-r', str(25),
'-i', '-',
'-pix_fmt', 'yuv420p',
'-vcodec', "h264",
#'-f' , 'flv',
f'temp{self.opt.sessionid}.mp4']
self._record_video_pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE)
acommand = ['ffmpeg',
'-y', '-vn',
'-f', 's16le',
#'-acodec','pcm_s16le',
'-ac', '1',
'-ar', '16000',
'-i', '-',
'-acodec', 'aac',
#'-f' , 'wav',
f'temp{self.opt.sessionid}.aac']
self._record_audio_pipe = subprocess.Popen(acommand, shell=False, stdin=subprocess.PIPE)
self.recording = True
# self.recordq_video.queue.clear()
# self.recordq_audio.queue.clear()
# self.container = av.open(path, mode="w")
# process_thread = Thread(target=self.record_frame, args=())
# process_thread.start()
def record_video_data(self,image):
if self.width == 0:
print("image.shape:",image.shape)
self.height,self.width,_ = image.shape
if self.recording:
self._record_video_pipe.stdin.write(image.tostring())
def record_audio_data(self,frame):
if self.recording:
self._record_audio_pipe.stdin.write(frame.tostring())
# def record_frame(self):
# videostream = self.container.add_stream("libx264", rate=25)
# videostream.codec_context.time_base = Fraction(1, 25)
# audiostream = self.container.add_stream("aac")
# audiostream.codec_context.time_base = Fraction(1, 16000)
# init = True
# framenum = 0
# while self.recording:
# try:
# videoframe = self.recordq_video.get(block=True, timeout=1)
# videoframe.pts = framenum #int(round(framenum*0.04 / videostream.codec_context.time_base))
# videoframe.dts = videoframe.pts
# if init:
# videostream.width = videoframe.width
# videostream.height = videoframe.height
# init = False
# for packet in videostream.encode(videoframe):
# self.container.mux(packet)
# for k in range(2):
# audioframe = self.recordq_audio.get(block=True, timeout=1)
# audioframe.pts = int(round((framenum*2+k)*0.02 / audiostream.codec_context.time_base))
# audioframe.dts = audioframe.pts
# for packet in audiostream.encode(audioframe):
# self.container.mux(packet)
# framenum += 1
# except queue.Empty:
# print('record queue empty,')
# continue
# except Exception as e:
# print(e)
# #break
# for packet in videostream.encode(None):
# self.container.mux(packet)
# for packet in audiostream.encode(None):
# self.container.mux(packet)
# self.container.close()
# self.recordq_video.queue.clear()
# self.recordq_audio.queue.clear()
# print('record thread stop')
def stop_recording(self):
"""停止录制视频"""
if not self.recording:
return
self.recording = False
self._record_video_pipe.stdin.close() #wait()
self._record_video_pipe.wait()
self._record_audio_pipe.stdin.close()
self._record_audio_pipe.wait()
cmd_combine_audio = f"ffmpeg -y -i temp{self.opt.sessionid}.aac -i temp{self.opt.sessionid}.mp4 -c:v copy -c:a copy data/record.mp4"
os.system(cmd_combine_audio)
#os.remove(output_path)
def mirror_index(self,size, index):
#size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
def get_audio_stream(self,audiotype):
idx = self.custom_audio_index[audiotype]
stream = self.custom_audio_cycle[audiotype][idx:idx+self.chunk]
self.custom_audio_index[audiotype] += self.chunk
if self.custom_audio_index[audiotype]>=self.custom_audio_cycle[audiotype].shape[0]:
self.curr_state = 1 #当前视频不循环播放,切换到静音状态
return stream
def set_curr_state(self,audiotype, reinit):
print('set_curr_state:',audiotype)
self.curr_state = audiotype
if reinit:
self.custom_audio_index[audiotype] = 0
self.custom_index[audiotype] = 0
# def process_custom(self,audiotype:int,idx:int):
# if self.curr_state!=audiotype: #从推理切到口播
# if idx in self.switch_pos: #在卡点位置可以切换
# self.curr_state=audiotype
# self.custom_index=0
# else:
# self.custom_index+=1
\ No newline at end of file
... ...
# Routines for DeepSpeech features processing
Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model.
## Installation
```
pip3 install -r requirements.txt
```
## Usage
Generate wav files:
```
python3 extract_wav.py --in-video=<you_data_dir>
```
Generate files with DeepSpeech features:
```
python3 extract_ds_features.py --input=<you_data_dir>
```
... ...
"""
DeepSpeech features processing routines.
NB: Based on VOCA code. See the corresponding license restrictions.
"""
__all__ = ['conv_audios_to_deepspeech']
import numpy as np
import warnings
import resampy
from scipy.io import wavfile
from python_speech_features import mfcc
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
def conv_audios_to_deepspeech(audios,
out_files,
num_frames_info,
deepspeech_pb_path,
audio_window_size=1,
audio_window_stride=1):
"""
Convert list of audio files into files with DeepSpeech features.
Parameters
----------
audios : list of str or list of None
Paths to input audio files.
out_files : list of str
Paths to output files with DeepSpeech features.
num_frames_info : list of int
List of numbers of frames.
deepspeech_pb_path : str
Path to DeepSpeech 0.1.0 frozen model.
audio_window_size : int, default 16
Audio window size.
audio_window_stride : int, default 1
Audio window stride.
"""
# deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(
deepspeech_pb_path)
with tf.compat.v1.Session(graph=graph) as sess:
for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):
print(audio_file_path)
print(out_file_path)
audio_sample_rate, audio = wavfile.read(audio_file_path)
if audio.ndim != 1:
warnings.warn(
"Audio has multiple channels, the first channel is used")
audio = audio[:, 0]
ds_features = pure_conv_audio_to_deepspeech(
audio=audio,
audio_sample_rate=audio_sample_rate,
audio_window_size=audio_window_size,
audio_window_stride=audio_window_stride,
num_frames=num_frames,
net_fn=lambda x: sess.run(
logits_ph,
feed_dict={
input_node_ph: x[np.newaxis, ...],
input_lengths_ph: [x.shape[0]]}))
net_output = ds_features.reshape(-1, 29)
win_size = 16
zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
net_output = np.concatenate(
(zero_pad, net_output, zero_pad), axis=0)
windows = []
for window_index in range(0, net_output.shape[0] - win_size, 2):
windows.append(
net_output[window_index:window_index + win_size])
print(np.array(windows).shape)
np.save(out_file_path, np.array(windows))
def prepare_deepspeech_net(deepspeech_pb_path):
"""
Load and prepare DeepSpeech network.
Parameters
----------
deepspeech_pb_path : str
Path to DeepSpeech 0.1.0 frozen model.
Returns
-------
graph : obj
ThensorFlow graph.
logits_ph : obj
ThensorFlow placeholder for `logits`.
input_node_ph : obj
ThensorFlow placeholder for `input_node`.
input_lengths_ph : obj
ThensorFlow placeholder for `input_lengths`.
"""
# Load graph and place_holders:
with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.compat.v1.get_default_graph()
tf.import_graph_def(graph_def, name="deepspeech")
logits_ph = graph.get_tensor_by_name("deepspeech/logits:0")
input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0")
input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0")
return graph, logits_ph, input_node_ph, input_lengths_ph
def pure_conv_audio_to_deepspeech(audio,
audio_sample_rate,
audio_window_size,
audio_window_stride,
num_frames,
net_fn):
"""
Core routine for converting audion into DeepSpeech features.
Parameters
----------
audio : np.array
Audio data.
audio_sample_rate : int
Audio sample rate.
audio_window_size : int
Audio window size.
audio_window_stride : int
Audio window stride.
num_frames : int or None
Numbers of frames.
net_fn : func
Function for DeepSpeech model call.
Returns
-------
np.array
DeepSpeech features.
"""
target_sample_rate = 16000
if audio_sample_rate != target_sample_rate:
resampled_audio = resampy.resample(
x=audio.astype(np.float),
sr_orig=audio_sample_rate,
sr_new=target_sample_rate)
else:
resampled_audio = audio.astype(np.float)
input_vector = conv_audio_to_deepspeech_input_vector(
audio=resampled_audio.astype(np.int16),
sample_rate=target_sample_rate,
num_cepstrum=26,
num_context=9)
network_output = net_fn(input_vector)
# print(network_output.shape)
deepspeech_fps = 50
video_fps = 50 # Change this option if video fps is different
audio_len_s = float(audio.shape[0]) / audio_sample_rate
if num_frames is None:
num_frames = int(round(audio_len_s * video_fps))
else:
video_fps = num_frames / audio_len_s
network_output = interpolate_features(
features=network_output[:, 0],
input_rate=deepspeech_fps,
output_rate=video_fps,
output_len=num_frames)
# Make windows:
zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))
network_output = np.concatenate(
(zero_pad, network_output, zero_pad), axis=0)
windows = []
for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):
windows.append(
network_output[window_index:window_index + audio_window_size])
return np.array(windows)
def conv_audio_to_deepspeech_input_vector(audio,
sample_rate,
num_cepstrum,
num_context):
"""
Convert audio raw data into DeepSpeech input vector.
Parameters
----------
audio : np.array
Audio data.
audio_sample_rate : int
Audio sample rate.
num_cepstrum : int
Number of cepstrum.
num_context : int
Number of context.
Returns
-------
np.array
DeepSpeech input vector.
"""
# Get mfcc coefficients:
features = mfcc(
signal=audio,
samplerate=sample_rate,
numcep=num_cepstrum)
# We only keep every second feature (BiRNN stride = 2):
features = features[::2]
# One stride per time step in the input:
num_strides = len(features)
# Add empty initial and final contexts:
empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)
features = np.concatenate((empty_context, features, empty_context))
# Create a view into the array with overlapping strides of size
# numcontext (past) + 1 (present) + numcontext (future):
window_size = 2 * num_context + 1
train_inputs = np.lib.stride_tricks.as_strided(
features,
shape=(num_strides, window_size, num_cepstrum),
strides=(features.strides[0],
features.strides[0], features.strides[1]),
writeable=False)
# Flatten the second and third dimensions:
train_inputs = np.reshape(train_inputs, [num_strides, -1])
train_inputs = np.copy(train_inputs)
train_inputs = (train_inputs - np.mean(train_inputs)) / \
np.std(train_inputs)
return train_inputs
def interpolate_features(features,
input_rate,
output_rate,
output_len):
"""
Interpolate DeepSpeech features.
Parameters
----------
features : np.array
DeepSpeech features.
input_rate : int
input rate (FPS).
output_rate : int
Output rate (FPS).
output_len : int
Output data length.
Returns
-------
np.array
Interpolated data.
"""
input_len = features.shape[0]
num_features = features.shape[1]
input_timestamps = np.arange(input_len) / float(input_rate)
output_timestamps = np.arange(output_len) / float(output_rate)
output_features = np.zeros((output_len, num_features))
for feature_idx in range(num_features):
output_features[:, feature_idx] = np.interp(
x=output_timestamps,
xp=input_timestamps,
fp=features[:, feature_idx])
return output_features
... ...
"""
Routines for loading DeepSpeech model.
"""
__all__ = ['get_deepspeech_model_file']
import os
import zipfile
import logging
import hashlib
deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'
def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")):
"""
Return location for the pretrained on local file system. This function will download from online model zoo when
model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
Parameters
----------
local_model_store_dir_path : str, default $TENSORFLOW_HOME/models
Location for keeping the model parameters.
Returns
-------
file_path
Path to the requested pretrained model file.
"""
sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e"
file_name = "deepspeech-0_1_0-b90017e8.pb"
local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)
file_path = os.path.join(local_model_store_dir_path, file_name)
if os.path.exists(file_path):
if _check_sha1(file_path, sha1_hash):
return file_path
else:
logging.warning("Mismatch in the content of model file detected. Downloading again.")
else:
logging.info("Model file not found. Downloading to {}.".format(file_path))
if not os.path.exists(local_model_store_dir_path):
os.makedirs(local_model_store_dir_path)
zip_file_path = file_path + ".zip"
_download(
url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format(
repo_url=deepspeech_features_repo_url,
repo_release_tag="v0.0.1",
file_name=file_name),
path=zip_file_path,
overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(local_model_store_dir_path)
os.remove(zip_file_path)
if _check_sha1(file_path, sha1_hash):
return file_path
else:
raise ValueError("Downloaded file has different hash. Please try again.")
def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""
Download an given URL
Parameters
----------
url : str
URL to download
path : str, optional
Destination path to store downloaded file. By default stores to the
current directory with same name as in url.
overwrite : bool, optional
Whether to overwrite destination file if already exists.
sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
but doesn't match.
retries : integer, default 5
The number of times to attempt the download in case of failure or non 200 return codes
verify_ssl : bool, default True
Verify SSL certificates.
Returns
-------
str
The file path of the downloaded file.
"""
import warnings
try:
import requests
except ImportError:
class requests_failed_to_import(object):
pass
requests = requests_failed_to_import
if path is None:
fname = url.split("/")[-1]
# Empty filenames are invalid
assert fname, "Can't construct file-name from this URL. Please set the `path` option manually."
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
fname = os.path.join(path, url.split("/")[-1])
else:
fname = path
assert retries >= 0, "Number of retries should be at least 0"
if not verify_ssl:
warnings.warn(
"Unverified HTTPS request is being made (verify_ssl=False). "
"Adding certificate verification is strongly advised.")
if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
while retries + 1 > 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
print("Downloading {} from {}...".format(fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError("Failed downloading url {}".format(url))
with open(fname, "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if sha1_hash and not _check_sha1(fname, sha1_hash):
raise UserWarning("File {} is downloaded but the content hash does not match."
" The repo may be outdated or download may be incomplete. "
"If the `repo_url` is overridden, consider switching to "
"the default repo.".format(fname))
break
except Exception as e:
retries -= 1
if retries <= 0:
raise e
else:
print("download failed, retrying, {} attempt{} left"
.format(retries, "s" if retries > 1 else ""))
return fname
def _check_sha1(filename, sha1_hash):
"""
Check whether the sha1 hash of the file content matches the expected hash.
Parameters
----------
filename : str
Path to the file.
sha1_hash : str
Expected sha1 hash in hexadecimal digits.
Returns
-------
bool
Whether the file content matches the expected hash.
"""
sha1 = hashlib.sha1()
with open(filename, "rb") as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
return sha1.hexdigest() == sha1_hash
... ...
"""
Script for extracting DeepSpeech features from audio file.
"""
import os
import argparse
import numpy as np
import pandas as pd
from deepspeech_store import get_deepspeech_model_file
from deepspeech_features import conv_audios_to_deepspeech
def parse_args():
"""
Create python script parameters.
Returns
-------
ArgumentParser
Resulted args.
"""
parser = argparse.ArgumentParser(
description="Extract DeepSpeech features from audio file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--input",
type=str,
required=True,
help="path to input audio file or directory")
parser.add_argument(
"--output",
type=str,
help="path to output file with DeepSpeech features")
parser.add_argument(
"--deepspeech",
type=str,
help="path to DeepSpeech 0.1.0 frozen model")
parser.add_argument(
"--metainfo",
type=str,
help="path to file with meta-information")
args = parser.parse_args()
return args
def extract_features(in_audios,
out_files,
deepspeech_pb_path,
metainfo_file_path=None):
"""
Real extract audio from video file.
Parameters
----------
in_audios : list of str
Paths to input audio files.
out_files : list of str
Paths to output files with DeepSpeech features.
deepspeech_pb_path : str
Path to DeepSpeech 0.1.0 frozen model.
metainfo_file_path : str, default None
Path to file with meta-information.
"""
#deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
if metainfo_file_path is None:
num_frames_info = [None] * len(in_audios)
else:
train_df = pd.read_csv(
metainfo_file_path,
sep="\t",
index_col=False,
dtype={"Id": np.int, "File": np.unicode, "Count": np.int})
num_frames_info = train_df["Count"].values
assert (len(num_frames_info) == len(in_audios))
for i, in_audio in enumerate(in_audios):
if not out_files[i]:
file_stem, _ = os.path.splitext(in_audio)
out_files[i] = file_stem + ".npy"
#print(out_files[i])
conv_audios_to_deepspeech(
audios=in_audios,
out_files=out_files,
num_frames_info=num_frames_info,
deepspeech_pb_path=deepspeech_pb_path)
def main():
"""
Main body of script.
"""
args = parse_args()
in_audio = os.path.expanduser(args.input)
if not os.path.exists(in_audio):
raise Exception("Input file/directory doesn't exist: {}".format(in_audio))
deepspeech_pb_path = args.deepspeech
#add
deepspeech_pb_path = True
args.deepspeech = '~/.tensorflow/models/deepspeech-0_1_0-b90017e8.pb'
#deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
if deepspeech_pb_path is None:
deepspeech_pb_path = ""
if deepspeech_pb_path:
deepspeech_pb_path = os.path.expanduser(args.deepspeech)
if not os.path.exists(deepspeech_pb_path):
deepspeech_pb_path = get_deepspeech_model_file()
if os.path.isfile(in_audio):
extract_features(
in_audios=[in_audio],
out_files=[args.output],
deepspeech_pb_path=deepspeech_pb_path,
metainfo_file_path=args.metainfo)
else:
audio_file_paths = []
for file_name in os.listdir(in_audio):
if not os.path.isfile(os.path.join(in_audio, file_name)):
continue
_, file_ext = os.path.splitext(file_name)
if file_ext.lower() == ".wav":
audio_file_path = os.path.join(in_audio, file_name)
audio_file_paths.append(audio_file_path)
audio_file_paths = sorted(audio_file_paths)
out_file_paths = [""] * len(audio_file_paths)
extract_features(
in_audios=audio_file_paths,
out_files=out_file_paths,
deepspeech_pb_path=deepspeech_pb_path,
metainfo_file_path=args.metainfo)
if __name__ == "__main__":
main()
... ...
"""
Script for extracting audio (16-bit, mono, 22000 Hz) from video file.
"""
import os
import argparse
import subprocess
def parse_args():
"""
Create python script parameters.
Returns
-------
ArgumentParser
Resulted args.
"""
parser = argparse.ArgumentParser(
description="Extract audio from video file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--in-video",
type=str,
required=True,
help="path to input video file or directory")
parser.add_argument(
"--out-audio",
type=str,
help="path to output audio file")
args = parser.parse_args()
return args
def extract_audio(in_video,
out_audio):
"""
Real extract audio from video file.
Parameters
----------
in_video : str
Path to input video file.
out_audio : str
Path to output audio file.
"""
if not out_audio:
file_stem, _ = os.path.splitext(in_video)
out_audio = file_stem + ".wav"
# command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}"
# command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
# command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}"
subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True)
def main():
"""
Main body of script.
"""
args = parse_args()
in_video = os.path.expanduser(args.in_video)
if not os.path.exists(in_video):
raise Exception("Input file/directory doesn't exist: {}".format(in_video))
if os.path.isfile(in_video):
extract_audio(
in_video=in_video,
out_audio=args.out_audio)
else:
video_file_paths = []
for file_name in os.listdir(in_video):
if not os.path.isfile(os.path.join(in_video, file_name)):
continue
_, file_ext = os.path.splitext(file_name)
if file_ext.lower() in (".mp4", ".mkv", ".avi"):
video_file_path = os.path.join(in_video, file_name)
video_file_paths.append(video_file_path)
video_file_paths = sorted(video_file_paths)
for video_file_path in video_file_paths:
extract_audio(
in_video=video_file_path,
out_audio="")
if __name__ == "__main__":
main()
... ...
import numpy as np
net_output = np.load('french.ds.npy').reshape(-1, 29)
win_size = 16
zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0)
windows = []
for window_index in range(0, net_output.shape[0] - win_size, 2):
windows.append(net_output[window_index:window_index + win_size])
print(np.array(windows).shape)
np.save('aud_french.npy', np.array(windows))
... ...
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import os.path as osp
import time
import sys
import logging
import torch.distributed as dist
def setup_logger(logpth):
logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
logfile = osp.join(logpth, logfile)
FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
log_level = logging.INFO
if dist.is_initialized() and not dist.get_rank()==0:
log_level = logging.ERROR
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
logging.root.addHandler(logging.StreamHandler())
... ...
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from resnet import Resnet18
# from modules.bn import InPlaceABNSync as BatchNorm2d
class ConvBNReLU(nn.Module):
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan,
out_chan,
kernel_size = ks,
stride = stride,
padding = padding,
bias = False)
self.bn = nn.BatchNorm2d(out_chan)
self.init_weight()
def forward(self, x):
x = self.conv(x)
x = F.relu(self.bn(x))
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
class BiSeNetOutput(nn.Module):
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
super(BiSeNetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
self.init_weight()
def forward(self, x):
x = self.conv(x)
x = self.conv_out(x)
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class AttentionRefinementModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
self.bn_atten = nn.BatchNorm2d(out_chan)
self.sigmoid_atten = nn.Sigmoid()
self.init_weight()
def forward(self, x):
feat = self.conv(x)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
out = torch.mul(feat, atten)
return out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
class ContextPath(nn.Module):
def __init__(self, *args, **kwargs):
super(ContextPath, self).__init__()
self.resnet = Resnet18()
self.arm16 = AttentionRefinementModule(256, 128)
self.arm32 = AttentionRefinementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
self.init_weight()
def forward(self, x):
H0, W0 = x.size()[2:]
feat8, feat16, feat32 = self.resnet(x)
H8, W8 = feat8.size()[2:]
H16, W16 = feat16.size()[2:]
H32, W32 = feat32.size()[2:]
avg = F.avg_pool2d(feat32, feat32.size()[2:])
avg = self.conv_avg(avg)
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + avg_up
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
feat32_up = self.conv_head32(feat32_up)
feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
feat16_up = self.conv_head16(feat16_up)
return feat8, feat16_up, feat32_up # x8, x8, x16
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
### This is not used, since I replace this with the resnet feature with the same size
class SpatialPath(nn.Module):
def __init__(self, *args, **kwargs):
super(SpatialPath, self).__init__()
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
self.init_weight()
def forward(self, x):
feat = self.conv1(x)
feat = self.conv2(feat)
feat = self.conv3(feat)
feat = self.conv_out(feat)
return feat
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class FeatureFusionModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d(out_chan,
out_chan//4,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.conv2 = nn.Conv2d(out_chan//4,
out_chan,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
self.init_weight()
def forward(self, fsp, fcp):
fcat = torch.cat([fsp, fcp], dim=1)
feat = self.convblk(fcat)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv1(atten)
atten = self.relu(atten)
atten = self.conv2(atten)
atten = self.sigmoid(atten)
feat_atten = torch.mul(feat, atten)
feat_out = feat_atten + feat
return feat_out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class BiSeNet(nn.Module):
def __init__(self, n_classes, *args, **kwargs):
super(BiSeNet, self).__init__()
self.cp = ContextPath()
## here self.sp is deleted
self.ffm = FeatureFusionModule(256, 256)
self.conv_out = BiSeNetOutput(256, 256, n_classes)
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
self.init_weight()
def forward(self, x):
H, W = x.size()[2:]
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
feat_fuse = self.ffm(feat_sp, feat_cp8)
feat_out = self.conv_out(feat_fuse)
feat_out16 = self.conv_out16(feat_cp8)
feat_out32 = self.conv_out32(feat_cp16)
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
# return feat_out, feat_out16, feat_out32
return feat_out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
for name, child in self.named_children():
child_wd_params, child_nowd_params = child.get_params()
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
lr_mul_wd_params += child_wd_params
lr_mul_nowd_params += child_nowd_params
else:
wd_params += child_wd_params
nowd_params += child_nowd_params
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
if __name__ == "__main__":
net = BiSeNet(19)
net.cuda()
net.eval()
in_ten = torch.randn(16, 3, 640, 480).cuda()
out, out16, out32 = net(in_ten)
print(out.shape)
net.get_params()
... ...
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as modelzoo
# from modules.bn import InPlaceABNSync as BatchNorm2d
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
def __init__(self, in_chan, out_chan, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_chan, out_chan, stride)
self.bn1 = nn.BatchNorm2d(out_chan)
self.conv2 = conv3x3(out_chan, out_chan)
self.bn2 = nn.BatchNorm2d(out_chan)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
if in_chan != out_chan or stride != 1:
self.downsample = nn.Sequential(
nn.Conv2d(in_chan, out_chan,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_chan),
)
def forward(self, x):
residual = self.conv1(x)
residual = F.relu(self.bn1(residual))
residual = self.conv2(residual)
residual = self.bn2(residual)
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
out = shortcut + residual
out = self.relu(out)
return out
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
for i in range(bnum-1):
layers.append(BasicBlock(out_chan, out_chan, stride=1))
return nn.Sequential(*layers)
class Resnet18(nn.Module):
def __init__(self):
super(Resnet18, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
self.init_weight()
def forward(self, x):
x = self.conv1(x)
x = F.relu(self.bn1(x))
x = self.maxpool(x)
x = self.layer1(x)
feat8 = self.layer2(x) # 1/8
feat16 = self.layer3(feat8) # 1/16
feat32 = self.layer4(feat16) # 1/32
return feat8, feat16, feat32
def init_weight(self):
state_dict = modelzoo.load_url(resnet18_url)
self_state_dict = self.state_dict()
for k, v in state_dict.items():
if 'fc' in k: continue
self_state_dict.update({k: v})
self.load_state_dict(self_state_dict)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
if __name__ == "__main__":
net = Resnet18()
x = torch.randn(16, 3, 224, 224)
out = net(x)
print(out[0].size())
print(out[1].size())
print(out[2].size())
net.get_params()
... ...
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import numpy as np
from model import BiSeNet
import torch
import os
import os.path as osp
from PIL import Image
import torchvision.transforms as transforms
import cv2
from pathlib import Path
import configargparse
import tqdm
# import ttach as tta
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg',
img_size=(512, 512)):
im = np.array(im)
vis_im = im.copy().astype(np.uint8)
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
vis_parsing_anno = cv2.resize(
vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
vis_parsing_anno_color = np.zeros(
(vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255
num_of_class = np.max(vis_parsing_anno)
# print(num_of_class)
for pi in range(1, 14):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
for pi in range(14, 16):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0])
for pi in range(16, 17):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255])
for pi in range(17, num_of_class+1):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
index = np.where(vis_parsing_anno == num_of_class-1)
vis_im = cv2.resize(vis_parsing_anno_color, img_size,
interpolation=cv2.INTER_NEAREST)
if save_im:
cv2.imwrite(save_path, vis_im)
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
Path(respth).mkdir(parents=True, exist_ok=True)
print(f'[INFO] loading model...')
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
net.load_state_dict(torch.load(cp))
net.eval()
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
image_paths = os.listdir(dspth)
with torch.no_grad():
for image_path in tqdm.tqdm(image_paths):
if image_path.endswith('.jpg') or image_path.endswith('.png'):
img = Image.open(osp.join(dspth, image_path))
ori_size = img.size
image = img.resize((512, 512), Image.BILINEAR)
image = image.convert("RGB")
img = to_tensor(image)
# test-time augmentation.
inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512]
outputs = net(inputs.cuda())
parsing = outputs.mean(0).cpu().numpy().argmax(0)
image_path = int(image_path[:-4])
image_path = str(image_path) + '.png'
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size)
if __name__ == "__main__":
parser = configargparse.ArgumentParser()
parser.add_argument('--respath', type=str, default='./result/', help='result path for label')
parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images')
parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth')
args = parser.parse_args()
evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath)
... ...
import numpy as np
from scipy.io import loadmat
original_BFM = loadmat("3DMM/01_MorphableModel.mat")
sub_inds = np.load("3DMM/topology_info.npy", allow_pickle=True).item()["sub_inds"]
shapePC = original_BFM["shapePC"]
shapeEV = original_BFM["shapeEV"]
shapeMU = original_BFM["shapeMU"]
texPC = original_BFM["texPC"]
texEV = original_BFM["texEV"]
texMU = original_BFM["texMU"]
b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3)
mu_shape = shapeMU.reshape(-1, 3)
b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3)
mu_tex = texMU.reshape(-1, 3)
b_shape = b_shape[:, sub_inds, :].reshape(199, -1)
mu_shape = mu_shape[sub_inds, :].reshape(-1)
b_tex = b_tex[:, sub_inds, :].reshape(199, -1)
mu_tex = mu_tex[sub_inds, :].reshape(-1)
exp_info = np.load("3DMM/exp_info.npy", allow_pickle=True).item()
np.save(
"3DMM/3DMM_info.npy",
{
"mu_shape": mu_shape,
"b_shape": b_shape,
"sig_shape": shapeEV.reshape(-1),
"mu_exp": exp_info["mu_exp"],
"b_exp": exp_info["base_exp"],
"sig_exp": exp_info["sig_exp"],
"mu_tex": mu_tex,
"b_tex": b_tex,
"sig_tex": texEV.reshape(-1),
},
)
... ...
import os
import torch
import numpy as np
def load_dir(path, start, end):
lmss = []
imgs_paths = []
for i in range(start, end):
if os.path.isfile(os.path.join(path, str(i) + ".lms")):
lms = np.loadtxt(os.path.join(path, str(i) + ".lms"), dtype=np.float32)
lmss.append(lms)
imgs_paths.append(os.path.join(path, str(i) + ".jpg"))
lmss = np.stack(lmss)
lmss = torch.as_tensor(lmss).cuda()
return lmss, imgs_paths
... ...
import os
import sys
import cv2
import argparse
from pathlib import Path
import torch
import numpy as np
from data_loader import load_dir
from facemodel import Face_3DMM
from util import *
from render_3dmm import Render_3DMM
# torch.autograd.set_detect_anomaly(True)
dir_path = os.path.dirname(os.path.realpath(__file__))
def set_requires_grad(tensor_list):
for tensor in tensor_list:
tensor.requires_grad = True
parser = argparse.ArgumentParser()
parser.add_argument(
"--path", type=str, default="obama/ori_imgs", help="idname of target person"
)
parser.add_argument("--img_h", type=int, default=512, help="image height")
parser.add_argument("--img_w", type=int, default=512, help="image width")
parser.add_argument("--frame_num", type=int, default=11000, help="image number")
args = parser.parse_args()
start_id = 0
end_id = args.frame_num
lms, img_paths = load_dir(args.path, start_id, end_id)
num_frames = lms.shape[0]
h, w = args.img_h, args.img_w
cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda()
id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650
model_3dmm = Face_3DMM(
os.path.join(dir_path, "3DMM"), id_dim, exp_dim, tex_dim, point_num
)
# only use one image per 40 to do fit the focal length
sel_ids = np.arange(0, num_frames, 40)
sel_num = sel_ids.shape[0]
arg_focal = 1600
arg_landis = 1e5
print(f'[INFO] fitting focal length...')
# fit the focal length
for focal in range(600, 1500, 100):
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True)
euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True)
trans = lms.new_zeros((sel_num, 3), requires_grad=True)
trans.data[:, 2] -= 7
focal_length = lms.new_zeros(1, requires_grad=False)
focal_length.data += focal
set_requires_grad([id_para, exp_para, euler_angle, trans])
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=0.1)
for iter in range(2000):
id_para_batch = id_para.expand(sel_num, -1)
geometry = model_3dmm.get_3dlandmarks(
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
)
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
loss = loss_lan
optimizer_frame.zero_grad()
loss.backward()
optimizer_frame.step()
# if iter % 100 == 0:
# print(focal, 'pose', iter, loss.item())
for iter in range(2500):
id_para_batch = id_para.expand(sel_num, -1)
geometry = model_3dmm.get_3dlandmarks(
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
)
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
loss_regid = torch.mean(id_para * id_para)
loss_regexp = torch.mean(exp_para * exp_para)
loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
optimizer_idexp.zero_grad()
optimizer_frame.zero_grad()
loss.backward()
optimizer_idexp.step()
optimizer_frame.step()
# if iter % 100 == 0:
# print(focal, 'poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
if iter % 1500 == 0 and iter >= 1500:
for param_group in optimizer_idexp.param_groups:
param_group["lr"] *= 0.2
for param_group in optimizer_frame.param_groups:
param_group["lr"] *= 0.2
print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item())
if loss_lan.item() < arg_landis:
arg_landis = loss_lan.item()
arg_focal = focal
print("[INFO] find best focal:", arg_focal)
print(f'[INFO] coarse fitting...')
# for all frames, do a coarse fitting ???
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
tex_para = lms.new_zeros(
(1, tex_dim), requires_grad=True
) # not optimized in this block ???
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
trans = lms.new_zeros((num_frames, 3), requires_grad=True)
light_para = lms.new_zeros((num_frames, 27), requires_grad=True)
trans.data[:, 2] -= 7 # ???
focal_length = lms.new_zeros(1, requires_grad=True)
focal_length.data += arg_focal
set_requires_grad([id_para, exp_para, tex_para, euler_angle, trans, light_para])
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1)
for iter in range(1500):
id_para_batch = id_para.expand(num_frames, -1)
geometry = model_3dmm.get_3dlandmarks(
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
)
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
loss = loss_lan
optimizer_frame.zero_grad()
loss.backward()
optimizer_frame.step()
if iter == 1000:
for param_group in optimizer_frame.param_groups:
param_group["lr"] = 0.1
# if iter % 100 == 0:
# print('pose', iter, loss.item())
for param_group in optimizer_frame.param_groups:
param_group["lr"] = 0.1
for iter in range(2000):
id_para_batch = id_para.expand(num_frames, -1)
geometry = model_3dmm.get_3dlandmarks(
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
)
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
loss_regid = torch.mean(id_para * id_para)
loss_regexp = torch.mean(exp_para * exp_para)
loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
optimizer_idexp.zero_grad()
optimizer_frame.zero_grad()
loss.backward()
optimizer_idexp.step()
optimizer_frame.step()
# if iter % 100 == 0:
# print('poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
if iter % 1000 == 0 and iter >= 1000:
for param_group in optimizer_idexp.param_groups:
param_group["lr"] *= 0.2
for param_group in optimizer_frame.param_groups:
param_group["lr"] *= 0.2
print(loss_lan.item(), torch.mean(trans[:, 2]).item())
print(f'[INFO] fitting light...')
batch_size = 32
device_default = torch.device("cuda:0")
device_render = torch.device("cuda:0")
renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render)
sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size]
imgs = []
for sel_id in sel_ids:
imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
imgs = np.stack(imgs)
sel_imgs = torch.as_tensor(imgs).cuda()
sel_lms = lms[sel_ids]
sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
set_requires_grad([sel_light])
optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=0.1)
optimizer_id_frame = torch.optim.Adam([euler_angle, trans, exp_para, id_para], lr=0.01)
for iter in range(71):
sel_exp_para, sel_euler, sel_trans = (
exp_para[sel_ids],
euler_angle[sel_ids],
trans[sel_ids],
)
sel_id_para = id_para.expand(batch_size, -1)
geometry = model_3dmm.get_3dlandmarks(
sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
)
proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
loss_regid = torch.mean(id_para * id_para)
loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
sel_tex_para = tex_para.expand(batch_size, -1)
sel_texture = model_3dmm.forward_tex(sel_tex_para)
geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
rott_geo = forward_rott(geometry, sel_euler, sel_trans)
render_imgs = renderer(
rott_geo.to(device_render),
sel_texture.to(device_render),
sel_light.to(device_render),
)
render_imgs = render_imgs.to(device_default)
mask = (render_imgs[:, :, :, 3]).detach() > 0.0
render_proj = sel_imgs.clone()
render_proj[mask] = render_imgs[mask][..., :3].byte()
loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
if iter > 50:
loss = loss_col + loss_lan * 0.05 + loss_regid * 1.0 + loss_regexp * 0.8
else:
loss = loss_col + loss_lan * 3 + loss_regid * 2.0 + loss_regexp * 1.0
optimizer_tl.zero_grad()
optimizer_id_frame.zero_grad()
loss.backward()
optimizer_tl.step()
optimizer_id_frame.step()
if iter % 50 == 0 and iter > 0:
for param_group in optimizer_id_frame.param_groups:
param_group["lr"] *= 0.2
for param_group in optimizer_tl.param_groups:
param_group["lr"] *= 0.2
# print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item())
light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1)
light_para.data = light_mean
exp_para = exp_para.detach()
euler_angle = euler_angle.detach()
trans = trans.detach()
light_para = light_para.detach()
print(f'[INFO] fine frame-wise fitting...')
for i in range(int((num_frames - 1) / batch_size + 1)):
if (i + 1) * batch_size > num_frames:
start_n = num_frames - batch_size
sel_ids = np.arange(num_frames - batch_size, num_frames)
else:
start_n = i * batch_size
sel_ids = np.arange(i * batch_size, i * batch_size + batch_size)
imgs = []
for sel_id in sel_ids:
imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
imgs = np.stack(imgs)
sel_imgs = torch.as_tensor(imgs).cuda()
sel_lms = lms[sel_ids]
sel_exp_para = exp_para.new_zeros((batch_size, exp_dim), requires_grad=True)
sel_exp_para.data = exp_para[sel_ids].clone()
sel_euler = euler_angle.new_zeros((batch_size, 3), requires_grad=True)
sel_euler.data = euler_angle[sel_ids].clone()
sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
sel_trans.data = trans[sel_ids].clone()
sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
sel_light.data = light_para[sel_ids].clone()
set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light])
optimizer_cur_batch = torch.optim.Adam(
[sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005
)
sel_id_para = id_para.expand(batch_size, -1).detach()
sel_tex_para = tex_para.expand(batch_size, -1).detach()
pre_num = 5
if i > 0:
pre_ids = np.arange(start_n - pre_num, start_n)
for iter in range(50):
geometry = model_3dmm.get_3dlandmarks(
sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
)
proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
sel_texture = model_3dmm.forward_tex(sel_tex_para)
geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
rott_geo = forward_rott(geometry, sel_euler, sel_trans)
render_imgs = renderer(
rott_geo.to(device_render),
sel_texture.to(device_render),
sel_light.to(device_render),
)
render_imgs = render_imgs.to(device_default)
mask = (render_imgs[:, :, :, 3]).detach() > 0.0
loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
if i > 0:
geometry_lap = model_3dmm.forward_geo_sub(
id_para.expand(batch_size + pre_num, -1).detach(),
torch.cat((exp_para[pre_ids].detach(), sel_exp_para)),
model_3dmm.rigid_ids,
)
rott_geo_lap = forward_rott(
geometry_lap,
torch.cat((euler_angle[pre_ids].detach(), sel_euler)),
torch.cat((trans[pre_ids].detach(), sel_trans)),
)
loss_lap = cal_lap_loss(
[rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
)
else:
geometry_lap = model_3dmm.forward_geo_sub(
id_para.expand(batch_size, -1).detach(),
sel_exp_para,
model_3dmm.rigid_ids,
)
rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans)
loss_lap = cal_lap_loss(
[rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
)
if iter > 30:
loss = loss_col * 0.5 + loss_lan * 1.5 + loss_lap * 100000 + loss_regexp * 1.0
else:
loss = loss_col * 0.5 + loss_lan * 8 + loss_lap * 100000 + loss_regexp * 1.0
optimizer_cur_batch.zero_grad()
loss.backward()
optimizer_cur_batch.step()
# if iter % 10 == 0:
# print(
# i,
# iter,
# loss_col.item(),
# loss_lan.item(),
# loss_lap.item(),
# loss_regexp.item(),
# )
print(str(i) + " of " + str(int((num_frames - 1) / batch_size + 1)) + " done")
render_proj = sel_imgs.clone()
render_proj[mask] = render_imgs[mask][..., :3].byte()
exp_para[sel_ids] = sel_exp_para.clone()
euler_angle[sel_ids] = sel_euler.clone()
trans[sel_ids] = sel_trans.clone()
light_para[sel_ids] = sel_light.clone()
torch.save(
{
"id": id_para.detach().cpu(),
"exp": exp_para.detach().cpu(),
"euler": euler_angle.detach().cpu(),
"trans": trans.detach().cpu(),
"focal": focal_length.detach().cpu(),
},
os.path.join(os.path.dirname(args.path), "track_params.pt"),
)
print("params saved")
... ...
import torch
import torch.nn as nn
import numpy as np
import os
from util import *
class Face_3DMM(nn.Module):
def __init__(self, modelpath, id_dim, exp_dim, tex_dim, point_num):
super(Face_3DMM, self).__init__()
# id_dim = 100
# exp_dim = 79
# tex_dim = 100
self.point_num = point_num
DMM_info = np.load(
os.path.join(modelpath, "3DMM_info.npy"), allow_pickle=True
).item()
base_id = DMM_info["b_shape"][:id_dim, :]
mu_id = DMM_info["mu_shape"]
base_exp = DMM_info["b_exp"][:exp_dim, :]
mu_exp = DMM_info["mu_exp"]
mu = mu_id + mu_exp
mu = mu.reshape(-1, 3)
for i in range(3):
mu[:, i] -= np.mean(mu[:, i])
mu = mu.reshape(-1)
self.base_id = torch.as_tensor(base_id).cuda() / 100000.0
self.base_exp = torch.as_tensor(base_exp).cuda() / 100000.0
self.mu = torch.as_tensor(mu).cuda() / 100000.0
base_tex = DMM_info["b_tex"][:tex_dim, :]
mu_tex = DMM_info["mu_tex"]
self.base_tex = torch.as_tensor(base_tex).cuda()
self.mu_tex = torch.as_tensor(mu_tex).cuda()
sig_id = DMM_info["sig_shape"][:id_dim]
sig_tex = DMM_info["sig_tex"][:tex_dim]
sig_exp = DMM_info["sig_exp"][:exp_dim]
self.sig_id = torch.as_tensor(sig_id).cuda()
self.sig_tex = torch.as_tensor(sig_tex).cuda()
self.sig_exp = torch.as_tensor(sig_exp).cuda()
keys_info = np.load(
os.path.join(modelpath, "keys_info.npy"), allow_pickle=True
).item()
self.keyinds = torch.as_tensor(keys_info["keyinds"]).cuda()
self.left_contours = torch.as_tensor(keys_info["left_contour"]).cuda()
self.right_contours = torch.as_tensor(keys_info["right_contour"]).cuda()
self.rigid_ids = torch.as_tensor(keys_info["rigid_ids"]).cuda()
def get_3dlandmarks(self, id_para, exp_para, euler_angle, trans, focal_length, cxy):
id_para = id_para * self.sig_id
exp_para = exp_para * self.sig_exp
batch_size = id_para.shape[0]
num_per_contour = self.left_contours.shape[1]
left_contours_flat = self.left_contours.reshape(-1)
right_contours_flat = self.right_contours.reshape(-1)
sel_index = torch.cat(
(
3 * left_contours_flat.unsqueeze(1),
3 * left_contours_flat.unsqueeze(1) + 1,
3 * left_contours_flat.unsqueeze(1) + 2,
),
dim=1,
).reshape(-1)
left_geometry = (
torch.mm(id_para, self.base_id[:, sel_index])
+ torch.mm(exp_para, self.base_exp[:, sel_index])
+ self.mu[sel_index]
)
left_geometry = left_geometry.view(batch_size, -1, 3)
proj_x = forward_transform(
left_geometry, euler_angle, trans, focal_length, cxy
)[:, :, 0]
proj_x = proj_x.reshape(batch_size, 8, num_per_contour)
arg_min = proj_x.argmin(dim=2)
left_geometry = left_geometry.view(batch_size * 8, num_per_contour, 3)
left_3dlands = left_geometry[
torch.arange(batch_size * 8), arg_min.view(-1), :
].view(batch_size, 8, 3)
sel_index = torch.cat(
(
3 * right_contours_flat.unsqueeze(1),
3 * right_contours_flat.unsqueeze(1) + 1,
3 * right_contours_flat.unsqueeze(1) + 2,
),
dim=1,
).reshape(-1)
right_geometry = (
torch.mm(id_para, self.base_id[:, sel_index])
+ torch.mm(exp_para, self.base_exp[:, sel_index])
+ self.mu[sel_index]
)
right_geometry = right_geometry.view(batch_size, -1, 3)
proj_x = forward_transform(
right_geometry, euler_angle, trans, focal_length, cxy
)[:, :, 0]
proj_x = proj_x.reshape(batch_size, 8, num_per_contour)
arg_max = proj_x.argmax(dim=2)
right_geometry = right_geometry.view(batch_size * 8, num_per_contour, 3)
right_3dlands = right_geometry[
torch.arange(batch_size * 8), arg_max.view(-1), :
].view(batch_size, 8, 3)
sel_index = torch.cat(
(
3 * self.keyinds.unsqueeze(1),
3 * self.keyinds.unsqueeze(1) + 1,
3 * self.keyinds.unsqueeze(1) + 2,
),
dim=1,
).reshape(-1)
geometry = (
torch.mm(id_para, self.base_id[:, sel_index])
+ torch.mm(exp_para, self.base_exp[:, sel_index])
+ self.mu[sel_index]
)
lands_3d = geometry.view(-1, self.keyinds.shape[0], 3)
lands_3d[:, :8, :] = left_3dlands
lands_3d[:, 9:17, :] = right_3dlands
return lands_3d
def forward_geo_sub(self, id_para, exp_para, sub_index):
id_para = id_para * self.sig_id
exp_para = exp_para * self.sig_exp
sel_index = torch.cat(
(
3 * sub_index.unsqueeze(1),
3 * sub_index.unsqueeze(1) + 1,
3 * sub_index.unsqueeze(1) + 2,
),
dim=1,
).reshape(-1)
geometry = (
torch.mm(id_para, self.base_id[:, sel_index])
+ torch.mm(exp_para, self.base_exp[:, sel_index])
+ self.mu[sel_index]
)
return geometry.reshape(-1, sub_index.shape[0], 3)
def forward_geo(self, id_para, exp_para):
id_para = id_para * self.sig_id
exp_para = exp_para * self.sig_exp
geometry = (
torch.mm(id_para, self.base_id)
+ torch.mm(exp_para, self.base_exp)
+ self.mu
)
return geometry.reshape(-1, self.point_num, 3)
def forward_tex(self, tex_para):
tex_para = tex_para * self.sig_tex
texture = torch.mm(tex_para, self.base_tex) + self.mu_tex
return texture.reshape(-1, self.point_num, 3)
... ...
"""This module contains functions for geometry transform and camera projection"""
import torch
import torch.nn as nn
import numpy as np
def euler2rot(euler_angle):
batch_size = euler_angle.shape[0]
theta = euler_angle[:, 0].reshape(-1, 1, 1)
phi = euler_angle[:, 1].reshape(-1, 1, 1)
psi = euler_angle[:, 2].reshape(-1, 1, 1)
one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
zero = torch.zeros(
(batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device
)
rot_x = torch.cat(
(
torch.cat((one, zero, zero), 1),
torch.cat((zero, theta.cos(), theta.sin()), 1),
torch.cat((zero, -theta.sin(), theta.cos()), 1),
),
2,
)
rot_y = torch.cat(
(
torch.cat((phi.cos(), zero, -phi.sin()), 1),
torch.cat((zero, one, zero), 1),
torch.cat((phi.sin(), zero, phi.cos()), 1),
),
2,
)
rot_z = torch.cat(
(
torch.cat((psi.cos(), -psi.sin(), zero), 1),
torch.cat((psi.sin(), psi.cos(), zero), 1),
torch.cat((zero, zero, one), 1),
),
2,
)
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
def rot_trans_geo(geometry, rot, trans):
rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1)
return rott_geo.permute(0, 2, 1)
def euler_trans_geo(geometry, euler, trans):
rot = euler2rot(euler)
return rot_trans_geo(geometry, rot, trans)
def proj_geo(rott_geo, camera_para):
fx = camera_para[:, 0]
fy = camera_para[:, 0]
cx = camera_para[:, 1]
cy = camera_para[:, 2]
X = rott_geo[:, :, 0]
Y = rott_geo[:, :, 1]
Z = rott_geo[:, :, 2]
fxX = fx[:, None] * X
fyY = fy[:, None] * Y
proj_x = -fxX / Z + cx[:, None]
proj_y = fyY / Z + cy[:, None]
return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)
... ...
import torch
import torch.nn as nn
import numpy as np
import os
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
look_at_view_transform,
PerspectiveCameras,
FoVPerspectiveCameras,
PointLights,
DirectionalLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
TexturesUV,
TexturesVertex,
blending,
)
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.renderer.blending import (
BlendParams,
hard_rgb_blend,
sigmoid_alpha_blend,
softmax_rgb_blend,
)
class SoftSimpleShader(nn.Module):
"""
Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function returns the
soft aggregated color using all the faces per pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
"""
def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
):
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
# Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = meshes.sample_textures(fragments)
blend_params = kwargs.get("blend_params", self.blend_params)
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
znear = kwargs.get("znear", getattr(cameras, "znear", 1.0))
zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
images = softmax_rgb_blend(
texels, fragments, blend_params, znear=znear, zfar=zfar
)
return images
class Render_3DMM(nn.Module):
def __init__(
self,
focal=1015,
img_h=500,
img_w=500,
batch_size=1,
device=torch.device("cuda:0"),
):
super(Render_3DMM, self).__init__()
self.focal = focal
self.img_h = img_h
self.img_w = img_w
self.device = device
self.renderer = self.get_render(batch_size)
dir_path = os.path.dirname(os.path.realpath(__file__))
topo_info = np.load(
os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True
).item()
self.tris = torch.as_tensor(topo_info["tris"]).to(self.device)
self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device)
def compute_normal(self, geometry):
vert_1 = torch.index_select(geometry, 1, self.tris[:, 0])
vert_2 = torch.index_select(geometry, 1, self.tris[:, 1])
vert_3 = torch.index_select(geometry, 1, self.tris[:, 2])
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
tri_normal = nn.functional.normalize(nnorm, dim=2)
v_norm = tri_normal[:, self.vert_tris, :].sum(2)
vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2)
return vert_normal
def get_render(self, batch_size=1):
half_s = self.img_w * 0.5
R, T = look_at_view_transform(10, 0, 0)
R = R.repeat(batch_size, 1, 1)
T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device)
cameras = FoVPerspectiveCameras(
device=self.device,
R=R,
T=T,
znear=0.01,
zfar=20,
fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi,
)
lights = PointLights(
device=self.device,
location=[[0.0, 0.0, 1e5]],
ambient_color=[[1, 1, 1]],
specular_color=[[0.0, 0.0, 0.0]],
diffuse_color=[[0.0, 0.0, 0.0]],
)
sigma = 1e-4
raster_settings = RasterizationSettings(
image_size=(self.img_h, self.img_w),
blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0,
faces_per_pixel=2,
perspective_correct=False,
)
blend_params = blending.BlendParams(background_color=[0, 0, 0])
renderer = MeshRenderer(
rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras),
shader=SoftSimpleShader(
lights=lights, blend_params=blend_params, cameras=cameras
),
)
return renderer.to(self.device)
@staticmethod
def Illumination_layer(face_texture, norm, gamma):
n_b, num_vertex, _ = face_texture.size()
n_v_full = n_b * num_vertex
gamma = gamma.view(-1, 3, 9).clone()
gamma[:, :, 0] += 0.8
gamma = gamma.permute(0, 2, 1)
a0 = np.pi
a1 = 2 * np.pi / np.sqrt(3.0)
a2 = 2 * np.pi / np.sqrt(8.0)
c0 = 1 / np.sqrt(4 * np.pi)
c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi)
c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi)
d0 = 0.5 / np.sqrt(3.0)
Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0
norm = norm.view(-1, 3)
nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2]
arrH = []
arrH.append(Y0)
arrH.append(-a1 * c1 * ny)
arrH.append(a1 * c1 * nz)
arrH.append(-a1 * c1 * nx)
arrH.append(a2 * c2 * nx * ny)
arrH.append(-a2 * c2 * ny * nz)
arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1))
arrH.append(-a2 * c2 * nx * nz)
arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2)))
H = torch.stack(arrH, 1)
Y = H.view(n_b, num_vertex, 9)
lighting = Y.bmm(gamma)
face_color = face_texture * lighting
return face_color
def forward(self, rott_geometry, texture, diffuse_sh):
face_normal = self.compute_normal(rott_geometry)
face_color = self.Illumination_layer(texture, face_normal, diffuse_sh)
face_color = TexturesVertex(face_color)
mesh = Meshes(
rott_geometry,
self.tris.float().repeat(rott_geometry.shape[0], 1, 1),
face_color,
)
rendered_img = self.renderer(mesh)
rendered_img = torch.clamp(rendered_img, 0, 255)
return rendered_img
... ...
import torch
import torch.nn as nn
import render_util
import geo_transform
import numpy as np
def compute_tri_normal(geometry, tris):
geometry = geometry.permute(0, 2, 1)
tri_1 = tris[:, 0]
tri_2 = tris[:, 1]
tri_3 = tris[:, 2]
vert_1 = torch.index_select(geometry, 2, tri_1)
vert_2 = torch.index_select(geometry, 2, tri_2)
vert_3 = torch.index_select(geometry, 2, tri_3)
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1)
normal = nn.functional.normalize(nnorm).permute(0, 2, 1)
return normal
class Compute_normal_base(torch.autograd.Function):
@staticmethod
def forward(ctx, normal):
(normal_b,) = render_util.normal_base_forward(normal)
ctx.save_for_backward(normal)
return normal_b
@staticmethod
def backward(ctx, grad_normal_b):
(normal,) = ctx.saved_tensors
(grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal)
return grad_normal
class Normal_Base(torch.nn.Module):
def __init__(self):
super(Normal_Base, self).__init__()
def forward(self, normal):
return Compute_normal_base.apply(normal)
def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img):
point_num = geometry.shape[1]
rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans)
proj_geo = geo_transform.proj_geo(rott_geo, cam)
rot_tri_normal = compute_tri_normal(rott_geo, tris)
rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris)
is_visible = -torch.bmm(
rot_vert_normal.reshape(-1, 1, 3),
nn.functional.normalize(rott_geo.reshape(-1, 3, 1)),
).reshape(-1, point_num)
is_visible[is_visible < 0.01] = -1
pixel_valid = torch.zeros(
(ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]),
dtype=torch.float32,
device=ori_img.device,
)
return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid
class Render_Face(torch.autograd.Function):
@staticmethod
def forward(
ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
):
batch_size, h, w, _ = ori_img.shape
ori_img = ori_img.view(batch_size, -1, 3)
ori_size = torch.cat(
(
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
* h,
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
* w,
),
dim=1,
).view(-1)
tri_index, tri_coord, render, real = render_util.render_face_forward(
proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid
)
ctx.save_for_backward(
ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord
)
return render, real
@staticmethod
def backward(ctx, grad_render, grad_real):
(
ori_img,
ori_size,
proj_geo,
texture,
nbl,
tri_inds,
tri_index,
tri_coord,
) = ctx.saved_tensors
grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward(
grad_render,
grad_real,
ori_img,
ori_size,
proj_geo,
texture,
nbl,
tri_inds,
tri_index,
tri_coord,
)
return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None
class Render_RGB(nn.Module):
def __init__(self):
super(Render_RGB, self).__init__()
def forward(
self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
):
return Render_Face.apply(
proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
)
def cal_land(proj_geo, is_visible, lands_info, land_num):
(land_index,) = render_util.update_contour(lands_info, is_visible, land_num)
proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[
:, :2
].reshape(-1, land_num, 2)
return proj_land
class Render_Land(nn.Module):
def __init__(self):
super(Render_Land, self).__init__()
lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32)
self.lands_info = torch.as_tensor(lands_info).cuda()
tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64)
self.tris = torch.as_tensor(tris).cuda() - 1
vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64)
self.vert_tris = torch.as_tensor(vert_tris).cuda()
self.normal_baser = Normal_Base().cuda()
self.renderer = Render_RGB().cuda()
def render_mesh(self, geometry, euler, trans, cam, ori_img, light):
batch_size, h, w, _ = ori_img.shape
ori_img = ori_img.view(batch_size, -1, 3)
ori_size = torch.cat(
(
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
* h,
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
* w,
),
dim=1,
).view(-1)
rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render(
geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
)
tri_nb = self.normal_baser(rot_tri_normal.contiguous())
nbl = torch.bmm(
tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3)
)
texture = torch.ones_like(geometry) * 200
(render,) = render_util.render_mesh(
proj_geo, ori_img, ori_size, texture, nbl, self.tris
)
return render.view(batch_size, h, w, 3).byte()
def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands):
rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render(
geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
)
tri_nb = self.normal_baser(rot_tri_normal.contiguous())
nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3))
render, real = self.renderer(
proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid
)
proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1])
col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape(
ori_img.shape[0], -1
)
col_dis = torch.mean(col_minus * pixel_valid) / (
torch.mean(pixel_valid) + 0.00001
)
land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape(
ori_img.shape[0], -1
)
lan_dis = torch.mean(land_dists)
return col_dis, lan_dis
... ...
import torch
import torch.nn as nn
import torch.nn.functional as F
def compute_tri_normal(geometry, tris):
tri_1 = tris[:, 0]
tri_2 = tris[:, 1]
tri_3 = tris[:, 2]
vert_1 = torch.index_select(geometry, 1, tri_1)
vert_2 = torch.index_select(geometry, 1, tri_2)
vert_3 = torch.index_select(geometry, 1, tri_3)
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
normal = nn.functional.normalize(nnorm)
return normal
def euler2rot(euler_angle):
batch_size = euler_angle.shape[0]
theta = euler_angle[:, 0].reshape(-1, 1, 1)
phi = euler_angle[:, 1].reshape(-1, 1, 1)
psi = euler_angle[:, 2].reshape(-1, 1, 1)
one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
rot_x = torch.cat(
(
torch.cat((one, zero, zero), 1),
torch.cat((zero, theta.cos(), theta.sin()), 1),
torch.cat((zero, -theta.sin(), theta.cos()), 1),
),
2,
)
rot_y = torch.cat(
(
torch.cat((phi.cos(), zero, -phi.sin()), 1),
torch.cat((zero, one, zero), 1),
torch.cat((phi.sin(), zero, phi.cos()), 1),
),
2,
)
rot_z = torch.cat(
(
torch.cat((psi.cos(), -psi.sin(), zero), 1),
torch.cat((psi.sin(), psi.cos(), zero), 1),
torch.cat((zero, zero, one), 1),
),
2,
)
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
def rot_trans_pts(geometry, rot, trans):
rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None]
return rott_geo.permute(0, 2, 1)
def cal_lap_loss(tensor_list, weight_list):
lap_kernel = (
torch.Tensor((-0.5, 1.0, -0.5))
.unsqueeze(0)
.unsqueeze(0)
.float()
.to(tensor_list[0].device)
)
loss_lap = 0
for i in range(len(tensor_list)):
in_tensor = tensor_list[i]
in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1])
out_tensor = F.conv1d(in_tensor, lap_kernel)
loss_lap += torch.mean(out_tensor ** 2) * weight_list[i]
return loss_lap
def proj_pts(rott_geo, focal_length, cxy):
cx, cy = cxy[0], cxy[1]
X = rott_geo[:, :, 0]
Y = rott_geo[:, :, 1]
Z = rott_geo[:, :, 2]
fxX = focal_length * X
fyY = focal_length * Y
proj_x = -fxX / Z + cx
proj_y = fyY / Z + cy
return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)
def forward_rott(geometry, euler_angle, trans):
rot = euler2rot(euler_angle)
rott_geo = rot_trans_pts(geometry, rot, trans)
return rott_geo
def forward_transform(geometry, euler_angle, trans, focal_length, cxy):
rot = euler2rot(euler_angle)
rott_geo = rot_trans_pts(geometry, rot, trans)
proj_geo = proj_pts(rott_geo, focal_length, cxy)
return proj_geo
def cal_lan_loss(proj_lan, gt_lan):
return torch.mean((proj_lan - gt_lan) ** 2)
def cal_col_loss(pred_img, gt_img, img_mask):
pred_img = pred_img.float()
# loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255
loss = (torch.sum(torch.square(pred_img - gt_img), 3)) * img_mask / 255
loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2))
loss = torch.mean(loss)
return loss
... ...
import os
import glob
import tqdm
import json
import argparse
import cv2
import numpy as np
def extract_audio(path, out_path, sample_rate=16000):
print(f'[INFO] ===== extract audio from {path} to {out_path} =====')
cmd = f'ffmpeg -i {path} -f wav -ar {sample_rate} {out_path}'
os.system(cmd)
print(f'[INFO] ===== extracted audio =====')
def extract_audio_features(path, mode='wav2vec'):
print(f'[INFO] ===== extract audio labels for {path} =====')
if mode == 'wav2vec':
cmd = f'python nerf/asr.py --wav {path} --save_feats'
else: # deepspeech
cmd = f'python data_utils/deepspeech_features/extract_ds_features.py --input {path}'
os.system(cmd)
print(f'[INFO] ===== extracted audio labels =====')
def extract_images(path, out_path, fps=25):
print(f'[INFO] ===== extract images from {path} to {out_path} =====')
cmd = f'ffmpeg -i {path} -vf fps={fps} -qmin 1 -q:v 1 -start_number 0 {os.path.join(out_path, "%d.jpg")}'
os.system(cmd)
print(f'[INFO] ===== extracted images =====')
def extract_semantics(ori_imgs_dir, parsing_dir):
print(f'[INFO] ===== extract semantics from {ori_imgs_dir} to {parsing_dir} =====')
cmd = f'python data_utils/face_parsing/test.py --respath={parsing_dir} --imgpath={ori_imgs_dir}'
os.system(cmd)
print(f'[INFO] ===== extracted semantics =====')
def extract_landmarks(ori_imgs_dir):
print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
import face_alignment
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
for image_path in tqdm.tqdm(image_paths):
input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
preds = fa.get_landmarks(input)
if len(preds) > 0:
lands = preds[0].reshape(-1, 2)[:,:2]
np.savetxt(image_path.replace('jpg', 'lms'), lands, '%f')
del fa
print(f'[INFO] ===== extracted face landmarks =====')
def extract_background(base_dir, ori_imgs_dir):
print(f'[INFO] ===== extract background image from {ori_imgs_dir} =====')
from sklearn.neighbors import NearestNeighbors
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
# only use 1/20 image_paths
image_paths = image_paths[::20]
# read one image to get H/W
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
h, w = tmp_image.shape[:2]
# nearest neighbors
all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
distss = []
for image_path in tqdm.tqdm(image_paths):
parse_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
bg = (parse_img[..., 0] == 255) & (parse_img[..., 1] == 255) & (parse_img[..., 2] == 255)
fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
dists, _ = nbrs.kneighbors(all_xys)
distss.append(dists)
distss = np.stack(distss)
max_dist = np.max(distss, 0)
max_id = np.argmax(distss, 0)
bc_pixs = max_dist > 5
bc_pixs_id = np.nonzero(bc_pixs)
bc_ids = max_id[bc_pixs]
imgs = []
num_pixs = distss.shape[1]
for image_path in image_paths:
img = cv2.imread(image_path)
imgs.append(img)
imgs = np.stack(imgs).reshape(-1, num_pixs, 3)
bc_img = np.zeros((h*w, 3), dtype=np.uint8)
bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
bc_img = bc_img.reshape(h, w, 3)
max_dist = max_dist.reshape(h, w)
bc_pixs = max_dist > 5
bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
distances, indices = nbrs.kneighbors(bg_xys)
bg_fg_xys = fg_xys[indices[:, 0]]
bc_img[bg_xys[:, 0], bg_xys[:, 1], :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
cv2.imwrite(os.path.join(base_dir, 'bc.jpg'), bc_img)
print(f'[INFO] ===== extracted background image =====')
def extract_torso_and_gt(base_dir, ori_imgs_dir):
print(f'[INFO] ===== extract torso and gt images for {base_dir} =====')
from scipy.ndimage import binary_erosion, binary_dilation
# load bg
bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED)
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
for image_path in tqdm.tqdm(image_paths):
# read ori image
ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
# read semantics
seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0)
neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0)
torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255)
bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255)
# get gt image
gt_image = ori_image.copy()
gt_image[bg_part] = bg_image[bg_part]
cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
# get torso image
torso_image = gt_image.copy() # rgb
torso_image[head_part] = bg_image[head_part]
torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
# torso part "vertical" in-painting...
L = 8 + 1
torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
# lexsort: sort 2D coords first by y then by x,
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
torso_coords = torso_coords[inds]
# choose the top pixel for each column
u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
top_torso_coords = torso_coords[uid] # [m, 2]
# only keep top-is-head pixels
top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0])
mask = head_part[tuple(top_torso_coords_up.T)]
if mask.any():
top_torso_coords = top_torso_coords[mask]
# get the color
top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
# construct inpaint coords (vertically up, or minus in x)
inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
inpaint_torso_coords += inpaint_offsets
inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
# set color
torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
else:
inpaint_torso_mask = None
# neck part "vertical" in-painting...
push_down = 4
L = 48 + push_down + 1
neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
# lexsort: sort 2D coords first by y then by x,
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
neck_coords = neck_coords[inds]
# choose the top pixel for each column
u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
top_neck_coords = neck_coords[uid] # [m, 2]
# only keep top-is-head pixels
top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
mask = head_part[tuple(top_neck_coords_up.T)]
top_neck_coords = top_neck_coords[mask]
# push these top down for 4 pixels to make the neck inpainting more natural...
offset_down = np.minimum(ucnt[mask] - 1, push_down)
top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
# get the color
top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
# construct inpaint coords (vertically up, or minus in x)
inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
inpaint_neck_coords += inpaint_offsets
inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
# set color
torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
# apply blurring to the inpaint area to avoid vertical-line artifects...
inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
inpaint_mask[tuple(inpaint_neck_coords.T)] = True
blur_img = torso_image.copy()
blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
torso_image[inpaint_mask] = blur_img[inpaint_mask]
# set mask
mask = (neck_part | torso_part | inpaint_mask)
if inpaint_torso_mask is not None:
mask = mask | inpaint_torso_mask
torso_image[~mask] = 0
torso_alpha[~mask] = 0
cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1))
print(f'[INFO] ===== extracted torso and gt images =====')
def face_tracking(ori_imgs_dir):
print(f'[INFO] ===== perform face tracking =====')
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
# read one image to get H/W
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
h, w = tmp_image.shape[:2]
cmd = f'python data_utils/face_tracking/face_tracker.py --path={ori_imgs_dir} --img_h={h} --img_w={w} --frame_num={len(image_paths)}'
os.system(cmd)
print(f'[INFO] ===== finished face tracking =====')
def save_transforms(base_dir, ori_imgs_dir):
print(f'[INFO] ===== save transforms =====')
import torch
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
# read one image to get H/W
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
h, w = tmp_image.shape[:2]
params_dict = torch.load(os.path.join(base_dir, 'track_params.pt'))
focal_len = params_dict['focal']
euler_angle = params_dict['euler']
trans = params_dict['trans'] / 10.0
valid_num = euler_angle.shape[0]
def euler2rot(euler_angle):
batch_size = euler_angle.shape[0]
theta = euler_angle[:, 0].reshape(-1, 1, 1)
phi = euler_angle[:, 1].reshape(-1, 1, 1)
psi = euler_angle[:, 2].reshape(-1, 1, 1)
one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
rot_x = torch.cat((
torch.cat((one, zero, zero), 1),
torch.cat((zero, theta.cos(), theta.sin()), 1),
torch.cat((zero, -theta.sin(), theta.cos()), 1),
), 2)
rot_y = torch.cat((
torch.cat((phi.cos(), zero, -phi.sin()), 1),
torch.cat((zero, one, zero), 1),
torch.cat((phi.sin(), zero, phi.cos()), 1),
), 2)
rot_z = torch.cat((
torch.cat((psi.cos(), -psi.sin(), zero), 1),
torch.cat((psi.sin(), psi.cos(), zero), 1),
torch.cat((zero, zero, one), 1)
), 2)
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
# train_val_split = int(valid_num*0.5)
# train_val_split = valid_num - 25 * 20 # take the last 20s as valid set.
train_val_split = int(valid_num * 10 / 11)
train_ids = torch.arange(0, train_val_split)
val_ids = torch.arange(train_val_split, valid_num)
rot = euler2rot(euler_angle)
rot_inv = rot.permute(0, 2, 1)
trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2))
pose = torch.eye(4, dtype=torch.float32)
save_ids = ['train', 'val']
train_val_ids = [train_ids, val_ids]
mean_z = -float(torch.mean(trans[:, 2]).item())
for split in range(2):
transform_dict = dict()
transform_dict['focal_len'] = float(focal_len[0])
transform_dict['cx'] = float(w/2.0)
transform_dict['cy'] = float(h/2.0)
transform_dict['frames'] = []
ids = train_val_ids[split]
save_id = save_ids[split]
for i in ids:
i = i.item()
frame_dict = dict()
frame_dict['img_id'] = i
frame_dict['aud_id'] = i
pose[:3, :3] = rot_inv[i]
pose[:3, 3] = trans_inv[i, :, 0]
frame_dict['transform_matrix'] = pose.numpy().tolist()
transform_dict['frames'].append(frame_dict)
with open(os.path.join(base_dir, 'transforms_' + save_id + '.json'), 'w') as fp:
json.dump(transform_dict, fp, indent=2, separators=(',', ': '))
print(f'[INFO] ===== finished saving transforms =====')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str, help="path to video file")
parser.add_argument('--task', type=int, default=-1, help="-1 means all")
parser.add_argument('--asr', type=str, default='wav2vec', help="wav2vec or deepspeech")
opt = parser.parse_args()
base_dir = os.path.dirname(opt.path)
wav_path = os.path.join(base_dir, 'aud.wav')
ori_imgs_dir = os.path.join(base_dir, 'ori_imgs')
parsing_dir = os.path.join(base_dir, 'parsing')
gt_imgs_dir = os.path.join(base_dir, 'gt_imgs')
torso_imgs_dir = os.path.join(base_dir, 'torso_imgs')
os.makedirs(ori_imgs_dir, exist_ok=True)
os.makedirs(parsing_dir, exist_ok=True)
os.makedirs(gt_imgs_dir, exist_ok=True)
os.makedirs(torso_imgs_dir, exist_ok=True)
# extract audio
if opt.task == -1 or opt.task == 1:
extract_audio(opt.path, wav_path)
# extract audio features
if opt.task == -1 or opt.task == 2:
extract_audio_features(wav_path, mode=opt.asr)
# extract images
if opt.task == -1 or opt.task == 3:
extract_images(opt.path, ori_imgs_dir)
# face parsing
if opt.task == -1 or opt.task == 4:
extract_semantics(ori_imgs_dir, parsing_dir)
# extract bg
if opt.task == -1 or opt.task == 5:
extract_background(base_dir, ori_imgs_dir)
# extract torso images and gt_images
if opt.task == -1 or opt.task == 6:
extract_torso_and_gt(base_dir, ori_imgs_dir)
# extract face landmarks
if opt.task == -1 or opt.task == 7:
extract_landmarks(ori_imgs_dir)
# face tracking
if opt.task == -1 or opt.task == 8:
face_tracking(ori_imgs_dir)
# save transforms.json
if opt.task == -1 or opt.task == 9:
save_transforms(base_dir, ori_imgs_dir)
... ...
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_encoder(encoding, input_dim=3,
multires=6,
degree=4,
num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
**kwargs):
if encoding == 'None':
return lambda x, **kwargs: x, input_dim
elif encoding == 'frequency':
from .freqencoder import FreqEncoder
encoder = FreqEncoder(input_dim=input_dim, degree=multires)
elif encoding == 'spherical_harmonics':
from .shencoder import SHEncoder
encoder = SHEncoder(input_dim=input_dim, degree=degree)
elif encoding == 'hashgrid':
from .gridencoder import GridEncoder
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
elif encoding == 'tiledgrid':
from .gridencoder import GridEncoder
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
elif encoding == 'ash':
from .ashencoder import AshEncoder
encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution)
else:
raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, spherical_harmonics, hashgrid, tiledgrid]')
return encoder, encoder.output_dim
\ No newline at end of file
... ...
from .freq import FreqEncoder
\ No newline at end of file
... ...
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++17',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-allow-unsupported-compiler',
'-use_fast_math'
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++17']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
_backend = load(name='_freqencoder',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'freqencoder.cu',
'bindings.cpp',
]],
)
__all__ = ['_backend']
\ No newline at end of file
... ...
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _freqencoder as _backend
except ImportError:
from .backend import _backend
class _freq_encoder(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
def forward(ctx, inputs, degree, output_dim):
# inputs: [B, input_dim], float
# RETURN: [B, F], float
if not inputs.is_cuda: inputs = inputs.cuda()
inputs = inputs.contiguous()
B, input_dim = inputs.shape # batch size, coord dim
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
_backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
ctx.save_for_backward(inputs, outputs)
ctx.dims = [B, input_dim, degree, output_dim]
return outputs
@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
# grad: [B, C * C]
grad = grad.contiguous()
inputs, outputs = ctx.saved_tensors
B, input_dim, degree, output_dim = ctx.dims
grad_inputs = torch.zeros_like(inputs)
_backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
return grad_inputs, None, None
freq_encode = _freq_encoder.apply
class FreqEncoder(nn.Module):
def __init__(self, input_dim=3, degree=4):
super().__init__()
self.input_dim = input_dim
self.degree = degree
self.output_dim = input_dim + input_dim * 2 * degree
def __repr__(self):
return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
def forward(self, inputs, **kwargs):
# inputs: [..., input_dim]
# return: [..., ]
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.reshape(-1, self.input_dim)
outputs = freq_encode(inputs, self.degree, self.output_dim)
outputs = outputs.reshape(prefix_shape + [self.output_dim])
return outputs
\ No newline at end of file
... ...
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++17',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-allow-unsupported-compiler',
'-use_fast_math'
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++17']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
setup(
name='freqencoder', # package name, import this to use python API
ext_modules=[
CUDAExtension(
name='_freqencoder', # extension name, import this to use CUDA API
sources=[os.path.join(_src_path, 'src', f) for f in [
'freqencoder.cu',
'bindings.cpp',
]],
extra_compile_args={
'cxx': c_flags,
'nvcc': nvcc_flags,
}
),
],
cmdclass={
'build_ext': BuildExtension,
}
)
\ No newline at end of file
... ...
#include <torch/extension.h>
#include "freqencoder.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
}
\ No newline at end of file
... ...
#include <stdint.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/torch.h>
#include <algorithm>
#include <stdexcept>
#include <cstdio>
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
inline constexpr __device__ float PI() { return 3.141592653589793f; }
template <typename T>
__host__ __device__ T div_round_up(T val, T divisor) {
return (val + divisor - 1) / divisor;
}
// inputs: [B, D]
// outputs: [B, C], C = D + D * deg * 2
__global__ void kernel_freq(
const float * __restrict__ inputs,
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
float * outputs
) {
// parallel on per-element
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * C) return;
// get index
const uint32_t b = t / C;
const uint32_t c = t - b * C; // t % C;
// locate
inputs += b * D;
outputs += t;
// write self
if (c < D) {
outputs[0] = inputs[c];
// write freq
} else {
const uint32_t col = c / D - 1;
const uint32_t d = c % D;
const uint32_t freq = col / 2;
const float phase_shift = (col % 2) * (PI() / 2);
outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
}
}
// grad: [B, C], C = D + D * deg * 2
// outputs: [B, C]
// grad_inputs: [B, D]
__global__ void kernel_freq_backward(
const float * __restrict__ grad,
const float * __restrict__ outputs,
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
float * grad_inputs
) {
// parallel on per-element
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * D) return;
const uint32_t b = t / D;
const uint32_t d = t - b * D; // t % D;
// locate
grad += b * C;
outputs += b * C;
grad_inputs += t;
// register
float result = grad[d];
grad += D;
outputs += D;
for (uint32_t f = 0; f < deg; f++) {
result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
grad += 2 * D;
outputs += 2 * D;
}
// write
grad_inputs[0] = result;
}
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
CHECK_CUDA(inputs);
CHECK_CUDA(outputs);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(outputs);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(outputs);
static constexpr uint32_t N_THREADS = 128;
kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
}
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
CHECK_CUDA(grad);
CHECK_CUDA(outputs);
CHECK_CUDA(grad_inputs);
CHECK_CONTIGUOUS(grad);
CHECK_CONTIGUOUS(outputs);
CHECK_CONTIGUOUS(grad_inputs);
CHECK_IS_FLOATING(grad);
CHECK_IS_FLOATING(outputs);
CHECK_IS_FLOATING(grad_inputs);
static constexpr uint32_t N_THREADS = 128;
kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
}
\ No newline at end of file
... ...
# pragma once
#include <stdint.h>
#include <torch/torch.h>
// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
\ No newline at end of file
... ...
from .grid import GridEncoder
\ No newline at end of file
... ...
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++17',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++17', '-finput-charset=UTF-8']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17', '/finput-charset=UTF-8']
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
_backend = load(name='_grid_encoder',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'gridencoder.cu',
'bindings.cpp',
]],
)
__all__ = ['_backend']
\ No newline at end of file
... ...
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _gridencoder as _backend
except ImportError:
from .backend import _backend
_gridtype_to_id = {
'hash': 0,
'tiled': 1,
}
class _grid_encode(Function):
@staticmethod
@custom_fwd
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
# inputs: [B, D], float in [0, 1]
# embeddings: [sO, C], float
# offsets: [L + 1], int
# RETURN: [B, F], float
inputs = inputs.float().contiguous()
B, D = inputs.shape # batch size, coord dim
L = offsets.shape[0] - 1 # level
C = embeddings.shape[1] # embedding dim for each level
S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
H = base_resolution # base resolution
# manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
# if C % 2 != 0, force float, since half for atomicAdd is very slow.
if torch.is_autocast_enabled() and C % 2 == 0:
embeddings = embeddings.to(torch.half)
# L first, optimize cache for cuda kernel, but needs an extra permute later
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
if calc_grad_inputs:
dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
else:
dy_dx = None
_backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners)
# permute back to [B, L * C]
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
ctx.dims = [B, D, C, L, S, H, gridtype]
ctx.align_corners = align_corners
return outputs
@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
B, D, C, L, S, H, gridtype = ctx.dims
align_corners = ctx.align_corners
# grad: [B, L * C] --> [L, B, C]
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
grad_embeddings = torch.zeros_like(embeddings)
if dy_dx is not None:
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
else:
grad_inputs = None
_backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners)
if dy_dx is not None:
grad_inputs = grad_inputs.to(inputs.dtype)
return grad_inputs, grad_embeddings, None, None, None, None, None, None
grid_encode = _grid_encode.apply
class GridEncoder(nn.Module):
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
super().__init__()
# the finest resolution desired at the last level, if provided, overridee per_level_scale
if desired_resolution is not None:
per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
self.input_dim = input_dim # coord dims, 2 or 3
self.num_levels = num_levels # num levels, each level multiply resolution by 2
self.level_dim = level_dim # encode channels per level
self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
self.log2_hashmap_size = log2_hashmap_size
self.base_resolution = base_resolution
self.output_dim = num_levels * level_dim
self.gridtype = gridtype
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
self.align_corners = align_corners
# allocate parameters
offsets = []
offset = 0
self.max_params = 2 ** log2_hashmap_size
for i in range(num_levels):
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
offsets.append(offset)
offset += params_in_level
# print(resolution, params_in_level)
offsets.append(offset)
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
self.register_buffer('offsets', offsets)
self.n_params = offsets[-1] * level_dim
# parameters
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
self.reset_parameters()
def reset_parameters(self):
std = 1e-4
self.embeddings.data.uniform_(-std, std)
def __repr__(self):
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
def forward(self, inputs, bound=1):
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
# return: [..., num_levels * level_dim]
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.view(-1, self.input_dim)
outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
outputs = outputs.view(prefix_shape + [self.output_dim])
#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
return outputs
\ No newline at end of file
... ...
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++17',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__','-allow-unsupported-compiler',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++17']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
setup(
name='gridencoder', # package name, import this to use python API
ext_modules=[
CUDAExtension(
name='_gridencoder', # extension name, import this to use CUDA API
sources=[os.path.join(_src_path, 'src', f) for f in [
'gridencoder.cu',
'bindings.cpp',
]],
extra_compile_args={
'cxx': c_flags,
'nvcc': nvcc_flags,
}
),
],
cmdclass={
'build_ext': BuildExtension,
}
)
\ No newline at end of file
... ...
#include <torch/extension.h>
#include "gridencoder.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
}
\ No newline at end of file
... ...