冯杨

对话界面可以显示对话文字内容;

GPU的启动以及训练,可以手动选择GPU
@@ -407,6 +407,7 @@ if __name__ == '__main__': @@ -407,6 +407,7 @@ if __name__ == '__main__':
407 # parser.add_argument('--EMOTION', type=str, default='default') 407 # parser.add_argument('--EMOTION', type=str, default='default')
408 408
409 parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip 409 parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip
  410 + parser.add_argument('--gpu', type=int, default=0, help="指定使用的GPU编号,例如0表示第一张GPU,1表示第二张GPU")
410 411
411 parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush 412 parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
412 parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream 413 parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
@@ -445,7 +446,7 @@ if __name__ == '__main__': @@ -445,7 +446,7 @@ if __name__ == '__main__':
445 elif opt.model == 'wav2lip': 446 elif opt.model == 'wav2lip':
446 from lipreal import LipReal,load_model,load_avatar,warm_up 447 from lipreal import LipReal,load_model,load_avatar,warm_up
447 logger.info(opt) 448 logger.info(opt)
448 - model = load_model("./models/wav2lip.pth") 449 + model = load_model("./models/wav2lip.pth", opt.gpu)
449 avatar = load_avatar(opt.avatar_id) 450 avatar = load_avatar(opt.avatar_id)
450 warm_up(opt.batch_size,model,256) 451 warm_up(opt.batch_size,model,256)
451 # for k in range(opt.max_session): 452 # for k in range(opt.max_session):
@@ -44,8 +44,25 @@ from basereal import BaseReal @@ -44,8 +44,25 @@ from basereal import BaseReal
44 from tqdm import tqdm 44 from tqdm import tqdm
45 from logger import logger 45 from logger import logger
46 46
47 -device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")  
48 -print('Using {} for inference.'.format(device)) 47 +# 根据命令行参数选择GPU设备
  48 +def get_device(gpu_id=0):
  49 + if torch.cuda.is_available():
  50 + if torch.cuda.device_count() > gpu_id:
  51 + torch.cuda.set_device(gpu_id)
  52 + return f"cuda:{gpu_id}"
  53 + else:
  54 + available_gpus = torch.cuda.device_count()
  55 + print(f"指定的GPU {gpu_id} 不可用,可用GPU数量为 {available_gpus},使用默认设备 0")
  56 + torch.cuda.set_device(0)
  57 + return "cuda:0"
  58 + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
  59 + return "mps"
  60 + else:
  61 + return "cpu"
  62 +
  63 +# 全局变量,将在load_model和其他函数中使用
  64 +device = None
  65 +print('Device will be set when model is loaded.')
49 66
50 def _load(checkpoint_path): 67 def _load(checkpoint_path):
51 if device == 'cuda': 68 if device == 'cuda':
@@ -55,7 +72,11 @@ def _load(checkpoint_path): @@ -55,7 +72,11 @@ def _load(checkpoint_path):
55 map_location=lambda storage, loc: storage) 72 map_location=lambda storage, loc: storage)
56 return checkpoint 73 return checkpoint
57 74
58 -def load_model(path): 75 +def load_model(path, gpu_id=0):
  76 + global device
  77 + device = get_device(gpu_id)
  78 + logger.info("Using {} for inference.".format(device))
  79 +
59 model = Wav2Lip() 80 model = Wav2Lip()
60 logger.info("Load checkpoint from: {}".format(path)) 81 logger.info("Load checkpoint from: {}".format(path))
61 checkpoint = _load(path) 82 checkpoint = _load(path)
@@ -4,4 +4,4 @@ __author__ = """Adrian Bulat""" @@ -4,4 +4,4 @@ __author__ = """Adrian Bulat"""
4 __email__ = 'adrian.bulat@nottingham.ac.uk' 4 __email__ = 'adrian.bulat@nottingham.ac.uk'
5 __version__ = '1.0.1' 5 __version__ = '1.0.1'
6 6
7 -from .api import FaceAlignment, LandmarksType, NetworkSize 7 +from .api import FaceAlignment, LandmarksType, NetworkSize, ImageStyle
@@ -5,11 +5,18 @@ from torch.utils.model_zoo import load_url @@ -5,11 +5,18 @@ from torch.utils.model_zoo import load_url
5 from enum import Enum 5 from enum import Enum
6 import numpy as np 6 import numpy as np
7 import cv2 7 import cv2
  8 +from .detection.core import FaceDetector
  9 +
8 try: 10 try:
9 import urllib.request as request_file 11 import urllib.request as request_file
10 -except BaseException: 12 +except ImportError:
11 import urllib as request_file 13 import urllib as request_file
12 14
  15 +try:
  16 + import dlib
  17 +except ImportError:
  18 + dlib = None
  19 +
13 from .models import FAN, ResNetDepth 20 from .models import FAN, ResNetDepth
14 from .utils import * 21 from .utils import *
15 22
@@ -27,6 +34,20 @@ class LandmarksType(Enum): @@ -27,6 +34,20 @@ class LandmarksType(Enum):
27 _3D = 3 34 _3D = 3
28 35
29 36
  37 +class ImageStyle(Enum):
  38 + """Enum class defining different image styles for face detection optimization.
  39 +
  40 + ``REALISTIC`` - Real human faces, standard detection parameters
  41 + ``ANIME`` - Anime/cartoon style faces, optimized for 2D illustrations
  42 + ``ANCIENT`` - Ancient/traditional art style, enhanced for classical paintings
  43 + ``AUTO`` - Automatic style detection based on image characteristics
  44 + """
  45 + REALISTIC = 1
  46 + ANIME = 2
  47 + ANCIENT = 3
  48 + AUTO = 4
  49 +
  50 +
30 class NetworkSize(Enum): 51 class NetworkSize(Enum):
31 # TINY = 1 52 # TINY = 1
32 # SMALL = 2 53 # SMALL = 2
@@ -43,14 +64,65 @@ class NetworkSize(Enum): @@ -43,14 +64,65 @@ class NetworkSize(Enum):
43 64
44 ROOT = os.path.dirname(os.path.abspath(__file__)) 65 ROOT = os.path.dirname(os.path.abspath(__file__))
45 66
  67 +
  68 +def detect_image_style(image):
  69 + """Automatically detect image style based on visual characteristics.
  70 +
  71 + Args:
  72 + image: Input image as numpy array
  73 +
  74 + Returns:
  75 + ImageStyle: Detected style enum
  76 + """
  77 + # Convert to grayscale for analysis
  78 + if len(image.shape) == 3:
  79 + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  80 + else:
  81 + gray = image
  82 +
  83 + # Calculate edge density (anime/cartoon images typically have more defined edges)
  84 + edges = cv2.Canny(gray, 50, 150)
  85 + edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1])
  86 +
  87 + # Calculate color saturation (anime images often have higher saturation)
  88 + if len(image.shape) == 3:
  89 + hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
  90 + saturation_mean = np.mean(hsv[:, :, 1])
  91 + else:
  92 + saturation_mean = 0
  93 +
  94 + # Calculate texture complexity
  95 + laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
  96 +
  97 + # Style classification logic
  98 + if edge_density > 0.15 and saturation_mean > 100:
  99 + return ImageStyle.ANIME
  100 + elif laplacian_var < 100 and saturation_mean < 80:
  101 + return ImageStyle.ANCIENT
  102 + else:
  103 + return ImageStyle.REALISTIC
  104 +
  105 +
46 class FaceAlignment: 106 class FaceAlignment:
47 def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, 107 def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
48 - device='cuda', flip_input=False, face_detector='sfd', verbose=False): 108 + device='cuda', flip_input=False, face_detector='sfd', verbose=False,
  109 + image_style=ImageStyle.AUTO, confidence_threshold=None):
49 self.device = device 110 self.device = device
50 self.flip_input = flip_input 111 self.flip_input = flip_input
51 self.landmarks_type = landmarks_type 112 self.landmarks_type = landmarks_type
52 self.verbose = verbose 113 self.verbose = verbose
53 - 114 + self.image_style = image_style
  115 +
  116 + # Style-specific confidence thresholds
  117 + self.style_thresholds = {
  118 + ImageStyle.REALISTIC: 0.5,
  119 + ImageStyle.ANIME: 0.3, # Lower threshold for anime faces
  120 + ImageStyle.ANCIENT: 0.25, # Even lower for ancient art
  121 + ImageStyle.AUTO: 0.4 # Balanced default
  122 + }
  123 +
  124 + self.confidence_threshold = confidence_threshold or self.style_thresholds.get(image_style, 0.4)
  125 +
54 network_size = int(network_size) 126 network_size = int(network_size)
55 127
56 if 'cuda' in device: 128 if 'cuda' in device:
@@ -61,19 +133,75 @@ class FaceAlignment: @@ -61,19 +133,75 @@ class FaceAlignment:
61 globals(), locals(), [face_detector], 0) 133 globals(), locals(), [face_detector], 0)
62 self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose) 134 self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
63 135
  136 + def preprocess_image_by_style(self, image, style):
  137 + """Apply style-specific preprocessing to improve detection.
  138 +
  139 + Args:
  140 + image: Input image
  141 + style: ImageStyle enum
  142 +
  143 + Returns:
  144 + Preprocessed image
  145 + """
  146 + processed = image.copy()
  147 +
  148 + if style == ImageStyle.ANIME:
  149 + # Enhance edges for anime/cartoon faces
  150 + kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
  151 + processed = cv2.filter2D(processed, -1, kernel)
  152 + # Increase contrast
  153 + processed = cv2.convertScaleAbs(processed, alpha=1.2, beta=10)
  154 +
  155 + elif style == ImageStyle.ANCIENT:
  156 + # Enhance contrast and reduce noise for ancient art
  157 + processed = cv2.convertScaleAbs(processed, alpha=1.3, beta=15)
  158 + # Apply slight gaussian blur to reduce texture noise
  159 + processed = cv2.GaussianBlur(processed, (3, 3), 0.5)
  160 +
  161 + return processed
  162 +
64 def get_detections_for_batch(self, images): 163 def get_detections_for_batch(self, images):
65 - images = images[..., ::-1]  
66 - detected_faces = self.face_detector.detect_from_batch(images.copy()) 164 + # Auto-detect style if needed
  165 + if self.image_style == ImageStyle.AUTO and len(images) > 0:
  166 + detected_style = detect_image_style(images[0])
  167 + current_threshold = self.style_thresholds[detected_style]
  168 + if self.verbose:
  169 + print(f"Auto-detected style: {detected_style.name}, using threshold: {current_threshold}")
  170 + else:
  171 + detected_style = self.image_style
  172 + current_threshold = self.confidence_threshold
  173 +
  174 + # Apply style-specific preprocessing
  175 + processed_images = []
  176 + for img in images:
  177 + processed = self.preprocess_image_by_style(img, detected_style)
  178 + processed_images.append(processed)
  179 +
  180 + # Convert color format
  181 + processed_images = np.array(processed_images)
  182 + processed_images = processed_images[..., ::-1]
  183 +
  184 + # Detect faces with original method
  185 + detected_faces = self.face_detector.detect_from_batch(processed_images.copy())
67 results = [] 186 results = []
68 187
69 for i, d in enumerate(detected_faces): 188 for i, d in enumerate(detected_faces):
70 if len(d) == 0: 189 if len(d) == 0:
71 results.append(None) 190 results.append(None)
72 continue 191 continue
73 - d = d[0]  
74 - d = np.clip(d, 0, None)  
75 192
76 - x1, y1, x2, y2 = map(int, d[:-1]) 193 + # Filter by style-specific confidence threshold
  194 + valid_detections = [det for det in d if len(det) > 4 and det[-1] > current_threshold]
  195 +
  196 + if len(valid_detections) == 0:
  197 + results.append(None)
  198 + continue
  199 +
  200 + # Use the detection with highest confidence
  201 + best_detection = max(valid_detections, key=lambda x: x[-1])
  202 + best_detection = np.clip(best_detection, 0, None)
  203 +
  204 + x1, y1, x2, y2 = map(int, best_detection[:-1])
77 results.append((x1, y1, x2, y2)) 205 results.append((x1, y1, x2, y2))
78 206
79 return results 207 return results
@@ -19,10 +19,21 @@ parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0], @@ -19,10 +19,21 @@ parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
19 help='Padding (top, bottom, left, right). Please adjust to include chin at least') 19 help='Padding (top, bottom, left, right). Please adjust to include chin at least')
20 parser.add_argument('--face_det_batch_size', type=int, 20 parser.add_argument('--face_det_batch_size', type=int,
21 help='Batch size for face detection', default=16) 21 help='Batch size for face detection', default=16)
  22 +parser.add_argument('--gpu_id', type=int, default=0,
  23 + help='GPU device ID to use (default: 0)')
  24 +parser.add_argument('--image_style', type=str, default='auto',
  25 + choices=['auto', 'realistic', 'anime', 'ancient'],
  26 + help='Image style for face detection optimization (default: auto)')
  27 +parser.add_argument('--confidence_threshold', type=float, default=None,
  28 + help='Custom confidence threshold for face detection (overrides style defaults)')
22 args = parser.parse_args() 29 args = parser.parse_args()
23 30
24 -device = 'cuda' if torch.cuda.is_available() else 'cpu'  
25 -print('Using {} for inference.'.format(device)) 31 +if torch.cuda.is_available():
  32 + device = f'cuda:{args.gpu_id}'
  33 + print(f'Using GPU {args.gpu_id} for inference.')
  34 +else:
  35 + device = 'cpu'
  36 + print('CUDA not available, using CPU for inference.')
26 37
27 def osmakedirs(path_list): 38 def osmakedirs(path_list):
28 for path in path_list: 39 for path in path_list:
@@ -60,8 +71,24 @@ def get_smoothened_boxes(boxes, T): @@ -60,8 +71,24 @@ def get_smoothened_boxes(boxes, T):
60 return boxes 71 return boxes
61 72
62 def face_detect(images): 73 def face_detect(images):
63 - detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,  
64 - flip_input=False, device=device) 74 + # Convert style string to enum
  75 + style_map = {
  76 + 'auto': face_detection.ImageStyle.AUTO,
  77 + 'realistic': face_detection.ImageStyle.REALISTIC,
  78 + 'anime': face_detection.ImageStyle.ANIME,
  79 + 'ancient': face_detection.ImageStyle.ANCIENT
  80 + }
  81 +
  82 + image_style = style_map.get(args.image_style, face_detection.ImageStyle.AUTO)
  83 +
  84 + detector = face_detection.FaceAlignment(
  85 + face_detection.LandmarksType._2D,
  86 + flip_input=False,
  87 + device=device,
  88 + image_style=image_style,
  89 + confidence_threshold=args.confidence_threshold,
  90 + verbose=True
  91 + )
65 92
66 batch_size = args.face_det_batch_size 93 batch_size = args.face_det_batch_size
67 94
@@ -221,6 +221,52 @@ @@ -221,6 +221,52 @@
221 border-radius: 8px; 221 border-radius: 8px;
222 box-shadow: 0 1px 3px rgba(0,0,0,0.05); 222 box-shadow: 0 1px 3px rgba(0,0,0,0.05);
223 } 223 }
  224 +
  225 + /* 聊天消息样式 */
  226 + #chatOverlay .message {
  227 + display: flex;
  228 + margin-bottom: 10px;
  229 + max-width: 100%;
  230 + }
  231 +
  232 + #chatOverlay .message.right {
  233 + justify-content: flex-end;
  234 + }
  235 +
  236 + #chatOverlay .message.left {
  237 + justify-content: flex-start;
  238 + }
  239 +
  240 + #chatOverlay .avatar {
  241 + width: 30px;
  242 + height: 30px;
  243 + border-radius: 50%;
  244 + margin: 0 5px;
  245 + }
  246 +
  247 + #chatOverlay .text-container {
  248 + background-color: rgba(255,255,255,0.9);
  249 + border-radius: 10px;
  250 + padding: 8px 12px;
  251 + max-width: 70%;
  252 + color: #333;
  253 + }
  254 +
  255 + #chatOverlay .message.right .text-container {
  256 + background-color: #4285f4;
  257 + color: white;
  258 + }
  259 +
  260 + #chatOverlay .time {
  261 + font-size: 10px;
  262 + color: #888;
  263 + margin-top: 4px;
  264 + text-align: right;
  265 + }
  266 +
  267 + #chatOverlay .message.right .time {
  268 + color: rgba(255,255,255,0.8);
  269 + }
224 </style> 270 </style>
225 </head> 271 </head>
226 <body> 272 <body>
@@ -281,6 +327,10 @@ @@ -281,6 +327,10 @@
281 <audio id="audio" autoplay="true"></audio> 327 <audio id="audio" autoplay="true"></audio>
282 <video id="video" autoplay="true" playsinline="true"></video> 328 <video id="video" autoplay="true" playsinline="true"></video>
283 </div> 329 </div>
  330 + <!-- 聊天消息显示区域 -->
  331 + <div id="chatOverlay" style="position: absolute; bottom: 20px; right: 20px; width: 300px; max-height: 400px; overflow-y: auto; background-color: rgba(0,0,0,0.7); border-radius: 10px; padding: 10px; color: white; z-index: 1005;">
  332 + <!-- 消息将在这里动态添加 -->
  333 + </div>
284 </div> 334 </div>
285 335
286 <script src="client.js"></script> 336 <script src="client.js"></script>
@@ -554,12 +604,18 @@ @@ -554,12 +604,18 @@
554 604
555 function addMessage(text, type = "right") { 605 function addMessage(text, type = "right") {
556 const chatOverlay = document.getElementById("chatOverlay"); 606 const chatOverlay = document.getElementById("chatOverlay");
  607 + if (!chatOverlay) {
  608 + console.error('聊天显示区域不存在');
  609 + return;
  610 + }
  611 +
557 const messageDiv = document.createElement("div"); 612 const messageDiv = document.createElement("div");
558 messageDiv.classList.add("message", type); 613 messageDiv.classList.add("message", type);
559 614
560 const avatar = document.createElement("img"); 615 const avatar = document.createElement("img");
561 avatar.classList.add("avatar"); 616 avatar.classList.add("avatar");
562 - avatar.src = type === "right" ? "images/avatar-right.png" : "images/avatar-left.png"; 617 + // 使用默认头像,如果图片不存在不会报错
  618 + avatar.src = type === "right" ? "data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24'%3E%3Ccircle cx='12' cy='12' r='12' fill='%234285f4'/%3E%3Cpath d='M12 6a3 3 0 1 0 0 6 3 3 0 0 0 0-6zm0 8c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z' fill='white'/%3E%3C/svg%3E" : "data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24'%3E%3Ccircle cx='12' cy='12' r='12' fill='%23999'/%3E%3Cpath d='M12 6a3 3 0 1 0 0 6 3 3 0 0 0 0-6zm0 8c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z' fill='white'/%3E%3C/svg%3E";
563 619
564 const textContainer = document.createElement("div"); 620 const textContainer = document.createElement("div");
565 textContainer.classList.add("text-container"); 621 textContainer.classList.add("text-container");
@@ -582,6 +638,9 @@ @@ -582,6 +638,9 @@
582 638
583 // 自动滚动到底部 639 // 自动滚动到底部
584 chatOverlay.scrollTop = chatOverlay.scrollHeight; 640 chatOverlay.scrollTop = chatOverlay.scrollHeight;
  641 +
  642 + // 显示聊天区域(如果之前是隐藏的)
  643 + chatOverlay.style.display = 'block';
585 } 644 }
586 645
587 ws.onmessage = function(e) { 646 ws.onmessage = function(e) {
@@ -605,6 +664,47 @@ @@ -605,6 +664,47 @@
605 }else if (messageData.Data.Key == "text") { 664 }else if (messageData.Data.Key == "text") {
606 var reply = messageData.Data.Value; 665 var reply = messageData.Data.Value;
607 addMessage(reply, "left"); 666 addMessage(reply, "left");
  667 +
  668 + // 将text类型消息推送到服务器,由数字人服务通过TTS合成语音并播放
  669 + fetch('/human', {
  670 + body: JSON.stringify({
  671 + text: reply,
  672 + type: 'echo',
  673 + interrupt: true,
  674 + sessionid: parseInt(document.getElementById('sessionid').value),
  675 + }),
  676 + headers: {
  677 + 'Content-Type': 'application/json'
  678 + },
  679 + method: 'POST'
  680 + });
  681 +
  682 + // 如果是纯文本消息且没有音频,则使用浏览器的语音合成API进行本地语音合成
  683 + // if (!messageData.Data.HttpValue && window.speechSynthesis) {
  684 + // console.log('使用本地语音合成播放文本:', reply);
  685 + // var utterance = new SpeechSynthesisUtterance(reply);
  686 + // utterance.lang = 'zh-CN'; // 设置语言为中文
  687 + // utterance.rate = 1.0; // 设置语速
  688 + // utterance.pitch = 1.0; // 设置音高
  689 + // utterance.volume = 1.0; // 设置音量
  690 + // speechSynthesis.speak(utterance);
  691 + // }
  692 + }else if (messageData.Data.Key == "plaintext") {
  693 + // 处理纯文本消息类型
  694 + var textContent = messageData.Data.Value;
  695 + console.log('收到纯文本消息:', textContent);
  696 + addMessage(textContent, "left");
  697 +
  698 + // 使用浏览器的语音合成API进行本地语音合成
  699 + if (window.speechSynthesis) {
  700 + console.log('使用本地语音合成播放文本:', textContent);
  701 + var utterance = new SpeechSynthesisUtterance(textContent);
  702 + utterance.lang = 'zh-CN'; // 设置语言为中文
  703 + utterance.rate = 1.0; // 设置语速
  704 + utterance.pitch = 1.0; // 设置音高
  705 + utterance.volume = 1.0; // 设置音量
  706 + speechSynthesis.speak(utterance);
  707 + }
608 } 708 }
609 } 709 }
610 710