from __future__ import annotations

import math
import os
import shutil
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Sequence, Tuple

import cv2
import numpy as np
from PIL import Image

try:
    from nudenet import NudeDetector
except Exception:  # pragma: no cover
    NudeDetector = None

IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
VIDEO_EXTS = {".mp4", ".mov", ".mkv", ".webm", ".avi", ".m4v"}

# NudeNet labels vary by version/model. Keep only labels that correspond to nipples,
# genitalia, anus/buttocks, exposed breasts. Exclude face/feet/armpits etc.
SENSITIVE_KEYWORDS = (
    "anus",
    "buttock",
    "buttocks",
    "breast",
    "genital",
    "penis",
    "vagina",
    "vulva",
    "pussy",
    "testicle",
    "nipple",
)
IGNORED_KEYWORDS = (
    "face",
    "feet",
    "foot",
    "armpit",
    "belly",
    "navel",
)


@dataclass
class Box:
    x1: int
    y1: int
    x2: int
    y2: int
    score: float = 1.0
    label: str = "sensitive"

    def clamp(self, width: int, height: int) -> "Box":
        return Box(
            max(0, min(width - 1, int(self.x1))),
            max(0, min(height - 1, int(self.y1))),
            max(0, min(width, int(self.x2))),
            max(0, min(height, int(self.y2))),
            float(self.score),
            self.label,
        )

    def expanded(self, width: int, height: int, margin: float) -> "Box":
        bw = self.x2 - self.x1
        bh = self.y2 - self.y1
        dx = int(round(bw * margin))
        dy = int(round(bh * margin))
        return Box(self.x1 - dx, self.y1 - dy, self.x2 + dx, self.y2 + dy, self.score, self.label).clamp(width, height)

    def valid(self) -> bool:
        return self.x2 > self.x1 + 2 and self.y2 > self.y1 + 2


def is_supported(path: Path) -> bool:
    return path.suffix.lower() in IMAGE_EXTS | VIDEO_EXTS


def is_image(path: Path) -> bool:
    return path.suffix.lower() in IMAGE_EXTS


def is_video(path: Path) -> bool:
    return path.suffix.lower() in VIDEO_EXTS


def _sensitive_label(label: str) -> bool:
    l = label.lower().replace("_", "-")
    if any(k in l for k in IGNORED_KEYWORDS):
        return False
    return any(k in l for k in SENSITIVE_KEYWORDS)


class SensitiveDetector:
    def __init__(self, score_threshold: float = 0.25):
        if NudeDetector is None:
            raise RuntimeError("nudenet is not installed or could not be imported")
        self.detector = NudeDetector()
        self.score_threshold = score_threshold

    def detect_bgr(self, frame_bgr: np.ndarray) -> List[Box]:
        # NudeNet accepts a numpy image. Use RGB to be safe.
        frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
        detections = self.detector.detect(frame_rgb)
        boxes: List[Box] = []
        h, w = frame_bgr.shape[:2]
        for det in detections:
            label = str(det.get("class", det.get("label", "")))
            score = float(det.get("score", det.get("confidence", 1.0)))
            if score < self.score_threshold or not _sensitive_label(label):
                continue
            raw = det.get("box") or det.get("bbox")
            if raw is None or len(raw) != 4:
                continue
            # NudeNet returns [x, y, width, height]
            x, y, bw, bh = [float(v) for v in raw]
            box = Box(int(x), int(y), int(x + bw), int(y + bh), score, label).clamp(w, h)
            if box.valid():
                boxes.append(box)
        return boxes


def blur_boxes(frame_bgr: np.ndarray, boxes: Sequence[Box], blur_strength: int = 61, margin: float = 0.15) -> np.ndarray:
    if blur_strength % 2 == 0:
        blur_strength += 1
    blur_strength = max(3, int(blur_strength))
    out = frame_bgr.copy()
    h, w = out.shape[:2]
    for box in boxes:
        b = box.expanded(w, h, margin)
        if not b.valid():
            continue
        roi = out[b.y1:b.y2, b.x1:b.x2]
        if roi.size == 0:
            continue
        # Kernel cannot exceed ROI dimensions too absurdly; OpenCV handles large kernels but this keeps it reasonable.
        k = min(blur_strength, max(3, (min(roi.shape[:2]) // 2) * 2 + 1))
        if k % 2 == 0:
            k += 1
        out[b.y1:b.y2, b.x1:b.x2] = cv2.GaussianBlur(roi, (k, k), 0)
    return out


def process_image(input_path: Path, output_path: Path, detector: SensitiveDetector, blur_strength: int = 61, margin: float = 0.15) -> dict:
    image = cv2.imread(str(input_path), cv2.IMREAD_COLOR)
    if image is None:
        raise ValueError(f"Cannot read image: {input_path}")
    boxes = detector.detect_bgr(image)
    result = blur_boxes(image, boxes, blur_strength=blur_strength, margin=margin)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    ok = cv2.imwrite(str(output_path), result)
    if not ok:
        raise ValueError(f"Cannot write image: {output_path}")
    return {"type": "image", "detections": len(boxes), "output": str(output_path)}


def iou(a: Box, b: Box) -> float:
    ix1 = max(a.x1, b.x1)
    iy1 = max(a.y1, b.y1)
    ix2 = min(a.x2, b.x2)
    iy2 = min(a.y2, b.y2)
    iw = max(0, ix2 - ix1)
    ih = max(0, iy2 - iy1)
    inter = iw * ih
    area_a = max(0, a.x2 - a.x1) * max(0, a.y2 - a.y1)
    area_b = max(0, b.x2 - b.x1) * max(0, b.y2 - b.y1)
    denom = area_a + area_b - inter
    return inter / denom if denom else 0.0


def smooth_boxes(current: Sequence[Box], previous: Sequence[Box], alpha: float = 0.65, iou_threshold: float = 0.05) -> List[Box]:
    smoothed: List[Box] = []
    used_prev: set[int] = set()
    for cur in current:
        best_i = -1
        best_score = 0.0
        for idx, prev in enumerate(previous):
            if idx in used_prev:
                continue
            score = iou(cur, prev)
            if score > best_score:
                best_i = idx
                best_score = score
        if best_i >= 0 and best_score >= iou_threshold:
            prev = previous[best_i]
            used_prev.add(best_i)
            smoothed.append(Box(
                int(alpha * cur.x1 + (1 - alpha) * prev.x1),
                int(alpha * cur.y1 + (1 - alpha) * prev.y1),
                int(alpha * cur.x2 + (1 - alpha) * prev.x2),
                int(alpha * cur.y2 + (1 - alpha) * prev.y2),
                cur.score,
                cur.label,
            ))
        else:
            smoothed.append(cur)
    return smoothed


def _ffmpeg_copy_audio(original: Path, silent_video: Path, final_output: Path) -> bool:
    if shutil.which("ffmpeg") is None:
        shutil.move(str(silent_video), str(final_output))
        return False
    cmd = [
        "ffmpeg", "-y",
        "-i", str(silent_video),
        "-i", str(original),
        "-map", "0:v:0",
        "-map", "1:a:0?",
        "-c:v", "copy",
        "-c:a", "aac",
        "-shortest",
        str(final_output),
    ]
    proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    if proc.returncode == 0 and final_output.exists():
        silent_video.unlink(missing_ok=True)
        return True
    shutil.move(str(silent_video), str(final_output))
    return False


def preview_frame(
    input_path: Path,
    output_path: Path,
    detector: SensitiveDetector,
    blur_strength: int = 61,
    margin: float = 0.15,
    frame_index: int = 0,
) -> dict:
    """Create a blurred preview still for an image or one selected video frame."""
    output_path.parent.mkdir(parents=True, exist_ok=True)
    if is_image(input_path):
        image = cv2.imread(str(input_path), cv2.IMREAD_COLOR)
        if image is None:
            raise ValueError(f"Cannot read image: {input_path}")
        boxes = detector.detect_bgr(image)
        result = blur_boxes(image, boxes, blur_strength=blur_strength, margin=margin)
        cv2.imwrite(str(output_path), result)
        return {"type": "image", "frame_index": 0, "detections": len(boxes), "output": str(output_path)}

    if not is_video(input_path):
        raise ValueError(f"Unsupported file type: {input_path.suffix}")
    cap = cv2.VideoCapture(str(input_path))
    if not cap.isOpened():
        raise ValueError(f"Cannot open video: {input_path}")
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
    fps = float(cap.get(cv2.CAP_PROP_FPS) or 24.0)
    if total > 0:
        frame_index = max(0, min(int(frame_index), total - 1))
    else:
        frame_index = max(0, int(frame_index))
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
    ok, frame = cap.read()
    cap.release()
    if not ok or frame is None:
        raise ValueError(f"Cannot read frame {frame_index} from video")
    boxes = detector.detect_bgr(frame)
    result = blur_boxes(frame, boxes, blur_strength=blur_strength, margin=margin)
    ok = cv2.imwrite(str(output_path), result)
    if not ok:
        raise ValueError(f"Cannot write preview: {output_path}")
    return {"type": "video", "frame_index": frame_index, "input_frames": total, "fps": fps, "detections": len(boxes), "output": str(output_path)}


def process_video(
    input_path: Path,
    output_path: Path,
    detector: SensitiveDetector,
    blur_strength: int = 61,
    margin: float = 0.15,
    detect_every: int = 5,
    max_frames: int | None = None,
) -> dict:
    cap = cv2.VideoCapture(str(input_path))
    if not cap.isOpened():
        raise ValueError(f"Cannot open video: {input_path}")
    fps = cap.get(cv2.CAP_PROP_FPS) or 24.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    tmp_video = output_path.with_suffix(".silent.mp4")
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(str(tmp_video), fourcc, fps, (width, height))
    if not writer.isOpened():
        cap.release()
        raise ValueError(f"Cannot create video writer: {tmp_video}")

    prev_boxes: List[Box] = []
    current_boxes: List[Box] = []
    frames = 0
    detection_frames = 0
    total_detections = 0
    try:
        while True:
            ok, frame = cap.read()
            if not ok:
                break
            if max_frames is not None and frames >= max_frames:
                break
            if frames % max(1, detect_every) == 0 or not current_boxes:
                detected = detector.detect_bgr(frame)
                current_boxes = smooth_boxes(detected, prev_boxes)
                prev_boxes = current_boxes
                detection_frames += 1
                total_detections += len(current_boxes)
            blurred = blur_boxes(frame, current_boxes, blur_strength=blur_strength, margin=margin)
            writer.write(blurred)
            frames += 1
    finally:
        cap.release()
        writer.release()

    audio_copied = _ffmpeg_copy_audio(input_path, tmp_video, output_path)
    return {
        "type": "video",
        "frames": frames,
        "input_frames": total,
        "fps": fps,
        "detection_frames": detection_frames,
        "detections": total_detections,
        "audio_copied": audio_copied,
        "output": str(output_path),
    }
