戒酒的李白

Optimize code structure and enhance security features.

1 import os 1 import os
2 -import re  
3 import getpass 2 import getpass
4 import pymysql 3 import pymysql
5 import subprocess 4 import subprocess
6 -from flask import Flask, session, request, redirect, render_template, jsonify 5 +from flask import Flask, session, request, redirect
7 from apscheduler.schedulers.background import BackgroundScheduler 6 from apscheduler.schedulers.background import BackgroundScheduler
8 from pytz import utc 7 from pytz import utc
9 from datetime import datetime, timedelta 8 from datetime import datetime, timedelta
10 -import time  
11 -from utils.logger import app_logger as logging  
12 -from utils.db_manager import DatabaseManager  
13 import secrets 9 import secrets
14 from dotenv import load_dotenv 10 from dotenv import load_dotenv
15 -from functools import wraps  
16 -import bleach 11 +from utils.logger import app_logger as logging
  12 +from utils.db_pool import DatabasePool
  13 +from utils.error_handlers import register_error_handlers
  14 +from middleware.security import set_secure_headers, log_request_info, require_https
17 15
18 # 加载环境变量 16 # 加载环境变量
19 load_dotenv() 17 load_dotenv()
@@ -56,21 +54,6 @@ def get_db_connection_interactive(): @@ -56,21 +54,6 @@ def get_db_connection_interactive():
56 logging.error(f"数据库连接失败: {e}") 54 logging.error(f"数据库连接失败: {e}")
57 raise 55 raise
58 56
59 -def sanitize_input(text):  
60 - """清理用户输入,防止XSS攻击"""  
61 - if text is None:  
62 - return None  
63 - return bleach.clean(str(text), strip=True)  
64 -  
65 -def set_secure_headers(response):  
66 - """设置安全响应头"""  
67 - response.headers['X-Content-Type-Options'] = 'nosniff'  
68 - response.headers['X-Frame-Options'] = 'SAMEORIGIN'  
69 - response.headers['X-XSS-Protection'] = '1; mode=block'  
70 - response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'  
71 - response.headers['Content-Security-Policy'] = "default-src 'self'"  
72 - return response  
73 -  
74 # 初始化 Flask 应用 57 # 初始化 Flask 应用
75 app = Flask(__name__) 58 app = Flask(__name__)
76 app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32)) 59 app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32))
@@ -87,11 +70,15 @@ from views.workflow_api import workflow_bp, workflow_api_bp @@ -87,11 +70,15 @@ from views.workflow_api import workflow_bp, workflow_api_bp
87 app.register_blueprint(page.pb) 70 app.register_blueprint(page.pb)
88 app.register_blueprint(user.ub) 71 app.register_blueprint(user.ub)
89 app.register_blueprint(spider_bp) 72 app.register_blueprint(spider_bp)
90 -app.register_blueprint(workflow_bp) # 注册工作流蓝图  
91 -app.register_blueprint(workflow_api_bp) # 注册工作流API蓝图 73 +app.register_blueprint(workflow_bp)
  74 +app.register_blueprint(workflow_api_bp)
  75 +
  76 +# 注册错误处理器
  77 +register_error_handlers(app)
92 78
93 # 首页路由 79 # 首页路由
94 @app.route('/') 80 @app.route('/')
  81 +@require_https()
95 def hello_world(): 82 def hello_world():
96 session.clear() 83 session.clear()
97 return redirect('/user/login') 84 return redirect('/user/login')
@@ -99,10 +86,8 @@ def hello_world(): @@ -99,10 +86,8 @@ def hello_world():
99 # 请求前中间件 86 # 请求前中间件
100 @app.before_request 87 @app.before_request
101 def before_request(): 88 def before_request():
102 - # 检查是否是HTTPS  
103 - if not request.is_secure and not app.debug:  
104 - url = request.url.replace('http://', 'https://', 1)  
105 - return redirect(url, code=301) 89 + # 记录请求信息
  90 + log_request_info()
106 91
107 # 如果请求的是静态文件路径,允许访问 92 # 如果请求的是静态文件路径,允许访问
108 if request.path.startswith('/static'): 93 if request.path.startswith('/static'):
@@ -138,35 +123,6 @@ def before_request(): @@ -138,35 +123,6 @@ def before_request():
138 def after_request(response): 123 def after_request(response):
139 return set_secure_headers(response) 124 return set_secure_headers(response)
140 125
141 -# 错误处理  
142 -@app.errorhandler(404)  
143 -def not_found_error(error):  
144 - return render_template('404.html'), 404  
145 -  
146 -@app.errorhandler(500)  
147 -def internal_error(error):  
148 - return render_template('error.html',  
149 - error_code=500,  
150 - error_title='服务器错误',  
151 - error_message='服务器遇到了一个问题,请稍后再试。',  
152 - error_i18n_key='serverError'), 500  
153 -  
154 -@app.errorhandler(403)  
155 -def forbidden_error(error):  
156 - return render_template('error.html',  
157 - error_code=403,  
158 - error_title='禁止访问',  
159 - error_message='您没有权限访问此页面。',  
160 - error_i18n_key='forbidden'), 403  
161 -  
162 -@app.errorhandler(400)  
163 -def bad_request_error(error):  
164 - return render_template('error.html',  
165 - error_code=400,  
166 - error_title='错误请求',  
167 - error_message='服务器无法理解您的请求。',  
168 - error_i18n_key='badRequest'), 400  
169 -  
170 # 数据库配置 126 # 数据库配置
171 DB_CONFIG = { 127 DB_CONFIG = {
172 'host': os.getenv('DB_HOST', 'localhost'), 128 'host': os.getenv('DB_HOST', 'localhost'),
@@ -178,9 +134,6 @@ DB_CONFIG = { @@ -178,9 +134,6 @@ DB_CONFIG = {
178 'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None 134 'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None
179 } 135 }
180 136
181 -# 初始化数据库管理器  
182 -DatabaseManager.initialize(DB_CONFIG)  
183 -  
184 if __name__ == '__main__': 137 if __name__ == '__main__':
185 # 检测是否需要初始化数据库 138 # 检测是否需要初始化数据库
186 try: 139 try:
@@ -194,6 +147,13 @@ if __name__ == '__main__': @@ -194,6 +147,13 @@ if __name__ == '__main__':
194 logging.error(f"数据库初始化失败: {e}") 147 logging.error(f"数据库初始化失败: {e}")
195 exit(1) 148 exit(1)
196 149
  150 + # 初始化数据库连接池
  151 + try:
  152 + DatabasePool.initialize(DB_CONFIG)
  153 + except Exception as e:
  154 + logging.error(f"数据库连接池初始化失败: {e}")
  155 + exit(1)
  156 +
197 # 设置定时任务 157 # 设置定时任务
198 try: 158 try:
199 scheduler = BackgroundScheduler(timezone=utc) 159 scheduler = BackgroundScheduler(timezone=utc)
@@ -222,23 +182,9 @@ if __name__ == '__main__': @@ -222,23 +182,9 @@ if __name__ == '__main__':
222 logging.error(f"应用启动失败: {e}") 182 logging.error(f"应用启动失败: {e}")
223 if 'scheduler' in locals(): 183 if 'scheduler' in locals():
224 scheduler.shutdown() 184 scheduler.shutdown()
  185 + DatabasePool.close()
225 exit(1) 186 exit(1)
226 finally: 187 finally:
227 if 'scheduler' in locals(): 188 if 'scheduler' in locals():
228 scheduler.shutdown() 189 scheduler.shutdown()
229 -  
230 -# 请求日志记录  
231 -@app.before_request  
232 -def log_request_info():  
233 - # 记录请求信息,但排除敏感数据  
234 - sanitized_headers = dict(request.headers)  
235 - if 'Authorization' in sanitized_headers:  
236 - sanitized_headers['Authorization'] = '[FILTERED]'  
237 - if 'Cookie' in sanitized_headers:  
238 - sanitized_headers['Cookie'] = '[FILTERED]'  
239 -  
240 - logging.info(  
241 - f"Request: {request.method} {request.path}\n"  
242 - f"Remote IP: {request.remote_addr}\n"  
243 - f"Headers: {sanitized_headers}"  
244 - ) 190 + DatabasePool.close()
  1 +from flask import request, redirect
  2 +from functools import wraps
  3 +import bleach
  4 +from utils.logger import app_logger as logging
  5 +
  6 +def sanitize_input(text):
  7 + """清理用户输入,防止XSS攻击"""
  8 + if text is None:
  9 + return None
  10 + return bleach.clean(str(text), strip=True)
  11 +
  12 +def set_secure_headers(response):
  13 + """设置安全响应头"""
  14 + response.headers['X-Content-Type-Options'] = 'nosniff'
  15 + response.headers['X-Frame-Options'] = 'SAMEORIGIN'
  16 + response.headers['X-XSS-Protection'] = '1; mode=block'
  17 + response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
  18 + response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline';"
  19 + return response
  20 +
  21 +def require_https():
  22 + """强制HTTPS中间件"""
  23 + def decorator(f):
  24 + @wraps(f)
  25 + def decorated_function(*args, **kwargs):
  26 + if not request.is_secure and not request.is_localhost:
  27 + url = request.url.replace('http://', 'https://', 1)
  28 + return redirect(url, code=301)
  29 + return f(*args, **kwargs)
  30 + return decorated_function
  31 + return decorator
  32 +
  33 +def log_request_info():
  34 + """请求日志记录中间件"""
  35 + sanitized_headers = dict(request.headers)
  36 + if 'Authorization' in sanitized_headers:
  37 + sanitized_headers['Authorization'] = '[FILTERED]'
  38 + if 'Cookie' in sanitized_headers:
  39 + sanitized_headers['Cookie'] = '[FILTERED]'
  40 +
  41 + logging.info(
  42 + f"Request: {request.method} {request.path}\n"
  43 + f"Remote IP: {request.remote_addr}\n"
  44 + f"Headers: {sanitized_headers}"
  45 + )
@@ -88,3 +88,5 @@ xz=5.4.6=h8cc25b3_1 @@ -88,3 +88,5 @@ xz=5.4.6=h8cc25b3_1
88 zipp=3.17.0=py38haa95532_0 88 zipp=3.17.0=py38haa95532_0
89 zlib=1.2.13=h8cc25b3_1 89 zlib=1.2.13=h8cc25b3_1
90 zstd=1.5.5=hd43e919_2 90 zstd=1.5.5=hd43e919_2
  91 +DBUtils==3.0.2
  92 +bleach==6.1.0
  1 +import pymysql
  2 +from pymysql.cursors import DictCursor
  3 +from dbutils.pooled_db import PooledDB
  4 +from utils.logger import app_logger as logging
  5 +
  6 +class DatabasePool:
  7 + _pool = None
  8 +
  9 + @classmethod
  10 + def initialize(cls, db_config):
  11 + """初始化数据库连接池"""
  12 + try:
  13 + cls._pool = PooledDB(
  14 + creator=pymysql,
  15 + maxconnections=10,
  16 + mincached=2,
  17 + maxcached=5,
  18 + maxshared=3,
  19 + blocking=True,
  20 + maxusage=None,
  21 + setsession=[],
  22 + ping=0,
  23 + host=db_config['host'],
  24 + port=db_config['port'],
  25 + user=db_config['user'],
  26 + password=db_config['password'],
  27 + database=db_config['database'],
  28 + charset=db_config['charset'],
  29 + cursorclass=DictCursor,
  30 + ssl=db_config.get('ssl')
  31 + )
  32 + logging.info("数据库连接池初始化成功")
  33 + except Exception as e:
  34 + logging.error(f"数据库连接池初始化失败: {e}")
  35 + raise
  36 +
  37 + @classmethod
  38 + def get_connection(cls):
  39 + """获取数据库连接"""
  40 + if cls._pool is None:
  41 + raise Exception("数据库连接池未初始化")
  42 + return cls._pool.connection()
  43 +
  44 + @classmethod
  45 + def close(cls):
  46 + """关闭数据库连接池"""
  47 + if cls._pool:
  48 + cls._pool._pool.close()
  49 + cls._pool = None
  50 + logging.info("数据库连接池已关闭")
  1 +from flask import render_template
  2 +from utils.logger import app_logger as logging
  3 +
  4 +def register_error_handlers(app):
  5 + """注册错误处理器"""
  6 +
  7 + @app.errorhandler(404)
  8 + def not_found_error(error):
  9 + logging.warning(f"404错误: {request.url}")
  10 + return render_template('404.html'), 404
  11 +
  12 + @app.errorhandler(500)
  13 + def internal_error(error):
  14 + logging.error(f"500错误: {error}")
  15 + return render_template('error.html',
  16 + error_code=500,
  17 + error_title='服务器错误',
  18 + error_message='服务器遇到了一个问题,请稍后再试。',
  19 + error_i18n_key='serverError'), 500
  20 +
  21 + @app.errorhandler(403)
  22 + def forbidden_error(error):
  23 + logging.warning(f"403错误: {request.url}")
  24 + return render_template('error.html',
  25 + error_code=403,
  26 + error_title='禁止访问',
  27 + error_message='您没有权限访问此页面。',
  28 + error_i18n_key='forbidden'), 403
  29 +
  30 + @app.errorhandler(400)
  31 + def bad_request_error(error):
  32 + logging.warning(f"400错误: {error}")
  33 + return render_template('error.html',
  34 + error_code=400,
  35 + error_title='错误请求',
  36 + error_message='服务器无法理解您的请求。',
  37 + error_i18n_key='badRequest'), 400
  38 +
  39 + @app.errorhandler(Exception)
  40 + def handle_exception(error):
  41 + logging.error(f"未处理的异常: {error}")
  42 + return render_template('error.html',
  43 + error_code=500,
  44 + error_title='系统错误',
  45 + error_message='系统发生了一个未预期的错误。',
  46 + error_i18n_key='unexpectedError'), 500