Showing
1 changed file
with
70 additions
and
0 deletions
InsightEngine/utils/db.py
0 → 100644
| 1 | +""" | ||
| 2 | +通用数据库工具(异步) | ||
| 3 | + | ||
| 4 | +此模块提供基于 SQLAlchemy 2.x 异步引擎的数据库访问封装,支持 MySQL 与 PostgreSQL。 | ||
| 5 | +数据模型定义位置: | ||
| 6 | +- 无(本模块仅提供连接与查询工具,不定义数据模型) | ||
| 7 | +""" | ||
| 8 | + | ||
| 9 | +from __future__ import annotations | ||
| 10 | + | ||
| 11 | +import asyncio | ||
| 12 | +import os | ||
| 13 | +from typing import Any, Dict, Iterable, List, Optional, Union | ||
| 14 | + | ||
| 15 | +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine | ||
| 16 | +from sqlalchemy import text | ||
| 17 | +from InsightEngine.utils.config import settings | ||
| 18 | + | ||
| 19 | +__all__ = [ | ||
| 20 | + "get_async_engine", | ||
| 21 | + "fetch_all", | ||
| 22 | +] | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +_engine: Optional[AsyncEngine] = None | ||
| 26 | + | ||
| 27 | + | ||
| 28 | +def _build_database_url() -> str: | ||
| 29 | + dialect: str = (settings.DB_DIALECT or "mysql").lower() | ||
| 30 | + host: str = settings.DB_HOST or "" | ||
| 31 | + port: str = str(settings.DB_PORT or "") | ||
| 32 | + user: str = settings.DB_USER or "" | ||
| 33 | + password: str = settings.DB_PASSWORD or "" | ||
| 34 | + db_name: str = settings.DB_NAME or "" | ||
| 35 | + | ||
| 36 | + if os.getenv("DATABASE_URL"): | ||
| 37 | + return os.getenv("DATABASE_URL") # 直接使用外部提供的完整URL | ||
| 38 | + | ||
| 39 | + if dialect in ("postgresql", "postgres"): | ||
| 40 | + # PostgreSQL 使用 asyncpg 驱动 | ||
| 41 | + return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{db_name}" | ||
| 42 | + | ||
| 43 | + # 默认 MySQL 使用 aiomysql 驱动 | ||
| 44 | + return f"mysql+aiomysql://{user}:{password}@{host}:{port}/{db_name}" | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +def get_async_engine() -> AsyncEngine: | ||
| 48 | + global _engine | ||
| 49 | + if _engine is None: | ||
| 50 | + database_url: str = _build_database_url() | ||
| 51 | + _engine = create_async_engine( | ||
| 52 | + database_url, | ||
| 53 | + pool_pre_ping=True, | ||
| 54 | + pool_recycle=1800, | ||
| 55 | + ) | ||
| 56 | + return _engine | ||
| 57 | + | ||
| 58 | + | ||
| 59 | +async def fetch_all(query: str, params: Optional[Union[Iterable[Any], Dict[str, Any]]] = None) -> List[Dict[str, Any]]: | ||
| 60 | + """ | ||
| 61 | + 执行只读查询并返回字典列表。 | ||
| 62 | + """ | ||
| 63 | + engine: AsyncEngine = get_async_engine() | ||
| 64 | + async with engine.connect() as conn: | ||
| 65 | + result = await conn.execute(text(query), params or {}) | ||
| 66 | + rows = result.mappings().all() | ||
| 67 | + # 将 RowMapping 转换为普通字典 | ||
| 68 | + return [dict(row) for row in rows] | ||
| 69 | + | ||
| 70 | + |
-
Please register or login to post a comment