# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar, Union

import torch
from torch import nn
from torch.distributions import Distribution, Uniform

from kornia.augmentation.utils.helpers import MultiprocessWrapper

T = TypeVar("T")


class _PostInitInjectionMetaClass(type):
    """To inject the ``__post_init__`` function after the creation of each instance."""

    def __call__(cls: Type[T], *args: Any, **kwargs: Any) -> T:
        obj = type.__call__(cls, *args, **kwargs)
        obj.__post_init__()
        return obj


class RandomGeneratorBase(nn.Module, metaclass=_PostInitInjectionMetaClass):
    """Base class for generating random augmentation parameters."""

    device: Union[None, str, torch.device] = None
    dtype: torch.dtype

    def __init__(self) -> None:
        super().__init__()

    def __post_init__(self) -> None:
        self.set_rng_device_and_dtype()

    def set_rng_device_and_dtype(
        self,
        device: Optional[torch.device] = None,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Change the random generation device and dtype.

        Note:
            The generated random numbers are not reproducible across different devices and dtypes.

        """
        if device is None:
            device = torch.device("cpu")
        if self.device != device or self.dtype != dtype:
            self.make_samplers(device, dtype)
            self.device = device
            self.dtype = dtype

    # TODO: refine the logic with module.to()
    def to(self, *args: Any, **kwargs: Any) -> "RandomGeneratorBase":
        """Update sampler device and dtype using ``torch.nn.Module.to`` semantics.

        Args:
            *args: Positional arguments accepted by ``Module.to``.
            **kwargs: Keyword arguments accepted by ``Module.to``.

        Returns:
            This generator instance.
        """
        device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
        self.set_rng_device_and_dtype(device=device, dtype=dtype)
        return self

    def make_samplers(self, device: torch.device, dtype: torch.dtype) -> None:
        """Create distribution samplers for the given device and dtype.

        Args:
            device: Target device.
            dtype: Target floating-point dtype.

        Raises:
            NotImplementedError: Subclass did not implement sampler creation.
        """
        raise NotImplementedError

    def forward(self, batch_shape: Tuple[int, ...], same_on_batch: bool = False) -> Dict[str, torch.Tensor]:
        """Sample random augmentation parameters.

        Args:
            batch_shape: Target batch shape.
            same_on_batch: If ``True``, use one sample and broadcast it across the
                batch dimension.

        Returns:
            Dictionary of tensors consumed by augmentation modules.

        Raises:
            NotImplementedError: Subclass did not implement parameter sampling.
        """
        raise NotImplementedError


class DistributionWithMapper(Distribution):
    """Wraps a distribution with a value mapper function.

    This is used to restrict the output values of a given distribution by a value mapper function.
    The value mapper function can be functions like sigmoid, tanh, etc.

    Args:
        dist: the target distribution.
        map_fn: the callable function to adjust the output from distributions.

    Example:
        >>> from torch.distributions import Normal
        >>> import torch.nn as nn
        >>> # without mapper
        >>> dist = DistributionWithMapper(Normal(0., 1.,), map_fn=None)
        >>> _ = torch.manual_seed(0)
        >>> dist.rsample((8,))
        tensor([ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845, -1.3986,  0.4033,  0.8380])
        >>> # with sigmoid mapper
        >>> dist = DistributionWithMapper(Normal(0., 1.,), map_fn=nn.Sigmoid())
        >>> _ = torch.manual_seed(0)
        >>> dist.rsample((8,))
        tensor([0.8236, 0.4272, 0.1017, 0.6384, 0.2527, 0.1980, 0.5995, 0.6980])

    """

    def __init__(self, dist: Distribution, map_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None) -> None:
        self.dist = dist
        self.map_fn = map_fn

    def rsample(self, sample_shape: Tuple[int, ...]) -> torch.Tensor:  # type: ignore[override]
        """Draw a reparameterized sample and apply ``map_fn`` when provided.

        Args:
            sample_shape: Desired sample shape.

        Returns:
            Sample tensor after optional mapping.
        """
        out = self.dist.rsample(torch.Size(sample_shape))
        if self.map_fn is not None:
            out = self.map_fn(out)
        return out

    def sample(self, sample_shape: Tuple[int, ...]) -> torch.Tensor:  # type: ignore[override]
        """Draw a sample and apply ``map_fn`` when provided.

        Args:
            sample_shape: Desired sample shape.

        Returns:
            Sample tensor after optional mapping.
        """
        out = self.dist.sample(torch.Size(sample_shape))
        if self.map_fn is not None:
            out = self.map_fn(out)
        return out

    def sample_n(self, n: int) -> torch.Tensor:
        """Draw ``n`` samples and apply ``map_fn`` when provided.

        Args:
            n: Number of samples.

        Returns:
            Sample tensor after optional mapping.
        """
        out = self.dist.sample_n(n)
        if self.map_fn is not None:
            out = self.map_fn(out)
        return out

    def __getattr__(self, attr: str) -> Any:
        try:
            return getattr(self, attr)
        except AttributeError:
            return getattr(self.dist, attr)


class UniformDistribution(MultiprocessWrapper, Uniform):
    """Wrapper around torch Uniform distribution which makes it work with the 'spawn' multiprocessing context."""
