# SPDX-FileCopyrightText: Copyright (c) 2025 Comfy Org. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import torch


def apply_rope1(x: torch.Tensor, freqs_cis: torch.Tensor):
    x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
    if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
        freqs_cis = freqs_cis[:, :, :x_.shape[2]]

    x_out = freqs_cis[..., 0] * x_[..., 0]
    x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
    return x_out.reshape(*x.shape).type_as(x)


def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)


def apply_rope_split_half1(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    t_ = x.reshape(*x.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).to(freqs_cis.dtype)
    t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
    return t_out.movedim(-1, -2).reshape(*x.shape).type_as(x)


def apply_rope_split_half(
    xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    return apply_rope_split_half1(xq, freqs_cis), apply_rope_split_half1(xk, freqs_cis)


# =============================================================================
# torch.library Custom Op Definitions
# =============================================================================


@torch.library.custom_op("comfy_kitchen::apply_rope", mutates_args=())
def _op_apply_rope(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    from comfy_kitchen.registry import registry

    kwargs = {"xq": xq, "xk": xk, "freqs_cis": freqs_cis}
    impl = registry.get_implementation("apply_rope", kwargs=kwargs)
    return impl(**kwargs)


@_op_apply_rope.register_fake
def _op_apply_rope_fake(xq, xk, freqs_cis):
    return torch.empty_like(xq), torch.empty_like(xk)


@torch.library.custom_op("comfy_kitchen::apply_rope1", mutates_args=())
def _op_apply_rope1(
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> torch.Tensor:
    from comfy_kitchen.registry import registry

    kwargs = {"x": x, "freqs_cis": freqs_cis}
    impl = registry.get_implementation("apply_rope1", kwargs=kwargs)
    return impl(**kwargs)


@_op_apply_rope1.register_fake
def _op_apply_rope1_fake(x, freqs_cis):
    return torch.empty_like(x)


@torch.library.custom_op("comfy_kitchen::apply_rope_split_half", mutates_args=())
def _op_apply_rope_split_half(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    from comfy_kitchen.registry import registry

    kwargs = {"xq": xq, "xk": xk, "freqs_cis": freqs_cis}
    impl = registry.get_implementation("apply_rope_split_half", kwargs=kwargs)
    return impl(**kwargs)


@_op_apply_rope_split_half.register_fake
def _op_apply_rope_split_half_fake(xq, xk, freqs_cis):
    return torch.empty_like(xq), torch.empty_like(xk)


@torch.library.custom_op("comfy_kitchen::apply_rope_split_half1", mutates_args=())
def _op_apply_rope_split_half1(
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> torch.Tensor:
    from comfy_kitchen.registry import registry

    kwargs = {"x": x, "freqs_cis": freqs_cis}
    impl = registry.get_implementation("apply_rope_split_half1", kwargs=kwargs)
    return impl(**kwargs)


@_op_apply_rope_split_half1.register_fake
def _op_apply_rope_split_half1_fake(x, freqs_cis):
    return torch.empty_like(x)
