import numpy as np
import time
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import base64
import random
import math
import os
import re
import json
import importlib
import hashlib
import pathlib
import logging
from io import BytesIO

try:
    import cv2
    HAS_CV2 = True
except ImportError:
    logging.warning("OpenCV not installed")
    HAS_CV2 = False

from PIL import ImageGrab, ImageDraw, ImageFont, Image, ImageOps, ImageSequence, ImageStat
from PIL.PngImagePlugin import PngInfo

from nodes import MAX_RESOLUTION, SaveImage
from comfy_extras.nodes_mask import composite
from comfy.cli_args import args
from comfy.utils import ProgressBar, common_upscale, tiled_scale_multidim
from comfy import model_management
from comfy_api.latest import io, InputImpl, Types, ui
from fractions import Fraction
import node_helpers
import folder_paths

from ..utility.utility import string_to_color

try:
    from server import PromptServer, BinaryEventTypes
except ImportError:
    PromptServer = None
    BinaryEventTypes = None
from concurrent.futures import ThreadPoolExecutor

script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

class ImagePass:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {               
            },
            "optional": {
                "image": ("IMAGE",),
            },
        }
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "passthrough"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Passes the image through without modifying it.
"""

    def passthrough(self, image=None):
        return image,

class ColorMatch:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image_ref": ("IMAGE",),
                "image_target": ("IMAGE",),
                "method": (['mkl','hm', 'reinhard', 'mvgd', 'hm-mvgd-hm', 'hm-mkl-hm'], {
               "default": 'mkl'
            }),
            },
            "optional": {
                "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                "multithread": ("BOOLEAN", {"default": True}),
            }
        }

    CATEGORY = "KJNodes/image"

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = "colormatch"
    DEPRECATED = True
    DESCRIPTION = """
color-matcher enables color transfer across images which comes in handy for automatic  
color-grading of photographs, paintings and film sequences as well as light-field  
and stopmotion corrections.  

The methods behind the mappings are based on the approach from Reinhard et al.,  
the Monge-Kantorovich Linearization (MKL) as proposed by Pitie et al. and our analytical solution  
to a Multi-Variate Gaussian Distribution (MVGD) transfer in conjunction with classical histogram   
matching. As shown below our HM-MVGD-HM compound outperforms existing methods.   
https://github.com/hahnec/color-matcher/

"""

    def colormatch(self, image_ref, image_target, method, strength=1.0, multithread=True):
        # Skip unnecessary processing
        if strength == 0:
            return (image_target,)

        try:
            from color_matcher import ColorMatcher
        except ImportError as e:
            raise ImportError("Can't import color-matcher, did you install requirements.txt? Manual install: pip install color-matcher") from e

        image_ref = image_ref.cpu()
        image_target = image_target.cpu()
        batch_size = image_target.size(0)

        images_target = image_target.squeeze()
        images_ref = image_ref.squeeze()

        image_ref_np = images_ref.numpy()
        images_target_np = images_target.numpy()

        def process(i):
            cm = ColorMatcher()
            image_target_np_i = images_target_np if batch_size == 1 else images_target[i].numpy()
            image_ref_np_i = image_ref_np if image_ref.size(0) == 1 else images_ref[i].numpy()
            try:
                image_result = cm.transfer(src=image_target_np_i, ref=image_ref_np_i, method=method) # Avoid potential blur when only the fully color-matched image is used
                if strength != 1:
                    image_result = image_target_np_i + strength * (image_result - image_target_np_i)

                return torch.from_numpy(image_result)

            except Exception as e:
                logging.warning(f"Thread {i} error: {e}")
                return torch.from_numpy(image_target_np_i)  # fallback

        if multithread and batch_size > 1:
            max_threads = min(os.cpu_count() or 1, batch_size)
            with ThreadPoolExecutor(max_workers=max_threads) as executor:
                out = list(executor.map(process, range(batch_size)))
        else:
            out = [process(i) for i in range(batch_size)]

        out = torch.stack(out, dim=0).to(torch.float32)
        out.clamp_(0, 1)
        return (out,)

class ColorMatchV2(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="ColorMatchV2",
            category="KJNodes/image",
            description="""
color-matcher enables color transfer across images which comes in handy for automatic  
color-grading of photographs, paintings and film sequences as well as light-field  
and stopmotion corrections.  

The methods behind the mappings are based on the approach from Reinhard et al.,  
the Monge-Kantorovich Linearization (MKL) as proposed by Pitie et al. and our analytical solution  
to a Multi-Variate Gaussian Distribution (MVGD) transfer in conjunction with classical histogram   
matching. As shown below our HM-MVGD-HM compound outperforms existing methods.   
https://github.com/hahnec/color-matcher/   

'reinhard_lab_gpu' method uses Kornia for GPU-accelerated color transfer in Lab color space.
""",
            inputs=[
                io.Image.Input("image_target"),
                io.Image.Input("image_ref"),
                io.Combo.Input("method", 
                    options=['mkl', 'hm', 'reinhard', 'mvgd', 'hm-mvgd-hm', 'hm-mkl-hm', 'reinhard_lab_gpu'],
                    default='mkl'),
                io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
                io.Boolean.Input("multithread", default=True),
            ],
            outputs=[
                io.Image.Output(display_name="image"),
            ],
        )

    @classmethod
    def execute(cls, image_target, image_ref, method, strength=1.0, multithread=True) -> io.NodeOutput:
        # Skip unnecessary processing
        if strength == 0:
            return io.NodeOutput(image_target)

        if method == "reinhard_lab_gpu":
            import kornia
            device = model_management.get_torch_device()

            B, H, W, C = image_target.shape

            src_bchw = image_target.to(device).permute(0, 3, 1, 2).contiguous() # (B, H, W, C) -> (B, C, H, W)
            ref_bchw = image_ref.to(device).permute(0, 3, 1, 2).contiguous()
            # RGB->Lab
            src_lab = kornia.color.rgb_to_lab(src_bchw)
            ref_lab = kornia.color.rgb_to_lab(ref_bchw)

            src_lab_flat = src_lab.view(B, C, -1)  # (B, C, HW)
            ref_lab_flat = ref_lab.view(ref_lab.shape[0], C, -1)  # (B or 1, C, HW)

            src_std, src_mean = torch.std_mean(src_lab_flat, dim=-1, keepdim=True, unbiased=False)
            ref_std, ref_mean = torch.std_mean(ref_lab_flat, dim=-1, keepdim=True, unbiased=False)
            src_std = src_std.clamp_min_(1e-6)

            if ref_lab.shape[0] == 1 and B > 1:
                ref_mean = ref_mean.expand(B, -1, -1)
                ref_std = ref_std.expand(B, -1, -1)

            corrected_lab_flat = (src_lab_flat - src_mean) * (ref_std / src_std) + ref_mean
            corrected_lab = corrected_lab_flat.view(B, C, H, W)

            # Lab->RGB
            corrected_rgb_01 = kornia.color.lab_to_rgb(corrected_lab)
            out = (1.0 - strength) * src_bchw + strength * corrected_rgb_01
            out = out.permute(0, 2, 3, 1).contiguous() # (B, C, H, W) -> (B, H, W, C)

            return io.NodeOutput(out.cpu().float().clamp_(0, 1))

        try:
            from color_matcher import ColorMatcher
        except ImportError as e:
            raise ImportError("Can't import color-matcher, did you install requirements.txt? Manual install: pip install color-matcher") from e

        batch_size = image_target.size(0)
        ref_batch_size = image_ref.size(0)

        def process(i):
            cm = ColorMatcher()
            image_target_np = image_target[i].cpu().numpy()
            image_ref_np = image_ref[min(i, ref_batch_size - 1)].cpu().numpy()
            try:
                image_result = cm.transfer(src=image_target_np, ref=image_ref_np, method=method)
                if strength != 1:
                    image_result = image_target_np + strength * (image_result - image_target_np)

                return torch.from_numpy(image_result)

            except Exception as e:
                logging.error(f"Thread {i} error: {e}")
                return torch.from_numpy(image_target_np)  # fallback
        if multithread and batch_size > 1:
            max_threads = min(os.cpu_count() or 1, batch_size)
            with ThreadPoolExecutor(max_workers=max_threads) as executor:
                out = list(executor.map(process, range(batch_size)))
        else:
            out = [process(i) for i in range(batch_size)]

        out = torch.stack(out, dim=0).to(torch.float32).clamp_(0, 1)

        return io.NodeOutput(out)

class SaveImageWithAlpha:
    def __init__(self):
        self.output_dir = folder_paths.get_output_directory()
        self.type = "output"
        self.prefix_append = ""

    @classmethod
    def INPUT_TYPES(s):
        return {"required": 
                    {"images": ("IMAGE", ),
                    "mask": ("MASK", ),
                    "filename_prefix": ("STRING", {"default": "ComfyUI"})},
                "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
                }

    RETURN_TYPES = ()
    FUNCTION = "save_images_alpha"
    OUTPUT_NODE = True
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Saves an image and mask as .PNG with the mask as the alpha channel. 
"""

    def save_images_alpha(self, images, mask, filename_prefix="ComfyUI_image_with_alpha", prompt=None, extra_pnginfo=None):
        filename_prefix += self.prefix_append
        full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
        results = list()
        def file_counter():
            max_counter = 0
            # Loop through the existing files
            for existing_file in sorted(os.listdir(full_output_folder)):
                # Check if the file matches the expected format
                match = re.fullmatch(fr"{filename}_(\d+)_?\.[a-zA-Z0-9]+", existing_file)
                if match:
                    # Extract the numeric portion of the filename
                    file_counter = int(match.group(1))
                    # Update the maximum counter value if necessary
                    if file_counter > max_counter:
                        max_counter = file_counter
            return max_counter

        for image, alpha in zip(images, mask):
            i = 255. * image.cpu().numpy()
            a = 255. * (1.0 - alpha.cpu().float()).numpy()
            img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))

             # Resize the mask to match the image size
            a_resized = Image.fromarray(a).resize(img.size, Image.LANCZOS)
            a_resized = np.clip(a_resized, 0, 255).astype(np.uint8)
            img.putalpha(Image.fromarray(a_resized, mode='L'))
            metadata = None
            if not args.disable_metadata:
                metadata = PngInfo()
                if prompt is not None:
                    metadata.add_text("prompt", json.dumps(prompt))
                if extra_pnginfo is not None:
                    for x in extra_pnginfo:
                        metadata.add_text(x, json.dumps(extra_pnginfo[x]))

            # Increment the counter by 1 to get the next available value
            counter = file_counter() + 1
            file = f"{filename}_{counter:05}.png"
            img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
            results.append({
                "filename": file,
                "subfolder": subfolder,
                "type": self.type
            })

        return { "ui": { "images": results } }


class ImageConcanate(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        # image1 drives the output type; image2 can independently be IMAGE or MASK and gets
        # converted to image1's type inside concatenate().
        type_template = io.MatchType.Template("image_or_mask", allowed_types=[io.Image, io.Mask])
        return io.Schema(
            node_id="ImageConcanate",
            category="KJNodes/image",
            description=(
                "Concatenates image2 to image1 in the specified direction.\n"
                "Both inputs accept IMAGE or MASK; the output type follows image1.\n"
                "If image2 is a different type than image1 it's converted (RGB mean for image→mask,\n"
                "channel-replicate for mask→image).\n"
                "When match_image_size is False and dimensions don't match along the shared axis,\n"
                "the smaller image is centered and zero-padded instead of erroring."
            ),
            inputs=[
                io.MatchType.Input("image1", template=type_template),
                io.MultiType.Input("image2", types=[io.Image, io.Mask]),
                io.Combo.Input("direction", options=['right', 'down', 'left', 'up'], default='right'),
                io.Boolean.Input("match_image_size", default=True),
            ],
            outputs=[
                io.MatchType.Output(template=type_template, display_name="output"),
            ],
        )

    @classmethod
    def execute(cls, image1, image2, direction, match_image_size) -> io.NodeOutput:
        return io.NodeOutput(cls.concatenate(image1, image2, direction, match_image_size))

    @staticmethod
    def concatenate(image1, image2, direction, match_image_size, first_image_shape=None):
        # IMAGE is BHWC, MASK is BHW. Output type follows image1; convert image2 to match,
        # then unsqueeze any masks to BHW1 so the rest of the function can stay BHWC-only.
        output_is_mask = image1.dim() == 3
        if output_is_mask and image2.dim() == 4:
            ch = min(3, image2.shape[-1])
            image2 = image2[..., :ch].mean(dim=-1)
        elif not output_is_mask and image2.dim() == 3:
            image2 = image2.unsqueeze(-1).expand(-1, -1, -1, image1.shape[-1])
        if output_is_mask:
            image1 = image1.unsqueeze(-1)
            image2 = image2.unsqueeze(-1)

        bs1 = image1.shape[0]
        bs2 = image2.shape[0]
        B = max(bs1, bs2)

        H1, W1 = image1.shape[1], image1.shape[2]
        C1, C2 = image1.shape[-1], image2.shape[-1]
        out_C = max(C1, C2)

        if match_image_size:
            target_shape = first_image_shape if first_image_shape is not None else image1.shape
            orig_aspect = image2.shape[2] / image2.shape[1]
            if direction in ('left', 'right'):
                H2 = target_shape[1]
                W2 = int(H2 * orig_aspect)
            else:
                W2 = target_shape[2]
                H2 = int(W2 / orig_aspect)
        else:
            H2, W2 = image2.shape[1], image2.shape[2]

        if direction in ('right', 'left'):
            out_H, out_W = max(H1, H2), W1 + W2
        else:
            out_H, out_W = H1 + H2, max(W1, W2)

        if direction == 'right':
            i1_y, i1_x, i2_y, i2_x = (out_H - H1) // 2, 0, (out_H - H2) // 2, W1
        elif direction == 'left':
            i1_y, i1_x, i2_y, i2_x = (out_H - H1) // 2, W2, (out_H - H2) // 2, 0
        elif direction == 'down':
            i1_y, i1_x, i2_y, i2_x = 0, (out_W - W1) // 2, H1, (out_W - W2) // 2
        else:  # 'up'
            i1_y, i1_x, i2_y, i2_x = H2, (out_W - W1) // 2, 0, (out_W - W2) // 2

        output = torch.zeros(
            (B, out_H, out_W, out_C),
            dtype=model_management.intermediate_dtype(),
            device=model_management.intermediate_device(),
        )

        def write(dst, src, src_C):
            if dst.shape[-1] == src_C:
                dst.copy_(src)
            else:
                dst[..., :src_C].copy_(src)
                dst[..., src_C:].fill_(1.0)

        slot1 = output[:, i1_y:i1_y + H1, i1_x:i1_x + W1, :]
        if bs1 == B:
            write(slot1, image1, C1)
        else:
            write(slot1[:bs1], image1, C1)
            write(slot1[bs1:], image1[-1:].expand(B - bs1, -1, -1, -1), C1)
        del slot1

        slot2 = output[:, i2_y:i2_y + H2, i2_x:i2_x + W2, :]
        if match_image_size:
            pbar = ProgressBar(B)
            device = model_management.get_torch_device()
            for i in range(B):
                src_i = min(i, bs2 - 1)
                frame = image2[src_i:src_i + 1].to(device, non_blocking=True).permute(0, 3, 1, 2)
                resized = F.interpolate(frame, size=(H2, W2), mode='bicubic', antialias=True).permute(0, 2, 3, 1)
                write(slot2[i:i + 1], resized, C2)
                del frame, resized
                pbar.update(1)
        else:
            if bs2 == B:
                write(slot2, image2, C2)
            else:
                write(slot2[:bs2], image2, C2)
                write(slot2[bs2:], image2[-1:].expand(B - bs2, -1, -1, -1), C2)
        del slot2

        if output_is_mask:
            return output.squeeze(-1)
        return output


class ImageConcatFromBatch:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "images": ("IMAGE",),
            "num_columns": ("INT", {"default": 3, "min": 1, "max": 255, "step": 1}),
            "match_image_size": ("BOOLEAN", {"default": False}),
            "max_resolution": ("INT", {"default": 4096}), 
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "concat"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
    Concatenates images from a batch into a grid with a specified number of columns.
    """

    def concat(self, images, num_columns, match_image_size, max_resolution):
        # Assuming images is a batch of images (B, H, W, C)
        batch_size, height, width, channels = images.shape
        num_rows = (batch_size + num_columns - 1) // num_columns  # Calculate number of rows

        logging.info(f"Initial dimensions: batch_size={batch_size}, height={height}, width={width}, channels={channels}")
        logging.info(f"num_rows={num_rows}, num_columns={num_columns}")

        if match_image_size:
            target_shape = images[0].shape

            resized_images = []
            for image in images:
                original_height = image.shape[0]
                original_width = image.shape[1]
                original_aspect_ratio = original_width / original_height

                if original_aspect_ratio > 1:
                    target_height = target_shape[0]
                    target_width = int(target_height * original_aspect_ratio)
                else:
                    target_width = target_shape[1]
                    target_height = int(target_width / original_aspect_ratio)

                logging.info(f"Resizing image from ({original_height}, {original_width}) to ({target_height}, {target_width})")

                # Resize the image to match the target size while preserving aspect ratio
                resized_image = common_upscale(image.movedim(-1, 0), target_width, target_height, "lanczos", "disabled")
                resized_image = resized_image.movedim(0, -1)  # Move channels back to the last dimension
                resized_images.append(resized_image)

            # Convert the list of resized images back to a tensor
            images = torch.stack(resized_images)

            height, width = target_shape[:2]  # Update height and width

        # Initialize an empty grid
        grid_height = num_rows * height
        grid_width = num_columns * width

        logging.info(f"Grid dimensions before scaling: grid_height={grid_height}, grid_width={grid_width}")

        # Original scale factor calculation remains unchanged
        scale_factor = min(max_resolution / grid_height, max_resolution / grid_width, 1.0)

        # Apply scale factor to height and width
        scaled_height = height * scale_factor
        scaled_width = width * scale_factor

        # Round scaled dimensions to the nearest number divisible by 8
        height = max(1, int(round(scaled_height / 8) * 8))
        width = max(1, int(round(scaled_width / 8) * 8))

        if abs(scaled_height - height) > 4:
            height = max(1, int(round((scaled_height + 4) / 8) * 8))
        if abs(scaled_width - width) > 4:
            width = max(1, int(round((scaled_width + 4) / 8) * 8))

        # Recalculate grid dimensions with adjusted height and width
        grid_height = num_rows * height
        grid_width = num_columns * width
        logging.info(f"Grid dimensions after scaling: grid_height={grid_height}, grid_width={grid_width}")
        logging.info(f"Final image dimensions: height={height}, width={width}")

        grid = torch.zeros((grid_height, grid_width, channels), dtype=images.dtype)

        for idx, image in enumerate(images):
            resized_image = torch.nn.functional.interpolate(image.unsqueeze(0).permute(0, 3, 1, 2), size=(height, width), mode="bilinear").squeeze().permute(1, 2, 0)
            row = idx // num_columns
            col = idx % num_columns
            grid[row*height:(row+1)*height, col*width:(col+1)*width, :] = resized_image

        return grid.unsqueeze(0),

    
class ImageGridComposite2x2:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "image1": ("IMAGE",),
            "image2": ("IMAGE",),
            "image3": ("IMAGE",),
            "image4": ("IMAGE",),   
        }}

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "compositegrid"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Concatenates the 4 input images into a 2x2 grid. 
"""

    def compositegrid(self, image1, image2, image3, image4):
        top_row = torch.cat((image1, image2), dim=2)
        bottom_row = torch.cat((image3, image4), dim=2)
        grid = torch.cat((top_row, bottom_row), dim=1)
        return (grid,)
    
class ImageGridComposite3x3:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "image1": ("IMAGE",),
            "image2": ("IMAGE",),
            "image3": ("IMAGE",),
            "image4": ("IMAGE",),
            "image5": ("IMAGE",),
            "image6": ("IMAGE",),
            "image7": ("IMAGE",),
            "image8": ("IMAGE",),
            "image9": ("IMAGE",),     
        }}

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "compositegrid"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Concatenates the 9 input images into a 3x3 grid. 
"""

    def compositegrid(self, image1, image2, image3, image4, image5, image6, image7, image8, image9):
        top_row = torch.cat((image1, image2, image3), dim=2)
        mid_row = torch.cat((image4, image5, image6), dim=2)
        bottom_row = torch.cat((image7, image8, image9), dim=2)
        grid = torch.cat((top_row, mid_row, bottom_row), dim=1)
        return (grid,)


class ImageBatchTestPattern(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="ImageBatchTestPattern",
            category="KJNodes/text",
            description="Generate a batch of images with sequential numbers rendered in a chosen font.",
            inputs=[
                io.Int.Input("batch_size", default=1, min=1, max=4096, step=1),
                io.Int.Input("start_from", default=0, min=0, max=4096, step=1),
                io.Int.Input("text_x", default=256, min=0, max=4096, step=1),
                io.Int.Input("text_y", default=256, min=0, max=4096, step=1),
                io.Int.Input("width", default=512, min=16, max=4096, step=1),
                io.Int.Input("height", default=512, min=16, max=4096, step=1),
                io.Combo.Input("font", options=folder_paths.get_filename_list("kjnodes_fonts")),
                io.Int.Input("font_size", default=255, min=8, max=4096, step=1),
            ],
            outputs=[
                io.Image.Output(display_name="image"),
            ],
        )

    @classmethod
    def execute(cls, batch_size, font, font_size, start_from, width, height, text_x, text_y) -> io.NodeOutput:
        font_path = folder_paths.get_full_path("kjnodes_fonts", font)
        pil_font = ImageFont.truetype(font_path, font_size)

        # Probe once whether the '-liga' feature is supported by this PIL build/font
        use_liga = True
        try:
            ImageDraw.Draw(Image.new("RGB", (1, 1))).text(
                (0, 0), "0", font=pil_font, fill=(0, 0, 0), features=['-liga']
            )
        except Exception:
            use_liga = False

        image = Image.new("RGB", (width, height), color='black')
        draw = ImageDraw.Draw(image)

        out_buf = np.empty((batch_size, height, width, 3), dtype=np.uint8)
        pbar = ProgressBar(batch_size)

        for i in range(batch_size):
            # Reset canvas to black instead of allocating a new PIL image
            draw.rectangle((0, 0, width, height), fill='black')

            font_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            text = str(start_from + i)

            if use_liga:
                draw.text((text_x, text_y), text, font=pil_font, fill=font_color, features=['-liga'])
            else:
                draw.text((text_x, text_y), text, font=pil_font, fill=font_color)

            out_buf[i] = np.asarray(image)
            pbar.update(1)

        out_tensor = torch.from_numpy(out_buf).to(
            device=model_management.intermediate_device(),
            dtype=model_management.intermediate_dtype(),
        ).div_(255.0)
        return io.NodeOutput(out_tensor)

class ImageGrabPIL:

    @classmethod
    def IS_CHANGED(cls):

        return

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = "screencap"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Captures an area specified by screen coordinates.  
Can be used for realtime diffusion with autoqueue.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "x": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
                 "y": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
                 "width": ("INT", {"default": 512,"min": 0, "max": 4096, "step": 1}),
                 "height": ("INT", {"default": 512,"min": 0, "max": 4096, "step": 1}),
                 "num_frames": ("INT", {"default": 1,"min": 1, "max": 255, "step": 1}),
                 "delay": ("FLOAT", {"default": 0.1,"min": 0.0, "max": 10.0, "step": 0.01}),
        },
    } 

    def screencap(self, x, y, width, height, num_frames, delay):
        start_time = time.time()
        captures = []
        bbox = (x, y, x + width, y + height)
        
        for _ in range(num_frames):
            # Capture screen
            screen_capture = ImageGrab.grab(bbox=bbox)
            screen_capture_torch = torch.from_numpy(np.array(screen_capture, dtype=np.float32) / 255.0).unsqueeze(0)
            captures.append(screen_capture_torch)
            
            # Wait for a short delay if more than one frame is to be captured
            if num_frames > 1:
                time.sleep(delay)

        elapsed_time = time.time() - start_time
        logging.info(f"screengrab took {elapsed_time} seconds.")
        
        return (torch.cat(captures, dim=0),)
    
class Screencap_mss:

    @classmethod
    def IS_CHANGED(s, **kwargs):
        return float("NaN")

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = "screencap"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Captures an area specified by screen coordinates.  
Can be used for realtime diffusion with autoqueue.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "x": ("INT", {"default": 0,"min": 0, "max": 10000, "step": 1}),
                 "y": ("INT", {"default": 0,"min": 0, "max": 10000, "step": 1}),
                 "width": ("INT", {"default": 512,"min": 0, "max": 10000, "step": 1}),
                 "height": ("INT", {"default": 512,"min": 0, "max": 10000, "step": 1}),
                 "num_frames": ("INT", {"default": 1,"min": 1, "max": 255, "step": 1}),
                 "delay": ("FLOAT", {"default": 0.1,"min": 0.0, "max": 10.0, "step": 0.01}),
        },
    } 

    def screencap(self, x, y, width, height, num_frames, delay):
        from mss import mss
        captures = []
        with mss() as sct:
            bbox = {'top': y, 'left': x, 'width': width, 'height': height}
            
            for _ in range(num_frames):
                sct_img = sct.grab(bbox)
                img_np = np.array(sct_img)
                img_torch = torch.from_numpy(img_np[..., [2, 1, 0]]).float() / 255.0
                captures.append(img_torch)
                
                if num_frames > 1:
                    time.sleep(delay)
        
        return (torch.stack(captures, 0),)

class ScreencapStream:

    @classmethod
    def IS_CHANGED(s, **kwargs):
        return float("NaN")

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = "capture"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Captures a frame from a browser screen/window share stream.
Click 'Start capture' to select a screen or window to share.
Live preview is shown in the node. Works with auto-queue.

Crop controls:
- Drag on preview to draw a crop box
- Drag inside the box to move it
- Drag edges or corners to resize
- Shift+drag to lock aspect ratio
- Right-click or double-click to clear crop
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "frame_data": ("STRING", {"default": "", "multiline": False}),
                "crop_width": ("INT", {"default": 1, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
                "crop_height": ("INT", {"default": 1, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
            },
        }

    MAX_FRAME_BYTES = 50 * 1024 * 1024  # 50MB base64 limit (PNG is larger than JPEG)

    def capture(self, crop_width, crop_height, frame_data):
        if not frame_data:
            w = crop_width if crop_width > 0 else 512
            h = crop_height if crop_height > 0 else 512
            return (torch.zeros(1, h, w, 3),)
        if len(frame_data) > self.MAX_FRAME_BYTES:
            raise ValueError(f"Frame data exceeds {self.MAX_FRAME_BYTES // (1024*1024)}MB limit")
        try:
            img_bytes = base64.b64decode(frame_data.split(",", 1)[-1])
        except Exception:
            raise ValueError("Invalid frame data encoding")
        img = Image.open(BytesIO(img_bytes)).convert("RGB")
        img_np = np.array(img).astype(np.float32) / 255.0
        img_tensor = torch.from_numpy(img_np).unsqueeze(0)
        return (img_tensor,)

class WebcamCaptureCV2:

    @classmethod
    def IS_CHANGED(cls):
        return

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = "capture"
    CATEGORY = "KJNodes/experimental"
    DESCRIPTION = """
Captures a frame from a webcam using CV2.  
Can be used for realtime diffusion with autoqueue.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "x": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
                 "y": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
                 "width": ("INT", {"default": 512,"min": 0, "max": 4096, "step": 1}),
                 "height": ("INT", {"default": 512,"min": 0, "max": 4096, "step": 1}),
                 "cam_index": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
                 "release": ("BOOLEAN", {"default": False}),
            },
        } 

    def capture(self, x, y, cam_index, width, height, release):
        # Check if the camera index has changed or the capture object doesn't exist
        if not hasattr(self, "cap") or self.cap is None or self.current_cam_index != cam_index:
            if hasattr(self, "cap") and self.cap is not None:
                self.cap.release()
            self.current_cam_index = cam_index
            self.cap = cv2.VideoCapture(cam_index)
            try:
                self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
                self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
            except cv2.error:
                pass
            if not self.cap.isOpened():
                raise RuntimeError("Could not open webcam")

        ret, frame = self.cap.read()
        if not ret:
            raise RuntimeError("Failed to capture image from webcam")

        # Crop the frame to the specified bbox
        frame = frame[y:y+height, x:x+width]
        img_torch = torch.from_numpy(frame[..., [2, 1, 0]]).float() / 255.0

        if release:
            self.cap.release()
            self.cap = None

        return (img_torch.unsqueeze(0),)

class AddLabel:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "image":("IMAGE",),  
            "text_x": ("INT", {"default": 10, "min": 0, "max": 4096, "step": 1}),
            "text_y": ("INT", {"default": 2, "min": 0, "max": 4096, "step": 1}),
            "height": ("INT", {"default": 48, "min": -1, "max": 4096, "step": 1}),
            "font_size": ("INT", {"default": 32, "min": 0, "max": 4096, "step": 1}),
            "font_color": ("STRING", {"default": "white"}),
            "label_color": ("STRING", {"default": "black"}),
            "font": (folder_paths.get_filename_list("kjnodes_fonts"), ),
            "text": ("STRING", {"default": "Text"}),
            "direction": (
            [   'up',
                'down',
                'left',
                'right',
                'overlay'
            ],
            {
            "default": 'up'
             }),
            },
            "optional":{
                "caption": ("STRING", {"default": "", "forceInput": True}),
            }
            }
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "addlabel"
    CATEGORY = "KJNodes/text"
    DESCRIPTION = """
Creates a new with the given text, and concatenates it to  
either above or below the input image.  
Note that this changes the input image's height!  
Fonts are loaded from this folder:  
ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts
"""

    def addlabel(self, image, text_x, text_y, text, height, font_size, font_color, label_color, font, direction, caption=""):
        batch_size = image.shape[0]
        width = image.shape[2]

        font_path = os.path.join(script_directory, "fonts", "TTNorms-Black.otf") if font == "TTNorms-Black.otf" else folder_paths.get_full_path("kjnodes_fonts", font)

        # Parse colors using helper function
        font_color_rgb = string_to_color(font_color)
        label_color_rgb = string_to_color(label_color)

        # Convert to tuples for PIL
        font_color_tuple = tuple(font_color_rgb[:3])  # RGB only
        label_color_tuple = tuple(label_color_rgb[:3])  # RGB only

        def process_image(input_image, caption_text):
            font = ImageFont.truetype(font_path, font_size)
            lines = []
            for text_line in caption_text.split('\n'):
                if text_line.strip() == "":
                    # Preserve empty lines for multiple newlines
                    lines.append("")
                    continue
                words = text_line.split()
                current_line = []
                for word in words:
                    if current_line:
                        test_line = " ".join(current_line + [word])
                    else:
                        test_line = word
                    try:
                        test_line_width = font.getbbox(test_line)[2]
                    except Exception:
                        test_line_width = font.getsize(test_line)[0]
                    if test_line_width <= width - 2 * text_x:
                        current_line.append(word)
                    else:
                        lines.append(" ".join(current_line))
                        current_line = [word]
                if current_line:
                    lines.append(" ".join(current_line))

            if direction == 'overlay':
                pil_image = Image.fromarray((input_image.cpu().numpy() * 255).astype(np.uint8))
            else:
                if height == -1:
                    # Adjust the image height automatically
                    margin = 8
                    required_height = (text_y + len(lines) * font_size) + margin # Calculate required height
                    pil_image = Image.new("RGB", (width, required_height), label_color_tuple)
                else:
                    # Initialize with a minimal height
                    label_image = Image.new("RGB", (width, height), label_color_tuple)
                    pil_image = label_image

            draw = ImageDraw.Draw(pil_image)


            y_offset = text_y
            for line in lines:
                try:
                    draw.text((text_x, y_offset), line, font=font, fill=font_color_tuple, features=['-liga'])
                except Exception:
                    draw.text((text_x, y_offset), line, font=font, fill=font_color_tuple)
                y_offset += font_size

            processed_image = torch.from_numpy(np.array(pil_image).astype(np.float32) / 255.0).unsqueeze(0)
            return processed_image

        if caption == "":
            processed_images = [process_image(img, text) for img in image]
        else:
            assert len(caption) == batch_size, f"Number of captions {(len(caption))} does not match number of images"
            processed_images = [process_image(img, cap) for img, cap in zip(image, caption)]
        processed_batch = torch.cat(processed_images, dim=0)

        # Combine images based on direction
        if direction == 'down':
            combined_images = torch.cat((image, processed_batch), dim=1)
        elif direction == 'up':
            combined_images = torch.cat((processed_batch, image), dim=1)
        elif direction == 'left':
            processed_batch = torch.rot90(processed_batch, 3, (2, 3)).permute(0, 3, 1, 2)
            combined_images = torch.cat((processed_batch, image), dim=2)
        elif direction == 'right':
            processed_batch = torch.rot90(processed_batch, 3, (2, 3)).permute(0, 3, 1, 2)
            combined_images = torch.cat((image, processed_batch), dim=2)
        else:
            combined_images = processed_batch
        
        return (combined_images,)
    
class GetImageSizeAndCount:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "image": ("IMAGE",),
        }}

    RETURN_TYPES = ("IMAGE","INT", "INT", "INT",)
    RETURN_NAMES = ("image", "width", "height", "count",)
    FUNCTION = "getsize"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Returns width, height and batch size of the image,  
and passes it through unchanged.  

"""

    def getsize(self, image):
        width = image.shape[2]
        height = image.shape[1]
        count = image.shape[0]
        return {"ui": {
            "text": [f"{count}x{width}x{height}"]}, 
            "result": (image, width, height, count) 
        }

class GetLatentSizeAndCount:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "latent": ("LATENT",),
        }}

    RETURN_TYPES = ("LATENT","INT", "INT", "INT", "INT", "INT")
    RETURN_NAMES = ("latent", "batch_size", "channels", "frames", "height", "width")
    FUNCTION = "getsize"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Returns latent tensor dimensions,  
and passes the latent through unchanged.  

"""
    def getsize(self, latent):
        if len(latent["samples"].shape) == 5:
            B, C, T, H, W = latent["samples"].shape
        elif len(latent["samples"].shape) == 4:
            B, C, H, W = latent["samples"].shape
            T = 0
        else:
            raise ValueError("Invalid latent shape")

        return {"ui": {
            "text": [f"{B}x{C}x{T}x{H}x{W}"]}, 
            "result": (latent, B, C, T, H, W) 
        }

class ImageBatchRepeatInterleaving:
    RETURN_TYPES = ("IMAGE", "MASK",)
    FUNCTION = "repeat"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Repeats each image in a batch by the specified number of times.  
Example batch of 5 images: 0, 1 ,2, 3, 4  
with repeats 2 becomes batch of 10 images: 0, 0, 1, 1, 2, 2, 3, 3, 4, 4  
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "images": ("IMAGE",),
                 "repeats": ("INT", {"default": 1, "min": 1, "max": 4096}),
            },
            "optional": {
                "mask": ("MASK",),
            }
        }

    def repeat(self, images, repeats, mask=None):
        original_count = images.shape[0]
        total_count = original_count * repeats

        repeated_images = torch.repeat_interleave(images, repeats=repeats, dim=0)
        if mask is not None:
            mask = torch.repeat_interleave(mask, repeats=repeats, dim=0)
        else:
            mask = torch.zeros((total_count, images.shape[1], images.shape[2]),
                              device=images.device, dtype=images.dtype)
            for i in range(original_count):
                mask[i * repeats] = 1.0

        return (repeated_images, mask)

class ImageUpscaleWithModelBatched:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "upscale_model": ("UPSCALE_MODEL",),
                              "images": ("IMAGE",),
                              "per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
                              },
                "optional": {
                    "downscale_ratio": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 1.0, "step": 0.01}),
                    "downscale_method": (["nearest-exact", "bilinear", "area", "bicubic", "lanczos"], {"default": "lanczos"}),
                    "precision": (["float32", "float16", "bfloat16"], {"default": "float32"}),
                }}
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "upscale"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Same as ComfyUI native model upscaling node,  
but allows setting sub-batches for reduced VRAM usage.
Optionally downscale the result with a ratio.
"""
    def upscale(self, upscale_model, images, per_batch, downscale_ratio=1.0, downscale_method="lanczos", precision="float32"):
        dtype = torch.float16 if precision == "float16" else torch.bfloat16 if precision == "bfloat16" else torch.float32
        device = model_management.get_torch_device()
        upscale_model.to(device, dtype=dtype)
        in_img = images.movedim(-1,-3).to(dtype)

        steps = in_img.shape[0]
        pbar = ProgressBar(steps)
        t = []

        for start_idx in range(0, in_img.shape[0], per_batch):
            sub_images = upscale_model(in_img[start_idx:start_idx+per_batch].to(device))
            t.append(sub_images.cpu())
            # Calculate the number of images processed in this batch
            batch_count = sub_images.shape[0]
            # Update the progress bar by the number of images processed in this batch
            pbar.update(batch_count)
        upscale_model.cpu()

        t = torch.cat(t, dim=0).permute(0, 2, 3, 1).cpu().float()

        # Apply downscaling if ratio is less than 1.0
        if downscale_ratio < 1.0:
            original_height = t.shape[1]
            original_width = t.shape[2]
            new_height = int(original_height * downscale_ratio)
            new_width = int(original_width * downscale_ratio)
            t = t.movedim(-1, 1)
            t = common_upscale(t, new_width, new_height, downscale_method, "disabled")
            t = t.movedim(1, -1)

        return (t,)

class ImageNormalize_Neg1_To_1:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
                              "images": ("IMAGE",),
    
                              }}
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "normalize"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Normalize the images to be in the range [-1, 1]  
"""

    def normalize(self,images):
        images = images * 2.0 - 1.0
        return (images,)

class RemapImageRange:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
            "image": ("IMAGE",),
            "min": ("FLOAT", {"default": 0.0,"min": -10.0, "max": 1.0, "step": 0.01}),
            "max": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 10.0, "step": 0.01}),
            "clamp": ("BOOLEAN", {"default": True}),
            },
            }
    
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "remap"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Remaps the image values to the specified range. 
"""
        
    def remap(self, image, min, max, clamp):
        if image.dtype == torch.float16:
            image = image.to(torch.float32)
        image = min + image * (max - min)
        if clamp:
            image = torch.clamp(image, min=0.0, max=1.0)
        return (image, )

class SplitImageChannels:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
            "image": ("IMAGE",),
            },
            }

    RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "MASK")
    RETURN_NAMES = ("red", "green", "blue", "mask")
    FUNCTION = "split"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Splits image channels into images where the selected channel  
is repeated for all channels, and the alpha as a mask. 
"""

    def split(self, image):
        red = image[:, :, :, 0:1] # Red channel
        green = image[:, :, :, 1:2] # Green channel
        blue = image[:, :, :, 2:3] # Blue channel
        if image.shape[3] == 4:
            alpha = image[:, :, :, 4] # Alpha channel
        else:
            alpha = torch.zeros(image.shape[0], image.shape[1], image.shape[2], device=image.device)

        # Repeat the selected channel for all channels
        red = torch.cat([red, red, red], dim=3)
        green = torch.cat([green, green, green], dim=3)
        blue = torch.cat([blue, blue, blue], dim=3)
        return (red, green, blue, alpha)

class MergeImageChannels:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
            "red": ("IMAGE",),
            "green": ("IMAGE",),
            "blue": ("IMAGE",),
            
            },
            "optional": {
                "alpha": ("MASK", {"default": None}),
                },
            }
    
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = "merge"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Merges channel data into an image.  
"""
        
    def merge(self, red, green, blue, alpha=None):
        image = torch.stack([
        red[..., 0, None], # Red channel
        green[..., 1, None], # Green channel
        blue[..., 2, None]   # Blue channel
        ], dim=-1)
        image = image.squeeze(-2)
        if alpha is not None:
            image = torch.cat([image, alpha.unsqueeze(-1)], dim=-1)
        return (image,)

class ImagePadForOutpaintMasked:

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
                "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
                "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
                "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
                "feathering": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
            },
            "optional": {
                "mask": ("MASK",),
            }
        }

    RETURN_TYPES = ("IMAGE", "MASK")
    FUNCTION = "expand_image"

    CATEGORY = "image"

    def expand_image(self, image, left, top, right, bottom, feathering, mask=None):
        if mask is not None:
            if torch.allclose(mask, torch.zeros_like(mask)):
                    logging.warning("The incoming mask is fully black. Handling it as None.")
                    mask = None
        B, H, W, C = image.size()

        new_image = torch.ones(
            (B, H + top + bottom, W + left + right, C),
            dtype=torch.float32,
        ) * 0.5

        new_image[:, top:top + H, left:left + W, :] = image

        if mask is None:
            new_mask = torch.ones(
                (B, H + top + bottom, W + left + right),
                dtype=torch.float32,
            )

            t = torch.zeros(
            (B, H, W),
            dtype=torch.float32
            )
        else:
            # If a mask is provided, pad it to fit the new image size
            mask = F.pad(mask, (left, right, top, bottom), mode='constant', value=0)
            mask = 1 - mask
            t = torch.zeros_like(mask)
        
        if feathering > 0 and feathering * 2 < H and feathering * 2 < W:

            for i in range(H):
                for j in range(W):
                    dt = i if top != 0 else H
                    db = H - i if bottom != 0 else H

                    dl = j if left != 0 else W
                    dr = W - j if right != 0 else W

                    d = min(dt, db, dl, dr)

                    if d >= feathering:
                        continue

                    v = (feathering - d) / feathering

                    if mask is None:
                        t[:, i, j] = v * v
                    else:
                        t[:, top + i, left + j] = v * v
        
        if mask is None:
            new_mask[:, top:top + H, left:left + W] = t
            return (new_image, new_mask,)
        else:
            return (new_image, mask,)

class ImagePadForOutpaintTargetSize:
    upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "target_width": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
                "target_height": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
                "feathering": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
                "upscale_method": (s.upscale_methods,),
            },
            "optional": {
                "mask": ("MASK",),
            }
        }

    RETURN_TYPES = ("IMAGE", "MASK")
    FUNCTION = "expand_image"

    CATEGORY = "image"

    def expand_image(self, image, target_width, target_height, feathering, upscale_method, mask=None):
        B, H, W, C = image.size()
        new_height = H
        new_width = W
         # Calculate the scaling factor while maintaining aspect ratio
        scaling_factor = min(target_width / W, target_height / H)
        
        # Check if the image needs to be downscaled
        if scaling_factor < 1:
            image = image.movedim(-1,1)
            # Calculate the new width and height after downscaling
            new_width = int(W * scaling_factor)
            new_height = int(H * scaling_factor)
            
            # Downscale the image
            image_scaled = common_upscale(image, new_width, new_height, upscale_method, "disabled").movedim(1,-1)
        else:
            # If downscaling is not needed, use the original image dimensions
            image_scaled = image

        # Ensure mask dimensions match image dimensions
        if mask is not None:
            mask_scaled = mask.unsqueeze(0)  # Add an extra dimension for batch size
            mask_scaled = F.interpolate(mask_scaled, size=(new_height, new_width), mode="nearest")
            mask_scaled = mask_scaled.squeeze(0)  # Remove the extra dimension after interpolation
        else:
            mask_scaled = None

        # Calculate how much padding is needed to reach the target dimensions
        pad_top = max(0, (target_height - new_height) // 2)
        pad_bottom = max(0, target_height - new_height - pad_top)
        pad_left = max(0, (target_width - new_width) // 2)
        pad_right = max(0, target_width - new_width - pad_left)

        # Now call the original expand_image with the calculated padding
        return ImagePadForOutpaintMasked.expand_image(self, image_scaled, pad_left, pad_top, pad_right, pad_bottom, feathering, mask_scaled)

class ImagePrepForICLora:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "reference_image": ("IMAGE",),
                "output_width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
                "output_height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
                "border_width": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
            },
            "optional": {
                "latent_image": ("IMAGE",),
                "latent_mask": ("MASK",),
                "reference_mask": ("MASK",),
            }
        }

    RETURN_TYPES = ("IMAGE", "MASK")
    FUNCTION = "expand_image"

    CATEGORY = "image"

    def expand_image(self, reference_image, output_width, output_height, border_width, latent_image=None, reference_mask=None, latent_mask=None):

        if reference_mask is not None:
            if torch.allclose(reference_mask, torch.zeros_like(reference_mask)):
                    logging.warning("The incoming mask is fully black. Handling it as None.")
                    reference_mask = None
        image = reference_image
        if latent_image is not None:
            if image.shape[0] != latent_image.shape[0]:
                image = image.repeat(latent_image.shape[0], 1, 1, 1)
        B, H, W, C = image.size()

        # Handle mask
        if reference_mask is not None:
            resized_mask = torch.nn.functional.interpolate(
                reference_mask.unsqueeze(1), 
                size=(H, W),
                mode='nearest'
            ).squeeze(1)
            image = image * resized_mask.unsqueeze(-1)

        # Calculate new width maintaining aspect ratio
        new_width = int((W / H) * output_height)
        
        # Resize image to new height while maintaining aspect ratio
        resized_image = common_upscale(image.movedim(-1,1), new_width, output_height, "lanczos", "disabled").movedim(1,-1)

        # Create padded image
        if latent_image is None:
            pad_image = torch.zeros((B, output_height, output_width, C), device=image.device)
        else:
            resized_latent_image = common_upscale(latent_image.movedim(-1,1), output_width, output_height, "lanczos", "disabled").movedim(1,-1)
            pad_image = resized_latent_image
            if latent_mask is not None:
                resized_latent_mask = torch.nn.functional.interpolate(
                    latent_mask.unsqueeze(1), 
                    size=(pad_image.shape[1], pad_image.shape[2]), 
                    mode='nearest'
                ).squeeze(1)

        if border_width > 0:
            border = torch.zeros((B, output_height, border_width, C), device=image.device)
            padded_image = torch.cat((resized_image, border, pad_image), dim=2)
            if latent_mask is not None:
                padded_mask = torch.zeros((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
                padded_mask[:, :, (new_width + border_width):] = resized_latent_mask
            else:
                padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
                padded_mask[:, :, :new_width + border_width] = 0
        else:
            padded_image = torch.cat((resized_image, pad_image), dim=2)
            if latent_mask is not None:
                padded_mask = torch.zeros((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
                padded_mask[:, :, new_width:] = resized_latent_mask
            else:
                padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
                padded_mask[:, :, :new_width] = 0

        return (padded_image, padded_mask)


class ImageAndMaskPreview(SaveImage):
    def __init__(self):
        self.output_dir = folder_paths.get_temp_directory()
        self.type = "temp"
        self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
        self.compress_level = 4

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "mask_opacity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
                "mask_color": ("STRING", {"default": "255, 255, 255", "tooltip": "RGB (255,255,255) or RGBA (255,255,255,128) or Hex (#RRGGBB / #RRGGBBAA)"}),
                "pass_through": ("BOOLEAN", {"default": False}),
             },
            "optional": {
                "image": ("IMAGE",),
                "mask": ("MASK",),
            },
            "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
        }
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("composite",)
    FUNCTION = "execute"
    CATEGORY = "KJNodes/masking"
    DESCRIPTION = """
Preview an image or a mask, when both inputs are used  
composites the mask on top of the image.
with pass_through on the preview is disabled and the  
composite is returned from the composite slot instead,  
this allows for the preview to be passed for video combine  
nodes for example. Supports RGBA for mask_color to adjust transparency per color.  
"""

    def execute(self, mask_opacity, mask_color, pass_through, filename_prefix="ComfyUI", image=None, mask=None, prompt=None, extra_pnginfo=None):
        if mask is not None and image is None:
            preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
        elif mask is None and image is not None:
            preview = image
        elif mask is not None and image is not None:
            mask_adjusted = mask * mask_opacity
            mask_image = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3).clone()

            # Use helper function to parse color string
            color_list = string_to_color(mask_color)

            # Apply RGB channels
            mask_image[:, :, :, 0] = color_list[0] / 255 # Red channel
            mask_image[:, :, :, 1] = color_list[1] / 255 # Green channel
            mask_image[:, :, :, 2] = color_list[2] / 255 # Blue channel

            if len(color_list) == 4: # Apply Alpha channel if present
                alpha_factor = color_list[3] / 255.0
                mask_adjusted = mask_adjusted * alpha_factor

            destination, source = node_helpers.image_alpha_fix(image, mask_image)
            destination = destination.clone().movedim(-1, 1)
            preview = composite(destination, source.movedim(-1, 1), 0, 0, mask_adjusted, 1, True).movedim(1, -1)
        if pass_through:
            return (preview, )
        return(self.save_images(preview, filename_prefix, prompt, extra_pnginfo))

def crossfade(images_1, images_2, alpha):
    crossfade = (1 - alpha) * images_1 + alpha * images_2
    return crossfade
def ease_in(t):
    return t * t
def ease_out(t):
    return 1 - (1 - t) * (1 - t)
def ease_in_out(t):
    return 3 * t * t - 2 * t * t * t
def bounce(t):
    if t < 0.5:
        return ease_out(t * 2) * 0.5
    else:
        return ease_in((t - 0.5) * 2) * 0.5 + 0.5
def elastic(t):
    return math.sin(13 * math.pi / 2 * t) * math.pow(2, 10 * (t - 1))
def glitchy(t):
    return t + 0.1 * math.sin(40 * t)
def exponential_ease_out(t):
    return 1 - (1 - t) ** 4

easing_functions = {
    "linear": lambda t: t,
    "ease_in": ease_in,
    "ease_out": ease_out,
    "ease_in_out": ease_in_out,
    "bounce": bounce,
    "elastic": elastic,
    "glitchy": glitchy,
    "exponential_ease_out": exponential_ease_out,
}

class CrossFadeImages:
    
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "crossfadeimages"
    CATEGORY = "KJNodes/image"

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "images_1": ("IMAGE",),
                 "images_2": ("IMAGE",),
                 "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],),
                 "transition_start_index": ("INT", {"default": 1,"min": -4096, "max": 4096, "step": 1}),
                 "transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
                 "start_level": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 1.0, "step": 0.01}),
                 "end_level": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 1.0, "step": 0.01}),
        },
    } 
    
    def crossfadeimages(self, images_1, images_2, transition_start_index, transitioning_frames, interpolation, start_level, end_level):

        crossfade_images = []

        if transition_start_index < 0:
            transition_start_index = len(images_1) + transition_start_index
            if transition_start_index < 0:
                raise ValueError("Transition start index is out of range for images_1.")
            
        transitioning_frames = min(transitioning_frames, len(images_1) - transition_start_index, len(images_2))

        alphas = torch.linspace(start_level, end_level, transitioning_frames)
        for i in range(transitioning_frames):
            alpha = alphas[i]
            image1 = images_1[transition_start_index + i]
            image2 = images_2[i]
            easing_function = easing_functions.get(interpolation)
            alpha = easing_function(alpha)  # Apply the easing function to the alpha value

            crossfade_image = crossfade(image1, image2, alpha)
            crossfade_images.append(crossfade_image)

        # Convert crossfade_images to tensor
        crossfade_images = torch.stack(crossfade_images, dim=0)

        # Append the beginning of images_1 (before the transition)
        beginning_images_1 = images_1[:transition_start_index]
        crossfade_images = torch.cat([beginning_images_1, crossfade_images], dim=0)

        # Append the remaining frames of images_2 (after the transition)
        remaining_images_2 = images_2[transitioning_frames:]
        if len(remaining_images_2) > 0:
            crossfade_images = torch.cat([crossfade_images, remaining_images_2], dim=0)

        return (crossfade_images, )
    
class CrossFadeImagesMulti:
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "crossfadeimages"
    CATEGORY = "KJNodes/image"

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
                 "image_1": ("IMAGE",),
                 "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],),
                 "transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
        },
        "optional": {
            "image_2": ("IMAGE",),
        }
    } 
    
    def crossfadeimages(self, inputcount, transitioning_frames, interpolation, **kwargs):

        image_1 = kwargs["image_1"]
        first_image_shape = image_1.shape
        first_image_device = image_1.device
        height = image_1.shape[1]
        width = image_1.shape[2]

        easing_function = easing_functions[interpolation]
       
        for c in range(1, inputcount):
            frames = []
            new_image = kwargs.get(f"image_{c + 1}", torch.zeros(first_image_shape)).to(first_image_device)
            new_image_height = new_image.shape[1]
            new_image_width = new_image.shape[2]

            if new_image_height != height or new_image_width != width:
                new_image = common_upscale(new_image.movedim(-1, 1), width, height, "lanczos", "disabled")
                new_image = new_image.movedim(1, -1)  # Move channels back to the last dimension

            last_frame_image_1 = image_1[-1]
            first_frame_image_2 = new_image[0]

            for frame in range(transitioning_frames):
                t = frame / (transitioning_frames - 1)
                alpha = easing_function(t)
                alpha_tensor = torch.tensor(alpha, dtype=last_frame_image_1.dtype, device=last_frame_image_1.device)
                frame_image = crossfade(last_frame_image_1, first_frame_image_2, alpha_tensor)
                frames.append(frame_image)
        
            frames = torch.stack(frames)
            image_1 = torch.cat((image_1, frames, new_image), dim=0)
        
        return image_1,

def transition_images(images_1, images_2, alpha, transition_type, blur_radius, reverse):        
    width = images_1.shape[1]
    height = images_1.shape[0]

    mask = torch.zeros_like(images_1, device=images_1.device)
    
    alpha = alpha.item()
    if reverse:
        alpha = 1 - alpha

    #transitions from matteo's essential nodes
    if "horizontal slide" in transition_type:
        pos = round(width * alpha)
        mask[:, :pos, :] = 1.0
    elif "vertical slide" in transition_type:
        pos = round(height * alpha)
        mask[:pos, :, :] = 1.0
    elif "box" in transition_type:
        box_w = round(width * alpha)
        box_h = round(height * alpha)
        x1 = (width - box_w) // 2
        y1 = (height - box_h) // 2
        x2 = x1 + box_w
        y2 = y1 + box_h
        mask[y1:y2, x1:x2, :] = 1.0
    elif "circle" in transition_type:
        radius = math.ceil(math.sqrt(pow(width, 2) + pow(height, 2)) * alpha / 2)
        c_x = width // 2
        c_y = height // 2
        x = torch.arange(0, width, dtype=torch.float32, device="cpu")
        y = torch.arange(0, height, dtype=torch.float32, device="cpu")
        y, x = torch.meshgrid((y, x), indexing="ij")
        circle = ((x - c_x) ** 2 + (y - c_y) ** 2) <= (radius ** 2)
        mask[circle] = 1.0
    elif "horizontal door" in transition_type:
        bar = math.ceil(height * alpha / 2)
        if bar > 0:
            mask[:bar, :, :] = 1.0
            mask[-bar:,:, :] = 1.0
    elif "vertical door" in transition_type:
        bar = math.ceil(width * alpha / 2)
        if bar > 0:
            mask[:, :bar,:] = 1.0
            mask[:, -bar:,:] = 1.0
    elif "fade" in transition_type:
        mask[:, :, :] = alpha

    mask = gaussian_blur(mask, blur_radius)

    return images_1 * (1 - mask) + images_2 * mask

def gaussian_blur(mask, blur_radius):
    if blur_radius > 0:
        kernel_size = int(blur_radius * 2) + 1
        if kernel_size % 2 == 0:
            kernel_size += 1  # Ensure kernel size is odd
        sigma = blur_radius / 3
        x = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32)
        x = torch.exp(-0.5 * (x / sigma) ** 2)
        kernel1d = x / x.sum()
        kernel2d = kernel1d[:, None] * kernel1d[None, :]
        kernel2d = kernel2d.to(mask.device)
        kernel2d = kernel2d.expand(mask.shape[2], 1, kernel2d.shape[0], kernel2d.shape[1])
        mask = mask.permute(2, 0, 1).unsqueeze(0)  # Change to [C, H, W] and add batch dimension
        mask = F.conv2d(mask, kernel2d, padding=kernel_size // 2, groups=mask.shape[1])
        mask = mask.squeeze(0).permute(1, 2, 0)  # Change back to [H, W, C]
    return mask

class TransitionImagesMulti:
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "transition"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Creates transitions between images.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
                 "image_1": ("IMAGE",),
                 "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],),
                 "transition_type": (["horizontal slide", "vertical slide", "box", "circle", "horizontal door", "vertical door", "fade"],),
                 "transitioning_frames": ("INT", {"default": 2,"min": 2, "max": 4096, "step": 1}),
                 "blur_radius": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 100.0, "step": 0.1}),
                 "reverse": ("BOOLEAN", {"default": False}),
                 "device": (["CPU", "GPU"], {"default": "CPU"}),
            },
            "optional": {
                "image_2": ("IMAGE",),
            }
    } 

    def transition(self, inputcount, transitioning_frames, transition_type, interpolation, device, blur_radius, reverse, **kwargs):

        gpu = model_management.get_torch_device()

        image_1 = kwargs["image_1"]
        height = image_1.shape[1]
        width = image_1.shape[2]
        first_image_shape = image_1.shape
        first_image_device = image_1.device

        easing_function = easing_functions[interpolation]
    
        for c in range(1, inputcount):
            frames = []
            new_image = kwargs.get(f"image_{c + 1}", torch.zeros(first_image_shape)).to(first_image_device)
            new_image_height = new_image.shape[1]
            new_image_width = new_image.shape[2]

            if new_image_height != height or new_image_width != width:
                new_image = common_upscale(new_image.movedim(-1, 1), width, height, "lanczos", "disabled")
                new_image = new_image.movedim(1, -1)  # Move channels back to the last dimension

            last_frame_image_1 = image_1[-1]
            first_frame_image_2 = new_image[0]
            if device == "GPU":
                last_frame_image_1 = last_frame_image_1.to(gpu)
                first_frame_image_2 = first_frame_image_2.to(gpu)

            if reverse:
                last_frame_image_1, first_frame_image_2 = first_frame_image_2, last_frame_image_1

            for frame in range(transitioning_frames):
                t = frame / (transitioning_frames - 1)
                alpha = easing_function(t)
                alpha_tensor = torch.tensor(alpha, dtype=last_frame_image_1.dtype, device=last_frame_image_1.device)
                frame_image = transition_images(last_frame_image_1, first_frame_image_2, alpha_tensor, transition_type, blur_radius, reverse)
                frames.append(frame_image)
        
            frames = torch.stack(frames).cpu()
            image_1 = torch.cat((image_1, frames, new_image), dim=0)
        
        return image_1.cpu(),

class TransitionImagesInBatch:
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "transition"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Creates transitions between images in a batch.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "images": ("IMAGE",),
                 "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],),
                 "transition_type": (["horizontal slide", "vertical slide", "box", "circle", "horizontal door", "vertical door", "fade"],),
                 "transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
                 "blur_radius": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 100.0, "step": 0.1}),
                 "reverse": ("BOOLEAN", {"default": False}),
                 "device": (["CPU", "GPU"], {"default": "CPU"}),
        },
    } 

    #transitions from matteo's essential nodes
    def transition(self, images, transitioning_frames, transition_type, interpolation, device, blur_radius, reverse):
        if images.shape[0] == 1:
            return images,

        gpu = model_management.get_torch_device()

        easing_function = easing_functions[interpolation]
        
        images_list = []
        pbar = ProgressBar(images.shape[0] - 1)
        for i in range(images.shape[0] - 1):
            frames = []
            image_1 = images[i]
            image_2 = images[i + 1]

            if device == "GPU":
                image_1 = image_1.to(gpu)
                image_2 = image_2.to(gpu)

            if reverse:
                image_1, image_2 = image_2, image_1
                
            for frame in range(transitioning_frames):
                t = frame / (transitioning_frames - 1)
                alpha = easing_function(t)
                alpha_tensor = torch.tensor(alpha, dtype=image_1.dtype, device=image_1.device)
                frame_image = transition_images(image_1, image_2, alpha_tensor, transition_type, blur_radius, reverse)
                frames.append(frame_image)
            pbar.update(1)
        
            frames = torch.stack(frames).cpu()
            images_list.append(frames)
        images = torch.cat(images_list, dim=0)
        
        return images.cpu(),

class ImageBatchJoinWithTransition:
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "transition_batches"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Transitions between two batches of images, starting at a specified index in the first batch.
During the transition, frames from both batches are blended frame-by-frame, so the video keeps playing.
"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "images_1": ("IMAGE",),
                "images_2": ("IMAGE",),
                "start_index": ("INT", {"default": 0, "min": -10000, "max": 10000, "step": 1}),
                "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],),
                "transition_type": (["horizontal slide", "vertical slide", "box", "circle", "horizontal door", "vertical door", "fade"],),
                "transitioning_frames": ("INT", {"default": 1, "min": 1, "max": 4096, "step": 1}),
                "blur_radius": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}),
                "reverse": ("BOOLEAN", {"default": False}),
                "device": (["CPU", "GPU"], {"default": "CPU"}),
            },
        }

    def transition_batches(self, images_1, images_2, start_index, interpolation, transition_type, transitioning_frames, blur_radius, reverse, device):
        if images_1.shape[0] == 0 or images_2.shape[0] == 0:
            raise ValueError("Both input batches must have at least one image.")
        
        if start_index < 0:
            start_index = images_1.shape[0] + start_index
        if start_index < 0 or start_index > images_1.shape[0]:
            raise ValueError("start_index is out of range.")

        gpu = model_management.get_torch_device()
        easing_function = easing_functions[interpolation]
        out_frames = []

        # Add images from images_1 up to start_index
        if start_index > 0:
            out_frames.append(images_1[:start_index])

        # Determine how many frames we can blend
        max_transition = min(transitioning_frames, images_1.shape[0] - start_index, images_2.shape[0])

        # Blend corresponding frames from both batches
        for i in range(max_transition):
            img1 = images_1[start_index + i]
            img2 = images_2[i]
            if device == "GPU":
                img1 = img1.to(gpu)
                img2 = img2.to(gpu)
            if reverse:
                img1, img2 = img2, img1
            t = i / (max_transition - 1) if max_transition > 1 else 1.0
            alpha = easing_function(t)
            alpha_tensor = torch.tensor(alpha, dtype=img1.dtype, device=img1.device)
            frame_image = transition_images(img1, img2, alpha_tensor, transition_type, blur_radius, reverse)
            out_frames.append(frame_image.cpu().unsqueeze(0))

        # Add remaining images from images_2 after transition
        if images_2.shape[0] > max_transition:
            out_frames.append(images_2[max_transition:])

        # Concatenate all frames
        out = torch.cat(out_frames, dim=0)
        return (out.cpu(),)

class ShuffleImageBatch:
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "shuffle"
    CATEGORY = "KJNodes/image"

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "images": ("IMAGE",),
                 "seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}),
        },
    } 

    def shuffle(self, images, seed):
        torch.manual_seed(seed)
        B, H, W, C = images.shape
        indices = torch.randperm(B)
        shuffled_images = images[indices]

        return shuffled_images,

class GetImageRangeFromBatch:
    
    RETURN_TYPES = ("IMAGE", "MASK", )
    FUNCTION = "imagesfrombatch"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Returns a range of images from a batch.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "start_index": ("INT", {"default": 0,"min": -1, "max": 4096, "step": 1}),
                 "num_frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
        },
        "optional": {
            "images": ("IMAGE",),
            "masks": ("MASK",),
        }
    } 
    
    def imagesfrombatch(self, start_index, num_frames, images=None, masks=None):
        chosen_images = None
        chosen_masks = None

        # Process images if provided
        if images is not None:
            if start_index == -1:
                start_index = max(0, len(images) - num_frames)
            if start_index < 0 or start_index >= len(images):
                raise ValueError("Start index is out of range")
            end_index = min(start_index + num_frames, len(images))
            chosen_images = images[start_index:end_index]

        # Process masks if provided
        if masks is not None:
            if start_index == -1:
                start_index = max(0, len(masks) - num_frames)
            if start_index < 0 or start_index >= len(masks):
                raise ValueError("Start index is out of range for masks")
            end_index = min(start_index + num_frames, len(masks))
            chosen_masks = masks[start_index:end_index]

        return (chosen_images, chosen_masks,)

class RandomImageFromBatch(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        template = io.MatchType.Template("input_type", [io.Image, io.Mask])
        return io.Schema(
            node_id="RandomImageFromBatch",
            display_name="Random Image From Batch",
            search_aliases=["random", "mask", "sequence", "frame"],
            category="KJNodes/image",
            description="Picks a sequence of frames from an image or mask batch within a selected index range. "
                        "At randomness=0 the picks are evenly spaced across the range; at randomness=1 they are "
                        "uniformly random without replacement; values in between blend linearly. "
                        "Output is always sorted by batch index. Negative indices count from the end (-1 = last).",
            inputs=[
                io.MatchType.Input("input", template=template,
                                   tooltip="Image or mask batch to sample from."),
                io.Int.Input("start_index", default=0, min=-4096, max=4096,
                             tooltip="Inclusive start of the sampling range. Negative values count from the end."),
                io.Int.Input("end_index", default=-1, min=-4096, max=4096,
                             tooltip="Inclusive end of the sampling range. -1 means the last frame."),
                io.Int.Input("num_frames", default=1, min=1, max=4096,
                             tooltip="How many frames to pick from the range."),
                io.Float.Input("randomness", default=1.0, min=0.0, max=1.0, step=0.01,
                               tooltip="0 = evenly spaced across the range, 1 = uniformly random without replacement, "
                                       "in-between = linear blend (jittered even spacing)."),
                io.Int.Input("min_distance", default=0, min=0, max=4096,
                             tooltip="Minimum gap (in frames) between consecutive picks. 0 = no minimum. "
                                     "Picks are pushed forward to satisfy this; later picks may clamp to the range end."),
                io.Int.Input("max_distance", default=0, min=0, max=4096,
                             tooltip="Maximum gap (in frames) between consecutive picks. 0 = no maximum. "
                                     "Picks are pulled in to satisfy this, which may compress the sequence toward the start."),
                io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, step=1,
                             tooltip="Random seed for reproducible sampling. Ignored when randomness is 0."),
            ],
            outputs=[
                io.MatchType.Output(template=template, display_name="output"),
            ],
        )

    @classmethod
    def execute(cls, input, start_index, end_index, num_frames, randomness, min_distance, max_distance, seed) -> io.NodeOutput:
        n = input.shape[0]
        if n == 0:
            raise ValueError("Input batch is empty.")

        s = start_index if start_index >= 0 else n + start_index
        e = end_index if end_index >= 0 else n + end_index
        s = max(0, min(s, n - 1))
        e = max(0, min(e, n - 1))
        if e < s:
            s, e = e, s
        range_size = e - s + 1

        if num_frames == 1:
            even = [(s + e) / 2]
        else:
            even = [s + i * (e - s) / (num_frames - 1) for i in range(num_frames)]

        if randomness <= 0:
            picks_float = even
        else:
            rng = random.Random(seed)
            if num_frames <= range_size:
                random_picks = rng.sample(range(s, e + 1), num_frames)
            else:
                random_picks = [rng.randint(s, e) for _ in range(num_frames)]
            random_picks.sort()
            picks_float = [(1 - randomness) * ev + randomness * rp for ev, rp in zip(even, random_picks)]

        picks = sorted(max(s, min(e, int(round(p)))) for p in picks_float)

        if num_frames > 1 and (min_distance > 0 or max_distance > 0):
            adjusted = [picks[0]]
            for i in range(1, len(picks)):
                prev = adjusted[-1]
                target = picks[i]
                if min_distance > 0 and target - prev < min_distance:
                    target = prev + min_distance
                if max_distance > 0 and target - prev > max_distance:
                    target = prev + max_distance
                adjusted.append(min(e, max(s, target)))
            picks = adjusted

        idx = torch.tensor(picks, dtype=torch.long, device=input.device)
        chosen = input.index_select(0, idx)

        return io.NodeOutput(chosen)

class ImageBatchExtendWithOverlap:

    RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", )
    RETURN_NAMES = ("source_images", "start_images", "extended_images")
    OUTPUT_TOOLTIPS = (
        "The original source images (passthrough)",
        "The input images used as the starting point for extension",
        "The extended images with overlap, if no new images are provided this will be empty",
    )
    FUNCTION = "imagesfrombatch"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Helper node for video generation extension   
First input source and overlap amount to get the starting frames for the extension.  
Then on another copy of the node provide the newly generated frames and choose how to overlap them.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "source_images": ("IMAGE", {"tooltip": "The source images to extend"}),
                "overlap": ("INT", {"default": 13,"min": 1, "max": 4096, "step": 1, "tooltip": "Number of overlapping frames between source and new images"}),
                "overlap_side": (["source", "new_images"], {"default": "source", "tooltip": "Which side to overlap on"}),
                "overlap_mode": (["cut", "linear_blend", "ease_in_out", "filmic_crossfade", "perceptual_crossfade"], {"default": "linear_blend", "tooltip": "Method to use for overlapping frames"}),
        },
        "optional": {
            "new_images": ("IMAGE", {"tooltip": "The new images to extend with"}),
        }
    }

    def imagesfrombatch(self, source_images, overlap, overlap_side, overlap_mode, new_images=None):
        if overlap > len(source_images):
            return source_images, source_images, source_images

        if new_images is not None:
            if source_images.shape[1:3] != new_images.shape[1:3]:
                raise ValueError(f"Source and new images must have the same shape: {source_images.shape[1:3]} vs {new_images.shape[1:3]}")
            # Determine where to place the overlap
            prefix = source_images[:-overlap]
            if overlap_side == "source":
                blend_src = source_images[-overlap:]
                blend_dst = new_images[:overlap]
            elif overlap_side == "new_images":
                blend_src = new_images[:overlap]
                blend_dst = source_images[-overlap:]
            suffix = new_images[overlap:]

            if overlap_mode == "linear_blend":
                # Vectorized version - process all frames at once
                alpha = torch.linspace(0, 1, overlap + 2, device=blend_src.device, dtype=blend_src.dtype)[1:-1]
                alpha = alpha.view(-1, 1, 1, 1)  # Shape: [overlap, 1, 1, 1]
                blended_images = (1 - alpha) * blend_src + alpha * blend_dst
                extended_images = torch.cat((prefix, blended_images, suffix), dim=0)

            elif overlap_mode == "filmic_crossfade":
                gamma = 2.2
                alpha = torch.linspace(0, 1, overlap + 2, device=blend_src.device, dtype=blend_src.dtype)[1:-1]
                alpha = alpha.view(-1, 1, 1, 1)
                linear_src = torch.pow(blend_src, gamma)
                linear_dst = torch.pow(blend_dst, gamma)
                blended = (1 - alpha) * linear_src + alpha * linear_dst
                blended_images = torch.pow(blended, 1.0 / gamma)
                extended_images = torch.cat((prefix, blended_images, suffix), dim=0)

            elif overlap_mode == "perceptual_crossfade":
                import kornia
                alpha = torch.linspace(0, 1, overlap + 2, device=blend_src.device, dtype=blend_src.dtype)[1:-1]

                src_nchw = blend_src.movedim(-1, 1)
                dst_nchw = blend_dst.movedim(-1, 1)
                lab_src = kornia.color.rgb_to_lab(src_nchw)
                lab_dst = kornia.color.rgb_to_lab(dst_nchw)

                # Blend in LAB space
                alpha = alpha.view(-1, 1, 1, 1)  # [N,1,1,1] for broadcasting
                blended_lab = (1 - alpha) * lab_src + alpha * lab_dst

                # Convert back to RGB and reshape
                blended_rgb = kornia.color.lab_to_rgb(blended_lab)
                blended_images = blended_rgb.movedim(1, -1)  # [N,C,H,W] -> [N,H,W,C]
                extended_images = torch.cat((prefix, blended_images, suffix), dim=0)

            elif overlap_mode == "ease_in_out":
                # Vectorized ease_in_out
                t = torch.linspace(0, 1, overlap + 2, device=blend_src.device, dtype=blend_src.dtype)[1:-1]
                eased_t = 3 * t * t - 2 * t * t * t  # ease_in_out formula
                eased_t = eased_t.view(-1, 1, 1, 1)
                blended_images = (1 - eased_t) * blend_src + eased_t * blend_dst
                extended_images = torch.cat((prefix, blended_images, suffix), dim=0)

            elif overlap_mode == "cut":
                extended_images = torch.cat((prefix, suffix), dim=0)
                if overlap_side == "new_images":
                    extended_images = torch.cat((source_images, new_images[overlap:]), dim=0)
                elif overlap_side == "source":
                    extended_images = torch.cat((source_images[:-overlap], new_images), dim=0)
        else:
            extended_images = torch.zeros((1, 64, 64, 3), device="cpu")

        start_images = source_images[-overlap:]

        return (source_images, start_images, extended_images)

class GetLatentRangeFromBatch:
    
    RETURN_TYPES = ("LATENT", )
    FUNCTION = "latentsfrombatch"
    CATEGORY = "KJNodes/latents"
    DESCRIPTION = """
Returns a range of latents from a batch.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "latents": ("LATENT",),
                "start_index": ("INT", {"default": 0,"min": -1, "max": 4096, "step": 1}),
                "num_frames": ("INT", {"default": 1,"min": -1, "max": 4096, "step": 1}),
        },
    } 
    
    def latentsfrombatch(self, latents, start_index, num_frames):
        chosen_latents = None
        samples = latents["samples"]
        if len(samples.shape) == 4:
            B, C, H, W = samples.shape
            num_latents = B
        elif len(samples.shape) == 5:
            B, C, T, H, W = samples.shape
            num_latents = T

        if start_index == -1:
            start_index = max(0, num_latents - num_frames)
        if start_index < 0 or start_index >= num_latents:
            raise ValueError("Start index is out of range")
        
        end_index = num_latents if num_frames == -1 else min(start_index + num_frames, num_latents)
        
        if len(samples.shape) == 4:
            chosen_latents = samples[start_index:end_index]
        elif len(samples.shape) == 5:
            chosen_latents = samples[:, :, start_index:end_index]

        return ({"samples": chosen_latents.contiguous(),},)
    
class InsertLatentToIndex:
    
    RETURN_TYPES = ("LATENT", )
    FUNCTION = "insert"
    CATEGORY = "KJNodes/latents"
    DESCRIPTION = """
Inserts a latent at the specified index into the original latent batch.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "source": ("LATENT",),
                "destination": ("LATENT",),
                "index": ("INT", {"default": 0,"min": -1, "max": 4096, "step": 1}),
        },
    } 
    
    def insert(self, source, destination, index):
        samples_destination = destination["samples"]
        samples_source = source["samples"].to(samples_destination)
        
        if len(samples_source.shape) == 4:
            B, C, H, W = samples_source.shape
            num_latents = B
        elif len(samples_source.shape) == 5:
            B, C, T, H, W = samples_source.shape
            num_latents = T
        
        if index >= num_latents or index < 0:
            raise ValueError(f"Index {index} out of bounds for tensor with {num_latents} latents")
        
        if len(samples_source.shape) == 4:
            joined_latents = torch.cat([
                samples_destination[:index],
                samples_source,
                samples_destination[index+1:]
            ], dim=0)
        else:
            joined_latents = torch.cat([
                samples_destination[:, :, :index],
                samples_source,
                samples_destination[:, :, index+1:]
            ], dim=2)

        return ({"samples": joined_latents,},)

class ImageBatchFilter:
    
    RETURN_TYPES = ("IMAGE", "STRING",)
    RETURN_NAMES = ("images", "removed_indices",)
    FUNCTION = "filter"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = "Removes empty images from a batch"

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "images": ("IMAGE",),
                 "empty_color": ("STRING", {"default": "0, 0, 0"}),
                "empty_threshold": ("FLOAT", {"default": 0.01,"min": 0.0, "max": 1.0, "step": 0.01}),
        },
        "optional": {
            "replacement_image": ("IMAGE",),
        }
    } 
    
    def filter(self, images, empty_color, empty_threshold, replacement_image=None):
        B, H, W, C = images.shape

        input_images = images.clone()

        empty_color_list = [int(color.strip()) for color in empty_color.split(',')]
        empty_color_tensor = torch.tensor(empty_color_list, dtype=torch.float32).to(input_images.device)
        
        color_diff = torch.abs(input_images - empty_color_tensor)
        mean_diff = color_diff.mean(dim=(1, 2, 3))

        empty_indices = mean_diff <= empty_threshold
        empty_indices_string = ', '.join([str(i) for i in range(B) if empty_indices[i]])
        
        if replacement_image is not None:
            B_rep, H_rep, W_rep, C_rep = replacement_image.shape
            replacement = replacement_image.clone()
            if (H_rep != images.shape[1]) or (W_rep != images.shape[2]) or (C_rep != images.shape[3]):
                replacement = common_upscale(replacement.movedim(-1, 1), W, H, "lanczos", "center").movedim(1, -1)
            input_images[empty_indices] = replacement[0]

            return (input_images, empty_indices_string,) 
        else:
            non_empty_images = input_images[~empty_indices]
            return (non_empty_images, empty_indices_string,)
    
class GetImagesFromBatchIndexed:
    
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "indexedimagesfrombatch"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Selects and returns the images at the specified indices as an image batch.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "images": ("IMAGE",),
                 "indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}),
        },
    } 
    
    def indexedimagesfrombatch(self, images, indexes):
        
        # Parse the indexes string into a list of integers
        index_list = [int(index.strip()) for index in indexes.split(',')]
        
        # Convert list of indices to a PyTorch tensor
        indices_tensor = torch.tensor(index_list, dtype=torch.long)
        
        # Select the images at the specified indices
        chosen_images = images[indices_tensor]
        
        return (chosen_images,)

class InsertImagesToBatchIndexed:
    
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "insertimagesfrombatch"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Inserts images at the specified indices into the original image batch.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "original_images": ("IMAGE",),
                "images_to_insert": ("IMAGE",),
                "indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}),
            },
            "optional": {
                "mode": (["replace", "insert"],),
            }
        }
    
    def insertimagesfrombatch(self, original_images, images_to_insert, indexes, mode="replace"):
        if indexes == "":
            return (original_images,)

        input_images = original_images.clone()
        
        # Parse the indexes string into a list of integers
        index_list = [int(index.strip()) for index in indexes.split(',')]
        
        # Convert list of indices to a PyTorch tensor
        indices_tensor = torch.tensor(index_list, dtype=torch.long)
        
        # Ensure the images_to_insert is a tensor
        if not isinstance(images_to_insert, torch.Tensor):
            images_to_insert = torch.tensor(images_to_insert)
        
        if mode == "replace":
            # Replace the images at the specified indices
            for index, image in zip(indices_tensor, images_to_insert):
                input_images[index] = image
        else:
            # Create a list to hold the new image sequence
            new_images = []
            insert_offset = 0
            
            for i in range(len(input_images) + len(indices_tensor)):
                if insert_offset < len(indices_tensor) and i == indices_tensor[insert_offset]:
                    # Use modulo to cycle through images_to_insert
                    new_images.append(images_to_insert[insert_offset % len(images_to_insert)])
                    insert_offset += 1
                else:
                    new_images.append(input_images[i - insert_offset])
            
            # Convert the list back to a tensor
            input_images = torch.stack(new_images, dim=0)
        
        return (input_images,)

class PadImageBatchInterleaved:
    
    RETURN_TYPES = ("IMAGE", "MASK",)
    RETURN_NAMES = ("images", "masks",)
    FUNCTION = "pad"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Inserts empty frames between the images in a batch.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "empty_frames_per_image": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
                "pad_frame_value": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 1.0, "step": 0.01}),
                "add_after_last": ("BOOLEAN", {"default": False}),
            },
        }
    
    def pad(self, images, empty_frames_per_image, pad_frame_value, add_after_last):
        B, H, W, C = images.shape
        
        # Handle single frame case specifically
        if B == 1:
            total_frames = 1 + empty_frames_per_image if add_after_last else 1
        else:
            # Original B images + (B-1) sets of empty frames between them
            total_frames = B + (B-1) * empty_frames_per_image
            # Add additional empty frames after the last image if requested
            if add_after_last:
                total_frames += empty_frames_per_image
        
        # Create new tensor with zeros (empty frames)
        padded_batch = torch.ones((total_frames, H, W, C), 
                                dtype=images.dtype, 
                                device=images.device) * pad_frame_value
        # Create mask tensor (1 for original frames, 0 for empty frames)
        mask = torch.zeros((total_frames, H, W), 
                        dtype=images.dtype, 
                        device=images.device)
        
        # Fill in original images at their new positions
        for i in range(B):
            if B == 1:
                # For single frame, just place it at the beginning
                new_pos = 0
            else:
                # Each image is separated by empty_frames_per_image blank frames
                new_pos = i * (empty_frames_per_image + 1)
                
            padded_batch[new_pos] = images[i]
            mask[new_pos] = 1.0  # Mark this as an original frame
        
        return (padded_batch, mask)

class ReplaceImagesInBatch:
    
    RETURN_TYPES = ("IMAGE", "MASK",)
    FUNCTION = "replace"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Replaces the images in a batch, starting from the specified start index,  
with the replacement images.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "start_index": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}),
        },
        "optional": {
            "original_images": ("IMAGE",),
            "replacement_images": ("IMAGE",),
            "original_masks": ("MASK",),
            "replacement_masks": ("MASK",),
        }
    } 
    
    def replace(self, original_images=None, replacement_images=None, start_index=1, original_masks=None, replacement_masks=None):
        images = None
        masks = None
        
        if original_images is not None and replacement_images is not None:
            if start_index >= len(original_images):
                raise ValueError("ReplaceImagesInBatch: Start index is out of range")
            end_index = start_index + len(replacement_images)
            if end_index > len(original_images):
                raise ValueError("ReplaceImagesInBatch: End index is out of range")
            
            original_images_copy = original_images.clone()
            if original_images_copy.shape[2] != replacement_images.shape[2] or original_images_copy.shape[3] != replacement_images.shape[3]:
                replacement_images = common_upscale(replacement_images.movedim(-1, 1), original_images_copy.shape[1], original_images_copy.shape[2], "lanczos", "center").movedim(1, -1)
            
            original_images_copy[start_index:end_index] = replacement_images
            images = original_images_copy
        else:
            images = torch.zeros((1, 64, 64, 3))
        
        if original_masks is not None and replacement_masks is not None:
            if start_index >= len(original_masks):
                raise ValueError("ReplaceImagesInBatch: Start index is out of range")
            end_index = start_index + len(replacement_masks)
            if end_index > len(original_masks):
                raise ValueError("ReplaceImagesInBatch: End index is out of range")

            original_masks_copy = original_masks.clone()
            if original_masks_copy.shape[1] != replacement_masks.shape[1] or original_masks_copy.shape[2] != replacement_masks.shape[2]:
                replacement_masks = common_upscale(replacement_masks.unsqueeze(1), original_masks_copy.shape[1], original_masks_copy.shape[2], "nearest-exact", "center").squeeze(0)
                
            original_masks_copy[start_index:end_index] = replacement_masks
            masks = original_masks_copy
        else:
            masks = torch.zeros((1, 64, 64))
        
        return (images, masks)
    

class ReverseImageBatch:
    
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "reverseimagebatch"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Reverses the order of the images in a batch.
"""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                 "images": ("IMAGE",),
        },
    } 
    
    def reverseimagebatch(self, images):
        reversed_images = torch.flip(images, [0])
        return (reversed_images, )

class ImageBatchMulti:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
                "image_1": ("IMAGE", ),
            },
            "optional": {
                "image_2": ("IMAGE", ),
            }
    }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("images",)
    FUNCTION = "combine"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Creates an image batch from multiple images.  
You can set how many inputs the node has,  
with the **inputcount** and clicking update.
"""

    def combine(self, inputcount, **kwargs):
        first = kwargs["image_1"]
        h, w = first.shape[1], first.shape[2]

        # determine output shape
        max_ch = first.shape[-1]
        total_frames = first.shape[0]
        for c in range(1, inputcount):
            img = kwargs.get(f"image_{c + 1}")
            if img is not None:
                max_ch = max(max_ch, img.shape[-1])
                total_frames += img.shape[0]
            else:
                total_frames += first.shape[0]

        # pre-allocate output
        out = torch.empty((total_frames, h, w, max_ch), dtype=first.dtype)
        offset = 0

        for c in range(inputcount):
            img = kwargs.get(f"image_{c + 1}", torch.zeros((first.shape[0], h, w, max_ch), dtype=first.dtype))

            if img.shape[1:3] != (h, w):
                img = common_upscale(img.movedim(-1, 1), w, h, "bilinear", "center").movedim(1, -1)

            if img.shape[-1] < max_ch:
                img = torch.nn.functional.pad(img, (0, max_ch - img.shape[-1]), mode='constant', value=1.0)

            n = img.shape[0]
            out[offset:offset + n].copy_(img, non_blocking=True)
            offset += n
            del img

        return (out.cpu(),)


class ImageTensorList:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "image1": ("IMAGE",),
            "image2": ("IMAGE",),
        }}

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "append"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Creates an image list from the input images.
"""

    def append(self, image1, image2):
        image_list = []
        if isinstance(image1, torch.Tensor) and isinstance(image2, torch.Tensor):
            image_list = [image1, image2]
        elif isinstance(image1, list) and isinstance(image2, torch.Tensor):
            image_list = image1 + [image2]
        elif isinstance(image1, torch.Tensor) and isinstance(image2, list):
            image_list = [image1] + image2
        elif isinstance(image1, list) and isinstance(image2, list):
            image_list = image1 + image2
        return image_list,

class ImageAddMulti:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
                "image_1": ("IMAGE", ),
                "image_2": ("IMAGE", ),
                "blending": (
                [   'add',
                    'subtract',
                    'multiply',
                    'difference',
                ],
                {
                "default": 'add'
                }),
                "blend_amount": ("FLOAT", {"default": 0.5, "min": 0, "max": 1, "step": 0.01}),
            },
    }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("images",)
    FUNCTION = "add"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Add blends multiple images together.    
You can set how many inputs the node has,  
with the **inputcount** and clicking update.
"""

    def add(self, inputcount, blending, blend_amount, **kwargs):
        image = kwargs["image_1"]
        for c in range(1, inputcount):
            new_image = kwargs[f"image_{c + 1}"]
            if blending == "add":
                image = torch.add(image * blend_amount, new_image * blend_amount)
            elif blending == "subtract":
                image = torch.sub(image * blend_amount, new_image * blend_amount)
            elif blending == "multiply":
                image = torch.mul(image * blend_amount, new_image * blend_amount)
            elif blending == "difference":
                image = torch.sub(image, new_image)
        return (image,)    


class ImageConcatMulti(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        # image_1 drives the output type; image_2 (and JS-added image_3+) can independently be IMAGE or MASK
        type_template = io.MatchType.Template("multi_image_or_mask", allowed_types=[io.Image, io.Mask])
        return io.Schema(
            node_id="ImageConcatMulti",
            display_name="Image Concatenate Multi",
            category="KJNodes/image",
            description=(
                "Creates an image from multiple images or masks.\n"
                "Set the input count and click 'Update inputs' to add more slots.\n"
                "The output type follows image_1; other inputs are converted to match."
            ),
            accept_all_inputs=True, # JS dynamically adds image_3..image_N beyond the declared inputs
            inputs=[
                io.Int.Input("inputcount", default=2, min=2, max=1000, step=1),
                io.MatchType.Input("image_1", template=type_template),
                io.Combo.Input("direction", options=['right', 'down', 'left', 'up'], default='right'),
                io.Boolean.Input("match_image_size", default=False),
                io.MultiType.Input("image_2", types=[io.Image, io.Mask], optional=True),
            ],
            outputs=[
                io.MatchType.Output(template=type_template, display_name="output"),
            ],
        )

    @classmethod
    def execute(cls, inputcount, image_1, direction, match_image_size, image_2=None, **kwargs) -> io.NodeOutput:
        kwargs["image_1"] = image_1
        if image_2 is not None:
            kwargs["image_2"] = image_2
        image = image_1
        first_image_shape = image.shape
        device = model_management.intermediate_device()
        dtype = model_management.intermediate_dtype()
        for c in range(1, inputcount):
            key = f"image_{c + 1}"
            new_image = kwargs[key] if key in kwargs else torch.zeros(
                first_image_shape, dtype=dtype, device=device
            )
            image = ImageConcanate.concatenate(image, new_image, direction, match_image_size, first_image_shape=first_image_shape)
        return io.NodeOutput(image)

class PreviewAnimation:
    def __init__(self):
        self.output_dir = folder_paths.get_temp_directory()
        self.type = "temp"
        self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
        self.compress_level = 1

    methods = {"default": 4, "fastest": 0, "slowest": 6}
    @classmethod
    def INPUT_TYPES(s):
        return {"required":
                    {
                     "fps": ("FLOAT", {"default": 8.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
                     },
                "optional": {
                    "images": ("IMAGE", ),
                    "masks": ("MASK", ),
                },
            }

    RETURN_TYPES = ()
    FUNCTION = "preview"
    OUTPUT_NODE = True
    CATEGORY = "KJNodes/image"

    def preview(self, fps, images=None, masks=None):
        filename_prefix = "AnimPreview"
        full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
        results = list()

        pil_images = []

        if images is not None and masks is not None:
            for image in images:
                i = 255. * image.cpu().numpy()
                img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
                pil_images.append(img)
            for mask in masks:
                if pil_images: 
                    mask_np = mask.cpu().numpy()
                    mask_np = np.clip(mask_np * 255, 0, 255).astype(np.uint8)  # Convert to values between 0 and 255
                    mask_img = Image.fromarray(mask_np, mode='L')
                    img = pil_images.pop(0)  # Remove and get the first image
                    img = img.convert("RGBA")  # Convert base image to RGBA

                    # Create a new RGBA image based on the grayscale mask
                    rgba_mask_img = Image.new("RGBA", img.size, (255, 255, 255, 255))
                    rgba_mask_img.putalpha(mask_img)  # Use the mask image as the alpha channel

                    # Composite the RGBA mask onto the base image
                    composited_img = Image.alpha_composite(img, rgba_mask_img)
                    pil_images.append(composited_img)  # Add the composited image back

        elif images is not None and masks is None:
            for image in images:
                i = 255. * image.cpu().numpy()
                img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
                pil_images.append(img)

        elif masks is not None and images is None:
            for mask in masks:
                mask_np = 255. * mask.cpu().numpy()
                mask_img = Image.fromarray(np.clip(mask_np, 0, 255).astype(np.uint8))
                pil_images.append(mask_img)
        else:
            logging.warning("PreviewAnimation: No images or masks provided")
            return { "ui": { "images": results, "animated": (None,), "text": "empty" }}

        num_frames = len(pil_images)

        c = len(pil_images)
        for i in range(0, c, num_frames):
            file = f"{filename}_{counter:05}_.webp"
            pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], lossless=False, quality=50, method=0)
            results.append({
                "filename": file,
                "subfolder": subfolder,
                "type": self.type
            })
            counter += 1

        animated = num_frames != 1
        return { "ui": { "images": results, "animated": (animated,), "text": [f"{num_frames}x{pil_images[0].size[0]}x{pil_images[0].size[1]}"] } }
    
class ImageResizeKJ:
    upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
                "height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
                "upscale_method": (s.upscale_methods,),
                "keep_proportion": ("BOOLEAN", { "default": False }),
                "divisible_by": ("INT", { "default": 2, "min": 0, "max": 512, "step": 1, }),
            },
            "optional" : {
                #"width_input": ("INT", { "forceInput": True}),
                #"height_input": ("INT", { "forceInput": True}),
                "get_image_size": ("IMAGE",),
                "crop": (["disabled","center", 0], { "tooltip": "0 will do the default center crop, this is a workaround for the widget order changing with the new frontend, as in old workflows the value of this widget becomes 0 automatically" }),
            }
        }

    RETURN_TYPES = ("IMAGE", "INT", "INT",)
    RETURN_NAMES = ("IMAGE", "width", "height",)
    FUNCTION = "resize"
    CATEGORY = "KJNodes/image"
    DEPRECATED = True
    DESCRIPTION = """
DEPRECATED!

Due to ComfyUI frontend changes, this node should no longer be used, please check the   
v2 of the node. This node is only kept to not completely break older workflows.  

"""

    def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, 
               width_input=None, height_input=None, get_image_size=None, crop="disabled"):
        B, H, W, C = image.shape

        if width_input:
            width = width_input
        if height_input:
            height = height_input
        if get_image_size is not None:
            _, height, width, _ = get_image_size.shape
        
        if keep_proportion and get_image_size is None:
                # If one of the dimensions is zero, calculate it to maintain the aspect ratio
                if width == 0 and height != 0:
                    ratio = height / H
                    width = round(W * ratio)
                elif height == 0 and width != 0:
                    ratio = width / W
                    height = round(H * ratio)
                elif width != 0 and height != 0:
                    # Scale based on which dimension is smaller in proportion to the desired dimensions
                    ratio = min(width / W, height / H)
                    width = round(W * ratio)
                    height = round(H * ratio)
        else:
            if width == 0:
                width = W
            if height == 0:
                height = H
      
        if divisible_by > 1 and get_image_size is None:
            width = width - (width % divisible_by)
            height = height - (height % divisible_by)
        
        if crop == 0: #workaround for old workflows
            crop = "center"

        image = image.movedim(-1,1)
        image = common_upscale(image, width, height, upscale_method, crop)
        image = image.movedim(1,-1)

        return(image, image.shape[2], image.shape[1],)

class ImageResizeKJv2:
    upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos", "nvidia_rtx_vsr"]
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
                "height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
                "upscale_method": (s.upscale_methods,),
                "keep_proportion": (["stretch", "resize", "pad", "pad_edge", "pad_edge_pixel", "crop", "pillarbox_blur", "total_pixels"], { "default": False }),
                "pad_color": ("STRING", { "default": "0, 0, 0", "tooltip": "Color to use for padding."}),
                "crop_position": (["center", "top", "bottom", "left", "right"], { "default": "center" }),
                "divisible_by": ("INT", { "default": 2, "min": 0, "max": 512, "step": 1, }),
            },
            "optional" : {
                "mask": ("MASK",),
                "device": (["cpu", "gpu"],),
                #"per_batch": ("INT", { "default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, "tooltip": "Process images in sub-batches to reduce memory usage. 0 disables sub-batching."}),
            },
             "hidden": {
                "unique_id": "UNIQUE_ID",
            },
        }

    RETURN_TYPES = ("IMAGE", "INT", "INT", "MASK",)
    RETURN_NAMES = ("IMAGE", "width", "height", "mask",)
    FUNCTION = "resize"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """
Resizes the image to the specified width and height.  
Size can be retrieved from the input.

Keep proportions keeps the aspect ratio of the image, by  
highest dimension.  
"""

    def resize(self, image, width, height, keep_proportion, upscale_method, divisible_by, pad_color, crop_position, unique_id, device="cpu", mask=None, per_batch=64):
        B, H, W, C = image.shape

        # Treat ComfyUI's 64x64 placeholder mask as no mask
        if mask is not None and mask.shape[-2:] == (64, 64) and (H != 64 or W != 64):
            mask = None

        # Scale mask to match image dimensions if they differ
        if mask is not None and mask.shape[-2:] != (H, W):
            mask = common_upscale(mask.unsqueeze(1), W, H, "bilinear", crop="disabled").squeeze(1)

        if device == "gpu":
            if upscale_method == "lanczos":
                raise ValueError("Lanczos is not supported on the GPU")
            device = model_management.get_torch_device()
        else:
            device = torch.device("cpu")

        pillarbox_blur = keep_proportion == "pillarbox_blur"

        # Initialize padding variables
        pad_left = pad_right = pad_top = pad_bottom = 0

        if keep_proportion in ["resize", "total_pixels"] or keep_proportion.startswith("pad") or pillarbox_blur:
            if keep_proportion == "total_pixels":
                total_pixels = width * height
                aspect_ratio = W / H
                new_height = int(math.sqrt(total_pixels / aspect_ratio))
                new_width = int(math.sqrt(total_pixels * aspect_ratio))

            # If one of the dimensions is zero, calculate it to maintain the aspect ratio
            elif width == 0 and height == 0:
                new_width = W
                new_height = H
            elif width == 0 and height != 0:
                ratio = height / H
                new_width = round(W * ratio)
                new_height = height
            elif height == 0 and width != 0:
                ratio = width / W
                new_width = width
                new_height = round(H * ratio)
            elif width != 0 and height != 0:
                ratio = min(width / W, height / H)
                new_width = round(W * ratio)
                new_height = round(H * ratio)
            else:
                new_width = width
                new_height = height

            if keep_proportion.startswith("pad") or pillarbox_blur:
                # Calculate padding based on position
                if crop_position == "center":
                    pad_left = (width - new_width) // 2
                    pad_right = width - new_width - pad_left
                    pad_top = (height - new_height) // 2
                    pad_bottom = height - new_height - pad_top
                elif crop_position == "top":
                    pad_left = (width - new_width) // 2
                    pad_right = width - new_width - pad_left
                    pad_top = 0
                    pad_bottom = height - new_height
                elif crop_position == "bottom":
                    pad_left = (width - new_width) // 2
                    pad_right = width - new_width - pad_left
                    pad_top = height - new_height
                    pad_bottom = 0
                elif crop_position == "left":
                    pad_left = 0
                    pad_right = width - new_width
                    pad_top = (height - new_height) // 2
                    pad_bottom = height - new_height - pad_top
                elif crop_position == "right":
                    pad_left = width - new_width
                    pad_right = 0
                    pad_top = (height - new_height) // 2
                    pad_bottom = height - new_height - pad_top

            width = new_width
            height = new_height
        else:
            if width == 0:
                width = W
            if height == 0:
                height = H

        if divisible_by > 1:
            width = width - (width % divisible_by)
            height = height - (height % divisible_by)

        # Preflight estimate (log-only when batching is active)
        if per_batch != 0 and B > per_batch:
            try:
                bytes_per_elem = image.element_size()  # typically 4 for float32
                est_total_bytes = B * height * width * C * bytes_per_elem
                est_mb = est_total_bytes / (1024 * 1024)
                msg = f"<tr><td>Resize v2</td><td>estimated output ~{est_mb:.2f} MB; batching {per_batch}/{B}</td></tr>"
                if unique_id and PromptServer is not None:
                    try:
                        PromptServer.instance.send_progress_text(msg, unique_id)
                    except Exception:
                        pass
                else:
                    logging.info(f"[ImageResizeKJv2] estimated output ~{est_mb:.2f} MB; batching {per_batch}/{B}")
            except Exception:
                pass

        # NVIDIA RTX Video Super Resolution setup
        nvvfx_sr = None
        nvvfx_ctx = None
        if upscale_method == "nvidia_rtx_vsr":
            try:
                import nvvfx
            except ImportError:
                raise ImportError("NVIDIA RTX Video Super Resolution is not available. Please install the nvidia-vfx library and ensure you have a compatible NVIDIA GPU.")
            nvvfx_ctx = nvvfx.VideoSuperRes(nvvfx.effects.QualityLevel.ULTRA)
            nvvfx_sr = nvvfx_ctx.__enter__()
            nvvfx_sr.output_width = max(8, round(width / 8) * 8)
            nvvfx_sr.output_height = max(8, round(height / 8) * 8)
            nvvfx_sr.load()

        def _process_subbatch(in_image, in_mask, pad_left, pad_right, pad_top, pad_bottom):
            # Avoid unnecessary clones; only move if needed
            out_image = in_image if in_image.device == device else in_image.to(device)
            out_mask = None if in_mask is None else (in_mask if in_mask.device == device else in_mask.to(device))

            # Crop logic
            if keep_proportion == "crop":
                old_height = out_image.shape[-3]
                old_width = out_image.shape[-2]
                old_aspect = old_width / old_height
                new_aspect = width / height
                if old_aspect > new_aspect:
                    crop_w = round(old_height * new_aspect)
                    crop_h = old_height
                else:
                    crop_w = old_width
                    crop_h = round(old_width / new_aspect)
                if crop_position == "center":
                    x = (old_width - crop_w) // 2
                    y = (old_height - crop_h) // 2
                elif crop_position == "top":
                    x = (old_width - crop_w) // 2
                    y = 0
                elif crop_position == "bottom":
                    x = (old_width - crop_w) // 2
                    y = old_height - crop_h
                elif crop_position == "left":
                    x = 0
                    y = (old_height - crop_h) // 2
                elif crop_position == "right":
                    x = old_width - crop_w
                    y = (old_height - crop_h) // 2
                out_image = out_image.narrow(-2, x, crop_w).narrow(-3, y, crop_h)
                if out_mask is not None:
                    out_mask = out_mask.narrow(-1, x, crop_w).narrow(-2, y, crop_h)

            if upscale_method == "nvidia_rtx_vsr":
                # Process each frame through RTX Video Super Resolution
                frames_chw = out_image.movedim(-1, 1).cuda().contiguous()
                upscaled_frames = []
                for j in range(frames_chw.shape[0]):
                    dlpack_out = nvvfx_sr.run(frames_chw[j]).image
                    upscaled_frames.append(torch.from_dlpack(dlpack_out).clone())
                out_image = torch.stack(upscaled_frames, dim=0).movedim(1, -1).cpu()
                if out_mask is not None:
                    out_mask = common_upscale(out_mask.unsqueeze(1), width, height, "bilinear", crop="disabled").squeeze(1)
            else:
                out_image = common_upscale(out_image.movedim(-1,1), width, height, upscale_method, crop="disabled").movedim(1,-1)
                if out_mask is not None:
                    if upscale_method == "lanczos":
                        out_mask = common_upscale(out_mask.unsqueeze(1).repeat(1, 3, 1, 1), width, height, upscale_method, crop="disabled").movedim(1,-1)[:, :, :, 0]
                    else:
                        out_mask = common_upscale(out_mask.unsqueeze(1), width, height, upscale_method, crop="disabled").squeeze(1)

            # Pad logic
            if (keep_proportion.startswith("pad") or pillarbox_blur) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0):
                padded_width = width + pad_left + pad_right
                padded_height = height + pad_top + pad_bottom
                if divisible_by > 1:
                    width_remainder = padded_width % divisible_by
                    height_remainder = padded_height % divisible_by
                    if width_remainder > 0:
                        extra_width = divisible_by - width_remainder
                        pad_right += extra_width
                    if height_remainder > 0:
                        extra_height = divisible_by - height_remainder
                        pad_bottom += extra_height

                pad_mode = (
                    "pillarbox_blur" if pillarbox_blur else
                    "edge" if keep_proportion == "pad_edge" else
                    "edge_pixel" if keep_proportion == "pad_edge_pixel" else
                    "color"
                )
                out_image, out_mask = ImagePadKJ.pad(self, out_image, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, pad_mode, mask=out_mask)

            return out_image, out_mask

        # If batching disabled (per_batch==0) or batch fits, process whole batch
        if per_batch == 0 or B <= per_batch:
            out_image, out_mask = _process_subbatch(image, mask, pad_left, pad_right, pad_top, pad_bottom)
        else:
            chunks = []
            mask_chunks = [] if mask is not None else None
            total_batches = (B + per_batch - 1) // per_batch
            current_batch = 0
            for start_idx in range(0, B, per_batch):
                current_batch += 1
                end_idx = min(start_idx + per_batch, B)
                sub_img = image[start_idx:end_idx]
                sub_mask = mask[start_idx:end_idx] if mask is not None else None
                sub_out_img, sub_out_mask = _process_subbatch(sub_img, sub_mask, pad_left, pad_right, pad_top, pad_bottom)
                chunks.append(sub_out_img.cpu())
                if mask is not None:
                    mask_chunks.append(sub_out_mask.cpu() if sub_out_mask is not None else None)
                # Per-batch progress update
                if unique_id and PromptServer is not None:
                    try:
                        PromptServer.instance.send_progress_text(
                            f"<tr><td>Resize v2</td><td>batch {current_batch}/{total_batches} · images {end_idx}/{B}</td></tr>",
                            unique_id
                        )
                    except Exception:
                        pass
                else:
                    logging.info(f"[ImageResizeKJv2] batch {current_batch}/{total_batches} · images {end_idx}/{B}")
            out_image = torch.cat(chunks, dim=0)
            if mask is not None and any(m is not None for m in mask_chunks):
                out_mask = torch.cat([m for m in mask_chunks if m is not None], dim=0)
            else:
                out_mask = None

        # Cleanup NVIDIA RTX VSR context
        if nvvfx_ctx is not None:
            nvvfx_ctx.__exit__(None, None, None)

        # Progress UI
        if unique_id and PromptServer is not None:
            try:
                num_elements = out_image.numel()
                element_size = out_image.element_size()
                memory_size_mb = (num_elements * element_size) / (1024 * 1024)
                PromptServer.instance.send_progress_text(
                    f"<tr><td>Output: </td><td><b>{out_image.shape[0]}</b> x <b>{out_image.shape[2]}</b> x <b>{out_image.shape[1]} | {memory_size_mb:.2f}MB</b></td></tr>",
                    unique_id
                )
            except Exception:
                pass

        return (out_image.cpu(), out_image.shape[2], out_image.shape[1], out_mask.cpu() if out_mask is not None else torch.zeros(64,64, device=torch.device("cpu"), dtype=torch.float32))

class LoadAndResizeImage:
    _color_channels = ["alpha", "red", "green", "blue"]
    @classmethod
    def INPUT_TYPES(s):
        input_dir = folder_paths.get_input_directory()
        files = [f.name for f in pathlib.Path(input_dir).iterdir() if f.is_file()]
        return {"required":
                    {
                    "image": (sorted(files), {"image_upload": True}),
                    "resize": ("BOOLEAN", { "default": False }),
                    "width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
                    "height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
                    "repeat": ("INT", { "default": 1, "min": 1, "max": 4096, "step": 1, }),
                    "keep_proportion": ("BOOLEAN", { "default": False }),
                    "divisible_by": ("INT", { "default": 2, "min": 0, "max": 512, "step": 1, }),
                    "mask_channel": (s._color_channels, {"tooltip": "Channel to use for the mask output"}), 
                    "background_color": ("STRING", { "default": "", "tooltip": "Fills the alpha channel with the specified color."}),
                    },
                }

    CATEGORY = "KJNodes/image"
    RETURN_TYPES = ("IMAGE", "MASK", "INT", "INT", "STRING",)
    RETURN_NAMES = ("image", "mask", "width", "height","image_path",)
    FUNCTION = "load_image"

    def load_image(self, image, resize, width, height, repeat, keep_proportion, divisible_by, mask_channel, background_color):
        image_path = folder_paths.get_annotated_filepath(image)

        img = node_helpers.pillow(Image.open, image_path)
        img = ImageOps.exif_transpose(img)

        # Process the background_color using the helper function
        if background_color:
            color_list = string_to_color(background_color)
            # Ensure we have RGBA (add alpha if only RGB)
            if len(color_list) == 3:
                bg_color_rgba = tuple(color_list) + (255,)
            else:
                bg_color_rgba = tuple(color_list)
        else:
            bg_color_rgba = None  # No background color specified

        output_images = []
        output_masks = []
        w, h = None, None

        excluded_formats = ['MPO']

        W, H = img.size
        if resize:
            if keep_proportion:
                ratio = min(width / W, height / H)
                width = round(W * ratio)
                height = round(H * ratio)
            else:
                if width == 0:
                    width = W
                if height == 0:
                    height = H

            if divisible_by > 1:
                width = width - (width % divisible_by)
                height = height - (height % divisible_by)
        else:
            width, height = W, H

        for frame in ImageSequence.Iterator(img):
            frame = node_helpers.pillow(ImageOps.exif_transpose, frame)

            if frame.mode == 'I':
                frame = frame.point(lambda i: i * (1 / 255))
            
            if frame.mode == 'P':
                frame = frame.convert("RGBA")
            elif 'A' in frame.getbands():
                frame = frame.convert("RGBA")
            
            # Extract alpha channel if it exists
            if 'A' in frame.getbands() and bg_color_rgba:
                alpha_mask = np.array(frame.getchannel('A')).astype(np.float32) / 255.0
                alpha_mask = 1. - torch.from_numpy(alpha_mask)
                bg_image = Image.new("RGBA", frame.size, bg_color_rgba)
                # Composite the frame onto the background
                frame = Image.alpha_composite(bg_image, frame)
            else:
                alpha_mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
            
            image = frame.convert("RGB")

            if len(output_images) == 0:
                w = image.size[0]
                h = image.size[1]
            
            if image.size[0] != w or image.size[1] != h:
                continue
            if resize:
                image = image.resize((width, height), Image.Resampling.BILINEAR)

            image = np.array(image).astype(np.float32) / 255.0
            image = torch.from_numpy(image)[None,]
            
            c = mask_channel[0].upper()
            if c in frame.getbands():
                if resize:
                    frame = frame.resize((width, height), Image.Resampling.BILINEAR)
                mask = np.array(frame.getchannel(c)).astype(np.float32) / 255.0
                mask = torch.from_numpy(mask)
                if c == 'A' and bg_color_rgba:
                    mask = alpha_mask
                elif c == 'A':
                    mask = 1. - mask
            else:
                mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")

            output_images.append(image)
            output_masks.append(mask.unsqueeze(0))

        if len(output_images) > 1 and img.format not in excluded_formats:
            output_image = torch.cat(output_images, dim=0)
            output_mask = torch.cat(output_masks, dim=0)
        else:
            output_image = output_images[0]
            output_mask = output_masks[0]
            if repeat > 1:
                output_image = output_image.repeat(repeat, 1, 1, 1)
                output_mask = output_mask.repeat(repeat, 1, 1)

        return (output_image, output_mask, width, height, image_path)
        

    # @classmethod
    # def IS_CHANGED(s, image, **kwargs):
    #     image_path = folder_paths.get_annotated_filepath(image)
    #     m = hashlib.sha256()
    #     with open(image_path, 'rb') as f:
    #         m.update(f.read())
    #     return m.digest().hex()

    @classmethod
    def VALIDATE_INPUTS(s, image):
        if not folder_paths.exists_annotated_filepath(image):
            return "Invalid image file: {}".format(image)

        return True

class LoadImagesFromFolderKJ:
    # Dictionary to store folder hashes
    folder_hashes = {}

    @classmethod
    def IS_CHANGED(cls, folder, **kwargs):
        if folder and not os.path.isabs(folder) and args.base_directory:
            folder = os.path.join(args.base_directory, folder)
        if not folder or not os.path.isdir(folder):
            return float("NaN")
        
        valid_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.tga']
        include_subfolders = kwargs.get('include_subfolders', False)
        
        file_data = []
        if include_subfolders:
            for root, _, files in os.walk(folder):
                for file in files:
                    if any(file.lower().endswith(ext) for ext in valid_extensions):
                        path = os.path.join(root, file)
                        try:
                            mtime = os.path.getmtime(path)
                            file_data.append((path, mtime))
                        except OSError:
                            pass
        else:
            for file in sorted(os.listdir(folder)):
                if any(file.lower().endswith(ext) for ext in valid_extensions):
                    path = os.path.join(folder, file)
                    try:
                        mtime = os.path.getmtime(path)
                        file_data.append((path, mtime))
                    except OSError:
                        pass
        
        file_data.sort()
        
        combined_hash = hashlib.md5()
        combined_hash.update(folder.encode('utf-8'))
        combined_hash.update(str(len(file_data)).encode('utf-8'))
        
        for path, mtime in file_data:
            combined_hash.update(f"{path}:{mtime}".encode('utf-8'))
        
        current_hash = combined_hash.hexdigest()
        
        old_hash = cls.folder_hashes.get(folder)
        cls.folder_hashes[folder] = current_hash
        
        if old_hash == current_hash:
            return old_hash
        
        return current_hash

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "folder": ("STRING", {"default": ""}),
                "width": ("INT", {"default": 1024, "min": -1, "step": 1}),
                "height": ("INT", {"default": 1024, "min": -1, "step": 1}),
                "keep_aspect_ratio": (["crop", "pad", "stretch",],), 
            },
            "optional": {
                "image_load_cap": ("INT", {"default": 0, "min": 0, "step": 1}),
                "start_index": ("INT", {"default": 0, "min": 0, "step": 1}),
                "include_subfolders": ("BOOLEAN", {"default": False}),
            }
        }

    RETURN_TYPES = ("IMAGE", "MASK", "INT", "STRING",)
    RETURN_NAMES = ("image", "mask", "count", "image_path",)
    FUNCTION = "load_images"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = """Loads images from a folder into a batch, images are resized and loaded into a batch."""

    def load_images(self, folder, width, height, image_load_cap, start_index, keep_aspect_ratio, include_subfolders=False):    
        if folder and not os.path.isabs(folder) and args.base_directory:
            folder = os.path.join(args.base_directory, folder)
        if not folder or not os.path.isdir(folder):
            raise FileNotFoundError(f"Folder '{folder}' cannot be found.")
        
        valid_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.tga']
        image_paths = []
        if include_subfolders:
            for root, _, files in os.walk(folder):
                for file in files:
                    if any(file.lower().endswith(ext) for ext in valid_extensions):
                        image_paths.append(os.path.join(root, file))
        else:
            for file in sorted(os.listdir(folder)):
                if any(file.lower().endswith(ext) for ext in valid_extensions):
                    image_paths.append(os.path.join(folder, file))

        dir_files = sorted(image_paths)

        if len(dir_files) == 0:
            raise FileNotFoundError(f"No files in directory '{folder}'.")

        # start at start_index
        dir_files = dir_files[start_index:]

        images = []
        masks = []
        image_path_list = []

        limit_images = False
        if image_load_cap > 0:
            limit_images = True
        image_count = 0

        pbar = ProgressBar(len(dir_files))

        for image_path in dir_files:
            if os.path.isdir(image_path):
                continue
            if limit_images and image_count >= image_load_cap:
                break
            i = Image.open(image_path)
            i = ImageOps.exif_transpose(i)
            
            # Resize image to maximum dimensions
            if width == -1 and height == -1:
                width = i.size[0]
                height = i.size[1]
            if i.size != (width, height):
                i = self.resize_with_aspect_ratio(i, width, height, keep_aspect_ratio)
            
            
            image = i.convert("RGB")
            image = np.array(image).astype(np.float32) / 255.0
            image = torch.from_numpy(image)[None,]
            
            if 'A' in i.getbands():
                mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
                mask = 1. - torch.from_numpy(mask)
                if mask.shape != (height, width):
                    mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), 
                                                         size=(height, width), 
                                                         mode='bilinear', 
                                                         align_corners=False).squeeze()
            else:
                mask = torch.zeros((height, width), dtype=torch.float32, device="cpu")
            
            images.append(image)
            masks.append(mask)
            image_path_list.append(image_path)
            image_count += 1
            pbar.update(1)

        if len(images) == 1:
            return (images[0], masks[0], 1, image_path_list)
        
        elif len(images) > 1:
            image1 = images[0]
            mask1 = masks[0].unsqueeze(0)

            for image2 in images[1:]:
                image1 = torch.cat((image1, image2), dim=0)

            for mask2 in masks[1:]:
                mask1 = torch.cat((mask1, mask2.unsqueeze(0)), dim=0)

            return (image1, mask1, len(images), image_path_list)
    def resize_with_aspect_ratio(self, img, width, height, mode):
        if mode == "stretch":
            return img.resize((width, height), Image.Resampling.LANCZOS)
        
        img_width, img_height = img.size
        aspect_ratio = img_width / img_height
        target_ratio = width / height

        if mode == "crop":
            # Calculate dimensions for center crop
            if aspect_ratio > target_ratio:
                # Image is wider - crop width
                new_width = int(height * aspect_ratio)
                img = img.resize((new_width, height), Image.Resampling.LANCZOS)
                left = (new_width - width) // 2
                return img.crop((left, 0, left + width, height))
            else:
                # Image is taller - crop height
                new_height = int(width / aspect_ratio)
                img = img.resize((width, new_height), Image.Resampling.LANCZOS)
                top = (new_height - height) // 2
                return img.crop((0, top, width, top + height))

        elif mode == "pad":
            pad_color = self.get_edge_color(img)
            # Calculate dimensions for padding
            if aspect_ratio > target_ratio:
                # Image is wider - pad height
                new_height = int(width / aspect_ratio)
                img = img.resize((width, new_height), Image.Resampling.LANCZOS)
                padding = (height - new_height) // 2
                padded = Image.new('RGBA', (width, height), pad_color)
                padded.paste(img, (0, padding))
                return padded
            else:
                # Image is taller - pad width
                new_width = int(height * aspect_ratio)
                img = img.resize((new_width, height), Image.Resampling.LANCZOS)
                padding = (width - new_width) // 2
                padded = Image.new('RGBA', (width, height), pad_color)
                padded.paste(img, (padding, 0))
                return padded
    def get_edge_color(self, img):
        """Sample edges and return dominant color"""
        width, height = img.size
        img = img.convert('RGBA')
        
        # Create 1-pixel high/wide images from edges
        top = img.crop((0, 0, width, 1))
        bottom = img.crop((0, height-1, width, height))
        left = img.crop((0, 0, 1, height))
        right = img.crop((width-1, 0, width, height))
        
        # Combine edges into single image
        edges = Image.new('RGBA', (width*2 + height*2, 1))
        edges.paste(top, (0, 0))
        edges.paste(bottom, (width, 0))
        edges.paste(left.resize((height, 1)), (width*2, 0))
        edges.paste(right.resize((height, 1)), (width*2 + height, 0))
        
        # Get median color
        stat = ImageStat.Stat(edges)
        median = tuple(map(int, stat.median))
        return median

class ImageGridtoBatch:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "image": ("IMAGE", ),
                    "columns": ("INT", {"default": 3, "min": 1, "max": 8, "tooltip": "The number of columns in the grid."}),
                    "rows": ("INT", {"default": 0, "min": 1, "max": 8, "tooltip": "The number of rows in the grid. Set to 0 for automatic calculation."}),
                  }
                }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "decompose"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = "Converts a grid of images to a batch of images."

    def decompose(self, image, columns, rows):
        B, H, W, C = image.shape

        # Calculate cell width, rounding down
        cell_width = W // columns

        if rows == 0:
            # If rows is 0, calculate number of full rows
            cell_height = H // columns
            rows = H // cell_height
        else:
            # If rows is specified, adjust cell_height
            cell_height = H // rows

        # Crop the image to fit full cells
        image = image[:, :rows*cell_height, :columns*cell_width, :]

        # Reshape and permute the image to get the grid
        image = image.view(B, rows, cell_height, columns, cell_width, C)
        image = image.permute(0, 1, 3, 2, 4, 5).contiguous()
        image = image.view(B, rows * columns, cell_height, cell_width, C)

        # Reshape to the final batch tensor
        img_tensor = image.view(-1, cell_height, cell_width, C)

        return (img_tensor,)

class SaveImageKJ:
    def __init__(self):
        self.type = "output"
        self.prefix_append = ""
        self.compress_level = 4

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE", {"tooltip": "The images to save."}),
                "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
                "output_folder": ("STRING", {"default": "output", "tooltip": "The folder to save the images to."}),
            },
            "optional": {
                "caption_file_extension": ("STRING", {"default": ".txt", "tooltip": "The extension for the caption file. Limited to plain-text/data formats."}),
                "caption": ("STRING", {"forceInput": True, "tooltip": "string to save as .txt file"}),
            },
            "hidden": {
                "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"
            },
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("filename",)
    FUNCTION = "save_images"

    OUTPUT_NODE = True

    CATEGORY = "KJNodes/image"
    DESCRIPTION = "Saves the input images to your ComfyUI output directory."

    def save_images(self, images, output_folder, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, caption=None, caption_file_extension=".txt"):
        filename_prefix += self.prefix_append

        if os.path.isabs(output_folder):
            if not os.path.exists(output_folder):
                os.makedirs(output_folder, exist_ok=True)
            full_output_folder = output_folder
            _, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_folder, images[0].shape[1], images[0].shape[0])
        else:
            self.output_dir = folder_paths.get_output_directory()
            full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])

        # sanitize caption extension: strip path components so it can't traverse out of the chosen folder, and allowlist to text/data formats
        if caption is not None:
            caption_file_extension = os.path.basename(caption_file_extension)
            if caption_file_extension and not caption_file_extension.startswith("."):
                caption_file_extension = "." + caption_file_extension
            if caption_file_extension.lower() not in SaveStringKJ.ALLOWED_EXTENSIONS:
                raise ValueError(f"Disallowed caption extension '{caption_file_extension}'. Allowed: {', '.join(SaveStringKJ.ALLOWED_EXTENSIONS)}")

        results = list()
        for (batch_number, image) in enumerate(images):
            i = 255. * image.cpu().numpy()
            img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
            metadata = None
            if not args.disable_metadata:
                metadata = PngInfo()
                if prompt is not None:
                    metadata.add_text("prompt", json.dumps(prompt))
                if extra_pnginfo is not None:
                    for x in extra_pnginfo:
                        metadata.add_text(x, json.dumps(extra_pnginfo[x]))

            filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
            base_file_name = f"{filename_with_batch_num}_{counter:05}_"
            file = f"{base_file_name}.png"
            img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
            results.append({
                "filename": file,
                "subfolder": subfolder,
                "type": self.type
            })
            if caption is not None:
                txt_file = base_file_name + caption_file_extension
                file_path = os.path.join(full_output_folder, txt_file)

                if os.path.commonpath((os.path.abspath(full_output_folder), os.path.abspath(file_path))) != os.path.abspath(full_output_folder):
                    raise ValueError(f"Refusing to write caption outside the target folder: {file_path}")
                with open(file_path, "w", encoding="utf-8") as f:
                    f.write(caption)

            counter += 1

        return file, 

class SaveStringKJ:
    ALLOWED_EXTENSIONS = [".txt", ".caption", ".json", ".yaml", ".yml", ".md", ".csv", ".tsv", ".xml", ".log", ".ini", ".toml"]

    def __init__(self):
        self.output_dir = folder_paths.get_output_directory()
        self.type = "output"
        self.prefix_append = ""

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "string": ("STRING", {"forceInput": True, "tooltip": "string to save as .txt file"}),
                "filename_prefix": ("STRING", {"default": "text", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
                "output_folder": ("STRING", {"default": "output", "tooltip": "Subfolder within the ComfyUI output directory to save to. Paths resolving outside the output directory are rejected."}),
            },
            "optional": {
                "file_extension": ("STRING", {"default": ".txt", "tooltip": "The extension for the saved file. Limited to plain-text/data formats."}),
            },
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("filename",)
    FUNCTION = "save_string"

    OUTPUT_NODE = True

    CATEGORY = "KJNodes/misc"
    DESCRIPTION = "Saves the input string to your ComfyUI output directory."

    def save_string(self, string, output_folder, filename_prefix="text", file_extension=".txt"):
        filename_prefix += self.prefix_append

        output_dir = os.path.abspath(self.output_dir)
        if output_folder and output_folder != "output":
            sub = os.path.splitdrive(output_folder)[1].replace("\\", "/").lstrip("/")
            target_dir = os.path.abspath(os.path.join(output_dir, sub))
        else:
            target_dir = output_dir

        try:
            inside = os.path.commonpath((output_dir, target_dir)) == output_dir
        except ValueError:
            inside = False
        if not inside:
            raise ValueError(f"output_folder must resolve within the ComfyUI output directory: {target_dir}")
        os.makedirs(target_dir, exist_ok=True)

        full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, target_dir)

        file_extension = os.path.basename(file_extension)
        if file_extension and not file_extension.startswith("."):
            file_extension = "." + file_extension

        if file_extension.lower() not in self.ALLOWED_EXTENSIONS:
            raise ValueError(f"Disallowed file extension '{file_extension}'. Allowed: {', '.join(self.ALLOWED_EXTENSIONS)}")

        base_file_name = f"{filename_prefix}_{counter:05}_"

        txt_file = base_file_name + file_extension
        file_path = os.path.join(full_output_folder, txt_file)
        while os.path.exists(file_path):
            counter += 1
            base_file_name = f"{filename_prefix}_{counter:05}_"
            txt_file = base_file_name + file_extension
            file_path = os.path.join(full_output_folder, txt_file)

        if os.path.commonpath((os.path.abspath(full_output_folder), os.path.abspath(file_path))) != os.path.abspath(full_output_folder):
            raise ValueError(f"Refusing to write outside the target folder: {file_path}")
        with open(file_path, 'w', encoding="utf-8") as f:
            f.write(string)

        return file_path,

class FastPreview:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image": ("IMAGE", ),
                "format": (["JPEG", "PNG"], {"default": "JPEG"}),
                "max_size": ("INT", {"default": 768, "min": 128, "max": 4096, "step": 64,
                             "tooltip": "Maximum width or height for the preview. Images larger than this are downscaled before encoding."}),
            },
            "hidden": {
                "unique_id": "UNIQUE_ID",
                "prompt_id": "PROMPT_ID",
            },
        }

    RETURN_TYPES = ()
    FUNCTION = "preview"
    CATEGORY = "KJNodes/experimental"
    OUTPUT_NODE = True
    DESCRIPTION = "Fast image preview using binary websocket, bypassing base64/JSON overhead."

    def preview(self, image, format, max_size, unique_id=None, prompt_id=None):
        arr = image[0].cpu().mul(255).clamp(0, 255).byte().numpy()
        h, w = arr.shape[:2]

        if w > max_size or h > max_size:
            scale = max_size / max(w, h)
            new_w, new_h = int(w * scale), int(h * scale)
            if HAS_CV2:
                arr = cv2.resize(arr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
                pil_image = Image.fromarray(arr)
            else:
                pil_image = Image.fromarray(arr).resize((new_w, new_h), Image.BILINEAR)
        else:
            pil_image = Image.fromarray(arr)

        if format == "JPEG" and pil_image.mode != "RGB":
            pil_image = pil_image.convert("RGB")

        if PromptServer is not None and unique_id is not None:
            server = PromptServer.instance
            client_supports_metadata = False
            if hasattr(BinaryEventTypes, "PREVIEW_IMAGE_WITH_METADATA"):
                try:
                    from comfy_api import feature_flags
                    client_supports_metadata = feature_flags.supports_feature(
                        server.sockets_metadata, server.client_id, "supports_preview_metadata"
                    )
                except Exception:
                    client_supports_metadata = False

            if client_supports_metadata:
                server.send_sync(
                    BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA,
                    (
                        (format, pil_image, None),
                        {
                            "node_id": unique_id,
                            "display_node_id": unique_id,
                            "prompt_id": prompt_id or "",
                        },
                    ),
                    server.client_id,
                )
            else:
                server.send_sync(
                    BinaryEventTypes.UNENCODED_PREVIEW_IMAGE,
                    (format, pil_image, None),
                    server.client_id,
                )

        return {"ui": {"fast_preview": [True]}, "result": ()}


class FastPreviewBatch(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="FastPreviewBatch",
            display_name="Fast Preview Batch",
            category="KJNodes/experimental",
            description="Encodes an image batch as an all-I-frame H.264 MP4 thumbnail strip "
                        "and shows it as an interactive grid. Click a tile to enlarge with "
                        "prev/next browsing. Avoids materializing N PNGs.",
            inputs=[
                io.MultiType.Input("input", [io.Image, io.Mask], tooltip="Image or mask batch to preview."),
                io.Int.Input("max_thumb_size", default=512, min=512, max=1024, step=8,
                             tooltip="Detail-view (mp4) thumbnail max side. Strip thumbs for the grid are auto-capped at 256."),
                io.Int.Input("crf", default=25, min=0, max=51, step=1,
                             tooltip="H.264 CRF. Lower = higher quality / larger file."),
                io.Int.Input("max_grid_frames", default=1024, min=1, max=4096, step=1,
                             tooltip="If batch exceeds this, frames are stride-sampled evenly."),
            ],
            is_output_node=True,
        )

    @classmethod
    def execute(cls, input, max_thumb_size, crf, max_grid_frames) -> io.NodeOutput:
        import av
        import threading
        import queue as _queue
        if input.ndim == 3:
            images = input.reshape((-1, 1, input.shape[-2], input.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
        else:
            images = input
        B, H, W, _ = images.shape

        if B > max_grid_frames:
            idx = torch.linspace(0, B - 1, max_grid_frames).round().long().tolist()
        else:
            idx = list(range(B))
        total = len(idx)

        scale = min(1.0, max_thumb_size / max(H, W))
        new_w = max(2, int(round(W * scale)))
        new_h = max(2, int(round(H * scale)))
        # yuv420p needs even dimensions
        new_w -= new_w & 1
        new_h -= new_h & 1

        # Strip thumbs serve the grid only; cap at 256 so the tiled JPEG stays well
        # under any browser image-decode limit regardless of detail-view size.
        STRIP_MAX = 256
        strip_scale = min(1.0, STRIP_MAX / max(new_h, new_w))
        strip_w = max(2, int(round(new_w * strip_scale)))
        strip_h = max(2, int(round(new_h * strip_scale)))

        output_dir = folder_paths.get_temp_directory()
        prefix = "kj_batch_preview_" + ''.join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(6))
        full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
            prefix, output_dir, new_w, new_h
        )
        file = f"{filename}_{counter:05}_.mp4"
        filepath = os.path.join(full_output_folder, file)
        strip_file = f"{filename}_{counter:05}_grid.jpg"
        strip_path = os.path.join(full_output_folder, strip_file)

        # Square-ish tiling for the JS grid renderer.
        strip_cols = max(1, int(math.ceil(math.sqrt(total))))
        strip_rows = int(math.ceil(total / strip_cols))
        strip_arr = np.zeros((strip_rows * strip_h, strip_cols * strip_w, 3), dtype=np.uint8)

        fps = 30
        container = av.open(filepath, mode="w")
        try:
            stream = container.add_stream("libx264", rate=Fraction(fps, 1))
            stream.width = new_w
            stream.height = new_h
            stream.pix_fmt = "yuv420p"
            stream.options = {"crf": str(crf), "preset": "ultrafast", "g": "1", "tune": "fastdecode"}

            chunk_size = 32
            need_resize = (new_h, new_w) != (H, W)
            need_strip_resize = (strip_h, strip_w) != (new_h, new_w)
            mode = 'area' if scale < 1.0 else 'bilinear'
            work_device = model_management.get_torch_device()

            def _to_numpy_nhwc_u8(t):
                return (t.mul(255).clamp(0, 255)
                         .to(dtype=torch.uint8, device='cpu')
                         .permute(0, 2, 3, 1).contiguous().numpy())

            # Producer: GPU resize + transfer; consumer (this thread): PyAV encode.
            # PyTorch GPU ops, host transfers, and PyAV's libx264 call all release the
            # GIL, so threading actually overlaps the two stages.
            frame_queue = _queue.Queue(maxsize=2)
            producer_error = [None]

            def producer():
                try:
                    for c_start in range(0, total, chunk_size):
                        c_idx = idx[c_start:c_start + chunk_size]
                        sel = (images[c_idx, ..., :3].permute(0, 3, 1, 2).contiguous()
                                                      .to(device=work_device, non_blocking=True))
                        sel_video = F.interpolate(sel, size=(new_h, new_w), mode=mode) if need_resize else sel
                        sel_strip = F.interpolate(sel_video, size=(strip_h, strip_w), mode='area') if need_strip_resize else sel_video
                        video_frames = _to_numpy_nhwc_u8(sel_video)
                        strip_frames = video_frames if sel_strip is sel_video else _to_numpy_nhwc_u8(sel_strip)
                        del sel, sel_video, sel_strip
                        frame_queue.put((c_start, video_frames, strip_frames))
                except Exception as e:
                    producer_error[0] = e
                finally:
                    frame_queue.put(None)

            producer_thread = threading.Thread(target=producer, daemon=True)
            producer_thread.start()

            pbar = ProgressBar(total)
            while True:
                item = frame_queue.get()
                if item is None:
                    break
                c_start, video_frames, strip_frames = item
                for i in range(video_frames.shape[0]):
                    global_idx = c_start + i
                    sr = global_idx // strip_cols
                    sc = global_idx % strip_cols
                    strip_arr[sr * strip_h:(sr + 1) * strip_h, sc * strip_w:(sc + 1) * strip_w] = strip_frames[i]
                    frame = av.VideoFrame.from_ndarray(video_frames[i], format="rgb24")
                    for packet in stream.encode(frame):
                        container.mux(packet)
                    pbar.update(1)

            producer_thread.join()
            if producer_error[0] is not None:
                raise producer_error[0]

            for packet in stream.encode():
                container.mux(packet)
        finally:
            container.close()

        Image.fromarray(strip_arr).save(strip_path, quality=85)

        return io.NodeOutput(ui={"kj_batch_preview": [{
            "filename": file,
            "subfolder": subfolder,
            "type": "temp",
            "frame_count": total,
            "fps": fps,
            "thumb_w": new_w,
            "thumb_h": new_h,
            "strip_filename": strip_file,
            "strip_cols": strip_cols,
            "strip_cell_w": strip_w,
            "strip_cell_h": strip_h,
        }]})


class ImageCropByMaskAndResize:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE", ),
                "mask": ("MASK", ),
                "base_resolution": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
                "padding": ("INT", { "default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
                "min_crop_resolution": ("INT", { "default": 128, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
                "max_crop_resolution": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
            },
        }

    RETURN_TYPES = ("IMAGE", "MASK", "BBOX", )
    RETURN_NAMES = ("images", "masks", "bbox",)
    FUNCTION = "crop"
    CATEGORY = "KJNodes/image"

    def crop_by_mask(self, mask, padding=0, min_crop_resolution=None, max_crop_resolution=None):
        """
        Calculate bounding box from mask with proper padding boundary protection
        Ensures crop region never exceeds original image boundaries
        """
        iy, ix = (mask == 1).nonzero(as_tuple=True)
        h0, w0 = mask.shape

        # Handle empty mask
        if iy.numel() == 0:
            x_c = w0 / 2.0
            y_c = h0 / 2.0
            width = 0
            height = 0
        else:
            x_min = ix.min().item()
            x_max = ix.max().item()
            y_min = iy.min().item()
            y_max = iy.max().item()
            width = x_max - x_min + 1  # Include boundary pixels
            height = y_max - y_min + 1
            x_c = (x_min + x_max) / 2.0
            y_c = (y_min + y_max) / 2.0

        # Apply min/max resolution constraints
        if min_crop_resolution:
            width = max(width, min_crop_resolution)
            height = max(height, min_crop_resolution)
        if max_crop_resolution:
            width = min(width, max_crop_resolution)
            height = min(height, max_crop_resolution)

        # Critical: Limit padding expansion to available image space
        # Calculate maximum possible padding for each direction
        max_padding_x = min((w0 - width) // 2, padding)
        max_padding_y = min((h0 - height) // 2, padding)
        
        # Apply constrained padding
        final_width = width + 2 * max_padding_x
        final_height = height + 2 * max_padding_y

        # Ensure final dimensions don't exceed image bounds
        final_width = min(final_width, w0)
        final_height = min(final_height, h0)

        # Calculate top-left corner with boundary protection
        # Center the crop while respecting image boundaries
        x0 = max(0, min(int(x_c - final_width / 2), w0 - final_width))
        y0 = max(0, min(int(y_c - final_height / 2), h0 - final_height))

        return (x0, y0, final_width, final_height)

    def crop(self, image, mask, base_resolution, padding=0, min_crop_resolution=128, max_crop_resolution=512):
        """
        Main crop and resize function with uniform target dimensions for all batch items
        """
        mask = mask.round()
        image_list = []
        mask_list = []
        bbox_list = []

        # Step 1: Calculate individual bounding boxes
        bbox_params = []
        aspect_ratios = []
        for i in range(image.shape[0]):
            x0, y0, w, h = self.crop_by_mask(mask[i], padding, min_crop_resolution, max_crop_resolution)
            bbox_params.append((x0, y0, w, h))
            aspect_ratios.append(w / h)

        # Step 2: Calculate uniform target dimensions based on maximum aspect ratio
        max_w = max([w for x0, y0, w, h in bbox_params])
        max_h = max([h for x0, y0, w, h in bbox_params])
        max_aspect_ratio = max(aspect_ratios)

        # Round up to nearest multiple of 16 for stable processing
        max_w = (max_w + 15) // 16 * 16
        max_h = (max_h + 15) // 16 * 16

        # Determine target dimensions maintaining aspect ratio
        if max_aspect_ratio > 1:
            target_width = base_resolution
            target_height = int(base_resolution / max_aspect_ratio)
        else:
            target_height = base_resolution
            target_width = int(base_resolution * max_aspect_ratio)

        # Ensure target dimensions are multiples of 16
        target_width = (target_width + 15) // 16 * 16
        target_height = (target_height + 15) // 16 * 16

        # Step 3: Process each image with uniform crop size
        for i in range(image.shape[0]):
            orig_x0, orig_y0, orig_w, orig_h = bbox_params[i]
            
            # Calculate center of original bounding box
            x_center = orig_x0 + orig_w / 2
            y_center = orig_y0 + orig_h / 2

            # Define uniform crop region centered on each image's bounding box
            # This ensures all crops have exactly the same dimensions
            x0_new = max(0, min(int(x_center - max_w / 2), image.shape[2] - max_w))
            y0_new = max(0, min(int(y_center - max_h / 2), image.shape[1] - max_h))
            x1_new = x0_new + max_w
            y1_new = y0_new + max_h

            # Extract cropped regions
            cropped_image = image[i][y0_new:y1_new, x0_new:x1_new, :]
            cropped_mask = mask[i][y0_new:y1_new, x0_new:x1_new]

            # Resize to exact target dimensions
            # Image with lanczos interpolation
            cropped_image = cropped_image.unsqueeze(0).movedim(-1, 1)
            cropped_image = common_upscale(cropped_image, target_width, target_height, "lanczos", "disabled")
            cropped_image = cropped_image.movedim(1, -1).squeeze(0)

            # Mask with bilinear interpolation
            cropped_mask = cropped_mask.unsqueeze(0).unsqueeze(0)
            cropped_mask = common_upscale(cropped_mask, target_width, target_height, 'bilinear', "disabled")
            cropped_mask = cropped_mask.squeeze(0).squeeze(0)

            image_list.append(cropped_image)
            mask_list.append(cropped_mask)
            bbox_list.append((x0_new, y0_new, x1_new, y1_new))

        return (torch.stack(image_list), torch.stack(mask_list), bbox_list)
    
class ImageCropByMask:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE", ),
                "mask": ("MASK", ),           
            },
        }

    RETURN_TYPES = ("IMAGE", )
    RETURN_NAMES = ("image", )
    FUNCTION = "crop"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = "Crops the input images based on the provided mask."

    def crop(self, image, mask):
        B, H, W, C = image.shape
        mask = mask.round()
        
        # Find bounding box for each batch
        crops = []
        
        for b in range(B):
            # Get coordinates of non-zero elements
            rows = torch.any(mask[min(b, mask.shape[0]-1)] > 0, dim=1)
            cols = torch.any(mask[min(b, mask.shape[0]-1)] > 0, dim=0)
            
            # Find boundaries
            y_min, y_max = torch.where(rows)[0][[0, -1]]
            x_min, x_max = torch.where(cols)[0][[0, -1]]
            
            # Crop image and mask
            crop = image[b:b+1, y_min:y_max+1, x_min:x_max+1, :]            
            crops.append(crop)
        
        # Stack results back together
        cropped_images = torch.cat(crops, dim=0)
        
        return (cropped_images, )

       
    
class ImageUncropByMask:

    @classmethod
    def INPUT_TYPES(s):
        return {"required":
                    {    
                        "destination": ("IMAGE",),
                        "source": ("IMAGE",),
                        "mask": ("MASK",),
                        "bbox": ("BBOX",),
                     },
                }

    CATEGORY = "KJNodes/image"
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = "uncrop"

    def uncrop(self, destination, source, mask, bbox=None):

        output_list = []

        B, H, W, C = destination.shape
       
        for i in range(source.shape[0]):
            x0, y0, x1, y1 = bbox[i]
            bbox_height = y1 - y0
            bbox_width = x1 - x0

            # Resize source image to match the bounding box dimensions
            #resized_source = F.interpolate(source[i].unsqueeze(0).movedim(-1, 1), size=(bbox_height, bbox_width), mode='bilinear', align_corners=False)
            resized_source = common_upscale(source[i].unsqueeze(0).movedim(-1, 1), bbox_width, bbox_height, "lanczos", "disabled")
            resized_source = resized_source.movedim(1, -1).squeeze(0)
    
            # Resize mask to match the bounding box dimensions
            resized_mask = common_upscale(mask[i].unsqueeze(0).unsqueeze(0), bbox_width, bbox_height, "bilinear", "disabled")
            resized_mask = resized_mask.squeeze(0).squeeze(0)

            # Calculate padding values
            pad_left = x0
            pad_right = W - x1
            pad_top = y0
            pad_bottom = H - y1

            # Pad the resized source image and mask to fit the destination dimensions
            padded_source = F.pad(resized_source, pad=(0, 0, pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
            padded_mask = F.pad(resized_mask, pad=(pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)

            # Ensure the padded mask has the correct shape
            padded_mask = padded_mask.unsqueeze(2).expand(-1, -1, destination[i].shape[2])
            # Ensure the padded source has the correct shape
            padded_source = padded_source.unsqueeze(2).expand(-1, -1, -1, destination[i].shape[2]).squeeze(2)
            
            # Combine the destination and padded source images using the mask
            result = destination[i] * (1.0 - padded_mask) + padded_source * padded_mask

            output_list.append(result)


        return (torch.stack(output_list),)
    
class ImageCropByMaskBatch:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "image": ("IMAGE", ),
                    "masks": ("MASK", ),
                    "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
                    "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
                    "padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1, }),
                    "preserve_size": ("BOOLEAN", {"default": False}),
                    "bg_color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255 or 0.0-1.0, or color name or hex code"}),
                  }
                }
    
    RETURN_TYPES = ("IMAGE", "MASK", )
    RETURN_NAMES = ("images", "masks",)
    FUNCTION = "crop"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = "Crops the input images based on the provided masks."
        
    def crop(self, image, masks, width, height, bg_color, padding, preserve_size):
        B, H, W, C = image.shape
        BM, HM, WM = masks.shape
        mask_count = BM
        if HM != H or WM != W:
            masks = F.interpolate(masks.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)
        output_images = []
        output_masks = []

        # Parse background color using helper function
        color_list = string_to_color(bg_color)
        bg_color = [x / 255.0 for x in color_list]
        
        # For each mask
        for i in range(mask_count):
            curr_mask = masks[i]
            
            # Find bounds
            y_indices, x_indices = torch.nonzero(curr_mask, as_tuple=True)
            if len(y_indices) == 0 or len(x_indices) == 0:
                continue
                
            # Get exact bounds with padding
            min_y = max(0, y_indices.min().item() - padding)
            max_y = min(H, y_indices.max().item() + 1 + padding)
            min_x = max(0, x_indices.min().item() - padding)
            max_x = min(W, x_indices.max().item() + 1 + padding)
            
            # Ensure mask has correct shape for multiplication
            curr_mask = curr_mask.unsqueeze(-1).expand(-1, -1, C)
            
            # Crop image and mask together
            cropped_img = image[0, min_y:max_y, min_x:max_x, :]
            cropped_mask = curr_mask[min_y:max_y, min_x:max_x, :]

            crop_h, crop_w = cropped_img.shape[0:2]
            new_w = crop_w
            new_h = crop_h

            if not preserve_size or crop_w > width or crop_h > height:
                scale = min(width/crop_w, height/crop_h)
                new_w = int(crop_w * scale)
                new_h = int(crop_h * scale)
                
                # Resize RGB
                resized_img = common_upscale(cropped_img.permute(2,0,1).unsqueeze(0), new_w, new_h, "lanczos", "disabled").squeeze(0).permute(1,2,0)
                resized_mask = torch.nn.functional.interpolate(
                    cropped_mask.permute(2,0,1).unsqueeze(0),
                    size=(new_h, new_w),
                    mode='nearest'
                ).squeeze(0).permute(1,2,0)
            else:
                resized_img = cropped_img
                resized_mask = cropped_mask

            # Create empty tensors
            new_img = torch.zeros((height, width, 3), dtype=image.dtype)
            new_mask = torch.zeros((height, width), dtype=image.dtype)

            # Pad both
            pad_x = (width - new_w) // 2
            pad_y = (height - new_h) // 2
            new_img[pad_y:pad_y+new_h, pad_x:pad_x+new_w, :] = resized_img
            if len(resized_mask.shape) == 3:
                resized_mask = resized_mask[:,:,0]  # Take first channel if 3D
            new_mask[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = resized_mask

            output_images.append(new_img)
            output_masks.append(new_mask)

        if not output_images:
            return (torch.zeros((0, height, width, 3), dtype=image.dtype),)

        out_rgb = torch.stack(output_images, dim=0)
        out_masks = torch.stack(output_masks, dim=0)

        # Apply mask to RGB
        mask_expanded = out_masks.unsqueeze(-1).expand(-1, -1, -1, 3)
        background_color = torch.tensor(bg_color, dtype=torch.float32, device=image.device)
        out_rgb = out_rgb * mask_expanded + background_color * (1 - mask_expanded)

        return (out_rgb, out_masks)
    
class ImagePadKJ:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "image": ("IMAGE", ),
                    "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
                    "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
                    "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
                    "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
                    "extra_padding": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1, }),
                    "pad_mode": (["edge", "edge_pixel", "color", "pillarbox_blur"],),
                    "color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255 or 0.0-1.0, or color name or hex code"}),
                  },
                "optional": {
                    "mask": ("MASK", ),
                    "target_width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, "forceInput": True}),
                    "target_height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1, "forceInput": True}),
                }
                }
    
    RETURN_TYPES = ("IMAGE", "MASK", )
    RETURN_NAMES = ("images", "masks",)
    FUNCTION = "pad"
    CATEGORY = "KJNodes/image"
    DESCRIPTION = "Pad the input image and optionally mask with the specified padding."
        
    def pad(self, image, left, right, top, bottom, extra_padding, color, pad_mode, mask=None, target_width=None, target_height=None):
        B, H, W, C = image.shape
        # Resize masks to image dimensions if necessary
        if mask is not None:
            BM, HM, WM = mask.shape
            if HM != H or WM != W:
                mask = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)

        # Parse background color using helper function
        color_list = string_to_color(color)
        bg_color = [x / 255.0 for x in color_list]
        if len(bg_color) == 1:
            bg_color = bg_color * 3  # Grayscale to RGB
        bg_color = torch.tensor(bg_color, dtype=image.dtype, device=image.device)

        # Calculate padding sizes with extra padding
        if target_width is not None and target_height is not None:
            if extra_padding > 0:
                image = common_upscale(image.movedim(-1, 1), W - extra_padding, H - extra_padding, "lanczos", "disabled").movedim(1, -1)
                B, H, W, C = image.shape

            padded_width = target_width
            padded_height = target_height
            pad_left = (padded_width - W) // 2
            pad_right = padded_width - W - pad_left
            pad_top = (padded_height - H) // 2
            pad_bottom = padded_height - H - pad_top
        else:
            pad_left = left + extra_padding
            pad_right = right + extra_padding
            pad_top = top + extra_padding
            pad_bottom = bottom + extra_padding

            padded_width = W + pad_left + pad_right
            padded_height = H + pad_top + pad_bottom

        # Pillarbox blur mode
        if pad_mode == "pillarbox_blur":
            def _gaussian_blur_nchw(img_nchw, sigma_px):
                if sigma_px <= 0:
                    return img_nchw
                radius = max(1, int(3.0 * float(sigma_px)))
                k = 2 * radius + 1
                x = torch.arange(-radius, radius + 1, device=img_nchw.device, dtype=img_nchw.dtype)
                k1 = torch.exp(-(x * x) / (2.0 * float(sigma_px) * float(sigma_px)))
                k1 = k1 / k1.sum()
                kx = k1.view(1, 1, 1, k)
                ky = k1.view(1, 1, k, 1)
                c = img_nchw.shape[1]
                kx = kx.repeat(c, 1, 1, 1)
                ky = ky.repeat(c, 1, 1, 1)
                img_nchw = F.conv2d(img_nchw, kx, padding=(0, radius), groups=c)
                img_nchw = F.conv2d(img_nchw, ky, padding=(radius, 0), groups=c)
                return img_nchw

            out_image = torch.zeros((B, padded_height, padded_width, C), dtype=image.dtype, device=image.device)
            for b in range(B):
                scale_fill = max(padded_width / float(W), padded_height / float(H)) if (W > 0 and H > 0) else 1.0
                bg_w = max(1, int(round(W * scale_fill)))
                bg_h = max(1, int(round(H * scale_fill)))
                src_b = image[b].movedim(-1, 0).unsqueeze(0)
                bg = common_upscale(src_b, bg_w, bg_h, "bilinear", crop="disabled")
                y0 = max(0, (bg_h - padded_height) // 2)
                x0 = max(0, (bg_w - padded_width) // 2)
                y1 = min(bg_h, y0 + padded_height)
                x1 = min(bg_w, x0 + padded_width)
                bg = bg[:, :, y0:y1, x0:x1]
                if bg.shape[2] != padded_height or bg.shape[3] != padded_width:
                    pad_h = padded_height - bg.shape[2]
                    pad_w = padded_width - bg.shape[3]
                    pad_top_fix = max(0, pad_h // 2)
                    pad_bottom_fix = max(0, pad_h - pad_top_fix)
                    pad_left_fix = max(0, pad_w // 2)
                    pad_right_fix = max(0, pad_w - pad_left_fix)
                    bg = F.pad(bg, (pad_left_fix, pad_right_fix, pad_top_fix, pad_bottom_fix), mode="replicate")
                sigma = max(1.0, 0.006 * float(min(padded_height, padded_width)))
                bg = _gaussian_blur_nchw(bg, sigma_px=sigma)
                if C >= 3:
                    r, g, bch = bg[:, 0:1], bg[:, 1:2], bg[:, 2:3]
                    luma = 0.2126 * r + 0.7152 * g + 0.0722 * bch
                    gray = torch.cat([luma, luma, luma], dim=1)
                    desat = 0.20
                    rgb = torch.cat([r, g, bch], dim=1)
                    rgb = rgb * (1.0 - desat) + gray * desat
                    bg[:, 0:3, :, :] = rgb
                dim = 0.35
                bg = torch.clamp(bg * dim, 0.0, 1.0)
                out_image[b] = bg.squeeze(0).movedim(0, -1)
            out_image[:, pad_top:pad_top+H, pad_left:pad_left+W, :] = image
            # Mask handling for pillarbox_blur
            if mask is not None:
                fg_mask = mask
                out_masks = torch.ones((B, padded_height, padded_width), dtype=image.dtype, device=image.device)
                out_masks[:, pad_top:pad_top+H, pad_left:pad_left+W] = fg_mask
            else:
                out_masks = torch.ones((B, padded_height, padded_width), dtype=image.dtype, device=image.device)
                out_masks[:, pad_top:pad_top+H, pad_left:pad_left+W] = 0.0
            return (out_image, out_masks)

        # Standard pad logic (edge/color)
        out_image = torch.zeros((B, padded_height, padded_width, C), dtype=image.dtype, device=image.device)
        for b in range(B):
                if pad_mode == "edge":
                    # Pad with edge color (mean)
                    top_edge = image[b, 0, :, :]
                    bottom_edge = image[b, H-1, :, :]
                    left_edge = image[b, :, 0, :]
                    right_edge = image[b, :, W-1, :]
                    out_image[b, :pad_top, :, :] = top_edge.mean(dim=0)
                    out_image[b, pad_top+H:, :, :] = bottom_edge.mean(dim=0)
                    out_image[b, :, :pad_left, :] = left_edge.mean(dim=0)
                    out_image[b, :, pad_left+W:, :] = right_edge.mean(dim=0)
                    out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b]
                elif pad_mode == "edge_pixel":
                    # Pad with exact edge pixel values
                    for y in range(pad_top):
                        out_image[b, y, pad_left:pad_left+W, :] = image[b, 0, :, :]
                    for y in range(pad_top+H, padded_height):
                        out_image[b, y, pad_left:pad_left+W, :] = image[b, H-1, :, :]
                    for x in range(pad_left):
                        out_image[b, pad_top:pad_top+H, x, :] = image[b, :, 0, :]
                    for x in range(pad_left+W, padded_width):
                        out_image[b, pad_top:pad_top+H, x, :] = image[b, :, W-1, :]
                    out_image[b, :pad_top, :pad_left, :] = image[b, 0, 0, :]
                    out_image[b, :pad_top, pad_left+W:, :] = image[b, 0, W-1, :]
                    out_image[b, pad_top+H:, :pad_left, :] = image[b, H-1, 0, :]
                    out_image[b, pad_top+H:, pad_left+W:, :] = image[b, H-1, W-1, :]
                    out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b]
                else:
                    # Pad with specified background color
                    out_image[b, :, :, :] = bg_color.unsqueeze(0).unsqueeze(0)
                    out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b]

        if mask is not None:
            out_masks = torch.nn.functional.pad(
                mask, 
                (pad_left, pad_right, pad_top, pad_bottom),
                mode='replicate'
            )
        else:
            out_masks = torch.ones((B, padded_height, padded_width), dtype=image.dtype, device=image.device)
            for m in range(B):
                out_masks[m, pad_top:pad_top+H, pad_left:pad_left+W] = 0.0

        return (out_image, out_masks)

# extends https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite
class LoadVideosFromFolder:
    @classmethod
    def __init__(cls):
        try:
            cls.vhs_nodes = importlib.import_module("ComfyUI-VideoHelperSuite.videohelpersuite")
        except ImportError:
            try:
                cls.vhs_nodes = importlib.import_module("comfyui-videohelpersuite.videohelpersuite")
            except ImportError:
                # Fallback to sys.modules search for Windows compatibility
                import sys
                vhs_module = None
                for module_name in sys.modules:
                    if 'videohelpersuite' in module_name and 'videohelpersuite' in sys.modules[module_name].__dict__:
                        vhs_module = sys.modules[module_name]
                        break
                
                if vhs_module is None:
                    # Try direct access to the videohelpersuite submodule
                    for module_name in sys.modules:
                        if module_name.endswith('videohelpersuite'):
                            vhs_module = sys.modules[module_name]
                            break
                
                if vhs_module is not None:
                    cls.vhs_nodes = vhs_module
                else:
                    raise ImportError("This node requires ComfyUI-VideoHelperSuite to be installed.")
                
        except ImportError:
            raise ImportError("This node requires ComfyUI-VideoHelperSuite to be installed.")
          
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "video": ("STRING", {"default": "X://insert/path/"},),
                "force_rate": ("FLOAT", {"default": 0, "min": 0, "max": 60, "step": 1, "disable": 0}),
                "custom_width": ("INT", {"default": 0, "min": 0, "max": 4096, 'disable': 0}),
                "custom_height": ("INT", {"default": 0, "min": 0, "max": 4096, 'disable': 0}),
                "frame_load_cap": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "disable": 0}),
                "skip_first_frames": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
                "select_every_nth": ("INT", {"default": 1, "min": 1, "max": 1000, "step": 1}),
                "output_type": (["batch", "grid"], {"default": "batch"}),
                "grid_max_columns": ("INT", {"default": 4, "min": 1, "max": 16, "step": 1, "disable": 1}),
                "add_label": ( "BOOLEAN", {"default": False} ),
            },
            "hidden": {
                "force_size": "STRING",
                "unique_id": "UNIQUE_ID"
            },
        }

    CATEGORY = "KJNodes/misc"

    RETURN_TYPES = ("IMAGE", )
    RETURN_NAMES = ("IMAGE", )

    FUNCTION = "load_video"

    def load_video(self, output_type, grid_max_columns, add_label=False, **kwargs):
        if kwargs.get('video') and not os.path.isabs(kwargs['video']) and args.base_directory:
            kwargs['video'] = os.path.join(args.base_directory, kwargs['video'])
            
        if self.vhs_nodes is None:
            raise ImportError("This node requires ComfyUI-VideoHelperSuite to be installed.")
        videos_list = []
        filenames = []
        for f in sorted(os.listdir(kwargs['video'])):
            if os.path.isfile(os.path.join(kwargs['video'], f)):
                file_parts = f.split('.')
                if len(file_parts) > 1 and (file_parts[-1].lower() in ['webm', 'mp4', 'mkv', 'gif', 'mov']):
                    videos_list.append(os.path.join(kwargs['video'], f))
                    filenames.append(f)

        kwargs.pop('video')
        loaded_videos = []
        for idx, video in enumerate(videos_list):
            video_tensor = self.vhs_nodes.load_video_nodes.load_video(video=video, **kwargs)[0]
            if add_label:
                # Add filename label above video (without extension)
                if video_tensor.dim() == 4:
                    _, h, w, c = video_tensor.shape
                else:
                    h, w, c = video_tensor.shape
                # Remove extension from filename
                label_text = filenames[idx].rsplit('.', 1)[0]
                font_size = max(16, w // 20)
                try:
                    font = ImageFont.truetype("arial.ttf", font_size)
                except OSError:
                    font = ImageFont.load_default()
                dummy_img = Image.new("RGB", (w, 10), (0,0,0))
                draw = ImageDraw.Draw(dummy_img)
                text_bbox = draw.textbbox((0,0), label_text, font=font)
                extra_padding = max(12, font_size // 2)  # More padding under the font
                label_height = text_bbox[3] - text_bbox[1] + extra_padding
                label_img = Image.new("RGB", (w, label_height), (0,0,0))
                draw = ImageDraw.Draw(label_img)
                draw.text((w//2 - (text_bbox[2]-text_bbox[0])//2, 4), label_text, font=font, fill=(255,255,255))
                label_np = np.asarray(label_img).astype(np.float32) / 255.0
                label_tensor = torch.from_numpy(label_np)
                if c == 1:
                    label_tensor = label_tensor.mean(dim=2, keepdim=True)
                elif c == 4:
                    alpha = torch.ones((label_height, w, 1), dtype=label_tensor.dtype)
                    label_tensor = torch.cat([label_tensor, alpha], dim=2)
                if video_tensor.dim() == 4:
                    label_tensor = label_tensor.unsqueeze(0).expand(video_tensor.shape[0], -1, -1, -1)
                    video_tensor = torch.cat([label_tensor, video_tensor], dim=1)
                else:
                    video_tensor = torch.cat([label_tensor, video_tensor], dim=0)
            loaded_videos.append(video_tensor)
        if output_type == "batch":
            out_tensor = torch.cat(loaded_videos)
        elif output_type == "grid":
            rows = (len(loaded_videos) + grid_max_columns - 1) // grid_max_columns
            # Pad the last row if needed
            total_slots = rows * grid_max_columns
            while len(loaded_videos) < total_slots:
                loaded_videos.append(torch.zeros_like(loaded_videos[0]))
            # Create grid by rows
            row_tensors = []
            for row_idx in range(rows):
                start_idx = row_idx * grid_max_columns
                end_idx = start_idx + grid_max_columns
                row_videos = loaded_videos[start_idx:end_idx]
                # Pad all videos in this row to the same height
                heights = [v.shape[1] for v in row_videos]
                max_height = max(heights)
                padded_row_videos = []
                for v in row_videos:
                    pad_height = max_height - v.shape[1]
                    if pad_height > 0:
                        # Pad (frames, H, W, C) or (H, W, C)
                        if v.dim() == 4:
                            pad = (0,0, 0,0, 0,pad_height, 0,0)  # (C,W,H,F)
                            v = torch.nn.functional.pad(v, (0,0,0,0,0,pad_height,0,0))
                        else:
                            v = torch.nn.functional.pad(v, (0,0,0,0,pad_height,0))
                    padded_row_videos.append(v)
                row_tensor = torch.cat(padded_row_videos, dim=2)  # Concatenate horizontally
                row_tensors.append(row_tensor)
            out_tensor = torch.cat(row_tensors, dim=1)  # Concatenate rows vertically
        return out_tensor,

    @classmethod
    def IS_CHANGED(s, video, **kwargs):
        if s.vhs_nodes is not None:
            return s.vhs_nodes.utils.hash_path(video)
        return None


class EncodeVideoComponents(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        position_options = ["center", "top", "bottom", "left", "right"]
        options = [
            io.DynamicCombo.Option(key="stretch", inputs=[]),
            io.DynamicCombo.Option(key="resize", inputs=[]),
            io.DynamicCombo.Option(key="total_pixels", inputs=[]),
            io.DynamicCombo.Option(key="crop", inputs=[
                io.Combo.Input("crop_position", options=position_options, tooltip="Position to crop from."),
            ]),
            io.DynamicCombo.Option(key="pad", inputs=[
                io.String.Input("pad_color", default="0, 0, 0", tooltip="Color to use for padding."),
                io.Combo.Input("pad_position", options=position_options, tooltip="Position to align the image within the padded area."),
            ]),
            io.DynamicCombo.Option(key="pad_edge", inputs=[
                io.Combo.Input("pad_position", options=position_options, tooltip="Position to align the image within the padded area."),
            ]),
            io.DynamicCombo.Option(key="pad_edge_pixel", inputs=[
                io.Combo.Input("pad_position", options=position_options, tooltip="Position to align the image within the padded area."),
            ]),
            io.DynamicCombo.Option(key="pillarbox_blur", inputs=[
                io.Combo.Input("pad_position", options=position_options, tooltip="Position to align the image within the padded area."),
            ]),
        ]
        return io.Schema(
            node_id="EncodeVideoComponents",
            search_aliases=["video to latent", "encode video", "vae encode video"],
            display_name="Encode Video Components",
            category="KJNodes/image",
            description="Extracts video frames, resizes them, and encodes with a VAE directly, avoiding storing the full image tensor.",
            inputs=[
                io.Video.Input("video", tooltip="The video to extract and encode."),
                io.Vae.Input("vae", tooltip="The VAE model to use for encoding."),
                io.Int.Input("width", default=768, min=0, max=16384, step=2, tooltip="Target width for the frames before encoding. 0 = original width."),
                io.Int.Input("height", default=512, min=0, max=16384, step=2, tooltip="Target height for the frames before encoding. 0 = original height."),
                io.Int.Input("max_frames", default=0, min=0, max=999999, step=1, tooltip="Maximum number of frames. 0 = no limit."),
                io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "lanczos"], default="lanczos", tooltip="Interpolation method for resizing."),
                io.DynamicCombo.Input(
                    "keep_proportion",
                    options=options,
                    display_name="Keep Proportion",
                    tooltip="How to handle aspect ratio mismatch when resizing.",
                ),
            ],
            outputs=[
                io.Latent.Output(display_name="latent"),
                io.Audio.Output(display_name="audio"),
                io.Float.Output(display_name="fps"),
                io.Int.Output(display_name="frame_count", tooltip="Number pixel space frames after any possible cropping"),
            ],
        )

    @staticmethod
    def _compute_resize_params(mode, position, width, height, src_w, src_h):
        """Compute target resize dimensions, crop region, and padding from keep_proportion mode."""
        if width == 0:
            width = src_w
        if height == 0:
            height = src_h
        pillarbox_blur = mode == "pillarbox_blur"
        pad_left = pad_right = pad_top = pad_bottom = 0
        crop_region = None  # (x, y, crop_w, crop_h) or None

        if mode in ["resize", "total_pixels"] or mode.startswith("pad") or pillarbox_blur:
            if mode == "total_pixels":
                total_pixels = width * height
                aspect_ratio = src_w / src_h
                new_height = int(math.sqrt(total_pixels / aspect_ratio))
                new_width = int(math.sqrt(total_pixels * aspect_ratio))
            else:
                ratio = min(width / src_w, height / src_h)
                new_width = round(src_w * ratio)
                new_height = round(src_h * ratio)

            if mode.startswith("pad") or pillarbox_blur:
                if position == "center":
                    pad_left = (width - new_width) // 2
                    pad_right = width - new_width - pad_left
                    pad_top = (height - new_height) // 2
                    pad_bottom = height - new_height - pad_top
                elif position == "top":
                    pad_left = (width - new_width) // 2
                    pad_right = width - new_width - pad_left
                    pad_top = 0
                    pad_bottom = height - new_height
                elif position == "bottom":
                    pad_left = (width - new_width) // 2
                    pad_right = width - new_width - pad_left
                    pad_top = height - new_height
                    pad_bottom = 0
                elif position == "left":
                    pad_left = 0
                    pad_right = width - new_width
                    pad_top = (height - new_height) // 2
                    pad_bottom = height - new_height - pad_top
                elif position == "right":
                    pad_left = width - new_width
                    pad_right = 0
                    pad_top = (height - new_height) // 2
                    pad_bottom = height - new_height - pad_top

            width = new_width
            height = new_height

        if mode == "crop":
            old_aspect = src_w / src_h
            new_aspect = width / height
            if old_aspect > new_aspect:
                crop_w = round(src_h * new_aspect)
                crop_h = src_h
            else:
                crop_w = src_w
                crop_h = round(src_w / new_aspect)
            if position == "center":
                x = (src_w - crop_w) // 2
                y = (src_h - crop_h) // 2
            elif position == "top":
                x = (src_w - crop_w) // 2
                y = 0
            elif position == "bottom":
                x = (src_w - crop_w) // 2
                y = src_h - crop_h
            elif position == "left":
                x = 0
                y = (src_h - crop_h) // 2
            elif position == "right":
                x = src_w - crop_w
                y = (src_h - crop_h) // 2
            crop_region = (x, y, crop_w, crop_h)

        return width, height, crop_region, (pad_left, pad_right, pad_top, pad_bottom)

    @classmethod
    def execute(cls, video, vae, width, height, max_frames, upscale_method, keep_proportion) -> io.NodeOutput:
        import av
        import itertools

        mode = keep_proportion["keep_proportion"]
        position = keep_proportion.get("crop_position") or keep_proportion.get("pad_position", "center")
        pad_color = keep_proportion.get("pad_color", "0, 0, 0")
        target_dtype = vae.vae_dtype

        # Access VideoFromFile internals for efficient per-frame decode
        source = video.get_stream_source()
        start_time = getattr(video, '_VideoFromFile__start_time', 0)
        duration = getattr(video, '_VideoFromFile__duration', 0)

        # Get frame count for progress bar, capped by max_frames
        try:
            total_frames = video.get_frame_count()
        except (ValueError, AttributeError):
            total_frames = 0
        if max_frames > 0 and total_frames > 0:
            total_frames = min(total_frames, max_frames)
        pbar = ProgressBar(total_frames) if total_frames > 0 else None

        # Lanczos requires PIL (CPU-only), all other methods use torch on GPU
        use_gpu = upscale_method != "lanczos"
        device = model_management.get_torch_device() if use_gpu else torch.device("cpu")

        # --- Decode video frames with per-frame resize + dtype cast ---
        with av.open(source, mode='r') as container:
            video_stream = container.streams.video[0]
            start_pts = int(start_time / video_stream.time_base)
            end_pts = int((start_time + duration) / video_stream.time_base) if duration else 0
            container.seek(start_pts, stream=video_stream)

            res_w, res_h, crop_region, padding = None, None, None, (0, 0, 0, 0)
            frames = []
            for frame in container.decode(video_stream):
                if frame.pts < start_pts:
                    continue
                if duration and frame.pts >= end_pts:
                    break
                if max_frames > 0 and len(frames) >= max_frames:
                    break

                if res_w is None:
                    src_h, src_w = frame.height, frame.width
                    res_w, res_h, crop_region, padding = cls._compute_resize_params(
                        mode, position, width, height, src_w, src_h
                    )

                # Decode to tensor and normalize
                img = torch.from_numpy(frame.to_ndarray(format='rgb24')).to(device=device, dtype=torch.float32) / 255.0

                # Crop if needed (before resize)
                if crop_region is not None:
                    cx, cy, cw, ch = crop_region
                    img = img[cy:cy+ch, cx:cx+cw, :]

                # Resize (GPU for torch-native methods, CPU/PIL for lanczos)
                img = common_upscale(
                    img.unsqueeze(0).movedim(-1, 1), res_w, res_h, upscale_method, crop="disabled"
                ).movedim(1, -1).squeeze(0).to(dtype=target_dtype, device="cpu")

                frames.append(img)
                if pbar is not None:
                    pbar.update(1)

            frame_rate = video_stream.average_rate if video_stream.average_rate else 1

        s = torch.stack(frames) if frames else torch.zeros(0, height, width, 3, dtype=target_dtype)

        # Pad logic (applied on the full stack since padding modes like pillarbox_blur need all frames)
        pillarbox_blur = mode == "pillarbox_blur"
        pad_left, pad_right, pad_top, pad_bottom = padding
        if (mode.startswith("pad") or pillarbox_blur) and (pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0):
            pad_mode = (
                "pillarbox_blur" if pillarbox_blur else
                "edge" if mode == "pad_edge" else
                "edge_pixel" if mode == "pad_edge_pixel" else
                "color"
            )
            s, _ = ImagePadKJ.pad(None, s, pad_left, pad_right, pad_top, pad_bottom, 0, pad_color, pad_mode)

        # Trim frames to a count valid for the VAE's temporal compression
        try:
            temporal_compress = vae.downscale_ratio[0]
            temporal_decompress = vae.upscale_ratio[0]
            valid_frames = temporal_decompress(temporal_compress(s.shape[0]))
            if valid_frames < s.shape[0]:
                logging.warning(f"[EncodeVideoComponents] Trimming {s.shape[0] - valid_frames} frames ({s.shape[0]} -> {valid_frames}) to match VAE temporal compression ratio")
                s = s[:valid_frames]
        except (TypeError, IndexError):
            pass

        t = vae.encode(s)

        # --- Extract audio in a separate pass ---
        audio = None
        if isinstance(source, BytesIO):
            source.seek(0)
        with av.open(source, mode='r') as container:
            if len(container.streams.audio):
                audio_stream = container.streams.audio[-1]
                if start_time > 0:
                    audio_start_pts = int(start_time / audio_stream.time_base)
                    container.seek(audio_start_pts, stream=audio_stream)
                audio_frames = []
                resample = av.audio.resampler.AudioResampler(format='fltp').resample
                aframes = itertools.chain.from_iterable(
                    map(resample, container.decode(audio_stream))
                )
                has_first_frame = False
                for aframe in aframes:
                    offset_seconds = start_time - aframe.time
                    to_skip = int(offset_seconds * audio_stream.sample_rate)
                    if to_skip < aframe.samples:
                        has_first_frame = True
                        break
                if has_first_frame:
                    audio_frames.append(aframe.to_ndarray()[..., to_skip:])
                    for aframe in aframes:
                        if duration and aframe.time > start_time + duration:
                            break
                        audio_frames.append(aframe.to_ndarray())
                if audio_frames:
                    audio_data = np.concatenate(audio_frames, axis=1)
                    if duration:
                        audio_data = audio_data[..., :int(duration * audio_stream.sample_rate)]
                    audio = {
                        "waveform": torch.from_numpy(audio_data).unsqueeze(0),
                        "sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
                    }

        return io.NodeOutput({"samples": t}, audio, float(frame_rate), s.shape[0])


class DecodeAndSaveVideo(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="DecodeAndSaveVideo",
            search_aliases=["video to latent", "decode video"],
            display_name="Decode and Save Video",
            category="KJNodes/image",
            description="Decodes video frames and audio from latent representations, combines them, and saves as a video file, without keeping intermediate images in memory.",
            inputs=[
                io.Latent.Input("video_latent", tooltip="The latent representation of the video frames."),
                io.Latent.Input("audio_latent", optional=True, tooltip="The latent representation of the audio frames."),
                io.Float.Input("fps", default=25.0, min=0.0, max=999.0, step=0.01, tooltip="Frame rate for the output video."),
                io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
                io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
                io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
                io.Vae.Input("video_vae", tooltip="The VAE model to use for encoding."),
                io.Vae.Input("audio_vae", optional=True, tooltip="The VAE model to use for decoding audio."),
                io.DynamicCombo.Input("tiling", options=[
                    io.DynamicCombo.Option(key="disabled", inputs=[]),
                    io.DynamicCombo.Option(key="enabled", inputs=[
                        io.Int.Input("tile_size", default=512, min=64, max=4096, step=32, tooltip="Size of the tiles to decode. Smaller tiles use less memory but take more time."),
                        io.Int.Input("overlap", default=64, min=0, max=4096, step=32, tooltip="Amount of overlap between tiles. Higher overlap can improve quality at the edges of tiles but uses more memory and takes more time."),
                        io.Int.Input("temporal_size", default=4096, min=8, max=4096, step=4, tooltip="Only used for video VAEs: Amount of frames to decode at a time. Higher value than number of frames = disabled"),
                        io.Int.Input("temporal_overlap", default=16, min=4, max=4096, step=4, tooltip="Only used for video VAEs: Amount of frames to overlap. Higher overlap can improve quality at the edges of temporal tiles but uses more memory and takes more time."),
                    ]),
                ]),
            ],
            hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
            is_output_node=True,
        )

    @classmethod
    def execute(cls, video_latent, video_vae, filename_prefix, format, codec, tiling, audio_latent=None, audio_vae=None, fps=25.0) -> io.NodeOutput:
        if tiling["tiling"] == "enabled":
            tile_size = tiling["tile_size"]
            overlap = tiling["overlap"]
            temporal_size = tiling["temporal_size"]
            temporal_overlap = tiling["temporal_overlap"]

            if tile_size < overlap * 4:
                overlap = tile_size // 4
            if temporal_size < temporal_overlap * 2:
                temporal_overlap = temporal_overlap // 2
            temporal_compression = video_vae.temporal_compression_decode()
            if temporal_compression is not None:
                temporal_size = max(2, temporal_size // temporal_compression)
                temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression))
            else:
                temporal_size = None
                temporal_overlap = None

            compression = video_vae.spacial_compression_decode()

            images = cls.decode_tiled(video_vae, video_latent["samples"],
                                      tile_t=max(2, temporal_size),
                                      tile_x=tile_size // compression,
                                      tile_y=tile_size // compression,
                                      overlap=(temporal_overlap if temporal_overlap is not None else 1, max(1, overlap // compression), max(1, overlap // compression)),
            ).movedim(1, -1)
            if len(images.shape) == 5: #Combine batches
                images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
        else:
            images = cls.decode_video(video_vae, video_latent)

        if audio_latent is not None:
            if audio_vae is None:
                raise ValueError("Audio VAE must be provided if audio latent is provided.")
            audio = cls.decode_audio(audio_latent, audio_vae)
        else:
            audio = None

        video = InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
        file, subfolder = cls.save_video(video, filename_prefix, format, codec)

        return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))

    @classmethod
    def decode_video(cls, vae, samples):
        samples_in = samples["samples"]
        if samples_in.is_nested:
            samples_in = samples_in.unbind()[0]

        vae.throw_exception_if_invalid()
        pixel_samples = None
        do_tile = False
        if vae.latent_dim == 2 and samples_in.ndim == 5:
            samples_in = samples_in[:, :, 0]
        try:
            memory_used = vae.memory_used_decode(samples_in.shape, vae.vae_dtype)
            model_management.load_models_gpu([vae.patcher], memory_required=memory_used, force_full_load=True)
            free_memory = vae.patcher.get_free_memory(vae.device)
            batch_number = int(free_memory / memory_used)
            batch_number = max(1, batch_number)

            for x in range(0, samples_in.shape[0], batch_number):
                samples = samples_in[x:x+batch_number].to(vae.vae_dtype).to(vae.device)
                out = vae.process_output(vae.first_stage_model.decode(samples).to(vae.output_device).to(torch.float16))
                if pixel_samples is None:
                    pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=vae.output_device, dtype=out.dtype)
                pixel_samples[x:x+batch_number] = out
        except Exception as e:
            model_management.raise_non_oom(e)
            logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
            do_tile = True

        if do_tile:
            dims = samples_in.ndim - 2
            if dims == 1 or cls.extra_1d_channel is not None:
                pixel_samples = vae.decode_tiled_1d(samples_in)
            elif dims == 2:
                pixel_samples = vae.decode_tiled_2d(samples_in)
            elif dims == 3:
                tile = 256 // vae.spacial_compression_decode()
                overlap = tile // 4
                pixel_samples = vae.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))

        pixel_samples = pixel_samples.to(vae.output_device).movedim(1,-1)

        if len(pixel_samples.shape) == 5: #Combine batches
            pixel_samples = pixel_samples.reshape(-1, pixel_samples.shape[-3], pixel_samples.shape[-2], pixel_samples.shape[-1])
        return pixel_samples

    @classmethod
    def decode_tiled(cls, vae, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
        vae.throw_exception_if_invalid()
        memory_used = vae.memory_used_decode(samples.shape, vae.vae_dtype)
        model_management.load_models_gpu([vae.patcher], memory_required=memory_used, force_full_load=vae.disable_offload)
        decode_fn = lambda a: vae.first_stage_model.decode(a.to(vae.vae_dtype).to(vae.device)).to(torch.float16)
        return vae.process_output(tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap,
                                                       upscale_amount=vae.upscale_ratio, out_channels=vae.output_channels, index_formulas=vae.upscale_index_formula, output_device=vae.output_device))


    @classmethod
    def decode_audio(cls, samples, audio_vae):
        audio_latent = samples["samples"]
        if audio_latent.is_nested:
            audio_latent = audio_latent.unbind()[-1]
        audio = audio_vae.decode(audio_latent)
        # Post-PR #13486: audio_vae is a comfy.sd.VAE wrapper returning channels-last (BTC).
        # Pre-PR: audio_vae is a raw AudioVAE returning channels-first (BCT).
        if hasattr(audio_vae, "first_stage_model"):
            audio = audio.movedim(-1, 1)
        audio = audio.to(audio_latent.device)
        output_audio_sample_rate = getattr(
            audio_vae,
            "audio_sample_rate_output",
            getattr(audio_vae, "output_sample_rate", None),
        )
        if output_audio_sample_rate is None:
            output_audio_sample_rate = getattr(
                getattr(audio_vae, "first_stage_model", None), "output_sample_rate", 44100
            )
        return {"waveform": audio, "sample_rate": int(output_audio_sample_rate)}

    @classmethod
    def save_video(cls, video, filename_prefix, format, codec) -> io.NodeOutput:
        width, height = video.get_dimensions()
        full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
            filename_prefix,
            folder_paths.get_output_directory(),
            width,
            height
        )
        saved_metadata = None
        if not args.disable_metadata:
            metadata = {}
            if cls.hidden.extra_pnginfo is not None:
                metadata.update(cls.hidden.extra_pnginfo)
            if cls.hidden.prompt is not None:
                metadata["prompt"] = cls.hidden.prompt
            if len(metadata) > 0:
                saved_metadata = metadata
        file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
        video.save_to(
            os.path.join(full_output_folder, file),
            format=Types.VideoContainer(format),
            codec=codec,
            metadata=saved_metadata
        )
        return file, subfolder


class PreviewImageOrMask(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="PreviewImageOrMask",
            display_name="Preview Image Or Mask",
            category="KJNodes/misc",
            description="Previews the input images or masks.",
            search_aliases=["output"],
            inputs=[
                io.MultiType.Input("input", [io.Image, io.Mask], tooltip="The image or mask to preview."),
            ],
            hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
            is_output_node=True,
        )

    @classmethod
    def execute(cls, input) -> io.NodeOutput:
        if input.ndim == 3:
            return io.NodeOutput(ui=ui.PreviewMask(input, cls=cls))
        return io.NodeOutput(ui=ui.PreviewImage(input, cls=cls))

