api.py 6.82 KB
from __future__ import print_function
import os
import torch
from torch.utils.model_zoo import load_url
from enum import Enum
import numpy as np
import cv2
from .detection.core import FaceDetector

try:
    import urllib.request as request_file
except ImportError:
    import urllib as request_file

try:
    import dlib
except ImportError:
    dlib = None

from .models import FAN, ResNetDepth
from .utils import *


class LandmarksType(Enum):
    """Enum class defining the type of landmarks to detect.

    ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
    ``_2halfD`` - this points represent the projection of the 3D points into 3D
    ``_3D`` - detect the points ``(x,y,z)``` in a 3D space

    """
    _2D = 1
    _2halfD = 2
    _3D = 3


class ImageStyle(Enum):
    """Enum class defining different image styles for face detection optimization.
    
    ``REALISTIC`` - Real human faces, standard detection parameters
    ``ANIME`` - Anime/cartoon style faces, optimized for 2D illustrations
    ``ANCIENT`` - Ancient/traditional art style, enhanced for classical paintings
    ``AUTO`` - Automatic style detection based on image characteristics
    """
    REALISTIC = 1
    ANIME = 2
    ANCIENT = 3
    AUTO = 4


class NetworkSize(Enum):
    # TINY = 1
    # SMALL = 2
    # MEDIUM = 3
    LARGE = 4

    def __new__(cls, value):
        member = object.__new__(cls)
        member._value_ = value
        return member

    def __int__(self):
        return self.value

ROOT = os.path.dirname(os.path.abspath(__file__))


def detect_image_style(image):
    """Automatically detect image style based on visual characteristics.
    
    Args:
        image: Input image as numpy array
        
    Returns:
        ImageStyle: Detected style enum
    """
    # Convert to grayscale for analysis
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        gray = image
    
    # Calculate edge density (anime/cartoon images typically have more defined edges)
    edges = cv2.Canny(gray, 50, 150)
    edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1])
    
    # Calculate color saturation (anime images often have higher saturation)
    if len(image.shape) == 3:
        hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        saturation_mean = np.mean(hsv[:, :, 1])
    else:
        saturation_mean = 0
    
    # Calculate texture complexity
    laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
    
    # Style classification logic
    if edge_density > 0.15 and saturation_mean > 100:
        return ImageStyle.ANIME
    elif laplacian_var < 100 and saturation_mean < 80:
        return ImageStyle.ANCIENT
    else:
        return ImageStyle.REALISTIC


class FaceAlignment:
    def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
                 device='cuda', flip_input=False, face_detector='sfd', verbose=False,
                 image_style=ImageStyle.AUTO, confidence_threshold=None):
        self.device = device
        self.flip_input = flip_input
        self.landmarks_type = landmarks_type
        self.verbose = verbose
        self.image_style = image_style
        
        # Style-specific confidence thresholds
        self.style_thresholds = {
            ImageStyle.REALISTIC: 0.5,
            ImageStyle.ANIME: 0.3,      # Lower threshold for anime faces
            ImageStyle.ANCIENT: 0.25,   # Even lower for ancient art
            ImageStyle.AUTO: 0.4        # Balanced default
        }
        
        self.confidence_threshold = confidence_threshold or self.style_thresholds.get(image_style, 0.4)
        
        network_size = int(network_size)

        if 'cuda' in device:
            torch.backends.cudnn.benchmark = True

        # Get the face detector
        face_detector_module = __import__('face_detection.detection.' + face_detector,
                                          globals(), locals(), [face_detector], 0)
        self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)

    def preprocess_image_by_style(self, image, style):
        """Apply style-specific preprocessing to improve detection.
        
        Args:
            image: Input image
            style: ImageStyle enum
            
        Returns:
            Preprocessed image
        """
        processed = image.copy()
        
        if style == ImageStyle.ANIME:
            # Enhance edges for anime/cartoon faces
            kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
            processed = cv2.filter2D(processed, -1, kernel)
            # Increase contrast
            processed = cv2.convertScaleAbs(processed, alpha=1.2, beta=10)
            
        elif style == ImageStyle.ANCIENT:
            # Enhance contrast and reduce noise for ancient art
            processed = cv2.convertScaleAbs(processed, alpha=1.3, beta=15)
            # Apply slight gaussian blur to reduce texture noise
            processed = cv2.GaussianBlur(processed, (3, 3), 0.5)
            
        return processed
    
    def get_detections_for_batch(self, images):
        # Auto-detect style if needed
        if self.image_style == ImageStyle.AUTO and len(images) > 0:
            detected_style = detect_image_style(images[0])
            current_threshold = self.style_thresholds[detected_style]
            if self.verbose:
                print(f"Auto-detected style: {detected_style.name}, using threshold: {current_threshold}")
        else:
            detected_style = self.image_style
            current_threshold = self.confidence_threshold
        
        # Apply style-specific preprocessing
        processed_images = []
        for img in images:
            processed = self.preprocess_image_by_style(img, detected_style)
            processed_images.append(processed)
        
        # Convert color format
        processed_images = np.array(processed_images)
        processed_images = processed_images[..., ::-1]
        
        # Detect faces with original method
        detected_faces = self.face_detector.detect_from_batch(processed_images.copy())
        results = []

        for i, d in enumerate(detected_faces):
            if len(d) == 0:
                results.append(None)
                continue
            
            # Filter by style-specific confidence threshold
            valid_detections = [det for det in d if len(det) > 4 and det[-1] > current_threshold]
            
            if len(valid_detections) == 0:
                results.append(None)
                continue
                
            # Use the detection with highest confidence
            best_detection = max(valid_detections, key=lambda x: x[-1])
            best_detection = np.clip(best_detection, 0, None)
            
            x1, y1, x2, y2 = map(int, best_detection[:-1])
            results.append((x1, y1, x2, y2))

        return results