Doiiars

修复缺失的文件

  1 +"""
  2 +通用数据库工具(异步)
  3 +
  4 +此模块提供基于 SQLAlchemy 2.x 异步引擎的数据库访问封装,支持 MySQL 与 PostgreSQL。
  5 +数据模型定义位置:
  6 +- 无(本模块仅提供连接与查询工具,不定义数据模型)
  7 +"""
  8 +
  9 +from __future__ import annotations
  10 +
  11 +import asyncio
  12 +import os
  13 +from typing import Any, Dict, Iterable, List, Optional, Union
  14 +
  15 +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
  16 +from sqlalchemy import text
  17 +from InsightEngine.utils.config import settings
  18 +
  19 +__all__ = [
  20 + "get_async_engine",
  21 + "fetch_all",
  22 +]
  23 +
  24 +
  25 +_engine: Optional[AsyncEngine] = None
  26 +
  27 +
  28 +def _build_database_url() -> str:
  29 + dialect: str = (settings.DB_DIALECT or "mysql").lower()
  30 + host: str = settings.DB_HOST or ""
  31 + port: str = str(settings.DB_PORT or "")
  32 + user: str = settings.DB_USER or ""
  33 + password: str = settings.DB_PASSWORD or ""
  34 + db_name: str = settings.DB_NAME or ""
  35 +
  36 + if os.getenv("DATABASE_URL"):
  37 + return os.getenv("DATABASE_URL") # 直接使用外部提供的完整URL
  38 +
  39 + if dialect in ("postgresql", "postgres"):
  40 + # PostgreSQL 使用 asyncpg 驱动
  41 + return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{db_name}"
  42 +
  43 + # 默认 MySQL 使用 aiomysql 驱动
  44 + return f"mysql+aiomysql://{user}:{password}@{host}:{port}/{db_name}"
  45 +
  46 +
  47 +def get_async_engine() -> AsyncEngine:
  48 + global _engine
  49 + if _engine is None:
  50 + database_url: str = _build_database_url()
  51 + _engine = create_async_engine(
  52 + database_url,
  53 + pool_pre_ping=True,
  54 + pool_recycle=1800,
  55 + )
  56 + return _engine
  57 +
  58 +
  59 +async def fetch_all(query: str, params: Optional[Union[Iterable[Any], Dict[str, Any]]] = None) -> List[Dict[str, Any]]:
  60 + """
  61 + 执行只读查询并返回字典列表。
  62 + """
  63 + engine: AsyncEngine = get_async_engine()
  64 + async with engine.connect() as conn:
  65 + result = await conn.execute(text(query), params or {})
  66 + rows = result.mappings().all()
  67 + # 将 RowMapping 转换为普通字典
  68 + return [dict(row) for row in rows]
  69 +
  70 +