import torch
import comfy.model_management
import comfy.utils
import folder_paths
import os
import logging
from tqdm import tqdm
import numpy as np
from comfy_api.latest import io

device = comfy.model_management.get_torch_device()

CLAMP_QUANTILE = 0.99

def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0, clamp_quantile=True):
    """
    Extracts LoRA weights from a weight difference tensor using SVD.
    """
    conv2d = (len(diff.shape) == 4)
    kernel_size = None if not conv2d else diff.size()[2:4]
    conv2d_3x3 = conv2d and kernel_size != (1, 1)
    out_dim, in_dim = diff.size()[0:2]

    if conv2d:
        if conv2d_3x3:
            diff = diff.flatten(start_dim=1)
        else:
            diff = diff.squeeze()

    diff_float = diff.float()
    if algorithm == "svd_lowrank":
        U, S, V = torch.svd_lowrank(diff_float, q=min(rank, in_dim, out_dim), niter=lowrank_iters)
        U = U @ torch.diag(S)
        Vh = V.t()
    else:
        #torch.linalg.svdvals() 
        U, S, Vh = torch.linalg.svd(diff_float)
        # Flexible rank selection logic like locon: https://github.com/KohakuBlueleaf/LyCORIS/blob/main/tools/extract_locon.py
        if "adaptive" in lora_type:
            if lora_type == "adaptive_ratio":
                min_s = torch.max(S) * adaptive_param
                lora_rank = torch.sum(S > min_s).item()
            elif lora_type == "adaptive_energy":
                energy = torch.cumsum(S**2, dim=0)
                total_energy = torch.sum(S**2)
                threshold = adaptive_param * total_energy  # e.g., adaptive_param=0.95 for 95%
                lora_rank = torch.sum(energy < threshold).item() + 1
            elif lora_type == "adaptive_quantile":
                s_cum = torch.cumsum(S, dim=0)
                min_cum_sum = adaptive_param * torch.sum(S)
                lora_rank = torch.sum(s_cum < min_cum_sum).item()
            elif lora_type == "adaptive_fro":
                S_squared = S.pow(2)
                S_fro_sq = float(torch.sum(S_squared))
                sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
                lora_rank = int(torch.searchsorted(sum_S_squared, adaptive_param**2)) + 1
                lora_rank = max(1, min(lora_rank, len(S)))
            else:
                pass  # Will print after capping
            
            # Cap adaptive rank by the specified max rank
            lora_rank = min(lora_rank, rank)
            
            # Calculate and print actual fro percentage retained after capping
            if lora_type == "adaptive_fro":
                S_squared = S.pow(2)
                s_fro = torch.sqrt(torch.sum(S_squared))
                s_red_fro = torch.sqrt(torch.sum(S_squared[:lora_rank]))
                fro_percent = float(s_red_fro / s_fro)
                logging.info(f"{key} Extracted LoRA rank: {lora_rank}, Frobenius retained: {fro_percent:.1%}")
            else:
                logging.info(f"{key} Extracted LoRA rank: {lora_rank}")
        else:
            lora_rank = rank

        lora_rank = max(1, lora_rank)
        lora_rank = min(out_dim, in_dim, lora_rank)
        
        U = U[:, :lora_rank]
        S = S[:lora_rank]
        U = U @ torch.diag(S)
        Vh = Vh[:lora_rank, :]

    if clamp_quantile:
        dist = torch.cat([U.flatten(), Vh.flatten()])
        if dist.numel() > 100_000:
            # Sample 100,000 elements for quantile estimation
            idx = torch.randperm(dist.numel(), device=dist.device)[:100_000]
            dist_sample = dist[idx]
            hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE)
        else:
            hi_val = torch.quantile(dist, CLAMP_QUANTILE)
        low_val = -hi_val

        U = U.clamp(low_val, hi_val)
        Vh = Vh.clamp(low_val, hi_val)
    if conv2d:
        U = U.reshape(out_dim, lora_rank, 1, 1)
        Vh = Vh.reshape(lora_rank, in_dim, kernel_size[0], kernel_size[1])
    return (U, Vh)


def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0, clamp_quantile=True):
    # Get key names from module structure without materializing weights
    sd_keys = []
    for name, _ in model_diff.model.named_parameters():
        if prefix_model is None or name.startswith(prefix_model):
            sd_keys.append(name)
    for name, _ in model_diff.model.named_buffers():
        if prefix_model is None or name.startswith(prefix_model):
            sd_keys.append(name)

    total_keys = len([k for k in sd_keys if k.endswith(".weight") or (bias_diff and k.endswith(".bias"))])
    progress_bar = tqdm(total=total_keys, desc=f"Extracting LoRA ({prefix_lora.strip('.')})")
    comfy_pbar = comfy.utils.ProgressBar(total_keys)

    # Process one weight at a time to minimize memory usage
    for k in sd_keys:
        if k.endswith(".weight"):
            # Patch and retrieve single weight
            weight_diff = model_diff.patch_weight_to_device(k, return_weight=True)
            if weight_diff is None:
                progress_bar.update(1)
                comfy_pbar.update(1)
                continue
            if weight_diff.ndim == 5:
                logging.info(f"Skipping 5D tensor for key {k}")
                del weight_diff
                progress_bar.update(1)
                comfy_pbar.update(1)
                continue
            if lora_type != "full":
                if weight_diff.ndim < 2:
                    if bias_diff:
                        output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu()
                    del weight_diff
                    progress_bar.update(1)
                    comfy_pbar.update(1)
                    continue
                try:
                    out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
                    output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().to(out_dtype).cpu()
                    output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().to(out_dtype).cpu()
                except Exception as e:
                    logging.warning(f"Could not generate lora weights for key {k}, error {e}")
            else:
                output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu()
            del weight_diff
            progress_bar.update(1)
            comfy_pbar.update(1)

        elif bias_diff and k.endswith(".bias"):
            weight = model_diff.patch_weight_to_device(k, return_weight=True)
            if weight is not None:
                output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight.contiguous().to(out_dtype).cpu()
                del weight
            progress_bar.update(1)
            comfy_pbar.update(1)

    progress_bar.close()
    del model_diff
    comfy.model_management.soft_empty_cache()
    return output_sd


class LoraExtractKJ(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="LoraExtractKJ",
            category="KJNodes/lora",
            is_output_node=True,
            inputs=[
                io.MultiType.Input("finetuned", [io.Model, io.Clip], tooltip="The finetuned model or clip to extract LoRA from."),
                io.MultiType.Input("original", [io.Model, io.Clip], tooltip="The original base model or clip to diff against."),
                io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
                io.Int.Input("rank", default=64, min=1, max=4096, step=1, tooltip="The rank to use for standard LoRA, or maximum rank limit for adaptive methods."),
                io.Combo.Input("lora_type", options=["standard", "full", "adaptive_ratio", "adaptive_quantile", "adaptive_energy", "adaptive_fro"]),
                io.Combo.Input("algorithm", options=["svd_linalg", "svd_lowrank"], default="svd_lowrank", tooltip="SVD algorithm to use, svd_lowrank is faster but less accurate."),
                io.Int.Input("lowrank_iters", default=7, min=1, max=100, step=1, tooltip="The number of subspace iterations for lowrank SVD algorithm."),
                io.Combo.Input("output_dtype", options=["fp16", "bf16", "fp32"], default="fp16"),
                io.Boolean.Input("bias_diff", default=True),
                io.Float.Input("adaptive_param", default=0.15, min=0.0, max=1.0, step=0.01, tooltip="For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values. For fro mode, this is the Frobenius norm retention ratio."),
                io.Boolean.Input("clamp_quantile", default=False),
            ],
        )


    @classmethod
    def execute(cls, finetuned, original, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param, clamp_quantile) -> io.NodeOutput:
        if algorithm == "svd_lowrank" and lora_type != "standard":
            raise ValueError("svd_lowrank algorithm is only supported for standard LoRA extraction.")

        dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[output_dtype]

        output_dir = folder_paths.get_output_directory()
        full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)

        output_sd = {}

        is_clip = hasattr(finetuned, "patcher")

        if is_clip:
            clip_diff = finetuned.clone()
            kp = original.get_key_patches()
            kp = {k: v for k, v in kp.items() if not k.endswith(".position_ids") and not k.endswith(".logit_scale")}
            clip_diff.add_patches(kp, -1.0, 1.0)
            output_sd = calc_lora_model(clip_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
        else:
            m = finetuned.clone()
            kp = original.get_key_patches("diffusion_model.")
            m.add_patches(kp, -1.0, 1.0)
            output_sd = calc_lora_model(m, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)

        if "adaptive" in lora_type:
            rank_str = f"{lora_type}_{adaptive_param:.2f}"
        else:
            rank_str = rank
        output_checkpoint = f"{filename}_rank_{rank_str}_{output_dtype}_{counter:05}_.safetensors"
        output_checkpoint = os.path.join(full_output_folder, output_checkpoint)

        comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
        return io.NodeOutput()

class LoraReduceRank(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="LoraReduceRankKJ",
            display_name="LoraReduceRank",
            category="KJNodes/lora",
            description="Resize a LoRA model by reducing its rank. Based on kohya's sd-scripts: https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py",
            is_output_node=True,
            is_experimental=True,
            inputs=[
                io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras"), tooltip="The name of the LoRA."),
                io.Int.Input("new_rank", default=8, min=1, max=4096, step=1, tooltip="The new rank to resize the LoRA. Acts as max rank when using dynamic_method."),
                io.Combo.Input("dynamic_method", options=["disabled", "sv_ratio", "sv_cumulative", "sv_fro", "sv_knee"], default="disabled", tooltip="Method to use for dynamically determining new alphas and dims. sv_knee finds the elbow point in the singular value curve."),
                io.Float.Input("dynamic_param", default=0.2, min=0.0, max=2.0, step=0.01, tooltip="Parameter for dynamic methods. For sv_knee: sensitivity (1.0=standard knee, <1.0=more aggressive/lower rank, >1.0=more conservative)."),
                io.Combo.Input("output_dtype", options=["match_original", "fp16", "bf16", "fp32"], default="match_original", tooltip="Data type to save the LoRA as."),
                io.Boolean.Input("verbose", default=True),
            ],
        )

    @classmethod
    def execute(cls, lora_name, new_rank, dynamic_method, dynamic_param, output_dtype, verbose) -> io.NodeOutput:
        lora_path = folder_paths.get_full_path("loras", lora_name)
        lora_sd, metadata = comfy.utils.load_torch_file(lora_path, return_metadata=True)

        if output_dtype == "fp16":
            save_dtype = torch.float16
        elif output_dtype == "bf16":
            save_dtype = torch.bfloat16
        elif output_dtype == "fp32":
            save_dtype = torch.float32
        elif output_dtype == "match_original":
            first_weight_key = next(k for k in lora_sd if k.endswith(".weight") and isinstance(lora_sd[k], torch.Tensor))
            save_dtype = lora_sd[first_weight_key].dtype

        new_lora_sd = {}
        for k, v in lora_sd.items():
            new_lora_sd[k.replace(".default", "")] = v
        del lora_sd
        logging.info("Resizing Lora...")
        output_sd, old_dim, new_alpha, rank_list = resize_lora_model(new_lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose)

        if metadata is None:
            metadata = {}

        comment = metadata.get("ss_training_comment", "")

        if dynamic_method == "disabled":
            metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {new_rank}; {comment}"
            metadata["ss_network_dim"] = str(new_rank)
            metadata["ss_network_alpha"] = str(new_alpha)
        else:
            metadata["ss_training_comment"] = f"Dynamic resize with {dynamic_method}: {dynamic_param} from {old_dim}; {comment}"
            metadata["ss_network_dim"] = "Dynamic"
            metadata["ss_network_alpha"] = "Dynamic"

        for key in list(output_sd.keys()):
            value = output_sd[key]
            if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
                output_sd[key] = value.to(save_dtype)

        output_dir = folder_paths.get_output_directory()
        output_filename_prefix = "loras/" + lora_name

        full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(output_filename_prefix, output_dir)
        output_dtype_str = f"_{output_dtype}" if output_dtype != "match_original" else ""
        average_rank = str(int(np.mean(rank_list)))
        rank_str = new_rank if dynamic_method == "disabled" else f"dynamic_{average_rank}"
        output_checkpoint = f"{filename.replace('.safetensors', '')}_resized_from_{old_dim}_to_{rank_str}{output_dtype_str}_{counter:05}_.safetensors"
        output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
        logging.info(f"Saving resized LoRA to {output_checkpoint}")

        comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=metadata)
        return io.NodeOutput()

# Convert LoRA to different rank approximation (should only be used to go to lower rank)
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo

# This version is based on
# https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py

MIN_SV = 1e-6

LORA_DOWN_UP_FORMATS = [
    ("lora_down", "lora_up"),  # sd-scripts LoRA
    ("lora_A", "lora_B"),  # PEFT LoRA
    ("down", "up"),  # ControlLoRA
]

# Indexing functions
def index_sv_cumulative(S, target):
    original_sum = float(torch.sum(S))
    cumulative_sums = torch.cumsum(S, dim=0) / original_sum
    index = int(torch.searchsorted(cumulative_sums, target)) + 1
    index = max(1, min(index, len(S) - 1))

    return index


def index_sv_fro(S, target):
    S_squared = S.pow(2)
    S_fro_sq = float(torch.sum(S_squared))
    sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
    index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
    index = max(1, min(index, len(S) - 1))

    return index


def index_sv_ratio(S, target):
    max_sv = S[0]
    min_sv = max_sv / target
    index = int(torch.sum(S > min_sv).item())
    index = max(1, min(index, len(S) - 1))

    return index


def index_sv_knee(S, sensitivity=1.0):
    """Find the knee/elbow point in the singular value curve.
    Uses the Kneedle method: normalizes the curve to [0,1] on both axes,
    then finds the point with maximum distance from a reference line.

    sensitivity controls the aggressiveness:
      1.0 = standard knee point
      < 1.0 = more aggressive (lower rank), e.g. 0.5 roughly halves the knee rank
      > 1.0 = more conservative (higher rank)
    The reference line tilts based on sensitivity: at lower values,
    the line favors keeping fewer singular values."""
    n = len(S)
    if n <= 2:
        return 1

    S_np = S.cpu().float().numpy()

    # Normalize x and y to [0, 1]
    x = np.linspace(0, 1, n)
    y = (S_np - S_np[-1]) / (S_np[0] - S_np[-1] + 1e-10)

    # Reference line from (0, 1) to (1, 1-sensitivity)
    # At sensitivity=1.0: line goes (0,1)->(1,0), standard kneedle
    # At sensitivity=0.5: line goes (0,1)->(1,0.5), steeper = more aggressive
    y_line = 1.0 - sensitivity * x

    # Signed distance: positive means curve is above the line (keep)
    distances = y - y_line

    # Find the point of maximum positive distance
    index = int(np.argmax(distances))
    index = max(1, min(index, n - 1))
    return index


# Modified from Kohaku-blueleaf's extract/merge functions
def _svd_extract(weight_2d, lora_rank, dynamic_method, dynamic_param, device, scale=1):
    """Shared SVD extraction for both linear and conv weights.
    Dynamic mode: single full SVD (need all singular values for rank selection).
    Disabled mode: svd_lowrank only (rank is known, much faster for large matrices)."""
    weight_2d = weight_2d.to(device)
    if weight_2d.dtype != torch.float32:
        weight_2d = weight_2d.float()

    if dynamic_method and dynamic_method != "disabled":
        # Full SVD: we need all singular values for dynamic rank selection,
        # and we reuse U/Vh directly — one SVD, no wasted work
        U, S, Vh = torch.linalg.svd(weight_2d, full_matrices=False)
        param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
        lora_rank = param_dict["new_rank"]
        U = U[:, :lora_rank]
        S = S[:lora_rank]
        Vh = Vh[:lora_rank, :]
    else:
        # Randomized lowrank SVD: only compute top-k, much faster when rank << min(m,n)
        U, S, V = torch.svd_lowrank(weight_2d, q=lora_rank, niter=7)
        Vh = V.t()
        param_dict = {"new_rank": lora_rank, "new_alpha": float(scale * lora_rank)}
        del V

    sqrt_S = torch.diag(torch.sqrt(S))
    U = U @ sqrt_S
    Vh = sqrt_S @ Vh

    return U, Vh, param_dict


def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
    out_size, in_size, kernel_size, _ = weight.size()

    U, Vh, param_dict = _svd_extract(
        weight.reshape(out_size, -1), lora_rank, dynamic_method, dynamic_param, device, scale
    )
    lora_rank = param_dict["new_rank"]

    param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
    param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
    del U, Vh, weight
    return param_dict


def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
    out_size, in_size = weight.size()

    U, Vh, param_dict = _svd_extract(
        weight, lora_rank, dynamic_method, dynamic_param, device, scale
    )
    lora_rank = param_dict["new_rank"]

    param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
    param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
    del U, Vh, weight
    return param_dict


def merge_conv(lora_down, lora_up, device):
    in_rank, in_size, kernel_size, k_ = lora_down.shape
    out_size, out_rank, _, _ = lora_up.shape
    assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"

    lora_down = lora_down.to(device)
    lora_up = lora_up.to(device)

    merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
    weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
    del lora_up, lora_down
    return weight


def merge_linear(lora_down, lora_up, device):
    in_rank, in_size = lora_down.shape
    out_size, out_rank = lora_up.shape
    assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"

    lora_down = lora_down.to(device)
    lora_up = lora_up.to(device)

    weight = lora_up @ lora_down
    del lora_up, lora_down
    return weight


def merge_conv3d(lora_down, lora_up, device):
    in_rank, in_size, kD, kH, kW = lora_down.shape
    out_size, out_rank, _, _, _ = lora_up.shape
    assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"

    lora_down = lora_down.to(device)
    lora_up = lora_up.to(device)

    merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
    weight = merged.reshape(out_size, in_size, kD, kH, kW)
    del lora_up, lora_down
    return weight


def extract_conv3d(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
    out_size, in_size, kD, kH, kW = weight.size()

    U, Vh, param_dict = _svd_extract(
        weight.reshape(out_size, -1), lora_rank, dynamic_method, dynamic_param, device, scale
    )
    lora_rank = param_dict["new_rank"]

    param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kD, kH, kW).cpu()
    param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1, 1).cpu()
    del U, Vh, weight
    return param_dict


# Calculate new rank


def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
    param_dict = {}

    if dynamic_method == "sv_ratio":
        # Calculate new dim and alpha based off ratio
        new_rank = index_sv_ratio(S, dynamic_param) + 1
        new_alpha = float(scale * new_rank)

    elif dynamic_method == "sv_cumulative":
        # Calculate new dim and alpha based off cumulative sum
        new_rank = index_sv_cumulative(S, dynamic_param) + 1
        new_alpha = float(scale * new_rank)

    elif dynamic_method == "sv_fro":
        # Calculate new dim and alpha based off sqrt sum of squares
        new_rank = index_sv_fro(S, dynamic_param) + 1
        new_alpha = float(scale * new_rank)

    elif dynamic_method == "sv_knee":
        # Knee/elbow detection in singular value curve
        new_rank = index_sv_knee(S, dynamic_param) + 1
        new_alpha = float(scale * new_rank)
    else:
        new_rank = rank
        new_alpha = float(scale * new_rank)

    if S[0] <= MIN_SV:  # Zero matrix, set dim to 1
        new_rank = 1
        new_alpha = float(scale * new_rank)
    elif new_rank > rank:  # cap max rank at rank
        new_rank = rank
        new_alpha = float(scale * new_rank)

    # Calculate resize info
    s_sum = torch.sum(torch.abs(S))
    s_rank = torch.sum(torch.abs(S[:new_rank]))

    S_squared = S.pow(2)
    s_fro = torch.sqrt(torch.sum(S_squared))
    s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
    fro_percent = float(s_red_fro / s_fro)

    param_dict["new_rank"] = new_rank
    param_dict["new_alpha"] = new_alpha
    param_dict["sum_retained"] = (s_rank) / s_sum
    param_dict["fro_retained"] = fro_percent
    param_dict["max_ratio"] = S[0] / S[new_rank - 1]

    return param_dict


def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
    max_old_rank = None
    new_alpha = None
    verbose_str = "\n"
    fro_list = []
    rank_list = []

    if dynamic_method:
        logging.info(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")

    lora_down_weight = None
    lora_up_weight = None

    o_lora_sd = lora_sd.copy()
    block_down_name = None
    block_up_name = None

    total_keys = len([k for k in lora_sd if k.endswith(".weight")])

    pbar = comfy.utils.ProgressBar(total_keys)
    for key, value in tqdm(lora_sd.items(), leave=True, desc="Resizing LoRA weights"):
        key_parts = key.split(".")
        block_down_name = None
        for _format in LORA_DOWN_UP_FORMATS:
            # Currently we only match lora_down_name in the last two parts of key
            # because ("down", "up") are general words and may appear in block_down_name
            if len(key_parts) >= 2 and _format[0] == key_parts[-2]:
                block_down_name = ".".join(key_parts[:-2])
                lora_down_name = "." + _format[0]
                lora_up_name = "." + _format[1]
                weight_name = "." + key_parts[-1]
                break
            if len(key_parts) >= 1 and _format[0] == key_parts[-1]:
                block_down_name = ".".join(key_parts[:-1])
                lora_down_name = "." + _format[0]
                lora_up_name = "." + _format[1]
                weight_name = ""
                break

        if block_down_name is None:
            # This parameter is not lora_down
            continue

        # Now weight_name can be ".weight" or ""
        # Find corresponding lora_up and alpha
        block_up_name = block_down_name
        lora_down_weight = value
        lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None)
        lora_alpha = lora_sd.get(block_down_name + ".alpha", None)

        weights_loaded = lora_down_weight is not None and lora_up_weight is not None

        if weights_loaded:

            conv2d = len(lora_down_weight.size()) == 4
            conv3d = len(lora_down_weight.size()) == 5
            old_rank = lora_down_weight.size()[0]
            max_old_rank = max(max_old_rank or 0, old_rank)

            # Skip if merged weight would be too large (>100k elements in any dimension)
            if conv2d:
                in_rank, in_size, kernel_size, _ = lora_down_weight.shape
                out_size, out_rank, _, _ = lora_up_weight.shape
                merged_size = out_size * in_size * kernel_size * kernel_size
            elif conv3d:
                in_rank, in_size, kD, kH, kW = lora_down_weight.shape
                out_size, out_rank, _, _, _ = lora_up_weight.shape
                merged_size = out_size * in_size * kD * kH * kW
            else:
                in_rank, in_size = lora_down_weight.shape
                out_size, out_rank = lora_up_weight.shape
                merged_size = out_size * in_size

            if merged_size > 100_000_000:  # Skip if >100M elements
                logging.warning(f"Skipping {block_down_name}: merged weight too large ({merged_size:,} elements)")
                tqdm.write(f"SKIPPED: {block_down_name} - too large ({merged_size:,} elements)")
                pbar.update(1)
                continue

            if lora_alpha is None:
                scale = 1.0
            else:
                scale = lora_alpha / old_rank

            if conv2d:
                full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
                param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
            elif conv3d:
                full_weight_matrix = merge_conv3d(lora_down_weight, lora_up_weight, device)
                param_dict = extract_conv3d(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
            else:
                full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
                param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)

            if verbose and "fro_retained" in param_dict:
                max_ratio = param_dict["max_ratio"]
                sum_retained = param_dict["sum_retained"]
                fro_retained = param_dict["fro_retained"]
                if not np.isnan(fro_retained):
                    fro_list.append(float(fro_retained))
                log_str = f"{block_down_name:75} | sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}, new dim: {param_dict['new_rank']}"
                tqdm.write(log_str)
                verbose_str += log_str

            if verbose and dynamic_method:
                verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
            else:
                verbose_str += "\n"

            new_alpha = param_dict["new_alpha"]
            o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
            o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous()
            o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)

            block_down_name = None
            block_up_name = None
            lora_down_weight = None
            lora_up_weight = None
            weights_loaded = False
            rank_list.append(param_dict["new_rank"])
            del param_dict
        pbar.update(1)

    if verbose:
        logging.info(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
    return o_lora_sd, max_old_rank, new_alpha, rank_list
