thread_manager.py 8.04 KB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-02 11:24:08
线程管理器 - 提供增强的线程功能
"""

import threading
import time
import traceback
from typing import Callable, Any, Optional
from utils.util import log


class MyThread(threading.Thread):
    """增强的线程类,提供更好的错误处理和监控功能"""
    
    def __init__(self, target: Optional[Callable] = None, name: Optional[str] = None, 
                 args: tuple = (), kwargs: dict = None, daemon: bool = True):
        """
        初始化线程
        
        Args:
            target: 目标函数
            name: 线程名称
            args: 位置参数
            kwargs: 关键字参数
            daemon: 是否为守护线程
        """
        super().__init__(target=target, name=name, args=args, kwargs=kwargs or {}, daemon=daemon)
        self._target = target
        self._args = args
        self._kwargs = kwargs or {}
        self._result = None
        self._exception = None
        self._start_time = None
        self._end_time = None
        self._running = False
    
    def run(self):
        """重写run方法,添加错误处理和监控"""
        self._start_time = time.time()
        self._running = True
        
        try:
            if self._target:
                log(1, f"线程 {self.name} 开始执行")
                self._result = self._target(*self._args, **self._kwargs)
                log(1, f"线程 {self.name} 执行完成")
        except Exception as e:
            self._exception = e
            log(3, f"线程 {self.name} 执行出错: {e}")
            log(3, f"错误详情: {traceback.format_exc()}")
        finally:
            self._end_time = time.time()
            self._running = False
            duration = self._end_time - self._start_time
            log(1, f"线程 {self.name} 运行时长: {duration:.2f}秒")
    
    def get_result(self) -> Any:
        """获取线程执行结果"""
        if self.is_alive():
            raise RuntimeError("线程仍在运行中")
        
        if self._exception:
            raise self._exception
        
        return self._result
    
    def get_exception(self) -> Optional[Exception]:
        """获取线程执行过程中的异常"""
        return self._exception
    
    def get_duration(self) -> Optional[float]:
        """获取线程运行时长(秒)"""
        if self._start_time is None:
            return None
        
        end_time = self._end_time or time.time()
        return end_time - self._start_time
    
    def is_running(self) -> bool:
        """检查线程是否正在运行"""
        return self._running and self.is_alive()
    
    def stop_gracefully(self, timeout: float = 5.0) -> bool:
        """优雅地停止线程
        
        Args:
            timeout: 等待超时时间(秒)
        
        Returns:
            bool: 是否成功停止
        """
        if not self.is_alive():
            return True
        
        log(1, f"正在停止线程 {self.name}")
        
        # 等待线程自然结束
        self.join(timeout=timeout)
        
        if self.is_alive():
            log(2, f"线程 {self.name} 在 {timeout} 秒内未能自然结束")
            return False
        else:
            log(1, f"线程 {self.name} 已成功停止")
            return True
    
    def __str__(self) -> str:
        """字符串表示"""
        status = "运行中" if self.is_running() else "已停止"
        duration = self.get_duration()
        duration_str = f", 运行时长: {duration:.2f}秒" if duration else ""
        return f"MyThread(name={self.name}, status={status}{duration_str})"
    
    def __repr__(self) -> str:
        """详细字符串表示"""
        return self.__str__()


class ThreadManager:
    """线程管理器,用于管理多个线程"""
    
    def __init__(self):
        self._threads = {}
        self._lock = threading.Lock()
    
    def create_thread(self, name: str, target: Callable, args: tuple = (), 
                     kwargs: dict = None, daemon: bool = True) -> MyThread:
        """创建新线程
        
        Args:
            name: 线程名称
            target: 目标函数
            args: 位置参数
            kwargs: 关键字参数
            daemon: 是否为守护线程
        
        Returns:
            MyThread: 创建的线程对象
        """
        with self._lock:
            if name in self._threads:
                raise ValueError(f"线程名称 '{name}' 已存在")
            
            thread = MyThread(target=target, name=name, args=args, 
                            kwargs=kwargs or {}, daemon=daemon)
            self._threads[name] = thread
            return thread
    
    def start_thread(self, name: str) -> bool:
        """启动指定线程
        
        Args:
            name: 线程名称
        
        Returns:
            bool: 是否成功启动
        """
        with self._lock:
            if name not in self._threads:
                log(3, f"线程 '{name}' 不存在")
                return False
            
            thread = self._threads[name]
            if thread.is_alive():
                log(2, f"线程 '{name}' 已在运行中")
                return False
            
            try:
                thread.start()
                log(1, f"线程 '{name}' 启动成功")
                return True
            except Exception as e:
                log(3, f"启动线程 '{name}' 失败: {e}")
                return False
    
    def stop_thread(self, name: str, timeout: float = 5.0) -> bool:
        """停止指定线程
        
        Args:
            name: 线程名称
            timeout: 等待超时时间
        
        Returns:
            bool: 是否成功停止
        """
        with self._lock:
            if name not in self._threads:
                log(3, f"线程 '{name}' 不存在")
                return False
            
            thread = self._threads[name]
            return thread.stop_gracefully(timeout)
    
    def stop_all_threads(self, timeout: float = 5.0) -> bool:
        """停止所有线程
        
        Args:
            timeout: 每个线程的等待超时时间
        
        Returns:
            bool: 是否所有线程都成功停止
        """
        log(1, "正在停止所有线程...")
        success = True
        
        with self._lock:
            for name, thread in self._threads.items():
                if thread.is_alive():
                    if not thread.stop_gracefully(timeout):
                        success = False
        
        if success:
            log(1, "所有线程已成功停止")
        else:
            log(2, "部分线程未能在指定时间内停止")
        
        return success
    
    def get_thread_status(self) -> dict:
        """获取所有线程状态
        
        Returns:
            dict: 线程状态信息
        """
        status = {}
        with self._lock:
            for name, thread in self._threads.items():
                status[name] = {
                    'alive': thread.is_alive(),
                    'running': thread.is_running(),
                    'duration': thread.get_duration(),
                    'exception': str(thread.get_exception()) if thread.get_exception() else None
                }
        return status
    
    def cleanup_finished_threads(self):
        """清理已完成的线程"""
        with self._lock:
            finished_threads = [name for name, thread in self._threads.items() 
                              if not thread.is_alive()]
            
            for name in finished_threads:
                del self._threads[name]
                log(1, f"已清理完成的线程: {name}")
    
    def __len__(self) -> int:
        """返回线程数量"""
        with self._lock:
            return len(self._threads)
    
    def __contains__(self, name: str) -> bool:
        """检查是否包含指定名称的线程"""
        with self._lock:
            return name in self._threads


# 全局线程管理器实例
thread_manager = ThreadManager()