import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import json
import re
import os
import time
import math
import importlib
import logging

from comfy import model_management
import folder_paths
from nodes import MAX_RESOLUTION
from comfy.utils import common_upscale, ProgressBar, load_torch_file, save_torch_file, state_dict_prefix_replace
from comfy.comfy_types.node_typing import IO
from comfy_api.latest import io, ui
import comfy.latent_formats
import node_helpers
from io import BytesIO

script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
folder_paths.add_model_folder_path("kjnodes_fonts", os.path.join(script_directory, "fonts"))

class BOOLConstant:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "value": ("BOOLEAN", {"default": True}),
        },
        }
    RETURN_TYPES = ("BOOLEAN",)
    RETURN_NAMES = ("value",)
    FUNCTION = "get_value"
    CATEGORY = "KJNodes/constants"
    SEARCH_ALIASES = ["boolean", "value"]

    def get_value(self, value):
        return (value,)
    
class INTConstant:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "value": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff}),
        },
        }
    RETURN_TYPES = ("INT",)
    RETURN_NAMES = ("value",)
    FUNCTION = "get_value"
    CATEGORY = "KJNodes/constants"
    SEARCH_ALIASES = ["integer", "value"]

    def get_value(self, value):
        return (value,)

class FloatConstant:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "value": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.00001}),
        },
        }

    RETURN_TYPES = ("FLOAT",)
    RETURN_NAMES = ("value",)
    FUNCTION = "get_value"
    CATEGORY = "KJNodes/constants"
    SEARCH_ALIASES = ["float", "value"]

    def get_value(self, value):
        return (round(value, 6),)

class StringConstant:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "string": ("STRING", {"default": '', "multiline": False}),
            }
        }
    RETURN_TYPES = ("STRING",)
    FUNCTION = "passtring"
    CATEGORY = "KJNodes/constants"
    SEARCH_ALIASES = ["text", "value"]

    def passtring(self, string):
        return (string, )

class StringConstantMultiline:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "string": ("STRING", {"default": "", "multiline": True}),
                "strip_newlines": ("BOOLEAN", {"default": True}),
            }
        }
    RETURN_TYPES = ("STRING",)
    FUNCTION = "stringify"
    CATEGORY = "KJNodes/constants"
    SEARCH_ALIASES = ["text", "value"]

    def stringify(self, string, strip_newlines):
        new_string = string
        if strip_newlines:
            new_string = new_string.replace('\n', '').strip()
        return (new_string,)


    
class ScaleBatchPromptSchedule:
    
    RETURN_TYPES = ("STRING",)
    FUNCTION = "scaleschedule"
    CATEGORY = "KJNodes/misc"
    DESCRIPTION = """
Scales a batch schedule from Fizz' nodes BatchPromptSchedule
to a different frame count.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "input_str": ("STRING", {"forceInput": True,"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n"}),
                 "old_frame_count": ("INT", {"forceInput": True,"default": 1,"min": 1, "max": 4096, "step": 1}),
                 "new_frame_count": ("INT", {"forceInput": True,"default": 1,"min": 1, "max": 4096, "step": 1}),
                
        },
    } 
    
    def scaleschedule(self, old_frame_count, input_str, new_frame_count):
        pattern = r'"(\d+)"\s*:\s*"(.*?)"(?:,|\Z)'
        frame_strings = dict(re.findall(pattern, input_str))
        
        # Calculate the scaling factor
        scaling_factor = (new_frame_count - 1) / (old_frame_count - 1)
        
        # Initialize a dictionary to store the new frame numbers and strings
        new_frame_strings = {}
        
        # Iterate over the frame numbers and strings
        for old_frame, string in frame_strings.items():
            # Calculate the new frame number
            new_frame = int(round(int(old_frame) * scaling_factor))
            
            # Store the new frame number and corresponding string
            new_frame_strings[new_frame] = string
        
        # Format the output string
        output_str = ', '.join([f'"{k}":"{v}"' for k, v in sorted(new_frame_strings.items())])
        return (output_str,)


class GetLatentsFromBatchIndexed:
    
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "indexedlatentsfrombatch"
    CATEGORY = "KJNodes/latents"
    DESCRIPTION = """
Selects and returns the latents at the specified indices as an latent batch.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "latents": ("LATENT",),
                 "indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}),
                 "latent_format": (["BCHW", "BTCHW", "BCTHW"], {"default": "BCHW"}),
        },
    } 
    
    def indexedlatentsfrombatch(self, latents, indexes, latent_format):
        
        samples = latents.copy()
        latent_samples = samples["samples"] 

        # Parse the indexes string into a list of integers
        index_list = [int(index.strip()) for index in indexes.split(',')]
        
        # Convert list of indices to a PyTorch tensor
        indices_tensor = torch.tensor(index_list, dtype=torch.long)
        
        # Select the latents at the specified indices
        if latent_format == "BCHW":
            chosen_latents = latent_samples[indices_tensor]
        elif latent_format == "BTCHW":
            chosen_latents = latent_samples[:, indices_tensor]
        elif latent_format == "BCTHW":
            chosen_latents = latent_samples[:, :, indices_tensor]

        samples["samples"] = chosen_latents
        return (samples,)
    

class ConditioningMultiCombine:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "inputcount": ("INT", {"default": 2, "min": 2, "max": 20, "step": 1}),
                "operation": (["combine", "concat"], {"default": "combine"}),
                "conditioning_1": ("CONDITIONING", ),
                "conditioning_2": ("CONDITIONING", ),
            },
    }

    RETURN_TYPES = ("CONDITIONING", "INT")
    RETURN_NAMES = ("combined", "inputcount")
    FUNCTION = "combine"
    CATEGORY = "KJNodes/masking/conditioning"
    DESCRIPTION = """
Combines multiple conditioning nodes into one
"""

    def combine(self, inputcount, operation, **kwargs):
        from nodes import ConditioningCombine
        from nodes import ConditioningConcat
        cond_combine_node = ConditioningCombine()
        cond_concat_node = ConditioningConcat()
        cond = kwargs["conditioning_1"]
        for c in range(1, inputcount):
            new_cond = kwargs[f"conditioning_{c + 1}"]
            if operation == "combine":
                cond = cond_combine_node.combine(cond, new_cond)[0]
            elif operation == "concat":
                cond = cond_concat_node.concat(cond, new_cond)[0]
        return (cond, inputcount,)

class AppendStringsToList:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "string1": ("STRING", {"default": '', "forceInput": True}),
                "string2": ("STRING", {"default": '', "forceInput": True}),
            }
        }
    RETURN_TYPES = ("STRING",)
    FUNCTION = "joinstring"
    CATEGORY = "KJNodes/text"

    def joinstring(self, string1, string2):
        if not isinstance(string1, list):
            string1 = [string1]
        if not isinstance(string2, list):
            string2 = [string2]
        
        joined_string = string1 + string2
        return (joined_string, )
    
class JoinStrings:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "delimiter": ("STRING", {"default": ' ', "multiline": False}),
            },
            "optional": {
                "string1": ("STRING", {"default": '', "forceInput": True}),
                "string2": ("STRING", {"default": '', "forceInput": True}),
            }
        }
    RETURN_TYPES = ("STRING",)
    FUNCTION = "joinstring"
    CATEGORY = "KJNodes/text"

    def joinstring(self, delimiter, string1="", string2=""):
        joined_string = string1 + delimiter + string2
        return (joined_string, )
    
class JoinStringMulti:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
                "string_1": ("STRING", {"default": '', "forceInput": True}),
                "delimiter": ("STRING", {"default": ' ', "multiline": False}),
                "return_list": ("BOOLEAN", {"default": False}),
            },
            "optional": {
                "string_2": ("STRING", {"default": '', "forceInput": True}),
            }
    }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("string",)
    FUNCTION = "combine"
    CATEGORY = "KJNodes/text"
    DESCRIPTION = """
Creates single string, or a list of strings, from  
multiple input strings.  
You can set how many inputs the node has,  
with the **inputcount** and clicking update.
"""

    def combine(self, inputcount, delimiter, **kwargs):
        string = kwargs["string_1"]
        return_list = kwargs["return_list"]
        strings = [string] # Initialize a list with the first string
        for c in range(1, inputcount):
            new_string = kwargs.get(f"string_{c + 1}", "")
            if not new_string:
                continue
            if return_list:
                strings.append(new_string) # Add new string to the list
            else:
                string = string + delimiter + new_string
        if return_list:
            return (strings,) # Return the list of strings
        else:
            return (string,) # Return the combined string

class CondPassThrough:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
            },
            "optional": {
                "positive": ("CONDITIONING", ),
                "negative": ("CONDITIONING", ),
            }, 
    }

    RETURN_TYPES = ("CONDITIONING", "CONDITIONING",)
    RETURN_NAMES = ("positive", "negative")
    FUNCTION = "passthrough"
    CATEGORY = "KJNodes/misc"
    DESCRIPTION = """
    Simply passes through the positive and negative conditioning,
    workaround for Set node not allowing bypassed inputs.
"""

    def passthrough(self, positive=None, negative=None):
        return (positive, negative,)

class ModelPassThrough:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {             
            },
            "optional": {
                "model": ("MODEL", ),
            }, 
    }

    RETURN_TYPES = ("MODEL", )
    RETURN_NAMES = ("model",)
    FUNCTION = "passthrough"
    CATEGORY = "KJNodes/misc"
    DESCRIPTION = """
    Simply passes through the model,
    workaround for Set node not allowing bypassed inputs.
"""

    def passthrough(self, model=None):
            return (model,)

def append_helper(t, mask, c, set_area_to_bounds, strength):
        n = [t[0], t[1].copy()]
        _, h, w = mask.shape
        n[1]['mask'] = mask
        n[1]['set_area_to_bounds'] = set_area_to_bounds
        n[1]['mask_strength'] = strength
        c.append(n)  

class ConditioningSetMaskAndCombine:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "positive_1": ("CONDITIONING", ),
                "negative_1": ("CONDITIONING", ),
                "positive_2": ("CONDITIONING", ),
                "negative_2": ("CONDITIONING", ),
                "mask_1": ("MASK", ),
                "mask_2": ("MASK", ),
                "mask_1_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "mask_2_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "set_cond_area": (["default", "mask bounds"],),
            }
        }

    RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
    RETURN_NAMES = ("combined_positive", "combined_negative",)
    FUNCTION = "append"
    CATEGORY = "KJNodes/masking/conditioning"
    DESCRIPTION = """
Bundles multiple conditioning mask and combine nodes into one,functionality is identical to ComfyUI native nodes
"""

    def append(self, positive_1, negative_1, positive_2, negative_2, mask_1, mask_2, set_cond_area, mask_1_strength, mask_2_strength):
        c = []
        c2 = []
        set_area_to_bounds = False
        if set_cond_area != "default":
            set_area_to_bounds = True
        if len(mask_1.shape) < 3:
            mask_1 = mask_1.unsqueeze(0)
        if len(mask_2.shape) < 3:
            mask_2 = mask_2.unsqueeze(0)
        for t in positive_1:
            append_helper(t, mask_1, c, set_area_to_bounds, mask_1_strength)
        for t in positive_2:
            append_helper(t, mask_2, c, set_area_to_bounds, mask_2_strength)
        for t in negative_1:
            append_helper(t, mask_1, c2, set_area_to_bounds, mask_1_strength)
        for t in negative_2:
            append_helper(t, mask_2, c2, set_area_to_bounds, mask_2_strength)
        return (c, c2)

class ConditioningSetMaskAndCombine3:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "positive_1": ("CONDITIONING", ),
                "negative_1": ("CONDITIONING", ),
                "positive_2": ("CONDITIONING", ),
                "negative_2": ("CONDITIONING", ),
                "positive_3": ("CONDITIONING", ),
                "negative_3": ("CONDITIONING", ),
                "mask_1": ("MASK", ),
                "mask_2": ("MASK", ),
                "mask_3": ("MASK", ),
                "mask_1_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "mask_2_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "mask_3_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "set_cond_area": (["default", "mask bounds"],),
            }
        }

    RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
    RETURN_NAMES = ("combined_positive", "combined_negative",)
    FUNCTION = "append"
    CATEGORY = "KJNodes/masking/conditioning"
    DESCRIPTION = """
Bundles multiple conditioning mask and combine nodes into one,functionality is identical to ComfyUI native nodes
"""

    def append(self, positive_1, negative_1, positive_2, positive_3, negative_2, negative_3, mask_1, mask_2, mask_3, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength):
        c = []
        c2 = []
        set_area_to_bounds = False
        if set_cond_area != "default":
            set_area_to_bounds = True
        if len(mask_1.shape) < 3:
            mask_1 = mask_1.unsqueeze(0)
        if len(mask_2.shape) < 3:
            mask_2 = mask_2.unsqueeze(0)
        if len(mask_3.shape) < 3:
            mask_3 = mask_3.unsqueeze(0)
        for t in positive_1:
            append_helper(t, mask_1, c, set_area_to_bounds, mask_1_strength)
        for t in positive_2:
            append_helper(t, mask_2, c, set_area_to_bounds, mask_2_strength)
        for t in positive_3:
            append_helper(t, mask_3, c, set_area_to_bounds, mask_3_strength)
        for t in negative_1:
            append_helper(t, mask_1, c2, set_area_to_bounds, mask_1_strength)
        for t in negative_2:
            append_helper(t, mask_2, c2, set_area_to_bounds, mask_2_strength)
        for t in negative_3:
            append_helper(t, mask_3, c2, set_area_to_bounds, mask_3_strength)
        return (c, c2)

class ConditioningSetMaskAndCombine4:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "positive_1": ("CONDITIONING", ),
                "negative_1": ("CONDITIONING", ),
                "positive_2": ("CONDITIONING", ),
                "negative_2": ("CONDITIONING", ),
                "positive_3": ("CONDITIONING", ),
                "negative_3": ("CONDITIONING", ),
                "positive_4": ("CONDITIONING", ),
                "negative_4": ("CONDITIONING", ),
                "mask_1": ("MASK", ),
                "mask_2": ("MASK", ),
                "mask_3": ("MASK", ),
                "mask_4": ("MASK", ),
                "mask_1_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "mask_2_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "mask_3_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "mask_4_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "set_cond_area": (["default", "mask bounds"],),
            }
        }

    RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
    RETURN_NAMES = ("combined_positive", "combined_negative",)
    FUNCTION = "append"
    CATEGORY = "KJNodes/masking/conditioning"
    DESCRIPTION = """
Bundles multiple conditioning mask and combine nodes into one,functionality is identical to ComfyUI native nodes
"""

    def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, negative_2, negative_3, negative_4, mask_1, mask_2, mask_3, mask_4, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength):
        c = []
        c2 = []
        set_area_to_bounds = False
        if set_cond_area != "default":
            set_area_to_bounds = True
        if len(mask_1.shape) < 3:
            mask_1 = mask_1.unsqueeze(0)
        if len(mask_2.shape) < 3:
            mask_2 = mask_2.unsqueeze(0)
        if len(mask_3.shape) < 3:
            mask_3 = mask_3.unsqueeze(0)
        if len(mask_4.shape) < 3:
            mask_4 = mask_4.unsqueeze(0)
        for t in positive_1:
            append_helper(t, mask_1, c, set_area_to_bounds, mask_1_strength)
        for t in positive_2:
            append_helper(t, mask_2, c, set_area_to_bounds, mask_2_strength)
        for t in positive_3:
            append_helper(t, mask_3, c, set_area_to_bounds, mask_3_strength)
        for t in positive_4:
            append_helper(t, mask_4, c, set_area_to_bounds, mask_4_strength)
        for t in negative_1:
            append_helper(t, mask_1, c2, set_area_to_bounds, mask_1_strength)
        for t in negative_2:
            append_helper(t, mask_2, c2, set_area_to_bounds, mask_2_strength)
        for t in negative_3:
            append_helper(t, mask_3, c2, set_area_to_bounds, mask_3_strength)
        for t in negative_4:
            append_helper(t, mask_4, c2, set_area_to_bounds, mask_4_strength)
        return (c, c2)

class ConditioningSetMaskAndCombine5:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "positive_1": ("CONDITIONING", ),
                "negative_1": ("CONDITIONING", ),
                "positive_2": ("CONDITIONING", ),
                "negative_2": ("CONDITIONING", ),
                "positive_3": ("CONDITIONING", ),
                "negative_3": ("CONDITIONING", ),
                "positive_4": ("CONDITIONING", ),
                "negative_4": ("CONDITIONING", ),
                "positive_5": ("CONDITIONING", ),
                "negative_5": ("CONDITIONING", ),
                "mask_1": ("MASK", ),
                "mask_2": ("MASK", ),
                "mask_3": ("MASK", ),
                "mask_4": ("MASK", ),
                "mask_5": ("MASK", ),
                "mask_1_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "mask_2_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "mask_3_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "mask_4_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "mask_5_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "set_cond_area": (["default", "mask bounds"],),
            }
        }

    RETURN_TYPES = ("CONDITIONING","CONDITIONING",)
    RETURN_NAMES = ("combined_positive", "combined_negative",)
    FUNCTION = "append"
    CATEGORY = "KJNodes/masking/conditioning"
    DESCRIPTION = """
Bundles multiple conditioning mask and combine nodes into one,functionality is identical to ComfyUI native nodes
"""

    def append(self, positive_1, negative_1, positive_2, positive_3, positive_4, positive_5, negative_2, negative_3, negative_4, negative_5, mask_1, mask_2, mask_3, mask_4, mask_5, set_cond_area, mask_1_strength, mask_2_strength, mask_3_strength, mask_4_strength, mask_5_strength):
        c = []
        c2 = []
        set_area_to_bounds = False
        if set_cond_area != "default":
            set_area_to_bounds = True
        if len(mask_1.shape) < 3:
            mask_1 = mask_1.unsqueeze(0)
        if len(mask_2.shape) < 3:
            mask_2 = mask_2.unsqueeze(0)
        if len(mask_3.shape) < 3:
            mask_3 = mask_3.unsqueeze(0)
        if len(mask_4.shape) < 3:
            mask_4 = mask_4.unsqueeze(0)
        if len(mask_5.shape) < 3:
            mask_5 = mask_5.unsqueeze(0)
        for t in positive_1:
            append_helper(t, mask_1, c, set_area_to_bounds, mask_1_strength)
        for t in positive_2:
            append_helper(t, mask_2, c, set_area_to_bounds, mask_2_strength)
        for t in positive_3:
            append_helper(t, mask_3, c, set_area_to_bounds, mask_3_strength)
        for t in positive_4:
            append_helper(t, mask_4, c, set_area_to_bounds, mask_4_strength)
        for t in positive_5:
            append_helper(t, mask_5, c, set_area_to_bounds, mask_5_strength)
        for t in negative_1:
            append_helper(t, mask_1, c2, set_area_to_bounds, mask_1_strength)
        for t in negative_2:
            append_helper(t, mask_2, c2, set_area_to_bounds, mask_2_strength)
        for t in negative_3:
            append_helper(t, mask_3, c2, set_area_to_bounds, mask_3_strength)
        for t in negative_4:
            append_helper(t, mask_4, c2, set_area_to_bounds, mask_4_strength)
        for t in negative_5:
            append_helper(t, mask_5, c2, set_area_to_bounds, mask_5_strength)
        return (c, c2)
    
class VRAM_Debug:
    
    @classmethod
    
    def INPUT_TYPES(s):
      return {
        "required": {
            
            "empty_cache": ("BOOLEAN", {"default": True}),
            "gc_collect": ("BOOLEAN", {"default": True}),
            "unload_all_models": ("BOOLEAN", {"default": False}),
        },
        "optional": {
            "any_input": (IO.ANY,),
            "image_pass": ("IMAGE",),
            "model_pass": ("MODEL",),
        }
	}
        
    RETURN_TYPES = (IO.ANY, "IMAGE","MODEL","INT", "INT",)
    RETURN_NAMES = ("any_output", "image_pass", "model_pass", "freemem_before", "freemem_after")
    FUNCTION = "VRAMdebug"
    CATEGORY = "KJNodes/memory"
    DESCRIPTION = """
Returns the inputs unchanged, they are only used as triggers,  
and performs comfy model management functions and garbage collection,  
reports free VRAM before and after the operations.
"""

    def VRAMdebug(self, gc_collect, empty_cache, unload_all_models, image_pass=None, model_pass=None, any_input=None):
        freemem_before = model_management.get_free_memory()
        logging.info(f"VRAMdebug: free memory before: {freemem_before:,.0f}")
        if empty_cache:
            model_management.soft_empty_cache()
        if unload_all_models:
            model_management.unload_all_models()
        if gc_collect:
            import gc
            gc.collect()
        freemem_after = model_management.get_free_memory()
        logging.info(f"VRAMdebug: free memory after: {freemem_after:,.0f}")
        logging.info(f"VRAMdebug: freed memory: {freemem_after - freemem_before:,.0f}")
        return {"ui": {
            "text": [f"{freemem_before:,.0f}x{freemem_after:,.0f}"]}, 
            "result": (any_input, image_pass, model_pass, freemem_before, freemem_after) 
        }

class SomethingToString:
    @classmethod
    
    def INPUT_TYPES(s):
     return {
        "required": {
        "input": (IO.ANY, ),
    },
    "optional": {
        "prefix": ("STRING", {"default": ""}),
        "suffix": ("STRING", {"default": ""}),
    }
    }
    RETURN_TYPES = ("STRING",)
    FUNCTION = "stringify"
    CATEGORY = "KJNodes/text"
    DESCRIPTION = """
Converts any type to a string.
"""

    def stringify(self, input, prefix="", suffix=""):
        if isinstance(input, (int, float, bool, str)):
            stringified = str(input)
        elif isinstance(input, list):
            stringified = ', '.join(str(item) for item in input)
        else:
            return input,
        if prefix: # Check if prefix is not empty
            stringified = prefix + stringified # Add the prefix
        if suffix: # Check if suffix is not empty
            stringified = stringified + suffix # Add the suffix

        return (stringified,)

class Sleep:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "input": (IO.ANY, ),
                "minutes": ("INT", {"default": 0, "min": 0, "max": 1439}),
                "seconds": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 59.99, "step": 0.01}),
            },
        }
    RETURN_TYPES = (IO.ANY,)
    FUNCTION = "sleepdelay"
    CATEGORY = "KJNodes/misc"
    DESCRIPTION = """
Delays the execution for the input amount of time.
"""

    def sleepdelay(self, input, minutes, seconds):
        total_seconds = minutes * 60 + seconds
        time.sleep(total_seconds)
        return input,
    
class EmptyLatentImagePresets:
    @classmethod
    def INPUT_TYPES(cls):  
        return {
        "required": {
             "dimensions": (
            [
                '512 x 512 (1:1)',
                '768 x 512 (1.5:1)',
                '960 x 512 (1.875:1)',
                '1024 x 512 (2:1)',
                '1024 x 576 (1.778:1)',
                '1536 x 640 (2.4:1)',
                '1344 x 768 (1.75:1)',
                '1216 x 832 (1.46:1)',
                '1152 x 896 (1.286:1)',
                '1024 x 1024 (1:1)',
            ],
            {
            "default": '512 x 512 (1:1)'
             }),
           
            "invert": ("BOOLEAN", {"default": False}),
            "batch_size": ("INT", {
            "default": 1,
            "min": 1,
            "max": 4096
            }),
        },
        }

    RETURN_TYPES = ("LATENT", "INT", "INT")
    RETURN_NAMES = ("Latent", "Width", "Height")
    FUNCTION = "generate"
    CATEGORY = "KJNodes/latents"

    def generate(self, dimensions, invert, batch_size):
        from nodes import EmptyLatentImage
        result = [x.strip() for x in dimensions.split('x')]

        # Remove the aspect ratio part
        result[0] = result[0].split('(')[0].strip()
        result[1] = result[1].split('(')[0].strip()
        
        if invert:
            width = int(result[1].split(' ')[0])
            height = int(result[0])
        else:
            width = int(result[0])
            height = int(result[1].split(' ')[0])
        latent = EmptyLatentImage().generate(width, height, batch_size)[0]

        return (latent, int(width), int(height),)

class EmptyLatentImageCustomPresets:
    @classmethod
    def INPUT_TYPES(cls):
        try:
            with open(os.path.join(script_directory, 'custom_dimensions.json')) as f:
                dimensions_dict = json.load(f)
        except FileNotFoundError:
            dimensions_dict = []
        return {
        "required": {
            "dimensions": (
                 [f"{d['label']} - {d['value']}" for d in dimensions_dict],
            ),
           
            "invert": ("BOOLEAN", {"default": False}),
            "batch_size": ("INT", {
            "default": 1,
            "min": 1,
            "max": 4096
            }),
        },
        }

    RETURN_TYPES = ("LATENT", "INT", "INT")
    RETURN_NAMES = ("Latent", "Width", "Height")
    FUNCTION = "generate"
    CATEGORY = "KJNodes/latents"
    DESCRIPTION = """
Generates an empty latent image with the specified dimensions.  
The choices are loaded from 'custom_dimensions.json' in the nodes folder.
"""

    def generate(self, dimensions, invert, batch_size):
       from nodes import EmptyLatentImage
       # Split the string into label and value
       label, value = dimensions.split(' - ')
       # Split the value into width and height
       width, height = [x.strip() for x in value.split('x')]
   
       if invert:
           width, height = height, width
   
       latent = EmptyLatentImage().generate(int(width), int(height), batch_size)[0]
   
       return (latent, int(width), int(height),)

class WidgetToString:
    @classmethod
    def IS_CHANGED(cls,*,id,node_title,any_input,**kwargs):
        if any_input is not None and (id != 0 or node_title != ""):
            return float("NaN")

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "id": ("INT", {"default": 0, "min": 0, "max": 100000, "step": 1}),
                "widget_name": ("STRING", {"multiline": False}),
                "return_all": ("BOOLEAN", {"default": False}),
            },
            "optional": {
                         "any_input": (IO.ANY, ),
                         "node_title": ("STRING", {"multiline": False}),
                         "allowed_float_decimals": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Number of decimal places to display for float values"}),
                         },
            "hidden": {"extra_pnginfo": "EXTRA_PNGINFO",
                       "prompt": "PROMPT",
                       "unique_id": "UNIQUE_ID",},
        }

    RETURN_TYPES = ("STRING", )
    FUNCTION = "get_widget_value"
    CATEGORY = "KJNodes/text"
    DESCRIPTION = """
Selects a node and it's specified widget and outputs the value as a string.  
If no node id or title is provided it will use the 'any_input' link and use that node.  
To see node id's, enable "Node ID Badge Mode" in main settings.
Alternatively you can search with the node title. Node titles ONLY exist if they  
are manually edited!
'widget_name' can be a comma separated list.
The 'any_input' is required for making sure the node you want the value from exists in the workflow.
"""

    def get_widget_value(self, id, widget_name, extra_pnginfo, prompt, unique_id, return_all=False, any_input=None, node_title="", allowed_float_decimals=2):
        workflow = extra_pnginfo["workflow"]
        #print(json.dumps(workflow, indent=4))
        results = []
        node_id = link_id = subgraph_prefix = None
        link_to_node_map = {}
        node_to_subgraph_map = {}  # Track which subgraph each node belongs to

        # Parse unique_id - handle both "parent:id" format and simple int format
        if isinstance(unique_id, str) and ":" in unique_id:
            unique_id_parts = unique_id.split(":")
            unique_id_int = int(unique_id_parts[-1])  # Use the last part as the node id
            subgraph_prefix = ":".join(unique_id_parts[:-1])  # Store the parent prefix (e.g., "14")
        else:
            unique_id_int = int(unique_id)

        # Collect all nodes from main workflow and subgraphs
        all_nodes = list(workflow.get("nodes", []))
        definitions = workflow.get("definitions", {})
        subgraphs = definitions.get("subgraphs", [])

        # Find which main workflow node references each subgraph
        subgraph_id_to_parent = {}
        for node in workflow.get("nodes", []):
            node_type = node.get("type", "")
            # Subgraph nodes have a UUID as their type
            if "-" in node_type and len(node_type) == 36:  # UUID format check
                subgraph_id_to_parent[node_type] = node["id"]

        for subgraph in subgraphs:
            subgraph_id = subgraph.get("id", "")
            parent_node_id = subgraph_id_to_parent.get(subgraph_id)

            subgraph_nodes = subgraph.get("nodes", [])
            for node in subgraph_nodes:
                # Track which subgraph (parent node) this node belongs to
                if parent_node_id is not None:
                    node_to_subgraph_map[node["id"]] = parent_node_id
            all_nodes.extend(subgraph_nodes)

            # Also build link_to_node_map from subgraph links
            subgraph_links = subgraph.get("links", [])
            for link in subgraph_links:
                # link format: [link_id, origin_id, origin_slot, target_id, target_slot, type]
                if isinstance(link, dict):
                    link_to_node_map[link["id"]] = link["origin_id"]
                elif isinstance(link, list) and len(link) >= 2:
                    link_to_node_map[link[0]] = link[1]

        for node in all_nodes:
            if node_title:
                if "title" in node:
                    if node["title"] == node_title:
                        node_id = node["id"]
                        break
                else:
                    logging.warning("Node title not found.")
            elif id != 0:
                if node["id"] == id:
                    node_id = id
                    break
            elif any_input is not None:
                if node["type"] == "WidgetToString" and node["id"] == unique_id_int and not link_id:
                    for node_input in node["inputs"]:
                        if node_input["name"] == "any_input":
                            link_id = node_input["link"]

                # Construct a map of links to node IDs for future reference
                node_outputs = node.get("outputs", None)
                if not node_outputs:
                    continue
                for output in node_outputs:
                    node_links = output.get("links", None)
                    if not node_links:
                        continue
                    for link in node_links:
                        link_to_node_map[link] = node["id"]
                        if link_id and link == link_id:
                            break

        if link_id:
            node_id = link_to_node_map.get(link_id, None)

        if node_id is None:
            raise ValueError("No matching node found for the given title or id")

        # Determine the correct prompt key
        # First check if the target node is in a subgraph
        target_subgraph_parent = node_to_subgraph_map.get(node_id)

        if target_subgraph_parent is not None:
            # Target node is in a subgraph, use the parent node id as prefix
            prompt_key = f"{target_subgraph_parent}:{node_id}"
        elif subgraph_prefix is not None:
            # We're in a subgraph, use our prefix
            prompt_key = f"{subgraph_prefix}:{node_id}"
        else:
            prompt_key = str(node_id)

        # Try the prefixed key first, then fall back to just the node_id
        if prompt_key not in prompt:
            prompt_key = str(node_id)

        if prompt_key not in prompt:
            raise KeyError(f"Node not found in prompt. Tried keys: '{target_subgraph_parent}:{node_id}' and '{node_id}'")

        values = prompt[prompt_key]
        if "inputs" in values:
            inputs = values["inputs"]

            # support comma-separated list and trim whitespace
            widget_names = []
            if widget_name:
                widget_names = [w.strip() for w in widget_name.split(",") if w.strip()]

            if return_all:
                # Format items based on type
                formatted_items = []
                for k, v in inputs.items():
                    if isinstance(v, float):
                        item = f"{k}: {v:.{allowed_float_decimals}f}"
                    else:
                        item = f"{k}: {str(v)}"
                    formatted_items.append(item)
                results.append(", ".join(formatted_items))

            # Single widget name (trimmed)
            elif len(widget_names) == 1:
                name = widget_names[0]
                if name in inputs:
                    v = inputs[name]
                    if isinstance(v, float):
                        v = f"{v:.{allowed_float_decimals}f}"
                    else:
                        v = str(v)
                    return (v, )
                else:
                    raise NameError(f"Widget not found: {node_id}.{name}")

            # Multiple widget names: return "name: value" pairs
            elif len(widget_names) > 1:
                formatted_items = []
                for name in widget_names:
                    if name not in inputs:
                        raise NameError(f"Widget not found: {node_id}.{name}")
                    v = inputs[name]
                    if isinstance(v, float):
                        v = f"{v:.{allowed_float_decimals}f}"
                    else:
                        v = str(v)
                    formatted_items.append(f"{name}: {v}")
                return (", ".join(formatted_items), )

            else:
                # No valid widget name provided
                raise NameError(f"Widget not found: {node_id}.{widget_name}")

        return (", ".join(results).strip(", "), )


class DummyOut:

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
            "any_input": (IO.ANY, ),
            }
        }

    RETURN_TYPES = (IO.ANY,)
    FUNCTION = "dummy"
    CATEGORY = "KJNodes/misc"
    OUTPUT_NODE = True
    DESCRIPTION = """
Does nothing, used to trigger generic workflow output.    
A way to get previews in the UI without saving anything to disk.
"""

    def dummy(self, any_input):
        return (any_input,)
    
class FlipSigmasAdjusted:
    @classmethod
    def INPUT_TYPES(s):
        return {"required":
                    {"sigmas": ("SIGMAS", ),
                     "divide_by_last_sigma": ("BOOLEAN", {"default": False}),
                     "divide_by": ("FLOAT", {"default": 1,"min": 1, "max": 255, "step": 0.01}),
                     "offset_by": ("INT", {"default": 1,"min": -100, "max": 100, "step": 1}),
                     }
                }
    RETURN_TYPES = ("SIGMAS", "STRING",)
    RETURN_NAMES = ("SIGMAS", "sigmas_string",)
    CATEGORY = "KJNodes/noise"
    FUNCTION = "get_sigmas_adjusted"

    def get_sigmas_adjusted(self, sigmas, divide_by_last_sigma, divide_by, offset_by):
        
        sigmas = sigmas.flip(0)
        if sigmas[0] == 0:
            sigmas[0] = 0.0001
        adjusted_sigmas = sigmas.clone()
        #offset sigma
        for i in range(1, len(sigmas)):
            offset_index = i - offset_by
            if 0 <= offset_index < len(sigmas):
                adjusted_sigmas[i] = sigmas[offset_index]
            else:
                adjusted_sigmas[i] = 0.0001 
        if adjusted_sigmas[0] == 0:
            adjusted_sigmas[0] = 0.0001  
        if divide_by_last_sigma:
            adjusted_sigmas = adjusted_sigmas / adjusted_sigmas[-1]

        sigma_np_array = adjusted_sigmas.numpy()
        array_string = np.array2string(sigma_np_array, precision=2, separator=', ', threshold=np.inf)
        adjusted_sigmas = adjusted_sigmas / divide_by
        return (adjusted_sigmas, array_string,)
    
class CustomSigmas:
    @classmethod
    def INPUT_TYPES(s):
        return {"required":
                    {
                     "sigmas_string" :("STRING", {"default": "14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029","multiline": True}),
                     "interpolate_to_steps": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}),
                     }
                }
    RETURN_TYPES = ("SIGMAS",)
    RETURN_NAMES = ("SIGMAS",)
    CATEGORY = "KJNodes/noise"
    FUNCTION = "customsigmas"
    DESCRIPTION = """
Creates a sigmas tensor from a string of comma separated values.  
Examples: 
   
Nvidia's optimized AYS 10 step schedule for SD 1.5:  
14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029  
SDXL:   
14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029  
SVD:  
700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002  
"""
    def customsigmas(self, sigmas_string, interpolate_to_steps):
        sigmas_list = sigmas_string.split(', ')
        sigmas_float_list = [float(sigma) for sigma in sigmas_list]
        sigmas_tensor = torch.FloatTensor(sigmas_float_list)
        if len(sigmas_tensor) != interpolate_to_steps + 1:
            sigmas_tensor = self.loglinear_interp(sigmas_tensor, interpolate_to_steps + 1)
        sigmas_tensor[-1] = 0
        return (sigmas_tensor.float(),)
     
    def loglinear_interp(self, t_steps, num_steps):
        """
        Performs log-linear interpolation of a given array of decreasing numbers.
        """
        t_steps_np = t_steps.numpy()

        xs = np.linspace(0, 1, len(t_steps_np))
        ys = np.log(t_steps_np[::-1])
        
        new_xs = np.linspace(0, 1, num_steps)
        new_ys = np.interp(new_xs, xs, ys)
        
        interped_ys = np.exp(new_ys)[::-1].copy()
        interped_ys_tensor = torch.tensor(interped_ys)
        return interped_ys_tensor
    
class StringToFloatList:
    @classmethod
    def INPUT_TYPES(s):
        return {"required":
                    {
                     "string" :("STRING", {"default": "1, 2, 3", "multiline": True}),
                     }
                }
    RETURN_TYPES = ("FLOAT",)
    RETURN_NAMES = ("FLOAT",)
    CATEGORY = "KJNodes/misc"
    FUNCTION = "createlist"

    def createlist(self, string):
        float_list = [float(x.strip()) for x in string.split(',')]
        return (float_list,)

 
class InjectNoiseToLatent:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "latents":("LATENT",),  
            "strength": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 200.0, "step": 0.0001}),
            "noise":  ("LATENT",),
            "normalize": ("BOOLEAN", {"default": False}),
            "average": ("BOOLEAN", {"default": False}),
            },
            "optional":{
                "mask": ("MASK", ),
                "mix_randn_amount": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.001}),
                "seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}),
            }
            }
    
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "injectnoise"
    CATEGORY = "KJNodes/noise"
        
    def injectnoise(self, latents, strength, noise, normalize, average, mix_randn_amount=0, seed=None, mask=None):
        samples = latents["samples"].clone().cpu()
        noise = noise["samples"].clone().cpu()
        if samples.shape != samples.shape:
            raise ValueError("InjectNoiseToLatent: Latent and noise must have the same shape")
        if average:
            noised = (samples + noise) / 2
        else:
            noised = samples + noise * strength
        if normalize:
            noised = noised / noised.std()
        if mask is not None:
            mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(noised.shape[2], noised.shape[3]), mode="bilinear")
            mask = mask.expand((-1,noised.shape[1],-1,-1))
            if mask.shape[0] < noised.shape[0]:
                mask = mask.repeat((noised.shape[0] -1) // mask.shape[0] + 1, 1, 1, 1)[:noised.shape[0]]
            noised = mask * noised + (1-mask) * samples
        if mix_randn_amount > 0:
            if seed is not None:
                generator = torch.manual_seed(seed)
                rand_noise = torch.randn(noised.size(), dtype=noised.dtype, layout=noised.layout, generator=generator, device="cpu")
                noised = noised + (mix_randn_amount * rand_noise)
        
        return ({"samples":noised},)
 
class SoundReactive:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {  
            "sound_level": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 99999, "step": 0.01}),
            "start_range_hz": ("INT", {"default": 150, "min": 0, "max": 9999, "step": 1}),
            "end_range_hz": ("INT", {"default": 2000, "min": 0, "max": 9999, "step": 1}),
            "multiplier": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 99999, "step": 0.01}),
            "smoothing_factor": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
            "normalize": ("BOOLEAN", {"default": False}),
            },
            }
    
    RETURN_TYPES = ("FLOAT","INT",)
    RETURN_NAMES =("sound_level", "sound_level_int",)
    FUNCTION = "react"
    CATEGORY = "KJNodes/audio"
    DESCRIPTION = """
Reacts to the sound level of the input.  
Uses your browsers sound input options and requires.  
Meant to be used with realtime diffusion with autoqueue.
"""
        
    def react(self, sound_level, start_range_hz, end_range_hz, smoothing_factor, multiplier, normalize):

        sound_level *= multiplier

        if normalize:
            sound_level /= 255

        sound_level_int = int(sound_level)
        return (sound_level, sound_level_int, )     
       
class GenerateNoise:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
            "width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
            "height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
            "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
            "seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}),
            "multiplier": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 4096, "step": 0.01}),
            "constant_batch_noise": ("BOOLEAN", {"default": False}),
            "normalize": ("BOOLEAN", {"default": False}),
            },
            "optional": {
                "model": ("MODEL", ),
                "sigmas": ("SIGMAS", ),
                "latent_channels": (['4', '16', ],),
                "shape": (["BCHW", "BCTHW","BTCHW",],),
            }
        }
    
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "generatenoise"
    CATEGORY = "KJNodes/noise"
    DESCRIPTION = """
Generates noise for injection or to be used as empty latents on samplers with add_noise off.
"""
        
    def generatenoise(self, batch_size, width, height, seed, multiplier, constant_batch_noise, normalize, sigmas=None, model=None, latent_channels=4, shape="BCHW"):

        generator = torch.manual_seed(seed)
        if shape == "BCHW":
            noise = torch.randn([batch_size, int(latent_channels), height // 8, width // 8], dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu")
        elif shape == "BCTHW":
            noise = torch.randn([1, int(latent_channels), batch_size,height // 8, width // 8], dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu")
        elif shape == "BTCHW":
            noise = torch.randn([1, batch_size, int(latent_channels), height // 8, width // 8], dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu")
        if sigmas is not None:
            sigma = sigmas[0] - sigmas[-1]
            sigma /= model.model.latent_format.scale_factor
            noise *= sigma

        noise *=multiplier

        if normalize:
            noise = noise / noise.std()
        if constant_batch_noise:
            noise = noise[0].repeat(batch_size, 1, 1, 1)

        
        return ({"samples":noise}, )

def camera_embeddings(elevation, azimuth):
    elevation = torch.as_tensor([elevation])
    azimuth = torch.as_tensor([azimuth])
    embeddings = torch.stack(
        [
                torch.deg2rad(
                    (90 - elevation) - (90)
                ),  # Zero123 polar is 90-elevation
                torch.sin(torch.deg2rad(azimuth)),
                torch.cos(torch.deg2rad(azimuth)),
                torch.deg2rad(
                    90 - torch.full_like(elevation, 0)
                ),
        ], dim=-1).unsqueeze(1)

    return embeddings

def interpolate_angle(start, end, fraction):
    # Calculate the difference in angles and adjust for wraparound if necessary
    diff = (end - start + 540) % 360 - 180
    # Apply fraction to the difference
    interpolated = start + fraction * diff
    # Normalize the result to be within the range of -180 to 180
    return (interpolated + 180) % 360 - 180


class StableZero123_BatchSchedule:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "clip_vision": ("CLIP_VISION",),
                              "init_image": ("IMAGE",),
                              "vae": ("VAE",),
                              "width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
                              "height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
                              "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
                              "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],),
                              "azimuth_points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n", "multiline": True}),
                              "elevation_points_string": ("STRING", {"default": "0:(0.0),\n7:(0.0),\n15:(0.0)\n", "multiline": True}),
                             }}
    
    RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
    RETURN_NAMES = ("positive", "negative", "latent")
    FUNCTION = "encode"
    CATEGORY = "KJNodes/experimental"

    def encode(self, clip_vision, init_image, vae, width, height, batch_size, azimuth_points_string, elevation_points_string, interpolation):
        output = clip_vision.encode_image(init_image)
        pooled = output.image_embeds.unsqueeze(0)
        pixels = common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
        encode_pixels = pixels[:,:,:,:3]
        t = vae.encode(encode_pixels)

        def ease_in(t):
            return t * t
        def ease_out(t):
            return 1 - (1 - t) * (1 - t)
        def ease_in_out(t):
            return 3 * t * t - 2 * t * t * t
        
        # Parse the azimuth input string into a list of tuples
        azimuth_points = []
        azimuth_points_string = azimuth_points_string.rstrip(',\n')
        for point_str in azimuth_points_string.split(','):
            frame_str, azimuth_str = point_str.split(':')
            frame = int(frame_str.strip())
            azimuth = float(azimuth_str.strip()[1:-1]) 
            azimuth_points.append((frame, azimuth))
        # Sort the points by frame number
        azimuth_points.sort(key=lambda x: x[0])

        # Parse the elevation input string into a list of tuples
        elevation_points = []
        elevation_points_string = elevation_points_string.rstrip(',\n')
        for point_str in elevation_points_string.split(','):
            frame_str, elevation_str = point_str.split(':')
            frame = int(frame_str.strip())
            elevation_val = float(elevation_str.strip()[1:-1]) 
            elevation_points.append((frame, elevation_val))
        # Sort the points by frame number
        elevation_points.sort(key=lambda x: x[0])

        # Index of the next point to interpolate towards
        next_point = 1
        next_elevation_point = 1

        positive_cond_out = []
        positive_pooled_out = []
        negative_cond_out = []
        negative_pooled_out = []
        
        #azimuth interpolation
        for i in range(batch_size):
            # Find the interpolated azimuth for the current frame
            while next_point < len(azimuth_points) and i >= azimuth_points[next_point][0]:
                next_point += 1
            # If next_point is equal to the length of points, we've gone past the last point
            if next_point == len(azimuth_points):
                next_point -= 1  # Set next_point to the last index of points
            prev_point = max(next_point - 1, 0)  # Ensure prev_point is not less than 0

            # Calculate fraction
            if azimuth_points[next_point][0] != azimuth_points[prev_point][0]:  # Prevent division by zero
                fraction = (i - azimuth_points[prev_point][0]) / (azimuth_points[next_point][0] - azimuth_points[prev_point][0])
                if interpolation == "ease_in":
                    fraction = ease_in(fraction)
                elif interpolation == "ease_out":
                    fraction = ease_out(fraction)
                elif interpolation == "ease_in_out":
                    fraction = ease_in_out(fraction)
                
                # Use the new interpolate_angle function
                interpolated_azimuth = interpolate_angle(azimuth_points[prev_point][1], azimuth_points[next_point][1], fraction)
            else:
                interpolated_azimuth = azimuth_points[prev_point][1]
            # Interpolate the elevation
            next_elevation_point = 1
            while next_elevation_point < len(elevation_points) and i >= elevation_points[next_elevation_point][0]:
                next_elevation_point += 1
            if next_elevation_point == len(elevation_points):
                next_elevation_point -= 1
            prev_elevation_point = max(next_elevation_point - 1, 0)

            if elevation_points[next_elevation_point][0] != elevation_points[prev_elevation_point][0]:
                fraction = (i - elevation_points[prev_elevation_point][0]) / (elevation_points[next_elevation_point][0] - elevation_points[prev_elevation_point][0])
                if interpolation == "ease_in":
                    fraction = ease_in(fraction)
                elif interpolation == "ease_out":
                    fraction = ease_out(fraction)
                elif interpolation == "ease_in_out":
                    fraction = ease_in_out(fraction)
                
                interpolated_elevation = interpolate_angle(elevation_points[prev_elevation_point][1], elevation_points[next_elevation_point][1], fraction)
            else:
                interpolated_elevation = elevation_points[prev_elevation_point][1]

            cam_embeds = camera_embeddings(interpolated_elevation, interpolated_azimuth)
            cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1)

            positive_pooled_out.append(t)
            positive_cond_out.append(cond)
            negative_pooled_out.append(torch.zeros_like(t))
            negative_cond_out.append(torch.zeros_like(pooled))

        # Concatenate the conditions and pooled outputs
        final_positive_cond = torch.cat(positive_cond_out, dim=0)
        final_positive_pooled = torch.cat(positive_pooled_out, dim=0)
        final_negative_cond = torch.cat(negative_cond_out, dim=0)
        final_negative_pooled = torch.cat(negative_pooled_out, dim=0)

        # Structure the final output
        final_positive = [[final_positive_cond, {"concat_latent_image": final_positive_pooled}]]
        final_negative = [[final_negative_cond, {"concat_latent_image": final_negative_pooled}]]

        latent = torch.zeros([batch_size, 4, height // 8, width // 8])
        return (final_positive, final_negative, {"samples": latent})

def linear_interpolate(start, end, fraction):
    return start + (end - start) * fraction

class SV3D_BatchSchedule:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "clip_vision": ("CLIP_VISION",),
                              "init_image": ("IMAGE",),
                              "vae": ("VAE",),
                              "width": ("INT", {"default": 576, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
                              "height": ("INT", {"default": 576, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
                              "batch_size": ("INT", {"default": 21, "min": 1, "max": 4096}),
                              "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],),
                              "azimuth_points_string": ("STRING", {"default": "0:(0.0),\n9:(180.0),\n20:(360.0)\n", "multiline": True}),
                              "elevation_points_string": ("STRING", {"default": "0:(0.0),\n9:(0.0),\n20:(0.0)\n", "multiline": True}),
                             }}
    
    RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
    RETURN_NAMES = ("positive", "negative", "latent")
    FUNCTION = "encode"
    CATEGORY = "KJNodes/experimental"
    DESCRIPTION = """
Allow scheduling of the azimuth and elevation conditions for SV3D.  
Note that SV3D is still a video model and the schedule needs to always go forward  
https://huggingface.co/stabilityai/sv3d
"""

    def encode(self, clip_vision, init_image, vae, width, height, batch_size, azimuth_points_string, elevation_points_string, interpolation):
        output = clip_vision.encode_image(init_image)
        pooled = output.image_embeds.unsqueeze(0)
        pixels = common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
        encode_pixels = pixels[:,:,:,:3]
        t = vae.encode(encode_pixels)

        def ease_in(t):
            return t * t
        def ease_out(t):
            return 1 - (1 - t) * (1 - t)
        def ease_in_out(t):
            return 3 * t * t - 2 * t * t * t
        
        # Parse the azimuth input string into a list of tuples
        azimuth_points = []
        azimuth_points_string = azimuth_points_string.rstrip(',\n')
        for point_str in azimuth_points_string.split(','):
            frame_str, azimuth_str = point_str.split(':')
            frame = int(frame_str.strip())
            azimuth = float(azimuth_str.strip()[1:-1]) 
            azimuth_points.append((frame, azimuth))
        # Sort the points by frame number
        azimuth_points.sort(key=lambda x: x[0])

        # Parse the elevation input string into a list of tuples
        elevation_points = []
        elevation_points_string = elevation_points_string.rstrip(',\n')
        for point_str in elevation_points_string.split(','):
            frame_str, elevation_str = point_str.split(':')
            frame = int(frame_str.strip())
            elevation_val = float(elevation_str.strip()[1:-1]) 
            elevation_points.append((frame, elevation_val))
        # Sort the points by frame number
        elevation_points.sort(key=lambda x: x[0])

        # Index of the next point to interpolate towards
        next_point = 1
        next_elevation_point = 1
        elevations = []
        azimuths = []
        # For azimuth interpolation
        for i in range(batch_size):
            # Find the interpolated azimuth for the current frame
            while next_point < len(azimuth_points) and i >= azimuth_points[next_point][0]:
                next_point += 1
            if next_point == len(azimuth_points):
                next_point -= 1
            prev_point = max(next_point - 1, 0)

            if azimuth_points[next_point][0] != azimuth_points[prev_point][0]:
                fraction = (i - azimuth_points[prev_point][0]) / (azimuth_points[next_point][0] - azimuth_points[prev_point][0])
                # Apply the ease function to the fraction
                if interpolation == "ease_in":
                    fraction = ease_in(fraction)
                elif interpolation == "ease_out":
                    fraction = ease_out(fraction)
                elif interpolation == "ease_in_out":
                    fraction = ease_in_out(fraction)
                
                interpolated_azimuth = linear_interpolate(azimuth_points[prev_point][1], azimuth_points[next_point][1], fraction)
            else:
                interpolated_azimuth = azimuth_points[prev_point][1]

            # Interpolate the elevation
            next_elevation_point = 1
            while next_elevation_point < len(elevation_points) and i >= elevation_points[next_elevation_point][0]:
                next_elevation_point += 1
            if next_elevation_point == len(elevation_points):
                next_elevation_point -= 1
            prev_elevation_point = max(next_elevation_point - 1, 0)

            if elevation_points[next_elevation_point][0] != elevation_points[prev_elevation_point][0]:
                fraction = (i - elevation_points[prev_elevation_point][0]) / (elevation_points[next_elevation_point][0] - elevation_points[prev_elevation_point][0])
                # Apply the ease function to the fraction
                if interpolation == "ease_in":
                    fraction = ease_in(fraction)
                elif interpolation == "ease_out":
                    fraction = ease_out(fraction)
                elif interpolation == "ease_in_out":
                    fraction = ease_in_out(fraction)
                
                interpolated_elevation = linear_interpolate(elevation_points[prev_elevation_point][1], elevation_points[next_elevation_point][1], fraction)
            else:
                interpolated_elevation = elevation_points[prev_elevation_point][1]

            azimuths.append(interpolated_azimuth)
            elevations.append(interpolated_elevation)

        #print("azimuths", azimuths)
        #print("elevations", elevations)

        # Structure the final output
        final_positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
        final_negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t),"elevation": elevations, "azimuth": azimuths}]]

        latent = torch.zeros([batch_size, 4, height // 8, width // 8])
        return (final_positive, final_negative, {"samples": latent})

        
class Superprompt:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "instruction_prompt": ("STRING", {"default": 'Expand the following prompt to add more detail', "multiline": True}),
                "prompt": ("STRING", {"default": '', "multiline": True, "forceInput": True}),
                "max_new_tokens": ("INT", {"default": 128, "min": 1, "max": 4096, "step": 1}),
            } 
        }

    RETURN_TYPES = ("STRING",)
    FUNCTION = "process"
    CATEGORY = "KJNodes/text"
    DESCRIPTION = """
# SuperPrompt
A T5 model fine-tuned on the SuperPrompt dataset for  
upsampling text prompts to more detailed descriptions.  
Meant to be used as a pre-generation step for text-to-image  
models that benefit from more detailed prompts.  
https://huggingface.co/roborovski/superprompt-v1
"""

    def process(self, instruction_prompt, prompt, max_new_tokens):
        device = model_management.get_torch_device()
        from transformers import T5Tokenizer, T5ForConditionalGeneration

        checkpoint_path = os.path.join(script_directory, "models","superprompt-v1")
        if not os.path.exists(checkpoint_path):
                logging.info(f"Downloading model to: {checkpoint_path}")
                from huggingface_hub import snapshot_download
                snapshot_download(repo_id="roborovski/superprompt-v1", 
                                  local_dir=checkpoint_path, 
                                  local_dir_use_symlinks=False)
        tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small", legacy=False)

        model = T5ForConditionalGeneration.from_pretrained(checkpoint_path, device_map=device)
        model.to(device)
        input_text = instruction_prompt + ": " + prompt
  
        input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
        outputs = model.generate(input_ids,  max_new_tokens=max_new_tokens)
        out = (tokenizer.decode(outputs[0]))
        out = out.replace('<pad>', '')
        out = out.replace('</s>', '')
        
        return (out, )


class CameraPoseVisualizer:
                
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "pose_file_path": ("STRING", {"default": '', "multiline": False}),
            "base_xval": ("FLOAT", {"default": 0.2,"min": 0, "max": 100, "step": 0.01}),
            "zval": ("FLOAT", {"default": 0.3,"min": 0, "max": 100, "step": 0.01}),
            "scale": ("FLOAT", {"default": 1.0,"min": 0.01, "max": 10.0, "step": 0.01}),
            "use_exact_fx": ("BOOLEAN", {"default": False}),
            "relative_c2w": ("BOOLEAN", {"default": True}),
            "use_viewer": ("BOOLEAN", {"default": False}),
            },
            "optional": {
                "cameractrl_poses": ("CAMERACTRL_POSES", {"default": None}),
            }
            }
    
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "plot"
    CATEGORY = "KJNodes/misc"
    DESCRIPTION = """
Visualizes the camera poses, from Animatediff-Evolved CameraCtrl Pose  
or a .txt file with RealEstate camera intrinsics and coordinates, in a 3D plot. 
"""
        
    def plot(self, pose_file_path, scale, base_xval, zval, use_exact_fx, relative_c2w, use_viewer, cameractrl_poses=None):
        import matplotlib as mpl
        import matplotlib.pyplot as plt
        from torchvision.transforms import ToTensor

        x_min = -2.0 * scale
        x_max = 2.0 * scale
        y_min = -2.0 * scale
        y_max = 2.0 * scale
        z_min = -2.0 * scale
        z_max = 2.0 * scale
        plt.rcParams['text.color'] = '#999999'
        self.fig = plt.figure(figsize=(18, 7))
        self.fig.patch.set_facecolor('#353535')
        self.ax = self.fig.add_subplot(projection='3d')
        self.ax.set_facecolor('#353535') # Set the background color here
        self.ax.grid(color='#999999', linestyle='-', linewidth=0.5)
        self.plotly_data = None  # plotly data traces
        self.ax.set_aspect("auto")
        self.ax.set_xlim(x_min, x_max)
        self.ax.set_ylim(y_min, y_max)
        self.ax.set_zlim(z_min, z_max)
        self.ax.set_xlabel('x', color='#999999')
        self.ax.set_ylabel('y', color='#999999')
        self.ax.set_zlabel('z', color='#999999')
        for text in self.ax.get_xticklabels() + self.ax.get_yticklabels() + self.ax.get_zticklabels():
            text.set_color('#999999')
        logging.info('initialize camera pose visualizer')

        if pose_file_path != "":
            with open(pose_file_path, 'r') as f:
                poses = f.readlines()
                w2cs = [np.asarray([float(p) for p in pose.strip().split(' ')[7:]]).reshape(3, 4) for pose in poses[1:]]
                fxs = [float(pose.strip().split(' ')[1]) for pose in poses[1:]]
                #print(poses)
        elif cameractrl_poses is not None:
            poses = cameractrl_poses
            w2cs = [np.array(pose[7:]).reshape(3, 4) for pose in cameractrl_poses]
            fxs = [pose[1] for pose in cameractrl_poses]
        else:
            raise ValueError("Please provide either pose_file_path or cameractrl_poses")

        total_frames = len(w2cs)
        transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4)
        last_row = np.zeros((1, 4))
        last_row[0, -1] = 1.0

        w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs]
        c2ws = self.get_c2w(w2cs, transform_matrix, relative_c2w)

        for frame_idx, c2w in enumerate(c2ws):
            self.extrinsic2pyramid(c2w, frame_idx / total_frames, hw_ratio=1/1, base_xval=base_xval,
                                    zval=(fxs[frame_idx] if use_exact_fx else zval))

        # Create the colorbar
        cmap = mpl.cm.rainbow
        norm = mpl.colors.Normalize(vmin=0, vmax=total_frames)
        colorbar = self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=self.ax, orientation='vertical')

        # Change the colorbar label
        colorbar.set_label('Frame', color='#999999') # Change the label and its color

        # Change the tick colors
        colorbar.ax.yaxis.set_tick_params(colors='#999999') # Change the tick color

        # Change the tick frequency
        # Assuming you want to set the ticks at every 10th frame
        ticks = np.arange(0, total_frames, 10)
        colorbar.ax.yaxis.set_ticks(ticks)
        
        plt.title('')
        plt.draw()
        buf = BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
        buf.seek(0)
        img = Image.open(buf)
        tensor_img = ToTensor()(img)
        buf.close()
        tensor_img = tensor_img.permute(1, 2, 0).unsqueeze(0)
        if use_viewer:
            time.sleep(1)
            plt.show()
        return (tensor_img,)

    def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=1/1, base_xval=1, zval=3):
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d.art3d import Poly3DCollection
        vertex_std = np.array([[0, 0, 0, 1],
                            [base_xval, -base_xval * hw_ratio, zval, 1],
                            [base_xval, base_xval * hw_ratio, zval, 1],
                            [-base_xval, base_xval * hw_ratio, zval, 1],
                            [-base_xval, -base_xval * hw_ratio, zval, 1]])
        vertex_transformed = vertex_std @ extrinsic.T
        meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]],
                            [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]],
                            [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]],
                            [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]],
                            [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]]]

        color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map)

        self.ax.add_collection3d(
            Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.25))

    def customize_legend(self, list_label):
        from matplotlib.patches import Patch
        import matplotlib.pyplot as plt
        list_handle = []
        for idx, label in enumerate(list_label):
            color = plt.cm.rainbow(idx / len(list_label))
            patch = Patch(color=color, label=label)
            list_handle.append(patch)
        plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle)

    def get_c2w(self, w2cs, transform_matrix, relative_c2w):
        if relative_c2w:
            target_cam_c2w = np.array([
                [1, 0, 0, 0],
                [0, 1, 0, 0],
                [0, 0, 1, 0],
                [0, 0, 0, 1]
            ])
            abs2rel = target_cam_c2w @ w2cs[0]
            ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]]
        else:
            ret_poses = [np.linalg.inv(w2c) for w2c in w2cs]
        ret_poses = [transform_matrix @ x for x in ret_poses]
        return np.array(ret_poses, dtype=np.float32)
    
    
            
class CheckpointPerturbWeights:

    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "model": ("MODEL",),
            "joint_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
            "final_layer": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
            "rest_of_the_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}),
            "seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}),
            }
        }
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "mod"
    OUTPUT_NODE = True

    CATEGORY = "KJNodes/experimental"

    def mod(self, seed, model, joint_blocks, final_layer, rest_of_the_blocks):
        import copy
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        device = model_management.get_torch_device()
        model_copy = copy.deepcopy(model)
        model_copy.model.to(device)
        keys = model_copy.model.diffusion_model.state_dict().keys()

        dict = {}
        for key in keys:
            dict[key] = model_copy.model.diffusion_model.state_dict()[key]

        pbar = ProgressBar(len(keys))
        for k in keys:
            v = dict[k]
            logging.info(f'{k}: {v.std()}')
            if k.startswith('joint_blocks'):
                multiplier = joint_blocks
            elif k.startswith('final_layer'):
                multiplier = final_layer
            else:
                multiplier = rest_of_the_blocks
            dict[k] += torch.normal(torch.zeros_like(v) * v.mean(), torch.ones_like(v) * v.std() * multiplier).to(device)
            pbar.update(1)
        model_copy.model.diffusion_model.load_state_dict(dict)
        return model_copy,
    
class DifferentialDiffusionAdvanced():
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "model": ("MODEL", ),
                    "samples": ("LATENT",),
                    "mask": ("MASK",),
                    "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.001}),
                            }}
    RETURN_TYPES = ("MODEL", "LATENT")
    FUNCTION = "apply"
    CATEGORY = "_for_testing"
    INIT = False

    def apply(self, model, samples, mask, multiplier):
        self.multiplier = multiplier
        model = model.clone()
        model.set_model_denoise_mask_function(self.forward)
        s = samples.copy()
        s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
        return (model, s)

    def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict):
        model = extra_options["model"]
        step_sigmas = extra_options["sigmas"]
        sigma_to = model.inner_model.model_sampling.sigma_min
        if step_sigmas[-1] > sigma_to:
            sigma_to = step_sigmas[-1]
        sigma_from = step_sigmas[0]

        ts_from = model.inner_model.model_sampling.timestep(sigma_from)
        ts_to = model.inner_model.model_sampling.timestep(sigma_to)
        current_ts = model.inner_model.model_sampling.timestep(sigma[0])

        threshold = (current_ts - ts_to) / (ts_from - ts_to) / self.multiplier

        return (denoise_mask >= threshold).to(denoise_mask.dtype)
    
class FluxBlockLoraSelect:
    def __init__(self):
        self.loaded_lora = None

    @classmethod
    def INPUT_TYPES(s):
        arg_dict = {}
        argument = ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01})

        for i in range(19):
            arg_dict["double_blocks.{}.".format(i)] = argument

        for i in range(38):
            arg_dict["single_blocks.{}.".format(i)] = argument

        return {"required": arg_dict}
    
    RETURN_TYPES = ("SELECTEDDITBLOCKS", )
    RETURN_NAMES = ("blocks", )
    OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
    FUNCTION = "load_lora"

    CATEGORY = "KJNodes/experimental"
    DESCRIPTION = "Select individual block alpha values, value of 0 removes the block altogether"

    def load_lora(self, **kwargs):
        return (kwargs,)
    
class HunyuanVideoBlockLoraSelect:
    @classmethod
    def INPUT_TYPES(s):
        arg_dict = {}
        argument = ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01})

        for i in range(20):
            arg_dict["double_blocks.{}.".format(i)] = argument

        for i in range(40):
            arg_dict["single_blocks.{}.".format(i)] = argument

        return {"required": arg_dict}
    
    RETURN_TYPES = ("SELECTEDDITBLOCKS", )
    RETURN_NAMES = ("blocks", )
    OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
    FUNCTION = "load_lora"

    CATEGORY = "KJNodes/hunyuanvideo"
    DESCRIPTION = "Select individual block alpha values, value of 0 removes the block altogether"

    def load_lora(self, **kwargs):
        return (kwargs,)

class Wan21BlockLoraSelect:
    @classmethod
    def INPUT_TYPES(s):
        arg_dict = {}
        argument = ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01})

        for i in range(40):
            arg_dict["blocks.{}.".format(i)] = argument

        return {"required": arg_dict}
    
    RETURN_TYPES = ("SELECTEDDITBLOCKS", )
    RETURN_NAMES = ("blocks", )
    OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
    FUNCTION = "load_lora"

    CATEGORY = "KJNodes/wan"
    DESCRIPTION = "Select individual block alpha values, value of 0 removes the block altogether"

    def load_lora(self, **kwargs):
        return (kwargs,)

class LTX2BlockLoraSelect:
    @classmethod
    def INPUT_TYPES(s):
        arg_dict = {}
        argument = ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10000.0, "step": 0.01})

        for i in range(48):
            arg_dict["blocks.{}.".format(i)] = argument

        return {"required": arg_dict}
    
    RETURN_TYPES = ("SELECTEDDITBLOCKS", )
    RETURN_NAMES = ("blocks", )
    OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
    FUNCTION = "load_lora"

    CATEGORY = "KJNodes/ltxv"
    DESCRIPTION = "Select individual block alpha values, value of 0 removes the block altogether"

    def load_lora(self, **kwargs):
        return (kwargs,)

    
class DiTBlockLoraLoader:
    def __init__(self):
        self.loaded_lora = None

    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
                "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
                
                },
                "optional": {
                    "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
                    "opt_lora_path": ("STRING", {"forceInput": True, "tooltip": "Absolute path of the LoRA."}),
                    "blocks": ("SELECTEDDITBLOCKS",),
                }
               }
    
    RETURN_TYPES = ("MODEL", "STRING", )
    RETURN_NAMES = ("model", "rank", )
    OUTPUT_TOOLTIPS = ("The modified diffusion model.", "possible rank of the LoRA.")
    FUNCTION = "load_lora"
    CATEGORY = "KJNodes/lora"

    def load_lora(self, model, strength_model, lora_name=None, opt_lora_path=None, blocks=None):
        
        import comfy.lora

        if opt_lora_path:
            lora_path = opt_lora_path
        else:
            lora_path = folder_paths.get_full_path("loras", lora_name)
        
        lora = None
        if self.loaded_lora is not None:
            if self.loaded_lora[0] == lora_path:
                lora = self.loaded_lora[1]
            else:
                self.loaded_lora = None
        
        if lora is None:
            lora = load_torch_file(lora_path, safe_load=True)
            self.loaded_lora = (lora_path, lora)

        # Find the first key that ends with "weight"
        rank = "unknown"
        weight_key = next((key for key in lora.keys() if key.endswith('weight')), None)
        # Print the shape of the value corresponding to the key
        if weight_key:
            logging.info(f"Shape of the first 'weight' key ({weight_key}): {lora[weight_key].shape}")
            rank = str(lora[weight_key].shape[0])
        else:
            logging.warning("No key ending with 'weight' found.")
            rank = "Couldn't find rank"
        self.loaded_lora = (lora_path, lora)

        key_map = {}
        if model is not None:
            key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)

        loaded = comfy.lora.load_lora(lora, key_map)

        if blocks is not None:
            keys_to_delete = []

            for block in blocks:
                for key in list(loaded.keys()):
                    match = False
                    if isinstance(key, str) and block in key:
                        match = True
                    elif isinstance(key, tuple):
                        for k in key:
                            if block in k:
                                match = True
                                break

                    if match:
                        ratio = blocks[block]
                        if ratio == 0:
                            keys_to_delete.append(key)
                        else:
                            # Only modify LoRA adapters, skip diff tuples
                            value = loaded[key]
                            if hasattr(value, 'weights'):
                                logging.info(f"Modifying LoRA adapter for key: {key}")
                                weights_list = list(value.weights)
                                weights_list[2] = ratio
                                loaded[key].weights = tuple(weights_list)
                            else:
                                logging.info(f"Skipping non-LoRA entry for key: {key}")

            for key in keys_to_delete:
                del loaded[key]

            logging.info("loading lora keys:")
            for key, value in loaded.items():
                if hasattr(value, 'weights'):
                    logging.info(f"Key: {key}, Alpha: {value.weights[2]}")
                else:
                    logging.info(f"Key: {key}, Type: {type(value)}")

        if model is not None:
            new_modelpatcher = model.clone()
            k = new_modelpatcher.add_patches(loaded, strength_model)  
    
        k = set(k)
        for x in loaded:
            if (x not in k):
                logging.warning(f"NOT LOADED {x}")

        return (new_modelpatcher, rank)
    
class CustomControlNetWeightsFluxFromList:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "list_of_floats": ("FLOAT", {"forceInput": True}, ),
            },
            "optional": {
                "uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
                "cn_extras": ("CN_WEIGHTS_EXTRAS",),
                "autosize": ("ACNAUTOSIZE", {"padding": 0}),
            }
        }
    
    RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
    RETURN_NAMES = ("CN_WEIGHTS", "TK_SHORTCUT")
    FUNCTION = "load_weights"
    DESCRIPTION = "Creates controlnet weights from a list of floats for Advanced-ControlNet"

    CATEGORY = "KJNodes/controlnet"

    def load_weights(self, list_of_floats: list[float],
                     uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
        
        adv_control = importlib.import_module("ComfyUI-Advanced-ControlNet.adv_control")
        ControlWeights = adv_control.utils.ControlWeights
        TimestepKeyframeGroup = adv_control.utils.TimestepKeyframeGroup
        TimestepKeyframe = adv_control.utils.TimestepKeyframe

        weights = ControlWeights.controlnet(weights_input=list_of_floats, uncond_multiplier=uncond_multiplier, extras=cn_extras)
        logging.info(weights.weights_input)
        return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
    
SHAKKERLABS_UNION_CONTROLNET_TYPES = {
    "canny": 0,
    "tile": 1,
    "depth": 2,
    "blur": 3,
    "pose": 4,
    "gray": 5,
    "low quality": 6,
}

class SetShakkerLabsUnionControlNetType:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"control_net": ("CONTROL_NET", ),
                             "type": (["auto"] + list(SHAKKERLABS_UNION_CONTROLNET_TYPES.keys()),)
                             }}

    CATEGORY = "conditioning/controlnet"
    RETURN_TYPES = ("CONTROL_NET",)

    FUNCTION = "set_controlnet_type"

    def set_controlnet_type(self, control_net, type):
        control_net = control_net.copy()
        type_number = SHAKKERLABS_UNION_CONTROLNET_TYPES.get(type, -1)
        if type_number >= 0:
            control_net.set_extra_arg("control_type", [type_number])
        else:
            control_net.set_extra_arg("control_type", [])

        return (control_net,)

class ModelSaveKJ:
    def __init__(self):
        self.output_dir = folder_paths.get_output_directory()

    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model": ("MODEL",),
                              "filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),
                              "model_key_prefix": ("STRING", {"default": "model.diffusion_model."}),
                              },
                "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
    RETURN_TYPES = ()
    FUNCTION = "save"
    OUTPUT_NODE = True

    CATEGORY = "advanced/model_merging"

    def save(self, model, filename_prefix, model_key_prefix, prompt=None, extra_pnginfo=None):
        full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
    
        output_checkpoint = f"{filename}_{counter:05}_.safetensors"
        output_checkpoint = os.path.join(full_output_folder, output_checkpoint)

        load_models = [model]

        model_management.load_models_gpu(load_models)
        default_prefix = "model.diffusion_model."

        sd = model.state_dict_for_saving(None, None, None)

        new_sd = {}
        for k in sd:
            if k.startswith(default_prefix):
                new_key = model_key_prefix + k[len(default_prefix):]
            else:
                new_key = k  # In case the key doesn't start with the default prefix, keep it unchanged
            t = sd[k]
            if not t.is_contiguous():
                t = t.contiguous()
            new_sd[new_key] = t
        logging.info(f"full_output_folder: {full_output_folder}")
        if not os.path.exists(full_output_folder):
            os.makedirs(full_output_folder)
        save_torch_file(new_sd, os.path.join(full_output_folder, output_checkpoint))
        return {}
       
class StyleModelApplyAdvanced:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"conditioning": ("CONDITIONING", ),
                             "style_model": ("STYLE_MODEL", ),
                             "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
                             "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.001}),
                             }}
    RETURN_TYPES = ("CONDITIONING",)
    FUNCTION = "apply_stylemodel"
    CATEGORY = "KJNodes/experimental"
    DESCRIPTION = "StyleModelApply but with strength parameter"

    def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength=1.0):
        cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
        cond = strength * cond
        c = []
        for t in conditioning:
            n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
            c.append(n)
        return (c, )

class AudioConcatenate:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "audio1": ("AUDIO",),
            "audio2": ("AUDIO",),
            "direction": (
            [   'right',
                'left',
            ],
            {
            "default": 'right'
             }),
        }}

    RETURN_TYPES = ("AUDIO",)
    FUNCTION = "concanate"
    CATEGORY = "KJNodes/audio"
    DESCRIPTION = """
Concatenates the audio1 to audio2 in the specified direction.
"""

    def concanate(self, audio1, audio2, direction):
        sample_rate_1 = audio1["sample_rate"]
        sample_rate_2 = audio2["sample_rate"]
        if sample_rate_1 != sample_rate_2:
            raise ValueError("Sample rates of the two audios do not match")

        waveform_1 = audio1["waveform"]
        waveform_2 = audio2["waveform"]

        # Concatenate based on the specified direction
        if direction == 'right':
            concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2)  # Concatenate along width
        elif direction == 'left':
            concatenated_audio= torch.cat((waveform_2, waveform_1), dim=2)  # Concatenate along width
        return ({"waveform": concatenated_audio, "sample_rate": sample_rate_1},)

class LeapfusionHunyuanI2V:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "latent": ("LATENT",),
                "index": ("INT", {"default": 0, "min": -1, "max": 1000, "step": 1,"tooltip": "The index of the latent to be replaced. 0 for first frame and -1 for last"}),
                "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of steps to apply"}),
                "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of steps to apply"}),
                "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.001}),
            }
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "KJNodes/hunyuanvideo"

    def patch(self, model, latent, index, strength, start_percent, end_percent):

        def outer_wrapper(samples, index, start_percent, end_percent):
            def unet_wrapper(apply_model, args):
                steps = args["c"]["transformer_options"]["sample_sigmas"]
                inp, timestep, c = args["input"], args["timestep"], args["c"]
                matched_step_index = (steps == timestep).nonzero()
                if len(matched_step_index) > 0:
                    current_step_index = matched_step_index.item()
                else:
                    for i in range(len(steps) - 1):
                        # walk from beginning of steps until crossing the timestep
                        if (steps[i] - timestep[0]) * (steps[i + 1] - timestep[0]) <= 0:
                            current_step_index = i
                            break
                    else:
                        current_step_index = 0
                current_percent = current_step_index / (len(steps) - 1)
                if samples is not None:
                    if start_percent <= current_percent <= end_percent:
                        inp[:, :, [index], :, :] = samples[:, :, [0], :, :].to(inp)
                    else:
                        inp[:, :, [index], :, :] = torch.zeros(1)
                return apply_model(inp, timestep, **c)
            return unet_wrapper
        
        samples = latent["samples"] * 0.476986 * strength
        m = model.clone()
        m.set_model_unet_function_wrapper(outer_wrapper(samples, index, start_percent, end_percent))

        return (m,)

class ImageNoiseAugmentation:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "noise_aug_strength": ("FLOAT", {"default": None, "min": 0.0, "max": 100.0, "step": 0.001}),
                "seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}),
            }
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "add_noise"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
    Add noise to an image.  
    """

    def add_noise(self, image, noise_aug_strength, seed):
        torch.manual_seed(seed)
        sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * noise_aug_strength
        image_noise = torch.randn_like(image) * sigma[:, None, None, None]
        image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
        image_out = image + image_noise
        return image_out,

class VAELoaderKJ:
    video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
    image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
    @staticmethod
    def vae_list(s):
        vaes = folder_paths.get_filename_list("vae")
        approx_vaes = folder_paths.get_filename_list("vae_approx")
        sdxl_taesd_enc = False
        sdxl_taesd_dec = False
        sd1_taesd_enc = False
        sd1_taesd_dec = False
        sd3_taesd_enc = False
        sd3_taesd_dec = False
        f1_taesd_enc = False
        f1_taesd_dec = False

        for v in approx_vaes:
            if v.startswith("taesd_decoder."):
                sd1_taesd_dec = True
            elif v.startswith("taesd_encoder."):
                sd1_taesd_enc = True
            elif v.startswith("taesdxl_decoder."):
                sdxl_taesd_dec = True
            elif v.startswith("taesdxl_encoder."):
                sdxl_taesd_enc = True
            elif v.startswith("taesd3_decoder."):
                sd3_taesd_dec = True
            elif v.startswith("taesd3_encoder."):
                sd3_taesd_enc = True
            elif v.startswith("taef1_encoder."):
                f1_taesd_dec = True
            elif v.startswith("taef1_decoder."):
                f1_taesd_enc = True
            else:
                for tae in s.video_taes:
                    if v.startswith(tae):
                        vaes.append(v)

        if sd1_taesd_dec and sd1_taesd_enc:
            vaes.append("taesd")
        if sdxl_taesd_dec and sdxl_taesd_enc:
            vaes.append("taesdxl")
        if sd3_taesd_dec and sd3_taesd_enc:
            vaes.append("taesd3")
        if f1_taesd_dec and f1_taesd_enc:
            vaes.append("taef1")
        vaes.append("pixel_space")
        return vaes

    @staticmethod
    def load_taesd(name):
        sd = {}
        approx_vaes = folder_paths.get_filename_list("vae_approx")

        encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
        decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))

        enc = load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
        for k in enc:
            sd["taesd_encoder.{}".format(k)] = enc[k]

        dec = load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
        for k in dec:
            sd["taesd_decoder.{}".format(k)] = dec[k]

        if name == "taesd":
            sd["vae_scale"] = torch.tensor(0.18215)
            sd["vae_shift"] = torch.tensor(0.0)
        elif name == "taesdxl":
            sd["vae_scale"] = torch.tensor(0.13025)
            sd["vae_shift"] = torch.tensor(0.0)
        elif name == "taesd3":
            sd["vae_scale"] = torch.tensor(1.5305)
            sd["vae_shift"] = torch.tensor(0.0609)
        elif name == "taef1":
            sd["vae_scale"] = torch.tensor(0.3611)
            sd["vae_shift"] = torch.tensor(0.1159)
        return sd

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": { "vae_name": (s.vae_list(s), ),
                          "device": (["main_device", "cpu"],),
                          "weight_dtype": (["bf16", "fp16", "fp32" ],),
                         }
            }

    RETURN_TYPES = ("VAE",)
    FUNCTION = "load_vae"
    CATEGORY = "KJNodes/vae"

    def load_vae(self, vae_name, device, weight_dtype):
        from comfy.sd import VAE
        metadata = None
        dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[weight_dtype]
        if device == "main_device":
            device = model_management.get_torch_device()
        elif device == "cpu":
            device = torch.device("cpu")

        if vae_name == "pixel_space":
            sd = {}
            sd["pixel_space_vae"] = torch.tensor(1.0)
        elif vae_name in self.image_taes:
            sd = self.load_taesd(vae_name)
        else:
            if os.path.splitext(vae_name)[0] in self.video_taes:
                vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name)
            else:
                vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
            sd, metadata = load_torch_file(vae_path, return_metadata=True)


        is_audio_vae = (
            "vocoder.conv_post.weight" in sd
            or "vocoder.vocoder.conv_post.weight" in sd
            or "vocoder.resblocks.0.convs1.0.weight" in sd
            or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd
        )
        if is_audio_vae:
            sd_audio = state_dict_prefix_replace(dict(sd), {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True)
            vae = VAE(sd=sd_audio, metadata=metadata)
        else:
            vae = VAE(sd=sd, device=device, dtype=dtype, metadata=metadata)
        vae.throw_exception_if_invalid()
        return (vae,)

from comfy.samplers import sampling_function, CFGGuider
class Guider_ScheduledCFG(CFGGuider):

    def set_cfg(self, cfg, start_percent, end_percent):
        self.cfg = cfg
        self.start_percent = start_percent
        self.end_percent = end_percent

    def predict_noise(self, x, timestep, model_options={}, seed=None):
        steps = model_options["transformer_options"]["sample_sigmas"]
        if isinstance(timestep, torch.Tensor):
            timestep_value = timestep.reshape(-1)[0].to(steps)
        else:
            timestep_value = torch.tensor(timestep, device=steps.device, dtype=steps.dtype)
        matched_step_index = torch.isclose(steps, timestep_value).nonzero()
        assert not (isinstance(self.cfg, list) and len(self.cfg) != (len(steps) - 1)), "cfg list length must match step count"
        if len(matched_step_index) > 0:
            current_step_index = matched_step_index.item()
        else:
            for i in range(len(steps) - 1):
                # walk from beginning of steps until crossing the timestep
                if (steps[i] - timestep_value) * (steps[i + 1] - timestep_value) <= 0:
                    current_step_index = i
                    break
            else:
                current_step_index = 0
        current_percent = current_step_index / (len(steps) - 1)

        if self.start_percent <= current_percent <= self.end_percent:
            if isinstance(self.cfg, list):
                cfg = self.cfg[current_step_index]
            else:
                cfg = self.cfg
            uncond = self.conds.get("negative", None)
        else:
            uncond = None
            cfg = 1.0

        return sampling_function(self.inner_model, x, timestep, uncond, self.conds.get("positive", None), cfg, model_options=model_options, seed=seed)            

class ScheduledCFGGuidance:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
                    "model": ("MODEL",),
                    "positive": ("CONDITIONING", ),
                    "negative": ("CONDITIONING", ),
                    "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 100.0, "step": 0.01}),
                    "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step":0.01}),
                    "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01}),
                    },
                }
    RETURN_TYPES = ("GUIDER",)
    FUNCTION = "get_guider"
    CATEGORY = "KJNodes/experimental"
    DESCRiPTION = """
CFG Guider that allows for scheduled CFG changes over steps, the steps outside the range will use CFG 1.0 thus being processed faster.  
cfg input can be a list of floats matching step count, or a single float for all steps.  
"""

    def get_guider(self, model, cfg, positive, negative, start_percent, end_percent):
        guider = Guider_ScheduledCFG(model) 
        guider.set_conds(positive, negative)
        guider.set_cfg(cfg, start_percent, end_percent)
        return (guider, )


class ApplyRifleXRoPE_WanVideo:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "latent": ("LATENT", {"tooltip": "Only used to get the latent count"}),
                "k": ("INT", {"default": 6, "min": 1, "max": 100, "step": 1, "tooltip": "Index of intrinsic frequency"}),
            } 
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"
    CATEGORY = "KJNodes/wan"
    EXPERIMENTAL = True
    DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx"

    def patch(self, model, latent, k):
        model_class = model.model.diffusion_model
        
        model_clone = model.clone()
        num_frames = latent["samples"].shape[2]
        d = model_class.dim // model_class.num_heads

        rope_embedder = EmbedND_RifleX(
            d, 
            10000.0, 
            [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)],
            num_frames,
            k
            )
        
        model_clone.add_object_patch(f"diffusion_model.rope_embedder", rope_embedder)
                    
        return (model_clone, )
    
class ApplyRifleXRoPE_HunuyanVideo:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "latent": ("LATENT", {"tooltip": "Only used to get the latent count"}),
                "k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1, "tooltip": "Index of intrinsic frequency"}),
            } 
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"
    CATEGORY = "KJNodes/hunyuanvideo"
    EXPERIMENTAL = True
    DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx"

    def patch(self, model, latent, k):
        model_class = model.model.diffusion_model
        
        model_clone = model.clone()
        num_frames = latent["samples"].shape[2]

        pe_embedder = EmbedND_RifleX(
            model_class.params.hidden_size // model_class.params.num_heads, 
            model_class.params.theta, 
            model_class.params.axes_dim, 
            num_frames,
            k
            )
        
        model_clone.add_object_patch(f"diffusion_model.pe_embedder", pe_embedder)
                    
        return (model_clone, )

def rope_riflex(pos, dim, theta, L_test, k):
    from einops import rearrange
    assert dim % 2 == 0
    if model_management.is_device_mps(pos.device) or model_management.is_intel_xpu() or model_management.is_directml_enabled():
        device = torch.device("cpu")
    else:
        device = pos.device

    scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
    omega = 1.0 / (theta**scale)

    # RIFLEX modification - adjust last frequency component if L_test and k are provided
    if k and L_test:
        omega[k-1] = 0.9 * 2 * torch.pi / L_test

    out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
    return out.to(dtype=torch.float32, device=pos.device)

class EmbedND_RifleX(nn.Module):
    def __init__(self, dim, theta, axes_dim, num_frames, k):
        super().__init__()
        self.dim = dim
        self.theta = theta
        self.axes_dim = axes_dim
        self.num_frames = num_frames
        self.k = k

    def forward(self, ids):
        n_axes = ids.shape[-1]
        emb = torch.cat(
            [rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k if i == 0 else 0) for i in range(n_axes)],
            dim=-3,
        )
        return emb.unsqueeze(1)


class Timer:
    def __init__(self, name):
        self.name = name
        self.start_time = None
        self.elapsed = 0

class TimerNodeKJ:
    @classmethod
    
    def INPUT_TYPES(s):
      return {
        "required": {
            "any_input": (IO.ANY, ),
            "mode": (["start", "stop"],),
            "name": ("STRING", {"default": "Timer"}),
        },
        "optional": {
            "timer": ("TIMER",),
        },
	}

    RETURN_TYPES = (IO.ANY, "TIMER", "INT", )
    RETURN_NAMES = ("any_output", "timer", "time")
    FUNCTION = "timer"
    CATEGORY = "KJNodes/misc"

    def timer(self, mode, name, any_input=None, timer=None):
        if timer is None:
            if mode == "start":
                timer = Timer(name=name)            
                timer.start_time = time.time()
                return {"ui": {
                "text": [f"{timer.start_time}"]}, 
                "result": (any_input, timer, 0) 
                 }
        elif mode == "stop" and timer is not None:
            end_time = time.time()
            timer.elapsed = int((end_time - timer.start_time) * 1000)
            timer.start_time = None
            return (any_input, timer, timer.elapsed)

class HunyuanVideoEncodeKeyframesToCond:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "model": ("MODEL",),
                    "positive": ("CONDITIONING", ),
                    "vae": ("VAE", ),
                    "start_frame": ("IMAGE", ),
                    "end_frame": ("IMAGE", ),
                    "num_frames": ("INT", {"default": 33, "min": 2, "max": 4096, "step": 1}),
                    "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
                    "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
                    "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}),
                    "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
                    },
                    "optional": {
                        "negative": ("CONDITIONING", ),
                    }
                }

    RETURN_TYPES = ("MODEL", "CONDITIONING","CONDITIONING","LATENT")
    RETURN_NAMES = ("model", "positive", "negative", "latent")
    FUNCTION = "encode"

    CATEGORY = "KJNodes/hunyuanvideo"

    def encode(self, model, positive, start_frame, end_frame, num_frames, vae, tile_size, overlap, temporal_size, temporal_overlap, negative=None):

        model_clone = model.clone()

        model_clone.add_object_patch("concat_keys", ("concat_image",))

       
        x = (start_frame.shape[1] // 8) * 8
        y = (start_frame.shape[2] // 8) * 8

        if start_frame.shape[1] != x or start_frame.shape[2] != y:
            x_offset = (start_frame.shape[1] % 8) // 2
            y_offset = (start_frame.shape[2] % 8) // 2
            start_frame = start_frame[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
        if end_frame.shape[1] != x or end_frame.shape[2] != y:
            x_offset = (start_frame.shape[1] % 8) // 2
            y_offset = (start_frame.shape[2] % 8) // 2
            end_frame = end_frame[:,x_offset:x + x_offset, y_offset:y + y_offset,:]

        video_frames = torch.zeros(num_frames-2, start_frame.shape[1], start_frame.shape[2], start_frame.shape[3], device=start_frame.device, dtype=start_frame.dtype)
        video_frames = torch.cat([start_frame, video_frames, end_frame], dim=0)

        concat_latent = vae.encode_tiled(video_frames[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)

        out_latent = {}
        out_latent["samples"] = torch.zeros_like(concat_latent)

        out = []
        for conditioning in [positive, negative if negative is not None else []]:
            c = []
            for t in conditioning:
                d = t[1].copy()
                d["concat_latent_image"] = concat_latent
                n = [t[0], d]
                c.append(n)
            out.append(c)
        if len(out) == 1:
            out.append(out[0])
        return (model_clone, out[0], out[1], out_latent)
    

class LazySwitchKJ:
    def __init__(self):
        pass

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "switch": ("BOOLEAN",),
                "on_false": (IO.ANY, {"lazy": True}),
                "on_true": (IO.ANY, {"lazy": True}),
            },
        }

    RETURN_TYPES = (IO.ANY,)
    FUNCTION = "switch"
    CATEGORY = "KJNodes/misc"
    DESCRIPTION = "Controls flow of execution based on a boolean switch."

    def check_lazy_status(self, switch, on_false=None, on_true=None):
        if switch and on_true is None:
            return ["on_true"]
        if not switch and on_false is None:
            return ["on_false"]

    def switch(self, switch, on_false = None, on_true=None):
        value = on_true if switch else on_false
        return (value,)


from comfy.patcher_extension import WrappersMP
from comfy.sampler_helpers import prepare_mask
class TTM_OuterSampleWrapper:
    def __init__(self, mask, steps):
        self.mask = mask
        self.steps = steps

    def __call__(self, executor, noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes):
        guider = executor.class_obj
        guider.model_options
        wrappers = guider.model_options["transformer_options"]["wrappers"]
        w = wrappers.setdefault(WrappersMP.APPLY_MODEL, {})

        if self.mask is not None:
            motion_mask = self.mask.reshape((-1, 1, self.mask.shape[-2], self.mask.shape[-1]))
            shape = latent_shapes[0]
            motion_mask = prepare_mask(motion_mask, shape, noise.device)

        scale_latent_inpaint = guider.model_patcher.model.scale_latent_inpaint
        w["TTM_ApplyModel_Wrapper"] = [TTM_ApplyModel_Wrapper(latent_image, noise, motion_mask, self.steps, scale_latent_inpaint)]

        out = executor(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)

        return out


class TTM_ApplyModel_Wrapper:
    def __init__(self, reference_samples, noise, motion_mask, steps, scale_latent_inpaint):
        self.reference_samples = reference_samples
        self.noise = noise
        self.motion_mask = motion_mask
        self.steps = steps
        self.scale_latent_inpaint = scale_latent_inpaint

    def __call__(self, executor, x, t, c_concat, c_crossattn, control, transformer_options, **kwargs):
        sigmas = transformer_options["sample_sigmas"]

        matched = (sigmas == t).nonzero(as_tuple=True)[0]
        if matched.numel() > 0:
            current_step_index = matched.item()
        else:
            crossing = ((sigmas[:-1] - t) * (sigmas[1:] - t) <= 0).nonzero(as_tuple=True)[0]
            current_step_index = crossing.item() if crossing.numel() > 0 else 0

        next_sigma = sigmas[current_step_index + 1] if current_step_index < len(sigmas) - 1 else sigmas[current_step_index]

        if current_step_index != 0 and current_step_index < self.steps:
            noisy_latent = self.scale_latent_inpaint(x=x, sigma=torch.tensor([next_sigma]), noise=self.noise.to(x), latent_image=self.reference_samples.to(x))
            if self.motion_mask is not None:
                x = x * (1-self.motion_mask).to(x) + noisy_latent * self.motion_mask.to(x)
            else:
                x = noisy_latent

        return executor(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)


class LatentInpaintTTM:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                        "model": ("MODEL", ),
                        "steps": ("INT", {"default": 7, "min": 0, "max": 888, "step": 1, "tooltip": "Number of steps to apply TTM inpainting for."}),
                        },
                "optional": {
                        "mask": ("MASK", {"tooltip": "Latent mask where white (1.0) is the area to inpaint and black (0.0) is the area to keep unchanged."}),
                        }
                }
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"
    EXPERIMENTAL = True
    DESCRIPTION = "https://github.com/time-to-move/TTM"
    SEARCH_ALIASES = ["time to move"]
    CATEGORY = "KJNodes/experimental"

    def patch(self, model, steps, mask=None):
        m = model.clone()
        m.add_wrapper_with_key(WrappersMP.OUTER_SAMPLE, "TTM_OuterSampleWrapper", TTM_OuterSampleWrapper(mask, steps))
        return (m, )


class SimpleCalculatorKJ(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        template = io.Autogrow.TemplateNames(input=io.MultiType.Input("var", [io.Int, io.Float, io.Boolean], optional=True), names=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"], min=2)
        return io.Schema(
            node_id="SimpleCalculatorKJ",
            category="KJNodes/misc",
            description="""
Calculator node that evaluates a mathematical expression using inputs a and b.  
    Supported operations: +, -, *, /, //, %, **, <<, >>, unary +/-  
    Supported comparisons: ==, !=, <, <=, >, >=  
    Supported logic: and, or, not  
    Supported functions: abs(), round(), min(), max(), pow(), sqrt(), sin(), cos(), tan(), log(), log10(), exp(), floor(), ceil()  
    Supported constants: pi, euler, True, False  
""",
            search_aliases=["math", "arithmetic", "expression", "logic"],
            inputs=[
                io.String.Input("expression", default="a + b", multiline=True),
                io.Autogrow.Input("variables", template=template),
            ],
            outputs=[
                io.Float.Output(),
                io.Int.Output(),
                io.Boolean.Output(),
            ],
        )

    @classmethod
    def execute(cls, variables, expression, a=None, b=None) -> io.NodeOutput:
        import ast
        import operator

        # Allowed operations
        allowed_operators = {
            ast.Add: operator.add, ast.Sub: operator.sub, ast.Mult: operator.mul, ast.Div: operator.truediv,
            ast.FloorDiv: operator.floordiv, ast.Mod: operator.mod, ast.Pow: operator.pow, 
            ast.USub: operator.neg, ast.UAdd: operator.pos, ast.LShift: operator.lshift, 
            ast.RShift: operator.rshift, ast.Eq: operator.eq, ast.NotEq: operator.ne, ast.Lt: operator.lt,
            ast.LtE: operator.le, ast.Gt: operator.gt, ast.GtE: operator.ge, ast.And: operator.and_, 
            ast.Or: operator.or_, ast.Not: operator.not_,
        }

        # Allowed functions
        allowed_functions = {
            'abs': abs, 'round': round, 'min': min, 'max': max,
            'pow': pow, 'sqrt': math.sqrt, 'sin': math.sin,
            'cos': math.cos, 'tan': math.tan, 'log': math.log,
            'log10': math.log10, 'exp': math.exp, 'floor': math.floor,
            'ceil': math.ceil
        }

        # Allowed constants - start with pi, e, True, False
        allowed_names = {'pi': math.pi, 'euler': math.e, 'True': True, 'False': False}

        # Add all variables from autogrow to allowed_names
        for var_name, var_value in variables.items():
            allowed_names[var_name] = var_value

        # Backwards compatibility: add a and b if they're provided (for old workflows)
        if a is not None:
            allowed_names['a'] = a
        if b is not None:
            allowed_names['b'] = b

        def eval_node(node):
            if isinstance(node, ast.Constant):  # Numbers and booleans
                return node.value
            elif isinstance(node, ast.Name):  # Variables
                if node.id in allowed_names:
                    return allowed_names[node.id]
                raise ValueError(f"Name '{node.id}' is not allowed")
            elif isinstance(node, ast.BinOp):  # Binary operations
                if type(node.op) not in allowed_operators:
                    raise ValueError(f"Operator {type(node.op).__name__} is not allowed")
                left = eval_node(node.left)
                right = eval_node(node.right)
                return allowed_operators[type(node.op)](left, right)
            elif isinstance(node, ast.UnaryOp):  # Unary operations
                if type(node.op) not in allowed_operators:
                    raise ValueError(f"Operator {type(node.op).__name__} is not allowed")
                operand = eval_node(node.operand)
                return allowed_operators[type(node.op)](operand)
            elif isinstance(node, ast.Compare):  # Comparison operations
                left = eval_node(node.left)
                for op, comparator in zip(node.ops, node.comparators):
                    if type(op) not in allowed_operators:
                        raise ValueError(f"Operator {type(op).__name__} is not allowed")
                    right = eval_node(comparator)
                    result = allowed_operators[type(op)](left, right)
                    if not result:
                        return False
                    left = right
                return True
            elif isinstance(node, ast.BoolOp):  # Boolean operations (and, or)
                if type(node.op) not in allowed_operators:
                    raise ValueError(f"Operator {type(node.op).__name__} is not allowed")
                values = [eval_node(value) for value in node.values]
                if isinstance(node.op, ast.And):
                    return all(values)
                elif isinstance(node.op, ast.Or):
                    return any(values)
            elif isinstance(node, ast.Call):  # Function calls
                if not isinstance(node.func, ast.Name):
                    raise ValueError("Only simple function calls are allowed")
                if node.func.id not in allowed_functions:
                    raise ValueError(f"Function '{node.func.id}' is not allowed")
                args = [eval_node(arg) for arg in node.args]
                return allowed_functions[node.func.id](*args)
            else:
                raise ValueError(f"Node type {type(node).__name__} is not allowed")

        try:
            tree = ast.parse(expression, mode='eval')
            result = eval_node(tree.body)
            return io.NodeOutput(float(result), int(result), bool(result))
        except Exception as e:
            logging.error(f"CalculatorKJ Error: {str(e)}")
            return io.NodeOutput(0.0, 0, False)


class GetTrackRange(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="GetTrackRange",
            category="conditioning/video_models",
            inputs=[
                io.Tracks.Input("tracks"),
                io.Int.Input("start_index", default=24, min=-10000, max=10000, step=1),
                io.Int.Input("num_frames", default=10, min=1, max=10000, step=1),
            ],
            outputs=[
                io.Tracks.Output(),
            ],
        )

    @classmethod
    def execute(cls, tracks, start_index, num_frames) -> io.NodeOutput:
        track_path = tracks["track_path"]
        mask = tracks["track_visibility"]
        total_frames = track_path.shape[0]

        if start_index < 0:
            start_index = total_frames + start_index
        start_index = max(0, min(start_index, total_frames))

        # Clamp end_index
        end_index = max(0, min(start_index + num_frames, total_frames))

        tracks_out = track_path[start_index:end_index, ...]
        mask_out = mask[start_index:end_index, ...]

        out_track = {
            "track_path": tracks_out,
            "track_visibility": mask_out,
        }
        return io.NodeOutput(out_track)

class AddNoiseToTrackPath(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="AddNoiseToTrackPath",
            category="conditioning/video_models",
            inputs=[
                io.Tracks.Input("tracks"),
                io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
                io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, step=1),
                io.Float.Input("noise_x_ratio", default=1.0, min=0.0, max=100.0, step=0.01,
                             tooltip="Multiplier for horizontal noise component"),
                io.Float.Input("noise_y_ratio", default=1.0, min=0.0, max=100.0, step=0.01,
                             tooltip="Multiplier for vertical noise component"),
                io.Float.Input("noise_temporal_ratio", default=1.0, min=0.0, max=100.0, step=0.01,
                             tooltip="Multiplier for temporal (frame-to-frame) noise"),
            ],
            outputs=[
                io.Tracks.Output(),
            ],
        )

    @classmethod
    def execute(cls, tracks, strength, seed, noise_x_ratio, noise_y_ratio, noise_temporal_ratio) -> io.NodeOutput:
        track_path = tracks["track_path"].clone()
        mask = tracks["track_visibility"]

        torch.manual_seed(seed)
        noise = torch.randn_like(track_path) * strength

        # Apply directional scaling to noise
        noise[..., 0] *= noise_x_ratio  # X coordinate noise
        noise[..., 1] *= noise_y_ratio  # Y coordinate noise

        # Apply temporal smoothing if temporal ratio is less than 1
        if noise_temporal_ratio < 1.0:
            num_frames = track_path.shape[0]
            smoothed_noise = noise.clone()
            kernel_size = max(1, int((1.0 - noise_temporal_ratio) * 10))

            for i in range(num_frames):
                start_idx = max(0, i - kernel_size // 2)
                end_idx = min(num_frames, i + kernel_size // 2 + 1)
                smoothed_noise[i] = noise[start_idx:end_idx].mean(dim=0)

            noise = smoothed_noise * noise_temporal_ratio + noise * (1 - noise_temporal_ratio)

        track_path = track_path + noise

        out_track = {
            "track_path": track_path,
            "track_visibility": mask,
        }
        return io.NodeOutput(out_track)


class VAEDecodeLoopKJ:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "samples": ("LATENT", {"tooltip": "The latent to be decoded."}),
                "vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."}),
                "overlap_latent_frames": ("INT", {"default": 2, "min": 2, "max": 8, "step": 1, "tooltip": "Number of frames to blend for seamless loop, for Wan 2 works and HunyuanVideo 1.5 should use 4"}),
            }
        }
    RETURN_TYPES = ("IMAGE",)
    OUTPUT_TOOLTIPS = ("The decoded images.",)
    FUNCTION = "decode"
    CATEGORY = "KJNodes/vae"
    DESCRIPTION = "Video latent VAE decoding to fix artifacts on loop seams."

    def decode(self, vae, samples, overlap_latent_frames):
        latents = samples["samples"]

        images = vae.decode(latents)
        if overlap_latent_frames <= 0:
            if len(images.shape) == 5:
                images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
            return (images, )

        end_frames = overlap_latent_frames + 1
        start_frames = overlap_latent_frames

        temp_images = vae.decode(torch.cat([latents[:, :, -end_frames:]] + [latents[:, :, :start_frames]], dim=2)).cpu().float()

        total_concat = end_frames + start_frames
        temp_start = total_concat * 2 - 1
        main_start = total_concat + (overlap_latent_frames if overlap_latent_frames > 2 else 0)

        images = torch.cat([temp_images[:, temp_start:].to(images), images[:, main_start:]], dim=1)
        if len(images.shape) == 5:
            images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])

        return (images, )

class WanImageToVideoSVIPro(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="WanImageToVideoSVIPro",
            category="conditioning/video_models",
            inputs=[
                io.Conditioning.Input("positive"),
                io.Conditioning.Input("negative"),
                io.Int.Input("length", default=81, min=1, max=MAX_RESOLUTION, step=4),
                io.Latent.Input("anchor_samples"),
                io.Latent.Input("prev_samples", optional=True),
                io.Int.Input("motion_latent_count", default=1, min=0, max=128, step=1),
            ],
            outputs=[
                io.Conditioning.Output(display_name="positive"),
                io.Conditioning.Output(display_name="negative"),
                io.Latent.Output(display_name="latent"),
            ],
        )

    @classmethod
    def execute(cls, positive, negative, length, motion_latent_count, anchor_samples, prev_samples=None) -> io.NodeOutput:
        anchor_latent = anchor_samples["samples"].clone()

        B, C, T, H, W = anchor_latent.shape
        empty_latent = torch.zeros([B, 16, ((length - 1) // 4) + 1, H, W], device=model_management.intermediate_device())

        total_latents = (length - 1) // 4 + 1
        device = anchor_latent.device
        dtype = anchor_latent.dtype

        if prev_samples is None or motion_latent_count == 0:
            padding_size = total_latents - anchor_latent.shape[2]
            image_cond_latent = anchor_latent
        else:
            motion_latent = prev_samples["samples"][:, :, -motion_latent_count:].clone()
            padding_size = total_latents - anchor_latent.shape[2] - motion_latent.shape[2]
            image_cond_latent = torch.cat([anchor_latent, motion_latent], dim=2)

        padding = torch.zeros(1, C, padding_size, H, W, dtype=dtype, device=device)
        padding = comfy.latent_formats.Wan21().process_out(padding)
        image_cond_latent = torch.cat([image_cond_latent, padding], dim=2)

        mask = torch.ones((1, 1, empty_latent.shape[2], H, W), device=device, dtype=dtype)
        mask[:, :, :1] = 0.0

        positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": image_cond_latent, "concat_mask": mask})
        negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": image_cond_latent, "concat_mask": mask})

        out_latent = {}
        out_latent["samples"] = empty_latent
        return io.NodeOutput(positive, negative, out_latent)

class DeprecatedCompileNodeKJ:
    @classmethod
    def INPUT_TYPES(s):
        return {
        "required": {
            "model": (IO.ANY,),
        },
    }
    RETURN_TYPES = (IO.ANY,)
    FUNCTION = "passthrough"
    CATEGORY = "KJNodes/deprecated"
    DESCRIPTION = "This node has been replaced with TorchCompileModelAdvanced node, please use that instead."
    def passthrough(self, model):
        return (model,)


class VisualizeSigmasKJ(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="VisualizeSigmasKJ",
            category="KJNodes/misc",
            inputs=[
                io.Sigmas.Input("sigmas"),
                io.Int.Input("start_step", default=0, min=-1, max=1000, step=1,
                             tooltip="Step index to mark as the start of a range (inclusive). Set to -1 to disable."),
                io.Int.Input("end_step", default=-1, min=-1, max=1000, step=1,
                             tooltip="Step index to mark as the end of a range (inclusive). Set to - 1 to disable."),
            ],
            outputs=[
                io.Sigmas.Output(display_name="sigmas_out"),
                io.Image.Output(display_name="image"),
            ],
        )

    @classmethod
    def execute(cls, sigmas, start_step=0, end_step=-1) -> io.NodeOutput:

        start_idx = 0
        end_idx = len(sigmas) - 1

        if isinstance(start_step, float):
            idxs = (sigmas <= start_step).nonzero(as_tuple=True)[0]
            if len(idxs) > 0:
                start_idx = idxs[0].item()
        elif isinstance(start_step, int):
            if start_step > 0:
                start_idx = start_step

        if isinstance(end_step, float):
            idxs = (sigmas >= end_step).nonzero(as_tuple=True)[0]
            if len(idxs) > 0:
                end_idx = idxs[-1].item()
        elif isinstance(end_step, int):
            if end_step != -1:
                end_idx = end_step - 1

        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
        sigmas_np = sigmas.cpu().numpy()
        if not np.isclose(sigmas_np[-1], 0.0, atol=1e-6):
            sigmas_np = np.append(sigmas_np, 0.0)
        buf = BytesIO()
        fig = plt.figure(facecolor='#353535')
        ax = fig.add_subplot(111)
        ax.set_facecolor('#353535')  # Set axes background color
        x_values = range(0, len(sigmas_np))
        ax.plot(x_values, sigmas_np)
        # Annotate each sigma value
        ax.scatter(x_values, sigmas_np, color='white', s=20, zorder=3)  # Small dots at each sigma
        for x, y in zip(x_values, sigmas_np):
            # Show all annotations if few steps, or just show split step annotations
            show_annotation = len(sigmas_np) <= 10
            is_split_step = (start_idx > 0 and x == start_idx) or (end_idx != -1 and x == end_idx + 1)

            if show_annotation or is_split_step:
                color = 'orange'
                if is_split_step:
                    color = 'yellow'
                ax.annotate(f"{y:.3f}", (x, y), textcoords="offset points", xytext=(10, 1), ha='center', color=color, fontsize=12)
        ax.set_xticks(x_values)
        ax.set_title("Sigmas", color='white')           # Title font color
        ax.set_xlabel("Step", color='white')            # X label font color
        ax.set_ylabel("Sigma Value", color='white')     # Y label font color
        ax.tick_params(axis='x', colors='white', labelsize=10)        # X tick color
        ax.tick_params(axis='y', colors='white', labelsize=10)        # Y tick color
        # Add split point if end_step is defined
        end_idx += 1
        if end_idx != -1 and 0 <= end_idx < len(sigmas_np) - 1:
            ax.axvline(end_idx, color='red', linestyle='--', linewidth=2, label='end_step split')
        # Add split point if start_step is defined
        if start_idx > 0 and 0 <= start_idx < len(sigmas_np):
            ax.axvline(start_idx, color='green', linestyle='--', linewidth=2, label='start_step split')
        if (end_idx != -1 and 0 <= end_idx < len(sigmas_np)) or (start_idx > 0 and 0 <= start_idx < len(sigmas_np)):
            handles, labels = ax.get_legend_handles_labels()
            if labels:
                ax.legend()
        # Draw shaded range
        range_start_idx = start_idx if start_idx > 0 else 0
        range_end_idx = end_idx if end_idx > 0 and end_idx < len(sigmas_np) else len(sigmas_np) - 1
        if range_start_idx < range_end_idx:
            ax.axvspan(range_start_idx, range_end_idx, color='lightblue', alpha=0.1, label='Sampled Range')


        plt.tight_layout()
        fig.canvas.draw()
        w, h = fig.canvas.get_width_height()
        try:
            buf = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
            buf = buf.reshape(h, w, 4)
            buf = buf[:, :, [1, 2, 3]]  # Convert ARGB to RGB
        except AttributeError:
            buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            buf = buf.reshape(h, w, 3).copy()
        image = torch.from_numpy(buf).float() / 255.0
        image = image.unsqueeze(0) #(H, W, C) -> (1, H, W, C)
        plt.close(fig)

        sigmas_out = sigmas[start_idx:end_idx + 1] if end_idx != -1 else sigmas[start_idx:]

        return io.NodeOutput(sigmas_out,image)

class PreviewLatentNoiseMask(io.ComfyNode):

    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="PreviewLatentNoiseMask",
            category="KJNodes/latents",
            description="Previews the latent noise mask",
            inputs=[
                io.Latent.Input("latent",),
            ],
            outputs=[
                io.Mask.Output(display_name="mask"),
            ],
        )

    @classmethod
    def execute(cls, latent) -> io.NodeOutput:
        noise_mask = latent.get("noise_mask", None)
        if noise_mask is None:
            return io.NodeOutput(torch.zeros((1, 64, 64)))
        noise_mask = noise_mask.clone()

        if noise_mask.ndim == 5:
            noise_mask = noise_mask[0, 0]

        return io.NodeOutput(noise_mask)

class PlaySoundKJ(io.ComfyNode):
    """Plays audio in the browser when execution reaches this node."""

    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="PlaySoundKJ",
            category="KJNodes/audio",
            description="Plays the input audio in the browser. Modes: 'always' plays on every execution, 'on_empty_queue' plays only when the queue finishes, 'on_change' plays only when the audio content changes. Duration limits playback length (0 = full audio).",
            inputs=[
                io.AnyType.Input("any_input", optional=True),
                io.Audio.Input("audio", optional=True),
                io.String.Input("audio_path", default="", tooltip="Path to an audio file. Used when audio input is not connected."),
                io.Combo.Input("mode", options=["always", "on_empty_queue", "on_change"], default="always"),
                io.Float.Input("volume", default=0.5, min=0.0, max=1.0, step=0.01),
                io.Float.Input("duration", default=5.0, min=0.0, max=300.0, step=0.1, tooltip="Duration in seconds to play. 0 = play full audio."),
            ],
            outputs=[
                io.AnyType.Output("any_output", display_name="any_output"),
            ],
            is_output_node=True,
        )

    @classmethod
    def fingerprint_inputs(cls, **kwargs):
        if kwargs.get("mode") == "on_change":
            return False
        return float("NaN")

    @staticmethod
    def _generate_chime():
        sr = 32000
        C, Ab, Bb = 523.26, 421.30, 466.16  # note frequencies
        e = 0.14  # eighth note — adjust to change tempo
        S, M, H = 16, 7, 2.5  # staccato / medium / held decay
        melody = [
            (C,e,S), (C,e,S), (C,e,S), (C,e*3,M),
            (Ab,e*3,M), (Bb,e*3,M), (C,e*2,S), (Bb,e,S), (C,e*5,H),
        ]
        k = torch.exp(-torch.linspace(-2, 2, 15) ** 2)
        k = (k / k.sum()).reshape(1, 1, -1)
        parts = []
        for freq, dur, decay in melody:
            t = torch.linspace(0, dur, int(sr * dur))
            tone = torch.tanh(3 * torch.sin(2 * math.pi * freq * t))
            tone = torch.nn.functional.conv1d(tone.reshape(1, 1, -1), k, padding=7).squeeze()
            parts.append(tone * torch.exp(-t * decay))
        wav = torch.cat(parts) * 0.45
        return {"waveform": wav.unsqueeze(0).unsqueeze(0), "sample_rate": sr}

    @classmethod
    def execute(cls, audio=None, audio_path="", mode="always", volume=0.5, duration=5.0, any_input=None) -> io.NodeOutput:
        if audio is None:
            if audio_path:
                import av
                with av.open(audio_path) as af:
                    stream = af.streams.audio[0]
                    sr = stream.codec_context.sample_rate
                    frames = []
                    for frame in af.decode(streams=stream.index):
                        buf = torch.from_numpy(frame.to_ndarray())
                        if buf.shape[0] != stream.channels:
                            buf = buf.view(-1, stream.channels).t()
                        frames.append(buf)
                    wav = torch.cat(frames, dim=1).float()
                audio = {"waveform": wav.unsqueeze(0), "sample_rate": sr}
            else:
                audio = cls._generate_chime()

        preview = ui.PreviewAudio(audio, cls=cls)
        ui_dict = preview.as_dict()
        ui_dict["audio_hash"] = [hash(audio["waveform"].sum().item())]
        return io.NodeOutput(
            any_input,
            ui=ui_dict,
        )
