# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

from typing import ClassVar

import torch
import torch.nn.functional as F
from torch import nn

from kornia.core.check import KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE


def rgb_to_xyz(image: torch.Tensor) -> torch.Tensor:
    r"""Convert a RGB image to XYZ.

    .. image:: _static/img/rgb_to_xyz.png

    Args:
        image: RGB Image to be converted to XYZ with shape :math:`(*, 3, H, W)`.

    Returns:
         XYZ version of the image with shape :math:`(*, 3, H, W)`.

    Example:
        >>> input = torch.rand(2, 3, 4, 5)
        >>> output = rgb_to_xyz(input)  # 2x3x4x5

    """
    KORNIA_CHECK_IS_TENSOR(image)
    KORNIA_CHECK_SHAPE(image, ["*", "3", "H", "W"])

    # CIE RGB to XYZ Matrix (D65 White Point)
    kernel = torch.tensor(
        [
            [0.412453, 0.357580, 0.180423],
            [0.212671, 0.715160, 0.072169],
            [0.019334, 0.119193, 0.950227],
        ],
        device=image.device,
        dtype=torch.float32,
    )

    # Apply Optimized Linear Transformation
    return _apply_linear_transformation(image, kernel)


def xyz_to_rgb(image: torch.Tensor) -> torch.Tensor:
    r"""Convert a XYZ image to RGB.

    Args:
        image: XYZ Image to be converted to RGB with shape :math:`(*, 3, H, W)`.

    Returns:
        RGB version of the image with shape :math:`(*, 3, H, W)`.

    Example:
        >>> input = torch.rand(2, 3, 4, 5)
        >>> output = xyz_to_rgb(input)  # 2x3x4x5

    """
    KORNIA_CHECK_IS_TENSOR(image)
    KORNIA_CHECK_SHAPE(image, ["*", "3", "H", "W"])

    # CIE XYZ to RGB Matrix (D65 White Point)
    kernel = torch.tensor(
        [
            [3.2404813432005266, -1.5371515162713185, -0.4985363261688878],
            [-0.9692549499965682, 1.8759900014898907, 0.0415559265582928],
            [0.0556466391351772, -0.2040413383665112, 1.0573110696453443],
        ],
        device=image.device,
        dtype=torch.float32,
    )

    # Apply Optimized Linear Transformation
    return _apply_linear_transformation(image, kernel)


def _apply_linear_transformation(image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
    """Apply a 3x3 linear color transformation with device-aware optimization.

    Args:
        image: Input image tensor with shape :math:`(*, 3, H, W)`.
        kernel: Transformation matrix with shape :math:`(3, 3)` applied along the channel
            dimension.

    Returns:
        Tensor with the same shape as ``image`` containing the transformed values.
    """
    # Handle Integer inputs by casting to float safely
    # If it's already floating point (e.g. float64 from gradcheck), we preserve it
    if image.is_floating_point():
        image_compute = image
    else:
        image_compute = image.float()

    # Match kernel dtype to the image (propagates float64 if needed)
    kernel_compute = kernel.to(dtype=image_compute.dtype, device=image_compute.device)

    input_shape = image_compute.shape

    # Empirical benchmarks show that einsum is faster on CPU for this specific pattern,
    # while conv2d offers significant speedups on GPU/CUDA.
    # We branch to ensure optimal performance on both devices.
    # BRANCH 1: CPU (Einsum)
    if image_compute.device.type == "cpu":
        out = torch.einsum("...chw,oc->...ohw", image_compute, kernel_compute)
        out = out.contiguous()

    # BRANCH 2: GPU/Accelerators (Conv2d)
    else:
        # Reshape for conv2d: (B*..., C, H, W)
        input_flat = image_compute.reshape(-1, 3, input_shape[-2], input_shape[-1])

        # Reshape kernel: (3, 3) -> (3, 3, 1, 1)
        weight = kernel_compute.view(3, 3, 1, 1)

        out_flat = F.conv2d(input_flat, weight)

        # Unflatten back to original shape
        out = out_flat.reshape(input_shape)

    return out


class RgbToXyz(nn.Module):
    r"""Convert an image from RGB to XYZ.

    The image data is assumed to be in the range of (0, 1).

    Returns:
        XYZ version of the image.

    Shape:
        - image: :math:`(*, 3, H, W)`
        - output: :math:`(*, 3, H, W)`

    Examples:
        >>> input = torch.rand(2, 3, 4, 5)
        >>> xyz = RgbToXyz()
        >>> output = xyz(input)  # 2x3x4x5

    Reference:
        [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html

    """

    ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
    ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """Convert an RGB tensor to XYZ.

        Args:
            image: Input tensor with shape :math:`(*, 3, H, W)`.
                Here, ``*`` means any number of leading dimensions (for example, batch size),
                ``3`` corresponds to RGB channels, and ``H``/``W`` are height and width.

        Returns:
            XYZ tensor with shape :math:`(*, 3, H, W)`.
        """
        return rgb_to_xyz(image)


class XyzToRgb(nn.Module):
    r"""Converts an image from XYZ to RGB.

    Returns:
        RGB version of the image.

    Shape:
        - image: :math:`(*, 3, H, W)`
        - output: :math:`(*, 3, H, W)`

    Examples:
        >>> input = torch.rand(2, 3, 4, 5)
        >>> rgb = XyzToRgb()
        >>> output = rgb(input)  # 2x3x4x5

    Reference:
        [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html

    """

    ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
    ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """Convert an XYZ tensor to RGB.

        Args:
            image: Input tensor with shape :math:`(*, 3, H, W)`.
                Here, ``*`` means any number of leading dimensions (for example, batch size),
                ``3`` corresponds to XYZ channels, and ``H``/``W`` are height and width.

        Returns:
            RGB tensor with shape :math:`(*, 3, H, W)`.
        """
        return xyz_to_rgb(image)
