test_z_image_client_2.py 6.24 KB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Client script to test Z-Image server."""

import base64
import time
import requests
from PIL import Image
import io


def test_health_check(base_url="http://localhost:9009"):
    """Test server health check."""
    print("Testing health check...")
    try:
        response = requests.get(f"{base_url}/health", timeout=10)
        print(f"Status code: {response.status_code}")
        print(f"Response text: {response.text}")
        print(f"Response headers: {dict(response.headers)}")
        if response.status_code == 200:
            print(f"Health check: {response.json()}")
        else:
            print(f"Error: HTTP {response.status_code}")
            print(f"Response: {response.text}")
    except requests.exceptions.RequestException as e:
        print(f"Request error: {e}")
        raise
    except ValueError as e:
        print(f"JSON decode error: {e}")
        print(f"Response content: {response.text if 'response' in locals() else 'No response'}")
        raise
    print()


def test_generate_image(base_url="http://localhost:9009", save_path="generated_image.png", output_format="base64"):
    """Test image generation with base64 or URL response."""
    format_name = "base64" if output_format == "base64" else "URL"
    print(f"Testing image generation ({format_name})...")
    
    # Prepare request
    request_data = {
        "prompt": "一只可爱的熊猫在竹林里吃竹子,阳光透过树叶洒下,4K高清,摄影作品",
        "height": 1024,
        "width": 1024,
        "num_inference_steps": 8,
        "guidance_scale": 0.0,
        "seed": 42,
        "output_format": output_format,
    }
    
    print(f"Prompt: {request_data['prompt']}")
    print(f"Size: {request_data['width']}x{request_data['height']}")
    print(f"Output format: {output_format}")
    
    # Send request
    start_time = time.time()
    response = requests.post(f"{base_url}/generate", json=request_data)
    end_time = time.time()
    
    if response.status_code == 200:
        result = response.json()
        print(f"Generation successful!")
        print(f"Time taken: {result['time_taken']:.2f}s")
        print(f"Total request time: {end_time - start_time:.2f}s")
        
        if output_format == "url":
            # Handle URL response
            if result.get("url"):
                image_url = result["url"]
                print(f"Image URL: {image_url}")
                
                # Download image from URL and save
                try:
                    img_response = requests.get(image_url, timeout=10)
                    if img_response.status_code == 200:
                        img = Image.open(io.BytesIO(img_response.content))
                        img.save(save_path)
                        print(f"Image downloaded and saved to: {save_path}")
                    else:
                        print(f"Failed to download image from URL: HTTP {img_response.status_code}")
                except Exception as e:
                    print(f"Error downloading image from URL: {e}")
            else:
                print("Warning: No URL returned in response")
                # Fallback: check if base64 is available
                if result.get("image"):
                    print("Fallback: Using base64 image from response")
                    img_data = base64.b64decode(result["image"])
                    img = Image.open(io.BytesIO(img_data))
                    img.save(save_path)
                    print(f"Image saved to: {save_path}")
        else:
            # Handle base64 response
            if result.get("image"):
                img_data = base64.b64decode(result["image"])
                img = Image.open(io.BytesIO(img_data))
                img.save(save_path)
                print(f"Image saved to: {save_path}")
            else:
                print("Warning: No image data in response")
    else:
        print(f"Error: {response.status_code}")
        print(response.json())
    
    print()


def test_generate_stream(base_url="http://localhost:9009", save_path="generated_stream.png"):
    """Test image generation with stream response."""
    print("Testing image generation (stream)...")
    
    # Prepare request
    request_data = {
        "prompt": "A futuristic cityscape at sunset, flying cars, neon lights, cyberpunk style, highly detailed",
        "height": 1024,
        "width": 1024,
        "num_inference_steps": 8,
        "guidance_scale": 0.0,
        "seed": 123,
    }
    
    print(f"Prompt: {request_data['prompt']}")
    print(f"Size: {request_data['width']}x{request_data['height']}")
    
    # Send request
    start_time = time.time()
    response = requests.post(f"{base_url}/generate_stream", json=request_data)
    end_time = time.time()
    
    if response.status_code == 200:
        print(f"Generation successful!")
        print(f"Total request time: {end_time - start_time:.2f}s")
        
        # Save image
        img = Image.open(io.BytesIO(response.content))
        img.save(save_path)
        print(f"Image saved to: {save_path}")
    else:
        print(f"Error: {response.status_code}")
        print(response.text)
    
    print()


def main():
    """Run all tests."""
    base_url = "http://106.120.52.146:39009"
    
    print("=" * 60)
    print("Z-Image Server Client Test")
    print("=" * 60)
    print()
    
    # Test health check
    try:
        test_health_check(base_url)
    except Exception as e:
        print(f"Health check failed: {e}")
        print("Make sure the server is running!")
        return
    
    # Test image generation (base64)
    try:
        test_generate_image(base_url, "test_output_base64.png", output_format="base64")
    except Exception as e:
        print(f"Image generation test (base64) failed: {e}")
    
    # Test image generation (URL)
    try:
        test_generate_image(base_url, "test_output_url.png", output_format="url")
    except Exception as e:
        print(f"Image generation test (URL) failed: {e}")
    
    # Test image generation (stream)
    try:
        test_generate_stream(base_url, "test_output_stream.png")
    except Exception as e:
        print(f"Stream generation test failed: {e}")
    
    print("=" * 60)
    print("All tests completed!")
    print("=" * 60)


if __name__ == "__main__":
    main()