戒酒的李白

Implement a two-level caching system (memory + disk) to optimize topic switch re…

…sponse speed, support asynchronous writing, and automatically clean up expired data.
  1 +import json
  2 +import os
  3 +import time
  4 +from datetime import datetime, timedelta
  5 +import threading
  6 +import queue
  7 +
  8 +class PredictionCache:
  9 + _instance = None
  10 + _lock = threading.Lock()
  11 +
  12 + def __new__(cls):
  13 + with cls._lock:
  14 + if cls._instance is None:
  15 + cls._instance = super(PredictionCache, cls).__new__(cls)
  16 + return cls._instance
  17 +
  18 + def __init__(self):
  19 + if not hasattr(self, 'initialized'):
  20 + self.cache_dir = 'cache/predictions'
  21 + self.cache_duration = timedelta(hours=24) # 缓存24小时
  22 + self.cache = {}
  23 + self.cache_queue = queue.Queue()
  24 + self.initialized = True
  25 +
  26 + # 确保缓存目录存在
  27 + os.makedirs(self.cache_dir, exist_ok=True)
  28 +
  29 + # 启动缓存清理线程
  30 + self.cleanup_thread = threading.Thread(target=self._cleanup_old_cache, daemon=True)
  31 + self.cleanup_thread.start()
  32 +
  33 + # 加载现有缓存
  34 + self._load_cache()
  35 +
  36 + def _load_cache(self):
  37 + """加载磁盘上的缓存文件"""
  38 + try:
  39 + for filename in os.listdir(self.cache_dir):
  40 + if filename.endswith('.json'):
  41 + filepath = os.path.join(self.cache_dir, filename)
  42 + with open(filepath, 'r', encoding='utf-8') as f:
  43 + cache_data = json.load(f)
  44 + # 检查缓存是否过期
  45 + if self._is_cache_valid(cache_data['timestamp']):
  46 + topic = filename[:-5] # 移除.json后缀
  47 + self.cache[topic] = cache_data
  48 + else:
  49 + # 删除过期缓存文件
  50 + os.remove(filepath)
  51 + except Exception as e:
  52 + print(f"加载缓存失败: {e}")
  53 +
  54 + def _cleanup_old_cache(self):
  55 + """定期清理过期缓存的后台线程"""
  56 + while True:
  57 + try:
  58 + # 检查并清理内存缓存
  59 + current_time = datetime.now()
  60 + expired_topics = []
  61 +
  62 + for topic, cache_data in self.cache.items():
  63 + if not self._is_cache_valid(cache_data['timestamp']):
  64 + expired_topics.append(topic)
  65 +
  66 + # 删除过期缓存
  67 + for topic in expired_topics:
  68 + del self.cache[topic]
  69 + cache_file = os.path.join(self.cache_dir, f"{topic}.json")
  70 + if os.path.exists(cache_file):
  71 + os.remove(cache_file)
  72 +
  73 + # 休眠1小时后再次检查
  74 + time.sleep(3600)
  75 + except Exception as e:
  76 + print(f"清理缓存时出错: {e}")
  77 + time.sleep(3600) # 发生错误时也等待1小时
  78 +
  79 + def _is_cache_valid(self, timestamp):
  80 + """检查缓存是否有效"""
  81 + cache_time = datetime.fromtimestamp(timestamp)
  82 + return datetime.now() - cache_time < self.cache_duration
  83 +
  84 + def get(self, topic):
  85 + """获取话题的预测缓存"""
  86 + if topic in self.cache and self._is_cache_valid(self.cache[topic]['timestamp']):
  87 + return self.cache[topic]['prediction']
  88 + return None
  89 +
  90 + def set(self, topic, prediction):
  91 + """设置话题的预测缓存"""
  92 + cache_data = {
  93 + 'prediction': prediction,
  94 + 'timestamp': datetime.now().timestamp()
  95 + }
  96 +
  97 + # 更新内存缓存
  98 + self.cache[topic] = cache_data
  99 +
  100 + # 异步保存到磁盘
  101 + self.cache_queue.put((topic, cache_data))
  102 + threading.Thread(target=self._save_cache_to_disk, daemon=True).start()
  103 +
  104 + def _save_cache_to_disk(self):
  105 + """异步保存缓存到磁盘"""
  106 + try:
  107 + while not self.cache_queue.empty():
  108 + topic, cache_data = self.cache_queue.get()
  109 + cache_file = os.path.join(self.cache_dir, f"{topic}.json")
  110 + with open(cache_file, 'w', encoding='utf-8') as f:
  111 + json.dump(cache_data, f, ensure_ascii=False, indent=2)
  112 + except Exception as e:
  113 + print(f"保存缓存到磁盘失败: {e}")
  114 +
  115 +# 创建全局缓存实例
  116 +prediction_cache = PredictionCache()
@@ -8,6 +8,7 @@ from utils.getEchartsData import * @@ -8,6 +8,7 @@ from utils.getEchartsData import *
8 from utils.getTopicPageData import * 8 from utils.getTopicPageData import *
9 from utils.yuqingpredict import * 9 from utils.yuqingpredict import *
10 from utils.logger import app_logger as logging 10 from utils.logger import app_logger as logging
  11 +from utils.cache_manager import prediction_cache
11 import torch 12 import torch
12 from BCAT_front.predict import model_manager 13 from BCAT_front.predict import model_manager
13 14
@@ -207,6 +208,13 @@ def yuqingpredict(): @@ -207,6 +208,13 @@ def yuqingpredict():
207 # 获取模型选择参数 208 # 获取模型选择参数
208 model_type = request.args.get('model', 'pro') # 默认使用改进模型 209 model_type = request.args.get('model', 'pro') # 默认使用改进模型
209 210
  211 + # 尝试从缓存获取预测结果
  212 + cache_key = f"{defaultTopic}_{model_type}"
  213 + cached_result = prediction_cache.get(cache_key)
  214 +
  215 + if cached_result is not None:
  216 + sentences = cached_result
  217 + else:
210 if model_type == 'basic': 218 if model_type == 'basic':
211 # 使用基础模型(SnowNLP) 219 # 使用基础模型(SnowNLP)
212 value = SnowNLP(defaultTopic).sentiments 220 value = SnowNLP(defaultTopic).sentiments
@@ -226,6 +234,9 @@ def yuqingpredict(): @@ -226,6 +234,9 @@ def yuqingpredict():
226 sentences = '预测失败,请稍后重试' 234 sentences = '预测失败,请稍后重试'
227 logging.error(f"预测失败,话题: {defaultTopic}") 235 logging.error(f"预测失败,话题: {defaultTopic}")
228 236
  237 + # 将结果存入缓存
  238 + prediction_cache.set(cache_key, sentences)
  239 +
229 comments = getCommentFilterDataTopic(defaultTopic) 240 comments = getCommentFilterDataTopic(defaultTopic)
230 return render_template('yuqingpredict.html', 241 return render_template('yuqingpredict.html',
231 username=username, 242 username=username,