戒酒的李白

Comprehensive security enhancement, fix race conditions and injection vulnerabilities.

@@ -3,13 +3,20 @@ import re @@ -3,13 +3,20 @@ import re
3 import getpass 3 import getpass
4 import pymysql 4 import pymysql
5 import subprocess 5 import subprocess
6 -from flask import Flask, session, request, redirect, render_template 6 +from flask import Flask, session, request, redirect, render_template, jsonify
7 from apscheduler.schedulers.background import BackgroundScheduler 7 from apscheduler.schedulers.background import BackgroundScheduler
8 from pytz import utc 8 from pytz import utc
9 from datetime import datetime, timedelta 9 from datetime import datetime, timedelta
10 import time 10 import time
11 from utils.logger import app_logger as logging 11 from utils.logger import app_logger as logging
12 from utils.db_manager import DatabaseManager 12 from utils.db_manager import DatabaseManager
  13 +import secrets
  14 +from dotenv import load_dotenv
  15 +from functools import wraps
  16 +import bleach
  17 +
  18 +# 加载环境变量
  19 +load_dotenv()
13 20
14 def get_db_connection_interactive(): 21 def get_db_connection_interactive():
15 """ 22 """
@@ -18,17 +25,17 @@ def get_db_connection_interactive(): @@ -18,17 +25,17 @@ def get_db_connection_interactive():
18 """ 25 """
19 print("请依次输入数据库连接信息(直接按回车使用默认值):") 26 print("请依次输入数据库连接信息(直接按回车使用默认值):")
20 27
21 - host = input(" 1. 主机 (默认: localhost): ") or "localhost"  
22 - port_str = input(" 2. 端口 (默认: 3306): ") or "3306" 28 + host = input(" 1. 主机 (默认: localhost): ") or os.getenv('DB_HOST', 'localhost')
  29 + port_str = input(" 2. 端口 (默认: 3306): ") or os.getenv('DB_PORT', '3306')
23 try: 30 try:
24 port = int(port_str) 31 port = int(port_str)
25 except ValueError: 32 except ValueError:
26 logging.warning("端口号无效,使用默认端口 3306。") 33 logging.warning("端口号无效,使用默认端口 3306。")
27 port = 3306 34 port = 3306
28 35
29 - user = input(" 3. 用户名 (默认: root): ") or "root"  
30 - password = getpass.getpass(" 4. 密码 (默认: 12345678): ") or "12345678"  
31 - db_name = input(" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): ") or "Weibo_PublicOpinion_AnalysisSystem" 36 + user = input(" 3. 用户名 (默认: root): ") or os.getenv('DB_USER', 'root')
  37 + password = getpass.getpass(" 4. 密码: ") or os.getenv('DB_PASSWORD', '')
  38 + db_name = input(" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): ") or os.getenv('DB_NAME', 'Weibo_PublicOpinion_AnalysisSystem')
32 39
33 logging.info(f"尝试连接到数据库: {user}@{host}:{port}/{db_name}") 40 logging.info(f"尝试连接到数据库: {user}@{host}:{port}/{db_name}")
34 41
@@ -40,237 +47,183 @@ def get_db_connection_interactive(): @@ -40,237 +47,183 @@ def get_db_connection_interactive():
40 password=password, 47 password=password,
41 database=db_name, 48 database=db_name,
42 charset='utf8mb4', 49 charset='utf8mb4',
43 - cursorclass=pymysql.cursors.DictCursor # 返回字典格式 50 + cursorclass=pymysql.cursors.DictCursor,
  51 + ssl={'ssl': {'ca': os.getenv('DB_SSL_CA')}} if os.getenv('DB_SSL_CA') else None
44 ) 52 )
45 logging.info("数据库连接成功。") 53 logging.info("数据库连接成功。")
46 return connection 54 return connection
47 except pymysql.MySQLError as e: 55 except pymysql.MySQLError as e:
48 logging.error(f"数据库连接失败: {e}") 56 logging.error(f"数据库连接失败: {e}")
49 - exit(1)  
50 -  
51 -def initialize_database(connection, sql_file_path):  
52 - """  
53 - 执行 SQL 文件中的语句以初始化数据库。  
54 -  
55 - :param connection: 已建立的数据库连接  
56 - :param sql_file_path: SQL 文件的路径  
57 - """  
58 - try:  
59 - with open(sql_file_path, 'r', encoding='utf8') as file:  
60 - sql_commands = file.read()  
61 -  
62 - with connection.cursor() as cursor:  
63 - for statement in sql_commands.split(';'):  
64 - statement = statement.strip()  
65 - if statement:  
66 - cursor.execute(statement)  
67 - connection.commit()  
68 - logging.info("数据库初始化成功。")  
69 - except FileNotFoundError:  
70 - logging.error(f"SQL 文件未找到: {sql_file_path}")  
71 - exit(1)  
72 - except pymysql.MySQLError as e:  
73 - logging.error(f"执行 SQL 时出错: {e}")  
74 - connection.rollback()  
75 - exit(1)  
76 - except Exception as e:  
77 - logging.error(f"初始化数据库时出错: {e}")  
78 - connection.rollback()  
79 - exit(1)  
80 -  
81 -def prompt_first_run():  
82 - """  
83 - 询问用户是否首次运行,需要初始化数据库。  
84 -  
85 - :return: Boolean,True 表示需要初始化数据库  
86 - """  
87 - while True:  
88 - choice = input("是否首次运行该项目,需要初始化数据库?(Y/n): ").strip().lower()  
89 - if choice in ['y', 'yes', '']:  
90 - return True  
91 - elif choice in ['n', 'no']:  
92 - return False  
93 - else:  
94 - print("请输入 Y 或 N。") 57 + raise
  58 +
  59 +def sanitize_input(text):
  60 + """清理用户输入,防止XSS攻击"""
  61 + if text is None:
  62 + return None
  63 + return bleach.clean(str(text), strip=True)
  64 +
  65 +def set_secure_headers(response):
  66 + """设置安全响应头"""
  67 + response.headers['X-Content-Type-Options'] = 'nosniff'
  68 + response.headers['X-Frame-Options'] = 'SAMEORIGIN'
  69 + response.headers['X-XSS-Protection'] = '1; mode=block'
  70 + response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
  71 + response.headers['Content-Security-Policy'] = "default-src 'self'"
  72 + return response
95 73
96 # 初始化 Flask 应用 74 # 初始化 Flask 应用
97 app = Flask(__name__) 75 app = Flask(__name__)
98 -app.secret_key = 'this is secret_key you know ?' # 设置 Flask 的密钥,用于 session 加密 76 +app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32))
  77 +app.config['SESSION_COOKIE_SECURE'] = True
  78 +app.config['SESSION_COOKIE_HTTPONLY'] = True
  79 +app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
  80 +app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=2)
99 81
100 # 导入蓝图 82 # 导入蓝图
101 from views.page import page 83 from views.page import page
102 from views.user import user 84 from views.user import user
103 from views.spider_control import spider_bp 85 from views.spider_control import spider_bp
104 -app.register_blueprint(page.pb) # 注册页面蓝图  
105 -app.register_blueprint(user.ub) # 注册用户蓝图  
106 -app.register_blueprint(spider_bp) # 注册爬虫控制蓝图 86 +app.register_blueprint(page.pb)
  87 +app.register_blueprint(user.ub)
  88 +app.register_blueprint(spider_bp)
107 89
108 -# 首页路由,清空 session 90 +# 首页路由
109 @app.route('/') 91 @app.route('/')
110 def hello_world(): 92 def hello_world():
111 - session.clear() # 清空 session,用户退出登录  
112 - return "Session Cleared" 93 + session.clear()
  94 + return redirect('/user/login')
113 95
114 -# 中间件:处理请求前的逻辑 96 +# 请求前中间件
115 @app.before_request 97 @app.before_request
116 def before_request(): 98 def before_request():
  99 + # 检查是否是HTTPS
  100 + if not request.is_secure and not app.debug:
  101 + url = request.url.replace('http://', 'https://', 1)
  102 + return redirect(url, code=301)
  103 +
117 # 如果请求的是静态文件路径,允许访问 104 # 如果请求的是静态文件路径,允许访问
118 if request.path.startswith('/static'): 105 if request.path.startswith('/static'):
119 return 106 return
120 - 107 +
121 # 如果请求的是登录或注册页面,不需要会话验证 108 # 如果请求的是登录或注册页面,不需要会话验证
122 if request.path in ['/user/login', '/user/register']: 109 if request.path in ['/user/login', '/user/register']:
123 return 110 return
124 -  
125 - # 如果 session 中没有用户名,重定向到登录页面 111 +
  112 + # 验证会话
126 if not session.get('username'): 113 if not session.get('username'):
127 return redirect('/user/login') 114 return redirect('/user/login')
  115 +
  116 + # 验证会话完整性
  117 + if 'client_info' not in session:
  118 + session.clear()
  119 + return redirect('/user/login')
  120 +
  121 + # 验证客户端信息
  122 + current_client = {
  123 + 'ip': request.remote_addr,
  124 + 'user_agent': str(request.user_agent)
  125 + }
  126 + stored_client = session.get('client_info', {})
  127 +
  128 + if (current_client['ip'] != stored_client.get('ip') or
  129 + current_client['user_agent'] != stored_client.get('user_agent')):
  130 + session.clear()
  131 + return redirect('/user/login')
128 132
129 -# 404 错误页面路由  
130 -@app.route('/<path:path>')  
131 -def catch_all(path):  
132 - return render_template('404.html') # 如果路径不存在,返回 404 页面 133 +# 响应后中间件
  134 +@app.after_request
  135 +def after_request(response):
  136 + return set_secure_headers(response)
133 137
134 -# 定义定时任务,运行爬虫脚本  
135 -def run_script():  
136 - current_dir = os.path.dirname(os.path.abspath(__file__)) # 获取当前脚本的目录  
137 - spider_script = os.path.join(current_dir, 'spider', 'main.py') # 爬虫脚本路径  
138 - # cutComments_script = os.path.join(current_dir, 'utils', 'cutComments.py') # 评论处理脚本路径  
139 - # cipingTotal_script = os.path.join(current_dir, 'utils', 'cipingTotal.py') # 评分处理脚本路径 138 +# 错误处理
  139 +@app.errorhandler(404)
  140 +def not_found_error(error):
  141 + return render_template('404.html'), 404
140 142
141 - # 定义所有要运行的脚本  
142 - scripts = [  
143 - ("Spider Script", spider_script),  
144 - # ("Cut Comments Script", cutComments_script),  
145 - # ("Ciping Total Script", cipingTotal_script)  
146 - ] 143 +@app.errorhandler(500)
  144 +def internal_error(error):
  145 + return render_template('500.html'), 500
147 146
148 - # 执行所有脚本  
149 - for script_name, script_path in scripts:  
150 - try:  
151 - logging.info(f"Running {script_name}...")  
152 - subprocess.run(['python', script_path], check=True) # 使用 subprocess 执行脚本  
153 - logging.info(f"{script_name} finished successfully.")  
154 - except subprocess.CalledProcessError as e:  
155 - logging.error(f"An error occurred while running {script_name}: {e}") 147 +@app.errorhandler(403)
  148 +def forbidden_error(error):
  149 + return render_template('403.html'), 403
156 150
157 -# 新增功能:动态调度爬虫脚本  
158 -def check_database_empty():  
159 - """  
160 - 检查数据库中的指定表是否为空。  
161 -  
162 - :return: 如果表为空则返回 True,否则返回 False  
163 - """  
164 - try:  
165 - connection = pymysql.connect(**DB_CONFIG)  
166 - with connection.cursor() as cursor:  
167 - cursor.execute("SELECT COUNT(*) as count FROM article")  
168 - result = cursor.fetchone()  
169 - count = result['count'] if result and 'count' in result else 0  
170 - logging.info(f"数据库中共有 {count} 条记录。")  
171 - return count == 0  
172 - except pymysql.MySQLError as e:  
173 - logging.error(f"检查数据库失败: {e}")  
174 - return True # 连接失败时假设数据库为空,以防止阻塞  
175 - finally:  
176 - if 'connection' in locals():  
177 - connection.close()  
178 -  
179 -def dynamic_crawl():  
180 - """  
181 - 执行爬取任务并根据爬取耗时和获取的数据量动态调度下次爬取时间。  
182 - """  
183 - try:  
184 - start_time = time.time()  
185 - logging.info("开始爬取数据。")  
186 -  
187 - run_script() # 执行爬虫脚本  
188 -  
189 - end_time = time.time()  
190 - duration = end_time - start_time # 爬取耗时  
191 -  
192 - # 获取爬取后数据库中记录的数量作为数据量  
193 - try:  
194 - connection = pymysql.connect(**DB_CONFIG)  
195 - with connection.cursor() as cursor:  
196 - cursor.execute("SELECT COUNT(*) as count FROM article")  
197 - result = cursor.fetchone()  
198 - data_fetched = result['count'] if result and 'count' in result else 0  
199 - logging.info(f"爬取完成,耗时 {duration:.2f} 秒,数据库中共有 {data_fetched} 条记录。")  
200 - except pymysql.MySQLError as e:  
201 - logging.error(f"获取数据量失败: {e}")  
202 - data_fetched = 0  
203 - finally:  
204 - if 'connection' in locals():  
205 - connection.close()  
206 -  
207 - # 根据爬取耗时和数据量调整下次爬取时间  
208 - base_interval = 5 * 60 * 60 # 5小时的基础时间间隔(秒)  
209 -  
210 - if duration > 3600: # 爬取耗时超过1小时  
211 - next_interval = base_interval + duration  
212 - logging.info(f"检测到长时间爬取。下次爬取将在 {next_interval/3600:.2f} 小时后执行。")  
213 - elif data_fetched < 50: # 获取的数据量少于50条  
214 - next_interval = base_interval / 2  
215 - logging.info(f"获取数据量较少。下次爬取将在 {next_interval/60:.2f} 分钟后执行。")  
216 - else:  
217 - next_interval = base_interval  
218 - logging.info(f"标准爬取完成。下次爬取将在 {next_interval/3600:.2f} 小时后执行。")  
219 -  
220 - # 安排下次爬取任务  
221 - scheduler.add_job(dynamic_crawl, 'date', run_date=datetime.now() + timedelta(seconds=next_interval), id='dynamic_crawl')  
222 -  
223 - except Exception as e:  
224 - logging.error(f"动态爬取过程中发生错误: {e}") 151 +@app.errorhandler(400)
  152 +def bad_request_error(error):
  153 + return render_template('400.html'), 400
225 154
226 -# 数据库配置,用于动态调度功能 155 +# 数据库配置
227 DB_CONFIG = { 156 DB_CONFIG = {
228 - 'host': 'localhost',  
229 - 'user': 'root',  
230 - 'password': '12345678',  
231 - 'database': 'Weibo_PublicOpinion_AnalysisSystem',  
232 - 'port': 3306,  
233 - 'charset': 'utf8mb4' 157 + 'host': os.getenv('DB_HOST', 'localhost'),
  158 + 'user': os.getenv('DB_USER', 'root'),
  159 + 'password': os.getenv('DB_PASSWORD', ''),
  160 + 'database': os.getenv('DB_NAME', 'Weibo_PublicOpinion_AnalysisSystem'),
  161 + 'port': int(os.getenv('DB_PORT', '3306')),
  162 + 'charset': 'utf8mb4',
  163 + 'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None
234 } 164 }
235 165
236 # 初始化数据库管理器 166 # 初始化数据库管理器
237 DatabaseManager.initialize(DB_CONFIG) 167 DatabaseManager.initialize(DB_CONFIG)
238 168
239 -# 主程序入口  
240 if __name__ == '__main__': 169 if __name__ == '__main__':
241 # 检测是否需要初始化数据库 170 # 检测是否需要初始化数据库
242 - if prompt_first_run():  
243 - # 获取数据库连接  
244 - connection = get_db_connection_interactive()  
245 -  
246 - # 执行数据库初始化  
247 - sql_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'createTables.sql')  
248 - initialize_database(connection, sql_file)  
249 -  
250 - # 关闭数据库连接  
251 - connection.close()  
252 - logging.info("数据库连接已关闭。")  
253 -  
254 - # 设置定时任务,动态执行爬虫脚本  
255 - scheduler = BackgroundScheduler(timezone=utc) # 创建后台任务调度器  
256 - scheduler.start() # 启动调度器  
257 -  
258 - # 初始化调度:如果数据库为空,立即爬取;否则,按照基础时间间隔安排首次爬取  
259 - if check_database_empty():  
260 - logging.info("数据库为空。立即开始初始爬取。")  
261 - dynamic_crawl()  
262 - else:  
263 - logging.info("数据库已有数据。安排首次爬取。")  
264 - base_interval = 5 * 60 * 60 # 5小时  
265 - scheduler.add_job(dynamic_crawl, 'date', run_date=datetime.now() + timedelta(seconds=base_interval), id='dynamic_crawl')  
266 -  
267 try: 171 try:
268 - app.run() # 启动 Flask 应用 172 + if os.getenv('INITIALIZE_DB', 'false').lower() == 'true':
  173 + connection = get_db_connection_interactive()
  174 + sql_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'createTables.sql')
  175 + initialize_database(connection, sql_file)
  176 + connection.close()
  177 + logging.info("数据库初始化完成。")
  178 + except Exception as e:
  179 + logging.error(f"数据库初始化失败: {e}")
  180 + exit(1)
  181 +
  182 + # 设置定时任务
  183 + try:
  184 + scheduler = BackgroundScheduler(timezone=utc)
  185 + scheduler.start()
  186 +
  187 + if check_database_empty():
  188 + logging.info("数据库为空。立即开始初始爬取。")
  189 + dynamic_crawl()
  190 + else:
  191 + logging.info("数据库已有数据。安排首次爬取。")
  192 + base_interval = int(os.getenv('CRAWL_INTERVAL', '18000')) # 默认5小时
  193 + scheduler.add_job(
  194 + dynamic_crawl,
  195 + 'date',
  196 + run_date=datetime.now() + timedelta(seconds=base_interval),
  197 + id='dynamic_crawl'
  198 + )
  199 +
  200 + # 启动应用
  201 + app.run(
  202 + host=os.getenv('FLASK_HOST', '127.0.0.1'),
  203 + port=int(os.getenv('FLASK_PORT', '5000')),
  204 + ssl_context='adhoc' if os.getenv('ENABLE_HTTPS', 'false').lower() == 'true' else None
  205 + )
  206 + except Exception as e:
  207 + logging.error(f"应用启动失败: {e}")
  208 + if 'scheduler' in locals():
  209 + scheduler.shutdown()
  210 + exit(1)
269 finally: 211 finally:
270 - scheduler.shutdown() # 确保在应用关闭时关闭调度器 212 + if 'scheduler' in locals():
  213 + scheduler.shutdown()
271 214
272 -# 设置日志记录,捕获应用的请求信息 215 +# 请求日志记录
273 @app.before_request 216 @app.before_request
274 def log_request_info(): 217 def log_request_info():
275 - # 记录每次请求的信息,便于调试和监控  
276 - logging.info(f"Request: {request.method} {request.path}") # 记录请求的方式(GET/POST)和路径 218 + # 记录请求信息,但排除敏感数据
  219 + sanitized_headers = dict(request.headers)
  220 + if 'Authorization' in sanitized_headers:
  221 + sanitized_headers['Authorization'] = '[FILTERED]'
  222 + if 'Cookie' in sanitized_headers:
  223 + sanitized_headers['Cookie'] = '[FILTERED]'
  224 +
  225 + logging.info(
  226 + f"Request: {request.method} {request.path}\n"
  227 + f"Remote IP: {request.remote_addr}\n"
  228 + f"Headers: {sanitized_headers}"
  229 + )
1 -from flask import render_template  
2 -def errorResponse(errorMsg):  
3 - return render_template('error.html',errorMsg=errorMsg)  
  1 +from flask import render_template, jsonify
  2 +import bleach
  3 +import re
  4 +
  5 +def sanitize_error_message(message):
  6 + """
  7 + 清理和验证错误消息
  8 + """
  9 + if not message:
  10 + return "发生未知错误"
  11 +
  12 + # 移除任何敏感信息
  13 + message = re.sub(r'(password|token|key|secret)=[\w\-]+', r'\1=[FILTERED]', str(message))
  14 +
  15 + # 清理HTML和特殊字符
  16 + message = bleach.clean(message, strip=True)
  17 +
  18 + # 限制消息长度
  19 + return message[:200] if len(message) > 200 else message
  20 +
  21 +def errorResponse(errorMsg, status_code=400):
  22 + """
  23 + 统一的错误响应处理
  24 + :param errorMsg: 错误消息
  25 + :param status_code: HTTP状态码
  26 + :return: 错误响应
  27 + """
  28 + safe_message = sanitize_error_message(errorMsg)
  29 +
  30 + if 'application/json' in request.headers.get('Accept', ''):
  31 + return jsonify({
  32 + 'success': False,
  33 + 'error': safe_message
  34 + }), status_code
  35 +
  36 + return render_template(
  37 + 'error.html',
  38 + errorMsg=safe_message,
  39 + status_code=status_code
  40 + ), status_code
1 -from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify 1 +from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify, abort
2 from utils.mynlp import SnowNLP 2 from utils.mynlp import SnowNLP
3 from utils.getHomePageData import * 3 from utils.getHomePageData import *
4 from utils.getHotWordPageData import * 4 from utils.getHotWordPageData import *
@@ -16,12 +16,60 @@ from sqlalchemy import create_engine @@ -16,12 +16,60 @@ from sqlalchemy import create_engine
16 import asyncio 16 import asyncio
17 import torch 17 import torch
18 from BCAT_front.predict import model_manager 18 from BCAT_front.predict import model_manager
  19 +from functools import wraps
  20 +import bleach
  21 +import re
  22 +from datetime import datetime, timedelta
19 23
20 pb = Blueprint('page', 24 pb = Blueprint('page',
21 __name__, 25 __name__,
22 url_prefix='/page', 26 url_prefix='/page',
23 template_folder='templates') 27 template_folder='templates')
24 28
  29 +def sanitize_input(text):
  30 + """清理用户输入,防止XSS攻击"""
  31 + if text is None:
  32 + return None
  33 + return bleach.clean(str(text), strip=True)
  34 +
  35 +def validate_csrf_token():
  36 + """验证CSRF令牌"""
  37 + token = request.form.get('csrf_token')
  38 + stored_token = session.get('csrf_token')
  39 + if not token or not stored_token or token != stored_token:
  40 + return False
  41 + return True
  42 +
  43 +def login_required(f):
  44 + @wraps(f)
  45 + def decorated_function(*args, **kwargs):
  46 + if 'username' not in session:
  47 + return redirect('/user/login')
  48 + return f(*args, **kwargs)
  49 + return decorated_function
  50 +
  51 +def api_login_required(f):
  52 + @wraps(f)
  53 + def decorated_function(*args, **kwargs):
  54 + if 'username' not in session:
  55 + return jsonify({'error': 'Unauthorized'}), 401
  56 + return f(*args, **kwargs)
  57 + return decorated_function
  58 +
  59 +def rate_limit(f):
  60 + @wraps(f)
  61 + def decorated_function(*args, **kwargs):
  62 + key = f"rate_limit:{request.remote_addr}:{f.__name__}"
  63 + current = int(redis_client.get(key) or 0)
  64 + if current >= 100: # 每分钟100次请求限制
  65 + return jsonify({'error': 'Too many requests'}), 429
  66 + pipe = redis_client.pipeline()
  67 + pipe.incr(key)
  68 + pipe.expire(key, 60) # 60秒后重置
  69 + pipe.execute()
  70 + return f(*args, **kwargs)
  71 + return decorated_function
  72 +
25 # 设置设备 73 # 设置设备
26 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 74 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27 75
@@ -37,14 +85,22 @@ except Exception as e: @@ -37,14 +85,22 @@ except Exception as e:
37 logging.error(f"模型加载失败: {e}") 85 logging.error(f"模型加载失败: {e}")
38 86
39 # 数据库配置 87 # 数据库配置
40 -DATABASE_URL = "sqlite:///ai_analysis.db" 88 +DATABASE_URL = os.getenv('DATABASE_URL', "sqlite:///ai_analysis.db")
41 engine = create_engine(DATABASE_URL) 89 engine = create_engine(DATABASE_URL)
42 AIAnalysis.metadata.create_all(engine) 90 AIAnalysis.metadata.create_all(engine)
43 91
44 def predict_sentiment(text): 92 def predict_sentiment(text):
45 """使用改进版模型预测单个文本的情感""" 93 """使用改进版模型预测单个文本的情感"""
46 try: 94 try:
47 - predictions, probabilities = model_manager.predict_batch([text]) 95 + if not text or len(text.strip()) == 0:
  96 + return None, None
  97 +
  98 + # 清理输入
  99 + cleaned_text = sanitize_input(text)
  100 + if not cleaned_text:
  101 + return None, None
  102 +
  103 + predictions, probabilities = model_manager.predict_batch([cleaned_text])
48 if predictions is not None and len(predictions) > 0: 104 if predictions is not None and len(predictions) > 0:
49 return predictions[0], probabilities[0][predictions[0]] 105 return predictions[0], probabilities[0][predictions[0]]
50 return None, None 106 return None, None
@@ -53,55 +109,70 @@ def predict_sentiment(text): @@ -53,55 +109,70 @@ def predict_sentiment(text):
53 return None, None 109 return None, None
54 110
55 @pb.route('/home') 111 @pb.route('/home')
  112 +@login_required
56 def home(): 113 def home():
57 - username = session.get('username')  
58 - articleLenMax, likeCountMaxAuthorName, cityMax = getHomeTagsData()  
59 - commentsLikeCountTopFore = getHomeCommentsLikeCountTopFore()  
60 - X, Y = getHomeArticleCreatedAtChart()  
61 - typeChart = getHomeTypeChart()  
62 - createAtChart = getHomeCommentCreatedChart()  
63 - # getUserNameWordCloud()  
64 - return render_template('index.html',  
65 - username=username,  
66 - articleLenMax=articleLenMax,  
67 - likeCountMaxAuthorName=likeCountMaxAuthorName,  
68 - cityMax=cityMax,  
69 - commentsLikeCountTopFore=commentsLikeCountTopFore,  
70 - xData=X,  
71 - yData=Y,  
72 - typeChart=typeChart,  
73 - createAtChart=createAtChart)  
74 - 114 + try:
  115 + username = session.get('username')
  116 + articleLenMax, likeCountMaxAuthorName, cityMax = getHomeTagsData()
  117 + commentsLikeCountTopFore = getHomeCommentsLikeCountTopFore()
  118 + X, Y = getHomeArticleCreatedAtChart()
  119 + typeChart = getHomeTypeChart()
  120 + createAtChart = getHomeCommentCreatedChart()
  121 +
  122 + return render_template('index.html',
  123 + username=username,
  124 + articleLenMax=articleLenMax,
  125 + likeCountMaxAuthorName=likeCountMaxAuthorName,
  126 + cityMax=cityMax,
  127 + commentsLikeCountTopFore=commentsLikeCountTopFore,
  128 + xData=X,
  129 + yData=Y,
  130 + typeChart=typeChart,
  131 + createAtChart=createAtChart)
  132 + except Exception as e:
  133 + logging.error(f"加载首页时发生错误: {e}")
  134 + return render_template('error.html', error_message="加载首页失败")
75 135
76 @pb.route('/hotWord') 136 @pb.route('/hotWord')
  137 +@login_required
77 def hotWord(): 138 def hotWord():
78 - username = session.get('username')  
79 - hotWordList = getAllHotWords()  
80 - print(hotWordList)  
81 - defaultHotWord = hotWordList[0][0]  
82 - if request.args.get('hotWord'):  
83 - defaultHotWord = request.args.get('hotWord')  
84 - hotWordLen = getHotWordLen(defaultHotWord)  
85 - X, Y = getHotWordPageCreatedAtCharData(defaultHotWord)  
86 - sentences = ''  
87 - value = SnowNLP(defaultHotWord).sentiments  
88 - if value == 0.5:  
89 - sentences = '中性'  
90 - elif value > 0.5:  
91 - sentences = '正面'  
92 - elif value < 0.5:  
93 - sentences = '负面'  
94 - comments = getCommentFilterData(defaultHotWord)  
95 - return render_template('hotWord.html',  
96 - username=username,  
97 - hotWordList=hotWordList,  
98 - defaultHotWord=defaultHotWord,  
99 - hotWordLen=hotWordLen,  
100 - sentences=sentences,  
101 - xData=X,  
102 - yData=Y,  
103 - comments=comments)  
104 - 139 + try:
  140 + username = session.get('username')
  141 + hotWordList = getAllHotWords()
  142 + if not hotWordList:
  143 + return render_template('error.html', error_message="无法获取热词列表")
  144 +
  145 + defaultHotWord = sanitize_input(request.args.get('hotWord', hotWordList[0][0]))
  146 +
  147 + # 验证热词是否在列表中
  148 + if not any(defaultHotWord in word for word in hotWordList):
  149 + return abort(400, "无效的热词")
  150 +
  151 + hotWordLen = getHotWordLen(defaultHotWord)
  152 + X, Y = getHotWordPageCreatedAtCharData(defaultHotWord)
  153 +
  154 + value = SnowNLP(defaultHotWord).sentiments
  155 + if value == 0.5:
  156 + sentences = '中性'
  157 + elif value > 0.5:
  158 + sentences = '正面'
  159 + elif value < 0.5:
  160 + sentences = '负面'
  161 +
  162 + comments = getCommentFilterData(defaultHotWord)
  163 +
  164 + return render_template('hotWord.html',
  165 + username=username,
  166 + hotWordList=hotWordList,
  167 + defaultHotWord=defaultHotWord,
  168 + hotWordLen=hotWordLen,
  169 + sentences=sentences,
  170 + xData=X,
  171 + yData=Y,
  172 + comments=comments)
  173 + except Exception as e:
  174 + logging.error(f"加载热词页面时发生错误: {e}")
  175 + return render_template('error.html', error_message="加载热词页面失败")
105 176
106 @pb.route('/hotTopic') 177 @pb.route('/hotTopic')
107 def hotTopic(): 178 def hotTopic():
@@ -127,18 +198,21 @@ def hotTopic(): @@ -127,18 +198,21 @@ def hotTopic():
127 yData=Y, 198 yData=Y,
128 comments=comments) 199 comments=comments)
129 200
130 -  
131 @pb.route('/tableData') 201 @pb.route('/tableData')
  202 +@login_required
132 def tableData(): 203 def tableData():
133 - username = session.get('username')  
134 - defaultFlag = False  
135 - if request.args.get('flag'): defaultFlag = True  
136 - tableData = getTableDataList(defaultFlag)  
137 - return render_template('tableData.html',  
138 - username=username,  
139 - tableData=tableData,  
140 - defaultFlag=defaultFlag)  
141 - 204 + try:
  205 + username = session.get('username')
  206 + defaultFlag = bool(request.args.get('flag', False))
  207 + tableData = getTableDataList(defaultFlag)
  208 +
  209 + return render_template('tableData.html',
  210 + username=username,
  211 + tableData=tableData,
  212 + defaultFlag=defaultFlag)
  213 + except Exception as e:
  214 + logging.error(f"加载表格数据时发生错误: {e}")
  215 + return render_template('error.html', error_message="加载表格数据失败")
142 216
143 @pb.route('/articleChar') 217 @pb.route('/articleChar')
144 def articleChar(): 218 def articleChar():
@@ -160,63 +234,89 @@ def articleChar(): @@ -160,63 +234,89 @@ def articleChar():
160 x2Data=x2Data, 234 x2Data=x2Data,
161 y2Data=y2Data) 235 y2Data=y2Data)
162 236
163 -  
164 @pb.route('/ipChar') 237 @pb.route('/ipChar')
  238 +@login_required
165 def ipChar(): 239 def ipChar():
166 - username = session.get('username')  
167 - articleRegionData = getIPByArticleRegion()  
168 - commentRegionData = getIPByCommentsRegion()  
169 - return render_template('ipChar.html',  
170 - username=username,  
171 - articleRegionData=articleRegionData,  
172 - commentRegionData=commentRegionData)  
173 - 240 + try:
  241 + username = session.get('username')
  242 + articleRegionData = getIPByArticleRegion()
  243 + commentRegionData = getIPByCommentsRegion()
  244 +
  245 + return render_template('ipChar.html',
  246 + username=username,
  247 + articleRegionData=articleRegionData,
  248 + commentRegionData=commentRegionData)
  249 + except Exception as e:
  250 + logging.error(f"加载IP统计时发生错误: {e}")
  251 + return render_template('error.html', error_message="加载IP统计失败")
174 252
175 @pb.route('/commentChar') 253 @pb.route('/commentChar')
  254 +@login_required
176 def commentChar(): 255 def commentChar():
177 - username = session.get('username')  
178 - X, Y = getCommentDataOne()  
179 - genderPieData = getCommentDataTwo()  
180 - return render_template('commentChar.html',  
181 - username=username,  
182 - xData=X,  
183 - yData=Y,  
184 - genderPieData=genderPieData)  
185 - 256 + try:
  257 + username = session.get('username')
  258 + X, Y = getCommentDataOne()
  259 + genderPieData = getCommentDataTwo()
  260 +
  261 + return render_template('commentChar.html',
  262 + username=username,
  263 + xData=X,
  264 + yData=Y,
  265 + genderPieData=genderPieData)
  266 + except Exception as e:
  267 + logging.error(f"加载评论统计时发生错误: {e}")
  268 + return render_template('error.html', error_message="加载评论统计失败")
186 269
187 @pb.route('/yuqingChar') 270 @pb.route('/yuqingChar')
  271 +@login_required
188 def yuqingChar(): 272 def yuqingChar():
189 - username = session.get('username')  
190 - # 获取模型选择参数  
191 - model_type = request.args.get('model', 'pro') # 默认使用改进模型  
192 -  
193 - X, Y, biedata = getYuQingCharDataOne()  
194 - biedata1, biedata2 = getYuQingCharDataTwo(model_type)  
195 - x1Data, y1Data = getYuQingCharDataThree()  
196 - return render_template('yuqingChar.html',  
197 - username=username,  
198 - xData=X,  
199 - yData=Y,  
200 - biedata=biedata,  
201 - biedata1=biedata1,  
202 - biedata2=biedata2,  
203 - x1Data=x1Data,  
204 - y1Data=y1Data,  
205 - model_type=model_type) 273 + try:
  274 + username = session.get('username')
  275 + model_type = sanitize_input(request.args.get('model', 'pro'))
  276 +
  277 + # 验证模型类型
  278 + if model_type not in ['pro', 'basic']:
  279 + return abort(400, "无效的模型类型")
  280 +
  281 + X, Y, biedata = getYuQingCharDataOne()
  282 + biedata1, biedata2 = getYuQingCharDataTwo(model_type)
  283 + x1Data, y1Data = getYuQingCharDataThree()
  284 +
  285 + return render_template('yuqingChar.html',
  286 + username=username,
  287 + xData=X,
  288 + yData=Y,
  289 + biedata=biedata,
  290 + biedata1=biedata1,
  291 + biedata2=biedata2,
  292 + x1Data=x1Data,
  293 + y1Data=y1Data,
  294 + model_type=model_type)
  295 + except Exception as e:
  296 + logging.error(f"加载舆情统计时发生错误: {e}")
  297 + return render_template('error.html', error_message="加载舆情统计失败")
206 298
207 @pb.route('/yuqingpredict') 299 @pb.route('/yuqingpredict')
  300 +@login_required
208 def yuqingpredict(): 301 def yuqingpredict():
209 try: 302 try:
210 username = session.get('username') 303 username = session.get('username')
211 TopicList = getAllTopicData() 304 TopicList = getAllTopicData()
212 - defaultTopic = TopicList[0][0]  
213 - if request.args.get('Topic'):  
214 - defaultTopic = request.args.get('Topic') 305 + if not TopicList:
  306 + return render_template('error.html', error_message="无法获取话题列表")
  307 +
  308 + defaultTopic = sanitize_input(request.args.get('Topic', TopicList[0][0]))
  309 +
  310 + # 验证话题是否在列表中
  311 + if not any(defaultTopic in topic for topic in TopicList):
  312 + return abort(400, "无效的话题")
  313 +
215 TopicLen = getTopicLen(defaultTopic) 314 TopicLen = getTopicLen(defaultTopic)
216 X, Y = getTopicCreatedAtandpredictData(defaultTopic) 315 X, Y = getTopicCreatedAtandpredictData(defaultTopic)
217 316
218 - # 获取模型选择参数  
219 - model_type = request.args.get('model', 'pro') # 默认使用改进模型 317 + model_type = sanitize_input(request.args.get('model', 'pro'))
  318 + if model_type not in ['pro', 'basic']:
  319 + return abort(400, "无效的模型类型")
220 320
221 # 尝试从缓存获取预测结果 321 # 尝试从缓存获取预测结果
222 cache_key = f"{defaultTopic}_{model_type}" 322 cache_key = f"{defaultTopic}_{model_type}"
@@ -226,7 +326,6 @@ def yuqingpredict(): @@ -226,7 +326,6 @@ def yuqingpredict():
226 sentences = cached_result 326 sentences = cached_result
227 else: 327 else:
228 if model_type == 'basic': 328 if model_type == 'basic':
229 - # 使用基础模型(SnowNLP)  
230 value = SnowNLP(defaultTopic).sentiments 329 value = SnowNLP(defaultTopic).sentiments
231 if value == 0.5: 330 if value == 0.5:
232 sentences = '中性' 331 sentences = '中性'
@@ -235,7 +334,6 @@ def yuqingpredict(): @@ -235,7 +334,6 @@ def yuqingpredict():
235 elif value < 0.5: 334 elif value < 0.5:
236 sentences = '负面' 335 sentences = '负面'
237 else: 336 else:
238 - # 使用改进模型  
239 predicted_label, confidence = predict_sentiment(defaultTopic) 337 predicted_label, confidence = predict_sentiment(defaultTopic)
240 if predicted_label is not None: 338 if predicted_label is not None:
241 sentences = '良好' if predicted_label == 0 else '不良' 339 sentences = '良好' if predicted_label == 0 else '不良'
@@ -248,26 +346,30 @@ def yuqingpredict(): @@ -248,26 +346,30 @@ def yuqingpredict():
248 prediction_cache.set(cache_key, sentences) 346 prediction_cache.set(cache_key, sentences)
249 347
250 comments = getCommentFilterDataTopic(defaultTopic) 348 comments = getCommentFilterDataTopic(defaultTopic)
  349 +
251 return render_template('yuqingpredict.html', 350 return render_template('yuqingpredict.html',
252 - username=username,  
253 - hotWordList=TopicList,  
254 - defaultHotWord=defaultTopic,  
255 - hotWordLen=TopicLen,  
256 - sentences=sentences,  
257 - xData=X,  
258 - yData=Y,  
259 - comments=comments,  
260 - model_type=model_type) 351 + username=username,
  352 + TopicList=TopicList,
  353 + defaultTopic=defaultTopic,
  354 + TopicLen=TopicLen,
  355 + sentences=sentences,
  356 + xData=X,
  357 + yData=Y,
  358 + comments=comments,
  359 + model_type=model_type)
261 except Exception as e: 360 except Exception as e:
262 - logging.error(f"舆情预测页面渲染失败: {e}")  
263 - return render_template('error.html', error_message="加载舆情预测页面失败,请稍后重试")  
264 - 361 + logging.error(f"加载舆情预测时发生错误: {e}")
  362 + return render_template('error.html', error_message="加载舆情预测失败")
265 363
266 @pb.route('/articleCloud') 364 @pb.route('/articleCloud')
  365 +@login_required
267 def articleCloud(): 366 def articleCloud():
268 - username = session.get('username')  
269 - return render_template('articleContentCloud.html', username=username)  
270 - 367 + try:
  368 + username = session.get('username')
  369 + return render_template('articleContentCloud.html', username=username)
  370 + except Exception as e:
  371 + logging.error(f"加载文章云图时发生错误: {e}")
  372 + return render_template('error.html', error_message="加载文章云图失败")
271 373
272 @pb.route('/page/index') 374 @pb.route('/page/index')
273 def index(): 375 def index():
@@ -306,15 +408,28 @@ def articleChar(id): @@ -306,15 +408,28 @@ def articleChar(id):
306 return render_template('error.html', error_message="加载文章详情失败") 408 return render_template('error.html', error_message="加载文章详情失败")
307 409
308 @pb.route('/api/analyze_messages', methods=['POST']) 410 @pb.route('/api/analyze_messages', methods=['POST'])
  411 +@api_login_required
  412 +@rate_limit
309 async def analyze_messages(): 413 async def analyze_messages():
310 try: 414 try:
311 - # 获取请求参数 415 + if not validate_csrf_token():
  416 + return jsonify({'error': 'Invalid CSRF token'}), 403
  417 +
312 data = request.get_json() 418 data = request.get_json()
313 - batch_size = data.get('batch_size', 50)  
314 - model_type = data.get('model_type', 'gpt-3.5-turbo')  
315 - analysis_depth = data.get('analysis_depth', 'standard') 419 + if not data:
  420 + return jsonify({'error': 'No data provided'}), 400
  421 +
  422 + batch_size = min(int(data.get('batch_size', 50)), 100) # 限制批量大小
  423 + model_type = sanitize_input(data.get('model_type', 'gpt-3.5-turbo'))
  424 + analysis_depth = sanitize_input(data.get('analysis_depth', 'standard'))
  425 +
  426 + # 验证参数
  427 + if model_type not in ['gpt-3.5-turbo', 'gpt-4']:
  428 + return jsonify({'error': 'Invalid model type'}), 400
  429 +
  430 + if analysis_depth not in ['basic', 'standard', 'deep']:
  431 + return jsonify({'error': 'Invalid analysis depth'}), 400
316 432
317 - # 获取最近的消息  
318 messages = getRecentMessages(batch_size) 433 messages = getRecentMessages(batch_size)
319 if not messages: 434 if not messages:
320 return jsonify({ 435 return jsonify({
@@ -322,7 +437,6 @@ async def analyze_messages(): @@ -322,7 +437,6 @@ async def analyze_messages():
322 'error': '没有找到需要分析的消息' 437 'error': '没有找到需要分析的消息'
323 }), 404 438 }), 404
324 439
325 - # 调用AI进行分析  
326 analysis_results = await ai_analyzer.analyze_messages( 440 analysis_results = await ai_analyzer.analyze_messages(
327 messages=messages, 441 messages=messages,
328 batch_size=batch_size, 442 batch_size=batch_size,
@@ -336,22 +450,27 @@ async def analyze_messages(): @@ -336,22 +450,27 @@ async def analyze_messages():
336 'error': '分析过程中出现错误' 450 'error': '分析过程中出现错误'
337 }), 500 451 }), 500
338 452
339 - # 保存到数据库  
340 - with Session(engine) as session:  
341 - for result in analysis_results:  
342 - analysis = AIAnalysis(  
343 - message_id=result['message_id'],  
344 - sentiment=result['sentiment'],  
345 - sentiment_score=float(result['sentiment_score']),  
346 - keywords=result['keywords'],  
347 - key_points=result['key_points'],  
348 - influence_analysis=result['influence_analysis'],  
349 - risk_level=result['risk_level']  
350 - )  
351 - session.add(analysis)  
352 - session.commit() 453 + try:
  454 + with Session(engine) as session:
  455 + for result in analysis_results:
  456 + analysis = AIAnalysis(
  457 + message_id=result['message_id'],
  458 + sentiment=result['sentiment'],
  459 + sentiment_score=float(result['sentiment_score']),
  460 + keywords=result['keywords'],
  461 + key_points=result['key_points'],
  462 + influence_analysis=result['influence_analysis'],
  463 + risk_level=result['risk_level']
  464 + )
  465 + session.add(analysis)
  466 + session.commit()
  467 + except Exception as e:
  468 + logging.error(f"保存分析结果时出错: {e}")
  469 + return jsonify({
  470 + 'success': False,
  471 + 'error': '保存分析结果失败'
  472 + }), 500
353 473
354 - # 格式化结果用于显示  
355 display_results = [ 474 display_results = [
356 ai_analyzer.format_analysis_for_display(result) 475 ai_analyzer.format_analysis_for_display(result)
357 for result in analysis_results 476 for result in analysis_results
@@ -359,27 +478,25 @@ async def analyze_messages(): @@ -359,27 +478,25 @@ async def analyze_messages():
359 478
360 return jsonify({ 479 return jsonify({
361 'success': True, 480 'success': True,
362 - 'data': display_results,  
363 - 'meta': {  
364 - 'total_messages': len(messages),  
365 - 'analyzed_messages': len(analysis_results),  
366 - 'batch_size': batch_size,  
367 - 'model_type': model_type,  
368 - 'analysis_depth': analysis_depth  
369 - } 481 + 'data': display_results
370 }) 482 })
371 - 483 +
372 except Exception as e: 484 except Exception as e:
373 - logging.error(f"AI分析过程出错: {e}") 485 + logging.error(f"分析消息时发生错误: {e}")
374 return jsonify({ 486 return jsonify({
375 'success': False, 487 'success': False,
376 'error': str(e) 488 'error': str(e)
377 }), 500 489 }), 500
378 490
379 @pb.route('/api/get_analysis/<int:message_id>') 491 @pb.route('/api/get_analysis/<int:message_id>')
  492 +@api_login_required
  493 +@rate_limit
380 def get_message_analysis(message_id): 494 def get_message_analysis(message_id):
381 """获取特定消息的分析结果""" 495 """获取特定消息的分析结果"""
382 try: 496 try:
  497 + if not message_id or message_id < 1:
  498 + return jsonify({'error': 'Invalid message ID'}), 400
  499 +
383 with Session(engine) as session: 500 with Session(engine) as session:
384 analysis = session.query(AIAnalysis)\ 501 analysis = session.query(AIAnalysis)\
385 .filter(AIAnalysis.message_id == message_id)\ 502 .filter(AIAnalysis.message_id == message_id)\