戒酒的李白

Optimize database connection management in spider module, fix WebSocket handling…

… and parameter validation.
@@ -9,6 +9,7 @@ from pytz import utc @@ -9,6 +9,7 @@ from pytz import utc
9 from datetime import datetime, timedelta 9 from datetime import datetime, timedelta
10 import time 10 import time
11 from utils.logger import app_logger as logging 11 from utils.logger import app_logger as logging
  12 +from utils.db_manager import DatabaseManager
12 13
13 def get_db_connection_interactive(): 14 def get_db_connection_interactive():
14 """ 15 """
@@ -232,6 +233,9 @@ DB_CONFIG = { @@ -232,6 +233,9 @@ DB_CONFIG = {
232 'charset': 'utf8mb4' 233 'charset': 'utf8mb4'
233 } 234 }
234 235
  236 +# 初始化数据库管理器
  237 +DatabaseManager.initialize(DB_CONFIG)
  238 +
235 # 主程序入口 239 # 主程序入口
236 if __name__ == '__main__': 240 if __name__ == '__main__':
237 # 检测是否需要初始化数据库 241 # 检测是否需要初始化数据库
@@ -10,6 +10,7 @@ import logging @@ -10,6 +10,7 @@ import logging
10 from bs4 import BeautifulSoup 10 from bs4 import BeautifulSoup
11 from datetime import datetime 11 from datetime import datetime
12 from utils.logger import spider_logger as logging 12 from utils.logger import spider_logger as logging
  13 +from utils.db_manager import DatabaseManager
13 14
14 def spiderData(): 15 def spiderData():
15 if not os.path.exists(navAddr): 16 if not os.path.exists(navAddr):
@@ -26,6 +27,7 @@ class SpiderData: @@ -26,6 +27,7 @@ class SpiderData:
26 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' 27 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
27 } 28 }
28 self.base_url = 'https://s.weibo.com' 29 self.base_url = 'https://s.weibo.com'
  30 + self.db = DatabaseManager()
29 31
30 def crawl_topic(self, topic, depth=3, interval=5, max_retries=3, timeout=30): 32 def crawl_topic(self, topic, depth=3, interval=5, max_retries=3, timeout=30):
31 """ 33 """
@@ -37,7 +39,17 @@ class SpiderData: @@ -37,7 +39,17 @@ class SpiderData:
37 :param max_retries: 最大重试次数 39 :param max_retries: 最大重试次数
38 :param timeout: 请求超时时间(秒) 40 :param timeout: 请求超时时间(秒)
39 """ 41 """
40 - logging.info(f"开始爬取话题: {topic}") 42 + # 参数验证
  43 + if not isinstance(depth, int) or depth < 1 or depth > 10:
  44 + raise ValueError("爬取深度必须在1-10页之间")
  45 + if not isinstance(interval, int) or interval < 3 or interval > 30:
  46 + raise ValueError("请求间隔必须在3-30秒之间")
  47 + if not isinstance(max_retries, int) or max_retries < 1 or max_retries > 5:
  48 + raise ValueError("最大重试次数必须在1-5次之间")
  49 + if not isinstance(timeout, int) or timeout < 10 or timeout > 60:
  50 + raise ValueError("请求超时时间必须在10-60秒之间")
  51 +
  52 + logging.info(f"开始爬取话题: {topic}, 参数: depth={depth}, interval={interval}, max_retries={max_retries}, timeout={timeout}")
41 53
42 for page in range(1, depth + 1): 54 for page in range(1, depth + 1):
43 retries = 0 55 retries = 0
@@ -140,11 +152,34 @@ class SpiderData: @@ -140,11 +152,34 @@ class SpiderData:
140 152
141 :param data: 要保存的数据字典 153 :param data: 要保存的数据字典
142 """ 154 """
  155 + connection = None
143 try: 156 try:
144 - # TODO: 实现数据库保存逻辑  
145 - logging.info(f"保存数据: {data}") 157 + connection = self.db.get_connection()
  158 +
  159 + with connection.cursor() as cursor:
  160 + # 插入文章数据
  161 + sql = """
  162 + INSERT INTO article (content, user_name, publish_time, forward_count,
  163 + comment_count, like_count, crawl_time)
  164 + VALUES (%s, %s, %s, %s, %s, %s, %s)
  165 + """
  166 + cursor.execute(sql, (
  167 + data['content'],
  168 + data['user_name'],
  169 + data['publish_time'],
  170 + data['forward_count'],
  171 + data['comment_count'],
  172 + data['like_count'],
  173 + data['crawl_time']
  174 + ))
  175 +
  176 + connection.commit()
  177 + logging.info(f"成功保存微博数据: {data['content'][:30]}...")
  178 +
146 except Exception as e: 179 except Exception as e:
147 logging.error(f"保存数据时出错: {e}") 180 logging.error(f"保存数据时出错: {e}")
  181 + if connection:
  182 + connection.rollback()
148 183
149 if __name__ == '__main__': 184 if __name__ == '__main__':
150 spiderData() 185 spiderData()
  1 +import pymysql
  2 +from pymysql.cursors import DictCursor
  3 +
  4 +class DatabaseManager:
  5 + _instance = None
  6 + _connection = None
  7 + _config = None
  8 +
  9 + def __new__(cls):
  10 + if cls._instance is None:
  11 + cls._instance = super(DatabaseManager, cls).__new__(cls)
  12 + return cls._instance
  13 +
  14 + @classmethod
  15 + def initialize(cls, config):
  16 + """初始化数据库配置"""
  17 + cls._config = config
  18 +
  19 + @classmethod
  20 + def get_connection(cls):
  21 + """获取数据库连接"""
  22 + if cls._connection is None or not cls._connection.open:
  23 + if cls._config is None:
  24 + raise ValueError("数据库未初始化,请先调用initialize方法设置配置")
  25 + cls._connection = pymysql.connect(
  26 + **cls._config,
  27 + cursorclass=DictCursor
  28 + )
  29 + return cls._connection
  30 +
  31 + @classmethod
  32 + def close(cls):
  33 + """关闭数据库连接"""
  34 + if cls._connection and cls._connection.open:
  35 + cls._connection.close()
  36 + cls._connection = None
@@ -73,6 +73,15 @@ def spider_worker(topics, parameters): @@ -73,6 +73,15 @@ def spider_worker(topics, parameters):
73 total_topics = len(topics) 73 total_topics = len(topics)
74 completed_topics = 0 74 completed_topics = 0
75 75
  76 + async def send_message(message):
  77 + """异步发送消息的包装函数"""
  78 + loop = asyncio.new_event_loop()
  79 + asyncio.set_event_loop(loop)
  80 + try:
  81 + await broadcast_message(message)
  82 + finally:
  83 + loop.close()
  84 +
76 try: 85 try:
77 spider = SpiderData() 86 spider = SpiderData()
78 87
@@ -80,13 +89,13 @@ def spider_worker(topics, parameters): @@ -80,13 +89,13 @@ def spider_worker(topics, parameters):
80 try: 89 try:
81 # 更新进度 90 # 更新进度
82 progress = int((completed_topics / total_topics) * 100) 91 progress = int((completed_topics / total_topics) * 100)
83 - asyncio.run(broadcast_message({ 92 + asyncio.run(send_message({
84 'type': 'progress', 93 'type': 'progress',
85 'value': progress 94 'value': progress
86 })) 95 }))
87 96
88 # 发送开始爬取的日志 97 # 发送开始爬取的日志
89 - asyncio.run(broadcast_message({ 98 + asyncio.run(send_message({
90 'type': 'log', 99 'type': 'log',
91 'message': f'开始爬取话题: {topic}' 100 'message': f'开始爬取话题: {topic}'
92 })) 101 }))
@@ -103,33 +112,33 @@ def spider_worker(topics, parameters): @@ -103,33 +112,33 @@ def spider_worker(topics, parameters):
103 completed_topics += 1 112 completed_topics += 1
104 113
105 # 发送完成爬取的日志 114 # 发送完成爬取的日志
106 - asyncio.run(broadcast_message({ 115 + asyncio.run(send_message({
107 'type': 'log', 116 'type': 'log',
108 'message': f'话题 {topic} 爬取完成' 117 'message': f'话题 {topic} 爬取完成'
109 })) 118 }))
110 119
111 except Exception as e: 120 except Exception as e:
112 # 发送错误日志 121 # 发送错误日志
113 - asyncio.run(broadcast_message({ 122 + asyncio.run(send_message({
114 'type': 'log', 123 'type': 'log',
115 'message': f'爬取话题 {topic} 时出错: {str(e)}' 124 'message': f'爬取话题 {topic} 时出错: {str(e)}'
116 })) 125 }))
117 126
118 # 更新最终进度 127 # 更新最终进度
119 - asyncio.run(broadcast_message({ 128 + asyncio.run(send_message({
120 'type': 'progress', 129 'type': 'progress',
121 'value': 100 130 'value': 100
122 })) 131 }))
123 132
124 # 发送完成消息 133 # 发送完成消息
125 - asyncio.run(broadcast_message({ 134 + asyncio.run(send_message({
126 'type': 'log', 135 'type': 'log',
127 'message': '所有话题爬取完成' 136 'message': '所有话题爬取完成'
128 })) 137 }))
129 138
130 except Exception as e: 139 except Exception as e:
131 # 发送错误日志 140 # 发送错误日志
132 - asyncio.run(broadcast_message({ 141 + asyncio.run(send_message({
133 'type': 'log', 142 'type': 'log',
134 'message': f'爬虫任务执行出错: {str(e)}' 143 'message': f'爬虫任务执行出错: {str(e)}'
135 })) 144 }))
@@ -196,23 +205,27 @@ def save_spider_config(): @@ -196,23 +205,27 @@ def save_spider_config():
196 }) 205 })
197 206
198 @spider_bp.websocket('/ws/spider-status') 207 @spider_bp.websocket('/ws/spider-status')
199 -async def spider_status_socket(): 208 +async def spider_status_socket(websocket):
200 """WebSocket连接处理""" 209 """WebSocket连接处理"""
201 try: 210 try:
202 - websocket = websockets.WebSocketServerProtocol()  
203 websocket_connections.add(websocket) 211 websocket_connections.add(websocket)
  212 + logging.info("新的WebSocket连接已建立")
204 213
205 try: 214 try:
206 while True: 215 while True:
207 - # 保持连接活跃  
208 - await websocket.ping()  
209 - await asyncio.sleep(30) 216 + # 等待消息,保持连接活跃
  217 + message = await websocket.receive()
  218 + if message is None:
  219 + break
210 except websockets.exceptions.ConnectionClosed: 220 except websockets.exceptions.ConnectionClosed:
211 - pass 221 + logging.info("WebSocket连接已关闭")
212 finally: 222 finally:
213 websocket_connections.remove(websocket) 223 websocket_connections.remove(websocket)
  224 + logging.info("WebSocket连接已移除")
214 except Exception as e: 225 except Exception as e:
215 logger.error(f"WebSocket连接处理失败: {e}") 226 logger.error(f"WebSocket连接处理失败: {e}")
  227 + if websocket in websocket_connections:
  228 + websocket_connections.remove(websocket)
216 229
217 def get_ai_client(): 230 def get_ai_client():
218 """获取可用的AI客户端""" 231 """获取可用的AI客户端"""