You need to sign in or sign up before continuing.
戒酒的李白

Optimize database connection management in spider module, fix WebSocket handling…

… and parameter validation.
... ... @@ -9,6 +9,7 @@ from pytz import utc
from datetime import datetime, timedelta
import time
from utils.logger import app_logger as logging
from utils.db_manager import DatabaseManager
def get_db_connection_interactive():
"""
... ... @@ -232,6 +233,9 @@ DB_CONFIG = {
'charset': 'utf8mb4'
}
# 初始化数据库管理器
DatabaseManager.initialize(DB_CONFIG)
# 主程序入口
if __name__ == '__main__':
# 检测是否需要初始化数据库
... ...
... ... @@ -10,6 +10,7 @@ import logging
from bs4 import BeautifulSoup
from datetime import datetime
from utils.logger import spider_logger as logging
from utils.db_manager import DatabaseManager
def spiderData():
if not os.path.exists(navAddr):
... ... @@ -26,6 +27,7 @@ class SpiderData:
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
self.base_url = 'https://s.weibo.com'
self.db = DatabaseManager()
def crawl_topic(self, topic, depth=3, interval=5, max_retries=3, timeout=30):
"""
... ... @@ -37,7 +39,17 @@ class SpiderData:
:param max_retries: 最大重试次数
:param timeout: 请求超时时间(秒)
"""
logging.info(f"开始爬取话题: {topic}")
# 参数验证
if not isinstance(depth, int) or depth < 1 or depth > 10:
raise ValueError("爬取深度必须在1-10页之间")
if not isinstance(interval, int) or interval < 3 or interval > 30:
raise ValueError("请求间隔必须在3-30秒之间")
if not isinstance(max_retries, int) or max_retries < 1 or max_retries > 5:
raise ValueError("最大重试次数必须在1-5次之间")
if not isinstance(timeout, int) or timeout < 10 or timeout > 60:
raise ValueError("请求超时时间必须在10-60秒之间")
logging.info(f"开始爬取话题: {topic}, 参数: depth={depth}, interval={interval}, max_retries={max_retries}, timeout={timeout}")
for page in range(1, depth + 1):
retries = 0
... ... @@ -140,11 +152,34 @@ class SpiderData:
:param data: 要保存的数据字典
"""
connection = None
try:
# TODO: 实现数据库保存逻辑
logging.info(f"保存数据: {data}")
connection = self.db.get_connection()
with connection.cursor() as cursor:
# 插入文章数据
sql = """
INSERT INTO article (content, user_name, publish_time, forward_count,
comment_count, like_count, crawl_time)
VALUES (%s, %s, %s, %s, %s, %s, %s)
"""
cursor.execute(sql, (
data['content'],
data['user_name'],
data['publish_time'],
data['forward_count'],
data['comment_count'],
data['like_count'],
data['crawl_time']
))
connection.commit()
logging.info(f"成功保存微博数据: {data['content'][:30]}...")
except Exception as e:
logging.error(f"保存数据时出错: {e}")
if connection:
connection.rollback()
if __name__ == '__main__':
spiderData()
\ No newline at end of file
... ...
import pymysql
from pymysql.cursors import DictCursor
class DatabaseManager:
_instance = None
_connection = None
_config = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(DatabaseManager, cls).__new__(cls)
return cls._instance
@classmethod
def initialize(cls, config):
"""初始化数据库配置"""
cls._config = config
@classmethod
def get_connection(cls):
"""获取数据库连接"""
if cls._connection is None or not cls._connection.open:
if cls._config is None:
raise ValueError("数据库未初始化,请先调用initialize方法设置配置")
cls._connection = pymysql.connect(
**cls._config,
cursorclass=DictCursor
)
return cls._connection
@classmethod
def close(cls):
"""关闭数据库连接"""
if cls._connection and cls._connection.open:
cls._connection.close()
cls._connection = None
\ No newline at end of file
... ...
... ... @@ -73,6 +73,15 @@ def spider_worker(topics, parameters):
total_topics = len(topics)
completed_topics = 0
async def send_message(message):
"""异步发送消息的包装函数"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
await broadcast_message(message)
finally:
loop.close()
try:
spider = SpiderData()
... ... @@ -80,13 +89,13 @@ def spider_worker(topics, parameters):
try:
# 更新进度
progress = int((completed_topics / total_topics) * 100)
asyncio.run(broadcast_message({
asyncio.run(send_message({
'type': 'progress',
'value': progress
}))
# 发送开始爬取的日志
asyncio.run(broadcast_message({
asyncio.run(send_message({
'type': 'log',
'message': f'开始爬取话题: {topic}'
}))
... ... @@ -103,33 +112,33 @@ def spider_worker(topics, parameters):
completed_topics += 1
# 发送完成爬取的日志
asyncio.run(broadcast_message({
asyncio.run(send_message({
'type': 'log',
'message': f'话题 {topic} 爬取完成'
}))
except Exception as e:
# 发送错误日志
asyncio.run(broadcast_message({
asyncio.run(send_message({
'type': 'log',
'message': f'爬取话题 {topic} 时出错: {str(e)}'
}))
# 更新最终进度
asyncio.run(broadcast_message({
asyncio.run(send_message({
'type': 'progress',
'value': 100
}))
# 发送完成消息
asyncio.run(broadcast_message({
asyncio.run(send_message({
'type': 'log',
'message': '所有话题爬取完成'
}))
except Exception as e:
# 发送错误日志
asyncio.run(broadcast_message({
asyncio.run(send_message({
'type': 'log',
'message': f'爬虫任务执行出错: {str(e)}'
}))
... ... @@ -196,23 +205,27 @@ def save_spider_config():
})
@spider_bp.websocket('/ws/spider-status')
async def spider_status_socket():
async def spider_status_socket(websocket):
"""WebSocket连接处理"""
try:
websocket = websockets.WebSocketServerProtocol()
websocket_connections.add(websocket)
logging.info("新的WebSocket连接已建立")
try:
while True:
# 保持连接活跃
await websocket.ping()
await asyncio.sleep(30)
# 等待消息,保持连接活跃
message = await websocket.receive()
if message is None:
break
except websockets.exceptions.ConnectionClosed:
pass
logging.info("WebSocket连接已关闭")
finally:
websocket_connections.remove(websocket)
logging.info("WebSocket连接已移除")
except Exception as e:
logger.error(f"WebSocket连接处理失败: {e}")
if websocket in websocket_connections:
websocket_connections.remove(websocket)
def get_ai_client():
"""获取可用的AI客户端"""
... ...