冯杨

对话界面可以显示对话文字内容;

GPU的启动以及训练,可以手动选择GPU
... ... @@ -407,6 +407,7 @@ if __name__ == '__main__':
# parser.add_argument('--EMOTION', type=str, default='default')
parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip
parser.add_argument('--gpu', type=int, default=0, help="指定使用的GPU编号,例如0表示第一张GPU,1表示第二张GPU")
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
... ... @@ -445,7 +446,7 @@ if __name__ == '__main__':
elif opt.model == 'wav2lip':
from lipreal import LipReal,load_model,load_avatar,warm_up
logger.info(opt)
model = load_model("./models/wav2lip.pth")
model = load_model("./models/wav2lip.pth", opt.gpu)
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model,256)
# for k in range(opt.max_session):
... ...
... ... @@ -44,8 +44,25 @@ from basereal import BaseReal
from tqdm import tqdm
from logger import logger
device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
print('Using {} for inference.'.format(device))
# 根据命令行参数选择GPU设备
def get_device(gpu_id=0):
if torch.cuda.is_available():
if torch.cuda.device_count() > gpu_id:
torch.cuda.set_device(gpu_id)
return f"cuda:{gpu_id}"
else:
available_gpus = torch.cuda.device_count()
print(f"指定的GPU {gpu_id} 不可用,可用GPU数量为 {available_gpus},使用默认设备 0")
torch.cuda.set_device(0)
return "cuda:0"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
# 全局变量,将在load_model和其他函数中使用
device = None
print('Device will be set when model is loaded.')
def _load(checkpoint_path):
if device == 'cuda':
... ... @@ -55,7 +72,11 @@ def _load(checkpoint_path):
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
def load_model(path, gpu_id=0):
global device
device = get_device(gpu_id)
logger.info("Using {} for inference.".format(device))
model = Wav2Lip()
logger.info("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
... ...
... ... @@ -4,4 +4,4 @@ __author__ = """Adrian Bulat"""
__email__ = 'adrian.bulat@nottingham.ac.uk'
__version__ = '1.0.1'
from .api import FaceAlignment, LandmarksType, NetworkSize
from .api import FaceAlignment, LandmarksType, NetworkSize, ImageStyle
... ...
... ... @@ -5,11 +5,18 @@ from torch.utils.model_zoo import load_url
from enum import Enum
import numpy as np
import cv2
from .detection.core import FaceDetector
try:
import urllib.request as request_file
except BaseException:
except ImportError:
import urllib as request_file
try:
import dlib
except ImportError:
dlib = None
from .models import FAN, ResNetDepth
from .utils import *
... ... @@ -27,6 +34,20 @@ class LandmarksType(Enum):
_3D = 3
class ImageStyle(Enum):
"""Enum class defining different image styles for face detection optimization.
``REALISTIC`` - Real human faces, standard detection parameters
``ANIME`` - Anime/cartoon style faces, optimized for 2D illustrations
``ANCIENT`` - Ancient/traditional art style, enhanced for classical paintings
``AUTO`` - Automatic style detection based on image characteristics
"""
REALISTIC = 1
ANIME = 2
ANCIENT = 3
AUTO = 4
class NetworkSize(Enum):
# TINY = 1
# SMALL = 2
... ... @@ -43,14 +64,65 @@ class NetworkSize(Enum):
ROOT = os.path.dirname(os.path.abspath(__file__))
def detect_image_style(image):
"""Automatically detect image style based on visual characteristics.
Args:
image: Input image as numpy array
Returns:
ImageStyle: Detected style enum
"""
# Convert to grayscale for analysis
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
# Calculate edge density (anime/cartoon images typically have more defined edges)
edges = cv2.Canny(gray, 50, 150)
edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1])
# Calculate color saturation (anime images often have higher saturation)
if len(image.shape) == 3:
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
saturation_mean = np.mean(hsv[:, :, 1])
else:
saturation_mean = 0
# Calculate texture complexity
laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
# Style classification logic
if edge_density > 0.15 and saturation_mean > 100:
return ImageStyle.ANIME
elif laplacian_var < 100 and saturation_mean < 80:
return ImageStyle.ANCIENT
else:
return ImageStyle.REALISTIC
class FaceAlignment:
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
device='cuda', flip_input=False, face_detector='sfd', verbose=False,
image_style=ImageStyle.AUTO, confidence_threshold=None):
self.device = device
self.flip_input = flip_input
self.landmarks_type = landmarks_type
self.verbose = verbose
self.image_style = image_style
# Style-specific confidence thresholds
self.style_thresholds = {
ImageStyle.REALISTIC: 0.5,
ImageStyle.ANIME: 0.3, # Lower threshold for anime faces
ImageStyle.ANCIENT: 0.25, # Even lower for ancient art
ImageStyle.AUTO: 0.4 # Balanced default
}
self.confidence_threshold = confidence_threshold or self.style_thresholds.get(image_style, 0.4)
network_size = int(network_size)
if 'cuda' in device:
... ... @@ -61,19 +133,75 @@ class FaceAlignment:
globals(), locals(), [face_detector], 0)
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
def preprocess_image_by_style(self, image, style):
"""Apply style-specific preprocessing to improve detection.
Args:
image: Input image
style: ImageStyle enum
Returns:
Preprocessed image
"""
processed = image.copy()
if style == ImageStyle.ANIME:
# Enhance edges for anime/cartoon faces
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
processed = cv2.filter2D(processed, -1, kernel)
# Increase contrast
processed = cv2.convertScaleAbs(processed, alpha=1.2, beta=10)
elif style == ImageStyle.ANCIENT:
# Enhance contrast and reduce noise for ancient art
processed = cv2.convertScaleAbs(processed, alpha=1.3, beta=15)
# Apply slight gaussian blur to reduce texture noise
processed = cv2.GaussianBlur(processed, (3, 3), 0.5)
return processed
def get_detections_for_batch(self, images):
images = images[..., ::-1]
detected_faces = self.face_detector.detect_from_batch(images.copy())
# Auto-detect style if needed
if self.image_style == ImageStyle.AUTO and len(images) > 0:
detected_style = detect_image_style(images[0])
current_threshold = self.style_thresholds[detected_style]
if self.verbose:
print(f"Auto-detected style: {detected_style.name}, using threshold: {current_threshold}")
else:
detected_style = self.image_style
current_threshold = self.confidence_threshold
# Apply style-specific preprocessing
processed_images = []
for img in images:
processed = self.preprocess_image_by_style(img, detected_style)
processed_images.append(processed)
# Convert color format
processed_images = np.array(processed_images)
processed_images = processed_images[..., ::-1]
# Detect faces with original method
detected_faces = self.face_detector.detect_from_batch(processed_images.copy())
results = []
for i, d in enumerate(detected_faces):
if len(d) == 0:
results.append(None)
continue
d = d[0]
d = np.clip(d, 0, None)
x1, y1, x2, y2 = map(int, d[:-1])
# Filter by style-specific confidence threshold
valid_detections = [det for det in d if len(det) > 4 and det[-1] > current_threshold]
if len(valid_detections) == 0:
results.append(None)
continue
# Use the detection with highest confidence
best_detection = max(valid_detections, key=lambda x: x[-1])
best_detection = np.clip(best_detection, 0, None)
x1, y1, x2, y2 = map(int, best_detection[:-1])
results.append((x1, y1, x2, y2))
return results
\ No newline at end of file
... ...
... ... @@ -19,10 +19,21 @@ parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
help='Padding (top, bottom, left, right). Please adjust to include chin at least')
parser.add_argument('--face_det_batch_size', type=int,
help='Batch size for face detection', default=16)
parser.add_argument('--gpu_id', type=int, default=0,
help='GPU device ID to use (default: 0)')
parser.add_argument('--image_style', type=str, default='auto',
choices=['auto', 'realistic', 'anime', 'ancient'],
help='Image style for face detection optimization (default: auto)')
parser.add_argument('--confidence_threshold', type=float, default=None,
help='Custom confidence threshold for face detection (overrides style defaults)')
args = parser.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
if torch.cuda.is_available():
device = f'cuda:{args.gpu_id}'
print(f'Using GPU {args.gpu_id} for inference.')
else:
device = 'cpu'
print('CUDA not available, using CPU for inference.')
def osmakedirs(path_list):
for path in path_list:
... ... @@ -60,8 +71,24 @@ def get_smoothened_boxes(boxes, T):
return boxes
def face_detect(images):
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device)
# Convert style string to enum
style_map = {
'auto': face_detection.ImageStyle.AUTO,
'realistic': face_detection.ImageStyle.REALISTIC,
'anime': face_detection.ImageStyle.ANIME,
'ancient': face_detection.ImageStyle.ANCIENT
}
image_style = style_map.get(args.image_style, face_detection.ImageStyle.AUTO)
detector = face_detection.FaceAlignment(
face_detection.LandmarksType._2D,
flip_input=False,
device=device,
image_style=image_style,
confidence_threshold=args.confidence_threshold,
verbose=True
)
batch_size = args.face_det_batch_size
... ...
... ... @@ -221,6 +221,52 @@
border-radius: 8px;
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
}
/* 聊天消息样式 */
#chatOverlay .message {
display: flex;
margin-bottom: 10px;
max-width: 100%;
}
#chatOverlay .message.right {
justify-content: flex-end;
}
#chatOverlay .message.left {
justify-content: flex-start;
}
#chatOverlay .avatar {
width: 30px;
height: 30px;
border-radius: 50%;
margin: 0 5px;
}
#chatOverlay .text-container {
background-color: rgba(255,255,255,0.9);
border-radius: 10px;
padding: 8px 12px;
max-width: 70%;
color: #333;
}
#chatOverlay .message.right .text-container {
background-color: #4285f4;
color: white;
}
#chatOverlay .time {
font-size: 10px;
color: #888;
margin-top: 4px;
text-align: right;
}
#chatOverlay .message.right .time {
color: rgba(255,255,255,0.8);
}
</style>
</head>
<body>
... ... @@ -281,6 +327,10 @@
<audio id="audio" autoplay="true"></audio>
<video id="video" autoplay="true" playsinline="true"></video>
</div>
<!-- 聊天消息显示区域 -->
<div id="chatOverlay" style="position: absolute; bottom: 20px; right: 20px; width: 300px; max-height: 400px; overflow-y: auto; background-color: rgba(0,0,0,0.7); border-radius: 10px; padding: 10px; color: white; z-index: 1005;">
<!-- 消息将在这里动态添加 -->
</div>
</div>
<script src="client.js"></script>
... ... @@ -554,12 +604,18 @@
function addMessage(text, type = "right") {
const chatOverlay = document.getElementById("chatOverlay");
if (!chatOverlay) {
console.error('聊天显示区域不存在');
return;
}
const messageDiv = document.createElement("div");
messageDiv.classList.add("message", type);
const avatar = document.createElement("img");
avatar.classList.add("avatar");
avatar.src = type === "right" ? "images/avatar-right.png" : "images/avatar-left.png";
// 使用默认头像,如果图片不存在不会报错
avatar.src = type === "right" ? "data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24'%3E%3Ccircle cx='12' cy='12' r='12' fill='%234285f4'/%3E%3Cpath d='M12 6a3 3 0 1 0 0 6 3 3 0 0 0 0-6zm0 8c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z' fill='white'/%3E%3C/svg%3E" : "data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24'%3E%3Ccircle cx='12' cy='12' r='12' fill='%23999'/%3E%3Cpath d='M12 6a3 3 0 1 0 0 6 3 3 0 0 0 0-6zm0 8c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z' fill='white'/%3E%3C/svg%3E";
const textContainer = document.createElement("div");
textContainer.classList.add("text-container");
... ... @@ -582,6 +638,9 @@
// 自动滚动到底部
chatOverlay.scrollTop = chatOverlay.scrollHeight;
// 显示聊天区域(如果之前是隐藏的)
chatOverlay.style.display = 'block';
}
ws.onmessage = function(e) {
... ... @@ -605,6 +664,47 @@
}else if (messageData.Data.Key == "text") {
var reply = messageData.Data.Value;
addMessage(reply, "left");
// 将text类型消息推送到服务器,由数字人服务通过TTS合成语音并播放
fetch('/human', {
body: JSON.stringify({
text: reply,
type: 'echo',
interrupt: true,
sessionid: parseInt(document.getElementById('sessionid').value),
}),
headers: {
'Content-Type': 'application/json'
},
method: 'POST'
});
// 如果是纯文本消息且没有音频,则使用浏览器的语音合成API进行本地语音合成
// if (!messageData.Data.HttpValue && window.speechSynthesis) {
// console.log('使用本地语音合成播放文本:', reply);
// var utterance = new SpeechSynthesisUtterance(reply);
// utterance.lang = 'zh-CN'; // 设置语言为中文
// utterance.rate = 1.0; // 设置语速
// utterance.pitch = 1.0; // 设置音高
// utterance.volume = 1.0; // 设置音量
// speechSynthesis.speak(utterance);
// }
}else if (messageData.Data.Key == "plaintext") {
// 处理纯文本消息类型
var textContent = messageData.Data.Value;
console.log('收到纯文本消息:', textContent);
addMessage(textContent, "left");
// 使用浏览器的语音合成API进行本地语音合成
if (window.speechSynthesis) {
console.log('使用本地语音合成播放文本:', textContent);
var utterance = new SpeechSynthesisUtterance(textContent);
utterance.lang = 'zh-CN'; // 设置语言为中文
utterance.rate = 1.0; // 设置语速
utterance.pitch = 1.0; // 设置音高
utterance.volume = 1.0; // 设置音量
speechSynthesis.speak(utterance);
}
}
}
... ...