main.py 9.53 KB
"""FastAPI proxy for the Z-Image generator frontend."""
import json
import os
import secrets
import time
from pathlib import Path
from threading import Lock
from typing import List, Literal, Optional

import httpx
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, ConfigDict
import logging

logger = logging.getLogger("uvicorn.error")
logging.basicConfig(level=logging.INFO)
logger.info("your message %s", "hello")
Z_IMAGE_BASE_URL = os.getenv("Z_IMAGE_BASE_URL", "http://106.120.52.146:39009").rstrip("/")
REQUEST_TIMEOUT_SECONDS = float(os.getenv("REQUEST_TIMEOUT_SECONDS", "120"))
GALLERY_DATA_PATH = Path(os.getenv("GALLERY_DATA_PATH", Path(__file__).with_name("gallery_data.json")))
GALLERY_MAX_ITEMS = int(os.getenv("GALLERY_MAX_ITEMS", "500"))


class ImageGenerationPayload(BaseModel):
    model_config = ConfigDict(populate_by_name=True)

    prompt: str = Field(..., min_length=1, max_length=2048)
    height: int = Field(1024, ge=64, le=2048)
    width: int = Field(1024, ge=64, le=2048)
    num_inference_steps: int = Field(8, ge=1, le=200)
    guidance_scale: float = Field(0.0, ge=0.0, le=20.0)
    seed: Optional[int] = Field(default=None, ge=0)
    negative_prompt: Optional[str] = Field(default=None, max_length=2048)
    output_format: Literal["base64", "url"] = "base64"
    author_id: Optional[str] = Field(default=None, alias="authorId", min_length=1, max_length=64)


class ImageGenerationResponse(BaseModel):
    image: Optional[str] = None
    url: Optional[str] = None
    time_taken: float = 0.0
    error: Optional[str] = None
    request_params: ImageGenerationPayload
    gallery_item: Optional["GalleryImage"] = None


class GalleryImage(BaseModel):
    model_config = ConfigDict(populate_by_name=True)

    id: str
    prompt: str = Field(..., min_length=1, max_length=2048)
    height: int = Field(..., ge=64, le=2048)
    width: int = Field(..., ge=64, le=2048)
    num_inference_steps: int = Field(..., ge=1, le=200)
    guidance_scale: float = Field(..., ge=0.0, le=20.0)
    seed: int = Field(..., ge=0)
    url: str
    created_at: float = Field(default_factory=lambda: time.time() * 1000, alias="createdAt")
    author_id: Optional[str] = Field(default=None, alias="authorId")
    likes: int = 0
    is_mock: bool = Field(default=False, alias="isMock")
    negative_prompt: Optional[str] = None
    liked_by: List[str] = Field(default_factory=list, alias="likedBy")


ImageGenerationResponse.model_rebuild()


class GalleryStore:
    """Simple JSON file backed store for generated images."""

    def __init__(self, path: Path, max_items: int = 500) -> None:
        self.path = path
        self.max_items = max_items
        self.lock = Lock()
        self.enabled = True
        self._memory_cache: List[dict] = []
        try:
            self.path.parent.mkdir(parents=True, exist_ok=True)
            if self.path.exists():
                self._memory_cache = self._read().get("images", [])
            else:
                self._write({"images": []})
        except OSError as exc:  # pragma: no cover - filesystem guards
            self.enabled = False
            print(f"[WARN] Gallery store disabled due to filesystem error: {exc}")

    def _read(self) -> dict:
        if not self.enabled:
            return {"images": list(self._memory_cache)}
        try:
            with self.path.open("r", encoding="utf-8") as file:
                return json.load(file)
        except (FileNotFoundError, json.JSONDecodeError):
            return {"images": []}

    def _write(self, data: dict) -> None:
        if not self.enabled:
            self._memory_cache = list(data.get("images", []))
            return
        payload = json.dumps(data, ensure_ascii=False, indent=2)
        temp_path = self.path.with_suffix(".tmp")
        try:
            with temp_path.open("w", encoding="utf-8") as file:
                file.write(payload)
            temp_path.replace(self.path)
        except OSError as exc:
            # Some filesystems (or permissions) may block atomic replace; fall back to direct write
            print(f"[WARN] Atomic gallery write failed, attempting direct write: {exc}")
            try:
                with self.path.open("w", encoding="utf-8") as file:
                    file.write(payload)
            except OSError as direct_exc:
                raise direct_exc
        self._memory_cache = list(data.get("images", []))

    def list_images(self) -> List[dict]:
        with self.lock:
            data = self._read()
            return list(data.get("images", []))

    def add_image(self, image: GalleryImage) -> dict:
        payload = image.model_dump(by_alias=True)
        with self.lock:
            data = self._read()
            images = data.get("images", [])
            images.insert(0, payload)
            data["images"] = images[: self.max_items]
            self._write(data)
        return payload

    def toggle_like(self, image_id: str, user_id: str) -> Optional[dict]:
        with self.lock:
            data = self._read()
            images = data.get("images", [])
            target_image = next((img for img in images if img.get("id") == image_id), None)
            
            if not target_image:
                return None
                
            liked_by = target_image.get("likedBy", [])
            # Handle legacy data where likedBy might be missing
            if not isinstance(liked_by, list):
                liked_by = []
                
            if user_id in liked_by:
                liked_by.remove(user_id)
                target_image["likes"] = max(0, target_image.get("likes", 0) - 1)
            else:
                liked_by.append(user_id)
                target_image["likes"] = target_image.get("likes", 0) + 1
            
            target_image["likedBy"] = liked_by
            self._write(data)
            return target_image


gallery_store = GalleryStore(GALLERY_DATA_PATH, GALLERY_MAX_ITEMS)


app = FastAPI(title="Z-Image Proxy", version="1.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=os.getenv("ALLOWED_ORIGINS", "*").split(","),
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.on_event("startup")
async def startup() -> None:
    timeout = httpx.Timeout(REQUEST_TIMEOUT_SECONDS, connect=5.0)
    app.state.http = httpx.AsyncClient(timeout=timeout)


@app.on_event("shutdown")
async def shutdown() -> None:
    await app.state.http.aclose()


@app.get("/health")
async def health() -> dict:
    return {"status": "ok"}


@app.post("/likes/{image_id}")
async def toggle_like(
    image_id: str,
    user_id: str = Query(..., alias="userId")
) -> dict:
    """Toggle like status for an image by a user."""
    updated_image = gallery_store.toggle_like(image_id, user_id)
    if not updated_image:
        raise HTTPException(status_code=404, detail="Image not found")
    return updated_image


@app.get("/gallery")
async def gallery(
    limit: int = Query(200, ge=1, le=1000),
    author_id: Optional[str] = Query(default=None, alias="authorId"),
) -> dict:
    """Return the persisted gallery images, optionally filtered by author."""
    images = gallery_store.list_images()
    if author_id:
        images = [item for item in images if item.get("authorId") == author_id]
    return {"images": images[:limit]}


@app.post("/generate", response_model=ImageGenerationResponse)
async def generate_image(payload: ImageGenerationPayload) -> ImageGenerationResponse:
    request_params_data = payload.model_dump()
    body = {
        key: value
        for key, value in request_params_data.items()
        if value is not None and key != "author_id"
    }
    if "seed" not in body:
        body["seed"] = secrets.randbelow(1_000_000_000)
    request_params_data["seed"] = body["seed"]
    request_params = ImageGenerationPayload(**request_params_data)
    url = f"{Z_IMAGE_BASE_URL}/generate"

    try:
        resp = await app.state.http.post(url, json=body)
    except httpx.RequestError as exc:  # pragma: no cover - network errors only
        raise HTTPException(status_code=502, detail=f"Z-Image service unreachable: {exc}") from exc

    if resp.status_code != 200:
        raise HTTPException(status_code=resp.status_code, detail=f"Z-Image error: {resp.text}")

    data = resp.json()
    image = data.get("image")
    image_url = data.get("url")
    if not image and not image_url:
        raise HTTPException(status_code=502, detail=f"Malformed response from Z-Image: {data}")

    stored_image: Optional[GalleryImage] = None
    try:
        stored = gallery_store.add_image(
            GalleryImage(
                id=data.get("id") or secrets.token_hex(16),
                prompt=payload.prompt,
                width=payload.width,
                height=payload.height,
                num_inference_steps=payload.num_inference_steps,
                guidance_scale=payload.guidance_scale,
                seed=request_params.seed,
                url=image_url or f"data:image/png;base64,{image}",
                author_id=payload.author_id,
                negative_prompt=payload.negative_prompt,
            )
        )
        stored_image = GalleryImage.model_validate(stored)
    except Exception as exc:  # pragma: no cover - diagnostics only
        # Persisting gallery data should not block the response
        print(f"[WARN] Failed to store gallery image: {exc}")

    return ImageGenerationResponse(
        image=image,
        url=image_url,
        time_taken=float(data.get("time_taken", 0.0)),
        error=data.get("error"),
        request_params=request_params,
        gallery_item=stored_image,
    )