middleware.py 4.58 KB
import os
import time
import logging
from logging.handlers import RotatingFileHandler
from datetime import datetime
from typing import Optional, Tuple
from fastapi import Request, HTTPException, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
import httpx

# Configuration
ENABLE_IP_FILTER = os.getenv("ENABLE_IP_FILTER", "false").lower() == "true"
LOG_FILE = os.path.join(os.path.expanduser("~"), ".pm2", "logs", "access_monitor.log")

# Setup Logger
logger = logging.getLogger("access_monitor")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')

# Use RotatingFileHandler instead of FileHandler
# maxBytes=10MB, backupCount=10 (Max storage ~100MB)
file_handler = RotatingFileHandler(LOG_FILE, maxBytes=10*1024*1024, backupCount=10, encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

# IP Cache: {ip: {"is_cn": bool, "location": str, "timestamp": float}}
IP_CACHE = {}

async def get_ip_info(ip: str) -> Tuple[bool, str]:
    """
    Check if IP is from China and get location.
    Returns (is_cn, location_string)
    """
    # Localhost / Private IP checks
    if ip == "127.0.0.1" or ip == "::1" or ip.startswith("192.168.") or ip.startswith("10."):
        return True, "Local Network"

    # Check Cache
    if ip in IP_CACHE:
        # Cache for 24 hours
        if time.time() - IP_CACHE[ip]["timestamp"] < 86400:
            return IP_CACHE[ip]["is_cn"], IP_CACHE[ip]["location"]

    # Query API
    # Using ip-api.com (free, non-commercial use, 45 req/min)
    # Lang=zh-CN for Chinese output
    url = f"http://ip-api.com/json/{ip}?lang=zh-CN&fields=status,countryCode,country,regionName,city"
    
    try:
        async with httpx.AsyncClient(timeout=3.0) as client:
            resp = await client.get(url)
            data = resp.json()
            
            if data.get("status") == "success":
                is_cn = data.get("countryCode") == "CN"
                location = f"{data.get('country')}-{data.get('regionName')}-{data.get('city')}"
                
                # Update Cache
                IP_CACHE[ip] = {
                    "is_cn": is_cn,
                    "location": location,
                    "timestamp": time.time()
                }
                return is_cn, location
            else:
                # API failed or private IP
                return True, "Unknown/Private"
                
    except Exception as e:
        print(f"IP Lookup failed for {ip}: {e}")
        # Fail open (allow access) if API fails
        return True, "Lookup Failed"

class IPFilterMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        start_time = time.time()
        
        # Get Client IP
        # X-Forwarded-For is preferred if behind proxy (like Nginx), but here we might be direct or behind pm2
        forwarded = request.headers.get("X-Forwarded-For")
        if forwarded:
            client_ip = forwarded.split(",")[0]
        else:
            client_ip = request.client.host if request.client else "unknown"

        # Determine Port (Local port server is listening on)
        server_port = request.scope.get("server", ("", "unknown"))[1]

        # IP Lookup
        is_allowed = True
        location = "Unknown"
        
        # Optimize: Skip static assets logs to reduce noise? 
        # Requirement says "Add record access IP log", implies all or relevant ones.
        # We will log everything for now or filter ext.
        
        is_cn, location = await get_ip_info(client_ip)
        
        if ENABLE_IP_FILTER and not is_cn:
            is_allowed = False
        
        # Explicit check for HK/TW/MO if they somehow pass as CN (though API shouldn't do that)
        # or if user wants double security.
        # Based on ip-api.com, 'HK', 'TW', 'MO' are country codes, not 'CN'.
        # So 'is_cn' check above is sufficient.
        
        # DEBUG: Log if we are blocking to ensure config works
        if not is_allowed:
            print(f"[BLOCK] Blocked {client_ip} ({location})")

        # Access Time
        access_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        # Log Format: IP | Location | Time | Port | Path | Allowed
        log_msg = f"{client_ip} | {location} | {access_time} | {server_port} | {request.url.path} | {'ALLOWED' if is_allowed else 'BLOCKED'}"
        logger.info(log_msg)

        if not is_allowed:
            return PlainTextResponse("Access Denied: Mainland China IP Required", status_code=403)

        response = await call_next(request)
        return response