"""Core utilities for the Nano Banana 2 declination app.

Focus of this version:
- understand visible text in an image
- generate a clean plate from that understanding
- fall back to the local ComfyUI inpaint workflow if the Gemini image backend
  is unavailable
"""

from __future__ import annotations

import base64
import json
import math
import os
import re
import shutil
import subprocess
import tempfile
import zipfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable, Sequence

import cv2
import numpy as np
import requests
from PIL import Image, ImageDraw, ImageFont, ImageOps

ROOT = Path(__file__).resolve().parent
OUTPUT_ROOT = ROOT / "outputs"
RUN_ROOT = OUTPUT_ROOT / "runs"
COMFY_WORKFLOW = ROOT / "workflows" / "sdxl_inpaint.json"
COMFY_SCRIPT = Path("/home/wildlama/.hermes/skills/creative/comfyui/scripts/run_workflow.py")
COMFY_WORKSPACE = Path("/home/wildlama/comfy/ComfyUI")

DEFAULT_MODEL = os.getenv("NANO_BANANA_MODEL", "gemini-3.1-flash-image")
DEFAULT_API_KEY = (
    os.getenv("NANO_BANANA_API_KEY")
    or os.getenv("GOOGLE_API_KEY")
    or os.getenv("GEMINI_API_KEY")
    or os.getenv("GOOGLE_GENAI_API_KEY")
    or ""
)
DEFAULT_API_BASE = os.getenv("NANO_BANANA_API_BASE", "https://generativelanguage.googleapis.com/v1beta")

DEFAULT_FONT_CANDIDATES = [
    Path("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"),
    Path("/usr/share/fonts/truetype/liberation2/LiberationSans-Regular.ttf"),
    Path("/usr/share/fonts/truetype/freefont/FreeSans.ttf"),
]

try:
    from google import genai  # type: ignore
except Exception:  # pragma: no cover - optional dependency
    genai = None


@dataclass(frozen=True)
class Box:
    x1: int
    y1: int
    x2: int
    y2: int
    score: float = 0.0

    @property
    def width(self) -> int:
        return max(0, self.x2 - self.x1)

    @property
    def height(self) -> int:
        return max(0, self.y2 - self.y1)

    @property
    def area(self) -> int:
        return self.width * self.height

    def padded(self, pad: int, limit_w: int, limit_h: int) -> "Box":
        return Box(
            max(0, self.x1 - pad),
            max(0, self.y1 - pad),
            min(limit_w, self.x2 + pad),
            min(limit_h, self.y2 + pad),
            self.score,
        )

    def to_tuple(self) -> tuple[int, int, int, int]:
        return self.x1, self.y1, self.x2, self.y2


# ---------------------------------------------------------------------------
# Files and images
# ---------------------------------------------------------------------------


def ensure_dirs() -> None:
    OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
    RUN_ROOT.mkdir(parents=True, exist_ok=True)



def load_image(image_path: str | Path) -> Image.Image:
    return Image.open(image_path).convert("RGB")



def save_image(image: Image.Image, path: str | Path) -> Path:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    image.save(path)
    return path



def _font_path() -> Path:
    for candidate in DEFAULT_FONT_CANDIDATES:
        if candidate.exists():
            return candidate
    return Path("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf")



def make_contact_sheet(images: Sequence[Image.Image], titles: Sequence[str], thumb_size: tuple[int, int] = (512, 512)) -> Image.Image:
    if not images:
        return Image.new("RGB", thumb_size, "white")

    thumbs = []
    for img in images:
        thumb = ImageOps.contain(img.convert("RGB"), thumb_size)
        canvas = Image.new("RGB", thumb_size, "#111111")
        x = (thumb_size[0] - thumb.width) // 2
        y = (thumb_size[1] - thumb.height) // 2
        canvas.paste(thumb, (x, y))
        thumbs.append(canvas)

    cols = 2 if len(thumbs) > 1 else 1
    rows = math.ceil(len(thumbs) / cols)
    sheet = Image.new("RGB", (cols * thumb_size[0], rows * (thumb_size[1] + 40)), "#0d0d0d")
    draw = ImageDraw.Draw(sheet)
    font = ImageFont.truetype(str(_font_path()), 22)

    for i, (img, title) in enumerate(zip(thumbs, titles)):
        r = i // cols
        c = i % cols
        x = c * thumb_size[0]
        y = r * (thumb_size[1] + 40)
        sheet.paste(img, (x, y + 30))
        draw.text((x + 12, y + 4), title, fill="white", font=font)
    return sheet



def bundle_outputs(paths: Sequence[str | Path], bundle_name: str) -> Path:
    ensure_dirs()
    bundle_path = OUTPUT_ROOT / bundle_name
    with zipfile.ZipFile(bundle_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
        for path in paths:
            p = Path(path)
            if p.exists():
                zf.write(p, arcname=p.name)
    return bundle_path


# ---------------------------------------------------------------------------
# Heuristic text detection, kept as a fallback and for previews
# ---------------------------------------------------------------------------


def _merge_overlapping(boxes: Sequence[Box], gap: int = 8) -> list[Box]:
    if not boxes:
        return []
    ordered = sorted(boxes, key=lambda b: (b.y1, b.x1))
    merged: list[Box] = []
    for box in ordered:
        if not merged:
            merged.append(box)
            continue
        cur = merged[-1]
        if (
            box.x1 <= cur.x2 + gap
            and box.x2 >= cur.x1 - gap
            and box.y1 <= cur.y2 + gap
            and box.y2 >= cur.y1 - gap
        ):
            merged[-1] = Box(
                min(cur.x1, box.x1),
                min(cur.y1, box.y1),
                max(cur.x2, box.x2),
                max(cur.y2, box.y2),
                max(cur.score, box.score),
            )
        else:
            merged.append(box)
    return merged



def detect_text_regions(image: Image.Image, min_area_ratio: float = 0.0015) -> list[Box]:
    arr = np.array(image)
    gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY)
    h, w = gray.shape

    def _find_from_map(src: np.ndarray) -> list[Box]:
        blur = cv2.GaussianBlur(src, (3, 3), 0)
        _, thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
        morph_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 7))
        closed = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, morph_kernel, iterations=2)
        closed = cv2.dilate(closed, None, iterations=1)
        contours, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        found: list[Box] = []
        min_area = int(h * w * min_area_ratio)
        for cnt in contours:
            x, y, bw, bh = cv2.boundingRect(cnt)
            area = bw * bh
            if area < min_area:
                continue
            aspect = bw / max(1, bh)
            if aspect < 1.0:
                continue
            if bw < 18 or bh < 8:
                continue
            local = gray[max(0, y):min(h, y + bh), max(0, x):min(w, x + bw)]
            contrast = float(local.std()) / 255.0
            score = min(1.0, (area / float(w * h)) * 12.0 + contrast)
            found.append(Box(x, y, x + bw, y + bh, score))
        return found

    def _find_from_edges(src: np.ndarray) -> list[Box]:
        edges = cv2.Canny(src, 50, 150)
        morph_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (21, 5))
        closed = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, morph_kernel, iterations=2)
        closed = cv2.dilate(closed, None, iterations=1)
        contours, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        found: list[Box] = []
        min_area = int(h * w * (min_area_ratio * 0.35))
        for cnt in contours:
            x, y, bw, bh = cv2.boundingRect(cnt)
            area = bw * bh
            if area < min_area:
                continue
            aspect = bw / max(1, bh)
            if aspect < 1.0:
                continue
            if bw < 14 or bh < 8:
                continue
            local = gray[max(0, y):min(h, y + bh), max(0, x):min(w, x + bw)]
            contrast = float(local.std()) / 255.0
            score = min(1.0, (area / float(w * h)) * 10.0 + contrast)
            found.append(Box(x, y, x + bw, y + bh, score))
        return found

    blackhat_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (17, 5))
    blackhat = cv2.morphologyEx(gray, cv2.MORPH_BLACKHAT, blackhat_kernel)
    tophat = cv2.morphologyEx(gray, cv2.MORPH_TOPHAT, blackhat_kernel)

    grad_black = cv2.Scharr(blackhat, cv2.CV_32F, 1, 0)
    grad_black = np.absolute(grad_black)
    grad_black = (255 * (grad_black / (grad_black.max() + 1e-6))).astype("uint8")

    grad_white = cv2.Scharr(tophat, cv2.CV_32F, 1, 0)
    grad_white = np.absolute(grad_white)
    grad_white = (255 * (grad_white / (grad_white.max() + 1e-6))).astype("uint8")

    candidates = _find_from_map(grad_black) + _find_from_map(grad_white) + _find_from_edges(gray)
    candidates.sort(key=lambda b: b.score, reverse=True)
    boxes = _merge_overlapping(candidates)
    boxes = [b.padded(12, w, h) for b in boxes]
    return boxes[:10]



def boxes_to_mask(image: Image.Image, boxes: Sequence[Box], pad: int = 10) -> Image.Image:
    w, h = image.size
    mask = Image.new("L", (w, h), 0)
    draw = ImageDraw.Draw(mask)
    for box in boxes:
        b = box.padded(pad, w, h)
        draw.rectangle([b.x1, b.y1, b.x2, b.y2], fill=255)
    return mask



def draw_boxes_preview(image: Image.Image, boxes: Sequence[Box], color: str = "#ff3366") -> Image.Image:
    out = image.copy().convert("RGB")
    draw = ImageDraw.Draw(out)
    for idx, box in enumerate(boxes, 1):
        draw.rectangle([box.x1, box.y1, box.x2, box.y2], outline=color, width=4)
        draw.text((box.x1 + 4, box.y1 + 4), str(idx), fill=color)
    return out


# ---------------------------------------------------------------------------
# Nano Banana 2 backend
# ---------------------------------------------------------------------------


def nano_banana_ready() -> tuple[bool, str]:
    if genai is None:
        return False, "google-genai n'est pas installé"
    if not DEFAULT_API_KEY:
        return False, "clé API manquante"
    return True, f"ok ({DEFAULT_MODEL})"



def _client():
    if genai is None:
        raise RuntimeError("Le package google-genai n'est pas installé")
    if not DEFAULT_API_KEY:
        raise RuntimeError("NANO_BANANA_API_KEY ou GOOGLE_API_KEY manque")
    return genai.Client(api_key=DEFAULT_API_KEY)



def _image_part(image: Image.Image) -> dict[str, str]:
    buf = tempfile.SpooledTemporaryFile()
    image.save(buf, format="PNG")
    buf.seek(0)
    data = base64.b64encode(buf.read()).decode("utf-8")
    return {"type": "image", "data": data, "mime_type": "image/png"}



def _text_part(text: str) -> dict[str, str]:
    return {"type": "text", "text": text}



def _extract_json(text: str) -> Any:
    text = text.strip()
    try:
        return json.loads(text)
    except Exception:
        match = re.search(r"\{.*\}", text, re.S)
        if match:
            return json.loads(match.group(0))
        raise



def build_analysis_prompt(user_request: str = "") -> str:
    user_request = user_request.strip()
    user_note = f"\nContext utilisateur: {user_request}" if user_request else ""
    return (
        "Analyse l'image pour comprendre tout le texte visible et la structure éditoriale. "
        "Réponds uniquement en JSON strict, sans markdown, avec les clés suivantes: "
        "language, visible_text, blocks, reading_order, clean_plate_notes, recommended_clean_prompt, confidence. "
        "blocks doit être un tableau d'objets avec: text, role, position_hint, confidence. "
        "position_hint doit être une description courte, par exemple top-left, center, banner, footer. "
        "Si un texte est incertain, mets-le dans visible_text avec un préfixe [uncertain]. "
        "N'invente pas de texte. "
        "clean_plate_notes doit expliquer quoi préserver pour faire un clean plate propre."
        f"{user_note}"
    )



def build_clean_prompt(analysis: dict[str, Any], extra_instruction: str = "") -> str:
    text_summary = json.dumps(analysis, ensure_ascii=False, indent=2)
    extra_instruction = extra_instruction.strip()
    extra = f"\nConsigne additionnelle: {extra_instruction}" if extra_instruction else ""
    return (
        "Tu reçois une image avec du texte à retirer. Crée un clean plate propre. "
        "Supprime tous les textes, lettres, logos typographiques, chiffres et artefacts de composition liés au texte. "
        "Préserve la composition, l'éclairage, les ombres, la perspective, les textures, les contours des objets et le fond. "
        "Ne rajoute aucun texte. Si une zone doit être reconstruite, invente un arrière-plan plausible et neutre. "
        "Utilise ce résumé d'analyse comme vérité de référence pour comprendre ce qu'il faut enlever: "
        f"{text_summary}{extra}"
    )



def analyze_with_nano_banana(image: Image.Image, user_request: str = "") -> dict[str, Any]:
    ready, reason = nano_banana_ready()
    fallback_boxes = detect_text_regions(image)
    fallback = {
        "backend": "local-heuristic",
        "status": reason,
        "language": "unknown",
        "visible_text": [],
        "blocks": [
            {
                "text": "",
                "role": "text_block",
                "position_hint": f"{b.x1},{b.y1},{b.x2},{b.y2}",
                "confidence": round(b.score, 3),
            }
            for b in fallback_boxes
        ],
        "reading_order": [],
        "clean_plate_notes": ["Fallback local heuristics used because the Nano Banana backend is unavailable."],
        "recommended_clean_prompt": "Remove all visible text while preserving the background and composition.",
        "detected_boxes": [b.to_tuple() for b in fallback_boxes],
        "raw_model_output": None,
    }

    if not ready:
        return fallback

    client = _client()
    prompt = build_analysis_prompt(user_request)
    interaction = client.interactions.create(
        model=DEFAULT_MODEL,
        input=[
            _text_part(prompt),
            _image_part(image),
        ],
    )
    raw_text = getattr(interaction, "output_text", "") or ""
    try:
        data = _extract_json(raw_text)
        if not isinstance(data, dict):
            raise ValueError("La réponse n'est pas un objet JSON")
    except Exception:
        fallback["backend"] = "nano-banana-2"
        fallback["status"] = "Réponse modèle non JSON, fallback local des zones détectées"
        fallback["raw_model_output"] = raw_text or None
        return fallback

    data.setdefault("language", "unknown")
    data.setdefault("visible_text", [])
    data.setdefault("blocks", [])
    data.setdefault("reading_order", [])
    data.setdefault("clean_plate_notes", [])
    data.setdefault("recommended_clean_prompt", "Remove all visible text while preserving the background and composition.")
    data["backend"] = "nano-banana-2"
    data["status"] = "ok"
    data["detected_boxes"] = [b.to_tuple() for b in fallback_boxes]
    data["raw_model_output"] = raw_text or None
    return data



def generate_clean_plate_with_nano_banana(
    image: Image.Image,
    analysis: dict[str, Any],
    extra_instruction: str = "",
    seed: int | None = None,
) -> tuple[Image.Image, str]:
    ready, reason = nano_banana_ready()
    if not ready:
        return _fallback_clean_plate(image, analysis), f"Fallback local activé: {reason}"

    client = _client()
    prompt = build_clean_prompt(analysis, extra_instruction=extra_instruction)
    if seed is not None:
        prompt += f"\nSeed de référence: {seed}"

    interaction = client.interactions.create(
        model=DEFAULT_MODEL,
        input=[
            _text_part(prompt),
            _image_part(image),
        ],
    )
    output = getattr(interaction, "output_image", None)
    if output is None or not getattr(output, "data", None):
        raise RuntimeError("Nano Banana n'a pas renvoyé d'image de sortie")
    data = base64.b64decode(output.data)
    return _image_from_bytes(data), "ok"



def _image_from_bytes(data: bytes) -> Image.Image:
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
        tmp.write(data)
        tmp_path = Path(tmp.name)
    try:
        return Image.open(tmp_path).convert("RGB")
    finally:
        try:
            tmp_path.unlink(missing_ok=True)
        except Exception:
            pass



def _fallback_clean_plate(image: Image.Image, analysis: dict[str, Any]) -> Image.Image:
    boxes = [Box(*coords) for coords in analysis.get("detected_boxes", []) if len(coords) == 4]
    if not boxes:
        boxes = detect_text_regions(image)
    if not boxes:
        return image.copy().convert("RGB")

    if not COMFY_SCRIPT.exists():
        return image.copy().convert("RGB")
    if not COMFY_WORKFLOW.exists():
        return image.copy().convert("RGB")

    ensure_comfy_running()
    mask = boxes_to_mask(image, boxes)
    with tempfile.TemporaryDirectory() as tmp:
        tmpdir = Path(tmp)
        input_path = tmpdir / "input.png"
        mask_path = tmpdir / "mask.png"
        image.save(input_path)
        mask.save(mask_path)
        prompt = str(analysis.get("recommended_clean_prompt", "Remove all visible text while preserving the background and composition."))
        clean_path = run_comfy_inpaint(str(input_path), str(mask_path), prompt)
        return Image.open(clean_path).convert("RGB")



def run_nano_banana_clean(image: Image.Image, user_request: str = "", extra_instruction: str = "", seed: int | None = None) -> tuple[dict[str, Any], Image.Image, str]:
    analysis = analyze_with_nano_banana(image, user_request=user_request)
    try:
        clean = generate_clean_plate_with_nano_banana(image, analysis, extra_instruction=extra_instruction, seed=seed)
        if isinstance(clean, tuple):
            clean_img, clean_status = clean
        else:
            clean_img = clean
            clean_status = "ok"
    except Exception as exc:
        clean_img = _fallback_clean_plate(image, analysis)
        clean_status = f"Fallback clean plate: {exc}"
    analysis["clean_status"] = clean_status
    return analysis, clean_img, clean_status


# ---------------------------------------------------------------------------
# ComfyUI fallback
# ---------------------------------------------------------------------------


def ensure_comfy_running() -> None:
    try:
        r = requests.get("http://127.0.0.1:8188/system_stats", timeout=3)
        if r.ok:
            return
    except Exception:
        pass
    comfy = shutil.which("comfy")
    if not comfy:
        raise RuntimeError("comfy-cli is not available on PATH")
    subprocess.run([comfy, "--workspace", str(COMFY_WORKSPACE), "launch", "--background"], check=True)
    for _ in range(60):
        try:
            r = requests.get("http://127.0.0.1:8188/system_stats", timeout=3)
            if r.ok:
                return
        except Exception:
            pass
    raise RuntimeError("ComfyUI did not become ready on 127.0.0.1:8188")



def run_comfy_inpaint(image_path: str, mask_path: str, prompt: str, negative_prompt: str = "", seed: int = 42) -> Path:
    ensure_dirs()
    ensure_comfy_running()
    if not COMFY_SCRIPT.exists():
        raise FileNotFoundError(f"Missing Comfy workflow runner: {COMFY_SCRIPT}")
    if not COMFY_WORKFLOW.exists():
        raise FileNotFoundError(f"Missing workflow: {COMFY_WORKFLOW}")
    cmd = [
        shutil.which("python3") or "python3",
        str(COMFY_SCRIPT),
        "--workflow",
        str(COMFY_WORKFLOW),
        "--input-image",
        f"image={image_path}",
        "--input-image",
        f"mask_image={mask_path}",
        "--args",
        json.dumps(
            {
                "prompt": prompt,
                "negative_prompt": negative_prompt,
                "steps": 28,
                "seed": seed,
            }
        ),
        "--output-dir",
        str(RUN_ROOT),
        "--timeout",
        "900",
    ]
    subprocess.run(cmd, check=True)
    outputs = sorted(RUN_ROOT.glob("*.png"), key=lambda p: p.stat().st_mtime, reverse=True)
    if not outputs:
        raise RuntimeError("ComfyUI run completed but no PNG output was found")
    return outputs[0]


# ---------------------------------------------------------------------------
# Small helpers used by the UI
# ---------------------------------------------------------------------------


def image_from_upload(value: Any) -> Image.Image | None:
    if value is None:
        return None
    if isinstance(value, Image.Image):
        return value.convert("RGB")
    try:
        return load_image(value)
    except Exception:
        return None



def write_json(path: str | Path, data: Any) -> Path:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
    return path



def to_overlay(image: Image.Image, analysis: dict[str, Any]) -> Image.Image:
    boxes = [Box(*coords) for coords in analysis.get("detected_boxes", []) if len(coords) == 4]
    return draw_boxes_preview(image, boxes)
