"""
ComfyUI custom node: Image Mask Comparer
=========================================

Compares the masked regions of two images and returns a boolean
indicating whether they match above a configurable similarity
threshold.  Useful for verifying that a specific region (e.g. a logo
or label) has been preserved after an AI generation or inpainting
step.

Retry behaviour
---------------
When ``max_retries`` > 0 and the comparison fails the node will:

1. **Randomise** the seed of every KSampler / noise node it finds
   in the prompt so that ComfyUI's cache is invalidated and a
   genuinely new image is generated.
2. **Queue** the modified prompt via direct access to the internal
   ``PromptServer.prompt_queue`` (with HTTP POST as fallback).
3. **Interrupt** the current execution so downstream nodes never
   see the bad result.

The loop repeats until the comparison passes or ``max_retries``
total attempts have been made, then gives up and outputs ``False``.
"""

from __future__ import annotations

import copy
import json
import random
import threading
import time
import urllib.request
import uuid as _uuid
from typing import Tuple, Dict, Any

import cv2
import numpy as np
import torch

# ── Sampler class_type -> seed field name mapping ────────────────────
# Covers the built-in ComfyUI samplers and the most popular custom
# packs.  Any node whose class_type is not listed here will still be
# caught by the generic "seed" / "noise_seed" scan below.
_KNOWN_SEED_FIELDS: Dict[str, str] = {
    "KSampler":                 "seed",
    "KSamplerAdvanced":         "noise_seed",
    "SamplerCustom":            "noise_seed",
    "SamplerCustomAdvanced":    "noise_seed",
    "RandomNoise":              "noise_seed",
    "Noise_RandomNoise":        "noise_seed",
    # Impact Pack
    "KSamplerProvider":         "seed",
    "BasicScheduler":           "seed",
}


class ImageMaskComparer:
    """Compare masked regions of two images; optionally retry on mismatch."""

    # Per-node retry counters keyed by node unique_id.
    _retry_counts: Dict[str, int] = {}

    # ── Node definition ──────────────────────────────────────────────
    @classmethod
    def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]:
        return {
            "required": {
                "image_a": ("IMAGE",),
                "image_b": ("IMAGE",),
                "mask": ("MASK",),
                "threshold": (
                    "FLOAT",
                    {"default": 0.90, "min": 0.0, "max": 1.0, "step": 0.01},
                ),
                "max_retries": (
                    "INT",
                    {"default": 4, "min": 0, "max": 50, "step": 1},
                ),
            },
            "hidden": {
                "prompt": "PROMPT",
                "unique_id": "UNIQUE_ID",
            },
        }

    RETURN_TYPES: Tuple[str, ...] = ("BOOLEAN", "FLOAT", "IMAGE")
    RETURN_NAMES: Tuple[str, ...] = ("is_match", "similarity", "image_out")
    FUNCTION: str = "compare"
    CATEGORY: str = "image/compare"
    OUTPUT_NODE = True

    # ── Image / mask helpers ─────────────────────────────────────────
    @staticmethod
    def _image_to_numpy(img: torch.Tensor) -> np.ndarray:
        t = img[0] if img.ndim == 4 else img
        arr = t.detach().cpu().numpy().astype(np.float32)
        arr = np.clip(arr, 0.0, 1.0)
        if arr.ndim == 2:
            arr = np.stack([arr] * 3, axis=-1)
        if arr.ndim == 3 and arr.shape[2] == 4:
            arr = arr[:, :, :3]
        return arr

    @staticmethod
    def _mask_to_numpy(mask: torch.Tensor) -> np.ndarray:
        t = mask
        if t.ndim == 4:
            t = t[0, :, :, 0]
        elif t.ndim == 3:
            t = t[0]
        arr = t.detach().cpu().numpy().astype(np.float32)
        return np.clip(arr, 0.0, 1.0)

    # ── Seed randomisation ───────────────────────────────────────────
    @staticmethod
    def _randomize_seeds(prompt: dict) -> Tuple[dict, int]:
        """Return a deep copy of *prompt* with every sampler seed randomised.

        Returns ``(modified_prompt, num_seeds_changed)``.
        """
        modified = copy.deepcopy(prompt)
        changed = 0

        for node_id, node_data in modified.items():
            class_type = node_data.get("class_type", "")
            inputs = node_data.get("inputs", {})

            # 1. Check the known-sampler table
            seed_field = _KNOWN_SEED_FIELDS.get(class_type)
            if seed_field and seed_field in inputs:
                val = inputs[seed_field]
                if isinstance(val, (int, float)):
                    inputs[seed_field] = random.randint(1, 2**53)
                    changed += 1
                    continue

            # 2. Generic fallback – any integer "seed" or "noise_seed"
            for field in ("seed", "noise_seed"):
                if field in inputs and isinstance(inputs[field], (int, float)):
                    inputs[field] = random.randint(1, 2**53)
                    changed += 1

        return modified, changed

    # ── Queue-item capture (synchronous) ─────────────────────────────
    @staticmethod
    def _capture_queue_item():
        """Deep-copy the currently-running queue item.

        Must be called *before* ``interrupt_processing`` while the item
        is still present in ``currently_running``.
        """
        try:
            from server import PromptServer  # type: ignore

            ps = PromptServer.instance
            for _pid, item in ps.prompt_queue.currently_running.items():
                return copy.deepcopy(item)
        except Exception:
            return None

    # ── Re-queue (background thread) ─────────────────────────────────
    @staticmethod
    def _requeue(queue_item, fallback_prompt):
        """Put a seed-randomised copy of the prompt back in the queue."""

        def _worker():
            time.sleep(0.5)
            success = False

            # ── Strategy 1: direct queue access ──────────────────────
            if queue_item is not None:
                try:
                    from server import PromptServer  # type: ignore

                    ps = PromptServer.instance
                    new_item = list(queue_item)
                    new_item[1] = str(_uuid.uuid4())          # fresh id
                    new_prompt, n = ImageMaskComparer._randomize_seeds(new_item[2])
                    new_item[2] = new_prompt
                    ps.prompt_queue.put(tuple(new_item))
                    print(
                        f"[ImageMaskComparer] Re-queued (direct) — "
                        f"randomised {n} seed(s)."
                    )
                    success = True
                except Exception as exc:
                    print(f"[ImageMaskComparer] Direct re-queue failed: {exc}")

            # ── Strategy 2: HTTP POST to /prompt ─────────────────────
            if not success and fallback_prompt is not None:
                try:
                    host, port = "127.0.0.1", 8188
                    try:
                        from server import PromptServer  # type: ignore

                        ps = PromptServer.instance
                        port = getattr(ps, "port", 8188)
                        addr = getattr(ps, "address", "127.0.0.1")
                        if addr in ("0.0.0.0", ""):
                            addr = "127.0.0.1"
                        host = addr
                    except Exception:
                        pass

                    modified, n = ImageMaskComparer._randomize_seeds(fallback_prompt)
                    url = f"http://{host}:{port}/prompt"
                    data = json.dumps({"prompt": modified}).encode("utf-8")
                    req = urllib.request.Request(
                        url,
                        data=data,
                        headers={"Content-Type": "application/json"},
                    )
                    urllib.request.urlopen(req, timeout=5)
                    print(
                        f"[ImageMaskComparer] Re-queued (HTTP) — "
                        f"randomised {n} seed(s)."
                    )
                    success = True
                except Exception as exc:
                    print(f"[ImageMaskComparer] HTTP re-queue failed: {exc}")

            if not success:
                print(
                    "[ImageMaskComparer] All re-queue strategies failed.  "
                    "Enable 'Auto Queue' in ComfyUI (Extra Options) "
                    "for manual retries."
                )

        threading.Thread(target=_worker, daemon=True).start()

    # ── Main comparison ──────────────────────────────────────────────
    def compare(
        self,
        image_a: torch.Tensor,
        image_b: torch.Tensor,
        mask: torch.Tensor,
        threshold: float = 0.90,
        max_retries: int = 4,
        prompt: dict | None = None,
        unique_id: str | None = None,
    ) -> Tuple[bool, float, torch.Tensor]:

        node_id = str(unique_id) if unique_id else "_default"

        # ── Convert inputs ───────────────────────────────────────────
        a = self._image_to_numpy(image_a)
        b = self._image_to_numpy(image_b)
        m = self._mask_to_numpy(mask)

        h, w, c = a.shape

        if b.shape[:2] != (h, w):
            b = cv2.resize(b, (w, h), interpolation=cv2.INTER_LINEAR)
        if m.shape[:2] != (h, w):
            m = cv2.resize(m, (w, h), interpolation=cv2.INTER_NEAREST)

        mask_bool = m > 0.5
        num_masked = int(np.count_nonzero(mask_bool))

        if num_masked == 0:
            self._retry_counts.pop(node_id, None)
            return (True, 1.0, image_b)

        # ── Compute similarity ───────────────────────────────────────
        mask_3d = np.stack([mask_bool] * c, axis=-1)
        pixels_a = a[mask_3d].astype(np.float64)
        pixels_b = b[mask_3d].astype(np.float64)

        mae = np.mean(np.abs(pixels_a - pixels_b))
        similarity_mae = 1.0 - mae

        mean_a, mean_b = np.mean(pixels_a), np.mean(pixels_b)
        a_c, b_c = pixels_a - mean_a, pixels_b - mean_b
        denom = np.sqrt(np.sum(a_c ** 2) * np.sum(b_c ** 2))
        if denom < 1e-8:
            ncc = 1.0 if mae < 0.01 else 0.0
        else:
            ncc = float(np.clip(np.sum(a_c * b_c) / denom, 0.0, 1.0))

        similarity = round(
            float(np.clip(0.6 * similarity_mae + 0.4 * ncc, 0.0, 1.0)), 4
        )
        is_match = similarity >= threshold

        # ── MATCH ────────────────────────────────────────────────────
        if is_match:
            attempts_used = self._retry_counts.get(node_id, 0)
            self._retry_counts.pop(node_id, None)
            print(
                f"[ImageMaskComparer] MATCH  "
                f"(similarity {similarity:.4f} >= {threshold})  "
                f"after {attempts_used} retry(ies).  Continuing workflow."
            )
            return (True, similarity, image_b)

        # ── MISMATCH ────────────────────────────────────────────────
        attempt = self._retry_counts.get(node_id, 0) + 1
        self._retry_counts[node_id] = attempt

        if max_retries > 0 and attempt < max_retries:
            print(
                f"[ImageMaskComparer] MISMATCH  "
                f"(similarity {similarity:.4f} < {threshold})  "
                f"attempt {attempt}/{max_retries} — "
                f"randomising seeds & re-queuing …"
            )

            # Capture BEFORE interrupt (still in currently_running).
            queue_item = self._capture_queue_item()
            self._requeue(queue_item, prompt)

            from nodes import interrupt_processing  # type: ignore

            interrupt_processing()
            # Explicit return — guarantees we never fall through even
            # if interrupt_processing() only sets a flag in this
            # ComfyUI build.
            return (False, similarity, image_b)

        # ── Retries exhausted ────────────────────────────────────────
        self._retry_counts.pop(node_id, None)
        print(
            f"[ImageMaskComparer] MISMATCH  "
            f"(similarity {similarity:.4f} < {threshold})  "
            f"after {attempt} attempt(s) — giving up."
        )
        return (False, similarity, image_b)


# ── ComfyUI registration ────────────────────────────────────────────
NODE_CLASS_MAPPINGS = {
    "ImageMaskComparer": ImageMaskComparer,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "ImageMaskComparer": "Image Mask Comparer",
}
