# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import warnings
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast

import torch
from torch import nn

from kornia.augmentation._2d.base import RigidAffineAugmentationBase2D
from kornia.augmentation._3d.base import AugmentationBase3D, RigidAffineAugmentationBase3D
from kornia.augmentation.base import _AugmentationBase
from kornia.constants import DataKey, Resample
from kornia.core.ops import eye_like
from kornia.core.utils import is_autocast_enabled
from kornia.geometry.boxes import Boxes, VideoBoxes
from kornia.geometry.keypoints import Keypoints, VideoKeypoints

from .base import TransformMatrixMinIn
from .image import ImageSequential
from .ops import AugmentationSequentialOps, DataType
from .params import ParamItem
from .patch import PatchSequential
from .video import VideoSequential

__all__ = ["AugmentationSequential"]

_BOXES_OPTIONS = {DataKey.BBOX, DataKey.BBOX_XYXY, DataKey.BBOX_XYWH}
_KEYPOINTS_OPTIONS = {DataKey.KEYPOINTS}
_IMG_OPTIONS = {DataKey.INPUT, DataKey.IMAGE}
_MSK_OPTIONS = {DataKey.MASK}
_CLS_OPTIONS = {DataKey.CLASS, DataKey.LABEL}

MaskDataType = Union[torch.Tensor, List[torch.Tensor]]


class AugmentationSequential(TransformMatrixMinIn, ImageSequential):
    r"""AugmentationSequential for handling multiple input types like inputs, masks, keypoints at once.

    .. image:: _static/img/AugmentationSequential.png

    Args:
        *args: a list of kornia augmentation modules.

        data_keys: the input type sequential for applying augmentations. Accepts "input", "image", "mask",
                   "bbox", "bbox_xyxy", "bbox_xywh", "keypoints", "class", "label".

        same_on_batch: apply the same transformation across the batch. If None, it will not overwrite the function-wise
                       settings.

        keepdim: whether to keep the output shape the same as input (True) or broadcast it to the batch form (False).
                 If None, it will not overwrite the function-wise settings.

        random_apply: randomly select a sublist (order agnostic) of args to apply transformation.
                      If int, a fixed number of transformations will be selected.
                      If (a,), x number of transformations (a <= x <= len(args)) will be selected.
                      If (a, b), x number of transformations (a <= x <= b) will be selected.
                      If True, the whole list of args will be processed as a sequence in a random order.
                      If False, the whole list of args will be processed as a sequence in original order.

        transformation_matrix_mode: computation mode for the chained transformation matrix, via `.transform_matrix`
                                    attribute.
                                    If `silent`, transformation matrix will be computed silently and the non-rigid
                                    modules will be ignored as identity transformations.
                                    If `rigid`, transformation matrix will be computed silently and the non-rigid
                                    modules will trigger errors.
                                    If `skip`, transformation matrix will be totally ignored.

        extra_args: to control the behaviour for each datakeys. By default, masks are handled by nearest interpolation
                    strategies.

    .. note::
        Mix augmentations (e.g. RandomMixUp, RandomCutMix) can only be working with "input"/"image" data key.
        It is not clear how to deal with the conversions of masks, bounding boxes and keypoints.

    .. note::
        See a working example `here <https://kornia.github.io/tutorials/nbs/data_augmentation_sequential.html>`__.

    Examples:
        >>> import kornia
        >>> input = torch.randn(2, 3, 5, 6)
        >>> mask = torch.ones(2, 3, 5, 6)
        >>> bbox = torch.tensor([[
        ...     [1., 1.],
        ...     [2., 1.],
        ...     [2., 2.],
        ...     [1., 2.],
        ... ]]).expand(2, 1, -1, -1)
        >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)
        >>> aug_list = AugmentationSequential(
        ...     kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
        ...     kornia.augmentation.RandomAffine(360, p=1.0),
        ...     data_keys=["input", "mask", "bbox", "keypoints"],
        ...     same_on_batch=False,
        ...     random_apply=10,
        ... )
        >>> out = aug_list(input, mask, bbox, points)
        >>> [o.shape for o in out]
        [torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 5, 6]), torch.Size([2, 1, 4, 2]), torch.Size([2, 1, 2])]
        >>> # apply the exact augmentation again.
        >>> out_rep = aug_list(input, mask, bbox, points, params=aug_list._params)
        >>> [(o == o_rep).all() for o, o_rep in zip(out, out_rep)]
        [tensor(True), tensor(True), tensor(True), tensor(True)]
        >>> # inverse the augmentations
        >>> out_inv = aug_list.inverse(*out)
        >>> [o.shape for o in out_inv]
        [torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 5, 6]), torch.Size([2, 1, 4, 2]), torch.Size([2, 1, 2])]

    This example demonstrates the integration of VideoSequential and AugmentationSequential.

        >>> import kornia
        >>> input = torch.randn(2, 3, 5, 6)[None]
        >>> mask = torch.ones(2, 3, 5, 6)[None]
        >>> bbox = torch.tensor([[
        ...     [1., 1.],
        ...     [2., 1.],
        ...     [2., 2.],
        ...     [1., 2.],
        ... ]]).expand(2, 1, -1, -1)[None]
        >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)[None]
        >>> aug_list = AugmentationSequential(
        ...     VideoSequential(
        ...         kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
        ...         kornia.augmentation.RandomAffine(360, p=1.0),
        ...     ),
        ...     data_keys=["input", "mask", "bbox", "keypoints"]
        ... )
        >>> out = aug_list(input, mask, bbox, points)
        >>> [o.shape for o in out]  # doctest: +ELLIPSIS
        [torch.Size([1, 2, 3, 5, 6]), torch.Size([1, 2, 3, 5, 6]), ...([1, 2, 1, 4, 2]), torch.Size([1, 2, 1, 2])]

    Perform ``OneOf`` transformation with ``random_apply=1`` and ``random_apply_weights``
    in ``AugmentationSequential``.

        >>> import kornia
        >>> input = torch.randn(2, 3, 5, 6)[None]
        >>> mask = torch.ones(2, 3, 5, 6)[None]
        >>> bbox = torch.tensor([[
        ...     [1., 1.],
        ...     [2., 1.],
        ...     [2., 2.],
        ...     [1., 2.],
        ... ]]).expand(2, 1, -1, -1)[None]
        >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)[None]
        >>> aug_list = AugmentationSequential(
        ...     VideoSequential(
        ...         kornia.augmentation.RandomAffine(360, p=1.0),
        ...     ),
        ...     VideoSequential(
        ...         kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
        ...     ),
        ...     data_keys=["input", "mask", "bbox", "keypoints"],
        ...     random_apply=1,
        ...     random_apply_weights=[0.5, 0.3]
        ... )
        >>> out = aug_list(input, mask, bbox, points)
        >>> [o.shape for o in out]  # doctest: +ELLIPSIS
        [torch.Size([1, 2, 3, 5, 6]), torch.Size([1, 2, 3, 5, 6]), ...([1, 2, 1, 4, 2]), torch.Size([1, 2, 1, 2])]

    This example shows how to use a list of masks and boxes within AugmentationSequential

        >>> import kornia.augmentation as K
        >>> input = torch.randn(2, 3, 256, 256)
        >>> mask = [torch.ones(1, 3, 256, 256), torch.ones(1, 2, 256, 256)]
        >>> bbox = [
        ...    torch.tensor([[28.0, 53.0, 143.0, 164.0], [254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]]),
        ...    torch.tensor([[254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]])
        ... ]
        >>> bbox = [Boxes.from_tensor(i).data for i in bbox]

        >>> aug_list = K.AugmentationSequential(
        ...    K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
        ...    K.RandomHorizontalFlip(p=1.0),
        ...    K.ImageSequential(K.RandomHorizontalFlip(p=1.0)),
        ...    K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0)),
        ...    data_keys=["input", "mask", "bbox"],
        ...    same_on_batch=False,
        ...    random_apply=10,
        ... )
        >>> out = aug_list(input, mask, bbox)

    How to use a dictionary as input with AugmentationSequential? The dictionary keys that start with
    one of the available datakeys will be augmented accordingly. Otherwise, the dictionary item is passed
    without any augmentation.

        >>> import kornia.augmentation as K
        >>> img = torch.randn(1, 3, 256, 256)
        >>> mask = [torch.ones(1, 3, 256, 256), torch.ones(1, 2, 256, 256)]
        >>> bbox = [
        ...    torch.tensor([[28.0, 53.0, 143.0, 164.0], [254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]]),
        ...    torch.tensor([[254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]])
        ... ]
        >>> bbox = [Boxes.from_tensor(i).data for i in bbox]
        >>> aug_dict = K.AugmentationSequential(
        ...    K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
        ...    K.RandomHorizontalFlip(p=1.0),
        ...    K.ImageSequential(K.RandomHorizontalFlip(p=1.0)),
        ...    K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0)),
        ...    data_keys=None,
        ...    same_on_batch=False,
        ...    random_apply=10,
        ... )
        >>> data = {'image': img, 'mask': mask[0], 'mask-b': mask[1], 'bbox': bbox[0], 'bbox-other':bbox[1]}
        >>> out = aug_dict(data)
        >>> out.keys()
        dict_keys(['image', 'mask', 'mask-b', 'bbox', 'bbox-other'])

    """

    input_dtype = None
    mask_dtype = None

    def __init__(
        self,
        *args: Union[_AugmentationBase, ImageSequential],
        data_keys: Optional[Union[Sequence[str], Sequence[int], Sequence[DataKey]]] = (DataKey.INPUT,),
        same_on_batch: Optional[bool] = None,
        keepdim: Optional[bool] = None,
        random_apply: Union[int, bool, Tuple[int, int]] = False,
        random_apply_weights: Optional[List[float]] = None,
        transformation_matrix_mode: str = "silent",
        extra_args: Optional[Dict[DataKey, Dict[str, Any]]] = None,
    ) -> None:
        self._transform_matrix: Optional[torch.Tensor]
        self._transform_matrices: List[Optional[torch.Tensor]] = []

        super().__init__(
            *args,
            same_on_batch=same_on_batch,
            keepdim=keepdim,
            random_apply=random_apply,
            random_apply_weights=random_apply_weights,
        )

        self._parse_transformation_matrix_mode(transformation_matrix_mode)

        self._valid_ops_for_transform_computation: Tuple[Any, ...] = (
            RigidAffineAugmentationBase2D,
            RigidAffineAugmentationBase3D,
            AugmentationSequential,
        )

        self.data_keys: Optional[List[DataKey]]
        if data_keys is not None:
            self.data_keys = [DataKey.get(inp) for inp in data_keys]
        else:
            self.data_keys = data_keys

        if self.data_keys:
            if any(in_type not in DataKey for in_type in self.data_keys):
                raise AssertionError(f"`data_keys` must be in {DataKey}. Got {self.data_keys}.")

            if self.data_keys[0] != DataKey.INPUT:
                raise NotImplementedError(f"The first input must be {DataKey.INPUT}.")

        self.transform_op = AugmentationSequentialOps(self.data_keys)

        self.contains_video_sequential: bool = False
        self.contains_3d_augmentation: bool = False
        for arg in args:
            if isinstance(arg, PatchSequential) and not arg.is_intensity_only():
                warnings.warn(
                    "Geometric transformation detected in PatchSeqeuntial, which would break bbox, mask.", stacklevel=1
                )
            if isinstance(arg, VideoSequential):
                self.contains_video_sequential = True
            # NOTE: only for images are supported for 3D.
            if isinstance(arg, AugmentationBase3D):
                self.contains_3d_augmentation = True
        self._transform_matrix = None
        self.extra_args = extra_args or {DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}}

    def clear_state(self) -> None:
        """Reset cached params and transformation-matrix state."""
        self._reset_transform_matrix_state()
        return super().clear_state()

    def _update_transform_matrix_for_valid_op(self, module: nn.Module) -> None:
        self._transform_matrices.append(module.transform_matrix)

    def identity_matrix(self, input: torch.Tensor) -> torch.Tensor:
        """Return identity matrix."""
        if self.contains_3d_augmentation:
            return eye_like(4, input)
        else:
            return eye_like(3, input)

    def inverse(  # type: ignore[override]
        self,
        *args: Union[DataType, Dict[str, DataType]],
        params: Optional[List[ParamItem]] = None,
        data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
    ) -> Union[DataType, List[DataType], Dict[str, DataType]]:
        """Reverse the transformation applied.

        Number of input tensors must align with the number of``data_keys``. If ``data_keys`` is not set, use
        ``self.data_keys`` by default.
        """
        original_keys = None
        if len(args) == 1 and isinstance(args[0], dict):
            original_keys, data_keys, args, invalid_data = self._preproc_dict_data(cast(Dict[str, DataType], args[0]))

        # args here should already be `DataType`
        # NOTE: how to right type to: unpacked args <-> tuple of args to unpack
        # issue with `self._preproc_dict_data` return args type

        self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys)

        self._validate_args_datakeys(*args, data_keys=self.transform_op.data_keys)  # type: ignore

        in_args = self._arguments_preproc(*args, data_keys=self.transform_op.data_keys)  # type: ignore

        if params is None:
            if self._params is None:
                raise ValueError(
                    "No parameters available for inversing, please run a forward pass first "
                    "or passing valid params into this function."
                )
            params = self._params

        outputs: List[DataType] = in_args
        for param in params[::-1]:
            module = self.get_submodule(param.name)
            outputs = self.transform_op.inverse(  # type: ignore
                *outputs, module=module, param=param, extra_args=self.extra_args
            )
            if not isinstance(outputs, list | tuple):
                # Make sure we are unpacking a list whilst post-proc
                outputs = [outputs]

        outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys)  # type: ignore

        if isinstance(original_keys, tuple):
            result = {k: v for v, k in zip(outputs, original_keys)}
            if invalid_data:
                result.update(invalid_data)
            return result

        if len(outputs) == 1 and isinstance(outputs, list):
            return outputs[0]

        return outputs

    def _validate_args_datakeys(self, *args: DataType, data_keys: List[DataKey]) -> None:
        if len(args) != len(data_keys):
            raise AssertionError(
                f"The number of inputs must align with the number of data_keys. Got {len(args)} and {len(data_keys)}."
            )
        # TODO: validate args batching, and its consistency

    def _arguments_preproc(self, *args: DataType, data_keys: List[DataKey]) -> List[DataType]:
        inp: List[DataType] = []
        for arg, dcate in zip(args, data_keys):
            if DataKey.get(dcate) in _IMG_OPTIONS:
                arg = cast(torch.Tensor, arg)
                self.input_dtype = arg.dtype
                inp.append(arg)
            elif DataKey.get(dcate) in _MSK_OPTIONS:
                if isinstance(inp, list):
                    arg = cast(List[torch.Tensor], arg)
                    self.mask_dtype = arg[0].dtype
                else:
                    arg = cast(torch.Tensor, arg)
                    self.mask_dtype = arg.dtype
                inp.append(self._preproc_mask(arg))
            elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS:
                inp.append(self._preproc_keypoints(arg, dcate))
            elif DataKey.get(dcate) in _BOXES_OPTIONS:
                inp.append(self._preproc_boxes(arg, dcate))
            elif DataKey.get(dcate) in _CLS_OPTIONS:
                inp.append(arg)
            else:
                raise NotImplementedError(f"input type of {dcate} is not implemented.")
        return inp

    def _arguments_postproc(
        self, in_args: List[DataType], out_args: List[DataType], data_keys: List[DataKey]
    ) -> List[DataType]:
        out: List[DataType] = []
        for in_arg, out_arg, dcate in zip(in_args, out_args, data_keys):
            if DataKey.get(dcate) in _IMG_OPTIONS:
                # It is torch.Tensor type already.
                out.append(out_arg)
                # TODO: may add the float to integer (for masks), etc.
            elif DataKey.get(dcate) in _MSK_OPTIONS:
                _out_m = self._postproc_mask(cast(MaskDataType, out_arg))
                out.append(_out_m)

            elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS:
                _out_k = self._postproc_keypoint(in_arg, cast(Keypoints, out_arg), dcate)
                if is_autocast_enabled() and isinstance(in_arg, torch.Tensor | Keypoints):
                    if isinstance(_out_k, list):
                        _out_k = [i.type(in_arg.dtype) for i in _out_k]
                    else:
                        _out_k = _out_k.type(in_arg.dtype)
                out.append(_out_k)

            elif DataKey.get(dcate) in _BOXES_OPTIONS:
                _out_b = self._postproc_boxes(in_arg, cast(Boxes, out_arg), dcate)
                if is_autocast_enabled() and isinstance(in_arg, torch.Tensor | Boxes):
                    if isinstance(_out_b, list):
                        _out_b = [i.type(in_arg.dtype) for i in _out_b]
                    else:
                        _out_b = _out_b.type(in_arg.dtype)
                out.append(_out_b)

            elif DataKey.get(dcate) in _CLS_OPTIONS:
                out.append(out_arg)

            else:
                raise NotImplementedError(f"input type of {dcate} is not implemented.")

        return out

    def forward(  # type: ignore[override]
        self,
        *args: Union[DataType, Dict[str, DataType]],
        params: Optional[List[ParamItem]] = None,
        data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
    ) -> Union[DataType, List[DataType], Dict[str, DataType]]:
        """Compute multiple tensors simultaneously according to ``self.data_keys``."""
        self.clear_state()

        # Unpack/handle dictionary args
        original_keys = None
        if len(args) == 1 and isinstance(args[0], dict):
            original_keys, data_keys, args, invalid_data = self._preproc_dict_data(cast(Dict[str, DataType], args[0]))

        self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys)

        self._validate_args_datakeys(*args, data_keys=self.transform_op.data_keys)  # type: ignore

        in_args = self._arguments_preproc(*args, data_keys=self.transform_op.data_keys)  # type: ignore

        if params is None:
            # image data must exist if params is not provided.
            if DataKey.INPUT in self.transform_op.data_keys:
                inp = in_args[self.transform_op.data_keys.index(DataKey.INPUT)]
                if not isinstance(inp, torch.Tensor):
                    raise ValueError(f"`INPUT` should be a torch.Tensor but `{type(inp)}` received.")
                # A video input shall be BCDHW while an image input shall be BCHW
                if self.contains_video_sequential or self.contains_3d_augmentation:
                    _, out_shape = self.autofill_dim(inp, dim_range=(3, 5))
                else:
                    _, out_shape = self.autofill_dim(inp, dim_range=(2, 4))
                params = self.forward_parameters(out_shape)
            else:
                raise ValueError("`params` must be provided whilst INPUT is not in data_keys.")

        outputs: Union[torch.Tensor, List[DataType]] = in_args
        for param in params:
            module = self.get_submodule(param.name)
            outputs = self.transform_op.transform(  # type: ignore
                *outputs, module=module, param=param, extra_args=self.extra_args
            )
            if not isinstance(outputs, list | tuple):
                # Make sure we are unpacking a list whilst post-proc
                outputs = [outputs]
            self._update_transform_matrix_by_module(module)

        outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys)  # type: ignore
        # Restore it back
        self.transform_op.data_keys = self.data_keys

        self._params = params

        if isinstance(original_keys, tuple):
            result = {k: v for v, k in zip(outputs, original_keys)}
            if invalid_data:
                result.update(invalid_data)
            return result

        if len(outputs) == 1 and isinstance(outputs, list):
            return outputs[0]

        return outputs

    def __call__(
        self,
        *inputs: Any,
        input_names_to_handle: Optional[List[Any]] = None,
        output_type: Literal["pt", "numpy", "pil"] = "pt",
        **kwargs: Any,
    ) -> Any:
        """Overwrite the __call__ function to handle various inputs.

        Args:
            inputs: Inputs to operate on.
            input_names_to_handle: List of input names to convert, if None, handle all inputs.
            output_type: Desired output type ('pt', 'numpy', or 'pil').
            kwargs: Additional arguments.

        Returns:
            Callable: Decorated function with converted input and output types.

        """
        # Wrap the forward method with the decorator
        if not self._disable_features:
            # TODO: Some more behaviour for AugmentationSequential needs to be revisited later
            # e.g. We convert only images, etc.
            decorated_forward = self.convert_input_output(
                input_names_to_handle=input_names_to_handle, output_type=output_type
            )(super(ImageSequential, self).__call__)
            _output_image = decorated_forward(*inputs, **kwargs)

            in_data_keys: Optional[List[DataKey]]
            if len(inputs) == 1 and isinstance(inputs[0], dict):
                original_keys, in_data_keys, inputs, _invalid_data = self._preproc_dict_data(inputs[0])
            else:
                in_data_keys = kwargs.get("data_keys", self.data_keys)
            data_keys = self.transform_op.preproc_datakeys(in_data_keys)

            if len(data_keys) > 1 and DataKey.INPUT in data_keys:
                # NOTE: we may update it later for more supports of drawing boxes, etc.
                idx = data_keys.index(DataKey.INPUT)
                if output_type == "pt":
                    self._output_image = _output_image
                    if isinstance(_output_image, dict):
                        self._output_image[original_keys[idx]] = _output_image[original_keys[idx]]
                    else:
                        self._output_image[idx] = _output_image[idx]
                elif isinstance(_output_image, dict):
                    self._output_image[original_keys[idx]] = _output_image[original_keys[idx]]
                else:
                    self._output_image[idx] = _output_image[idx]
            else:
                self._output_image = _output_image
        else:
            _output_image = super(ImageSequential, self).__call__(*inputs, **kwargs)
        return _output_image

    def _preproc_dict_data(
        self, data: Dict[str, DataType]
    ) -> Tuple[Tuple[str, ...], List[DataKey], Tuple[DataType, ...], Optional[Dict[str, Any]]]:
        if self.data_keys is not None:
            raise ValueError("If you are using a dictionary as input, the data_keys should be None.")

        keys = tuple(data.keys())
        data_keys, invalid_keys = self._read_datakeys_from_dict(keys)
        invalid_data = {i: data.pop(i) for i in invalid_keys} if invalid_keys else None
        keys = tuple(k for k in keys if k not in invalid_keys) if invalid_keys else keys
        data_unpacked = tuple(data.values())

        return keys, data_keys, data_unpacked, invalid_data

    def _read_datakeys_from_dict(self, keys: Sequence[str]) -> Tuple[List[DataKey], Optional[List[str]]]:
        def retrieve_key(key: str) -> DataKey:
            """Try to retrieve the datakey value by matching `<datakey>*`."""
            # Alias cases, like INPUT, will not be get by the enum iterator.
            if key.upper().startswith("INPUT"):
                return DataKey.INPUT

            for dk in DataKey:
                if key.upper() in {"BBOX_XYXY", "BBOX_XYWH"}:
                    return DataKey.get(key.upper())
                if key.upper().startswith(dk.name):
                    return DataKey.get(dk.name)

            allowed_dk = " | ".join(f"`{d.name}`" for d in DataKey)
            raise ValueError(
                f"Your input data dictionary keys should start with some of datakey values: {allowed_dk}. Got `{key}`"
            )

        valid_data_keys = []
        invalid_keys = []
        for k in keys:
            try:
                valid_data_keys.append(DataKey.get(retrieve_key(k)))
            except ValueError:
                invalid_keys.append(k)

        return valid_data_keys, invalid_keys

    def _preproc_mask(self, arg: MaskDataType) -> MaskDataType:
        if isinstance(arg, list):
            new_arg = []
            for a in arg:
                a_new = a.to(self.input_dtype) if self.input_dtype else a.to(torch.float)
                new_arg.append(a_new)
            return new_arg

        else:
            arg = arg.to(self.input_dtype) if self.input_dtype else arg.to(torch.float)
        return arg

    def _postproc_mask(self, arg: MaskDataType) -> MaskDataType:
        if isinstance(arg, list):
            new_arg = []
            for a in arg:
                a_new = a.to(self.mask_dtype) if self.mask_dtype else a.to(torch.float)
                new_arg.append(a_new)
            return new_arg

        else:
            arg = arg.to(self.mask_dtype) if self.mask_dtype else arg.to(torch.float)
        return arg

    def _preproc_boxes(self, arg: DataType, dcate: DataKey) -> Boxes:
        if DataKey.get(dcate) in [DataKey.BBOX]:
            mode = "vertices_plus"
        elif DataKey.get(dcate) in [DataKey.BBOX_XYXY]:
            mode = "xyxy_plus"
        elif DataKey.get(dcate) in [DataKey.BBOX_XYWH]:
            mode = "xywh"
        else:
            raise ValueError(f"Unsupported mode `{DataKey.get(dcate).name}`.")
        if isinstance(arg, Boxes):
            return arg
        elif self.contains_video_sequential:
            arg = cast(torch.Tensor, arg)
            return VideoBoxes.from_tensor(arg)
        elif self.contains_3d_augmentation:
            raise NotImplementedError("3D box handlers are not yet supported.")
        else:
            arg = cast(torch.Tensor, arg)
            return Boxes.from_tensor(arg, mode=mode)

    def _postproc_boxes(
        self, in_arg: DataType, out_arg: Boxes, dcate: DataKey
    ) -> Union[torch.Tensor, List[torch.Tensor], Boxes]:
        if DataKey.get(dcate) in [DataKey.BBOX]:
            mode = "vertices_plus"
        elif DataKey.get(dcate) in [DataKey.BBOX_XYXY]:
            mode = "xyxy_plus"
        elif DataKey.get(dcate) in [DataKey.BBOX_XYWH]:
            mode = "xywh"
        else:
            raise ValueError(f"Unsupported mode `{DataKey.get(dcate).name}`.")

        # TODO: handle 3d scenarios
        if isinstance(in_arg, Boxes):
            return out_arg
        else:
            return out_arg.to_tensor(mode=mode)

    def _preproc_keypoints(self, arg: DataType, dcate: DataKey) -> Keypoints:
        dtype = None

        if self.contains_video_sequential:
            arg = cast(Union[torch.Tensor, List[torch.Tensor]], arg)
            if isinstance(arg, list):
                if not torch.is_floating_point(arg[0]):
                    dtype = arg[0].dtype
                    arg = [a.float() for a in arg]
            elif not torch.is_floating_point(arg):
                dtype = arg.dtype
                arg = arg.float()
            video_result = VideoKeypoints.from_tensor(arg)
            return video_result.type(dtype) if dtype else video_result
        elif self.contains_3d_augmentation:
            raise NotImplementedError("3D keypoint handlers are not yet supported.")
        elif isinstance(arg, Keypoints):
            return arg
        else:
            arg = cast(torch.Tensor, arg)
            if not torch.is_floating_point(arg):
                dtype = arg.dtype
                arg = arg.float()
            # TODO: Add List[torch.Tensor] in the future.
            result = Keypoints.from_tensor(arg)
            return result.type(dtype) if dtype else result

    def _postproc_keypoint(
        self, in_arg: DataType, out_arg: Keypoints, dcate: DataKey
    ) -> Union[torch.Tensor, List[torch.Tensor], Keypoints]:
        if isinstance(in_arg, Keypoints):
            return out_arg
        else:
            return out_arg.to_tensor()
