simple_musetalk.py 13 KB
import argparse
import glob
import json
import os
import pickle
import shutil

import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from diffusers import AutoencoderKL
from face_alignment import NetworkSize
from mmpose.apis import inference_topdown, init_model
from mmpose.structures import merge_data_samples
from tqdm import tqdm

try:
    from utils.face_parsing import FaceParsing
except ModuleNotFoundError:
    from musetalk.utils.face_parsing import FaceParsing


def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
    cap = cv2.VideoCapture(vid_path)
    count = 0
    while True:
        if count > cut_frame:
            break
        ret, frame = cap.read()
        if ret:
            cv2.putText(frame, "Ewin", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1)
            cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
            count += 1
        else:
            break


def read_imgs(img_list):
    frames = []
    print('reading images...')
    for img_path in tqdm(img_list):
        frame = cv2.imread(img_path)
        frames.append(frame)
    return frames


def get_landmark_and_bbox(img_list, upperbondrange=0):
    frames = read_imgs(img_list)
    batch_size_fa = 1
    batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
    coords_list = []
    landmarks = []
    if upperbondrange != 0:
        print('get key_landmark and face bounding boxes with the bbox_shift:', upperbondrange)
    else:
        print('get key_landmark and face bounding boxes with the default value')
    average_range_minus = []
    average_range_plus = []
    coord_placeholder = (0.0, 0.0, 0.0, 0.0)
    for fb in tqdm(batches):
        results = inference_topdown(model, np.asarray(fb)[0])
        results = merge_data_samples(results)
        keypoints = results.pred_instances.keypoints
        face_land_mark = keypoints[0][23:91]
        face_land_mark = face_land_mark.astype(np.int32)

        # get bounding boxes by face detetion
        bbox = fa.get_detections_for_batch(np.asarray(fb))

        # adjust the bounding box refer to landmark
        # Add the bounding box to a tuple and append it to the coordinates list
        for j, f in enumerate(bbox):
            if f is None:  # no face in the image
                coords_list += [coord_placeholder]
                continue

            half_face_coord = face_land_mark[29]  # np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
            range_minus = (face_land_mark[30] - face_land_mark[29])[1]
            range_plus = (face_land_mark[29] - face_land_mark[28])[1]
            average_range_minus.append(range_minus)
            average_range_plus.append(range_plus)
            if upperbondrange != 0:
                half_face_coord[1] = upperbondrange + half_face_coord[1]  # 手动调整  + 向下(偏29)  - 向上(偏28)
            half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1]
            upper_bond = half_face_coord[1] - half_face_dist

            f_landmark = (
                np.min(face_land_mark[:, 0]), int(upper_bond), np.max(face_land_mark[:, 0]),
                np.max(face_land_mark[:, 1]))
            x1, y1, x2, y2 = f_landmark

            if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0:  # if the landmark bbox is not suitable, reuse the bbox
                coords_list += [f]
                w, h = f[2] - f[0], f[3] - f[1]
                print("error bbox:", f)
            else:
                coords_list += [f_landmark]
    return coords_list, frames


class FaceAlignment:
    def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
                 device='cuda', flip_input=False, face_detector='sfd', verbose=False):
        self.device = device
        self.flip_input = flip_input
        self.landmarks_type = landmarks_type
        self.verbose = verbose

        network_size = int(network_size)
        if 'cuda' in device:
            torch.backends.cudnn.benchmark = True
            #             torch.backends.cuda.matmul.allow_tf32 = False
            #             torch.backends.cudnn.benchmark = True
            #             torch.backends.cudnn.deterministic = False
            #             torch.backends.cudnn.allow_tf32 = True
            print('cuda start')

        # Get the face detector
        face_detector_module = __import__('face_detection.detection.' + face_detector,
                                          globals(), locals(), [face_detector], 0)

        self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)

    def get_detections_for_batch(self, images):
        images = images[..., ::-1]
        detected_faces = self.face_detector.detect_from_batch(images.copy())
        results = []

        for i, d in enumerate(detected_faces):
            if len(d) == 0:
                results.append(None)
                continue
            d = d[0]
            d = np.clip(d, 0, None)

            x1, y1, x2, y2 = map(int, d[:-1])
            results.append((x1, y1, x2, y2))
        return results


def get_mask_tensor():
    """
    Creates a mask tensor for image processing.
    :return: A mask tensor.
    """
    mask_tensor = torch.zeros((256, 256))
    mask_tensor[:256 // 2, :] = 1
    mask_tensor[mask_tensor < 0.5] = 0
    mask_tensor[mask_tensor >= 0.5] = 1
    return mask_tensor


def preprocess_img(img_name, half_mask=False):
    window = []
    if isinstance(img_name, str):
        window_fnames = [img_name]
        for fname in window_fnames:
            img = cv2.imread(fname)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (256, 256),
                             interpolation=cv2.INTER_LANCZOS4)
            window.append(img)
    else:
        img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
        window.append(img)
    x = np.asarray(window) / 255.
    x = np.transpose(x, (3, 0, 1, 2))
    x = torch.squeeze(torch.FloatTensor(x))
    if half_mask:
        x = x * (get_mask_tensor() > 0.5)
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    x = normalize(x)
    x = x.unsqueeze(0)  # [1, 3, 256, 256] torch tensor
    x = x.to(device)
    return x


def encode_latents(image):
    with torch.no_grad():
        init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
    init_latents = vae.config.scaling_factor * init_latent_dist.sample()
    return init_latents


def get_latents_for_unet(img):
    ref_image = preprocess_img(img, half_mask=True)  # [1, 3, 256, 256] RGB, torch tensor
    masked_latents = encode_latents(ref_image)  # [1, 4, 32, 32], torch tensor
    ref_image = preprocess_img(img, half_mask=False)  # [1, 3, 256, 256] RGB, torch tensor
    ref_latents = encode_latents(ref_image)  # [1, 4, 32, 32], torch tensor
    latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
    return latent_model_input


def get_crop_box(box, expand):
    x, y, x1, y1 = box
    x_c, y_c = (x + x1) // 2, (y + y1) // 2
    w, h = x1 - x, y1 - y
    s = int(max(w, h) // 2 * expand)
    crop_box = [x_c - s, y_c - s, x_c + s, y_c + s]
    return crop_box, s


def face_seg(image):
    seg_image = fp(image)
    if seg_image is None:
        print("error, no person_segment")
        return None

    seg_image = seg_image.resize(image.size)
    return seg_image


def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.2):
    body = Image.fromarray(image[:, :, ::-1])

    x, y, x1, y1 = face_box
    # print(x1-x,y1-y)
    crop_box, s = get_crop_box(face_box, expand)
    x_s, y_s, x_e, y_e = crop_box

    face_large = body.crop(crop_box)
    ori_shape = face_large.size

    mask_image = face_seg(face_large)
    mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
    mask_image = Image.new('L', ori_shape, 0)
    mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))

    # keep upper_boundary_ratio of talking area
    width, height = mask_image.size
    top_boundary = int(height * upper_boundary_ratio)
    modified_mask_image = Image.new('L', ori_shape, 0)
    modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))

    blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
    mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
    return mask_array, crop_box


##todo 简单根据文件后缀判断  要更精确的可以自己修改 使用 magic
def is_video_file(file_path):
    video_exts = ['.mp4', '.mkv', '.flv', '.avi', '.mov']  # 这里列出了一些常见的视频文件扩展名,可以根据需要添加更多
    file_ext = os.path.splitext(file_path)[1].lower()  # 获取文件扩展名并转换为小写
    return file_ext in video_exts


def create_dir(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


current_dir = os.path.dirname(os.path.abspath(__file__))


def create_musetalk_human(file, avatar_id):
    # 保存文件设置 可以不动
    save_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}')
    save_full_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}/full_imgs')
    create_dir(save_path)
    create_dir(save_full_path)
    mask_out_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}/mask')
    create_dir(mask_out_path)

    # 模型
    mask_coords_path = os.path.join(current_dir, f'{save_path}/mask_coords.pkl')
    coords_path = os.path.join(current_dir, f'{save_path}/coords.pkl')
    latents_out_path = os.path.join(current_dir, f'{save_path}/latents.pt')

    with open(os.path.join(current_dir, f'{save_path}/avator_info.json'), "w") as f:
        json.dump({
            "avatar_id": avatar_id,
            "video_path": file,
            "bbox_shift": 5
        }, f)

    if os.path.isfile(file):
        if is_video_file(file):
            video2imgs(file, save_full_path, ext='png')
        else:
            shutil.copyfile(file, f"{save_full_path}/{os.path.basename(file)}")
    else:
        files = os.listdir(file)
        files.sort()
        files = [file for file in files if file.split(".")[-1] == "png"]
        for filename in files:
            shutil.copyfile(f"{file}/{filename}", f"{save_full_path}/{filename}")
    input_img_list = sorted(glob.glob(os.path.join(save_full_path, '*.[jpJP][pnPN]*[gG]')))
    print("extracting landmarks...")
    coord_list, frame_list = get_landmark_and_bbox(input_img_list, 5)
    input_latent_list = []
    idx = -1
    # maker if the bbox is not sufficient
    coord_placeholder = (0.0, 0.0, 0.0, 0.0)
    for bbox, frame in zip(coord_list, frame_list):
        idx = idx + 1
        if bbox == coord_placeholder:
            continue
        x1, y1, x2, y2 = bbox
        crop_frame = frame[y1:y2, x1:x2]
        resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
        latents = get_latents_for_unet(resized_crop_frame)
        input_latent_list.append(latents)

    frame_list_cycle = frame_list #+ frame_list[::-1]
    coord_list_cycle = coord_list #+ coord_list[::-1]
    input_latent_list_cycle = input_latent_list #+ input_latent_list[::-1]
    mask_coords_list_cycle = []
    mask_list_cycle = []
    for i, frame in enumerate(tqdm(frame_list_cycle)):
        cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame)
        face_box = coord_list_cycle[i]
        mask, crop_box = get_image_prepare_material(frame, face_box)
        cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
        mask_coords_list_cycle += [crop_box]
        mask_list_cycle.append(mask)

    with open(mask_coords_path, 'wb') as f:
        pickle.dump(mask_coords_list_cycle, f)

    with open(coords_path, 'wb') as f:
        pickle.dump(coord_list_cycle, f)
    torch.save(input_latent_list_cycle, os.path.join(latents_out_path))


# initialize the mmpose model
device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
fa = FaceAlignment(1, flip_input=False, device=device)
config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py')
checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth'))
model = init_model(config_file, checkpoint_file, device=device)
vae = AutoencoderKL.from_pretrained(os.path.abspath(os.path.join(current_dir, '../models/sd-vae-ft-mse')))
vae.to(device)
fp = FaceParsing(os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/resnet18-5c106cde.pth')),
                 os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/79999_iter.pth')))
if __name__ == '__main__':
    # 视频文件地址
    parser = argparse.ArgumentParser()
    parser.add_argument("--file",
                        type=str,
                        default=r'D:\ok\00000000.png',
                        )
    parser.add_argument("--avatar_id",
                        type=str,
                        default='3',
                        )
    args = parser.parse_args()
    create_musetalk_human(args.file, args.avatar_id)