Toggle navigation
Toggle navigation
This project
Loading...
Sign in
万朱浩
/
Venue-Ops
Go to a project
Toggle navigation
Projects
Groups
Snippets
Help
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
戒酒的李白
2025-03-08 00:17:42 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
f81a71e97093c4dfed0e75d662e5099dacbc7eb4
f81a71e9
1 parent
5630b300
Comprehensive security enhancement, fix race conditions and injection vulnerabilities.
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
450 additions
and
343 deletions
app.py
utils/errorResponse.py
views/page/page.py
app.py
View file @
f81a71e
...
...
@@ -3,13 +3,20 @@ import re
import
getpass
import
pymysql
import
subprocess
from
flask
import
Flask
,
session
,
request
,
redirect
,
render_template
from
flask
import
Flask
,
session
,
request
,
redirect
,
render_template
,
jsonify
from
apscheduler.schedulers.background
import
BackgroundScheduler
from
pytz
import
utc
from
datetime
import
datetime
,
timedelta
import
time
from
utils.logger
import
app_logger
as
logging
from
utils.db_manager
import
DatabaseManager
import
secrets
from
dotenv
import
load_dotenv
from
functools
import
wraps
import
bleach
# 加载环境变量
load_dotenv
()
def
get_db_connection_interactive
():
"""
...
...
@@ -18,17 +25,17 @@ def get_db_connection_interactive():
"""
print
(
"请依次输入数据库连接信息(直接按回车使用默认值):"
)
host
=
input
(
" 1. 主机 (默认: localhost): "
)
or
"localhost"
port_str
=
input
(
" 2. 端口 (默认: 3306): "
)
or
"3306"
host
=
input
(
" 1. 主机 (默认: localhost): "
)
or
os
.
getenv
(
'DB_HOST'
,
'localhost'
)
port_str
=
input
(
" 2. 端口 (默认: 3306): "
)
or
os
.
getenv
(
'DB_PORT'
,
'3306'
)
try
:
port
=
int
(
port_str
)
except
ValueError
:
logging
.
warning
(
"端口号无效,使用默认端口 3306。"
)
port
=
3306
user
=
input
(
" 3. 用户名 (默认: root): "
)
or
"root"
password
=
getpass
.
getpass
(
" 4. 密码 (默认: 12345678): "
)
or
"12345678"
db_name
=
input
(
" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): "
)
or
"Weibo_PublicOpinion_AnalysisSystem"
user
=
input
(
" 3. 用户名 (默认: root): "
)
or
os
.
getenv
(
'DB_USER'
,
'root'
)
password
=
getpass
.
getpass
(
" 4. 密码: "
)
or
os
.
getenv
(
'DB_PASSWORD'
,
''
)
db_name
=
input
(
" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): "
)
or
os
.
getenv
(
'DB_NAME'
,
'Weibo_PublicOpinion_AnalysisSystem'
)
logging
.
info
(
f
"尝试连接到数据库: {user}@{host}:{port}/{db_name}"
)
...
...
@@ -40,237 +47,183 @@ def get_db_connection_interactive():
password
=
password
,
database
=
db_name
,
charset
=
'utf8mb4'
,
cursorclass
=
pymysql
.
cursors
.
DictCursor
# 返回字典格式
cursorclass
=
pymysql
.
cursors
.
DictCursor
,
ssl
=
{
'ssl'
:
{
'ca'
:
os
.
getenv
(
'DB_SSL_CA'
)}}
if
os
.
getenv
(
'DB_SSL_CA'
)
else
None
)
logging
.
info
(
"数据库连接成功。"
)
return
connection
except
pymysql
.
MySQLError
as
e
:
logging
.
error
(
f
"数据库连接失败: {e}"
)
exit
(
1
)
def
initialize_database
(
connection
,
sql_file_path
):
"""
执行 SQL 文件中的语句以初始化数据库。
:param connection: 已建立的数据库连接
:param sql_file_path: SQL 文件的路径
"""
try
:
with
open
(
sql_file_path
,
'r'
,
encoding
=
'utf8'
)
as
file
:
sql_commands
=
file
.
read
()
with
connection
.
cursor
()
as
cursor
:
for
statement
in
sql_commands
.
split
(
';'
):
statement
=
statement
.
strip
()
if
statement
:
cursor
.
execute
(
statement
)
connection
.
commit
()
logging
.
info
(
"数据库初始化成功。"
)
except
FileNotFoundError
:
logging
.
error
(
f
"SQL 文件未找到: {sql_file_path}"
)
exit
(
1
)
except
pymysql
.
MySQLError
as
e
:
logging
.
error
(
f
"执行 SQL 时出错: {e}"
)
connection
.
rollback
()
exit
(
1
)
except
Exception
as
e
:
logging
.
error
(
f
"初始化数据库时出错: {e}"
)
connection
.
rollback
()
exit
(
1
)
def
prompt_first_run
():
"""
询问用户是否首次运行,需要初始化数据库。
:return: Boolean,True 表示需要初始化数据库
"""
while
True
:
choice
=
input
(
"是否首次运行该项目,需要初始化数据库?(Y/n): "
)
.
strip
()
.
lower
()
if
choice
in
[
'y'
,
'yes'
,
''
]:
return
True
elif
choice
in
[
'n'
,
'no'
]:
return
False
else
:
print
(
"请输入 Y 或 N。"
)
raise
def
sanitize_input
(
text
):
"""清理用户输入,防止XSS攻击"""
if
text
is
None
:
return
None
return
bleach
.
clean
(
str
(
text
),
strip
=
True
)
def
set_secure_headers
(
response
):
"""设置安全响应头"""
response
.
headers
[
'X-Content-Type-Options'
]
=
'nosniff'
response
.
headers
[
'X-Frame-Options'
]
=
'SAMEORIGIN'
response
.
headers
[
'X-XSS-Protection'
]
=
'1; mode=block'
response
.
headers
[
'Strict-Transport-Security'
]
=
'max-age=31536000; includeSubDomains'
response
.
headers
[
'Content-Security-Policy'
]
=
"default-src 'self'"
return
response
# 初始化 Flask 应用
app
=
Flask
(
__name__
)
app
.
secret_key
=
'this is secret_key you know ?'
# 设置 Flask 的密钥,用于 session 加密
app
.
secret_key
=
os
.
getenv
(
'FLASK_SECRET_KEY'
,
secrets
.
token_hex
(
32
))
app
.
config
[
'SESSION_COOKIE_SECURE'
]
=
True
app
.
config
[
'SESSION_COOKIE_HTTPONLY'
]
=
True
app
.
config
[
'SESSION_COOKIE_SAMESITE'
]
=
'Lax'
app
.
config
[
'PERMANENT_SESSION_LIFETIME'
]
=
timedelta
(
hours
=
2
)
# 导入蓝图
from
views.page
import
page
from
views.user
import
user
from
views.spider_control
import
spider_bp
app
.
register_blueprint
(
page
.
pb
)
# 注册页面蓝图
app
.
register_blueprint
(
user
.
ub
)
# 注册用户蓝图
app
.
register_blueprint
(
spider_bp
)
# 注册爬虫控制蓝图
app
.
register_blueprint
(
page
.
pb
)
app
.
register_blueprint
(
user
.
ub
)
app
.
register_blueprint
(
spider_bp
)
# 首页路由
,清空 session
# 首页路由
@app.route
(
'/'
)
def
hello_world
():
session
.
clear
()
# 清空 session,用户退出登录
return
"Session Cleared"
session
.
clear
()
return
redirect
(
'/user/login'
)
#
中间件:处理请求前的逻辑
#
请求前中间件
@app.before_request
def
before_request
():
# 检查是否是HTTPS
if
not
request
.
is_secure
and
not
app
.
debug
:
url
=
request
.
url
.
replace
(
'http://'
,
'https://'
,
1
)
return
redirect
(
url
,
code
=
301
)
# 如果请求的是静态文件路径,允许访问
if
request
.
path
.
startswith
(
'/static'
):
return
# 如果请求的是登录或注册页面,不需要会话验证
if
request
.
path
in
[
'/user/login'
,
'/user/register'
]:
return
# 如果 session 中没有用户名,重定向到登录页面
# 验证会话
if
not
session
.
get
(
'username'
):
return
redirect
(
'/user/login'
)
# 验证会话完整性
if
'client_info'
not
in
session
:
session
.
clear
()
return
redirect
(
'/user/login'
)
# 验证客户端信息
current_client
=
{
'ip'
:
request
.
remote_addr
,
'user_agent'
:
str
(
request
.
user_agent
)
}
stored_client
=
session
.
get
(
'client_info'
,
{})
if
(
current_client
[
'ip'
]
!=
stored_client
.
get
(
'ip'
)
or
current_client
[
'user_agent'
]
!=
stored_client
.
get
(
'user_agent'
)):
session
.
clear
()
return
redirect
(
'/user/login'
)
# 404 错误页面路由
@app.route
(
'/<path:path>'
)
def
catch_all
(
path
):
return
render_template
(
'404.html'
)
# 如果路径不存在,返回 404 页面
# 响应后中间件
@app.after_request
def
after_request
(
response
):
return
set_secure_headers
(
response
)
# 定义定时任务,运行爬虫脚本
def
run_script
():
current_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
# 获取当前脚本的目录
spider_script
=
os
.
path
.
join
(
current_dir
,
'spider'
,
'main.py'
)
# 爬虫脚本路径
# cutComments_script = os.path.join(current_dir, 'utils', 'cutComments.py') # 评论处理脚本路径
# cipingTotal_script = os.path.join(current_dir, 'utils', 'cipingTotal.py') # 评分处理脚本路径
# 错误处理
@app.errorhandler
(
404
)
def
not_found_error
(
error
):
return
render_template
(
'404.html'
),
404
# 定义所有要运行的脚本
scripts
=
[
(
"Spider Script"
,
spider_script
),
# ("Cut Comments Script", cutComments_script),
# ("Ciping Total Script", cipingTotal_script)
]
@app.errorhandler
(
500
)
def
internal_error
(
error
):
return
render_template
(
'500.html'
),
500
# 执行所有脚本
for
script_name
,
script_path
in
scripts
:
try
:
logging
.
info
(
f
"Running {script_name}..."
)
subprocess
.
run
([
'python'
,
script_path
],
check
=
True
)
# 使用 subprocess 执行脚本
logging
.
info
(
f
"{script_name} finished successfully."
)
except
subprocess
.
CalledProcessError
as
e
:
logging
.
error
(
f
"An error occurred while running {script_name}: {e}"
)
@app.errorhandler
(
403
)
def
forbidden_error
(
error
):
return
render_template
(
'403.html'
),
403
# 新增功能:动态调度爬虫脚本
def
check_database_empty
():
"""
检查数据库中的指定表是否为空。
:return: 如果表为空则返回 True,否则返回 False
"""
try
:
connection
=
pymysql
.
connect
(
**
DB_CONFIG
)
with
connection
.
cursor
()
as
cursor
:
cursor
.
execute
(
"SELECT COUNT(*) as count FROM article"
)
result
=
cursor
.
fetchone
()
count
=
result
[
'count'
]
if
result
and
'count'
in
result
else
0
logging
.
info
(
f
"数据库中共有 {count} 条记录。"
)
return
count
==
0
except
pymysql
.
MySQLError
as
e
:
logging
.
error
(
f
"检查数据库失败: {e}"
)
return
True
# 连接失败时假设数据库为空,以防止阻塞
finally
:
if
'connection'
in
locals
():
connection
.
close
()
def
dynamic_crawl
():
"""
执行爬取任务并根据爬取耗时和获取的数据量动态调度下次爬取时间。
"""
try
:
start_time
=
time
.
time
()
logging
.
info
(
"开始爬取数据。"
)
run_script
()
# 执行爬虫脚本
end_time
=
time
.
time
()
duration
=
end_time
-
start_time
# 爬取耗时
# 获取爬取后数据库中记录的数量作为数据量
try
:
connection
=
pymysql
.
connect
(
**
DB_CONFIG
)
with
connection
.
cursor
()
as
cursor
:
cursor
.
execute
(
"SELECT COUNT(*) as count FROM article"
)
result
=
cursor
.
fetchone
()
data_fetched
=
result
[
'count'
]
if
result
and
'count'
in
result
else
0
logging
.
info
(
f
"爬取完成,耗时 {duration:.2f} 秒,数据库中共有 {data_fetched} 条记录。"
)
except
pymysql
.
MySQLError
as
e
:
logging
.
error
(
f
"获取数据量失败: {e}"
)
data_fetched
=
0
finally
:
if
'connection'
in
locals
():
connection
.
close
()
# 根据爬取耗时和数据量调整下次爬取时间
base_interval
=
5
*
60
*
60
# 5小时的基础时间间隔(秒)
if
duration
>
3600
:
# 爬取耗时超过1小时
next_interval
=
base_interval
+
duration
logging
.
info
(
f
"检测到长时间爬取。下次爬取将在 {next_interval/3600:.2f} 小时后执行。"
)
elif
data_fetched
<
50
:
# 获取的数据量少于50条
next_interval
=
base_interval
/
2
logging
.
info
(
f
"获取数据量较少。下次爬取将在 {next_interval/60:.2f} 分钟后执行。"
)
else
:
next_interval
=
base_interval
logging
.
info
(
f
"标准爬取完成。下次爬取将在 {next_interval/3600:.2f} 小时后执行。"
)
# 安排下次爬取任务
scheduler
.
add_job
(
dynamic_crawl
,
'date'
,
run_date
=
datetime
.
now
()
+
timedelta
(
seconds
=
next_interval
),
id
=
'dynamic_crawl'
)
except
Exception
as
e
:
logging
.
error
(
f
"动态爬取过程中发生错误: {e}"
)
@app.errorhandler
(
400
)
def
bad_request_error
(
error
):
return
render_template
(
'400.html'
),
400
# 数据库配置
,用于动态调度功能
# 数据库配置
DB_CONFIG
=
{
'host'
:
'localhost'
,
'user'
:
'root'
,
'password'
:
'12345678'
,
'database'
:
'Weibo_PublicOpinion_AnalysisSystem'
,
'port'
:
3306
,
'charset'
:
'utf8mb4'
'host'
:
os
.
getenv
(
'DB_HOST'
,
'localhost'
),
'user'
:
os
.
getenv
(
'DB_USER'
,
'root'
),
'password'
:
os
.
getenv
(
'DB_PASSWORD'
,
''
),
'database'
:
os
.
getenv
(
'DB_NAME'
,
'Weibo_PublicOpinion_AnalysisSystem'
),
'port'
:
int
(
os
.
getenv
(
'DB_PORT'
,
'3306'
)),
'charset'
:
'utf8mb4'
,
'ssl'
:
{
'ca'
:
os
.
getenv
(
'DB_SSL_CA'
)}
if
os
.
getenv
(
'DB_SSL_CA'
)
else
None
}
# 初始化数据库管理器
DatabaseManager
.
initialize
(
DB_CONFIG
)
# 主程序入口
if
__name__
==
'__main__'
:
# 检测是否需要初始化数据库
if
prompt_first_run
():
# 获取数据库连接
connection
=
get_db_connection_interactive
()
# 执行数据库初始化
sql_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'createTables.sql'
)
initialize_database
(
connection
,
sql_file
)
# 关闭数据库连接
connection
.
close
()
logging
.
info
(
"数据库连接已关闭。"
)
# 设置定时任务,动态执行爬虫脚本
scheduler
=
BackgroundScheduler
(
timezone
=
utc
)
# 创建后台任务调度器
scheduler
.
start
()
# 启动调度器
# 初始化调度:如果数据库为空,立即爬取;否则,按照基础时间间隔安排首次爬取
if
check_database_empty
():
logging
.
info
(
"数据库为空。立即开始初始爬取。"
)
dynamic_crawl
()
else
:
logging
.
info
(
"数据库已有数据。安排首次爬取。"
)
base_interval
=
5
*
60
*
60
# 5小时
scheduler
.
add_job
(
dynamic_crawl
,
'date'
,
run_date
=
datetime
.
now
()
+
timedelta
(
seconds
=
base_interval
),
id
=
'dynamic_crawl'
)
try
:
app
.
run
()
# 启动 Flask 应用
if
os
.
getenv
(
'INITIALIZE_DB'
,
'false'
)
.
lower
()
==
'true'
:
connection
=
get_db_connection_interactive
()
sql_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'createTables.sql'
)
initialize_database
(
connection
,
sql_file
)
connection
.
close
()
logging
.
info
(
"数据库初始化完成。"
)
except
Exception
as
e
:
logging
.
error
(
f
"数据库初始化失败: {e}"
)
exit
(
1
)
# 设置定时任务
try
:
scheduler
=
BackgroundScheduler
(
timezone
=
utc
)
scheduler
.
start
()
if
check_database_empty
():
logging
.
info
(
"数据库为空。立即开始初始爬取。"
)
dynamic_crawl
()
else
:
logging
.
info
(
"数据库已有数据。安排首次爬取。"
)
base_interval
=
int
(
os
.
getenv
(
'CRAWL_INTERVAL'
,
'18000'
))
# 默认5小时
scheduler
.
add_job
(
dynamic_crawl
,
'date'
,
run_date
=
datetime
.
now
()
+
timedelta
(
seconds
=
base_interval
),
id
=
'dynamic_crawl'
)
# 启动应用
app
.
run
(
host
=
os
.
getenv
(
'FLASK_HOST'
,
'127.0.0.1'
),
port
=
int
(
os
.
getenv
(
'FLASK_PORT'
,
'5000'
)),
ssl_context
=
'adhoc'
if
os
.
getenv
(
'ENABLE_HTTPS'
,
'false'
)
.
lower
()
==
'true'
else
None
)
except
Exception
as
e
:
logging
.
error
(
f
"应用启动失败: {e}"
)
if
'scheduler'
in
locals
():
scheduler
.
shutdown
()
exit
(
1
)
finally
:
scheduler
.
shutdown
()
# 确保在应用关闭时关闭调度器
if
'scheduler'
in
locals
():
scheduler
.
shutdown
()
#
设置日志记录,捕获应用的请求信息
#
请求日志记录
@app.before_request
def
log_request_info
():
# 记录每次请求的信息,便于调试和监控
logging
.
info
(
f
"Request: {request.method} {request.path}"
)
# 记录请求的方式(GET/POST)和路径
# 记录请求信息,但排除敏感数据
sanitized_headers
=
dict
(
request
.
headers
)
if
'Authorization'
in
sanitized_headers
:
sanitized_headers
[
'Authorization'
]
=
'[FILTERED]'
if
'Cookie'
in
sanitized_headers
:
sanitized_headers
[
'Cookie'
]
=
'[FILTERED]'
logging
.
info
(
f
"Request: {request.method} {request.path}
\n
"
f
"Remote IP: {request.remote_addr}
\n
"
f
"Headers: {sanitized_headers}"
)
...
...
utils/errorResponse.py
View file @
f81a71e
from
flask
import
render_template
def
errorResponse
(
errorMsg
):
return
render_template
(
'error.html'
,
errorMsg
=
errorMsg
)
\ No newline at end of file
from
flask
import
render_template
,
jsonify
import
bleach
import
re
def
sanitize_error_message
(
message
):
"""
清理和验证错误消息
"""
if
not
message
:
return
"发生未知错误"
# 移除任何敏感信息
message
=
re
.
sub
(
r'(password|token|key|secret)=[
\
w
\
-]+'
,
r'
\
1=[FILTERED]'
,
str
(
message
))
# 清理HTML和特殊字符
message
=
bleach
.
clean
(
message
,
strip
=
True
)
# 限制消息长度
return
message
[:
200
]
if
len
(
message
)
>
200
else
message
def
errorResponse
(
errorMsg
,
status_code
=
400
):
"""
统一的错误响应处理
:param errorMsg: 错误消息
:param status_code: HTTP状态码
:return: 错误响应
"""
safe_message
=
sanitize_error_message
(
errorMsg
)
if
'application/json'
in
request
.
headers
.
get
(
'Accept'
,
''
):
return
jsonify
({
'success'
:
False
,
'error'
:
safe_message
}),
status_code
return
render_template
(
'error.html'
,
errorMsg
=
safe_message
,
status_code
=
status_code
),
status_code
\ No newline at end of file
...
...
views/page/page.py
View file @
f81a71e
from
flask
import
Flask
,
session
,
render_template
,
redirect
,
Blueprint
,
request
,
jsonify
from
flask
import
Flask
,
session
,
render_template
,
redirect
,
Blueprint
,
request
,
jsonify
,
abort
from
utils.mynlp
import
SnowNLP
from
utils.getHomePageData
import
*
from
utils.getHotWordPageData
import
*
...
...
@@ -16,12 +16,60 @@ from sqlalchemy import create_engine
import
asyncio
import
torch
from
BCAT_front.predict
import
model_manager
from
functools
import
wraps
import
bleach
import
re
from
datetime
import
datetime
,
timedelta
pb
=
Blueprint
(
'page'
,
__name__
,
url_prefix
=
'/page'
,
template_folder
=
'templates'
)
def
sanitize_input
(
text
):
"""清理用户输入,防止XSS攻击"""
if
text
is
None
:
return
None
return
bleach
.
clean
(
str
(
text
),
strip
=
True
)
def
validate_csrf_token
():
"""验证CSRF令牌"""
token
=
request
.
form
.
get
(
'csrf_token'
)
stored_token
=
session
.
get
(
'csrf_token'
)
if
not
token
or
not
stored_token
or
token
!=
stored_token
:
return
False
return
True
def
login_required
(
f
):
@wraps
(
f
)
def
decorated_function
(
*
args
,
**
kwargs
):
if
'username'
not
in
session
:
return
redirect
(
'/user/login'
)
return
f
(
*
args
,
**
kwargs
)
return
decorated_function
def
api_login_required
(
f
):
@wraps
(
f
)
def
decorated_function
(
*
args
,
**
kwargs
):
if
'username'
not
in
session
:
return
jsonify
({
'error'
:
'Unauthorized'
}),
401
return
f
(
*
args
,
**
kwargs
)
return
decorated_function
def
rate_limit
(
f
):
@wraps
(
f
)
def
decorated_function
(
*
args
,
**
kwargs
):
key
=
f
"rate_limit:{request.remote_addr}:{f.__name__}"
current
=
int
(
redis_client
.
get
(
key
)
or
0
)
if
current
>=
100
:
# 每分钟100次请求限制
return
jsonify
({
'error'
:
'Too many requests'
}),
429
pipe
=
redis_client
.
pipeline
()
pipe
.
incr
(
key
)
pipe
.
expire
(
key
,
60
)
# 60秒后重置
pipe
.
execute
()
return
f
(
*
args
,
**
kwargs
)
return
decorated_function
# 设置设备
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
@@ -37,14 +85,22 @@ except Exception as e:
logging
.
error
(
f
"模型加载失败: {e}"
)
# 数据库配置
DATABASE_URL
=
"sqlite:///ai_analysis.db"
DATABASE_URL
=
os
.
getenv
(
'DATABASE_URL'
,
"sqlite:///ai_analysis.db"
)
engine
=
create_engine
(
DATABASE_URL
)
AIAnalysis
.
metadata
.
create_all
(
engine
)
def
predict_sentiment
(
text
):
"""使用改进版模型预测单个文本的情感"""
try
:
predictions
,
probabilities
=
model_manager
.
predict_batch
([
text
])
if
not
text
or
len
(
text
.
strip
())
==
0
:
return
None
,
None
# 清理输入
cleaned_text
=
sanitize_input
(
text
)
if
not
cleaned_text
:
return
None
,
None
predictions
,
probabilities
=
model_manager
.
predict_batch
([
cleaned_text
])
if
predictions
is
not
None
and
len
(
predictions
)
>
0
:
return
predictions
[
0
],
probabilities
[
0
][
predictions
[
0
]]
return
None
,
None
...
...
@@ -53,55 +109,70 @@ def predict_sentiment(text):
return
None
,
None
@pb.route
(
'/home'
)
@login_required
def
home
():
username
=
session
.
get
(
'username'
)
articleLenMax
,
likeCountMaxAuthorName
,
cityMax
=
getHomeTagsData
()
commentsLikeCountTopFore
=
getHomeCommentsLikeCountTopFore
()
X
,
Y
=
getHomeArticleCreatedAtChart
()
typeChart
=
getHomeTypeChart
()
createAtChart
=
getHomeCommentCreatedChart
()
# getUserNameWordCloud()
return
render_template
(
'index.html'
,
username
=
username
,
articleLenMax
=
articleLenMax
,
likeCountMaxAuthorName
=
likeCountMaxAuthorName
,
cityMax
=
cityMax
,
commentsLikeCountTopFore
=
commentsLikeCountTopFore
,
xData
=
X
,
yData
=
Y
,
typeChart
=
typeChart
,
createAtChart
=
createAtChart
)
try
:
username
=
session
.
get
(
'username'
)
articleLenMax
,
likeCountMaxAuthorName
,
cityMax
=
getHomeTagsData
()
commentsLikeCountTopFore
=
getHomeCommentsLikeCountTopFore
()
X
,
Y
=
getHomeArticleCreatedAtChart
()
typeChart
=
getHomeTypeChart
()
createAtChart
=
getHomeCommentCreatedChart
()
return
render_template
(
'index.html'
,
username
=
username
,
articleLenMax
=
articleLenMax
,
likeCountMaxAuthorName
=
likeCountMaxAuthorName
,
cityMax
=
cityMax
,
commentsLikeCountTopFore
=
commentsLikeCountTopFore
,
xData
=
X
,
yData
=
Y
,
typeChart
=
typeChart
,
createAtChart
=
createAtChart
)
except
Exception
as
e
:
logging
.
error
(
f
"加载首页时发生错误: {e}"
)
return
render_template
(
'error.html'
,
error_message
=
"加载首页失败"
)
@pb.route
(
'/hotWord'
)
@login_required
def
hotWord
():
username
=
session
.
get
(
'username'
)
hotWordList
=
getAllHotWords
()
print
(
hotWordList
)
defaultHotWord
=
hotWordList
[
0
][
0
]
if
request
.
args
.
get
(
'hotWord'
):
defaultHotWord
=
request
.
args
.
get
(
'hotWord'
)
hotWordLen
=
getHotWordLen
(
defaultHotWord
)
X
,
Y
=
getHotWordPageCreatedAtCharData
(
defaultHotWord
)
sentences
=
''
value
=
SnowNLP
(
defaultHotWord
)
.
sentiments
if
value
==
0.5
:
sentences
=
'中性'
elif
value
>
0.5
:
sentences
=
'正面'
elif
value
<
0.5
:
sentences
=
'负面'
comments
=
getCommentFilterData
(
defaultHotWord
)
return
render_template
(
'hotWord.html'
,
username
=
username
,
hotWordList
=
hotWordList
,
defaultHotWord
=
defaultHotWord
,
hotWordLen
=
hotWordLen
,
sentences
=
sentences
,
xData
=
X
,
yData
=
Y
,
comments
=
comments
)
try
:
username
=
session
.
get
(
'username'
)
hotWordList
=
getAllHotWords
()
if
not
hotWordList
:
return
render_template
(
'error.html'
,
error_message
=
"无法获取热词列表"
)
defaultHotWord
=
sanitize_input
(
request
.
args
.
get
(
'hotWord'
,
hotWordList
[
0
][
0
]))
# 验证热词是否在列表中
if
not
any
(
defaultHotWord
in
word
for
word
in
hotWordList
):
return
abort
(
400
,
"无效的热词"
)
hotWordLen
=
getHotWordLen
(
defaultHotWord
)
X
,
Y
=
getHotWordPageCreatedAtCharData
(
defaultHotWord
)
value
=
SnowNLP
(
defaultHotWord
)
.
sentiments
if
value
==
0.5
:
sentences
=
'中性'
elif
value
>
0.5
:
sentences
=
'正面'
elif
value
<
0.5
:
sentences
=
'负面'
comments
=
getCommentFilterData
(
defaultHotWord
)
return
render_template
(
'hotWord.html'
,
username
=
username
,
hotWordList
=
hotWordList
,
defaultHotWord
=
defaultHotWord
,
hotWordLen
=
hotWordLen
,
sentences
=
sentences
,
xData
=
X
,
yData
=
Y
,
comments
=
comments
)
except
Exception
as
e
:
logging
.
error
(
f
"加载热词页面时发生错误: {e}"
)
return
render_template
(
'error.html'
,
error_message
=
"加载热词页面失败"
)
@pb.route
(
'/hotTopic'
)
def
hotTopic
():
...
...
@@ -127,18 +198,21 @@ def hotTopic():
yData
=
Y
,
comments
=
comments
)
@pb.route
(
'/tableData'
)
@login_required
def
tableData
():
username
=
session
.
get
(
'username'
)
defaultFlag
=
False
if
request
.
args
.
get
(
'flag'
):
defaultFlag
=
True
tableData
=
getTableDataList
(
defaultFlag
)
return
render_template
(
'tableData.html'
,
username
=
username
,
tableData
=
tableData
,
defaultFlag
=
defaultFlag
)
try
:
username
=
session
.
get
(
'username'
)
defaultFlag
=
bool
(
request
.
args
.
get
(
'flag'
,
False
))
tableData
=
getTableDataList
(
defaultFlag
)
return
render_template
(
'tableData.html'
,
username
=
username
,
tableData
=
tableData
,
defaultFlag
=
defaultFlag
)
except
Exception
as
e
:
logging
.
error
(
f
"加载表格数据时发生错误: {e}"
)
return
render_template
(
'error.html'
,
error_message
=
"加载表格数据失败"
)
@pb.route
(
'/articleChar'
)
def
articleChar
():
...
...
@@ -160,63 +234,89 @@ def articleChar():
x2Data
=
x2Data
,
y2Data
=
y2Data
)
@pb.route
(
'/ipChar'
)
@login_required
def
ipChar
():
username
=
session
.
get
(
'username'
)
articleRegionData
=
getIPByArticleRegion
()
commentRegionData
=
getIPByCommentsRegion
()
return
render_template
(
'ipChar.html'
,
username
=
username
,
articleRegionData
=
articleRegionData
,
commentRegionData
=
commentRegionData
)
try
:
username
=
session
.
get
(
'username'
)
articleRegionData
=
getIPByArticleRegion
()
commentRegionData
=
getIPByCommentsRegion
()
return
render_template
(
'ipChar.html'
,
username
=
username
,
articleRegionData
=
articleRegionData
,
commentRegionData
=
commentRegionData
)
except
Exception
as
e
:
logging
.
error
(
f
"加载IP统计时发生错误: {e}"
)
return
render_template
(
'error.html'
,
error_message
=
"加载IP统计失败"
)
@pb.route
(
'/commentChar'
)
@login_required
def
commentChar
():
username
=
session
.
get
(
'username'
)
X
,
Y
=
getCommentDataOne
()
genderPieData
=
getCommentDataTwo
()
return
render_template
(
'commentChar.html'
,
username
=
username
,
xData
=
X
,
yData
=
Y
,
genderPieData
=
genderPieData
)
try
:
username
=
session
.
get
(
'username'
)
X
,
Y
=
getCommentDataOne
()
genderPieData
=
getCommentDataTwo
()
return
render_template
(
'commentChar.html'
,
username
=
username
,
xData
=
X
,
yData
=
Y
,
genderPieData
=
genderPieData
)
except
Exception
as
e
:
logging
.
error
(
f
"加载评论统计时发生错误: {e}"
)
return
render_template
(
'error.html'
,
error_message
=
"加载评论统计失败"
)
@pb.route
(
'/yuqingChar'
)
@login_required
def
yuqingChar
():
username
=
session
.
get
(
'username'
)
# 获取模型选择参数
model_type
=
request
.
args
.
get
(
'model'
,
'pro'
)
# 默认使用改进模型
X
,
Y
,
biedata
=
getYuQingCharDataOne
()
biedata1
,
biedata2
=
getYuQingCharDataTwo
(
model_type
)
x1Data
,
y1Data
=
getYuQingCharDataThree
()
return
render_template
(
'yuqingChar.html'
,
username
=
username
,
xData
=
X
,
yData
=
Y
,
biedata
=
biedata
,
biedata1
=
biedata1
,
biedata2
=
biedata2
,
x1Data
=
x1Data
,
y1Data
=
y1Data
,
model_type
=
model_type
)
try
:
username
=
session
.
get
(
'username'
)
model_type
=
sanitize_input
(
request
.
args
.
get
(
'model'
,
'pro'
))
# 验证模型类型
if
model_type
not
in
[
'pro'
,
'basic'
]:
return
abort
(
400
,
"无效的模型类型"
)
X
,
Y
,
biedata
=
getYuQingCharDataOne
()
biedata1
,
biedata2
=
getYuQingCharDataTwo
(
model_type
)
x1Data
,
y1Data
=
getYuQingCharDataThree
()
return
render_template
(
'yuqingChar.html'
,
username
=
username
,
xData
=
X
,
yData
=
Y
,
biedata
=
biedata
,
biedata1
=
biedata1
,
biedata2
=
biedata2
,
x1Data
=
x1Data
,
y1Data
=
y1Data
,
model_type
=
model_type
)
except
Exception
as
e
:
logging
.
error
(
f
"加载舆情统计时发生错误: {e}"
)
return
render_template
(
'error.html'
,
error_message
=
"加载舆情统计失败"
)
@pb.route
(
'/yuqingpredict'
)
@login_required
def
yuqingpredict
():
try
:
username
=
session
.
get
(
'username'
)
TopicList
=
getAllTopicData
()
defaultTopic
=
TopicList
[
0
][
0
]
if
request
.
args
.
get
(
'Topic'
):
defaultTopic
=
request
.
args
.
get
(
'Topic'
)
if
not
TopicList
:
return
render_template
(
'error.html'
,
error_message
=
"无法获取话题列表"
)
defaultTopic
=
sanitize_input
(
request
.
args
.
get
(
'Topic'
,
TopicList
[
0
][
0
]))
# 验证话题是否在列表中
if
not
any
(
defaultTopic
in
topic
for
topic
in
TopicList
):
return
abort
(
400
,
"无效的话题"
)
TopicLen
=
getTopicLen
(
defaultTopic
)
X
,
Y
=
getTopicCreatedAtandpredictData
(
defaultTopic
)
# 获取模型选择参数
model_type
=
request
.
args
.
get
(
'model'
,
'pro'
)
# 默认使用改进模型
model_type
=
sanitize_input
(
request
.
args
.
get
(
'model'
,
'pro'
))
if
model_type
not
in
[
'pro'
,
'basic'
]:
return
abort
(
400
,
"无效的模型类型"
)
# 尝试从缓存获取预测结果
cache_key
=
f
"{defaultTopic}_{model_type}"
...
...
@@ -226,7 +326,6 @@ def yuqingpredict():
sentences
=
cached_result
else
:
if
model_type
==
'basic'
:
# 使用基础模型(SnowNLP)
value
=
SnowNLP
(
defaultTopic
)
.
sentiments
if
value
==
0.5
:
sentences
=
'中性'
...
...
@@ -235,7 +334,6 @@ def yuqingpredict():
elif
value
<
0.5
:
sentences
=
'负面'
else
:
# 使用改进模型
predicted_label
,
confidence
=
predict_sentiment
(
defaultTopic
)
if
predicted_label
is
not
None
:
sentences
=
'良好'
if
predicted_label
==
0
else
'不良'
...
...
@@ -248,26 +346,30 @@ def yuqingpredict():
prediction_cache
.
set
(
cache_key
,
sentences
)
comments
=
getCommentFilterDataTopic
(
defaultTopic
)
return
render_template
(
'yuqingpredict.html'
,
username
=
username
,
hotWordList
=
TopicList
,
defaultHotWord
=
defaultTopic
,
hotWordLen
=
TopicLen
,
sentences
=
sentences
,
xData
=
X
,
yData
=
Y
,
comments
=
comments
,
model_type
=
model_type
)
username
=
username
,
TopicList
=
TopicList
,
defaultTopic
=
defaultTopic
,
TopicLen
=
TopicLen
,
sentences
=
sentences
,
xData
=
X
,
yData
=
Y
,
comments
=
comments
,
model_type
=
model_type
)
except
Exception
as
e
:
logging
.
error
(
f
"舆情预测页面渲染失败: {e}"
)
return
render_template
(
'error.html'
,
error_message
=
"加载舆情预测页面失败,请稍后重试"
)
logging
.
error
(
f
"加载舆情预测时发生错误: {e}"
)
return
render_template
(
'error.html'
,
error_message
=
"加载舆情预测失败"
)
@pb.route
(
'/articleCloud'
)
@login_required
def
articleCloud
():
username
=
session
.
get
(
'username'
)
return
render_template
(
'articleContentCloud.html'
,
username
=
username
)
try
:
username
=
session
.
get
(
'username'
)
return
render_template
(
'articleContentCloud.html'
,
username
=
username
)
except
Exception
as
e
:
logging
.
error
(
f
"加载文章云图时发生错误: {e}"
)
return
render_template
(
'error.html'
,
error_message
=
"加载文章云图失败"
)
@pb.route
(
'/page/index'
)
def
index
():
...
...
@@ -306,15 +408,28 @@ def articleChar(id):
return
render_template
(
'error.html'
,
error_message
=
"加载文章详情失败"
)
@pb.route
(
'/api/analyze_messages'
,
methods
=
[
'POST'
])
@api_login_required
@rate_limit
async
def
analyze_messages
():
try
:
# 获取请求参数
if
not
validate_csrf_token
():
return
jsonify
({
'error'
:
'Invalid CSRF token'
}),
403
data
=
request
.
get_json
()
batch_size
=
data
.
get
(
'batch_size'
,
50
)
model_type
=
data
.
get
(
'model_type'
,
'gpt-3.5-turbo'
)
analysis_depth
=
data
.
get
(
'analysis_depth'
,
'standard'
)
if
not
data
:
return
jsonify
({
'error'
:
'No data provided'
}),
400
batch_size
=
min
(
int
(
data
.
get
(
'batch_size'
,
50
)),
100
)
# 限制批量大小
model_type
=
sanitize_input
(
data
.
get
(
'model_type'
,
'gpt-3.5-turbo'
))
analysis_depth
=
sanitize_input
(
data
.
get
(
'analysis_depth'
,
'standard'
))
# 验证参数
if
model_type
not
in
[
'gpt-3.5-turbo'
,
'gpt-4'
]:
return
jsonify
({
'error'
:
'Invalid model type'
}),
400
if
analysis_depth
not
in
[
'basic'
,
'standard'
,
'deep'
]:
return
jsonify
({
'error'
:
'Invalid analysis depth'
}),
400
# 获取最近的消息
messages
=
getRecentMessages
(
batch_size
)
if
not
messages
:
return
jsonify
({
...
...
@@ -322,7 +437,6 @@ async def analyze_messages():
'error'
:
'没有找到需要分析的消息'
}),
404
# 调用AI进行分析
analysis_results
=
await
ai_analyzer
.
analyze_messages
(
messages
=
messages
,
batch_size
=
batch_size
,
...
...
@@ -336,22 +450,27 @@ async def analyze_messages():
'error'
:
'分析过程中出现错误'
}),
500
# 保存到数据库
with
Session
(
engine
)
as
session
:
for
result
in
analysis_results
:
analysis
=
AIAnalysis
(
message_id
=
result
[
'message_id'
],
sentiment
=
result
[
'sentiment'
],
sentiment_score
=
float
(
result
[
'sentiment_score'
]),
keywords
=
result
[
'keywords'
],
key_points
=
result
[
'key_points'
],
influence_analysis
=
result
[
'influence_analysis'
],
risk_level
=
result
[
'risk_level'
]
)
session
.
add
(
analysis
)
session
.
commit
()
try
:
with
Session
(
engine
)
as
session
:
for
result
in
analysis_results
:
analysis
=
AIAnalysis
(
message_id
=
result
[
'message_id'
],
sentiment
=
result
[
'sentiment'
],
sentiment_score
=
float
(
result
[
'sentiment_score'
]),
keywords
=
result
[
'keywords'
],
key_points
=
result
[
'key_points'
],
influence_analysis
=
result
[
'influence_analysis'
],
risk_level
=
result
[
'risk_level'
]
)
session
.
add
(
analysis
)
session
.
commit
()
except
Exception
as
e
:
logging
.
error
(
f
"保存分析结果时出错: {e}"
)
return
jsonify
({
'success'
:
False
,
'error'
:
'保存分析结果失败'
}),
500
# 格式化结果用于显示
display_results
=
[
ai_analyzer
.
format_analysis_for_display
(
result
)
for
result
in
analysis_results
...
...
@@ -359,27 +478,25 @@ async def analyze_messages():
return
jsonify
({
'success'
:
True
,
'data'
:
display_results
,
'meta'
:
{
'total_messages'
:
len
(
messages
),
'analyzed_messages'
:
len
(
analysis_results
),
'batch_size'
:
batch_size
,
'model_type'
:
model_type
,
'analysis_depth'
:
analysis_depth
}
'data'
:
display_results
})
except
Exception
as
e
:
logging
.
error
(
f
"
AI分析过程出错
: {e}"
)
logging
.
error
(
f
"
分析消息时发生错误
: {e}"
)
return
jsonify
({
'success'
:
False
,
'error'
:
str
(
e
)
}),
500
@pb.route
(
'/api/get_analysis/<int:message_id>'
)
@api_login_required
@rate_limit
def
get_message_analysis
(
message_id
):
"""获取特定消息的分析结果"""
try
:
if
not
message_id
or
message_id
<
1
:
return
jsonify
({
'error'
:
'Invalid message ID'
}),
400
with
Session
(
engine
)
as
session
:
analysis
=
session
.
query
(
AIAnalysis
)
\
.
filter
(
AIAnalysis
.
message_id
==
message_id
)
\
...
...
Please
register
or
login
to post a comment