# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import copy
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union

import torch
from torch import nn
from typing_extensions import ParamSpec

import kornia.augmentation as K
from kornia.augmentation.base import _AugmentationBase
from kornia.constants import DataKey
from kornia.geometry.boxes import Boxes
from kornia.geometry.keypoints import Keypoints

from .params import ParamItem

DataType = Union[torch.Tensor, List[torch.Tensor], Boxes, Keypoints]

# NOTE: shouldn't this SequenceDataType alias be equals to List[DataType]?
SequenceDataType = Union[List[torch.Tensor], List[List[torch.Tensor]], List[Boxes], List[Keypoints]]

T = TypeVar("T")


class SequentialOpsInterface(Generic[T], metaclass=ABCMeta):
    """Abstract interface for applying and inversing transformations."""

    @classmethod
    def get_instance_module_param(cls, param: ParamItem) -> Dict[str, torch.Tensor]:
        """Extract per-module parameter dict from a :class:`ParamItem`.

        Args:
            param: Parameter wrapper produced by sequential containers.

        Returns:
            Dictionary of tensor parameters for one module call.

        Raises:
            TypeError: ``param.data`` is not a dictionary.
        """
        if isinstance(param, ParamItem) and isinstance(param.data, dict):
            _params = param.data
        else:
            raise TypeError(f"Expected param (ParamItem.data) be a dictionary. Gotcha {param}.")
        return _params

    @classmethod
    def get_sequential_module_param(cls, param: ParamItem) -> List[ParamItem]:
        """Extract nested sequential parameters from a :class:`ParamItem`.

        Args:
            param: Parameter wrapper produced by sequential containers.

        Returns:
            List of :class:`ParamItem` values for nested modules.

        Raises:
            TypeError: ``param.data`` is not a list.
        """
        if isinstance(param, ParamItem) and isinstance(param.data, list):
            _params = param.data
        else:
            raise TypeError(f"Expected param (ParamItem.data) be a list. Gotcha {param}.")
        return _params

    @classmethod
    @abstractmethod
    def transform(cls, input: T, module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None) -> T:
        """Apply a transformation with respect to the parameters.

        Args:
            input: the input torch.Tensor.
            module: any torch nn.Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def inverse(cls, input: T, module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None) -> T:
        """Inverse a transformation with respect to the parameters.

        Args:
            input: the input torch.Tensor.
            module: any torch nn.Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        raise NotImplementedError


class AugmentationSequentialOps:
    """Implement the operational logic for the Augmentation Sequential container.

    This class manages how data keys (e.g., IMAGE, MASK, BBOX) are handled
    during the execution of an augmentation pipeline.

    Args:
        data_keys: A list of :class:`DataKey` defining the types of data to process.
    """

    def __init__(self, data_keys: Optional[List[DataKey]]) -> None:
        self._data_keys = data_keys

    @property
    def data_keys(self) -> Optional[List[DataKey]]:
        """Return currently configured data keys."""
        return self._data_keys

    @data_keys.setter
    def data_keys(self, data_keys: Optional[Union[List[DataKey], List[str], List[int]]]) -> None:
        if data_keys:
            self._data_keys = [DataKey.get(inp) for inp in data_keys]
        else:
            self._data_keys = None

    def preproc_datakeys(self, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None) -> List[DataKey]:
        """Normalize user-provided data keys into :class:`DataKey` values.

        Args:
            data_keys: Optional keys passed by the caller. If omitted, this method
                uses ``self.data_keys``.

        Returns:
            Normalized list of data keys.

        Raises:
            ValueError: Neither argument nor instance-level keys are available.
        """
        if data_keys is None:
            if isinstance(self.data_keys, list):
                return self.data_keys
            raise ValueError("nn.Sequential ops needs data keys to be able to process.")
        else:
            return [DataKey.get(inp) for inp in data_keys]

    def _get_op(self, data_key: DataKey) -> Type[SequentialOpsInterface[Any]]:
        """Return the corresponding operation given a data key."""
        if data_key == DataKey.INPUT:
            return InputSequentialOps
        if data_key == DataKey.MASK:
            return MaskSequentialOps
        if data_key in {DataKey.BBOX, DataKey.BBOX_XYWH, DataKey.BBOX_XYXY}:
            return BoxSequentialOps
        if data_key == DataKey.KEYPOINTS:
            return KeypointSequentialOps
        if data_key == DataKey.CLASS:
            return ClassSequentialOps
        raise RuntimeError(f"Operation for `{data_key.name}` is not found.")

    def transform(
        self,
        *arg: DataType,
        module: nn.Module,
        param: ParamItem,
        extra_args: Dict[DataKey, Dict[str, Any]],
        data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
    ) -> Union[DataType, SequenceDataType]:
        """Apply one module to all inputs according to their data keys.

        Args:
            *arg: Inputs to transform (image, mask, boxes, keypoints, ...).
            module: Module to execute.
            param: Parameters associated with ``module``.
            extra_args: Optional runtime overrides keyed by :class:`DataKey`.
            data_keys: Optional key override for this call.

        Returns:
            Transformed output(s). A single value is returned for one input,
            otherwise a list is returned.
        """
        _data_keys = self.preproc_datakeys(data_keys)

        if isinstance(module, K.RandomTransplantation):
            # For transforms which require the full input to calculate the parameters (e.g. RandomTransplantation)
            param = ParamItem(
                name=param.name,
                data=module.params_from_input(
                    *arg,  # type: ignore[arg-type]
                    data_keys=_data_keys,
                    params=param.data,  # type: ignore[arg-type]
                    extra_args=extra_args,
                ),
            )

        outputs = []
        for inp, dcate in zip(arg, _data_keys):
            op = self._get_op(dcate)
            extra_arg = extra_args.get(dcate, {})
            if dcate.name == "MASK" and isinstance(inp, list):
                outputs.append(MaskSequentialOps.transform_list(inp, module, param=param, extra_args=extra_arg))
            else:
                outputs.append(op.transform(inp, module, param=param, extra_args=extra_arg))
        if len(outputs) == 1 and isinstance(outputs, (list, tuple)):
            return outputs[0]
        return outputs

    def inverse(
        self,
        *arg: DataType,
        module: nn.Module,
        param: ParamItem,
        extra_args: Dict[DataKey, Dict[str, Any]],
        data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
    ) -> Union[DataType, SequenceDataType]:
        """Apply inverse transformation dispatch for one module step.

        Args:
            *arg: Inputs to invert.
            module: Module used in the forward pass.
            param: Parameters captured for ``module`` during forward.
            extra_args: Optional runtime overrides keyed by :class:`DataKey`.
            data_keys: Optional key override for this call.

        Returns:
            Inverse-transformed output(s). A single value is returned for one
            input, otherwise a list is returned.
        """
        _data_keys = self.preproc_datakeys(data_keys)
        outputs = []
        for inp, dcate in zip(arg, _data_keys):
            op = self._get_op(dcate)
            extra_arg = extra_args[dcate] if dcate in extra_args else {}
            outputs.append(op.inverse(inp, module, param=param, extra_args=extra_arg))
        if len(outputs) == 1 and isinstance(outputs, (list, tuple)):
            return outputs[0]
        return outputs


P = ParamSpec("P")


def make_input_only_sequential(module: "K.container.ImageSequentialBase") -> Callable[P, torch.Tensor]:
    """Disable all other additional inputs (e.g. ) for ImageSequential."""

    def f(*args: P.args, **kwargs: P.kwargs) -> torch.Tensor:
        return module(*args, **kwargs)

    return f


def get_geometric_only_param(module: "K.container.ImageSequentialBase", param: List[ParamItem]) -> List[ParamItem]:
    """Return geometry param."""
    named_modules = module.get_forward_sequence(param)

    res: List[ParamItem] = []
    for (_, mod), p in zip(named_modules, param):
        if isinstance(mod, (K.GeometricAugmentationBase2D, K.GeometricAugmentationBase3D)):
            res.append(p)
    return res


class InputSequentialOps(SequentialOpsInterface[torch.Tensor]):
    """Implement the operations for processing input tensors within a sequential container.

    This class provides class methods to apply transformations and manage the
    flow of data through the augmentation pipeline.
    """

    @classmethod
    def transform(
        cls, input: torch.Tensor, module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> torch.Tensor:
        """Apply one module step to an image tensor.

        Args:
            input: Input image tensor.
            module: Module to execute.
            param: Parameters for ``module``.
            extra_args: Optional runtime overrides.

        Returns:
            Transformed tensor.

        Raises:
            AssertionError: A non-augmentation module receives non-empty params.
        """
        if extra_args is None:
            extra_args = {}
        if isinstance(module, (_AugmentationBase, K.MixAugmentationBaseV2)):
            input = module(input, params=cls.get_instance_module_param(param), data_keys=[DataKey.INPUT], **extra_args)
        elif isinstance(module, (K.container.ImageSequentialBase,)):
            input = module.transform_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
        elif isinstance(module, (K.auto.operations.OperationBase,)):
            input = module(input, params=cls.get_instance_module_param(param))
        else:
            if param.data is not None:
                raise AssertionError(f"Non-augmentaion operation {param.name} require empty parameters. Got {param}.")
            input = module(input)
        return input

    @classmethod
    def inverse(
        cls, input: torch.Tensor, module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> torch.Tensor:
        """Apply one inverse module step to an image tensor.

        Args:
            input: Tensor to invert.
            module: Module used in the forward path.
            param: Forward parameters for ``module``.
            extra_args: Optional runtime overrides.

        Returns:
            Inverse-transformed tensor.

        Raises:
            NotImplementedError: Inverse for 3D geometric ops is not supported.
        """
        if extra_args is None:
            extra_args = {}
        if isinstance(module, K.GeometricAugmentationBase2D):
            input = module.inverse(input, params=cls.get_instance_module_param(param), **extra_args)
        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d inverse operations are not yet supported. You are welcome to file a PR in our repo."
            )
        elif isinstance(module, (K.auto.operations.OperationBase,)):
            return InputSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            input = module.inverse_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
        elif isinstance(module, K.container.ImageSequentialBase):
            input = module.inverse_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
        return input


class ClassSequentialOps(SequentialOpsInterface[torch.Tensor]):
    """Apply and inverse transformations for class labels if needed."""

    @classmethod
    def transform(
        cls, input: torch.Tensor, module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> torch.Tensor:
        """Apply class-label handling for one module step.

        Args:
            input: Class-label tensor.
            module: Module to execute.
            param: Parameters for ``module``.
            extra_args: Optional runtime overrides.

        Returns:
            Class labels after transformation handling.

        Raises:
            NotImplementedError: Label-changing mix ops are not supported yet.
        """
        if isinstance(module, K.MixAugmentationBaseV2):
            raise NotImplementedError(
                "The support for class labels for mix augmentations that change the class label is not yet supported."
            )
        return input

    @classmethod
    def inverse(
        cls, input: torch.Tensor, module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> torch.Tensor:
        """Return class labels unchanged during inverse dispatch.

        Note:
            Class labels do not have a geometric inverse in this pipeline.

        Args:
            input: Class-label tensor.
            module: Module used in forward (unused).
            param: Forward parameters (unused).
            extra_args: Optional runtime overrides (unused).

        Returns:
            Unmodified class labels.
        """
        return input


class MaskSequentialOps(SequentialOpsInterface[torch.Tensor]):
    """Apply and inverse transformations for mask tensors."""

    @classmethod
    def transform(
        cls, input: torch.Tensor, module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> torch.Tensor:
        """Apply a transformation with respect to the parameters.

        Args:
            input: the input torch.Tensor.
            module: any torch nn.Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}

        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            input = module.transform_masks(
                input,
                params=cls.get_instance_module_param(param),
                flags=module.flags,
                transform=module.transform_matrix,
                **extra_args,
            )

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d mask operations are not yet supported. You are welcome to file a PR in our repo."
            )

        elif isinstance(module, K.RandomTransplantation):
            input = module(input, params=cls.get_instance_module_param(param), data_keys=[DataKey.MASK], **extra_args)

        elif isinstance(module, (_AugmentationBase)):
            input = module.transform_masks(
                input, params=cls.get_instance_module_param(param), flags=module.flags, **extra_args
            )

        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            input = module.transform_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

        elif isinstance(module, K.container.ImageSequentialBase):
            input = module.transform_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            input = MaskSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)

        return input

    @classmethod
    def transform_list(
        cls, input: List[torch.Tensor], module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> List[torch.Tensor]:
        """Apply a transformation with respect to the parameters.

        Args:
            input: list of input tensors.
            module: any torch nn.Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}
        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            tfm_input = []
            params = cls.get_instance_module_param(param)
            params_i = copy.deepcopy(params)
            for i, inp in enumerate(input):
                params_i["batch_prob"] = params["batch_prob"][i]
                tfm_inp = module.transform_masks(
                    inp, params=params_i, flags=module.flags, transform=module.transform_matrix, **extra_args
                )
                tfm_input.append(tfm_inp)
            input = tfm_input

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d mask operations are not yet supported. You are welcome to file a PR in our repo."
            )

        elif isinstance(module, (_AugmentationBase)):
            tfm_input = []
            params = cls.get_instance_module_param(param)
            params_i = copy.deepcopy(params)
            for i, inp in enumerate(input):
                params_i["batch_prob"] = params["batch_prob"][i]
                tfm_inp = module.transform_masks(inp, params=params_i, flags=module.flags, **extra_args)
                tfm_input.append(tfm_inp)
            input = tfm_input

        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            tfm_input = []
            seq_params = cls.get_sequential_module_param(param)
            for inp in input:
                tfm_inp = module.transform_masks(inp, params=seq_params, extra_args=extra_args)
                tfm_input.append(tfm_inp)
            input = tfm_input

        elif isinstance(module, K.container.ImageSequentialBase):
            tfm_input = []
            seq_params = cls.get_sequential_module_param(param)
            for inp in input:
                tfm_inp = module.transform_masks(inp, params=seq_params, extra_args=extra_args)
                tfm_input.append(tfm_inp)
            input = tfm_input

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            raise NotImplementedError(
                "The support for list of masks under auto operations are not yet supported. You are welcome to file a"
                " PR in our repo."
            )
        return input

    @classmethod
    def inverse(
        cls, input: torch.Tensor, module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> torch.Tensor:
        """Inverse a transformation with respect to the parameters.

        Args:
            input: the input torch.Tensor.
            module: any torch nn.Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}

        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            if module.transform_matrix is None:
                raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
            transform = module.compute_inverse_transformation(module.transform_matrix)
            input = module.inverse_masks(
                input,
                params=cls.get_instance_module_param(param),
                flags=module.flags,
                transform=transform,
                **extra_args,
            )

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d mask operations are not yet supported. You are welcome to file a PR in our repo."
            )

        elif isinstance(module, K.container.ImageSequentialBase):
            input = module.inverse_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            input = MaskSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)

        return input


class BoxSequentialOps(SequentialOpsInterface[Boxes]):
    """Apply and inverse transformations for bounding box tensors.

    This is for transform boxes in the format (B, N, 4, 2).
    """

    @classmethod
    def transform(
        cls, input: Boxes, module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Boxes:
        """Apply a transformation with respect to the parameters.

        Args:
            input: the input torch.Tensor, (B, N, 4, 2) or (B, 4, 2).
            module: any torch nn.Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}
        _input = input.clone()

        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            _input = module.transform_boxes(
                _input,
                cls.get_instance_module_param(param),
                module.flags,
                transform=module.transform_matrix,
                **extra_args,
            )

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d box operations are not yet supported. You are welcome to file a PR in our repo."
            )

        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            _input = module.transform_boxes(
                _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
            )

        elif isinstance(module, K.container.ImageSequentialBase):
            _input = module.transform_boxes(
                _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
            )

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            return BoxSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)

        return _input

    @classmethod
    def inverse(
        cls, input: Boxes, module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Boxes:
        """Inverse a transformation with respect to the parameters.

        Args:
            input: the input torch.Tensor.
            module: any torch nn.Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}
        _input = input.clone()

        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            if module.transform_matrix is None:
                raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
            transform = module.compute_inverse_transformation(module.transform_matrix)
            _input = module.inverse_boxes(
                _input,
                param.data,  # type: ignore[arg-type]
                module.flags,
                transform=transform,
                **extra_args,
            )

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d box operations are not yet supported. You are welcome to file a PR in our repo."
            )

        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            _input = module.inverse_boxes(_input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

        elif isinstance(module, K.container.ImageSequentialBase):
            _input = module.inverse_boxes(_input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            return BoxSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
        return _input


class KeypointSequentialOps(SequentialOpsInterface[Keypoints]):
    """Apply and inverse transformations for keypoints tensors.

    This is for transform keypoints in the format (B, N, 2).
    """

    @classmethod
    def transform(
        cls, input: Keypoints, module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Keypoints:
        """Apply a transformation with respect to the parameters.

        Args:
            input: the input torch.Tensor, (B, N, 4, 2) or (B, 4, 2).
            module: any torch nn.Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}
        _input = input.clone()

        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            _input = module.transform_keypoints(
                _input,
                cls.get_instance_module_param(param),
                module.flags,
                transform=module.transform_matrix,
                **extra_args,
            )

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d keypoint operations are not yet supported. "
                "You are welcome to file a PR in our repo."
            )

        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            _input = module.transform_keypoints(
                _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
            )

        elif isinstance(module, K.container.ImageSequentialBase):
            _input = module.transform_keypoints(
                _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
            )

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            return KeypointSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)

        return _input

    @classmethod
    def inverse(
        cls, input: Keypoints, module: nn.Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Keypoints:
        """Inverse a transformation with respect to the parameters.

        Args:
            input: Input keypoints. Coordinates are conceptually stored as
                ``(B, N, 2)``, where the last dimension stores ``(x, y)``.
            module: any torch nn.Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.

        Returns:
            Keypoints after inverse transformation.
        """
        if extra_args is None:
            extra_args = {}
        _input = input.clone()

        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            if module.transform_matrix is None:
                raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
            transform = module.compute_inverse_transformation(module.transform_matrix)
            _input = module.inverse_keypoints(
                _input, cls.get_instance_module_param(param), module.flags, transform=transform, **extra_args
            )

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d keypoint operations are not yet supported. "
                "You are welcome to file a PR in our repo."
            )

        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            _input = module.inverse_keypoints(
                _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
            )

        elif isinstance(module, K.container.ImageSequentialBase):
            _input = module.inverse_keypoints(
                _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
            )

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            return KeypointSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)

        return _input
