predict.py 5.56 KB
from __future__ import annotations

import argparse
from pathlib import Path


MODEL_NAME = "tabularisai/multilingual-sentiment-analysis"
SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_MODEL_DIR = SCRIPT_DIR / "model"
SENTIMENT_LABELS = {
    0: "非常负面",
    1: "负面",
    2: "中性",
    3: "正面",
    4: "非常正面",
}
DEMO_TEXTS = (
    ("今天天气真好,心情特别棒。", "zh"),
    ("服务态度太差了,很失望。", "zh"),
    ("I absolutely love this product.", "en"),
    ("The customer service was disappointing.", "en"),
    ("Este lugar es increible y muy acogedor.", "es"),
    ("El servicio fue terrible y muy lento.", "es"),
    ("このレストランの料理は本当に美味しいです。", "ja"),
    ("このホテルのサービスにはがっかりしました。", "ja"),
)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run the multilingual sentiment analysis research demo."
    )
    parser.add_argument(
        "--text",
        help="Single text to analyze. If omitted, the script enters interactive mode.",
    )
    parser.add_argument(
        "--lang",
        default="auto",
        help="Compatibility-only language hint for CLI parity. Default: auto.",
    )
    parser.add_argument(
        "--demo",
        action="store_true",
        help="Run a built-in multilingual demo set and exit.",
    )
    parser.add_argument(
        "--model-dir",
        type=Path,
        default=DEFAULT_MODEL_DIR,
        help="Directory used to cache the Hugging Face model locally.",
    )
    parser.add_argument(
        "--device",
        choices=("auto", "cpu", "cuda"),
        default="auto",
        help="Execution device. Default: auto.",
    )
    return parser.parse_args()


def import_runtime_dependencies():
    import torch
    from transformers import AutoModelForSequenceClassification, AutoTokenizer

    return torch, AutoTokenizer, AutoModelForSequenceClassification


def choose_device(device_name: str):
    torch, _, _ = import_runtime_dependencies()
    if device_name == "cpu":
        return torch.device("cpu")
    if device_name == "cuda":
        if not torch.cuda.is_available():
            raise RuntimeError("CUDA requested but not available.")
        return torch.device("cuda")
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def has_local_model(model_dir: Path) -> bool:
    return (model_dir / "config.json").exists()


def load_model(model_dir: Path):
    _, AutoTokenizer, AutoModelForSequenceClassification = import_runtime_dependencies()
    model_dir = model_dir.resolve()
    if has_local_model(model_dir):
        source = str(model_dir)
        downloaded = False
    else:
        source = MODEL_NAME
        downloaded = True

    tokenizer = AutoTokenizer.from_pretrained(source)
    model = AutoModelForSequenceClassification.from_pretrained(source)

    if downloaded:
        model_dir.mkdir(parents=True, exist_ok=True)
        tokenizer.save_pretrained(model_dir)
        model.save_pretrained(model_dir)

    return tokenizer, model, downloaded


def predict_text(tokenizer, model, device, text: str):
    torch, _, _ = import_runtime_dependencies()
    inputs = tokenizer(
        text,
        max_length=512,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    inputs = {key: value.to(device) for key, value in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        probabilities = torch.softmax(outputs.logits, dim=1)[0].cpu().tolist()

    prediction = int(max(range(len(probabilities)), key=probabilities.__getitem__))
    return SENTIMENT_LABELS[prediction], probabilities[prediction], probabilities


def print_prediction(tokenizer, model, device, text: str, language_hint: str = "auto") -> None:
    label, confidence, probabilities = predict_text(tokenizer, model, device, text)
    print(f"文本: {text}")
    print(f"语言提示: {language_hint}")
    print(f"预测结果: {label} (置信度: {confidence:.4f})")
    print("详细概率分布:")
    for index, score in enumerate(probabilities):
        print(f"  - {SENTIMENT_LABELS[index]}: {score:.4f}")


def run_demo(tokenizer, model, device) -> None:
    print("运行内置多语言示例...")
    for text, language_hint in DEMO_TEXTS:
        print_prediction(tokenizer, model, device, text, language_hint)
        print("")


def interactive_loop(tokenizer, model, device) -> None:
    print("进入交互模式。输入 `demo` 运行示例,输入 `q` 退出。")
    while True:
        text = input("请输入文本: ").strip()
        if text.lower() == "q":
            return
        if not text:
            print("输入不能为空,请重新输入。")
            continue
        if text.lower() == "demo":
            run_demo(tokenizer, model, device)
            continue
        print_prediction(tokenizer, model, device, text)
        print("")


def main() -> None:
    args = parse_args()
    device = choose_device(args.device)
    tokenizer, model, downloaded = load_model(args.model_dir)
    model.to(device)
    model.eval()

    if downloaded:
        print(f"首次运行已下载模型到: {args.model_dir.resolve()}")
    else:
        print(f"已从本地缓存加载模型: {args.model_dir.resolve()}")
    print(f"当前执行设备: {device}")

    if args.demo:
        run_demo(tokenizer, model, device)
        return

    if args.text:
        print_prediction(tokenizer, model, device, args.text, args.lang)
        return

    interactive_loop(tokenizer, model, device)


if __name__ == "__main__":
    main()