# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import datetime
import math
import os
from functools import wraps
from typing import Any, Callable, List, Literal, Optional, Tuple, Union

import torch

from kornia.core.external import PILImage as Image
from kornia.core.external import numpy as np


class ImageModuleMixIn:
    """A MixIn that handles image-based operations.

    This modules accepts multiple input and output data types, provides end-to-end visualization, file saving features.
    Note that this MixIn fits the classes that return one image tensor only.
    """

    _output_image: Any

    def convert_input_output(
        self,
        input_names_to_handle: Optional[List[Any]] = None,
        output_type: Literal["pt", "numpy", "pil"] = "pt",
    ) -> Callable[[Any], Any]:
        """Convert input and output types for a function.

        Args:
            input_names_to_handle: List of input names to convert, if None, handle all inputs.
            output_type: Desired output type ('pt', 'numpy', or 'pil').

        Returns:
            Callable: Decorated function with converted input and output types.

        """
        # Validate output_type at the start
        if output_type not in ("pt", "numpy", "pil"):
            raise ValueError(f"Invalid output_type '{output_type}'. Must be one of 'pt', 'numpy', or 'pil'.")

        def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
            @wraps(func)
            def wrapper(*args: Any, **kwargs: Any) -> Union[Any, List[Any]]:
                # If input_names_to_handle is None, handle all inputs
                if input_names_to_handle is None:
                    # Convert all args to tensors
                    args = tuple(self.to_tensor(arg) if self._is_valid_arg(arg) else arg for arg in args)
                    # Convert all kwargs to tensors
                    kwargs = {k: self.to_tensor(v) if self._is_valid_arg(v) else v for k, v in kwargs.items()}
                else:
                    # Convert specified args to tensors
                    args = list(args)  # type:ignore
                    for i, (arg, name) in enumerate(zip(args, func.__code__.co_varnames)):  # ty: ignore[unresolved-attribute]
                        if name in input_names_to_handle:
                            args[i] = self.to_tensor(arg)  # type:ignore
                    # Convert specified kwargs to tensors
                    for name, value in kwargs.items():
                        if name in input_names_to_handle:
                            kwargs[name] = self.to_tensor(value)

                # Call the actual forward method
                tensor_outputs = func(*args, **kwargs)

                if not isinstance(tensor_outputs, tuple):
                    tensor_outputs = (tensor_outputs,)

                # Convert outputs to the desired type
                outputs = []
                for output in tensor_outputs:
                    if output_type == "pt":
                        outputs.append(output)
                    elif output_type == "numpy":
                        outputs.append(self.to_numpy(output))
                    elif output_type == "pil":
                        outputs.append(self.to_pil(output))
                    else:
                        raise ValueError("Output type not supported. Choose from 'pt', 'numpy', or 'pil'.")

                return outputs if len(outputs) > 1 else outputs[0]

            return wrapper

        return decorator

    def _is_valid_arg(self, arg: Any) -> bool:
        """Check if the argument is a valid type for conversion.

        Args:
            arg: The argument to check.

        Returns:
            bool: True if valid, False otherwise.

        """
        if isinstance(arg, str) and os.path.exists(arg):
            return True
        if isinstance(arg, torch.Tensor):
            return True
        # Make sure that the numpy and PIL are not necessarily needed to be imported.
        if isinstance(arg, np.ndarray):  # type: ignore
            return True
        if isinstance(arg, (Image.Image)):  # type: ignore
            return True
        return False

    def to_tensor(self, x: Any) -> torch.Tensor:
        """Convert input to tensor.

        Supports image path, numpy array, PIL image, and raw tensor.

        Args:
            x: The input to convert.

        Returns:
            Tensor: The converted tensor.

        """
        if isinstance(x, str):
            from kornia.io import ImageLoadType, load_image  # pylint: disable=C0415

            return load_image(x, ImageLoadType.UNCHANGED) / 255
        if isinstance(x, torch.Tensor):
            return x
        if isinstance(x, np.ndarray):  # type: ignore
            from kornia.image.utils import image_to_tensor  # pylint: disable=C0415

            return image_to_tensor(x) / 255
        if isinstance(x, Image.Image):  # type: ignore
            return torch.from_numpy(np.array(x)).permute(2, 0, 1).float() / 255  # type: ignore
        raise TypeError("Input type not supported")

    def to_numpy(self, x: Any) -> "np.array":  # type: ignore
        """Convert input to numpy array.

        Args:
            x: The input to convert.

        Returns:
            np.array: The converted numpy array.

        """
        if isinstance(x, torch.Tensor):
            return x.cpu().detach().numpy()
        if isinstance(x, np.ndarray):  # type: ignore
            return x
        if isinstance(x, Image.Image):  # type: ignore
            return np.array(x)  # type: ignore
        raise TypeError("Input type not supported")

    def to_pil(self, x: Any) -> "Image.Image":  # type: ignore
        """Convert input to PIL image.

        Args:
            x: The input to convert.

        Returns:
            Image.Image: The converted PIL image.

        """
        if isinstance(x, torch.Tensor):
            x = x.cpu().detach() * 255
            if x.dim() == 3:
                x = x.permute(1, 2, 0)
                return Image.fromarray(x.byte().numpy())  # type: ignore
            elif x.dim() == 4:
                x = x.permute(0, 2, 3, 1)
                return [Image.fromarray(_x.byte().numpy()) for _x in x]  # type: ignore
            else:
                raise NotImplementedError
        if isinstance(x, np.ndarray):  # type: ignore
            raise NotImplementedError
        if isinstance(x, Image.Image):  # type: ignore
            return x
        raise TypeError("Input type not supported")

    def _detach_tensor_to_cpu(
        self, output_image: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]:
        if isinstance(output_image, torch.Tensor):
            return output_image.detach().cpu()
        if isinstance(output_image, list | tuple):
            return type(output_image)([self._detach_tensor_to_cpu(out) for out in output_image])  # type: ignore
        raise RuntimeError(f"Unexpected object {output_image} with a type of `{type(output_image)}`")

    def show(self, n_row: Optional[int] = None, backend: str = "pil", display: bool = True) -> Optional[Any]:
        """Return PIL images.

        Args:
            n_row: Number of images displayed in each row of the grid.
            backend: visualization backend. Only PIL is supported now.
            display: Whether or not to show the image.

        """
        if self._output_image is None:
            raise ValueError("No pre-computed images found. Needs to execute first.")

        if len(self._output_image.shape) == 3:
            out_image = self._output_image
        elif len(self._output_image.shape) == 4:
            from kornia.image.utils import make_grid  # pylint: disable=C0415

            if n_row is None:
                n_row = math.ceil(self._output_image.shape[0] ** 0.5)
            out_image = make_grid(self._output_image, n_row, padding=2)
        else:
            raise ValueError

        if backend == "pil" and display:
            Image.fromarray((out_image.permute(1, 2, 0).squeeze().numpy() * 255).astype(np.uint8)).show()  # type: ignore
            return None
        if backend == "pil":
            return Image.fromarray((out_image.permute(1, 2, 0).squeeze().numpy() * 255).astype(np.uint8))  # type: ignore
        raise ValueError(f"Unsupported backend `{backend}`.")

    def save(self, name: Optional[str] = None, n_row: Optional[int] = None) -> None:
        """Save the output image(s) to a directory.

        Args:
            name: Directory to save the images.
            n_row: Number of images displayed in each row of the grid.

        """
        from kornia.image.utils import make_grid  # pylint: disable=C0415
        from kornia.io import write_image  # pylint: disable=C0415

        if name is None:
            name = f"Kornia-{datetime.datetime.now(tz=datetime.UTC).strftime('%Y%m%d%H%M%S')!s}.jpg"
        if len(self._output_image.shape) == 3:
            out_image = self._output_image
        if len(self._output_image.shape) == 4:
            if n_row is None:
                n_row = math.ceil(self._output_image.shape[0] ** 0.5)
            out_image = make_grid(self._output_image, n_row, padding=2)
        write_image(name, out_image.mul(255.0).byte())
