middleware.py
4.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import time
import logging
from logging.handlers import RotatingFileHandler
from datetime import datetime
from typing import Optional, Tuple
from fastapi import Request, HTTPException, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
import httpx
# Configuration
ENABLE_IP_FILTER = os.getenv("ENABLE_IP_FILTER", "false").lower() == "true"
LOG_FILE = os.path.join(os.path.expanduser("~"), ".pm2", "logs", "access_monitor.log")
# Setup Logger
logger = logging.getLogger("access_monitor")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
# Use RotatingFileHandler instead of FileHandler
# maxBytes=10MB, backupCount=10 (Max storage ~100MB)
file_handler = RotatingFileHandler(LOG_FILE, maxBytes=10*1024*1024, backupCount=10, encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# IP Cache: {ip: {"is_cn": bool, "location": str, "timestamp": float}}
IP_CACHE = {}
async def get_ip_info(ip: str) -> Tuple[bool, str]:
"""
Check if IP is from China and get location.
Returns (is_cn, location_string)
"""
# Localhost / Private IP checks
if ip == "127.0.0.1" or ip == "::1" or ip.startswith("192.168.") or ip.startswith("10."):
return True, "Local Network"
# Check Cache
if ip in IP_CACHE:
# Cache for 24 hours
if time.time() - IP_CACHE[ip]["timestamp"] < 86400:
return IP_CACHE[ip]["is_cn"], IP_CACHE[ip]["location"]
# Query API
# Using ip-api.com (free, non-commercial use, 45 req/min)
# Lang=zh-CN for Chinese output
url = f"http://ip-api.com/json/{ip}?lang=zh-CN&fields=status,countryCode,country,regionName,city"
try:
async with httpx.AsyncClient(timeout=3.0) as client:
resp = await client.get(url)
data = resp.json()
if data.get("status") == "success":
is_cn = data.get("countryCode") == "CN"
location = f"{data.get('country')}-{data.get('regionName')}-{data.get('city')}"
# Update Cache
IP_CACHE[ip] = {
"is_cn": is_cn,
"location": location,
"timestamp": time.time()
}
return is_cn, location
else:
# API failed or private IP
return True, "Unknown/Private"
except Exception as e:
print(f"IP Lookup failed for {ip}: {e}")
# Fail open (allow access) if API fails
return True, "Lookup Failed"
class IPFilterMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
start_time = time.time()
# Get Client IP
# X-Forwarded-For is preferred if behind proxy (like Nginx), but here we might be direct or behind pm2
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
client_ip = forwarded.split(",")[0]
else:
client_ip = request.client.host if request.client else "unknown"
# Determine Port (Local port server is listening on)
server_port = request.scope.get("server", ("", "unknown"))[1]
# IP Lookup
is_allowed = True
location = "Unknown"
# Optimize: Skip static assets logs to reduce noise?
# Requirement says "Add record access IP log", implies all or relevant ones.
# We will log everything for now or filter ext.
is_cn, location = await get_ip_info(client_ip)
if ENABLE_IP_FILTER and not is_cn:
is_allowed = False
# Explicit check for HK/TW/MO if they somehow pass as CN (though API shouldn't do that)
# or if user wants double security.
# Based on ip-api.com, 'HK', 'TW', 'MO' are country codes, not 'CN'.
# So 'is_cn' check above is sufficient.
# DEBUG: Log if we are blocking to ensure config works
if not is_allowed:
print(f"[BLOCK] Blocked {client_ip} ({location})")
# Access Time
access_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Log Format: IP | Location | Time | Port | Path | Allowed
log_msg = f"{client_ip} | {location} | {access_time} | {server_port} | {request.url.path} | {'ALLOWED' if is_allowed else 'BLOCKED'}"
logger.info(log_msg)
if not is_allowed:
return PlainTextResponse("Access Denied: Mainland China IP Required", status_code=403)
response = await call_next(request)
return response