import os
import sys
import logging
import torch
import importlib
import math
import datetime
from tqdm import tqdm

import folder_paths
import comfy.model_management as mm
from comfy.cli_args import args
from comfy.ldm.modules.attention import wrap_attn, optimized_attention, attention_pytorch
import comfy.utils
import comfy.sd
import comfy.ops

try:
    from comfy_api.latest import io
    v3_available = True
except ImportError:
    v3_available = False
    logging.warning("ComfyUI v3 node API not available, please update ComfyUI to access latest v3 nodes.")

sageattn_modes = ["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp8_cuda++", "sageattn3", "sageattn3_per_block_mean"]

def get_sage_func(sage_attention, allow_compile=False):
    logging.info(f"Using sage attention mode: {sage_attention}")
    if sage_attention == "auto":
        from sageattention import sageattn
        def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
            return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
    elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda":
        from sageattention import sageattn_qk_int8_pv_fp16_cuda
        def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
            return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout)
    elif sage_attention == "sageattn_qk_int8_pv_fp16_triton":
        from sageattention import sageattn_qk_int8_pv_fp16_triton
        def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
            return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
    elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda":
        from sageattention import sageattn_qk_int8_pv_fp8_cuda
        def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
            return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout)
    elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda++":
        from sageattention import sageattn_qk_int8_pv_fp8_cuda
        def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"):
            return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout)
    elif "sageattn3" in sage_attention:
        from sageattn3 import sageattn3_blackwell
        def sage_func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD", **kwargs):
            q, k, v = [x.transpose(1, 2) if tensor_layout == "NHD" else x for x in (q, k, v)]
            out = sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=(sage_attention == "sageattn3_per_block_mean"))
            return out.transpose(1, 2) if tensor_layout == "NHD" else out

    if not allow_compile:
        sage_func = torch.compiler.disable()(sage_func)

    @wrap_attn
    def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
        if kwargs.get("low_precision_attention", True) is False:
            return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
        in_dtype = v.dtype
        if q.dtype == torch.float32 or k.dtype == torch.float32 or v.dtype == torch.float32:
            q, k, v = q.to(torch.float16), k.to(torch.float16), v.to(torch.float16)
        if skip_reshape:
            b, _, _, dim_head = q.shape
            tensor_layout="HND"
        else:
            b, _, dim_head = q.shape
            dim_head //= heads
            q, k, v = map(
                lambda t: t.view(b, -1, heads, dim_head),
                (q, k, v),
            )
            tensor_layout="NHD"
        if mask is not None:
            # add a batch dimension if there isn't already one
            if mask.ndim == 2:
                mask = mask.unsqueeze(0)
            # add a heads dimension if there isn't already one
            if mask.ndim == 3:
                mask = mask.unsqueeze(1)
        out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout).to(in_dtype)
        if tensor_layout == "HND":
            if not skip_output_reshape:
                out = (
                    out.transpose(1, 2).reshape(b, -1, heads * dim_head)
                )
        else:
            if skip_output_reshape:
                out = out.transpose(1, 2)
            else:
                out = out.reshape(b, -1, heads * dim_head)
        return out
    return attention_sage


from comfy.patcher_extension import CallbacksMP
class PathchSageAttentionKJ():
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "model": ("MODEL",),
            "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch the attention of the model passing through this node to use sageattn. To revert, run this node again with the disabled option. Requires the sageattention library to be installed."}),
        },
        "optional": {
            "allow_compile": ("BOOLEAN", {"default": False, "tooltip": "Allow the use of torch.compile for the sage attention function, requires latest sageattn 2.2.0 or higher."})
            }
        }

    RETURN_TYPES = ("MODEL", )
    FUNCTION = "patch"
    DESCRIPTION = "Experimental node for patching attention mode. This doesn't use the model patching system and thus can't be disabled without running the node again with 'disabled' option."
    EXPERIMENTAL = True
    CATEGORY = "KJNodes/experimental"

    def patch(self, model, sage_attention, allow_compile=False):
        if sage_attention == "disabled":
            return model,

        model_clone = model.clone()

        new_attention = get_sage_func(sage_attention, allow_compile=allow_compile)
        def attention_override_sage(func, *args, **kwargs):
            return new_attention.__wrapped__(*args, **kwargs)

        # attention override
        model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage

        return model_clone,


def get_flash_func(allow_compile=False, cast_dtype=torch.float16):
    # Prefer FA2 (broad arch support, plain tensor return); fall back to FA3
    # (flash_attn_interface), which has no dropout arg and returns (out, lse).
    is_fa3 = False
    try:
        from flash_attn import flash_attn_func
    except ImportError:
        try:
            from flash_attn_interface import flash_attn_func
            is_fa3 = True
        except ImportError:
            raise ImportError(
                "Flash attention not found. Install either FA2 ('flash_attn') or "
                "FA3 ('flash_attn_interface', pip package 'flash-attn-3')."
            )
    logging.info(f"Using flash attention {'3' if is_fa3 else '2'}: cast_dtype={cast_dtype}")

    # q, k, v in NHD layout (b, seq, heads, dim_head)
    def flash_func(q, k, v):
        if is_fa3:
            out = flash_attn_func(q, k, v, causal=False)
        else:
            out = flash_attn_func(q, k, v, dropout_p=0.0, causal=False)
        # FA3 returns (out, softmax_lse); FA2 returns the tensor directly
        return out[0] if isinstance(out, tuple) else out

    if not allow_compile:
        flash_func = torch.compiler.disable()(flash_func)

    if torch.cuda.is_available():
        probe = torch.zeros(1, 8, 2, 64, dtype=cast_dtype, device="cuda")
        flash_func(probe, probe, probe)

    @wrap_attn
    def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
        if mask is not None:
            raise RuntimeError("Flash attention does not support attention masks")
        in_dtype = v.dtype
        # flash_attn only supports fp16/bf16
        if q.dtype == torch.float32 or k.dtype == torch.float32 or v.dtype == torch.float32:
            q, k, v = q.to(cast_dtype), k.to(cast_dtype), v.to(cast_dtype)
        # flash_attn wants NHD layout (b, seq, heads, dim_head)
        if skip_reshape:
            # input is HND (b, heads, seq, dim_head)
            b, _, _, dim_head = q.shape
            q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
        else:
            b, _, dim_head = q.shape
            dim_head //= heads
            q, k, v = map(lambda t: t.view(b, -1, heads, dim_head), (q, k, v))
        out = flash_func(q, k, v).to(in_dtype)
        if skip_output_reshape:
            out = out.transpose(1, 2)  # NHD -> HND
        else:
            out = out.reshape(b, -1, heads * dim_head)
        return out
    return attention_flash


class PatchFlashAttentionKJ():
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "model": ("MODEL",),
        },
        "optional": {
            "allow_compile": ("BOOLEAN", {"default": False, "tooltip": "Allow torch.compile to trace into the flash attention function. If disabled (default), the function is wrapped with torch.compiler.disable() for compatibility, matching the sage attention node."}),
        }}

    RETURN_TYPES = ("MODEL", )
    FUNCTION = "patch"
    DESCRIPTION = "Experimental node for patching attention to use flash attention, without the silent SDPA fallback the ComfyUI default does. Patches the attention of the model passing through this node; to disable, bypass or disconnect this node. Requires the flash_attn library to be installed."
    EXPERIMENTAL = True
    CATEGORY = "KJNodes/experimental"

    def patch(self, model, allow_compile=False):
        # match the model's compute dtype for the fp32 downcast, fall back to fp16
        inference_dtype = model.model.get_dtype_inference() if hasattr(model.model, "get_dtype_inference") else torch.float16
        cast_dtype = inference_dtype if inference_dtype in (torch.float16, torch.bfloat16) else torch.float16

        new_attention = get_flash_func(
            allow_compile=allow_compile,
            cast_dtype=cast_dtype,
        )

        model_clone = model.clone()
        def attention_override_flash(func, *args, **kwargs):
            return new_attention.__wrapped__(*args, **kwargs)

        # attention override
        model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_flash

        return model_clone,


class CheckpointLoaderKJ():
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
            "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],),
            "compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}),
            "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the cublas_ops arg"}),
            "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
            "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, required minimum pytorch version 2.7.1"}),
        }}

    RETURN_TYPES = ("MODEL", "CLIP", "VAE")
    FUNCTION = "load"
    DESCRIPTION = "Experimental node for patching torch.nn.Linear with CublasLinear."
    EXPERIMENTAL = True
    CATEGORY = "KJNodes/model_loaders"

    def load(self, ckpt_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation):
        DTYPE_MAP = {
            "fp8_e4m3fn": torch.float8_e4m3fn,
            "fp8_e5m2": torch.float8_e5m2,
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
            "fp32": torch.float32
        }
        model_options = {}
        if dtype := DTYPE_MAP.get(weight_dtype):
            model_options["dtype"] = dtype
            logging.info(f"Setting {ckpt_name} weight dtype to {dtype}")

        if weight_dtype == "fp8_e4m3fn_fast":
            model_options["dtype"] = torch.float8_e4m3fn
            model_options["fp8_optimizations"] = True

        if patch_cublaslinear:
            args.fast.add("cublas_ops")
        else:
            args.fast.discard("cublas_ops")

        ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
        model, clip, vae, _ = comfy.sd.load_checkpoint_guess_config(
            ckpt_path,
            output_vae=True,
            output_clip=True,
            embedding_directory=folder_paths.get_folder_paths("embeddings"),
            model_options=model_options)

        if dtype := DTYPE_MAP.get(compute_dtype):
            model.set_model_compute_dtype(dtype)
            model.force_cast_weights = False
            logging.info(f"Setting {ckpt_name} compute dtype to {dtype}")

        if enable_fp16_accumulation:
            if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
                torch.backends.cuda.matmul.allow_fp16_accumulation = True
            else:
                raise RuntimeError("Failed to set fp16 accumulation, requires pytorch version 2.7.1 or higher")
        else:
            if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
                torch.backends.cuda.matmul.allow_fp16_accumulation = False

        if sage_attention != "disabled":
            new_attention = get_sage_func(sage_attention)
            def attention_override_sage(func, *args, **kwargs):
                return new_attention.__wrapped__(*args, **kwargs)

            # attention override
            model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage

        return model, clip, vae


class DiffusionModelSelector():
    @classmethod
    def INPUT_TYPES(s):
        ltx2_connector_models = folder_paths.get_filename_list("text_encoders")
        ltx2_connector_models = [m for m in ltx2_connector_models if "connector" in m.lower()]
        return {"required": {
            "model_name": (folder_paths.get_filename_list("diffusion_models") + ltx2_connector_models, {"tooltip": "The name of the checkpoint (model) to load."}),
        },
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("model_path",)
    FUNCTION = "get_path"
    DESCRIPTION = "Returns the path to the model as a string."
    EXPERIMENTAL = True
    CATEGORY = "KJNodes/model_loaders"

    def get_path(self, model_name):
        if "connector" in model_name.lower():
            model_path = folder_paths.get_full_path_or_raise("text_encoders", model_name)
        else:
            model_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name)
        return (model_path,)

def _load_diffusion_model_kj(unet_path, model_options=None, extra_state_dict=None, disable_dynamic=False):
    model_options = {} if model_options is None else dict(model_options)

    sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
    if extra_state_dict is not None:
        extra_sd = comfy.utils.load_torch_file(extra_state_dict)
        sd.update(extra_sd)
        del extra_sd

        diffusion_model_prefix = comfy.sd.model_detection.unet_prefix_from_state_dict(sd)
        sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=False)

    model = comfy.sd.load_diffusion_model_state_dict(
        sd,
        model_options=model_options,
        metadata=metadata,
        disable_dynamic=disable_dynamic,
    )

    model.cached_patcher_init = (_load_diffusion_model_kj, (unet_path, model_options, extra_state_dict))
    return model

class DiffusionModelLoaderKJ():
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}),
            "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],),
            "compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}),
            "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the cublas_ops arg"}),
            "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
            "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
        },
        "optional": {
            "extra_state_dict": ("STRING", {"forceInput": True, "tooltip": "The full path to an additional state dict to load, this will be merged with the main state dict. Useful for example to add VACE module to a WanVideoModel. You can use DiffusionModelSelector to easily get the path."}),
        }
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch_and_load"
    DESCRIPTION = "Node for patching torch.nn.Linear with CublasLinear."
    EXPERIMENTAL = True
    CATEGORY = "KJNodes/model_loaders"

    def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation, extra_state_dict=None):
        DTYPE_MAP = {
            "fp8_e4m3fn": torch.float8_e4m3fn,
            "fp8_e5m2": torch.float8_e5m2,
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
            "fp32": torch.float32
        }
        model_options = {}
        if dtype := DTYPE_MAP.get(weight_dtype):
            model_options["dtype"] = dtype
            logging.info(f"Setting {model_name} weight dtype to {dtype}")

        if weight_dtype == "fp8_e4m3fn_fast":
            model_options["dtype"] = torch.float8_e4m3fn
            model_options["fp8_optimizations"] = True

        if enable_fp16_accumulation:
            if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
                torch.backends.cuda.matmul.allow_fp16_accumulation = True
            else:
                raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.1 or higher")
        else:
            if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
                torch.backends.cuda.matmul.allow_fp16_accumulation = False

        if patch_cublaslinear:
            args.fast.add("cublas_ops")
        else:
            args.fast.discard("cublas_ops")

        unet_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name)

        model = _load_diffusion_model_kj(unet_path, model_options=model_options, extra_state_dict=extra_state_dict)
        if dtype := DTYPE_MAP.get(compute_dtype):
            model.set_model_compute_dtype(dtype)
            model.force_cast_weights = False
            logging.info(f"Setting {model_name} compute dtype to {dtype}")

        if sage_attention != "disabled":
            new_attention = get_sage_func(sage_attention)
            def attention_override_sage(func, *args, **kwargs):
                return new_attention.__wrapped__(*args, **kwargs)

            # attention override
            model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage

        return (model,)

class ModelPatchTorchSettings:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "model": ("MODEL",),
            "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
        }}

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"
    DESCRIPTION = "Adds callbacks to model to set torch settings before and after running the model."
    EXPERIMENTAL = True
    CATEGORY = "KJNodes/experimental"

    def patch(self, model, enable_fp16_accumulation):        
        model_clone = model.clone()

        def patch_enable_fp16_accum(model):
            logging.info("Patching torch settings: torch.backends.cuda.matmul.allow_fp16_accumulation = True")
            torch.backends.cuda.matmul.allow_fp16_accumulation = True
        def patch_disable_fp16_accum(model):
            logging.info("Patching torch settings: torch.backends.cuda.matmul.allow_fp16_accumulation = False")
            torch.backends.cuda.matmul.allow_fp16_accumulation = False

        if enable_fp16_accumulation:
            if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
                model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_enable_fp16_accum)
                model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_disable_fp16_accum)
            else:
                raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.1 or higher")
        else:
            if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
                model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_disable_fp16_accum)
            else:
                raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.1 or higher")

        return (model_clone,)


class PatchModelPatcherOrder:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
                    "model": ("MODEL",),
                    "patch_order": (["object_patch_first", "weight_patch_first"], {"default": "weight_patch_first", "tooltip": "Patch the comfy patch_model function to load weight patches (LoRAs) before compiling the model"}),
                    "full_load": (["enabled", "disabled", "auto"], {"default": "auto", "tooltip": "Disabling may help with memory issues when loading large models, when changing this you should probably force model reload to avoid issues!"}),
                }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"
    CATEGORY = "KJNodes/deprecated"
    DESCRIPTION = "NO LONGER NECESSARY OR FUNCTIONAL, keeping node for backwards compatibility. Use the TorchCompileModelAdvanced to use LoRA with torch.compile."
    DEPRECATED = True

    def patch(self, model, patch_order, full_load):
        return model,


class TorchCompileModelFluxAdvancedV2:
    def __init__(self):
        self._compiled = False

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
                    "model": ("MODEL",),
                    "backend": (["inductor", "cudagraphs"],),
                    "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
                    "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
                    "double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile double blocks"}),
                    "single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile single blocks"}),
                    "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
                },
                "optional": {
                    "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
                    "force_parameter_static_shapes": ("BOOLEAN", {"default": True, "tooltip": "torch._dynamo.config.force_parameter_static_shapes"}),
                }
                }
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "KJNodes/torchcompile"
    EXPERIMENTAL = True
    DEPRECATED = True
    DESCRIPTION = "Deprecated, use TorchCompileModelAdvanced instead."

    def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic, dynamo_cache_size_limit=64, force_parameter_static_shapes=True):
        from comfy_api.torch_helpers import set_torch_compile_wrapper
        m = model.clone()
        diffusion_model = m.get_model_object("diffusion_model")
        torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
        torch._dynamo.config.force_parameter_static_shapes = force_parameter_static_shapes

        compile_key_list = []

        try:
            if double_blocks:
                for i, block in enumerate(diffusion_model.double_blocks):
                    print("Adding double block to compile list", i)
                    compile_key_list.append(f"diffusion_model.double_blocks.{i}")
            if single_blocks:
                for i, block in enumerate(diffusion_model.single_blocks):
                    compile_key_list.append(f"diffusion_model.single_blocks.{i}")

            set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph)           
        except Exception as e:
            raise RuntimeError("Failed to compile model") from e

        return (m, )


class TorchCompileModelWanVideoV2:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "backend": (["inductor","cudagraphs"], {"default": "inductor"}),
                "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
                "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
                "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
                "compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}),
                "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),

            },
            "optional": {
                "force_parameter_static_shapes": ("BOOLEAN", {"default": True, "tooltip": "torch._dynamo.config.force_parameter_static_shapes"}),
            },
        }
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "KJNodes/torchcompile"
    EXPERIMENTAL = True
    DEPRECATED = True
    DESCRIPTION = "Deprecated, use TorchCompileModelAdvanced instead."

    def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, force_parameter_static_shapes=True):
        from comfy_api.torch_helpers import set_torch_compile_wrapper
        m = model.clone()
        diffusion_model = m.get_model_object("diffusion_model")
        torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
        torch._dynamo.config.force_parameter_static_shapes = force_parameter_static_shapes
        try:
            if compile_transformer_blocks_only:
                compile_key_list = []
                for i, block in enumerate(diffusion_model.blocks):
                    compile_key_list.append(f"diffusion_model.blocks.{i}")
            else:
                compile_key_list =["diffusion_model"]

            set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph)           
        except Exception as e:
            raise RuntimeError("Failed to compile model") from e

        return (m, )


_aimdo_patched = False

def patch_aimdo_for_compile():
    # reduce recompiles with dynamic VRAM
    global _aimdo_patched
    if _aimdo_patched:
        return
    _aimdo_patched = True
    names = ("cast_bias_weight", "uncast_bias_weight", "cast_modules_with_vbar", "resolve_cast_module_with_vbar")
    for name in names:
        fn = getattr(comfy.ops, name, None)
        if fn is not None:
            setattr(comfy.ops, name, torch._dynamo.disable(fn))
    try:
        import comfy_aimdo.torch as _at
        _at.get_tensor_from_raw_ptr = torch._dynamo.disable(_at.get_tensor_from_raw_ptr)
    except Exception:
        pass
    logging.info("KJNodes dynamic-compile: comfy.ops weight cast marked as eager graph break (recompile fix active).")


def skip_torch_compile_dict(guard_entries):
    # don't recompile when transformer_options change
    return [("transformer_options" not in entry.name) for entry in guard_entries]


def build_compile_kwargs(backend, mode, fullgraph, dynamic, use_guard_filter=True):
    # torch.compile forbids passing mode and options together; an explicit mode wins,
    # otherwise attach the guard filter via options on the default mode.
    kw = {"backend": backend, "fullgraph": fullgraph, "dynamic": dynamic}
    if mode and mode != "default":
        kw["mode"] = mode
    elif use_guard_filter:
        kw["options"] = {"guard_filter_fn": skip_torch_compile_dict}
    return kw

import weakref as _kj_weakref

_KJ_COMPILE_KEY = "torch.compile"
_KJ_COMPILED_BY_MODEL = _kj_weakref.WeakKeyDictionary()  # BaseModel instance -> {key: compiled_module}

# resolve compiled modules by the BaseModel actually executing
def _kj_apply_torch_compile_wrapper(executor, *args, **kwargs):
    compiled = _KJ_COMPILED_BY_MODEL.get(executor.class_obj)
    if not compiled:
        return executor(*args, **kwargs)  # this BaseModel wasn't compiled -> run eager, no swap
    orig = {}
    try:
        for key, value in compiled.items():
            orig[key] = comfy.utils.get_attr(executor.class_obj, key)
            comfy.utils.set_attr(executor.class_obj, key, value)
        return executor(*args, **kwargs)
    finally:
        for key, value in orig.items():
            comfy.utils.set_attr(executor.class_obj, key, value)


def kj_set_torch_compile_wrapper(model, backend, options=None, mode=None, fullgraph=False, dynamic=None, keys=("diffusion_model",)):
    WrappersMP = comfy.patcher_extension.WrappersMP
    model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, _KJ_COMPILE_KEY)
    if not keys:
        keys = ["diffusion_model"]
    compile_kwargs = {"backend": backend, "options": options, "mode": mode, "fullgraph": fullgraph, "dynamic": dynamic}
    compiled_modules = {key: torch.compile(model=model.get_model_object(key), **compile_kwargs) for key in keys}
    _KJ_COMPILED_BY_MODEL[model.model] = compiled_modules           # register by the BaseModel that will execute
    model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, _KJ_COMPILE_KEY, _kj_apply_torch_compile_wrapper)
    model.model_options["torch_compile_kwargs"] = compile_kwargs


class TorchCompileModelAdvanced:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "backend": (["inductor","cudagraphs"], {"default": "inductor"}),
                "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
                "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
                "dynamic": (
                    ["auto", "true", "false"],
                    {"default": "auto", "tooltip": "Use dynamic shape tracing."},
                ),
                "compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}),
                "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
                "debug_compile_keys": ("BOOLEAN", {"default": False, "tooltip": "Print the compile keys used for torch.compile"}),
            },
            "optional": {
                "disable_dynamic_vram": ("BOOLEAN", {"default": False, "tooltip": "Disable dynamic VRAM feature as it can cause issues with compile"}),
            }
        }
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"
    CATEGORY = "KJNodes/torchcompile"
    DESCRIPTION = "Advanced torch.compile patching for diffusion models."
    EXPERIMENTAL = True

    def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, debug_compile_keys, disable_dynamic_vram=False):
        if disable_dynamic_vram:
            try:
                m = model.clone(disable_dynamic=True)
            except TypeError:
                logging.warning("This ComfyUI version do not support disabling dynamic VRAM through a node. This may cause issues with torch.compile.")
                m = model.clone()
        else:
            m = model.clone()

        diffusion_model = m.get_model_object("diffusion_model")
        torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit

        try:
            compile_key_list = []
            if compile_transformer_blocks_only:
                layer_types = ["double_blocks", "single_blocks", "layers", "transformer_blocks", "blocks", "visual_transformer_blocks", "text_transformer_blocks", "patch_blocks", "pixel_blocks"]
                for layer_name in layer_types:
                    if hasattr(diffusion_model, layer_name):
                        blocks = getattr(diffusion_model, layer_name)
                        for i in range(len(blocks)):
                            compile_key_list.append(f"diffusion_model.{layer_name}.{i}")
                if not compile_key_list:
                    logging.warning("No known transformer blocks found to compile, compiling entire diffusion model instead")
                elif debug_compile_keys:
                    logging.info("TorchCompileModelAdvanced: Compile key list:")
                    for key in compile_key_list:
                        logging.info(f" - {key}")
            if not compile_key_list:
                compile_key_list =["diffusion_model"]

            dynamic_kv = {"true": True, "false": False, "auto": None}
            try:
                dynamic = dynamic_kv[dynamic]
            except KeyError:
                raise ValueError(f"Invalid dynamic arg {dynamic}")

            if not disable_dynamic_vram and getattr(m, "is_dynamic", lambda: False)():
                patch_aimdo_for_compile() # reduce recompiles with dynamic VRAM, will break the graph still but better than nothing
            kj_set_torch_compile_wrapper(model=m, keys=compile_key_list, **build_compile_kwargs(backend, mode, fullgraph, dynamic))
        except Exception as e:
            raise RuntimeError("Failed to compile model") from e

        return (m, )


class TorchCompileModelQwenImage:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "backend": (["inductor","cudagraphs"], {"default": "inductor"}),
                "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
                "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
                "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
                "compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}),
                "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
            },
        }
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "KJNodes/torchcompile"
    EXPERIMENTAL = True
    DEPRECATED = True
    DESCRIPTION = "Deprecated, use TorchCompileModelAdvanced instead."

    def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only):
        from comfy_api.torch_helpers import set_torch_compile_wrapper
        m = model.clone()
        diffusion_model = m.get_model_object("diffusion_model")
        torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
        try:
            if compile_transformer_blocks_only:
                compile_key_list = []
                for i, block in enumerate(diffusion_model.transformer_blocks):
                    compile_key_list.append(f"diffusion_model.transformer_blocks.{i}")
            else:
                compile_key_list =["diffusion_model"]

            set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph)
        except Exception as e:
            raise RuntimeError("Failed to compile model") from e

        return (m, )

class TorchCompileVAE:
    def __init__(self):
        self._compiled_encoder = False
        self._compiled_decoder = False

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
                    "vae": ("VAE",),
                    "backend": (["inductor", "cudagraphs"],),
                    "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
                    "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
                    "compile_encoder": ("BOOLEAN", {"default": True, "tooltip": "Compile encoder"}),
                    "compile_decoder": ("BOOLEAN", {"default": True, "tooltip": "Compile decoder"}),
                }}
    RETURN_TYPES = ("VAE",)
    FUNCTION = "compile"

    CATEGORY = "KJNodes/torchcompile"
    EXPERIMENTAL = True

    def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder):
        if compile_encoder:
            if not self._compiled_encoder:
                encoder_name = "encoder"
                if hasattr(vae.first_stage_model, "taesd_encoder"):
                    encoder_name = "taesd_encoder"

                try:
                    setattr(
                        vae.first_stage_model,
                        encoder_name,
                        torch.compile(
                            getattr(vae.first_stage_model, encoder_name),
                            mode=mode,
                            fullgraph=fullgraph,
                            backend=backend,
                        ),
                    )
                    self._compiled_encoder = True
                except Exception as e:
                    raise RuntimeError("Failed to compile model") from e
        if compile_decoder:
            if not self._compiled_decoder:
                decoder_name = "decoder"
                if hasattr(vae.first_stage_model, "taesd_decoder"):
                    decoder_name = "taesd_decoder"

                try:
                    setattr(
                        vae.first_stage_model,
                        decoder_name,
                        torch.compile(
                            getattr(vae.first_stage_model, decoder_name),
                            mode=mode,
                            fullgraph=fullgraph,
                            backend=backend,
                        ),
                    )
                    self._compiled_decoder = True
                except Exception as e:
                    raise RuntimeError("Failed to compile model") from e
        return (vae, )

class TorchCompileControlNet:
    def __init__(self):
        self._compiled= False

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
                    "controlnet": ("CONTROL_NET",),
                    "backend": (["inductor", "cudagraphs"],),
                    "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
                    "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
                }}
    RETURN_TYPES = ("CONTROL_NET",)
    FUNCTION = "compile"

    CATEGORY = "KJNodes/torchcompile"
    EXPERIMENTAL = True

    def compile(self, controlnet, backend, mode, fullgraph):
        if not self._compiled:
            try:
                # for i, block in enumerate(controlnet.control_model.double_blocks):
                #     print("Compiling controlnet double_block", i)
                #     controlnet.control_model.double_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
                controlnet.control_model = torch.compile(controlnet.control_model, mode=mode, fullgraph=fullgraph, backend=backend)
                self._compiled = True
            except Exception as e:
                self._compiled = False
                raise RuntimeError("Failed to compile model") from e

        return (controlnet, )

class WanVideoTeaCacheKJ:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "rel_l1_thresh": ("FLOAT", {"default": 0.275, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Threshold for to determine when to apply the cache, compromise between speed and accuracy. When using coefficients a good value range is something between 0.2-0.4 for all but 1.3B model, which should be about 10 times smaller, same as when not using coefficients."}),
                "start_percent": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}),
                "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}),
                "cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}),
                "coefficients": (["disabled", "1.3B", "14B", "i2v_480", "i2v_720"], {"default": "i2v_480", "tooltip": "Coefficients for rescaling the relative l1 distance, if disabled the threshold value should be about 10 times smaller than the value used with coefficients."}),
            }
        }

    RETURN_TYPES = ("MODEL",)
    RETURN_NAMES = ("model",)
    FUNCTION = "patch_teacache"
    CATEGORY = "KJNodes/deprecated"
    DEPRECATED = True
    DESCRIPTION = """DEPRECATED, use the native EasyCache or alternative custom node that's up to date instead of this."""
    EXPERIMENTAL = True

    def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device, coefficients):
        return model,


from comfy.ldm.flux.math import apply_rope

def modified_wan_self_attention_forward(self, x, freqs, transformer_options={}):
    r"""
    Args:
        x(Tensor): Shape [B, L, num_heads, C / num_heads]
        freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
    """
    b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim

    # query, key, value function
    def qkv_fn(x):
        q = self.norm_q(self.q(x)).view(b, s, n, d)
        k = self.norm_k(self.k(x)).view(b, s, n, d)
        v = self.v(x).view(b, s, n * d)
        return q, k, v

    q, k, v = qkv_fn(x)
    q, k = apply_rope(q, k, freqs)
    feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight)

    x = comfy.ldm.modules.attention.optimized_attention(
            q.view(b, s, n * d),
            k.view(b, s, n * d),
            v,
            heads=self.num_heads,
            transformer_options=transformer_options,
        )

    x = self.o(x)

    x *= feta_scores

    return x

from einops import rearrange
def get_feta_scores(query, key, num_frames, enhance_weight, num_heads=12):
    img_q, img_k = query, key #torch.Size([2, 9216, 12, 128])

    if img_q.ndim == 4:
        B, ST, num_heads, head_dim = img_q.shape
    elif img_q.ndim == 3:
        B, ST, hidden_dim = img_q.shape
        head_dim = hidden_dim // num_heads

        # Reshape from [B, ST, hidden_dim] to [B, ST, num_heads, head_dim]
        img_q = img_q.view(B, ST, num_heads, head_dim)
        img_k = img_k.view(B, ST, num_heads, head_dim)

    spatial_dim = ST // num_frames

    query_image = rearrange(
        img_q, "B (T S) N C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
    )
    key_image = rearrange(
        img_k, "B (T S) N C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
    )

    return feta_score(query_image, key_image, head_dim, num_frames, enhance_weight)

def feta_score(query_image, key_image, head_dim, num_frames, enhance_weight):
    scale = head_dim**-0.5
    query_image = query_image * scale
    attn_temp = query_image @ key_image.transpose(-2, -1)  # translate attn to float32
    attn_temp = attn_temp.to(torch.float32)
    attn_temp = attn_temp.softmax(dim=-1)

    # Reshape to [batch_size * num_tokens, num_frames, num_frames]
    attn_temp = attn_temp.reshape(-1, num_frames, num_frames)

    # Create a mask for diagonal elements
    diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
    diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)

    # Zero out diagonal elements
    attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)

    # Calculate mean for each token's attention matrix
    # Number of off-diagonal elements per matrix is n*n - n
    num_off_diag = num_frames * num_frames - num_frames
    mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag

    enhance_scores = mean_scores.mean() * (num_frames + enhance_weight)
    enhance_scores = enhance_scores.clamp(min=1)
    return enhance_scores

import types
class WanAttentionPatch:
    def __init__(self, num_frames, weight):
        self.num_frames = num_frames
        self.enhance_weight = weight

    def __get__(self, obj, objtype=None):
        # Create bound method with stored parameters
        def wrapped_attention(self_module, *args, **kwargs):
            self_module.num_frames = self.num_frames
            self_module.enhance_weight = self.enhance_weight
            return modified_wan_self_attention_forward(self_module, *args, **kwargs)
        return types.MethodType(wrapped_attention, obj)

class WanVideoEnhanceAVideoKJ:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "latent": ("LATENT", {"tooltip": "Only used to get the latent count"}),
                "weight": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of the enhance effect"}),
           }
        }

    RETURN_TYPES = ("MODEL",)
    RETURN_NAMES = ("model",)
    FUNCTION = "enhance"
    CATEGORY = "KJNodes/wan"
    DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video"
    EXPERIMENTAL = True

    def enhance(self, model, weight, latent):
        if weight == 0:
            return (model,)

        num_frames = latent["samples"].shape[2]

        model_clone = model.clone()
        if 'transformer_options' not in model_clone.model_options:
            model_clone.model_options['transformer_options'] = {}
        model_clone.model_options["transformer_options"]["enhance_weight"] = weight
        diffusion_model = model_clone.get_model_object("diffusion_model")

        compile_settings = getattr(model.model, "compile_settings", None)
        for idx, block in enumerate(diffusion_model.blocks):
            patched_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__)
            if compile_settings is not None:
                patched_attn = torch.compile(patched_attn, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"])

            model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", patched_attn)

        return (model_clone,)

try:
    from comfy.ldm.lightricks.model import apply_rotary_emb
except ImportError:
    apply_rotary_emb = None

try:
    from comfy.ldm.lightricks.model import GuideAttentionMask as _GuideAttentionMask, _attention_with_guide_mask as _ltx_attn_with_guide_mask
except ImportError:
    _GuideAttentionMask = None
    _ltx_attn_with_guide_mask = None


def ltxv_feta_forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
    q = self.to_q(x)
    context = x if context is None else context
    k = self.to_k(context)
    v = self.to_v(context)

    q = self.q_norm(q)
    k = self.k_norm(k)

    if pe is not None:
        q = apply_rotary_emb(q, pe)
        k = apply_rotary_emb(k, pe if k_pe is None else k_pe)

    feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight, self.heads)

    if mask is None:
        out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
    elif _GuideAttentionMask is not None and isinstance(mask, _GuideAttentionMask):
        out = _ltx_attn_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
    else:
        out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)

    if self.to_gate_logits is not None:
        gate_logits = self.to_gate_logits(x)  # (B, T, H)
        b, t, _ = out.shape
        out = out.view(b, t, self.heads, self.dim_head)
        gates = 2.0 * torch.sigmoid(gate_logits)  # zero-init -> identity
        out = out * gates.unsqueeze(-1)
        out = out.view(b, t, self.heads * self.dim_head)

    return self.to_out(out) * feta_scores


class LTXCrossAttentionPatch:
    def __init__(self, num_frames, weight):
        self.num_frames = num_frames
        self.enhance_weight = weight

    def __get__(self, obj, objtype=None):
        # Create bound method with stored parameters
        def wrapped_attention(self_module, *args, **kwargs):
            self_module.num_frames = self.num_frames
            self_module.enhance_weight = self.enhance_weight
            return ltxv_feta_forward(self_module, *args, **kwargs)
        return types.MethodType(wrapped_attention, obj)

class LTXVEnhanceAVideoKJ:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "latent": ("LATENT", {"tooltip": "Only used to get the latent count"}),
                "weight": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of the enhance effect"}),
           }
        }

    RETURN_TYPES = ("MODEL",)
    RETURN_NAMES = ("model",)
    FUNCTION = "enhance"
    CATEGORY = "KJNodes/ltxv"
    DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video"
    EXPERIMENTAL = True

    def enhance(self, model, weight, latent):
        if weight == 0:
            return (model,)

        num_frames = latent["samples"].shape[2]

        model_clone = model.clone()
        if 'transformer_options' not in model_clone.model_options:
            model_clone.model_options['transformer_options'] = {}
        model_clone.model_options["transformer_options"]["enhance_weight"] = weight
        diffusion_model = model_clone.get_model_object("diffusion_model")

        for idx, block in enumerate(diffusion_model.transformer_blocks):
            patched_attn1 = LTXCrossAttentionPatch(num_frames, weight).__get__(block.attn1, block.__class__)
            model_clone.add_object_patch(f"diffusion_model.transformer_blocks.{idx}.attn1.forward", patched_attn1)
        return (model_clone,)

def _wan_compute_attention(self, query, context, transformer_options={}):
    k = self.norm_k(self.k(context))
    v = self.v(context)
    return comfy.ldm.modules.attention.optimized_attention(query, k, v, heads=self.num_heads, transformer_options=transformer_options).flatten(2)

def wan_nag_attention(self, query, context_positive, nag_context, transformer_options={}):
    x_positive = _wan_compute_attention(self, query, context_positive, transformer_options)
    x_negative = _wan_compute_attention(self, query, nag_context, transformer_options)
    return x_positive, x_negative

def normalized_attention_guidance(self, x_positive, x_negative):
    if self.inplace:
        nag_guidance = x_negative.mul_(self.nag_scale - 1).neg_().add_(x_positive, alpha=self.nag_scale)
        del x_negative
    else:
        nag_guidance = x_negative * (self.nag_scale - 1)
        del x_negative
        nag_guidance = (x_positive * self.nag_scale).sub_(nag_guidance)

    norm_positive = torch.norm(x_positive, p=1, dim=-1, keepdim=True)
    norm_guidance = torch.norm(nag_guidance, p=1, dim=-1, keepdim=True)

    scale = norm_guidance / norm_positive
    torch.nan_to_num_(scale, nan=10.0)
    mask = scale > self.nag_tau
    del scale

    adjustment = (norm_positive * self.nag_tau) / (norm_guidance + 1e-7)
    del norm_positive, norm_guidance

    nag_guidance.mul_(torch.where(mask, adjustment, 1.0))
    del mask, adjustment

    if self.inplace:
        return nag_guidance.sub_(x_positive).mul_(self.nag_alpha).add_(x_positive)
    else:
        nag_guidance.mul_(self.nag_alpha)
        return nag_guidance.add_(x_positive * (1 - self.nag_alpha))

#region NAG
def wan_crossattn_forward_nag(self, x, context, transformer_options={}, **kwargs):
    # Determine batch splitting and context handling
    if self.input_type == "default":
        # Single or [pos, neg] pair
        if context.shape[0] == 1:
            x_pos, context_pos = x, context
            x_neg, context_neg = None, None
        else:
            x_pos, x_neg = torch.chunk(x, 2, dim=0)
            context_pos, context_neg = torch.chunk(context, 2, dim=0)
    elif self.input_type == "batch":
        # Standard batch, no CFG
        x_pos, context_pos = x, context
        x_neg, context_neg = None, None

    # Positive branch
    q_pos = self.norm_q(self.q(x_pos))
    nag_context = self.nag_context
    if self.input_type == "batch":
        nag_context = nag_context.repeat(x_pos.shape[0], 1, 1)
    del x_pos

    x_positive, x_negative = wan_nag_attention(self, q_pos, context_pos, nag_context, transformer_options=transformer_options)
    del context_pos, q_pos

    x_pos_out = normalized_attention_guidance(self, x_positive, x_negative)
    del x_positive, x_negative

    # Negative branch
    if x_neg is not None and context_neg is not None:
        q_neg = self.norm_q(self.q(x_neg))
        k_neg = self.norm_k(self.k(context_neg))
        v_neg = self.v(context_neg)
        x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads, transformer_options=transformer_options)
        x = torch.cat([x_pos_out, x_neg_out], dim=0)
    else:
        x = x_pos_out

    return self.o(x)

def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len, transformer_options={}, **kwargs):
    context_img = context[:, :context_img_len]
    context = context[:, context_img_len:]

    q_img = self.norm_q(self.q(x))
    k_img = self.norm_k_img(self.k_img(context_img))
    v_img = self.v_img(context_img)
    img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
    del q_img, k_img, v_img, context_img

    if context.shape[0] == 2:
        x, x_real_negative = torch.chunk(x, 2, dim=0)
        context_positive, context_negative = torch.chunk(context, 2, dim=0)
    else:
        context_positive = context
        context_negative = None

    q = self.norm_q(self.q(x))

    x_positive, x_negative = wan_nag_attention(self, q, context_positive, self.nag_context, transformer_options=transformer_options)
    del q, context_positive
    x = normalized_attention_guidance(self, x_positive, x_negative)
    del x_positive, x_negative

    if context_negative is not None:
        q_real_negative = self.norm_q(self.q(x_real_negative))
        k_real_negative = self.norm_k(self.k(context_negative))
        v_real_negative = self.v(context_negative)
        x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads, transformer_options=transformer_options)
        x = torch.cat([x, x_real_negative], dim=0)

    return self.o(x + img_x)


class WanCrossAttentionPatch:
    def __init__(self, context, nag_scale, nag_alpha, nag_tau, i2v=False, input_type="default", inplace=True):
        self.nag_context = context
        self.nag_scale = nag_scale
        self.nag_alpha = nag_alpha
        self.nag_tau = nag_tau
        self.i2v = i2v
        self.input_type = input_type
        self.inplace = inplace
    def __get__(self, obj, objtype=None):
        # Create bound method with stored parameters
        def wrapped_attention(self_module, *args, **kwargs):
            self_module.nag_context = self.nag_context
            self_module.nag_scale = self.nag_scale
            self_module.nag_alpha = self.nag_alpha
            self_module.nag_tau = self.nag_tau
            self_module.input_type = self.input_type
            self_module.inplace = self.inplace
            if self.i2v:
                return wan_i2v_crossattn_forward_nag(self_module, *args, **kwargs)
            else:
                return wan_crossattn_forward_nag(self_module, *args, **kwargs)
        return types.MethodType(wrapped_attention, obj)


class WanVideoNAG:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "conditioning": ("CONDITIONING",),
                "nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of negative guidance effect"}),
                "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Mixing coefficient in that controls the balance between the normalized guided representation and the original positive representation."}),
                "nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Clipping threshold that controls how much the guided attention can deviate from the positive attention."}),
           },
           "optional": {
                "input_type": (["default", "batch"], {"tooltip": "Type of the model input"}),
                "inplace": ("BOOLEAN", {"default": False, "tooltip": "If true, modifies tensors in place to save memory. Leads to different numerical results which may change the output slightly."}),
           },

        }

    RETURN_TYPES = ("MODEL",)
    RETURN_NAMES = ("model",)
    FUNCTION = "patch"
    CATEGORY = "KJNodes/wan"
    DESCRIPTION = "https://github.com/ChenDarYen/Normalized-Attention-Guidance"
    EXPERIMENTAL = True

    def patch(self, model, conditioning, nag_scale, nag_alpha, nag_tau, input_type="default", inplace=False):
        if nag_scale == 0:
            return (model,)

        device = mm.get_torch_device()
        dtype = mm.unet_dtype()

        model_clone = model.clone()

        diffusion_model = model_clone.get_model_object("diffusion_model")

        diffusion_model.text_embedding.to(device)
        context = diffusion_model.text_embedding(conditioning[0][0].to(device, dtype))

        type_str = str(type(model.model.model_config).__name__)
        i2v = True if "WAN21_I2V" in type_str else False

        for idx, block in enumerate(diffusion_model.blocks):
            patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau, i2v, input_type=input_type, inplace=inplace).__get__(block.cross_attn, block.__class__)

            model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.cross_attn.forward", patched_attn)

        return (model_clone,)
    
class SkipLayerGuidanceWanVideo:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"model": ("MODEL", ),
                             "blocks": ("STRING", {"default": "10", "multiline": False}),
                             "start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
                             "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
                            }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "slg"
    EXPERIMENTAL = True
    DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks"
    DEPRECATED = True
    CATEGORY = "advanced/guidance"

    def slg(self, model, start_percent, end_percent, blocks):
        def skip(args, extra_args):
            transformer_options = extra_args.get("transformer_options", {})
            original_block = extra_args["original_block"]

            if not transformer_options:
                raise ValueError("transformer_options not found in extra_args, currently SkipLayerGuidanceWanVideo only works with TeaCacheKJ")
            if start_percent <= transformer_options["current_percent"] <= end_percent:
                if args["img"].shape[0] == 2:
                    prev_img_uncond = args["img"][0].unsqueeze(0)

                    new_args = {
                        "img": args["img"][1].unsqueeze(0),
                        "txt": args["txt"][1].unsqueeze(0),
                        "vec": args["vec"][1].unsqueeze(0),
                        "pe": args["pe"][1].unsqueeze(0)
                    }
                    
                    block_out = original_block(new_args)

                    out = {
                        "img": torch.cat([prev_img_uncond, block_out["img"]], dim=0),
                        "txt": args["txt"],
                        "vec": args["vec"],
                        "pe": args["pe"]
                    }
                else:
                    if transformer_options.get("cond_or_uncond") == [0]:
                        out = original_block(args)
                    else:
                        out = args
            else:
                out = original_block(args)
            return out

        block_list = [int(x.strip()) for x in blocks.split(",")]
        blocks = [int(i) for i in block_list]
        logging.info(f"Selected blocks to skip uncond on: {blocks}")

        m = model.clone()

        for b in blocks:
            #m.set_model_patch_replace(skip, "dit", "double_block", b)
            model_options = m.model_options["transformer_options"].copy()
            if "patches_replace" not in model_options:
                model_options["patches_replace"] = {}
            else:
                model_options["patches_replace"] = model_options["patches_replace"].copy()

            if "dit" not in model_options["patches_replace"]:
                model_options["patches_replace"]["dit"] = {}
            else:
                model_options["patches_replace"]["dit"] = model_options["patches_replace"]["dit"].copy()

            block = ("double_block", b)

            model_options["patches_replace"]["dit"][block] = skip
            m.model_options["transformer_options"] = model_options
            

        return (m, )

class CFGZeroStarAndInit:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
                    "model": ("MODEL",),
                    "use_zero_init": ("BOOLEAN", {"default": True}),
                    "zero_init_steps": ("INT", {"default": 0, "min": 0, "tooltip": "for zero init, starts from 0 so first step is always zeroed out if use_zero_init enabled"}),
                }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"
    DESCRIPTION = "https://github.com/WeichenFan/CFG-Zero-star"
    CATEGORY = "KJNodes/experimental"
    EXPERIMENTAL = True

    def patch(self, model, use_zero_init, zero_init_steps):
        def cfg_zerostar(args):
            #zero init
            cond = args["cond"]
            timestep = args["timestep"]
            sigmas = args["model_options"]["transformer_options"]["sample_sigmas"]
            matched_step_index = (sigmas == timestep[0]).nonzero()
            if len(matched_step_index) > 0:
                current_step_index = matched_step_index.item()
            else:
                for i in range(len(sigmas) - 1):
                    if (sigmas[i] - timestep[0]) * (sigmas[i + 1] - timestep[0]) <= 0:
                        current_step_index = i
                        break
                else:
                    current_step_index = 0

            if (current_step_index <= zero_init_steps) and use_zero_init:
                return cond * 0
                        
            uncond = args["uncond"]
            cond_scale = args["cond_scale"]
                
            batch_size = cond.shape[0]

            positive_flat = cond.view(batch_size, -1)
            negative_flat = uncond.view(batch_size, -1)

            dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
            squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
            alpha = dot_product / squared_norm
            alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1)))

            noise_pred = uncond * alpha + cond_scale * (cond - uncond * alpha)
            return noise_pred

        m = model.clone()
        m.set_model_sampler_cfg_function(cfg_zerostar)
        return (m, )

class GGUFLoaderKJ(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        # Get GGUF models safely, fallback to empty list if unet_gguf folder doesn't exist
        try:
            gguf_models = folder_paths.get_filename_list("unet_gguf")
            ltx2_connector_models = folder_paths.get_filename_list("text_encoders")
            ltx2_connector_models = [m for m in ltx2_connector_models if "connector" in m.lower()]
        except KeyError:
            gguf_models = []
            ltx2_connector_models = []

        return io.Schema(
            node_id="GGUFLoaderKJ",
            category="KJNodes/model_loaders",
            description="Loads a GGUF model with advanced options, requires [ComfyUI-GGUF](https://github.com/city96/ComfyUI-GGUF) to be installed.",
            is_experimental=True,
            inputs=[
                io.Combo.Input("model_name", options=gguf_models),
                io.Combo.Input("extra_model_name", options=gguf_models + ltx2_connector_models + ["none"], default="none", tooltip="An extra gguf model to load and merge into the main model, for example VACE module"),
                io.Combo.Input("dequant_dtype", options=["default", "target", "float32", "float16", "bfloat16"], default="default"),
                io.Combo.Input("patch_dtype", options=["default", "target", "float32", "float16", "bfloat16"], default="default"),
                io.Boolean.Input("patch_on_device", default=False),
                io.Boolean.Input("enable_fp16_accumulation", default=False, tooltip="Enable torch.backends.cuda.matmul.allow_fp16_accumulation, required minimum pytorch version 2.7.1"),
                io.Combo.Input("attention_override", options=["none", "sdpa", "sageattn", "xformers", "flashattn"], default="none", tooltip="Overrides the used attention implementation, requires the respective library to be installed"),

            ],
            outputs=[io.Model.Output(),],
        )

    def attention_override_pytorch(func, *args, **kwargs):
        new_attention = comfy.ldm.modules.attention.attention_pytorch
        return new_attention.__wrapped__(*args, **kwargs)
    def attention_override_sage(func, *args, **kwargs):
        new_attention = comfy.ldm.modules.attention.attention_sage
        return new_attention.__wrapped__(*args, **kwargs)
    def attention_override_xformers(func, *args, **kwargs):
        new_attention = comfy.ldm.modules.attention.attention_xformers
        return new_attention.__wrapped__(*args, **kwargs)
    def attention_override_flash(func, *args, **kwargs):
        new_attention = comfy.ldm.modules.attention.attention_flash
        return new_attention.__wrapped__(*args, **kwargs)

    ATTENTION_OVERRIDES = {
        "sdpa": attention_override_pytorch,
        "sageattn": attention_override_sage,
        "xformers": attention_override_xformers,
        "flashattn": attention_override_flash,
    }


    @classmethod
    def _get_gguf_module(cls):
        """Import GGUF module with version validation"""
        for key, mod in sys.modules.items():
            if key.endswith("ComfyUI-GGUF") or key.endswith("comfyui-gguf"):
                if hasattr(mod, "ops") and hasattr(mod, "nodes"):
                    return mod

        gguf_path = os.path.join(folder_paths.folder_names_and_paths["custom_nodes"][0][0], "ComfyUI-GGUF")
        for module_name in ["ComfyUI-GGUF", "custom_nodes.ComfyUI-GGUF", "comfyui-gguf", "custom_nodes.comfyui-gguf", gguf_path, gguf_path.lower()]:
            try:
                module = importlib.import_module(module_name)
                return module
            except ImportError:
                continue

        raise ImportError(
            "Compatible ComfyUI-GGUF not found. "
            "Please install/update from: https://github.com/city96/ComfyUI-GGUF"
        )

    @classmethod
    def execute(cls, model_name, extra_model_name, dequant_dtype, patch_dtype, patch_on_device, attention_override, enable_fp16_accumulation):
        gguf_nodes = cls._get_gguf_module()
        ops = gguf_nodes.ops.GGMLOps()

        def set_linear_dtype(attr, value):
            if value == "default":
                setattr(ops.Linear, attr, None)
            elif value == "target":
                setattr(ops.Linear, attr, value)
            else:
                setattr(ops.Linear, attr, getattr(torch, value))

        set_linear_dtype("dequant_dtype", dequant_dtype)
        set_linear_dtype("patch_dtype", patch_dtype)

        # init model
        extra = {}
        model_path = folder_paths.get_full_path("unet", model_name)
        try:
            sd, extra = gguf_nodes.loader.gguf_sd_loader(model_path)
        except TypeError:
            sd = gguf_nodes.loader.gguf_sd_loader(model_path)

        if extra_model_name is not None and extra_model_name != "none":
            if extra_model_name.endswith(".gguf"):
                extra_model_full_path = folder_paths.get_full_path("unet", extra_model_name)
                try:
                    extra_model, _ = gguf_nodes.loader.gguf_sd_loader(extra_model_full_path)
                except TypeError:
                    extra_model = gguf_nodes.loader.gguf_sd_loader(extra_model_full_path)
            elif "connector" in extra_model_name.lower():
                extra_model_full_path = folder_paths.get_full_path("text_encoders", extra_model_name)
                extra_model = comfy.utils.load_torch_file(extra_model_full_path)
                diffusion_model_prefix = comfy.model_detection.unet_prefix_from_state_dict(extra_model)
                if diffusion_model_prefix == "model.diffusion_model.":
                    temp_sd = comfy.utils.state_dict_prefix_replace(extra_model, {diffusion_model_prefix: ""}, filter_keys=True)
                    if len(temp_sd) > 0:
                        extra_model = temp_sd
            else:
                raise ValueError("Extra model must also be a .gguf file")
            sd.update(extra_model)

        model = comfy.sd.load_diffusion_model_state_dict(
            sd, model_options={"custom_operations": ops}, metadata=extra.get("metadata", {})
        )
        if model is None:
            raise RuntimeError(f"ERROR: Could not detect model type of: {model_path}")

        model = gguf_nodes.nodes.GGUFModelPatcher.clone(model)
        model.patch_on_device = patch_on_device

        # attention override
        if attention_override in cls.ATTENTION_OVERRIDES:
            model.model_options["transformer_options"]["optimized_attention_override"] = cls.ATTENTION_OVERRIDES[attention_override]

        if enable_fp16_accumulation:
            if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
                torch.backends.cuda.matmul.allow_fp16_accumulation = True
            else:
                raise RuntimeError("Failed to set fp16 accumulation, requires pytorch version 2.7.1 or higher")
        else:
            if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
                torch.backends.cuda.matmul.allow_fp16_accumulation = False

        return io.NodeOutput(model,)

try:
    from torch.nn.attention.flex_attention import flex_attention, BlockMask
except ImportError:
    flex_attention = None
    BlockMask = None

class NABLA_AttentionKJ():
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "model": ("MODEL",),
            "latent": ("LATENT", {"tooltip": "Only used to get the latent shape"}),
            "window_time": ("INT", {"default": 11, "min": 1, "tooltip": "Temporal attention window size"}),
            "window_width": ("INT", {"default": 3, "min": 1, "tooltip": "Spatial attention window size"}),
            "window_height": ("INT", {"default": 3, "min": 1, "tooltip": "Spatial attention window size"}),
            "sparsity": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01}),
            "torch_compile": ("BOOLEAN", {"default": True, "tooltip": "Most likely required for reasonable memory usage"})
        },
        }

    RETURN_TYPES = ("MODEL", )
    FUNCTION = "patch"
    DESCRIPTION = "Experimental node for patching attention mode to use NABLA sparse attention for video models, currently only works with Kadinsky5"
    CATEGORY = "KJNodes/experimental"

    def patch(self, model, latent, window_time, window_width, window_height, sparsity, torch_compile):
        if flex_attention is None or BlockMask is None:
            raise RuntimeError("can't import flex_attention from torch.nn.attention, requires newer pytorch version")

        model_clone = model.clone()
        samples = latent["samples"]

        sparse_params = get_sparse_params(samples, window_time, window_height, window_width, sparsity)
        nabla_attention = NABLA_Attention(sparse_params)

        def attention_override_nabla(func, *args, **kwargs):
            return nabla_attention(*args, **kwargs)

        if torch_compile:
            attention_override_nabla = torch.compile(attention_override_nabla, mode="max-autotune-no-cudagraphs", dynamic=True)

        # attention override
        model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_nabla

        return model_clone,


class NABLA_Attention():
    def __init__(self, sparse_params):
        self.sparse_params = sparse_params

    def __call__(self, q, k, v, heads, **kwargs):
        if q.shape[-2] < 3000 or k.shape[-2] < 3000:
            return optimized_attention(q, k, v, heads, **kwargs)
        block_mask = self.nablaT_v2(q, k, self.sparse_params["sta_mask"], thr=self.sparse_params["P"])
        out = flex_attention(q, k, v, block_mask=block_mask).transpose(1, 2).contiguous().flatten(-2, -1)
        return out

    def nablaT_v2(self, q, k, sta, thr=0.9):
        # Map estimation
        BLOCK_SIZE = 64
        B, h, S, D = q.shape
        s1 = S // BLOCK_SIZE
        qa = q.reshape(B, h, s1, BLOCK_SIZE, D).mean(-2)
        ka = k.reshape(B, h, s1, BLOCK_SIZE, D).mean(-2).transpose(-2, -1)
        map = qa @ ka

        map = torch.softmax(map / math.sqrt(D), dim=-1)
        # Map binarization
        vals, inds = map.sort(-1)
        cvals = vals.cumsum_(-1)
        mask = (cvals >= 1 - thr).int()
        mask = mask.gather(-1, inds.argsort(-1))

        mask = torch.logical_or(mask, sta)

        # BlockMask creation
        kv_nb = mask.sum(-1).to(torch.int32)
        kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
        return BlockMask.from_kv_blocks(torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=BLOCK_SIZE, mask_mod=None)

def fast_sta_nabla(T, H, W, wT=3, wH=3, wW=3):
    l = torch.Tensor([T, H, W]).amax()
    r = torch.arange(0, l, 1, dtype=torch.int16, device=mm.get_torch_device())
    mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
    sta_t, sta_h, sta_w = (
        mat[:T, :T].flatten(),
        mat[:H, :H].flatten(),
        mat[:W, :W].flatten(),
    )
    sta_t = sta_t <= wT // 2
    sta_h = sta_h <= wH // 2
    sta_w = sta_w <= wW // 2
    sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten()
    sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2)
    return sta.reshape(T * H * W, T * H * W)


def get_sparse_params(x, wT, wH, wW, sparsity=0.9):
    B, C, T, H, W = x.shape
    #print("x shape:", x.shape)
    patch_size = (1, 2, 2)
    T, H, W = (
        T // patch_size[0],
        H // patch_size[1],
        W // patch_size[2],
    )
    sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW)
    sparse_params = {
        "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0),
        "to_fractal": True,
        "P": sparsity,
        "wT": wT,
        "wH": wH,
        "wW": wW,
        "add_sta": True,
        "visual_shape": (T, H, W),
        "method": "topcdf",
    }

    return sparse_params

from comfy.comfy_types.node_typing import IO
class StartRecordCUDAMemoryHistory():
    # @classmethod
    # def IS_CHANGED(s):
    #     return True

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "input": (IO.ANY,),
                "enabled": (["all", "state", "None"], {"default": "all", "tooltip": "None: disable, 'state': keep info for allocated memory, 'all': keep history of all alloc/free calls"}),
                "context": (["all", "state", "alloc", "None"], {"default": "all", "tooltip": "None: no tracebacks, 'state': tracebacks for allocated memory, 'alloc': for alloc calls, 'all': for free calls"}),
                "stacks": (["python", "all"], {"default": "all", "tooltip": "'python': Python/TorchScript/inductor frames, 'all': also C++ frames"}),
                "max_entries": ("INT", {"default": 100000, "min": 1000, "max": 10000000, "tooltip": "Maximum number of entries to record"}),
            },
        }

    RETURN_TYPES = (IO.ANY, )
    RETURN_NAMES = ("input", "output_path",)
    FUNCTION = "start"
    CATEGORY = "KJNodes/memory"
    DESCRIPTION = "THIS NODE ALWAYS RUNS. Starts recording CUDA memory allocation history, can be ended and saved with EndRecordCUDAMemoryHistory. "

    def start(self, input, enabled, context, stacks, max_entries):
        mm.soft_empty_cache()
        torch.cuda.reset_peak_memory_stats(mm.get_torch_device())
        torch.cuda.memory._record_memory_history(
            max_entries=max_entries,
            enabled=enabled if enabled != "None" else None,
            context=context if context != "None" else None,
            stacks=stacks
        )
        return input,

class EndRecordCUDAMemoryHistory():
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "input": (IO.ANY,),
            "output_path": ("STRING", {"default": "comfy_cuda_memory_history"}, "Base path for saving the CUDA memory history file, timestamp and .pt extension will be added"),
        },
        }

    RETURN_TYPES = (IO.ANY, "STRING",)
    RETURN_NAMES = ("input", "output_path",)
    FUNCTION = "end"
    CATEGORY = "KJNodes/memory"
    DESCRIPTION = "Records CUDA memory allocation history between start and end, saves to a file that can be analyzed here: https://docs.pytorch.org/memory_viz or with VisualizeCUDAMemoryHistory node"

    def end(self, input, output_path):
        mm.soft_empty_cache()
        time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        output_path = f"{output_path}{time}.pt"
        torch.cuda.memory._dump_snapshot(output_path)
        torch.cuda.memory._record_memory_history(enabled=None)
        return input, output_path

try:
    from server import PromptServer
except ImportError:
    PromptServer = None

class VisualizeCUDAMemoryHistory():
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "snapshot_path": ("STRING", ),
        },
         "hidden": {
                "unique_id": "UNIQUE_ID",
            },
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("output_path",)
    FUNCTION = "visualize"
    CATEGORY = "KJNodes/memory"
    DESCRIPTION = "Visualizes a CUDA memory allocation history file, opens in browser"
    OUTPUT_NODE = True

    def visualize(self, snapshot_path, unique_id):
        import pickle
        from torch.cuda import _memory_viz
        import uuid

        from folder_paths import get_output_directory
        output_dir = get_output_directory()

        with open(snapshot_path, "rb") as f:
            snapshot = pickle.load(f)

        html = _memory_viz.trace_plot(snapshot)
        html_filename = f"cuda_memory_history_{uuid.uuid4().hex}.html"
        output_path = os.path.join(output_dir, "memory_history", html_filename)
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        with open(output_path, "w", encoding="utf-8") as f:
            f.write(html)

        api_url = f"http://localhost:8188/api/view?type=output&filename={html_filename}&subfolder=memory_history"

        # Progress UI
        if unique_id and PromptServer is not None:
            try:
                PromptServer.instance.send_progress_text(
                    api_url,
                    unique_id
                )
            except:
                pass

        return api_url,


class ModelMemoryUseReportPatch:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "model": ("MODEL",),
        }}

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"
    DESCRIPTION = "Adds callbacks to model to report memory usage during after sampling"
    EXPERIMENTAL = True
    CATEGORY = "KJNodes/memory"

    def patch(self, model):
        model_clone = model.clone()
        device = mm.get_torch_device()

        def reset_mem_usage(model):
            torch.cuda.reset_peak_memory_stats(device)
        def report_mem_usage(model):
            max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
            max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
            logging.info(f"Sampling max allocated memory: {max_memory=:.3f} GB")
            logging.info(f"Sampling max reserved memory: {max_reserved=:.3f} GB")

        model_clone.add_callback(CallbacksMP.ON_PRE_RUN, reset_mem_usage)
        model_clone.add_callback(CallbacksMP.ON_CLEANUP, report_mem_usage)

        return (model_clone,)


class MemoryUsageFactorAdjustWrapper:
    def __init__(self, memory_usage_factor, original_factor):
        self.memory_usage_factor = memory_usage_factor
        self.original_factor = original_factor

    def __call__(self, executor, model, noise_shape: torch.Tensor, *args, **kwargs):
        m = model.clone()
        m.model.memory_usage_factor = self.memory_usage_factor
        logging.info(f"Temporarily set memory usage factor to {self.memory_usage_factor}")
        try:
            result = executor(m, noise_shape, *args, **kwargs)
        finally:
            logging.info(f"Model memory usage calculated, restoring original memory usage factor: {self.original_factor}")
            m.model.memory_usage_factor = self.original_factor
        return result

class ModelMemoryUsageFactorOverride:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "model": ("MODEL",),
            "memory_usage_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.001}),
        }}

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"
    DESCRIPTION = "Overrides the memory usage factor of the model during sampling."
    EXPERIMENTAL = True
    CATEGORY = "KJNodes/memory"

    def patch(self, model, memory_usage_factor):
        model_clone = model.clone()
        original_memory_usage_factor = model_clone.model.memory_usage_factor
        logging.info(f"Original memory usage factor: {original_memory_usage_factor}")

        wrapper = MemoryUsageFactorAdjustWrapper(memory_usage_factor, original_memory_usage_factor)
        model_clone.add_wrapper_with_key(
            comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
            "memory_usage_factor_adjust_prepare_sampling",
            wrapper
        )
        return (model_clone,)

def wan_ffn_chunked_forward(self, x):
    if x.shape[1] > self.dim_threshold:
        chunks = torch.chunk(x, self.num_chunks, dim=1)
        output_chunks = []
        for chunk in chunks:
            output_chunks.append(torch.nn.Sequential.forward(self, chunk))
        chunked = torch.cat(output_chunks, dim=1)
        return chunked
    else:
        return torch.nn.Sequential.forward(self, x)

class WanffnChunkPatch:
    def __init__(self, num_chunks, dim_threshold=4096):
        self.num_chunks = num_chunks
        self.dim_threshold = dim_threshold

    def __get__(self, obj, objtype=None):
        def wrapped_forward(self_module, *args, **kwargs):
            self_module.num_chunks = self.num_chunks
            self_module.dim_threshold = self.dim_threshold
            return wan_ffn_chunked_forward(self_module, *args, **kwargs)
        return types.MethodType(wrapped_forward, obj)

class WanChunkFeedForward(io.ComfyNode):

    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="WanChunkFeedForward",
            display_name="Wan Chunk FeedForward",
            category="KJNodes/wan",
            description="EXPERIMENTAL AND MAY CHANGE THE MODEL OUTPUT!! Chunks feedforward activations to reduce peak VRAM usage.",
            is_experimental=True,
            inputs=[
                io.Model.Input("model"),
                io.Int.Input("chunks", default=2, min=1, max=100, step=1, tooltip="Number of chunks to split the feedforward activations into to reduce peak VRAM usage."),
                io.Int.Input("dim_threshold", default=4096, min=1024, max=16384, step=256, tooltip="Dimension threshold above which to apply chunking."),
            ],
            outputs=[
                io.Model.Output(display_name="model"),
            ],
        )

    @classmethod
    def execute(cls, model, chunks, dim_threshold) -> io.NodeOutput:
        if chunks == 1:
            return io.NodeOutput(model)

        model_clone = model.clone()
        diffusion_model = model_clone.get_model_object("diffusion_model")

        for idx, block in enumerate(diffusion_model.blocks):
            patched_ffn = WanffnChunkPatch(chunks, dim_threshold).__get__(block.ffn, block.__class__)
            model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.ffn.forward", patched_ffn)

        return io.NodeOutput(model_clone)


# Ideogram4 peak-VRAM patches (FFN sequence chunking + bf16 RoPE)
from comfy.ldm.lumina.model import FeedForward as _Ideogram4FeedForward
from comfy.ldm.modules.attention import optimized_attention_masked as _ideogram4_attn


def ideogram4_ffn_chunked_forward(self, x):
    # x: (B, L, dim). Chunk over the token dim so the (B, L, hidden) SwiGLU
    if x.shape[1] > self.kj_dim_threshold and self.kj_num_chunks > 1:
        out = [_Ideogram4FeedForward.forward(self, c) for c in torch.chunk(x, self.kj_num_chunks, dim=1)]
        return torch.cat(out, dim=1)
    return _Ideogram4FeedForward.forward(self, x)


class Ideogram4FFNChunkPatch:
    def __init__(self, num_chunks, dim_threshold):
        self.num_chunks = num_chunks
        self.dim_threshold = dim_threshold

    def __get__(self, obj, objtype=None):
        def wrapped_forward(self_module, *args, **kwargs):
            self_module.kj_num_chunks = self.num_chunks
            self_module.kj_dim_threshold = self.dim_threshold
            return ideogram4_ffn_chunked_forward(self_module, *args, **kwargs)
        return types.MethodType(wrapped_forward, obj)


def _ideogram4_apply_rope_lowp(xq, xk, freqs_cis):
    # (bf16/fp16) instead of being upcast to fp32 -> ~halves RoPE activation memory.
    cos = freqs_cis[0].to(xq.dtype)
    sin = freqs_cis[1].to(xq.dtype)
    nsin = freqs_cis[2].to(xq.dtype)

    q_embed = xq * cos
    qs = q_embed.shape[-1] // 2
    q_embed[..., :qs].addcmul_(xq[..., qs:], nsin)
    q_embed[..., qs:].addcmul_(xq[..., :qs], sin)

    k_embed = xk * cos
    ks = k_embed.shape[-1] // 2
    k_embed[..., :ks].addcmul_(xk[..., ks:], nsin)
    k_embed[..., ks:].addcmul_(xk[..., :ks], sin)
    return q_embed, k_embed


def ideogram4_attention_lowp_rope_forward(self, x, attn_mask, freqs_cis, transformer_options={}):
    batch_size, seq_len, _ = x.shape
    q, k, v = self.qkv(x).view(batch_size, seq_len, 3, self.num_heads, self.head_dim).unbind(dim=2)
    q = self.norm_q(q).transpose(1, 2)
    k = self.norm_k(k).transpose(1, 2)
    v = v.transpose(1, 2)
    q, k = _ideogram4_apply_rope_lowp(q, k, freqs_cis)
    out = _ideogram4_attn(q, k, v, self.num_heads, attn_mask, skip_reshape=True, transformer_options=transformer_options)
    return self.o(out)


class Ideogram4RopePatch:
    def __get__(self, obj, objtype=None):
        return types.MethodType(ideogram4_attention_lowp_rope_forward, obj)


class Ideogram4OptimizationsKJ(io.ComfyNode):

    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="Ideogram4OptimizationsKJ",
            display_name="Ideogram4 Optimizations KJ",
            category="KJNodes/experimental",
            description="EXPERIMENTAL AND MAY CHANGE THE MODEL OUTPUT!! Reduces peak VRAM of the Ideogram4 forward. "
                        "chunk_ffn splits the SwiGLU activations over the token dim; bf16_rope applies RoPE in the model "
                        "dtype instead of upcasting to fp32. Both target the two largest transient tensors in the block.",
            is_experimental=True,
            inputs=[
                io.Model.Input("model"),
                io.Boolean.Input("chunk_ffn", default=True,
                                 tooltip="Chunk the feedforward activations over the sequence dim to cap the (B, L, hidden) intermediate."),
                io.Int.Input("ffn_chunks", default=2, min=1, max=64, step=1,
                             tooltip="Number of chunks to split the feedforward sequence into. More chunks = lower peak, slightly more overhead."),
                io.Int.Input("ffn_seq_threshold", default=1024, min=256, max=65536, step=256,
                             tooltip="Only chunk when the token sequence length exceeds this (skips chunking for tiny sequences)."),
                io.Boolean.Input("bf16_rope", default=True,
                                 tooltip="Apply RoPE in the input dtype instead of fp32. ~Halves RoPE activation memory; matches the HF reference dtype."),
            ],
            outputs=[
                io.Model.Output(display_name="model"),
            ],
        )

    @classmethod
    def execute(cls, model, chunk_ffn, ffn_chunks, ffn_seq_threshold, bf16_rope) -> io.NodeOutput:
        if not chunk_ffn and not bf16_rope:
            return io.NodeOutput(model)

        m = model.clone()
        diffusion_model = m.get_model_object("diffusion_model")

        layers = getattr(diffusion_model, "layers", None)
        if not layers or not hasattr(layers[0], "feed_forward") or not hasattr(layers[0], "attention"):
            logging.warning("Ideogram4OptimizationsKJ: model does not look like Ideogram4 "
                            "(expected diffusion_model.layers[*].feed_forward/.attention); returning model unchanged.")
            return io.NodeOutput(model)

        for idx, block in enumerate(layers):
            if chunk_ffn and ffn_chunks > 1:
                patched_ffn = Ideogram4FFNChunkPatch(ffn_chunks, ffn_seq_threshold).__get__(block.feed_forward, block.feed_forward.__class__)
                m.add_object_patch(f"diffusion_model.layers.{idx}.feed_forward.forward", patched_ffn)
            if bf16_rope:
                patched_attn = Ideogram4RopePatch().__get__(block.attention, block.attention.__class__)
                m.add_object_patch(f"diffusion_model.layers.{idx}.attention.forward", patched_attn)

        return io.NodeOutput(m)


from comfy.samplers import KSAMPLER
from comfy.k_diffusion.sampling import to_d

def sample_selfrefinevideo(model, x, sigmas, stochastic_step_map, certain_percentage=0.999, uncertainty_threshold=0.25, extra_args=None, callback=None, disable=None, verbose=False, video_shape=None, seed=None):
    extra_args = {} if extra_args is None else extra_args
    sigma_in = x.new_ones([x.shape[0]])

    if seed is not None:
        generator = torch.Generator(torch.device("cpu")).manual_seed(seed)

    pbar = tqdm(total=len(sigmas) - 1, disable=disable, desc="Sampling")

    for i in range(len(sigmas) - 1):

        # Get stochastic steps for this noise level
        current_num_anneal_steps = stochastic_step_map.get(i, 0)
        use_stochastic = current_num_anneal_steps > 0
        m = current_num_anneal_steps + 1 if use_stochastic else 1

        sigma, sigma_next = sigmas[i], sigmas[i + 1]

        prev_certain_mask = None
        prev_denoised = None
        prev_denoised_full = None
        prev_x_next = None
        prev_x_next_video = None
        is_certain = False

        for ii in range(m):
            if m > 1:
                pbar.set_description(f"Step {i}/{len(sigmas)-1} (substep {ii+1}/{m})")
            # Early exit if certain threshold reached
            if is_certain:
                x = prev_x_next
                break

            # Determine input
            noise = torch.randn(x.shape, device=torch.device("cpu"), generator=generator).to(x)
            x_in = x if ii == 0 else (1.0 - sigma) * prev_denoised_full + sigma * noise
            if ii > 0:
                x = x_in

            denoised = model(x_in, sigmas[i] * sigma_in, **extra_args)

            if callback is not None:
                callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})

            # Compute next latents
            d = to_d(x, sigma, denoised)
            x_next = x + (sigma_next - sigma) * d

            # Separate video and audio if joint model
            if d.ndim == 3 and video_shape is not None:
                cut = math.prod(video_shape[1:])
                denoised_video = denoised[:, :, :cut].reshape([denoised.shape[0]] + list(video_shape)[1:])
                x_next_video = x_next[:, :, :cut].reshape([denoised.shape[0]] + list(video_shape)[1:])
                denoised_audio = denoised[:, :, cut:]
                x_next_audio = x_next[:, :, cut:]
                if verbose:
                    tqdm.write(f"Video shape: {denoised_video.shape}, Audio shape: {denoised_audio.shape}")
            else:
                denoised_video = denoised
                x_next_video = x_next
                denoised_audio = None
                x_next_audio = None

            # Stochastic sampling with uncertainty masking
            if use_stochastic and prev_denoised is not None:
                # Compute uncertainty and masking on video part
                diff = denoised_video - prev_denoised
                uncertainty = torch.sqrt(torch.sum(diff ** 2, dim=1)) / denoised_video.shape[1]
                certain_mask = uncertainty < uncertainty_threshold

                if verbose:
                    tqdm.write(f"Step {i}/{len(sigmas)-1} substep {ii+1}/{m}:")
                    tqdm.write(f"Uncertainty: min {uncertainty.min():.4f}, max {uncertainty.max():.4f}, threshold {uncertainty_threshold}")
                    tqdm.write(f"Certain pixels: {certain_mask.sum()}/{certain_mask.numel()} = {certain_mask.sum()/certain_mask.numel():.4f}")

                # Update certain mask (union with previous)
                if prev_certain_mask is not None:
                    certain_mask = certain_mask | prev_certain_mask

                # Check certainty threshold
                if certain_mask.sum() / certain_mask.numel() > certain_percentage:
                    is_certain = True
                    if verbose:
                        tqdm.write(f"{ii}/{current_num_anneal_steps}: Certain region is more than {certain_percentage}, we are certain")

                # Apply masking to video
                certain_mask_float = certain_mask.float().unsqueeze(1)
                x_next_video = certain_mask_float * prev_x_next_video + (1.0 - certain_mask_float) * x_next_video
                denoised_video = certain_mask_float * prev_denoised + (1.0 - certain_mask_float) * denoised_video

                # Reconstruct full latents by replacing the video portion
                if x_next_audio is not None:
                    # Flatten masked video back to match original format and replace video portion
                    x_next = x_next.clone()
                    x_next[:, :, :cut] = x_next_video.reshape([x_next_video.shape[0], x_next.shape[1], -1])
                    # Also reconstruct full denoised for next iteration input
                    denoised_full = denoised.clone()
                    denoised_full[:, :, :cut] = denoised_video.reshape([denoised_video.shape[0], denoised.shape[1], -1])
                else:
                    # No audio separation
                    x_next = x_next_video
                    denoised_full = denoised_video

                prev_certain_mask = certain_mask
                prev_denoised = denoised_video
                prev_denoised_full = denoised_full
                prev_x_next_video = x_next_video
                prev_x_next = x_next
            elif use_stochastic:
                # For first stochastic step, create denoised_full if we have audio
                if x_next_audio is not None:
                    denoised_full = denoised.clone()
                    denoised_full[:, :, :cut] = denoised_video.reshape([denoised_video.shape[0], denoised.shape[1], -1])
                else:
                    denoised_full = denoised_video

                prev_certain_mask = None
                prev_denoised = denoised_video
                prev_denoised_full = denoised_full
                prev_x_next_video = x_next_video
                prev_x_next = x_next

            # Update x for final step
            if use_stochastic and ii == m - 1:
                x = prev_x_next
            elif not use_stochastic:
                x = x_next

        pbar.update(1)
        if m == 1:
            pbar.set_description("Sampling")
    pbar.close()
    return x

class SamplerSelfRefineVideo(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        default_ranges = [
            (2, 5, 3),   # Range 1
            (6, 14, 1),  # Range 2
        ]

        options = []

        # Option 1: 2 ranges
        range_inputs_2 = []
        for i in range(1, 3):
            start_default, end_default, steps_default = default_ranges[i - 1]
            range_inputs_2.extend([
                io.Int.Input(f"start_step{i}", default=start_default, min=0, max=999, step=1, tooltip=f"Start step for range {i}"),
                io.Int.Input(f"end_step{i}", default=end_default, min=0, max=999, step=1, tooltip=f"End step for range {i}"),
                io.Int.Input(f"steps_{i}", default=steps_default, min=1, max=100, step=1, tooltip=f"Number of P&P steps for range {i}"),
            ])
        options.append(io.DynamicCombo.Option(key="2 ranges", inputs=range_inputs_2))

        # Option 2: 1 range
        range_inputs_1 = []
        for i in range(1, 2):
            start_default, end_default, steps_default = default_ranges[i - 1]
            range_inputs_1.extend([
                io.Int.Input(f"start_step{i}", default=start_default, min=0, max=999, step=1, tooltip=f"Start step for range {i}"),
                io.Int.Input(f"end_step{i}", default=end_default, min=0, max=999, step=1, tooltip=f"End step for range {i}"),
                io.Int.Input(f"steps_{i}", default=steps_default, min=1, max=100, step=1, tooltip=f"Number of P&P steps for range {i}"),
            ])
        options.append(io.DynamicCombo.Option(key="1 range", inputs=range_inputs_1))

        # Option 3: Manual string input
        options.append(io.DynamicCombo.Option(
            key="from_string",
            inputs=[
                io.String.Input(
                    "stochastic_plan",
                    default="2-5:3,6-14:1",
                    multiline=True,
                    tooltip="Format: 'start-end:steps,start-end:steps' e.g. '2-5:3,6-14:1'"
                )
            ]
        ))
        return io.Schema(
            node_id="SamplerSelfRefineVideo",
            category="KJNodes/samplers",
            description="Attempt to implement https://github.com/agwmon/self-refine-video, for testing only, MAY NOT WORK AS INTENDED.",
            is_experimental=True,
            inputs=[
                io.DynamicCombo.Input("input_mode", options=options, tooltip="How to configure the step plan"),
                io.Float.Input("certain_percentage", default=0.999, min=0.0, max=1.0, step=0.001, round=False, tooltip="Percentage of certain pixels to consider the frame as certain and skip further refinement"),
                io.Float.Input("uncertainty_threshold", default=0.2, min=0.0, max=1.0, step=0.01, round=False, tooltip="Threshold of uncertainty to consider a pixel uncertain"),
                io.Boolean.Input("verbose", default=False, tooltip="Enable verbose logging during sampling"),
                io.Latent.Input("latent", optional=True, tooltip="Optional latent input to get input shape for LTX2 audio/video separation"),
                io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, step=1, tooltip="Seed for stochastic sampling"),
            ],
            outputs=[io.Sampler.Output()]
        )

    @classmethod
    def execute(cls, input_mode, certain_percentage, uncertainty_threshold, seed, verbose, latent=None) -> io.NodeOutput:
        video_shape = None
        if latent is not None:
            video_shape = latent["samples"].shape

        range_keys = sorted([k for k in input_mode.keys() if k.startswith('start_step')])
        stochastic_step_map = {}
        if "stochastic_plan" in input_mode:
            # Parse manual string format: "2-5:3,6-14:1"
            plan_str = input_mode["stochastic_plan"]
            ranges = plan_str.split(",")
            for range_spec in ranges:
                range_spec = range_spec.strip()
                if not range_spec:
                    continue
                try:
                    range_part, steps_part = range_spec.split(":")
                    start, end = range_part.split("-")
                    start, end, steps = int(start), int(end), int(steps_part)
                    for idx in range(start, end + 1):
                        stochastic_step_map[idx] = steps
                except ValueError:
                    raise ValueError(f"Invalid format in stochastic_plan: '{range_spec}'. Expected format: 'start-end:steps'")
        else:
            range_keys = [k for k in input_mode.keys() if k.startswith('start_step')]
            for start_key in range_keys:
                i = start_key.replace('start_step', '')
                start = input_mode.get(f"start_step{i}")
                end = input_mode.get(f"end_step{i}")
                steps = input_mode.get(f"steps_{i}")

                if start is not None and end is not None and steps is not None:
                    for idx in range(start, end + 1):
                        stochastic_step_map[idx] = steps

        sampler = KSAMPLER(sample_selfrefinevideo, {
            "stochastic_step_map": stochastic_step_map,
            "certain_percentage": certain_percentage,
            "uncertainty_threshold": uncertainty_threshold,
            "verbose": verbose,
            "video_shape": video_shape,
            "seed": seed,
        })
        return io.NodeOutput(sampler)


# Multi-feature linear bias corrector for PiD (Flux2 backbone, 4-step).
# Calibrated on 124 natural-image samples (LOO-CV per-channel RMSE 0.027/0.026/0.024).
# Held-out validation on 20 unseen images: 60% reduction in total drift vs uncorrected.
# Features per row: [R_mean, G_mean, B_mean, R_std, G_std, B_std,
#                    R_mean*G_mean, R_mean*B_mean, G_mean*B_mean, intercept(1.0)]
# Columns: predicted bias for R, G, B (subtract from x0_pred at step 0).
PID_BIAS_COEF_FLUX2 = torch.tensor([
    [-0.130306, +0.127184, +0.014058],  # R_mean
    [-0.053279, -0.408929, +0.004243],  # G_mean
    [-0.009386, +0.109546, -0.134091],  # B_mean
    [-0.033373, -0.011615, -0.026129],  # R_std
    [+0.180052, +0.062021, +0.071317],  # G_std
    [-0.067958, -0.058595, -0.098645],  # B_std
    [-0.248116, -0.240633, -0.105600],  # R_mean*G_mean
    [+0.304035, +0.322566, +0.093224],  # R_mean*B_mean
    [-0.157648, -0.227127, -0.112368],  # G_mean*B_mean
    [-0.062814, +0.030765, +0.062735],  # intercept
], dtype=torch.float32)


class PiDColorBiasCorrection:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "model": ("MODEL",),
            "strength": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01,
                                   "tooltip": "Correction strength. 1.0 = full predicted bias subtracted. <1 = milder, >1 = stronger, 0 = disabled."}),
            "backbone": (["flux2"], {"default": "flux2",
                                     "tooltip": "Calibrated PiD backbone (currently only flux2 — others use the same model but coefficients differ)."}),
        }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"
    CATEGORY = "KJNodes/experimental"
    EXPERIMENTAL = True
    DESCRIPTION = (
        "PiD 4-step decoder color/brightness drift corrector. "
        "Subtracts a per-channel bias from x0_pred at the first sampling step, "
        "using a small linear model calibrated against the model's systematic drift "
        "(model tends to brighten dark scenes and add a blue cast)."
    )

    def patch(self, model, strength, backbone):
        if strength == 0.0 or backbone != "flux2":
            return (model,)
        coef_cpu = PID_BIAS_COEF_FLUX2  # (10, 3)

        def pid_bias_post_cfg(args):
            denoised = args["denoised"]
            # Step detection: only apply at the first sampling step.
            # Use sample_sigmas like CFGZeroStarAndInit for robustness across schedules.
            try:
                sigmas = args["model_options"]["transformer_options"]["sample_sigmas"]
                sigma = args.get("sigma", args.get("timestep"))
                # First step matches sigmas[0]
                if sigma is None or not torch.isclose(sigma.max(), sigmas[0]).item():
                    return denoised
            except (KeyError, AttributeError):
                # Fallback heuristic: PiD's first step has sigma=0.999
                sigma = args.get("sigma")
                if sigma is None or sigma.max().item() < 0.95:
                    return denoised

            coef = coef_cpu.to(denoised.device, dtype=denoised.dtype)
            rgb_m = denoised.mean(dim=(0, 2, 3))
            rgb_s = denoised.std(dim=(0, 2, 3))
            one = torch.tensor(1.0, device=denoised.device, dtype=denoised.dtype)
            feats = torch.stack([
                rgb_m[0], rgb_m[1], rgb_m[2],
                rgb_s[0], rgb_s[1], rgb_s[2],
                rgb_m[0] * rgb_m[1], rgb_m[0] * rgb_m[2], rgb_m[1] * rgb_m[2],
                one,
            ])
            bias = feats @ coef  # (3,)
            return denoised - strength * bias.view(1, 3, 1, 1)

        m = model.clone()
        m.set_model_sampler_post_cfg_function(pid_bias_post_cfg)
        return (m,)
