# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Based from the original code from Meta Platforms, Inc. and affiliates.

https://github.com/facebookresearch/segment-
anything/blob/3518c86b78b3bc9cf4fbe3d18e682fad1c79dc51/segment_anything/build_sam.py

https://github.com/facebookresearch/segment-
anything/blob/3518c86b78b3bc9cf4fbe3d18e682fad1c79dc51/segment_anything/modeling/sam.py
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional

import torch

from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE
from kornia.models.base import ModelBase
from kornia.models.sam.architecture.common import LayerNorm
from kornia.models.sam.architecture.image_encoder import ImageEncoderViT
from kornia.models.sam.architecture.mask_decoder import MaskDecoder
from kornia.models.sam.architecture.prompt_encoder import PromptEncoder
from kornia.models.sam.architecture.transformer import TwoWayTransformer
from kornia.models.structures import SegmentationResults
from kornia.models.tiny_vit import TinyViT


class SamModelType(Enum):
    """Map the SAM model types."""

    vit_h = 0
    vit_l = 1
    vit_b = 2
    mobile_sam = 3


@dataclass
class SamConfig:
    """Encapsulate the Config to build a SAM model.

    Args:
        model_type: the available models are:

            - 0, 'vit_h' or :func:`kornia.contrib.sam.SamModelType.vit_h`
            - 1, 'vit_l' or :func:`kornia.contrib.sam.SamModelType.vit_l`
            - 2, 'vit_b' or :func:`kornia.contrib.sam.SamModelType.vit_b`
            - 3, 'mobile_sam', or :func:`kornia.contrib.sam.SamModelType.mobile_sam`

        checkpoint: URL or a path for a file with the weights of the model
        encoder_embed_dim: Patch embedding dimension.
        encoder_depth: Depth of ViT.
        encoder_num_heads: Number of attention heads in each ViT block.
        encoder_global_attn_indexes: Encoder indexes for blocks using global attention.

    """

    model_type: Optional[str | int | SamModelType] = None
    checkpoint: Optional[str] = None
    pretrained: bool = False

    encoder_embed_dim: Optional[int] = None
    encoder_depth: Optional[int] = None
    encoder_num_heads: Optional[int] = None
    encoder_global_attn_indexes: Optional[tuple[int, ...]] = None


class Sam(ModelBase[SamConfig]):
    """Implement the Segment Anything Model (SAM) wrapper.

    This class coordinates the image encoder, prompt encoder, and mask decoder.
    """

    mask_threshold: float = 0.0

    def __init__(
        self, image_encoder: ImageEncoderViT | TinyViT, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder
    ) -> None:
        """SAM predicts object masks from an image and input prompts.

        Args:
            image_encoder: The backbone used to encode the image into image embeddings that allow for efficient mask
                           prediction.
            prompt_encoder: Encodes various types of input prompts.
            mask_decoder: Predicts masks from the image embeddings and encoded prompts.

        """
        super().__init__()
        self.image_encoder = image_encoder
        self.prompt_encoder = prompt_encoder
        self.mask_decoder = mask_decoder

    @staticmethod
    def from_name(name: str) -> Sam:
        """Build/load the SAM model based on it's name.

        Args:
            name: The name of the SAM model. Valid names are:
                - 'vit_b'
                - 'vit_l'
                - 'vit_h'
                - 'mobile_sam'

        Returns:
            The respective SAM model

        """
        if name in ["vit_b", "vit_l", "vit_h", "mobile_sam"]:
            return Sam.from_config(SamConfig(name))
        else:
            raise ValueError(f"Invalid SAM model name: {name}")

    @staticmethod
    def from_config(config: SamConfig) -> Sam:
        """Build/load the SAM model based on it's config.

        Args:
            config: The SamConfig data structure. If the model_type is available, build from it, otherwise will use
                    the parameters set.

        Returns:
            The respective SAM model

        Example:
            >>> from kornia.models.sam import SamConfig
            >>> sam_model = Sam.from_config(SamConfig('vit_b'))

        """
        model_type = config.model_type

        if isinstance(model_type, int):
            model_type = SamModelType(model_type)
        elif isinstance(model_type, str):
            _map_sam_type = {
                "vit_h": SamModelType.vit_h,
                "vit_l": SamModelType.vit_l,
                "vit_b": SamModelType.vit_b,
                "mobile_sam": SamModelType.mobile_sam,
            }
            model_type = _map_sam_type[model_type]

        if model_type == SamModelType.vit_b:
            model = _build_sam(
                encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_global_attn_indexes=(2, 5, 8, 11)
            )

        elif model_type == SamModelType.vit_l:
            model = _build_sam(
                encoder_embed_dim=1024,
                encoder_depth=24,
                encoder_num_heads=16,
                encoder_global_attn_indexes=(5, 11, 17, 23),
            )

        elif model_type == SamModelType.vit_h:
            model = _build_sam(
                encoder_embed_dim=1280,
                encoder_depth=32,
                encoder_num_heads=16,
                encoder_global_attn_indexes=(7, 15, 23, 31),
            )

        elif model_type == SamModelType.mobile_sam:
            # TODO: merge this with _build_sam()
            prompt_embed_dim = 256
            image_size = 1024
            vit_patch_size = 16
            image_embedding_size = image_size // vit_patch_size

            model = Sam(
                image_encoder=TinyViT.from_config("5m", img_size=image_size, mobile_sam=True),
                prompt_encoder=PromptEncoder(
                    embed_dim=prompt_embed_dim,
                    image_embedding_size=(image_embedding_size, image_embedding_size),
                    input_image_size=(image_size, image_size),
                    mask_in_chans=16,
                ),
                mask_decoder=MaskDecoder(
                    num_multimask_outputs=3,
                    transformer=TwoWayTransformer(depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8),
                    transformer_dim=prompt_embed_dim,
                    iou_head_depth=3,
                    iou_head_hidden_dim=256,
                ),
                #     pixel_mean=[123.675, 116.28, 103.53],
                #     pixel_std=[58.395, 57.12, 57.375],
            )

        elif (
            isinstance(config.encoder_embed_dim, int)
            and isinstance(config.encoder_depth, int)
            and isinstance(config.encoder_num_heads, int)
            and isinstance(config.encoder_global_attn_indexes, int)
        ):
            model = _build_sam(
                encoder_embed_dim=config.encoder_embed_dim,
                encoder_depth=config.encoder_depth,
                encoder_num_heads=config.encoder_num_heads,
                encoder_global_attn_indexes=config.encoder_global_attn_indexes,
            )

        else:
            raise NotImplementedError("Unexpected config. The model_type should be provide or the encoder configs.")

        checkpoint = config.checkpoint
        if config.pretrained:
            if checkpoint is None:
                checkpoint = {
                    SamModelType.vit_b: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
                    SamModelType.vit_l: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
                    SamModelType.vit_h: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
                    SamModelType.mobile_sam: "https://github.com/ChaoningZhang/MobileSAM/raw/a509aac54fdd7af59f843135f2f7cee307283c88/weights/mobile_sam.pt",
                }[model_type]
            else:
                warnings.warn("checkpoint is not None. pretrained=True is ignored", stacklevel=1)

        if checkpoint:
            model.load_checkpoint(checkpoint)

        return model

    @torch.no_grad()
    def forward(
        self, images: torch.Tensor, batched_prompts: list[dict[str, Any]], multimask_output: bool
    ) -> list[SegmentationResults]:
        """Predicts masks end-to-end from provided images and prompts.

        This method expects that the images have already been pre-processed, at least been normalized, resized and
        padded to be compatible with the `self.image_encoder`.

        .. note:: For each image :math:`(3, H, W)`, it is possible to input a batch (:math:`K`) of :math:`N` prompts,
                 the results are batched by the number of prompts batch. So given a prompt with :math:`K=5`, and
                 :math:`N=10`, the results will look like :math:`5xCxHxW` where :math:`C` is determined by
                 multimask_output. And within each of these masks :math:`(5xC)`, it should be possible to find
                 :math:`N` instances if the model succeed.

        Args:
            images: The image as a torch tensor in :math:`(B, 3, H, W)` format, already transformed for input to the
                    model.
            batched_prompts: A list over the batch of images (list length should be :math:`B`), each a dictionary with
                             the following keys. If it does not have the respective prompt, it should not be included
                             in this dictionary. The options are:

                - "points": tuple of (torch.Tensor, torch.Tensor) within the coordinate keypoints
                  and their respective labels. The tuple should look like (keypoints, labels), where the keypoints
                  (a tensor) are a batched point prompts for this image, with shape :math:`(K, N, 2)`. Already
                  transformed to the input frame of the model. The labels (a tensor) are a batched labels for point
                  prompts, with shape :math:`(K, N)`. Where 1 indicates a foreground point and 0 indicates a background
                  point.

                - "boxes": (torch.Tensor) Batched box inputs, with shape :math:`(K, 4)`.
                  Already transformed to the input frame of the model.

                - "mask_inputs": (torch.Tensor) Batched mask inputs to the model, in the form :math:`(K, 1, H, W)`.

            multimask_output: Whether the model should predict multiple disambiguating masks, or return a single mask.

        Returns:
            A list over input images, where each element is as SegmentationResults the following:

                - logits: Low resolution logits with shape :math:`(K, C, H, W)`. Can be passed as mask input to
                  subsequent iterations of prediction. Where :math:`K` is the number of input prompts,
                  :math:`C` is determined by multimask_output, and :math:`H=W=256` are the model output size.
                - scores: The model's predictions of mask quality (iou prediction), in shape BxC.

        """
        KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"])
        KORNIA_CHECK(
            images.shape[0] == len(batched_prompts),
            "The number of images (`B`) should match with the length of prompts!",
        )

        image_embeddings = self.image_encoder(images)

        outputs = []
        for prompt_record, curr_embedding in zip(batched_prompts, image_embeddings):
            # Embed prompts
            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=prompt_record.get("points", None),
                boxes=prompt_record.get("boxes", None),
                masks=prompt_record.get("mask_inputs", None),
            )

            # Predict masks
            low_res_logits, iou_predictions = self.mask_decoder(
                image_embeddings=curr_embedding[None, ...],
                image_pe=self.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=multimask_output,
            )

            # Save results
            outputs.append(SegmentationResults(low_res_logits, iou_predictions, self.mask_threshold))

        return outputs


def _build_sam(
    encoder_embed_dim: int, encoder_depth: int, encoder_num_heads: int, encoder_global_attn_indexes: tuple[int, ...]
) -> Sam:
    prompt_embed_dim = 256
    image_size = 1024
    vit_patch_size = 16
    image_embedding_size = image_size // vit_patch_size

    return Sam(
        image_encoder=ImageEncoderViT(
            depth=encoder_depth,
            embed_dim=encoder_embed_dim,
            img_size=image_size,
            mlp_ratio=4,
            norm_layer=LayerNorm,
            num_heads=encoder_num_heads,
            patch_size=vit_patch_size,
            qkv_bias=True,
            use_rel_pos=True,
            global_attn_indexes=encoder_global_attn_indexes,
            window_size=14,
            out_chans=prompt_embed_dim,
        ),
        prompt_encoder=PromptEncoder(
            embed_dim=prompt_embed_dim,
            image_embedding_size=(image_embedding_size, image_embedding_size),
            input_image_size=(image_size, image_size),
            mask_in_chans=16,
        ),
        mask_decoder=MaskDecoder(
            num_multimask_outputs=3,
            transformer=TwoWayTransformer(depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8),
            transformer_dim=prompt_embed_dim,
            iou_head_depth=3,
            iou_head_hidden_dim=256,
        ),
        #     pixel_mean=[123.675, 116.28, 103.53],
        #     pixel_std=[58.395, 57.12, 57.375],
    )
