戒酒的李白

Optimize code structure and enhance security features.

import os
import re
import getpass
import pymysql
import subprocess
from flask import Flask, session, request, redirect, render_template, jsonify
from flask import Flask, session, request, redirect
from apscheduler.schedulers.background import BackgroundScheduler
from pytz import utc
from datetime import datetime, timedelta
import time
from utils.logger import app_logger as logging
from utils.db_manager import DatabaseManager
import secrets
from dotenv import load_dotenv
from functools import wraps
import bleach
from utils.logger import app_logger as logging
from utils.db_pool import DatabasePool
from utils.error_handlers import register_error_handlers
from middleware.security import set_secure_headers, log_request_info, require_https
# 加载环境变量
load_dotenv()
... ... @@ -56,21 +54,6 @@ def get_db_connection_interactive():
logging.error(f"数据库连接失败: {e}")
raise
def sanitize_input(text):
"""清理用户输入,防止XSS攻击"""
if text is None:
return None
return bleach.clean(str(text), strip=True)
def set_secure_headers(response):
"""设置安全响应头"""
response.headers['X-Content-Type-Options'] = 'nosniff'
response.headers['X-Frame-Options'] = 'SAMEORIGIN'
response.headers['X-XSS-Protection'] = '1; mode=block'
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
response.headers['Content-Security-Policy'] = "default-src 'self'"
return response
# 初始化 Flask 应用
app = Flask(__name__)
app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32))
... ... @@ -87,11 +70,15 @@ from views.workflow_api import workflow_bp, workflow_api_bp
app.register_blueprint(page.pb)
app.register_blueprint(user.ub)
app.register_blueprint(spider_bp)
app.register_blueprint(workflow_bp) # 注册工作流蓝图
app.register_blueprint(workflow_api_bp) # 注册工作流API蓝图
app.register_blueprint(workflow_bp)
app.register_blueprint(workflow_api_bp)
# 注册错误处理器
register_error_handlers(app)
# 首页路由
@app.route('/')
@require_https()
def hello_world():
session.clear()
return redirect('/user/login')
... ... @@ -99,11 +86,9 @@ def hello_world():
# 请求前中间件
@app.before_request
def before_request():
# 检查是否是HTTPS
if not request.is_secure and not app.debug:
url = request.url.replace('http://', 'https://', 1)
return redirect(url, code=301)
# 记录请求信息
log_request_info()
# 如果请求的是静态文件路径,允许访问
if request.path.startswith('/static'):
return
... ... @@ -138,35 +123,6 @@ def before_request():
def after_request(response):
return set_secure_headers(response)
# 错误处理
@app.errorhandler(404)
def not_found_error(error):
return render_template('404.html'), 404
@app.errorhandler(500)
def internal_error(error):
return render_template('error.html',
error_code=500,
error_title='服务器错误',
error_message='服务器遇到了一个问题,请稍后再试。',
error_i18n_key='serverError'), 500
@app.errorhandler(403)
def forbidden_error(error):
return render_template('error.html',
error_code=403,
error_title='禁止访问',
error_message='您没有权限访问此页面。',
error_i18n_key='forbidden'), 403
@app.errorhandler(400)
def bad_request_error(error):
return render_template('error.html',
error_code=400,
error_title='错误请求',
error_message='服务器无法理解您的请求。',
error_i18n_key='badRequest'), 400
# 数据库配置
DB_CONFIG = {
'host': os.getenv('DB_HOST', 'localhost'),
... ... @@ -178,9 +134,6 @@ DB_CONFIG = {
'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None
}
# 初始化数据库管理器
DatabaseManager.initialize(DB_CONFIG)
if __name__ == '__main__':
# 检测是否需要初始化数据库
try:
... ... @@ -194,6 +147,13 @@ if __name__ == '__main__':
logging.error(f"数据库初始化失败: {e}")
exit(1)
# 初始化数据库连接池
try:
DatabasePool.initialize(DB_CONFIG)
except Exception as e:
logging.error(f"数据库连接池初始化失败: {e}")
exit(1)
# 设置定时任务
try:
scheduler = BackgroundScheduler(timezone=utc)
... ... @@ -222,23 +182,9 @@ if __name__ == '__main__':
logging.error(f"应用启动失败: {e}")
if 'scheduler' in locals():
scheduler.shutdown()
DatabasePool.close()
exit(1)
finally:
if 'scheduler' in locals():
scheduler.shutdown()
# 请求日志记录
@app.before_request
def log_request_info():
# 记录请求信息,但排除敏感数据
sanitized_headers = dict(request.headers)
if 'Authorization' in sanitized_headers:
sanitized_headers['Authorization'] = '[FILTERED]'
if 'Cookie' in sanitized_headers:
sanitized_headers['Cookie'] = '[FILTERED]'
logging.info(
f"Request: {request.method} {request.path}\n"
f"Remote IP: {request.remote_addr}\n"
f"Headers: {sanitized_headers}"
)
DatabasePool.close()
... ...
from flask import request, redirect
from functools import wraps
import bleach
from utils.logger import app_logger as logging
def sanitize_input(text):
"""清理用户输入,防止XSS攻击"""
if text is None:
return None
return bleach.clean(str(text), strip=True)
def set_secure_headers(response):
"""设置安全响应头"""
response.headers['X-Content-Type-Options'] = 'nosniff'
response.headers['X-Frame-Options'] = 'SAMEORIGIN'
response.headers['X-XSS-Protection'] = '1; mode=block'
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline';"
return response
def require_https():
"""强制HTTPS中间件"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
if not request.is_secure and not request.is_localhost:
url = request.url.replace('http://', 'https://', 1)
return redirect(url, code=301)
return f(*args, **kwargs)
return decorated_function
return decorator
def log_request_info():
"""请求日志记录中间件"""
sanitized_headers = dict(request.headers)
if 'Authorization' in sanitized_headers:
sanitized_headers['Authorization'] = '[FILTERED]'
if 'Cookie' in sanitized_headers:
sanitized_headers['Cookie'] = '[FILTERED]'
logging.info(
f"Request: {request.method} {request.path}\n"
f"Remote IP: {request.remote_addr}\n"
f"Headers: {sanitized_headers}"
)
\ No newline at end of file
... ...
... ... @@ -88,3 +88,5 @@ xz=5.4.6=h8cc25b3_1
zipp=3.17.0=py38haa95532_0
zlib=1.2.13=h8cc25b3_1
zstd=1.5.5=hd43e919_2
DBUtils==3.0.2
bleach==6.1.0
... ...
import pymysql
from pymysql.cursors import DictCursor
from dbutils.pooled_db import PooledDB
from utils.logger import app_logger as logging
class DatabasePool:
_pool = None
@classmethod
def initialize(cls, db_config):
"""初始化数据库连接池"""
try:
cls._pool = PooledDB(
creator=pymysql,
maxconnections=10,
mincached=2,
maxcached=5,
maxshared=3,
blocking=True,
maxusage=None,
setsession=[],
ping=0,
host=db_config['host'],
port=db_config['port'],
user=db_config['user'],
password=db_config['password'],
database=db_config['database'],
charset=db_config['charset'],
cursorclass=DictCursor,
ssl=db_config.get('ssl')
)
logging.info("数据库连接池初始化成功")
except Exception as e:
logging.error(f"数据库连接池初始化失败: {e}")
raise
@classmethod
def get_connection(cls):
"""获取数据库连接"""
if cls._pool is None:
raise Exception("数据库连接池未初始化")
return cls._pool.connection()
@classmethod
def close(cls):
"""关闭数据库连接池"""
if cls._pool:
cls._pool._pool.close()
cls._pool = None
logging.info("数据库连接池已关闭")
\ No newline at end of file
... ...
from flask import render_template
from utils.logger import app_logger as logging
def register_error_handlers(app):
"""注册错误处理器"""
@app.errorhandler(404)
def not_found_error(error):
logging.warning(f"404错误: {request.url}")
return render_template('404.html'), 404
@app.errorhandler(500)
def internal_error(error):
logging.error(f"500错误: {error}")
return render_template('error.html',
error_code=500,
error_title='服务器错误',
error_message='服务器遇到了一个问题,请稍后再试。',
error_i18n_key='serverError'), 500
@app.errorhandler(403)
def forbidden_error(error):
logging.warning(f"403错误: {request.url}")
return render_template('error.html',
error_code=403,
error_title='禁止访问',
error_message='您没有权限访问此页面。',
error_i18n_key='forbidden'), 403
@app.errorhandler(400)
def bad_request_error(error):
logging.warning(f"400错误: {error}")
return render_template('error.html',
error_code=400,
error_title='错误请求',
error_message='服务器无法理解您的请求。',
error_i18n_key='badRequest'), 400
@app.errorhandler(Exception)
def handle_exception(error):
logging.error(f"未处理的异常: {error}")
return render_template('error.html',
error_code=500,
error_title='系统错误',
error_message='系统发生了一个未预期的错误。',
error_i18n_key='unexpectedError'), 500
\ No newline at end of file
... ...