import comfy
import folder_paths
import torch
import os

from nodes import CheckpointLoaderSimple
from dfloat11 import DFloat11Model, compress_model
from .dfloat11_custom import DFloat11ModelPatcher
from .dfloat11_decompress import decompress_state_dict_func_map
from .dfloat11_diffusers import DFloat11FluxDiffusersModel
from .convert_fixed_tensors import convert_diffusers_to_comfyui_flux
from .pattern_dict import MODEL_TO_PATTERN_DICT

def filter_df11_keys(state_dict):
    return {key: tensor for key, tensor in state_dict.items() if not any((key.endswith("sign_mantissa"), key.endswith("encoded_exponent"), key.endswith("luts"), key.endswith("gaps"), key.endswith("output_positions"), key.endswith("split_positions")))}

class disable_weight_init_df11(comfy.ops.disable_weight_init):
    class Linear(comfy.ops.disable_weight_init.Linear):
        def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
            super(comfy.ops.disable_weight_init.Linear, self).__init__(in_features, out_features, bias, device, dtype)
            return

class DFloat11ModelLoaderAdvanced:
    """
    A custom node to load a DFloat11 diffusion model from the `diffusion_models` directory.

    DFloat11 models are >30% smaller than their float16 counterparts, yet produce bit-for-bit identical outputs.
    """

    '''
    max_memory: Maximum memory allocation per device
    cpu_offload: Enables CPU offloading; only keeps a single block of weights in GPU at once
    cpu_offload_blocks: Number of transformer blocks to offload to CPU; if None, offload all blocks
    pin_memory: Enables memory-pinning/page-locking when using CPU offloading
    '''
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "dfloat11_model_name": (folder_paths.get_filename_list("diffusion_models"),),
                "cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "Whether to offload to CPU RAM"}),
                "cpu_offload_blocks": ("INT", {"default": 0, "min": 0, "max": 999, "step": 1, "tooltip": "If set to 0, all blocks will be offloaded to CPU RAM"}),
                "pin_memory": ("BOOLEAN", {"default": True, "tooltip": "Whether to lock/pin the weights to CPU RAM. Enabling this option increases RAM usage (which might cause OOM), but should increase speed"}),
                "dynamic_vram_compatibility": (["custom_ops", "load_state_dict"], {"default": "custom_ops", "tooltip": "Strategy for compatibility with `dynamic_vram`, if it is enabled. `custom_ops` is better since it causes a smaller spike in RAM usage"}),
            }
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "load_dfloat11_model_advanced"
    CATEGORY = "DFloat11"

    def load_dfloat11_model_advanced(self, dfloat11_model_name, cpu_offload, cpu_offload_blocks, pin_memory, dynamic_vram_compatibility):
        if not cpu_offload:
            cpu_offload_blocks = 0
            pin_memory = True
        
        dfloat11_model_path = folder_paths.get_full_path_or_raise("diffusion_models", dfloat11_model_name)
        state_dict = comfy.utils.load_torch_file(dfloat11_model_path)

        if not any(k.endswith("encoded_exponent") for k in state_dict.keys()):
            raise ValueError(f"The model '{dfloat11_model_name}' is not a DFloat11 model.")

        load_device = comfy.model_management.get_torch_device()
        offload_device = comfy.model_management.unet_offload_device()
        
        # TODO: Refactor the logic for detecting `df11_type` into an external function
        missing_keys = {}
        
        if "double_blocks.0.img_mlp.gate_proj.bias" in state_dict and ("txt_norm.scale" in state_dict or "txt_norm.weight" in state_dict): # for Flux.2
            missing_keys["double_blocks.0.img_mlp.gate_proj.weight"] = None
        
        if "double_stream_modulation_img.lin.sign_mantissa" in state_dict and "double_stream_modulation_img.lin.weight" not in state_dict: # for Flux.2
            missing_keys["double_stream_modulation_img.lin.weight"] = None
        
        if "adaLN_modulation.1.sign_mantissa" in state_dict and "time_embedding.sign_mantissa" in state_dict and state_dict["layers.0.sign_mantissa"].numel() == 218103808: # for ErnieImage
            missing_keys["layers.0.mlp.linear_fc2.weight"] = None
        
        if "transformer_blocks.0.attn.norm_added_q.weight" in state_dict and state_dict["transformer_blocks.0.attn.norm_added_q.weight"].numel() == 64: # for Lens
            missing_keys["transformer_blocks.0.img_mlp.w1.weight"] = None
        
        if "encoder.lyric_encoder.layers.0.input_layernorm.weight" in state_dict and "decoder.layers.0.sign_mantissa" in state_dict: # for Ace-Step-v1.5
            if state_dict["decoder.layers.0.sign_mantissa"].numel() == 62914560: # The smaller version
                missing_keys["decoder.layers.0.mlp.gate_proj.weight"] = torch.empty([6144, 2048], device="meta")
                missing_keys["decoder.layers.0.self_attn.q_proj.weight"] = torch.empty([2048, 2048], device="meta")
                
                missing_keys["encoder.lyric_encoder.layers.0.self_attn.q_proj.weight"] = torch.empty([2048, 2048], device="meta")
                missing_keys["encoder.lyric_encoder.layers.0.mlp.gate_proj.weight"] = torch.empty([6144, 2048], device="meta")
                
            elif state_dict["decoder.layers.0.sign_mantissa"].numel() == 127139840: # The XL version
                missing_keys["decoder.layers.0.mlp.gate_proj.weight"] = torch.empty([9728, 2560], device="meta")
                missing_keys["decoder.layers.0.self_attn.q_proj.weight"] = torch.empty([4096, 2560], device="meta")
                
                missing_keys["encoder.lyric_encoder.layers.0.self_attn.q_proj.weight"] = torch.empty([2048, 2048], device="meta")
                missing_keys["encoder.lyric_encoder.layers.0.mlp.gate_proj.weight"] = torch.empty([6144, 2048], device="meta")
            
            else:
                raise Exception(f"Detected Ace-Step-v1.5 model, but unsure of size {state_dict['decoder.layers.0.sign_mantissa'].numel()}")
                

        model_config = comfy.sd.model_detection.model_config_from_unet(state_dict | missing_keys, "")
        
        if model_config is None:
            # In case `model_config` cannot be found, the possible options are Anima and CosmosPredict2
            missing_keys = {}
            if "llm_adapter.blocks.0.sign_mantissa" in state_dict:
                # This should be Anima
                missing_keys["llm_adapter.blocks.0.cross_attn.q_proj.weight"] = None
                
            # This applies to both Anima and CosmosPredict2
            missing_keys["blocks.0.mlp.layer1.weight"] = None
            model_config = comfy.sd.model_detection.model_config_from_unet(state_dict | missing_keys, "")
            assert model_config is not None, "Unable to detect model type"
        
        df11_type = type(model_config).__name__
        
        if df11_type == "FluxSchnell" and model_config.unet_config.get("yak_mlp", False) and model_config.unet_config.get("txt_norm", False):
            df11_type = "OvisImage"
        
        if comfy.memory_management.aimdo_enabled and dynamic_vram_compatibility == "custom_ops":
            model_config.custom_operations = disable_weight_init_df11
        
        model_config.set_inference_dtype(torch.bfloat16, torch.bfloat16)
        model = model_config.get_model(state_dict, "")
        
        if not comfy.memory_management.aimdo_enabled:
            model = model.to(offload_device)
        
        if comfy.memory_management.aimdo_enabled and dynamic_vram_compatibility == "load_state_dict":
            model.diffusion_model.load_state_dict(filter_df11_keys(state_dict), strict=False, assign=False)

        DFloat11Model.from_single_file(
            dfloat11_model_path,
            pattern_dict=MODEL_TO_PATTERN_DICT[df11_type],
            bfloat16_model=model.diffusion_model,
            device=offload_device,
            cpu_offload=cpu_offload,
            cpu_offload_blocks=cpu_offload_blocks if cpu_offload_blocks > 0 else None,
            pin_memory=pin_memory,
        )
        

        # Always use DFloat11ModelPatcher for DF11 models (required due to missing .weight attributes

        return (
            DFloat11ModelPatcher(model, load_device=load_device, offload_device=offload_device),
        )

class DFloat11ModelLoader(DFloat11ModelLoaderAdvanced):
    """
    A custom node to load a DFloat11 diffusion model from the `diffusion_models` directory.

    DFloat11 models are >30% smaller than their float16 counterparts, yet produce bit-for-bit identical outputs.
    """

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "dfloat11_model_name": (folder_paths.get_filename_list("diffusion_models"),),
            }
        }

    FUNCTION = "load_dfloat11_model"
    
    def load_dfloat11_model(self, dfloat11_model_name):
        return self.load_dfloat11_model_advanced(dfloat11_model_name, cpu_offload = False, cpu_offload_blocks = 0, pin_memory = True, dynamic_vram_compatibility = "custom_ops")


class DFloat11Decompressor:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "dfloat11_model_name": (folder_paths.get_filename_list("diffusion_models"),),
                "decompress_recipe": (["Flux.2-Klein-4B", "Flux.2-Klein-9B"],),
                "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
            }
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "decompress_dfloat11_model"
    CATEGORY = "DFloat11"
    
    def decompress_dfloat11_model(self, dfloat11_model_name, decompress_recipe, weight_dtype):
        dfloat11_model_path = folder_paths.get_full_path_or_raise("diffusion_models", dfloat11_model_name)
        df11_state_dict = comfy.utils.load_torch_file(dfloat11_model_path)

        if not any(k.endswith("encoded_exponent") for k in df11_state_dict.keys()):
            raise ValueError(f"The model '{dfloat11_model_name}' is not a DFloat11 model.")
            
        reconstructed_state_dict = decompress_state_dict_func_map[decompress_recipe](df11_state_dict)
        
        model_options = {}
        if weight_dtype == "fp8_e4m3fn":
            model_options["dtype"] = torch.float8_e4m3fn
        elif weight_dtype == "fp8_e4m3fn_fast":
            model_options["dtype"] = torch.float8_e4m3fn
            model_options["fp8_optimizations"] = True
        elif weight_dtype == "fp8_e5m2":
            model_options["dtype"] = torch.float8_e5m2
            
        model_patcher = comfy.sd.load_diffusion_model_state_dict(reconstructed_state_dict, model_options=model_options)
        
        return (model_patcher,)



class DFloat11LoadingPatch:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "model_patcher": ("MODEL", {"tooltip": "The model to display information for"}),
                "load_version": (["v1", "v1.5", "v2"],),
                "memory_usage_factor_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step":0.01, "round": 0.01, "tooltip": "The multiplier to scale ComfyUI's memory usage estimation"}),
            }
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch_loading_methods"
    CATEGORY = "DFloat11"

    def patch_loading_methods(self, model_patcher, load_version, memory_usage_factor_scale):
        
        new_model_patcher = model_patcher.clone()
        
        new_model_patcher.patch_loading_methods(load_version)

        new_model_patcher.model.model_config.memory_usage_factor *= memory_usage_factor_scale
        new_model_patcher.model.memory_usage_factor = new_model_patcher.model.model_config.memory_usage_factor

        return (new_model_patcher,)


class CheckpointLoaderWithDFloat11(CheckpointLoaderSimple):
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
                "dfloat11_model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The diffusion model in DF11 format."}),
            }
        }

    FUNCTION = "load_checkpoint_with_df11"
    CATEGORY = "DFloat11"
    DESCRIPTION = "Loads a diffusion model checkpoint, along with a DF11 unet."
    
    def load_checkpoint_with_df11(self, ckpt_name, dfloat11_model_name):
        ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
        out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
        model_patcher, clip, vae, *_ = out

        dfloat11_model_path = folder_paths.get_full_path_or_raise("diffusion_models", dfloat11_model_name)
        state_dict = comfy.utils.load_torch_file(dfloat11_model_path)
        if not any(k.endswith("encoded_exponent") for k in state_dict.keys()):
            raise ValueError(f"The model '{dfloat11_model_name}' is not a DFloat11 model.")

        load_device = comfy.model_management.get_torch_device()
        offload_device = comfy.model_management.unet_offload_device()
        
        df11_type = type(model_patcher.model).__name__
        
        df11_model_patcher = DFloat11ModelPatcher(
            model_patcher.model,
            load_device=load_device,
            offload_device=offload_device,
        )
        
        del model_patcher

        DFloat11Model.from_single_file(
            dfloat11_model_path,
            pattern_dict=MODEL_TO_PATTERN_DICT[df11_type],
            bfloat16_model=df11_model_patcher.model.diffusion_model,
            device=offload_device,
        )
        
        return (df11_model_patcher, clip, vae)



class DFloat11DiffusersModelLoader:
    """
    A custom node to load a diffusers-native DFloat11 diffusion model from the `diffusion_models` directory.

    DFloat11 models are >30% smaller than their float16 counterparts, yet produce bit-for-bit identical outputs.
    """

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "dfloat11_model_name": (folder_paths.get_filename_list("diffusion_models"),),
                "model_type": (["Flux",],)
            }
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "load_dfloat11_model"
    CATEGORY = "DFloat11"

    def load_dfloat11_model(self, dfloat11_model_name, model_type):
        dfloat11_model_path = folder_paths.get_full_path_or_raise("diffusion_models", dfloat11_model_name)
        
        # state_dict = convert_diffusers_to_comfyui_flux(comfy.utils.load_torch_file(dfloat11_model_path))
        state_dict = comfy.utils.load_torch_file(dfloat11_model_path)
        
        if not any(k.endswith("encoded_exponent") for k in state_dict.keys()):
            raise ValueError(f"The model '{dfloat11_model_name}' is not a DFloat11 model.")

        load_device = comfy.model_management.get_torch_device()
        offload_device = comfy.model_management.unet_offload_device()
        
        unet_config = {
            'image_model': 'flux', 
            'axes_dim': [16, 56, 56], 
            'num_heads': 24, 
            'mlp_ratio': 4.0, 
            'theta': 10000, 
            'out_channels': 16, 
            'qkv_bias': True, 
            'txt_ids_dims': [], 
            'in_channels': 16, 
            'hidden_size': 3072, 
            'context_in_dim': 4096, 
            'patch_size': 2, 
            'vec_in_dim': 768, 
            'depth': 19, 
            'depth_single_blocks': 38, 
            'guidance_embed': True, 
            'yak_mlp': False, 
            'txt_norm': False
        }
        
        unet_config["guidance_embed"] = "time_text_embed.guidance_embedder.linear_1.weight" in state_dict
        
        model_config = comfy.supported_models.Flux(unet_config)
        model_config.set_inference_dtype(torch.bfloat16, torch.bfloat16)
        model = model_config.get_model(state_dict, "")
        model = model.to(offload_device)

        DFloat11FluxDiffusersModel.from_single_file(
            dfloat11_model_path,
            pattern_dict=MODEL_TO_PATTERN_DICT[model_type],
            bfloat16_model=model.diffusion_model,
            device=offload_device,
        )

        return (
            DFloat11ModelPatcher(model, load_device=load_device, offload_device=offload_device),
        )

class DFloat11ModelCompressor:
    """
    A custom node to compress a DFloat11 diffusion model from the `diffusion_models` directory.

    DFloat11 models are >30% smaller than their float16 counterparts, yet produce bit-for-bit identical outputs.
    """

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "bfloat16_model_name": (folder_paths.get_filename_list("diffusion_models"),),
                "model_type": (list(MODEL_TO_PATTERN_DICT.keys()),)
            }
        }

    RETURN_TYPES = ("STRING",)
    FUNCTION = "load_bfloat16_model"
    CATEGORY = "DFloat11"

    def load_bfloat16_model(self, bfloat16_model_name, model_type):
        bfloat16_model_path = folder_paths.get_full_path_or_raise("diffusion_models", bfloat16_model_name)
        model = comfy.sd.load_diffusion_model(bfloat16_model_path, {"dtype" : torch.bfloat16})

        save_path = f"{os.path.splitext(bfloat16_model_path)[0]}-DF11"
        
        compress_model(
            model=model.model.diffusion_model,
            pattern_dict=MODEL_TO_PATTERN_DICT[model_type],
            save_path= save_path,
            save_single_file= True,
            check_correctness= True,
            block_range= (0, 500),
        )

        return (save_path,)


class DFloat11CheckpointCompressor:
    """
    A custom node to compress a DFloat11 diffusion model (unet only) from the `checkpoints` directory.

    DFloat11 models are >30% smaller than their float16 counterparts, yet produce bit-for-bit identical outputs.
    """

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
                "model_type": (list(MODEL_TO_PATTERN_DICT.keys()),)
            }
        }

    RETURN_TYPES = ("STRING",)
    FUNCTION = "load_bfloat16_checkpoint"
    CATEGORY = "DFloat11"

    def load_bfloat16_checkpoint(self, ckpt_name, model_type):
        ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)

        state_dict, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)

        diffusion_model_prefix = comfy.model_detection.unet_prefix_from_state_dict(state_dict)
        parameters = comfy.utils.calculate_parameters(state_dict, diffusion_model_prefix)
        weight_dtype = comfy.utils.weight_dtype(state_dict, diffusion_model_prefix)
        
        load_device = comfy.model_management.get_torch_device()
        offload_device = comfy.model_management.unet_offload_device()

        model_config = comfy.model_detection.model_config_from_unet(state_dict, diffusion_model_prefix, metadata=metadata)
        model_config.set_inference_dtype(weight_dtype, weight_dtype)
        
        unet_weight_dtype = list(model_config.supported_inference_dtypes)
        
        unet_dtype = comfy.model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
        
        inital_load_device = comfy.model_management.unet_inital_load_device(parameters, unet_dtype)
        model = model_config.get_model(state_dict, diffusion_model_prefix, device=inital_load_device)
        
        model.load_model_weights(state_dict, diffusion_model_prefix)
        model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)

        diffusion_model = model_patcher.model.diffusion_model
        
        save_path = f"{os.path.splitext(ckpt_path)[0]}-DF11"

        compress_model(
            model=diffusion_model,
            pattern_dict=MODEL_TO_PATTERN_DICT[model_type],
            save_path=save_path,
            save_single_file=True,
            check_correctness=True,
            block_range=(0, 500),
        )

        return (save_path,)
