# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Dict, Optional, Tuple

import torch
from torch import nn

from kornia.feature import DescriptorMatcher, GFTTAffNetHardNet, LocalFeatureMatcher, LoFTR
from kornia.feature.integrated import LocalFeature
from kornia.geometry.linalg import transform_points
from kornia.geometry.ransac import RANSAC
from kornia.geometry.transform import warp_perspective


class HomographyTracker(nn.Module):
    r"""Perform local-feature-based tracking of the target planar object in the sequence of the frames.

    Args:
        initial_matcher: image matching module, e.g. :class:`~kornia.feature.LocalFeatureMatcher`
                          or :class:`~kornia.feature.LoFTR`. Default: :class:`~kornia.feature.GFTTAffNetHardNet`.
        fast_matcher: fast image matching module, e.g. :class:`~kornia.feature.LocalFeatureMatcher`
                          or :class:`~kornia.feature.LoFTR`. Default: :class:`~kornia.feature.DescriptorMatcher`.
        ransac: homography estimation module. Default: :class:`~kornia.geometry.RANSAC`.
        minimum_inliers_num: threshold for number inliers for matching to be successful.

    """

    def __init__(
        self,
        initial_matcher: Optional[LocalFeature] = None,
        fast_matcher: Optional[nn.Module] = None,
        ransac: Optional[nn.Module] = None,
        minimum_inliers_num: int = 30,
    ) -> None:
        super().__init__()
        self.initial_matcher = initial_matcher or (
            LocalFeatureMatcher(GFTTAffNetHardNet(3000), DescriptorMatcher("smnn", 0.95))
        )
        self.fast_matcher = fast_matcher or LoFTR("outdoor")
        self.ransac = ransac or RANSAC("homography", inl_th=5.0, batch_size=4096, max_iter=10, max_lo_iters=10)
        self.minimum_inliers_num = minimum_inliers_num

        # placeholders
        self.target: torch.Tensor
        self.target_initial_representation: Dict[str, torch.Tensor] = {}
        self.target_fast_representation: Dict[str, torch.Tensor] = {}
        self.previous_homography: Optional[torch.Tensor] = None

        self.inliers_num: int = 0
        self.keypoints0_num: int = 0
        self.keypoints1_num: int = 0

        self.reset_tracking()

    @property
    def device(self) -> torch.device:
        """Return the device used by the current target image tensor.

        Returns:
            The ``torch.device`` where ``self.target`` is allocated.
        """
        return self.target.device

    @property
    def dtype(self) -> torch.dtype:
        """Return the data type used by the current target image tensor.

        Returns:
            The ``torch.dtype`` of ``self.target``.
        """
        return self.target.dtype

    @torch.no_grad()
    def set_target(self, target: torch.Tensor) -> None:
        """Register a new target image and refresh cached matcher features.

        Args:
            target: Reference target image tensor used for subsequent matching.

        Returns:
            None.

        The method clears previously cached features and precomputes new
        feature representations when the configured matchers expose an
        ``extract_features`` method.
        """
        self.target = target
        self.target_initial_representation = {}
        self.target_fast_representation = {}
        if hasattr(self.initial_matcher, "extract_features") and isinstance(
            self.initial_matcher.extract_features, nn.Module
        ):
            self.target_initial_representation = self.initial_matcher.extract_features(target)
        if hasattr(self.fast_matcher, "extract_features") and isinstance(self.fast_matcher.extract_features, nn.Module):
            self.target_fast_representation = self.fast_matcher.extract_features(target)

    def reset_tracking(self) -> None:
        """Reset temporal tracking state from previously processed frames.

        Returns:
            None.
        """
        self.previous_homography = None

    def no_match(self) -> Tuple[torch.Tensor, bool]:
        """Return a failed-match response and clear current match statistics.

        Returns:
            A tuple ``(H, is_valid)`` where ``H`` is an empty ``3 x 3`` tensor
            on the tracker device and dtype, and ``is_valid`` is ``False``.
        """
        self.inliers_num = 0
        self.keypoints0_num = 0
        self.keypoints1_num = 0
        return torch.empty(3, 3, device=self.device, dtype=self.dtype), False

    def match_initial(self, x: torch.Tensor) -> Tuple[torch.Tensor, bool]:
        """Estimate a homography from the initial target frame to frame ``x``.

        Args:
            x: Current frame tensor to match against the stored target image.

        Returns:
            A tuple ``(H, is_valid)`` where ``H`` is the estimated homography
            matrix and ``is_valid`` indicates whether enough inliers were found.

        The method updates keypoint counters, inlier statistics, and stores the
        estimated homography as ``previous_homography`` on success.
        """
        input_dict: Dict[str, torch.Tensor] = {"image0": self.target, "image1": x}

        for k, v in self.target_initial_representation.items():
            input_dict[f"{k}0"] = v

        match_dict: Dict[str, torch.Tensor] = self.initial_matcher(input_dict)
        keypoints0 = match_dict["keypoints0"][match_dict["batch_indexes"] == 0]
        keypoints1 = match_dict["keypoints1"][match_dict["batch_indexes"] == 0]

        self.keypoints0_num = len(keypoints0)
        self.keypoints1_num = len(keypoints1)

        if self.keypoints0_num < self.minimum_inliers_num:
            return self.no_match()

        H, inliers = self.ransac(keypoints0, keypoints1)
        self.inliers_num = inliers.sum().item()

        if self.inliers_num < self.minimum_inliers_num:
            return self.no_match()
        self.previous_homography = H.clone()

        return H, True

    def track_next_frame(self, x: torch.Tensor) -> Tuple[torch.Tensor, bool]:
        """Track the target in frame ``x`` using the previous homography prior.

        Args:
            x: Current frame tensor to align with the target image.

        Returns:
            A tuple ``(H, is_valid)`` where ``H`` is the updated homography and
            ``is_valid`` indicates whether tracking remained reliable.

        The frame is first prewarped by the inverse of the previous homography,
        then matched with ``fast_matcher`` and verified using RANSAC.
        """
        if self.previous_homography is not None:  # mypy, shut up
            Hwarp = self.previous_homography.clone()[None]
        # make a bit of border for safety
        Hwarp[:, 0:2, 0:2] = Hwarp[:, 0:2, 0:2] / 0.8
        Hwarp[:, 0:2, 2] -= 10.0
        Hinv = torch.inverse(Hwarp)
        h, w = self.target.shape[2:]
        frame_warped = warp_perspective(x, Hinv, (h, w))
        input_dict: Dict[str, torch.Tensor] = {"image0": self.target, "image1": frame_warped}
        for k, v in self.target_fast_representation.items():
            input_dict[f"{k}0"] = v

        match_dict = self.fast_matcher(input_dict)
        keypoints0 = match_dict["keypoints0"][match_dict["batch_indexes"] == 0]
        keypoints1 = match_dict["keypoints1"][match_dict["batch_indexes"] == 0]
        keypoints1 = transform_points(Hwarp, keypoints1)

        self.keypoints0_num = len(keypoints0)
        self.keypoints1_num = len(keypoints1)

        if self.keypoints0_num < self.minimum_inliers_num:
            self.reset_tracking()
            return self.no_match()

        H, inliers = self.ransac(keypoints0, keypoints1)
        self.inliers_num = inliers.sum().item()

        if self.inliers_num < self.minimum_inliers_num:
            self.reset_tracking()
            return self.no_match()

        self.previous_homography = H.clone()
        return H, True

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, bool]:
        """Run one tracking step on frame ``x``.

        Args:
            x: Current frame tensor.

        Returns:
            A tuple ``(H, is_valid)`` from ``track_next_frame`` when previous
            state exists, otherwise from ``match_initial``.
        """
        if self.previous_homography is not None:
            return self.track_next_frame(x)
        return self.match_initial(x)
