#!/usr/bin/env python3
"""Generate an English .srt next to each video using whisper.cpp."""

import argparse
import os
import re
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path

WHISPER_DIR = Path(
    os.environ.get("WHISPER_DIR", Path.home() / "repos/whisper.cpp")
)
DEFAULT_MODEL = Path(
    os.environ.get("WHISPER_MODEL", WHISPER_DIR / "models/ggml-large-v3.bin")
)
DEFAULT_BIN = Path(
    os.environ.get("WHISPER_BIN", WHISPER_DIR / "build/bin/whisper-cli")
)
DEFAULT_VAD = Path(
    os.environ.get("VAD_MODEL", WHISPER_DIR / "models/ggml-silero-v6.2.0.bin")
)

MIN_WORDS_TO_WRAP = 2
SRT_HEADER_LINES = 2


def wrap_line(text: str, limit: int) -> str:
    text = " ".join(text.split())
    if len(text) <= limit:
        return text
    words = text.split(" ")
    if len(words) < MIN_WORDS_TO_WRAP:
        return text
    best = 0
    best_score = float("inf")
    cum = 0
    for i, w in enumerate(words[:-1]):
        cum += len(w) + (1 if i > 0 else 0)
        top, bot = cum, len(text) - cum - 1
        score = abs(bot - top)
        if top > limit:
            score += (top - limit) * 100
        if bot > limit:
            score += (bot - limit) * 100
        if w.rstrip(",.!?:;") != w:
            score -= 8
        if score < best_score:
            best_score, best = score, i
    return " ".join(words[: best + 1]) + "\n" + " ".join(words[best + 1 :])


def wrap_srt(path: Path, limit: int) -> None:
    blocks = re.split(r"\n\n+", path.read_text(encoding="utf-8").strip())
    out: list[str] = []
    for b in blocks:
        lines = b.split("\n")
        if len(lines) <= SRT_HEADER_LINES:
            out.append(b)
            continue
        head, body = lines[:SRT_HEADER_LINES], " ".join(lines[SRT_HEADER_LINES:])
        out.append("\n".join([*head, wrap_line(body, limit)]))
    path.write_text("\n\n".join(out) + "\n", encoding="utf-8")


def output_stem(video: Path, override: str | None) -> Path:
    if override:
        stem = override[:-4] if override.endswith(".srt") else override
        return Path(stem)
    return video.with_suffix("")


def process(video: Path, args: argparse.Namespace) -> None:
    out_stem = output_stem(video, args.output)
    srt = out_stem.with_suffix(".srt")

    if srt.exists() and not args.force:
        print(f"skip: {srt} exists (use --force to overwrite)")
        return

    out_stem.parent.mkdir(parents=True, exist_ok=True)
    print(f">> {video}")

    with tempfile.TemporaryDirectory() as tmpdir:
        wav_path = Path(tmpdir) / "audio.wav"

        subprocess.run(
            [
                "ffmpeg",
                "-hide_banner",
                "-loglevel",
                "error",
                "-y",
                "-i",
                str(video),
                "-vn",
                "-ar",
                "16000",
                "-ac",
                "1",
                "-c:a",
                "pcm_s16le",
                str(wav_path),
            ],
            check=True,
        )

        cmd = [
            str(args.bin),
            "-m",
            str(args.model),
            "-f",
            str(wav_path),
            "-of",
            str(out_stem),
            "--output-srt",
            "-l",
            args.lang,
            "-mc",
            "0",
        ]
        if args.translate:
            cmd.append("--translate")
        if args.max_len > 0:
            cmd += ["--max-len", str(args.max_len), "--split-on-word"]
        if args.vad:
            if args.vad_model.exists():
                cmd += ["--vad", "--vad-model", str(args.vad_model)]
                if args.vad_max_speech > 0:
                    cmd += ["-vmsd", str(args.vad_max_speech)]
            else:
                print(
                    f"warn: VAD model not found at {args.vad_model}, running without VAD",
                    file=sys.stderr,
                )
                print(
                    f"      download with: sh {WHISPER_DIR}/models/download-vad-model.sh silero-v6.2.0",
                    file=sys.stderr,
                )

        subprocess.run(cmd, check=True)

        if args.line_len > 0:
            wrap_srt(srt, args.line_len)

        print(f"<< {srt}")


def main() -> int:
    p = argparse.ArgumentParser(
        prog="video2srt",
        description="Generate an English .srt next to each video using whisper.cpp.",
    )
    p.add_argument("videos", nargs="+", type=Path, metavar="VIDEO")
    p.add_argument(
        "-t",
        "--transcribe",
        dest="translate",
        action="store_false",
        help="Transcribe in source language (default: translate to English)",
    )
    p.add_argument(
        "-l",
        "--lang",
        default=os.environ.get("SRC_LANG", "auto"),
        help="Force source language (default: auto-detect)",
    )
    p.add_argument(
        "-m",
        "--model",
        type=Path,
        default=DEFAULT_MODEL,
        help=f"Path to ggml model (default: {DEFAULT_MODEL})",
    )
    p.add_argument(
        "-o",
        "--output",
        help="Output .srt path (single input only; default: <video>.srt)",
    )
    p.add_argument(
        "-f", "--force", action="store_true", help="Overwrite existing .srt"
    )
    p.add_argument(
        "--no-vad",
        dest="vad",
        action="store_false",
        help="Disable Silero VAD pre-filtering (VAD reduces hallucination loops)",
    )
    p.add_argument(
        "--vad-max-speech",
        type=float,
        default=15.0,
        help="Max seconds of speech per VAD chunk; shorter values give tighter timestamps (default: 15)",
    )
    p.add_argument(
        "--max-len",
        type=int,
        default=84,
        help="Max characters per SRT entry, 0 to disable (default: 84)",
    )
    p.add_argument(
        "--line-len",
        type=int,
        default=42,
        help="Max characters per visible line, 0 to disable (default: 42)",
    )
    p.add_argument(
        "--bin", type=Path, default=DEFAULT_BIN, help=argparse.SUPPRESS
    )
    p.add_argument(
        "--vad-model", type=Path, default=DEFAULT_VAD, help=argparse.SUPPRESS
    )

    args = p.parse_args()

    if args.output and len(args.videos) > 1:
        p.error("--output cannot be combined with multiple input files")
    if not args.bin.is_file() or not os.access(args.bin, os.X_OK):
        sys.exit(f"whisper-cli not found at {args.bin}")
    if not args.model.is_file():
        sys.exit(f"model not found at {args.model}")
    if shutil.which("ffmpeg") is None:
        sys.exit("ffmpeg not installed")

    for video in args.videos:
        if not video.is_file():
            print(f"skip: {video} (not a file)", file=sys.stderr)
            continue
        process(video, args)

    return 0


if __name__ == "__main__":
    sys.exit(main())
