Optimize database connection management in spider module, fix WebSocket handling…
… and parameter validation.
Showing
4 changed files
with
104 additions
and
16 deletions
| @@ -9,6 +9,7 @@ from pytz import utc | @@ -9,6 +9,7 @@ from pytz import utc | ||
| 9 | from datetime import datetime, timedelta | 9 | from datetime import datetime, timedelta |
| 10 | import time | 10 | import time |
| 11 | from utils.logger import app_logger as logging | 11 | from utils.logger import app_logger as logging |
| 12 | +from utils.db_manager import DatabaseManager | ||
| 12 | 13 | ||
| 13 | def get_db_connection_interactive(): | 14 | def get_db_connection_interactive(): |
| 14 | """ | 15 | """ |
| @@ -232,6 +233,9 @@ DB_CONFIG = { | @@ -232,6 +233,9 @@ DB_CONFIG = { | ||
| 232 | 'charset': 'utf8mb4' | 233 | 'charset': 'utf8mb4' |
| 233 | } | 234 | } |
| 234 | 235 | ||
| 236 | +# 初始化数据库管理器 | ||
| 237 | +DatabaseManager.initialize(DB_CONFIG) | ||
| 238 | + | ||
| 235 | # 主程序入口 | 239 | # 主程序入口 |
| 236 | if __name__ == '__main__': | 240 | if __name__ == '__main__': |
| 237 | # 检测是否需要初始化数据库 | 241 | # 检测是否需要初始化数据库 |
| @@ -10,6 +10,7 @@ import logging | @@ -10,6 +10,7 @@ import logging | ||
| 10 | from bs4 import BeautifulSoup | 10 | from bs4 import BeautifulSoup |
| 11 | from datetime import datetime | 11 | from datetime import datetime |
| 12 | from utils.logger import spider_logger as logging | 12 | from utils.logger import spider_logger as logging |
| 13 | +from utils.db_manager import DatabaseManager | ||
| 13 | 14 | ||
| 14 | def spiderData(): | 15 | def spiderData(): |
| 15 | if not os.path.exists(navAddr): | 16 | if not os.path.exists(navAddr): |
| @@ -26,6 +27,7 @@ class SpiderData: | @@ -26,6 +27,7 @@ class SpiderData: | ||
| 26 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | 27 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' |
| 27 | } | 28 | } |
| 28 | self.base_url = 'https://s.weibo.com' | 29 | self.base_url = 'https://s.weibo.com' |
| 30 | + self.db = DatabaseManager() | ||
| 29 | 31 | ||
| 30 | def crawl_topic(self, topic, depth=3, interval=5, max_retries=3, timeout=30): | 32 | def crawl_topic(self, topic, depth=3, interval=5, max_retries=3, timeout=30): |
| 31 | """ | 33 | """ |
| @@ -37,7 +39,17 @@ class SpiderData: | @@ -37,7 +39,17 @@ class SpiderData: | ||
| 37 | :param max_retries: 最大重试次数 | 39 | :param max_retries: 最大重试次数 |
| 38 | :param timeout: 请求超时时间(秒) | 40 | :param timeout: 请求超时时间(秒) |
| 39 | """ | 41 | """ |
| 40 | - logging.info(f"开始爬取话题: {topic}") | 42 | + # 参数验证 |
| 43 | + if not isinstance(depth, int) or depth < 1 or depth > 10: | ||
| 44 | + raise ValueError("爬取深度必须在1-10页之间") | ||
| 45 | + if not isinstance(interval, int) or interval < 3 or interval > 30: | ||
| 46 | + raise ValueError("请求间隔必须在3-30秒之间") | ||
| 47 | + if not isinstance(max_retries, int) or max_retries < 1 or max_retries > 5: | ||
| 48 | + raise ValueError("最大重试次数必须在1-5次之间") | ||
| 49 | + if not isinstance(timeout, int) or timeout < 10 or timeout > 60: | ||
| 50 | + raise ValueError("请求超时时间必须在10-60秒之间") | ||
| 51 | + | ||
| 52 | + logging.info(f"开始爬取话题: {topic}, 参数: depth={depth}, interval={interval}, max_retries={max_retries}, timeout={timeout}") | ||
| 41 | 53 | ||
| 42 | for page in range(1, depth + 1): | 54 | for page in range(1, depth + 1): |
| 43 | retries = 0 | 55 | retries = 0 |
| @@ -140,11 +152,34 @@ class SpiderData: | @@ -140,11 +152,34 @@ class SpiderData: | ||
| 140 | 152 | ||
| 141 | :param data: 要保存的数据字典 | 153 | :param data: 要保存的数据字典 |
| 142 | """ | 154 | """ |
| 155 | + connection = None | ||
| 143 | try: | 156 | try: |
| 144 | - # TODO: 实现数据库保存逻辑 | ||
| 145 | - logging.info(f"保存数据: {data}") | 157 | + connection = self.db.get_connection() |
| 158 | + | ||
| 159 | + with connection.cursor() as cursor: | ||
| 160 | + # 插入文章数据 | ||
| 161 | + sql = """ | ||
| 162 | + INSERT INTO article (content, user_name, publish_time, forward_count, | ||
| 163 | + comment_count, like_count, crawl_time) | ||
| 164 | + VALUES (%s, %s, %s, %s, %s, %s, %s) | ||
| 165 | + """ | ||
| 166 | + cursor.execute(sql, ( | ||
| 167 | + data['content'], | ||
| 168 | + data['user_name'], | ||
| 169 | + data['publish_time'], | ||
| 170 | + data['forward_count'], | ||
| 171 | + data['comment_count'], | ||
| 172 | + data['like_count'], | ||
| 173 | + data['crawl_time'] | ||
| 174 | + )) | ||
| 175 | + | ||
| 176 | + connection.commit() | ||
| 177 | + logging.info(f"成功保存微博数据: {data['content'][:30]}...") | ||
| 178 | + | ||
| 146 | except Exception as e: | 179 | except Exception as e: |
| 147 | logging.error(f"保存数据时出错: {e}") | 180 | logging.error(f"保存数据时出错: {e}") |
| 181 | + if connection: | ||
| 182 | + connection.rollback() | ||
| 148 | 183 | ||
| 149 | if __name__ == '__main__': | 184 | if __name__ == '__main__': |
| 150 | spiderData() | 185 | spiderData() |
utils/db_manager.py
0 → 100644
| 1 | +import pymysql | ||
| 2 | +from pymysql.cursors import DictCursor | ||
| 3 | + | ||
| 4 | +class DatabaseManager: | ||
| 5 | + _instance = None | ||
| 6 | + _connection = None | ||
| 7 | + _config = None | ||
| 8 | + | ||
| 9 | + def __new__(cls): | ||
| 10 | + if cls._instance is None: | ||
| 11 | + cls._instance = super(DatabaseManager, cls).__new__(cls) | ||
| 12 | + return cls._instance | ||
| 13 | + | ||
| 14 | + @classmethod | ||
| 15 | + def initialize(cls, config): | ||
| 16 | + """初始化数据库配置""" | ||
| 17 | + cls._config = config | ||
| 18 | + | ||
| 19 | + @classmethod | ||
| 20 | + def get_connection(cls): | ||
| 21 | + """获取数据库连接""" | ||
| 22 | + if cls._connection is None or not cls._connection.open: | ||
| 23 | + if cls._config is None: | ||
| 24 | + raise ValueError("数据库未初始化,请先调用initialize方法设置配置") | ||
| 25 | + cls._connection = pymysql.connect( | ||
| 26 | + **cls._config, | ||
| 27 | + cursorclass=DictCursor | ||
| 28 | + ) | ||
| 29 | + return cls._connection | ||
| 30 | + | ||
| 31 | + @classmethod | ||
| 32 | + def close(cls): | ||
| 33 | + """关闭数据库连接""" | ||
| 34 | + if cls._connection and cls._connection.open: | ||
| 35 | + cls._connection.close() | ||
| 36 | + cls._connection = None |
| @@ -73,6 +73,15 @@ def spider_worker(topics, parameters): | @@ -73,6 +73,15 @@ def spider_worker(topics, parameters): | ||
| 73 | total_topics = len(topics) | 73 | total_topics = len(topics) |
| 74 | completed_topics = 0 | 74 | completed_topics = 0 |
| 75 | 75 | ||
| 76 | + async def send_message(message): | ||
| 77 | + """异步发送消息的包装函数""" | ||
| 78 | + loop = asyncio.new_event_loop() | ||
| 79 | + asyncio.set_event_loop(loop) | ||
| 80 | + try: | ||
| 81 | + await broadcast_message(message) | ||
| 82 | + finally: | ||
| 83 | + loop.close() | ||
| 84 | + | ||
| 76 | try: | 85 | try: |
| 77 | spider = SpiderData() | 86 | spider = SpiderData() |
| 78 | 87 | ||
| @@ -80,13 +89,13 @@ def spider_worker(topics, parameters): | @@ -80,13 +89,13 @@ def spider_worker(topics, parameters): | ||
| 80 | try: | 89 | try: |
| 81 | # 更新进度 | 90 | # 更新进度 |
| 82 | progress = int((completed_topics / total_topics) * 100) | 91 | progress = int((completed_topics / total_topics) * 100) |
| 83 | - asyncio.run(broadcast_message({ | 92 | + asyncio.run(send_message({ |
| 84 | 'type': 'progress', | 93 | 'type': 'progress', |
| 85 | 'value': progress | 94 | 'value': progress |
| 86 | })) | 95 | })) |
| 87 | 96 | ||
| 88 | # 发送开始爬取的日志 | 97 | # 发送开始爬取的日志 |
| 89 | - asyncio.run(broadcast_message({ | 98 | + asyncio.run(send_message({ |
| 90 | 'type': 'log', | 99 | 'type': 'log', |
| 91 | 'message': f'开始爬取话题: {topic}' | 100 | 'message': f'开始爬取话题: {topic}' |
| 92 | })) | 101 | })) |
| @@ -103,33 +112,33 @@ def spider_worker(topics, parameters): | @@ -103,33 +112,33 @@ def spider_worker(topics, parameters): | ||
| 103 | completed_topics += 1 | 112 | completed_topics += 1 |
| 104 | 113 | ||
| 105 | # 发送完成爬取的日志 | 114 | # 发送完成爬取的日志 |
| 106 | - asyncio.run(broadcast_message({ | 115 | + asyncio.run(send_message({ |
| 107 | 'type': 'log', | 116 | 'type': 'log', |
| 108 | 'message': f'话题 {topic} 爬取完成' | 117 | 'message': f'话题 {topic} 爬取完成' |
| 109 | })) | 118 | })) |
| 110 | 119 | ||
| 111 | except Exception as e: | 120 | except Exception as e: |
| 112 | # 发送错误日志 | 121 | # 发送错误日志 |
| 113 | - asyncio.run(broadcast_message({ | 122 | + asyncio.run(send_message({ |
| 114 | 'type': 'log', | 123 | 'type': 'log', |
| 115 | 'message': f'爬取话题 {topic} 时出错: {str(e)}' | 124 | 'message': f'爬取话题 {topic} 时出错: {str(e)}' |
| 116 | })) | 125 | })) |
| 117 | 126 | ||
| 118 | # 更新最终进度 | 127 | # 更新最终进度 |
| 119 | - asyncio.run(broadcast_message({ | 128 | + asyncio.run(send_message({ |
| 120 | 'type': 'progress', | 129 | 'type': 'progress', |
| 121 | 'value': 100 | 130 | 'value': 100 |
| 122 | })) | 131 | })) |
| 123 | 132 | ||
| 124 | # 发送完成消息 | 133 | # 发送完成消息 |
| 125 | - asyncio.run(broadcast_message({ | 134 | + asyncio.run(send_message({ |
| 126 | 'type': 'log', | 135 | 'type': 'log', |
| 127 | 'message': '所有话题爬取完成' | 136 | 'message': '所有话题爬取完成' |
| 128 | })) | 137 | })) |
| 129 | 138 | ||
| 130 | except Exception as e: | 139 | except Exception as e: |
| 131 | # 发送错误日志 | 140 | # 发送错误日志 |
| 132 | - asyncio.run(broadcast_message({ | 141 | + asyncio.run(send_message({ |
| 133 | 'type': 'log', | 142 | 'type': 'log', |
| 134 | 'message': f'爬虫任务执行出错: {str(e)}' | 143 | 'message': f'爬虫任务执行出错: {str(e)}' |
| 135 | })) | 144 | })) |
| @@ -196,23 +205,27 @@ def save_spider_config(): | @@ -196,23 +205,27 @@ def save_spider_config(): | ||
| 196 | }) | 205 | }) |
| 197 | 206 | ||
| 198 | @spider_bp.websocket('/ws/spider-status') | 207 | @spider_bp.websocket('/ws/spider-status') |
| 199 | -async def spider_status_socket(): | 208 | +async def spider_status_socket(websocket): |
| 200 | """WebSocket连接处理""" | 209 | """WebSocket连接处理""" |
| 201 | try: | 210 | try: |
| 202 | - websocket = websockets.WebSocketServerProtocol() | ||
| 203 | websocket_connections.add(websocket) | 211 | websocket_connections.add(websocket) |
| 212 | + logging.info("新的WebSocket连接已建立") | ||
| 204 | 213 | ||
| 205 | try: | 214 | try: |
| 206 | while True: | 215 | while True: |
| 207 | - # 保持连接活跃 | ||
| 208 | - await websocket.ping() | ||
| 209 | - await asyncio.sleep(30) | 216 | + # 等待消息,保持连接活跃 |
| 217 | + message = await websocket.receive() | ||
| 218 | + if message is None: | ||
| 219 | + break | ||
| 210 | except websockets.exceptions.ConnectionClosed: | 220 | except websockets.exceptions.ConnectionClosed: |
| 211 | - pass | 221 | + logging.info("WebSocket连接已关闭") |
| 212 | finally: | 222 | finally: |
| 213 | websocket_connections.remove(websocket) | 223 | websocket_connections.remove(websocket) |
| 224 | + logging.info("WebSocket连接已移除") | ||
| 214 | except Exception as e: | 225 | except Exception as e: |
| 215 | logger.error(f"WebSocket连接处理失败: {e}") | 226 | logger.error(f"WebSocket连接处理失败: {e}") |
| 227 | + if websocket in websocket_connections: | ||
| 228 | + websocket_connections.remove(websocket) | ||
| 216 | 229 | ||
| 217 | def get_ai_client(): | 230 | def get_ai_client(): |
| 218 | """获取可用的AI客户端""" | 231 | """获取可用的AI客户端""" |
-
Please register or login to post a comment