import torch

chroma_keys_dict = {
    "distilled_guidance_layer.in_proj.bias" : torch.Size([5120]),
    "distilled_guidance_layer.in_proj.weight" : torch.Size([5120, 64]),
    "distilled_guidance_layer.layers.0.in_layer.bias" : torch.Size([5120]),
    "distilled_guidance_layer.layers.0.in_layer.weight" : torch.Size([5120, 5120]),
    "distilled_guidance_layer.layers.0.out_layer.bias" : torch.Size([5120]),
    "distilled_guidance_layer.layers.0.out_layer.weight" : torch.Size([5120, 5120]),
    "distilled_guidance_layer.layers.1.in_layer.bias" : torch.Size([5120]),
    "distilled_guidance_layer.layers.1.in_layer.weight" : torch.Size([5120, 5120]),
    "distilled_guidance_layer.layers.1.out_layer.bias" : torch.Size([5120]),
    "distilled_guidance_layer.layers.1.out_layer.weight" : torch.Size([5120, 5120]),
    "distilled_guidance_layer.layers.2.in_layer.bias" : torch.Size([5120]),
    "distilled_guidance_layer.layers.2.in_layer.weight" : torch.Size([5120, 5120]),
    "distilled_guidance_layer.layers.2.out_layer.bias" : torch.Size([5120]),
    "distilled_guidance_layer.layers.2.out_layer.weight" : torch.Size([5120, 5120]),
    "distilled_guidance_layer.layers.3.in_layer.bias" : torch.Size([5120]),
    "distilled_guidance_layer.layers.3.in_layer.weight" : torch.Size([5120, 5120]),
    "distilled_guidance_layer.layers.3.out_layer.bias" : torch.Size([5120]),
    "distilled_guidance_layer.layers.3.out_layer.weight" : torch.Size([5120, 5120]),
    "distilled_guidance_layer.layers.4.in_layer.bias" : torch.Size([5120]),
    "distilled_guidance_layer.layers.4.in_layer.weight" : torch.Size([5120, 5120]),
    "distilled_guidance_layer.layers.4.out_layer.bias" : torch.Size([5120]),
    "distilled_guidance_layer.layers.4.out_layer.weight" : torch.Size([5120, 5120]),
    "distilled_guidance_layer.norms.0.scale" : torch.Size([5120]),
    "distilled_guidance_layer.norms.1.scale" : torch.Size([5120]),
    "distilled_guidance_layer.norms.2.scale" : torch.Size([5120]),
    "distilled_guidance_layer.norms.3.scale" : torch.Size([5120]),
    "distilled_guidance_layer.norms.4.scale" : torch.Size([5120]),
    "distilled_guidance_layer.out_proj.bias" : torch.Size([3072]),
    "distilled_guidance_layer.out_proj.weight" : torch.Size([3072, 5120]),
    "double_blocks.0.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.0.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.0.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.0.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.0.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.0.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.0.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.0.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.0.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.0.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.0.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.0.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.0.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.0.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.0.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.0.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.0.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.0.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.0.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.0.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.1.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.1.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.1.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.1.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.1.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.1.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.1.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.1.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.1.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.1.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.1.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.1.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.1.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.1.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.1.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.1.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.1.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.1.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.1.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.1.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.10.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.10.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.10.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.10.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.10.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.10.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.10.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.10.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.10.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.10.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.10.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.10.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.10.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.10.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.10.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.10.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.10.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.10.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.10.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.10.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.11.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.11.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.11.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.11.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.11.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.11.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.11.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.11.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.11.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.11.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.11.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.11.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.11.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.11.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.11.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.11.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.11.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.11.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.11.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.11.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.12.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.12.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.12.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.12.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.12.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.12.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.12.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.12.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.12.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.12.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.12.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.12.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.12.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.12.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.12.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.12.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.12.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.12.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.12.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.12.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.13.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.13.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.13.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.13.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.13.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.13.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.13.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.13.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.13.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.13.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.13.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.13.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.13.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.13.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.13.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.13.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.13.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.13.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.13.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.13.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.14.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.14.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.14.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.14.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.14.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.14.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.14.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.14.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.14.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.14.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.14.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.14.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.14.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.14.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.14.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.14.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.14.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.14.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.14.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.14.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.15.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.15.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.15.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.15.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.15.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.15.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.15.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.15.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.15.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.15.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.15.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.15.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.15.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.15.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.15.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.15.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.15.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.15.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.15.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.15.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.16.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.16.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.16.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.16.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.16.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.16.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.16.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.16.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.16.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.16.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.16.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.16.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.16.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.16.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.16.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.16.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.16.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.16.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.16.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.16.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.17.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.17.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.17.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.17.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.17.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.17.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.17.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.17.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.17.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.17.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.17.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.17.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.17.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.17.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.17.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.17.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.17.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.17.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.17.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.17.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.18.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.18.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.18.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.18.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.18.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.18.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.18.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.18.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.18.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.18.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.18.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.18.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.18.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.18.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.18.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.18.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.18.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.18.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.18.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.18.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.2.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.2.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.2.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.2.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.2.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.2.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.2.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.2.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.2.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.2.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.2.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.2.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.2.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.2.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.2.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.2.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.2.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.2.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.2.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.2.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.3.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.3.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.3.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.3.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.3.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.3.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.3.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.3.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.3.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.3.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.3.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.3.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.3.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.3.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.3.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.3.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.3.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.3.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.3.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.3.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.4.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.4.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.4.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.4.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.4.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.4.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.4.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.4.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.4.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.4.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.4.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.4.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.4.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.4.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.4.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.4.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.4.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.4.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.4.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.4.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.5.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.5.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.5.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.5.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.5.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.5.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.5.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.5.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.5.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.5.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.5.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.5.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.5.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.5.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.5.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.5.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.5.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.5.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.5.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.5.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.6.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.6.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.6.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.6.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.6.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.6.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.6.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.6.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.6.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.6.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.6.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.6.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.6.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.6.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.6.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.6.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.6.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.6.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.6.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.6.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.7.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.7.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.7.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.7.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.7.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.7.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.7.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.7.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.7.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.7.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.7.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.7.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.7.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.7.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.7.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.7.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.7.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.7.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.7.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.7.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.8.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.8.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.8.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.8.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.8.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.8.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.8.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.8.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.8.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.8.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.8.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.8.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.8.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.8.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.8.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.8.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.8.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.8.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.8.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.8.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.9.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.9.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.9.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.9.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.9.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.9.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.9.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.9.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.9.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.9.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.9.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.9.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.9.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.9.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.9.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.9.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.9.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.9.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.9.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.9.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "final_layer.linear.bias" : torch.Size([64]),
    "final_layer.linear.weight" : torch.Size([64, 3072]),
    "img_in.bias" : torch.Size([3072]),
    "img_in.weight" : torch.Size([3072, 64]),
    "single_blocks.0.linear1.bias" : torch.Size([21504]),
    "single_blocks.0.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.0.linear2.bias" : torch.Size([3072]),
    "single_blocks.0.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.0.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.0.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.1.linear1.bias" : torch.Size([21504]),
    "single_blocks.1.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.1.linear2.bias" : torch.Size([3072]),
    "single_blocks.1.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.1.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.1.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.10.linear1.bias" : torch.Size([21504]),
    "single_blocks.10.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.10.linear2.bias" : torch.Size([3072]),
    "single_blocks.10.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.10.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.10.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.11.linear1.bias" : torch.Size([21504]),
    "single_blocks.11.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.11.linear2.bias" : torch.Size([3072]),
    "single_blocks.11.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.11.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.11.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.12.linear1.bias" : torch.Size([21504]),
    "single_blocks.12.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.12.linear2.bias" : torch.Size([3072]),
    "single_blocks.12.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.12.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.12.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.13.linear1.bias" : torch.Size([21504]),
    "single_blocks.13.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.13.linear2.bias" : torch.Size([3072]),
    "single_blocks.13.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.13.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.13.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.14.linear1.bias" : torch.Size([21504]),
    "single_blocks.14.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.14.linear2.bias" : torch.Size([3072]),
    "single_blocks.14.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.14.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.14.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.15.linear1.bias" : torch.Size([21504]),
    "single_blocks.15.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.15.linear2.bias" : torch.Size([3072]),
    "single_blocks.15.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.15.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.15.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.16.linear1.bias" : torch.Size([21504]),
    "single_blocks.16.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.16.linear2.bias" : torch.Size([3072]),
    "single_blocks.16.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.16.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.16.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.17.linear1.bias" : torch.Size([21504]),
    "single_blocks.17.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.17.linear2.bias" : torch.Size([3072]),
    "single_blocks.17.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.17.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.17.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.18.linear1.bias" : torch.Size([21504]),
    "single_blocks.18.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.18.linear2.bias" : torch.Size([3072]),
    "single_blocks.18.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.18.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.18.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.19.linear1.bias" : torch.Size([21504]),
    "single_blocks.19.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.19.linear2.bias" : torch.Size([3072]),
    "single_blocks.19.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.19.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.19.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.2.linear1.bias" : torch.Size([21504]),
    "single_blocks.2.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.2.linear2.bias" : torch.Size([3072]),
    "single_blocks.2.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.2.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.2.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.20.linear1.bias" : torch.Size([21504]),
    "single_blocks.20.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.20.linear2.bias" : torch.Size([3072]),
    "single_blocks.20.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.20.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.20.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.21.linear1.bias" : torch.Size([21504]),
    "single_blocks.21.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.21.linear2.bias" : torch.Size([3072]),
    "single_blocks.21.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.21.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.21.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.22.linear1.bias" : torch.Size([21504]),
    "single_blocks.22.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.22.linear2.bias" : torch.Size([3072]),
    "single_blocks.22.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.22.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.22.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.23.linear1.bias" : torch.Size([21504]),
    "single_blocks.23.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.23.linear2.bias" : torch.Size([3072]),
    "single_blocks.23.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.23.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.23.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.24.linear1.bias" : torch.Size([21504]),
    "single_blocks.24.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.24.linear2.bias" : torch.Size([3072]),
    "single_blocks.24.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.24.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.24.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.25.linear1.bias" : torch.Size([21504]),
    "single_blocks.25.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.25.linear2.bias" : torch.Size([3072]),
    "single_blocks.25.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.25.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.25.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.26.linear1.bias" : torch.Size([21504]),
    "single_blocks.26.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.26.linear2.bias" : torch.Size([3072]),
    "single_blocks.26.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.26.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.26.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.27.linear1.bias" : torch.Size([21504]),
    "single_blocks.27.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.27.linear2.bias" : torch.Size([3072]),
    "single_blocks.27.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.27.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.27.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.28.linear1.bias" : torch.Size([21504]),
    "single_blocks.28.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.28.linear2.bias" : torch.Size([3072]),
    "single_blocks.28.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.28.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.28.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.29.linear1.bias" : torch.Size([21504]),
    "single_blocks.29.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.29.linear2.bias" : torch.Size([3072]),
    "single_blocks.29.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.29.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.29.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.3.linear1.bias" : torch.Size([21504]),
    "single_blocks.3.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.3.linear2.bias" : torch.Size([3072]),
    "single_blocks.3.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.3.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.3.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.30.linear1.bias" : torch.Size([21504]),
    "single_blocks.30.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.30.linear2.bias" : torch.Size([3072]),
    "single_blocks.30.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.30.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.30.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.31.linear1.bias" : torch.Size([21504]),
    "single_blocks.31.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.31.linear2.bias" : torch.Size([3072]),
    "single_blocks.31.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.31.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.31.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.32.linear1.bias" : torch.Size([21504]),
    "single_blocks.32.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.32.linear2.bias" : torch.Size([3072]),
    "single_blocks.32.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.32.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.32.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.33.linear1.bias" : torch.Size([21504]),
    "single_blocks.33.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.33.linear2.bias" : torch.Size([3072]),
    "single_blocks.33.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.33.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.33.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.34.linear1.bias" : torch.Size([21504]),
    "single_blocks.34.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.34.linear2.bias" : torch.Size([3072]),
    "single_blocks.34.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.34.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.34.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.35.linear1.bias" : torch.Size([21504]),
    "single_blocks.35.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.35.linear2.bias" : torch.Size([3072]),
    "single_blocks.35.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.35.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.35.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.36.linear1.bias" : torch.Size([21504]),
    "single_blocks.36.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.36.linear2.bias" : torch.Size([3072]),
    "single_blocks.36.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.36.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.36.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.37.linear1.bias" : torch.Size([21504]),
    "single_blocks.37.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.37.linear2.bias" : torch.Size([3072]),
    "single_blocks.37.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.37.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.37.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.4.linear1.bias" : torch.Size([21504]),
    "single_blocks.4.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.4.linear2.bias" : torch.Size([3072]),
    "single_blocks.4.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.4.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.4.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.5.linear1.bias" : torch.Size([21504]),
    "single_blocks.5.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.5.linear2.bias" : torch.Size([3072]),
    "single_blocks.5.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.5.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.5.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.6.linear1.bias" : torch.Size([21504]),
    "single_blocks.6.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.6.linear2.bias" : torch.Size([3072]),
    "single_blocks.6.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.6.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.6.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.7.linear1.bias" : torch.Size([21504]),
    "single_blocks.7.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.7.linear2.bias" : torch.Size([3072]),
    "single_blocks.7.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.7.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.7.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.8.linear1.bias" : torch.Size([21504]),
    "single_blocks.8.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.8.linear2.bias" : torch.Size([3072]),
    "single_blocks.8.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.8.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.8.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.9.linear1.bias" : torch.Size([21504]),
    "single_blocks.9.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.9.linear2.bias" : torch.Size([3072]),
    "single_blocks.9.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.9.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.9.norm.query_norm.scale" : torch.Size([128]),
    "txt_in.bias" : torch.Size([3072]),
    "txt_in.weight" : torch.Size([3072, 4096]),
}


flux_schnell_keys_dict = {
    "double_blocks.0.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.0.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.0.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.0.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.0.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.0.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.0.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.0.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.0.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.0.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.0.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.0.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.0.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.0.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.0.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.0.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.0.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.0.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.0.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.0.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.0.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.0.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.0.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.0.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.1.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.1.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.1.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.1.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.1.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.1.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.1.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.1.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.1.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.1.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.1.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.1.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.1.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.1.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.1.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.1.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.1.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.1.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.1.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.1.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.1.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.1.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.1.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.1.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.10.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.10.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.10.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.10.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.10.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.10.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.10.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.10.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.10.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.10.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.10.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.10.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.10.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.10.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.10.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.10.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.10.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.10.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.10.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.10.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.10.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.10.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.10.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.10.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.11.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.11.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.11.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.11.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.11.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.11.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.11.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.11.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.11.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.11.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.11.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.11.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.11.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.11.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.11.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.11.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.11.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.11.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.11.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.11.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.11.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.11.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.11.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.11.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.12.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.12.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.12.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.12.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.12.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.12.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.12.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.12.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.12.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.12.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.12.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.12.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.12.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.12.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.12.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.12.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.12.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.12.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.12.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.12.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.12.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.12.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.12.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.12.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.13.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.13.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.13.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.13.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.13.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.13.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.13.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.13.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.13.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.13.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.13.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.13.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.13.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.13.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.13.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.13.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.13.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.13.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.13.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.13.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.13.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.13.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.13.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.13.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.14.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.14.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.14.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.14.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.14.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.14.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.14.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.14.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.14.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.14.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.14.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.14.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.14.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.14.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.14.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.14.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.14.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.14.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.14.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.14.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.14.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.14.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.14.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.14.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.15.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.15.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.15.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.15.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.15.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.15.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.15.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.15.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.15.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.15.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.15.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.15.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.15.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.15.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.15.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.15.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.15.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.15.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.15.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.15.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.15.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.15.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.15.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.15.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.16.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.16.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.16.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.16.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.16.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.16.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.16.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.16.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.16.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.16.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.16.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.16.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.16.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.16.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.16.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.16.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.16.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.16.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.16.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.16.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.16.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.16.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.16.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.16.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.17.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.17.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.17.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.17.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.17.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.17.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.17.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.17.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.17.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.17.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.17.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.17.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.17.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.17.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.17.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.17.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.17.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.17.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.17.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.17.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.17.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.17.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.17.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.17.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.18.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.18.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.18.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.18.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.18.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.18.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.18.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.18.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.18.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.18.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.18.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.18.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.18.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.18.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.18.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.18.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.18.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.18.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.18.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.18.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.18.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.18.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.18.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.18.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.2.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.2.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.2.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.2.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.2.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.2.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.2.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.2.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.2.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.2.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.2.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.2.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.2.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.2.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.2.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.2.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.2.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.2.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.2.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.2.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.2.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.2.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.2.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.2.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.3.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.3.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.3.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.3.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.3.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.3.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.3.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.3.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.3.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.3.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.3.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.3.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.3.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.3.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.3.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.3.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.3.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.3.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.3.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.3.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.3.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.3.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.3.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.3.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.4.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.4.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.4.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.4.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.4.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.4.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.4.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.4.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.4.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.4.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.4.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.4.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.4.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.4.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.4.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.4.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.4.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.4.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.4.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.4.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.4.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.4.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.4.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.4.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.5.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.5.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.5.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.5.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.5.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.5.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.5.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.5.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.5.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.5.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.5.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.5.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.5.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.5.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.5.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.5.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.5.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.5.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.5.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.5.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.5.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.5.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.5.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.5.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.6.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.6.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.6.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.6.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.6.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.6.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.6.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.6.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.6.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.6.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.6.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.6.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.6.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.6.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.6.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.6.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.6.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.6.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.6.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.6.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.6.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.6.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.6.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.6.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.7.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.7.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.7.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.7.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.7.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.7.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.7.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.7.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.7.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.7.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.7.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.7.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.7.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.7.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.7.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.7.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.7.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.7.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.7.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.7.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.7.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.7.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.7.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.7.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.8.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.8.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.8.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.8.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.8.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.8.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.8.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.8.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.8.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.8.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.8.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.8.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.8.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.8.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.8.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.8.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.8.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.8.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.8.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.8.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.8.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.8.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.8.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.8.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.9.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.9.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.9.img_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.9.img_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.9.img_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.9.img_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.9.img_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.9.img_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.9.img_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.9.img_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.9.img_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.9.img_mod.lin.weight" : torch.Size([18432, 3072]),
    "double_blocks.9.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.9.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.9.txt_attn.proj.bias" : torch.Size([3072]),
    "double_blocks.9.txt_attn.proj.weight" : torch.Size([3072, 3072]),
    "double_blocks.9.txt_attn.qkv.bias" : torch.Size([9216]),
    "double_blocks.9.txt_attn.qkv.weight" : torch.Size([9216, 3072]),
    "double_blocks.9.txt_mlp.0.bias" : torch.Size([12288]),
    "double_blocks.9.txt_mlp.0.weight" : torch.Size([12288, 3072]),
    "double_blocks.9.txt_mlp.2.bias" : torch.Size([3072]),
    "double_blocks.9.txt_mlp.2.weight" : torch.Size([3072, 12288]),
    "double_blocks.9.txt_mod.lin.bias" : torch.Size([18432]),
    "double_blocks.9.txt_mod.lin.weight" : torch.Size([18432, 3072]),
    "final_layer.adaLN_modulation.1.bias" : torch.Size([6144]),
    "final_layer.adaLN_modulation.1.weight" : torch.Size([6144, 3072]),
    "final_layer.linear.bias" : torch.Size([64]),
    "final_layer.linear.weight" : torch.Size([64, 3072]),
    "img_in.bias" : torch.Size([3072]),
    "img_in.weight" : torch.Size([3072, 64]),
    "single_blocks.0.linear1.bias" : torch.Size([21504]),
    "single_blocks.0.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.0.linear2.bias" : torch.Size([3072]),
    "single_blocks.0.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.0.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.0.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.0.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.0.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.1.linear1.bias" : torch.Size([21504]),
    "single_blocks.1.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.1.linear2.bias" : torch.Size([3072]),
    "single_blocks.1.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.1.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.1.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.1.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.1.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.10.linear1.bias" : torch.Size([21504]),
    "single_blocks.10.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.10.linear2.bias" : torch.Size([3072]),
    "single_blocks.10.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.10.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.10.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.10.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.10.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.11.linear1.bias" : torch.Size([21504]),
    "single_blocks.11.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.11.linear2.bias" : torch.Size([3072]),
    "single_blocks.11.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.11.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.11.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.11.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.11.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.12.linear1.bias" : torch.Size([21504]),
    "single_blocks.12.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.12.linear2.bias" : torch.Size([3072]),
    "single_blocks.12.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.12.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.12.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.12.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.12.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.13.linear1.bias" : torch.Size([21504]),
    "single_blocks.13.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.13.linear2.bias" : torch.Size([3072]),
    "single_blocks.13.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.13.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.13.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.13.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.13.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.14.linear1.bias" : torch.Size([21504]),
    "single_blocks.14.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.14.linear2.bias" : torch.Size([3072]),
    "single_blocks.14.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.14.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.14.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.14.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.14.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.15.linear1.bias" : torch.Size([21504]),
    "single_blocks.15.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.15.linear2.bias" : torch.Size([3072]),
    "single_blocks.15.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.15.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.15.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.15.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.15.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.16.linear1.bias" : torch.Size([21504]),
    "single_blocks.16.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.16.linear2.bias" : torch.Size([3072]),
    "single_blocks.16.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.16.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.16.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.16.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.16.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.17.linear1.bias" : torch.Size([21504]),
    "single_blocks.17.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.17.linear2.bias" : torch.Size([3072]),
    "single_blocks.17.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.17.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.17.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.17.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.17.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.18.linear1.bias" : torch.Size([21504]),
    "single_blocks.18.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.18.linear2.bias" : torch.Size([3072]),
    "single_blocks.18.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.18.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.18.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.18.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.18.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.19.linear1.bias" : torch.Size([21504]),
    "single_blocks.19.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.19.linear2.bias" : torch.Size([3072]),
    "single_blocks.19.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.19.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.19.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.19.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.19.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.2.linear1.bias" : torch.Size([21504]),
    "single_blocks.2.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.2.linear2.bias" : torch.Size([3072]),
    "single_blocks.2.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.2.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.2.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.2.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.2.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.20.linear1.bias" : torch.Size([21504]),
    "single_blocks.20.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.20.linear2.bias" : torch.Size([3072]),
    "single_blocks.20.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.20.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.20.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.20.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.20.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.21.linear1.bias" : torch.Size([21504]),
    "single_blocks.21.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.21.linear2.bias" : torch.Size([3072]),
    "single_blocks.21.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.21.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.21.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.21.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.21.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.22.linear1.bias" : torch.Size([21504]),
    "single_blocks.22.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.22.linear2.bias" : torch.Size([3072]),
    "single_blocks.22.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.22.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.22.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.22.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.22.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.23.linear1.bias" : torch.Size([21504]),
    "single_blocks.23.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.23.linear2.bias" : torch.Size([3072]),
    "single_blocks.23.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.23.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.23.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.23.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.23.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.24.linear1.bias" : torch.Size([21504]),
    "single_blocks.24.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.24.linear2.bias" : torch.Size([3072]),
    "single_blocks.24.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.24.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.24.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.24.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.24.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.25.linear1.bias" : torch.Size([21504]),
    "single_blocks.25.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.25.linear2.bias" : torch.Size([3072]),
    "single_blocks.25.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.25.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.25.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.25.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.25.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.26.linear1.bias" : torch.Size([21504]),
    "single_blocks.26.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.26.linear2.bias" : torch.Size([3072]),
    "single_blocks.26.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.26.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.26.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.26.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.26.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.27.linear1.bias" : torch.Size([21504]),
    "single_blocks.27.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.27.linear2.bias" : torch.Size([3072]),
    "single_blocks.27.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.27.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.27.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.27.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.27.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.28.linear1.bias" : torch.Size([21504]),
    "single_blocks.28.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.28.linear2.bias" : torch.Size([3072]),
    "single_blocks.28.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.28.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.28.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.28.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.28.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.29.linear1.bias" : torch.Size([21504]),
    "single_blocks.29.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.29.linear2.bias" : torch.Size([3072]),
    "single_blocks.29.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.29.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.29.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.29.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.29.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.3.linear1.bias" : torch.Size([21504]),
    "single_blocks.3.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.3.linear2.bias" : torch.Size([3072]),
    "single_blocks.3.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.3.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.3.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.3.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.3.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.30.linear1.bias" : torch.Size([21504]),
    "single_blocks.30.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.30.linear2.bias" : torch.Size([3072]),
    "single_blocks.30.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.30.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.30.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.30.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.30.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.31.linear1.bias" : torch.Size([21504]),
    "single_blocks.31.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.31.linear2.bias" : torch.Size([3072]),
    "single_blocks.31.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.31.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.31.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.31.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.31.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.32.linear1.bias" : torch.Size([21504]),
    "single_blocks.32.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.32.linear2.bias" : torch.Size([3072]),
    "single_blocks.32.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.32.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.32.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.32.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.32.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.33.linear1.bias" : torch.Size([21504]),
    "single_blocks.33.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.33.linear2.bias" : torch.Size([3072]),
    "single_blocks.33.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.33.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.33.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.33.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.33.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.34.linear1.bias" : torch.Size([21504]),
    "single_blocks.34.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.34.linear2.bias" : torch.Size([3072]),
    "single_blocks.34.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.34.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.34.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.34.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.34.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.35.linear1.bias" : torch.Size([21504]),
    "single_blocks.35.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.35.linear2.bias" : torch.Size([3072]),
    "single_blocks.35.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.35.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.35.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.35.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.35.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.36.linear1.bias" : torch.Size([21504]),
    "single_blocks.36.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.36.linear2.bias" : torch.Size([3072]),
    "single_blocks.36.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.36.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.36.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.36.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.36.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.37.linear1.bias" : torch.Size([21504]),
    "single_blocks.37.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.37.linear2.bias" : torch.Size([3072]),
    "single_blocks.37.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.37.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.37.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.37.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.37.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.4.linear1.bias" : torch.Size([21504]),
    "single_blocks.4.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.4.linear2.bias" : torch.Size([3072]),
    "single_blocks.4.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.4.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.4.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.4.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.4.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.5.linear1.bias" : torch.Size([21504]),
    "single_blocks.5.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.5.linear2.bias" : torch.Size([3072]),
    "single_blocks.5.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.5.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.5.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.5.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.5.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.6.linear1.bias" : torch.Size([21504]),
    "single_blocks.6.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.6.linear2.bias" : torch.Size([3072]),
    "single_blocks.6.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.6.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.6.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.6.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.6.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.7.linear1.bias" : torch.Size([21504]),
    "single_blocks.7.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.7.linear2.bias" : torch.Size([3072]),
    "single_blocks.7.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.7.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.7.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.7.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.7.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.8.linear1.bias" : torch.Size([21504]),
    "single_blocks.8.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.8.linear2.bias" : torch.Size([3072]),
    "single_blocks.8.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.8.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.8.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.8.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.8.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.9.linear1.bias" : torch.Size([21504]),
    "single_blocks.9.linear1.weight" : torch.Size([21504, 3072]),
    "single_blocks.9.linear2.bias" : torch.Size([3072]),
    "single_blocks.9.linear2.weight" : torch.Size([3072, 15360]),
    "single_blocks.9.modulation.lin.bias" : torch.Size([9216]),
    "single_blocks.9.modulation.lin.weight" : torch.Size([9216, 3072]),
    "single_blocks.9.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.9.norm.query_norm.scale" : torch.Size([128]),
    "time_in.in_layer.bias" : torch.Size([3072]),
    "time_in.in_layer.weight" : torch.Size([3072, 256]),
    "time_in.out_layer.bias" : torch.Size([3072]),
    "time_in.out_layer.weight" : torch.Size([3072, 3072]),
    "txt_in.bias" : torch.Size([3072]),
    "txt_in.weight" : torch.Size([3072, 4096]),
    "vector_in.in_layer.bias" : torch.Size([3072]),
    "vector_in.in_layer.weight" : torch.Size([3072, 768]),
    "vector_in.out_layer.bias" : torch.Size([3072]),
    "vector_in.out_layer.weight" : torch.Size([3072, 3072]),
}

flux_keys_dict = flux_schnell_keys_dict | {
    "guidance_in.in_layer.bias" : torch.Size([3072]),
    "guidance_in.in_layer.weight" : torch.Size([3072, 256]),
    "guidance_in.out_layer.bias" : torch.Size([3072]),
    "guidance_in.out_layer.weight" : torch.Size([3072, 3072]),
}

ernie_keys_dict = {
    "x_embedder.proj.weight" : torch.Size([4096, 128, 1, 1]),
    "x_embedder.proj.bias" : torch.Size([4096]),
    "text_proj.weight" : torch.Size([4096, 3072]),
    "time_embedding.linear_1.weight" : torch.Size([4096, 4096]),
    "time_embedding.linear_1.bias" : torch.Size([4096]),
    "time_embedding.linear_2.weight" : torch.Size([4096, 4096]),
    "time_embedding.linear_2.bias" : torch.Size([4096]),
    "adaLN_modulation.1.weight" : torch.Size([24576, 4096]),
    "adaLN_modulation.1.bias" : torch.Size([24576]),
    "layers.0.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.0.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.0.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.0.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.0.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.0.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.0.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.0.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.0.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.0.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.0.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.1.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.1.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.1.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.1.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.1.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.1.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.1.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.1.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.1.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.1.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.1.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.2.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.2.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.2.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.2.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.2.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.2.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.2.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.2.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.2.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.2.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.2.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.3.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.3.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.3.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.3.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.3.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.3.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.3.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.3.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.3.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.3.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.3.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.4.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.4.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.4.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.4.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.4.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.4.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.4.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.4.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.4.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.4.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.4.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.5.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.5.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.5.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.5.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.5.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.5.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.5.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.5.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.5.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.5.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.5.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.6.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.6.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.6.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.6.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.6.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.6.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.6.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.6.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.6.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.6.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.6.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.7.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.7.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.7.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.7.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.7.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.7.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.7.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.7.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.7.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.7.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.7.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.8.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.8.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.8.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.8.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.8.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.8.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.8.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.8.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.8.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.8.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.8.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.9.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.9.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.9.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.9.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.9.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.9.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.9.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.9.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.9.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.9.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.9.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.10.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.10.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.10.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.10.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.10.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.10.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.10.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.10.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.10.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.10.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.10.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.11.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.11.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.11.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.11.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.11.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.11.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.11.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.11.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.11.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.11.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.11.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.12.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.12.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.12.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.12.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.12.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.12.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.12.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.12.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.12.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.12.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.12.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.13.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.13.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.13.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.13.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.13.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.13.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.13.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.13.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.13.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.13.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.13.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.14.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.14.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.14.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.14.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.14.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.14.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.14.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.14.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.14.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.14.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.14.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.15.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.15.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.15.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.15.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.15.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.15.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.15.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.15.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.15.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.15.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.15.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.16.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.16.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.16.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.16.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.16.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.16.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.16.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.16.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.16.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.16.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.16.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.17.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.17.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.17.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.17.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.17.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.17.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.17.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.17.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.17.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.17.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.17.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.18.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.18.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.18.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.18.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.18.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.18.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.18.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.18.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.18.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.18.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.18.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.19.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.19.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.19.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.19.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.19.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.19.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.19.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.19.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.19.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.19.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.19.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.20.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.20.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.20.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.20.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.20.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.20.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.20.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.20.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.20.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.20.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.20.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.21.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.21.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.21.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.21.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.21.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.21.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.21.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.21.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.21.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.21.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.21.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.22.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.22.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.22.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.22.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.22.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.22.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.22.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.22.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.22.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.22.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.22.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.23.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.23.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.23.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.23.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.23.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.23.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.23.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.23.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.23.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.23.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.23.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.24.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.24.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.24.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.24.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.24.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.24.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.24.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.24.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.24.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.24.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.24.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.25.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.25.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.25.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.25.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.25.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.25.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.25.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.25.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.25.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.25.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.25.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.26.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.26.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.26.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.26.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.26.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.26.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.26.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.26.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.26.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.26.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.26.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.27.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.27.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.27.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.27.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.27.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.27.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.27.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.27.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.27.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.27.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.27.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.28.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.28.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.28.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.28.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.28.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.28.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.28.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.28.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.28.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.28.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.28.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.29.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.29.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.29.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.29.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.29.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.29.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.29.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.29.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.29.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.29.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.29.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.30.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.30.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.30.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.30.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.30.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.30.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.30.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.30.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.30.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.30.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.30.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.31.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.31.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.31.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.31.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.31.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.31.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.31.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.31.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.31.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.31.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.31.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.32.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.32.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.32.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.32.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.32.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.32.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.32.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.32.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.32.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.32.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.32.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.33.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.33.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.33.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.33.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.33.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.33.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.33.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.33.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.33.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.33.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.33.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.34.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.34.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.34.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.34.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.34.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.34.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.34.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.34.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.34.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.34.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.34.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "layers.35.adaLN_sa_ln.weight" : torch.Size([4096]),
    "layers.35.self_attention.to_q.weight" : torch.Size([4096, 4096]),
    "layers.35.self_attention.to_k.weight" : torch.Size([4096, 4096]),
    "layers.35.self_attention.to_v.weight" : torch.Size([4096, 4096]),
    "layers.35.self_attention.norm_q.weight" : torch.Size([128]),
    "layers.35.self_attention.norm_k.weight" : torch.Size([128]),
    "layers.35.self_attention.to_out.0.weight" : torch.Size([4096, 4096]),
    "layers.35.adaLN_mlp_ln.weight" : torch.Size([4096]),
    "layers.35.mlp.gate_proj.weight" : torch.Size([12288, 4096]),
    "layers.35.mlp.up_proj.weight" : torch.Size([12288, 4096]),
    "layers.35.mlp.linear_fc2.weight" : torch.Size([4096, 12288]),
    "final_norm.linear.weight" : torch.Size([8192, 4096]),
    "final_norm.linear.bias" : torch.Size([8192]),
    "final_linear.weight" : torch.Size([128, 4096]),
    "final_linear.bias" : torch.Size([128]),
}

zimage_keys_dict = {
    "cap_embedder.0.weight" : torch.Size([2560]),
    "cap_embedder.1.bias" : torch.Size([3840]),
    "cap_embedder.1.weight" : torch.Size([3840, 2560]),
    "context_refiner.0.attention.k_norm.weight" : torch.Size([128]),
    "context_refiner.0.attention.out.weight" : torch.Size([3840, 3840]),
    "context_refiner.0.attention.q_norm.weight" : torch.Size([128]),
    "context_refiner.0.attention.qkv.weight" : torch.Size([11520, 3840]),
    "context_refiner.0.attention_norm1.weight" : torch.Size([3840]),
    "context_refiner.0.attention_norm2.weight" : torch.Size([3840]),
    "context_refiner.0.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "context_refiner.0.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "context_refiner.0.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "context_refiner.0.ffn_norm1.weight" : torch.Size([3840]),
    "context_refiner.0.ffn_norm2.weight" : torch.Size([3840]),
    "context_refiner.1.attention.k_norm.weight" : torch.Size([128]),
    "context_refiner.1.attention.out.weight" : torch.Size([3840, 3840]),
    "context_refiner.1.attention.q_norm.weight" : torch.Size([128]),
    "context_refiner.1.attention.qkv.weight" : torch.Size([11520, 3840]),
    "context_refiner.1.attention_norm1.weight" : torch.Size([3840]),
    "context_refiner.1.attention_norm2.weight" : torch.Size([3840]),
    "context_refiner.1.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "context_refiner.1.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "context_refiner.1.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "context_refiner.1.ffn_norm1.weight" : torch.Size([3840]),
    "context_refiner.1.ffn_norm2.weight" : torch.Size([3840]),
    "final_layer.adaLN_modulation.1.bias" : torch.Size([3840]),
    "final_layer.adaLN_modulation.1.weight" : torch.Size([3840, 256]),
    "final_layer.linear.bias" : torch.Size([64]),
    "final_layer.linear.weight" : torch.Size([64, 3840]),
    "layers.0.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.0.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.0.attention.k_norm.weight" : torch.Size([128]),
    "layers.0.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.0.attention.q_norm.weight" : torch.Size([128]),
    "layers.0.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.0.attention_norm1.weight" : torch.Size([3840]),
    "layers.0.attention_norm2.weight" : torch.Size([3840]),
    "layers.0.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.0.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.0.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.0.ffn_norm1.weight" : torch.Size([3840]),
    "layers.0.ffn_norm2.weight" : torch.Size([3840]),
    "layers.1.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.1.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.1.attention.k_norm.weight" : torch.Size([128]),
    "layers.1.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.1.attention.q_norm.weight" : torch.Size([128]),
    "layers.1.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.1.attention_norm1.weight" : torch.Size([3840]),
    "layers.1.attention_norm2.weight" : torch.Size([3840]),
    "layers.1.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.1.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.1.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.1.ffn_norm1.weight" : torch.Size([3840]),
    "layers.1.ffn_norm2.weight" : torch.Size([3840]),
    "layers.10.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.10.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.10.attention.k_norm.weight" : torch.Size([128]),
    "layers.10.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.10.attention.q_norm.weight" : torch.Size([128]),
    "layers.10.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.10.attention_norm1.weight" : torch.Size([3840]),
    "layers.10.attention_norm2.weight" : torch.Size([3840]),
    "layers.10.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.10.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.10.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.10.ffn_norm1.weight" : torch.Size([3840]),
    "layers.10.ffn_norm2.weight" : torch.Size([3840]),
    "layers.11.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.11.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.11.attention.k_norm.weight" : torch.Size([128]),
    "layers.11.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.11.attention.q_norm.weight" : torch.Size([128]),
    "layers.11.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.11.attention_norm1.weight" : torch.Size([3840]),
    "layers.11.attention_norm2.weight" : torch.Size([3840]),
    "layers.11.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.11.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.11.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.11.ffn_norm1.weight" : torch.Size([3840]),
    "layers.11.ffn_norm2.weight" : torch.Size([3840]),
    "layers.12.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.12.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.12.attention.k_norm.weight" : torch.Size([128]),
    "layers.12.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.12.attention.q_norm.weight" : torch.Size([128]),
    "layers.12.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.12.attention_norm1.weight" : torch.Size([3840]),
    "layers.12.attention_norm2.weight" : torch.Size([3840]),
    "layers.12.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.12.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.12.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.12.ffn_norm1.weight" : torch.Size([3840]),
    "layers.12.ffn_norm2.weight" : torch.Size([3840]),
    "layers.13.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.13.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.13.attention.k_norm.weight" : torch.Size([128]),
    "layers.13.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.13.attention.q_norm.weight" : torch.Size([128]),
    "layers.13.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.13.attention_norm1.weight" : torch.Size([3840]),
    "layers.13.attention_norm2.weight" : torch.Size([3840]),
    "layers.13.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.13.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.13.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.13.ffn_norm1.weight" : torch.Size([3840]),
    "layers.13.ffn_norm2.weight" : torch.Size([3840]),
    "layers.14.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.14.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.14.attention.k_norm.weight" : torch.Size([128]),
    "layers.14.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.14.attention.q_norm.weight" : torch.Size([128]),
    "layers.14.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.14.attention_norm1.weight" : torch.Size([3840]),
    "layers.14.attention_norm2.weight" : torch.Size([3840]),
    "layers.14.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.14.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.14.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.14.ffn_norm1.weight" : torch.Size([3840]),
    "layers.14.ffn_norm2.weight" : torch.Size([3840]),
    "layers.15.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.15.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.15.attention.k_norm.weight" : torch.Size([128]),
    "layers.15.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.15.attention.q_norm.weight" : torch.Size([128]),
    "layers.15.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.15.attention_norm1.weight" : torch.Size([3840]),
    "layers.15.attention_norm2.weight" : torch.Size([3840]),
    "layers.15.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.15.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.15.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.15.ffn_norm1.weight" : torch.Size([3840]),
    "layers.15.ffn_norm2.weight" : torch.Size([3840]),
    "layers.16.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.16.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.16.attention.k_norm.weight" : torch.Size([128]),
    "layers.16.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.16.attention.q_norm.weight" : torch.Size([128]),
    "layers.16.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.16.attention_norm1.weight" : torch.Size([3840]),
    "layers.16.attention_norm2.weight" : torch.Size([3840]),
    "layers.16.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.16.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.16.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.16.ffn_norm1.weight" : torch.Size([3840]),
    "layers.16.ffn_norm2.weight" : torch.Size([3840]),
    "layers.17.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.17.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.17.attention.k_norm.weight" : torch.Size([128]),
    "layers.17.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.17.attention.q_norm.weight" : torch.Size([128]),
    "layers.17.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.17.attention_norm1.weight" : torch.Size([3840]),
    "layers.17.attention_norm2.weight" : torch.Size([3840]),
    "layers.17.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.17.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.17.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.17.ffn_norm1.weight" : torch.Size([3840]),
    "layers.17.ffn_norm2.weight" : torch.Size([3840]),
    "layers.18.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.18.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.18.attention.k_norm.weight" : torch.Size([128]),
    "layers.18.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.18.attention.q_norm.weight" : torch.Size([128]),
    "layers.18.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.18.attention_norm1.weight" : torch.Size([3840]),
    "layers.18.attention_norm2.weight" : torch.Size([3840]),
    "layers.18.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.18.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.18.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.18.ffn_norm1.weight" : torch.Size([3840]),
    "layers.18.ffn_norm2.weight" : torch.Size([3840]),
    "layers.19.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.19.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.19.attention.k_norm.weight" : torch.Size([128]),
    "layers.19.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.19.attention.q_norm.weight" : torch.Size([128]),
    "layers.19.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.19.attention_norm1.weight" : torch.Size([3840]),
    "layers.19.attention_norm2.weight" : torch.Size([3840]),
    "layers.19.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.19.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.19.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.19.ffn_norm1.weight" : torch.Size([3840]),
    "layers.19.ffn_norm2.weight" : torch.Size([3840]),
    "layers.2.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.2.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.2.attention.k_norm.weight" : torch.Size([128]),
    "layers.2.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.2.attention.q_norm.weight" : torch.Size([128]),
    "layers.2.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.2.attention_norm1.weight" : torch.Size([3840]),
    "layers.2.attention_norm2.weight" : torch.Size([3840]),
    "layers.2.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.2.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.2.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.2.ffn_norm1.weight" : torch.Size([3840]),
    "layers.2.ffn_norm2.weight" : torch.Size([3840]),
    "layers.20.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.20.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.20.attention.k_norm.weight" : torch.Size([128]),
    "layers.20.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.20.attention.q_norm.weight" : torch.Size([128]),
    "layers.20.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.20.attention_norm1.weight" : torch.Size([3840]),
    "layers.20.attention_norm2.weight" : torch.Size([3840]),
    "layers.20.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.20.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.20.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.20.ffn_norm1.weight" : torch.Size([3840]),
    "layers.20.ffn_norm2.weight" : torch.Size([3840]),
    "layers.21.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.21.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.21.attention.k_norm.weight" : torch.Size([128]),
    "layers.21.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.21.attention.q_norm.weight" : torch.Size([128]),
    "layers.21.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.21.attention_norm1.weight" : torch.Size([3840]),
    "layers.21.attention_norm2.weight" : torch.Size([3840]),
    "layers.21.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.21.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.21.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.21.ffn_norm1.weight" : torch.Size([3840]),
    "layers.21.ffn_norm2.weight" : torch.Size([3840]),
    "layers.22.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.22.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.22.attention.k_norm.weight" : torch.Size([128]),
    "layers.22.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.22.attention.q_norm.weight" : torch.Size([128]),
    "layers.22.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.22.attention_norm1.weight" : torch.Size([3840]),
    "layers.22.attention_norm2.weight" : torch.Size([3840]),
    "layers.22.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.22.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.22.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.22.ffn_norm1.weight" : torch.Size([3840]),
    "layers.22.ffn_norm2.weight" : torch.Size([3840]),
    "layers.23.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.23.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.23.attention.k_norm.weight" : torch.Size([128]),
    "layers.23.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.23.attention.q_norm.weight" : torch.Size([128]),
    "layers.23.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.23.attention_norm1.weight" : torch.Size([3840]),
    "layers.23.attention_norm2.weight" : torch.Size([3840]),
    "layers.23.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.23.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.23.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.23.ffn_norm1.weight" : torch.Size([3840]),
    "layers.23.ffn_norm2.weight" : torch.Size([3840]),
    "layers.24.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.24.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.24.attention.k_norm.weight" : torch.Size([128]),
    "layers.24.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.24.attention.q_norm.weight" : torch.Size([128]),
    "layers.24.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.24.attention_norm1.weight" : torch.Size([3840]),
    "layers.24.attention_norm2.weight" : torch.Size([3840]),
    "layers.24.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.24.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.24.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.24.ffn_norm1.weight" : torch.Size([3840]),
    "layers.24.ffn_norm2.weight" : torch.Size([3840]),
    "layers.25.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.25.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.25.attention.k_norm.weight" : torch.Size([128]),
    "layers.25.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.25.attention.q_norm.weight" : torch.Size([128]),
    "layers.25.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.25.attention_norm1.weight" : torch.Size([3840]),
    "layers.25.attention_norm2.weight" : torch.Size([3840]),
    "layers.25.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.25.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.25.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.25.ffn_norm1.weight" : torch.Size([3840]),
    "layers.25.ffn_norm2.weight" : torch.Size([3840]),
    "layers.26.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.26.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.26.attention.k_norm.weight" : torch.Size([128]),
    "layers.26.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.26.attention.q_norm.weight" : torch.Size([128]),
    "layers.26.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.26.attention_norm1.weight" : torch.Size([3840]),
    "layers.26.attention_norm2.weight" : torch.Size([3840]),
    "layers.26.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.26.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.26.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.26.ffn_norm1.weight" : torch.Size([3840]),
    "layers.26.ffn_norm2.weight" : torch.Size([3840]),
    "layers.27.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.27.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.27.attention.k_norm.weight" : torch.Size([128]),
    "layers.27.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.27.attention.q_norm.weight" : torch.Size([128]),
    "layers.27.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.27.attention_norm1.weight" : torch.Size([3840]),
    "layers.27.attention_norm2.weight" : torch.Size([3840]),
    "layers.27.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.27.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.27.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.27.ffn_norm1.weight" : torch.Size([3840]),
    "layers.27.ffn_norm2.weight" : torch.Size([3840]),
    "layers.28.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.28.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.28.attention.k_norm.weight" : torch.Size([128]),
    "layers.28.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.28.attention.q_norm.weight" : torch.Size([128]),
    "layers.28.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.28.attention_norm1.weight" : torch.Size([3840]),
    "layers.28.attention_norm2.weight" : torch.Size([3840]),
    "layers.28.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.28.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.28.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.28.ffn_norm1.weight" : torch.Size([3840]),
    "layers.28.ffn_norm2.weight" : torch.Size([3840]),
    "layers.29.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.29.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.29.attention.k_norm.weight" : torch.Size([128]),
    "layers.29.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.29.attention.q_norm.weight" : torch.Size([128]),
    "layers.29.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.29.attention_norm1.weight" : torch.Size([3840]),
    "layers.29.attention_norm2.weight" : torch.Size([3840]),
    "layers.29.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.29.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.29.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.29.ffn_norm1.weight" : torch.Size([3840]),
    "layers.29.ffn_norm2.weight" : torch.Size([3840]),
    "layers.3.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.3.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.3.attention.k_norm.weight" : torch.Size([128]),
    "layers.3.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.3.attention.q_norm.weight" : torch.Size([128]),
    "layers.3.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.3.attention_norm1.weight" : torch.Size([3840]),
    "layers.3.attention_norm2.weight" : torch.Size([3840]),
    "layers.3.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.3.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.3.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.3.ffn_norm1.weight" : torch.Size([3840]),
    "layers.3.ffn_norm2.weight" : torch.Size([3840]),
    "layers.4.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.4.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.4.attention.k_norm.weight" : torch.Size([128]),
    "layers.4.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.4.attention.q_norm.weight" : torch.Size([128]),
    "layers.4.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.4.attention_norm1.weight" : torch.Size([3840]),
    "layers.4.attention_norm2.weight" : torch.Size([3840]),
    "layers.4.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.4.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.4.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.4.ffn_norm1.weight" : torch.Size([3840]),
    "layers.4.ffn_norm2.weight" : torch.Size([3840]),
    "layers.5.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.5.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.5.attention.k_norm.weight" : torch.Size([128]),
    "layers.5.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.5.attention.q_norm.weight" : torch.Size([128]),
    "layers.5.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.5.attention_norm1.weight" : torch.Size([3840]),
    "layers.5.attention_norm2.weight" : torch.Size([3840]),
    "layers.5.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.5.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.5.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.5.ffn_norm1.weight" : torch.Size([3840]),
    "layers.5.ffn_norm2.weight" : torch.Size([3840]),
    "layers.6.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.6.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.6.attention.k_norm.weight" : torch.Size([128]),
    "layers.6.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.6.attention.q_norm.weight" : torch.Size([128]),
    "layers.6.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.6.attention_norm1.weight" : torch.Size([3840]),
    "layers.6.attention_norm2.weight" : torch.Size([3840]),
    "layers.6.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.6.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.6.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.6.ffn_norm1.weight" : torch.Size([3840]),
    "layers.6.ffn_norm2.weight" : torch.Size([3840]),
    "layers.7.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.7.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.7.attention.k_norm.weight" : torch.Size([128]),
    "layers.7.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.7.attention.q_norm.weight" : torch.Size([128]),
    "layers.7.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.7.attention_norm1.weight" : torch.Size([3840]),
    "layers.7.attention_norm2.weight" : torch.Size([3840]),
    "layers.7.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.7.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.7.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.7.ffn_norm1.weight" : torch.Size([3840]),
    "layers.7.ffn_norm2.weight" : torch.Size([3840]),
    "layers.8.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.8.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.8.attention.k_norm.weight" : torch.Size([128]),
    "layers.8.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.8.attention.q_norm.weight" : torch.Size([128]),
    "layers.8.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.8.attention_norm1.weight" : torch.Size([3840]),
    "layers.8.attention_norm2.weight" : torch.Size([3840]),
    "layers.8.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.8.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.8.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.8.ffn_norm1.weight" : torch.Size([3840]),
    "layers.8.ffn_norm2.weight" : torch.Size([3840]),
    "layers.9.adaLN_modulation.0.bias" : torch.Size([15360]),
    "layers.9.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "layers.9.attention.k_norm.weight" : torch.Size([128]),
    "layers.9.attention.out.weight" : torch.Size([3840, 3840]),
    "layers.9.attention.q_norm.weight" : torch.Size([128]),
    "layers.9.attention.qkv.weight" : torch.Size([11520, 3840]),
    "layers.9.attention_norm1.weight" : torch.Size([3840]),
    "layers.9.attention_norm2.weight" : torch.Size([3840]),
    "layers.9.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "layers.9.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "layers.9.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "layers.9.ffn_norm1.weight" : torch.Size([3840]),
    "layers.9.ffn_norm2.weight" : torch.Size([3840]),
    "noise_refiner.0.adaLN_modulation.0.bias" : torch.Size([15360]),
    "noise_refiner.0.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "noise_refiner.0.attention.k_norm.weight" : torch.Size([128]),
    "noise_refiner.0.attention.out.weight" : torch.Size([3840, 3840]),
    "noise_refiner.0.attention.q_norm.weight" : torch.Size([128]),
    "noise_refiner.0.attention.qkv.weight" : torch.Size([11520, 3840]),
    "noise_refiner.0.attention_norm1.weight" : torch.Size([3840]),
    "noise_refiner.0.attention_norm2.weight" : torch.Size([3840]),
    "noise_refiner.0.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "noise_refiner.0.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "noise_refiner.0.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "noise_refiner.0.ffn_norm1.weight" : torch.Size([3840]),
    "noise_refiner.0.ffn_norm2.weight" : torch.Size([3840]),
    "noise_refiner.1.adaLN_modulation.0.bias" : torch.Size([15360]),
    "noise_refiner.1.adaLN_modulation.0.weight" : torch.Size([15360, 256]),
    "noise_refiner.1.attention.k_norm.weight" : torch.Size([128]),
    "noise_refiner.1.attention.out.weight" : torch.Size([3840, 3840]),
    "noise_refiner.1.attention.q_norm.weight" : torch.Size([128]),
    "noise_refiner.1.attention.qkv.weight" : torch.Size([11520, 3840]),
    "noise_refiner.1.attention_norm1.weight" : torch.Size([3840]),
    "noise_refiner.1.attention_norm2.weight" : torch.Size([3840]),
    "noise_refiner.1.feed_forward.w1.weight" : torch.Size([10240, 3840]),
    "noise_refiner.1.feed_forward.w2.weight" : torch.Size([3840, 10240]),
    "noise_refiner.1.feed_forward.w3.weight" : torch.Size([10240, 3840]),
    "noise_refiner.1.ffn_norm1.weight" : torch.Size([3840]),
    "noise_refiner.1.ffn_norm2.weight" : torch.Size([3840]),
    "t_embedder.mlp.0.bias" : torch.Size([1024]),
    "t_embedder.mlp.0.weight" : torch.Size([1024, 256]),
    "t_embedder.mlp.2.bias" : torch.Size([256]),
    "t_embedder.mlp.2.weight" : torch.Size([256, 1024]),
    "x_embedder.bias" : torch.Size([3840]),
    "x_embedder.weight" : torch.Size([3840, 64]),
}

anima_keys_dict = {
    "t_embedder.1.linear_1.weight" : torch.Size([2048, 2048]),
    "t_embedder.1.linear_2.weight" : torch.Size([6144, 2048]),
    "x_embedder.proj.1.weight" : torch.Size([2048, 68]),
    "blocks.0.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.0.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.0.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.0.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.0.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.0.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.0.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.0.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.0.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.0.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.0.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.0.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.0.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.0.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.0.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.0.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.0.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.0.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.0.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.0.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.1.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.1.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.1.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.1.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.1.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.1.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.1.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.1.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.1.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.1.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.1.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.1.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.1.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.1.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.1.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.1.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.1.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.1.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.1.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.1.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.2.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.2.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.2.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.2.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.2.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.2.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.2.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.2.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.2.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.2.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.2.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.2.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.2.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.2.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.2.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.2.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.2.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.2.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.2.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.2.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.3.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.3.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.3.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.3.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.3.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.3.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.3.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.3.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.3.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.3.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.3.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.3.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.3.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.3.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.3.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.3.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.3.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.3.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.3.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.3.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.4.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.4.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.4.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.4.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.4.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.4.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.4.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.4.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.4.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.4.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.4.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.4.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.4.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.4.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.4.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.4.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.4.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.4.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.4.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.4.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.5.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.5.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.5.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.5.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.5.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.5.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.5.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.5.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.5.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.5.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.5.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.5.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.5.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.5.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.5.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.5.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.5.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.5.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.5.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.5.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.6.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.6.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.6.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.6.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.6.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.6.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.6.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.6.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.6.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.6.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.6.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.6.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.6.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.6.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.6.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.6.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.6.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.6.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.6.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.6.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.7.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.7.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.7.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.7.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.7.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.7.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.7.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.7.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.7.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.7.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.7.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.7.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.7.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.7.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.7.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.7.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.7.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.7.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.7.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.7.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.8.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.8.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.8.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.8.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.8.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.8.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.8.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.8.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.8.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.8.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.8.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.8.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.8.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.8.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.8.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.8.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.8.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.8.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.8.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.8.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.9.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.9.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.9.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.9.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.9.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.9.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.9.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.9.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.9.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.9.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.9.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.9.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.9.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.9.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.9.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.9.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.9.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.9.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.9.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.9.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.10.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.10.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.10.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.10.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.10.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.10.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.10.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.10.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.10.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.10.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.10.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.10.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.10.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.10.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.10.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.10.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.10.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.10.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.10.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.10.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.11.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.11.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.11.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.11.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.11.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.11.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.11.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.11.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.11.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.11.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.11.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.11.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.11.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.11.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.11.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.11.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.11.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.11.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.11.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.11.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.12.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.12.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.12.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.12.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.12.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.12.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.12.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.12.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.12.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.12.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.12.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.12.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.12.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.12.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.12.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.12.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.12.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.12.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.12.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.12.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.13.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.13.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.13.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.13.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.13.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.13.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.13.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.13.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.13.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.13.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.13.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.13.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.13.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.13.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.13.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.13.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.13.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.13.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.13.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.13.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.14.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.14.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.14.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.14.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.14.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.14.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.14.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.14.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.14.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.14.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.14.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.14.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.14.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.14.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.14.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.14.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.14.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.14.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.14.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.14.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.15.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.15.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.15.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.15.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.15.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.15.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.15.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.15.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.15.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.15.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.15.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.15.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.15.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.15.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.15.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.15.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.15.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.15.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.15.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.15.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.16.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.16.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.16.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.16.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.16.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.16.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.16.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.16.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.16.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.16.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.16.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.16.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.16.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.16.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.16.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.16.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.16.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.16.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.16.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.16.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.17.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.17.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.17.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.17.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.17.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.17.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.17.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.17.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.17.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.17.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.17.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.17.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.17.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.17.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.17.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.17.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.17.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.17.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.17.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.17.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.18.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.18.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.18.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.18.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.18.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.18.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.18.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.18.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.18.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.18.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.18.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.18.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.18.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.18.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.18.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.18.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.18.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.18.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.18.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.18.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.19.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.19.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.19.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.19.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.19.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.19.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.19.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.19.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.19.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.19.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.19.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.19.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.19.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.19.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.19.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.19.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.19.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.19.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.19.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.19.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.20.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.20.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.20.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.20.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.20.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.20.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.20.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.20.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.20.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.20.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.20.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.20.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.20.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.20.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.20.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.20.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.20.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.20.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.20.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.20.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.21.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.21.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.21.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.21.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.21.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.21.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.21.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.21.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.21.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.21.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.21.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.21.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.21.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.21.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.21.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.21.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.21.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.21.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.21.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.21.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.22.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.22.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.22.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.22.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.22.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.22.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.22.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.22.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.22.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.22.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.22.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.22.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.22.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.22.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.22.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.22.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.22.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.22.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.22.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.22.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.23.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.23.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.23.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.23.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.23.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.23.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.23.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.23.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.23.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.23.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.23.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.23.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.23.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.23.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.23.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.23.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.23.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.23.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.23.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.23.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.24.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.24.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.24.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.24.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.24.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.24.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.24.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.24.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.24.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.24.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.24.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.24.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.24.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.24.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.24.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.24.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.24.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.24.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.24.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.24.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.25.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.25.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.25.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.25.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.25.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.25.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.25.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.25.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.25.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.25.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.25.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.25.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.25.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.25.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.25.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.25.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.25.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.25.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.25.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.25.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.26.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.26.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.26.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.26.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.26.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.26.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.26.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.26.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.26.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.26.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.26.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.26.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.26.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.26.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.26.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.26.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.26.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.26.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.26.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.26.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "blocks.27.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.27.self_attn.q_norm.weight" : torch.Size([128]),
    "blocks.27.self_attn.k_proj.weight" : torch.Size([2048, 2048]),
    "blocks.27.self_attn.k_norm.weight" : torch.Size([128]),
    "blocks.27.self_attn.v_proj.weight" : torch.Size([2048, 2048]),
    "blocks.27.self_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.27.cross_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "blocks.27.cross_attn.q_norm.weight" : torch.Size([128]),
    "blocks.27.cross_attn.k_proj.weight" : torch.Size([2048, 1024]),
    "blocks.27.cross_attn.k_norm.weight" : torch.Size([128]),
    "blocks.27.cross_attn.v_proj.weight" : torch.Size([2048, 1024]),
    "blocks.27.cross_attn.output_proj.weight" : torch.Size([2048, 2048]),
    "blocks.27.mlp.layer1.weight" : torch.Size([8192, 2048]),
    "blocks.27.mlp.layer2.weight" : torch.Size([2048, 8192]),
    "blocks.27.adaln_modulation_self_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.27.adaln_modulation_self_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.27.adaln_modulation_cross_attn.1.weight" : torch.Size([256, 2048]),
    "blocks.27.adaln_modulation_cross_attn.2.weight" : torch.Size([6144, 256]),
    "blocks.27.adaln_modulation_mlp.1.weight" : torch.Size([256, 2048]),
    "blocks.27.adaln_modulation_mlp.2.weight" : torch.Size([6144, 256]),
    "final_layer.linear.weight" : torch.Size([64, 2048]),
    "final_layer.adaln_modulation.1.weight" : torch.Size([256, 2048]),
    "final_layer.adaln_modulation.2.weight" : torch.Size([4096, 256]),
    "t_embedding_norm.weight" : torch.Size([2048]),
    "llm_adapter.embed.weight" : torch.Size([32128, 1024]),
    "llm_adapter.blocks.0.norm_self_attn.weight" : torch.Size([1024]),
    "llm_adapter.blocks.0.self_attn.q_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.0.self_attn.q_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.0.self_attn.k_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.0.self_attn.k_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.0.self_attn.v_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.0.self_attn.o_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.0.norm_cross_attn.weight" : torch.Size([1024]),
    "llm_adapter.blocks.0.cross_attn.q_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.0.cross_attn.q_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.0.cross_attn.k_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.0.cross_attn.k_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.0.cross_attn.v_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.0.cross_attn.o_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.0.norm_mlp.weight" : torch.Size([1024]),
    "llm_adapter.blocks.0.mlp.0.weight" : torch.Size([4096, 1024]),
    "llm_adapter.blocks.0.mlp.0.bias" : torch.Size([4096]),
    "llm_adapter.blocks.0.mlp.2.weight" : torch.Size([1024, 4096]),
    "llm_adapter.blocks.0.mlp.2.bias" : torch.Size([1024]),
    "llm_adapter.blocks.1.norm_self_attn.weight" : torch.Size([1024]),
    "llm_adapter.blocks.1.self_attn.q_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.1.self_attn.q_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.1.self_attn.k_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.1.self_attn.k_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.1.self_attn.v_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.1.self_attn.o_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.1.norm_cross_attn.weight" : torch.Size([1024]),
    "llm_adapter.blocks.1.cross_attn.q_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.1.cross_attn.q_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.1.cross_attn.k_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.1.cross_attn.k_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.1.cross_attn.v_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.1.cross_attn.o_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.1.norm_mlp.weight" : torch.Size([1024]),
    "llm_adapter.blocks.1.mlp.0.weight" : torch.Size([4096, 1024]),
    "llm_adapter.blocks.1.mlp.0.bias" : torch.Size([4096]),
    "llm_adapter.blocks.1.mlp.2.weight" : torch.Size([1024, 4096]),
    "llm_adapter.blocks.1.mlp.2.bias" : torch.Size([1024]),
    "llm_adapter.blocks.2.norm_self_attn.weight" : torch.Size([1024]),
    "llm_adapter.blocks.2.self_attn.q_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.2.self_attn.q_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.2.self_attn.k_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.2.self_attn.k_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.2.self_attn.v_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.2.self_attn.o_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.2.norm_cross_attn.weight" : torch.Size([1024]),
    "llm_adapter.blocks.2.cross_attn.q_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.2.cross_attn.q_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.2.cross_attn.k_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.2.cross_attn.k_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.2.cross_attn.v_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.2.cross_attn.o_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.2.norm_mlp.weight" : torch.Size([1024]),
    "llm_adapter.blocks.2.mlp.0.weight" : torch.Size([4096, 1024]),
    "llm_adapter.blocks.2.mlp.0.bias" : torch.Size([4096]),
    "llm_adapter.blocks.2.mlp.2.weight" : torch.Size([1024, 4096]),
    "llm_adapter.blocks.2.mlp.2.bias" : torch.Size([1024]),
    "llm_adapter.blocks.3.norm_self_attn.weight" : torch.Size([1024]),
    "llm_adapter.blocks.3.self_attn.q_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.3.self_attn.q_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.3.self_attn.k_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.3.self_attn.k_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.3.self_attn.v_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.3.self_attn.o_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.3.norm_cross_attn.weight" : torch.Size([1024]),
    "llm_adapter.blocks.3.cross_attn.q_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.3.cross_attn.q_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.3.cross_attn.k_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.3.cross_attn.k_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.3.cross_attn.v_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.3.cross_attn.o_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.3.norm_mlp.weight" : torch.Size([1024]),
    "llm_adapter.blocks.3.mlp.0.weight" : torch.Size([4096, 1024]),
    "llm_adapter.blocks.3.mlp.0.bias" : torch.Size([4096]),
    "llm_adapter.blocks.3.mlp.2.weight" : torch.Size([1024, 4096]),
    "llm_adapter.blocks.3.mlp.2.bias" : torch.Size([1024]),
    "llm_adapter.blocks.4.norm_self_attn.weight" : torch.Size([1024]),
    "llm_adapter.blocks.4.self_attn.q_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.4.self_attn.q_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.4.self_attn.k_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.4.self_attn.k_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.4.self_attn.v_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.4.self_attn.o_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.4.norm_cross_attn.weight" : torch.Size([1024]),
    "llm_adapter.blocks.4.cross_attn.q_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.4.cross_attn.q_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.4.cross_attn.k_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.4.cross_attn.k_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.4.cross_attn.v_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.4.cross_attn.o_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.4.norm_mlp.weight" : torch.Size([1024]),
    "llm_adapter.blocks.4.mlp.0.weight" : torch.Size([4096, 1024]),
    "llm_adapter.blocks.4.mlp.0.bias" : torch.Size([4096]),
    "llm_adapter.blocks.4.mlp.2.weight" : torch.Size([1024, 4096]),
    "llm_adapter.blocks.4.mlp.2.bias" : torch.Size([1024]),
    "llm_adapter.blocks.5.norm_self_attn.weight" : torch.Size([1024]),
    "llm_adapter.blocks.5.self_attn.q_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.5.self_attn.q_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.5.self_attn.k_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.5.self_attn.k_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.5.self_attn.v_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.5.self_attn.o_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.5.norm_cross_attn.weight" : torch.Size([1024]),
    "llm_adapter.blocks.5.cross_attn.q_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.5.cross_attn.q_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.5.cross_attn.k_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.5.cross_attn.k_norm.weight" : torch.Size([64]),
    "llm_adapter.blocks.5.cross_attn.v_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.5.cross_attn.o_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.blocks.5.norm_mlp.weight" : torch.Size([1024]),
    "llm_adapter.blocks.5.mlp.0.weight" : torch.Size([4096, 1024]),
    "llm_adapter.blocks.5.mlp.0.bias" : torch.Size([4096]),
    "llm_adapter.blocks.5.mlp.2.weight" : torch.Size([1024, 4096]),
    "llm_adapter.blocks.5.mlp.2.bias" : torch.Size([1024]),
    "llm_adapter.out_proj.weight" : torch.Size([1024, 1024]),
    "llm_adapter.out_proj.bias" : torch.Size([1024]),
    "llm_adapter.norm.weight" : torch.Size([1024]),
}

flux2_keys_dict = {
    "double_blocks.0.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.0.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.0.img_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.0.img_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.0.img_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.0.img_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.0.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.0.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.0.txt_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.0.txt_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.0.txt_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.0.txt_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.1.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.1.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.1.img_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.1.img_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.1.img_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.1.img_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.1.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.1.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.1.txt_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.1.txt_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.1.txt_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.1.txt_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.2.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.2.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.2.img_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.2.img_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.2.img_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.2.img_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.2.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.2.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.2.txt_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.2.txt_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.2.txt_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.2.txt_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.3.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.3.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.3.img_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.3.img_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.3.img_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.3.img_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.3.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.3.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.3.txt_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.3.txt_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.3.txt_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.3.txt_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.4.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.4.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.4.img_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.4.img_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.4.img_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.4.img_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.4.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.4.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.4.txt_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.4.txt_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.4.txt_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.4.txt_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.5.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.5.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.5.img_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.5.img_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.5.img_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.5.img_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.5.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.5.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.5.txt_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.5.txt_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.5.txt_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.5.txt_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.6.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.6.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.6.img_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.6.img_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.6.img_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.6.img_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.6.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.6.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.6.txt_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.6.txt_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.6.txt_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.6.txt_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.7.img_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.7.img_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.7.img_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.7.img_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.7.img_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.7.img_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_blocks.7.txt_attn.norm.key_norm.scale" : torch.Size([128]),
    "double_blocks.7.txt_attn.norm.query_norm.scale" : torch.Size([128]),
    "double_blocks.7.txt_attn.proj.weight" : torch.Size([4096, 4096]),
    "double_blocks.7.txt_attn.qkv.weight" : torch.Size([12288, 4096]),
    "double_blocks.7.txt_mlp.0.weight" : torch.Size([24576, 4096]),
    "double_blocks.7.txt_mlp.2.weight" : torch.Size([4096, 12288]),
    "double_stream_modulation_img.lin.weight" : torch.Size([24576, 4096]),
    "double_stream_modulation_txt.lin.weight" : torch.Size([24576, 4096]),
    "final_layer.adaLN_modulation.1.weight" : torch.Size([8192, 4096]),
    "final_layer.linear.weight" : torch.Size([128, 4096]),
    "img_in.weight" : torch.Size([4096, 128]),
    "single_blocks.0.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.0.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.0.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.0.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.1.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.1.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.1.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.1.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.10.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.10.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.10.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.10.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.11.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.11.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.11.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.11.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.12.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.12.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.12.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.12.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.13.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.13.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.13.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.13.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.14.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.14.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.14.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.14.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.15.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.15.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.15.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.15.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.16.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.16.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.16.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.16.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.17.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.17.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.17.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.17.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.18.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.18.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.18.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.18.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.19.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.19.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.19.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.19.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.2.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.2.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.2.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.2.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.20.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.20.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.20.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.20.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.21.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.21.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.21.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.21.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.22.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.22.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.22.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.22.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.23.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.23.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.23.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.23.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.3.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.3.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.3.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.3.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.4.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.4.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.4.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.4.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.5.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.5.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.5.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.5.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.6.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.6.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.6.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.6.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.7.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.7.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.7.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.7.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.8.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.8.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.8.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.8.norm.query_norm.scale" : torch.Size([128]),
    "single_blocks.9.linear1.weight" : torch.Size([36864, 4096]),
    "single_blocks.9.linear2.weight" : torch.Size([4096, 16384]),
    "single_blocks.9.norm.key_norm.scale" : torch.Size([128]),
    "single_blocks.9.norm.query_norm.scale" : torch.Size([128]),
    "single_stream_modulation.lin.weight" : torch.Size([12288, 4096]),
    "time_in.in_layer.weight" : torch.Size([4096, 256]),
    "time_in.out_layer.weight" : torch.Size([4096, 4096]),
    "txt_in.weight" : torch.Size([4096, 12288]),
}

ace_step_keys_dict = {
    "decoder.scale_shift_table" : torch.Size([1, 2, 2560]),
    "decoder.proj_in.1.weight" : torch.Size([2560, 192, 2]),
    "decoder.proj_in.1.bias" : torch.Size([2560]),
    "decoder.time_embed.linear_1.weight" : torch.Size([2560, 256]),
    "decoder.time_embed.linear_1.bias" : torch.Size([2560]),
    "decoder.time_embed.linear_2.weight" : torch.Size([2560, 2560]),
    "decoder.time_embed.linear_2.bias" : torch.Size([2560]),
    "decoder.time_embed.time_proj.weight" : torch.Size([15360, 2560]),
    "decoder.time_embed.time_proj.bias" : torch.Size([15360]),
    "decoder.time_embed_r.linear_1.weight" : torch.Size([2560, 256]),
    "decoder.time_embed_r.linear_1.bias" : torch.Size([2560]),
    "decoder.time_embed_r.linear_2.weight" : torch.Size([2560, 2560]),
    "decoder.time_embed_r.linear_2.bias" : torch.Size([2560]),
    "decoder.time_embed_r.time_proj.weight" : torch.Size([15360, 2560]),
    "decoder.time_embed_r.time_proj.bias" : torch.Size([15360]),
    "decoder.condition_embedder.weight" : torch.Size([2560, 2048]),
    "decoder.condition_embedder.bias" : torch.Size([2560]),
    "decoder.layers.0.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.0.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.0.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.0.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.0.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.0.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.0.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.0.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.0.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.0.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.0.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.0.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.0.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.0.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.0.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.0.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.0.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.0.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.0.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.1.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.1.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.1.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.1.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.1.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.1.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.1.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.1.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.1.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.1.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.1.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.1.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.1.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.1.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.1.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.1.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.1.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.1.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.1.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.2.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.2.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.2.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.2.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.2.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.2.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.2.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.2.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.2.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.2.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.2.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.2.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.2.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.2.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.2.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.2.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.2.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.2.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.2.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.3.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.3.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.3.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.3.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.3.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.3.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.3.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.3.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.3.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.3.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.3.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.3.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.3.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.3.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.3.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.3.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.3.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.3.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.3.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.4.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.4.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.4.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.4.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.4.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.4.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.4.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.4.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.4.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.4.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.4.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.4.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.4.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.4.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.4.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.4.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.4.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.4.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.4.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.5.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.5.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.5.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.5.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.5.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.5.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.5.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.5.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.5.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.5.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.5.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.5.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.5.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.5.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.5.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.5.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.5.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.5.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.5.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.6.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.6.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.6.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.6.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.6.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.6.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.6.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.6.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.6.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.6.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.6.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.6.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.6.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.6.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.6.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.6.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.6.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.6.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.6.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.7.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.7.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.7.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.7.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.7.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.7.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.7.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.7.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.7.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.7.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.7.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.7.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.7.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.7.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.7.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.7.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.7.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.7.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.7.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.8.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.8.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.8.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.8.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.8.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.8.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.8.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.8.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.8.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.8.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.8.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.8.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.8.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.8.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.8.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.8.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.8.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.8.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.8.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.9.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.9.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.9.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.9.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.9.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.9.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.9.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.9.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.9.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.9.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.9.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.9.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.9.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.9.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.9.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.9.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.9.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.9.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.9.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.10.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.10.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.10.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.10.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.10.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.10.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.10.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.10.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.10.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.10.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.10.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.10.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.10.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.10.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.10.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.10.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.10.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.10.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.10.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.11.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.11.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.11.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.11.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.11.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.11.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.11.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.11.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.11.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.11.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.11.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.11.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.11.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.11.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.11.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.11.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.11.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.11.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.11.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.12.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.12.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.12.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.12.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.12.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.12.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.12.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.12.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.12.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.12.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.12.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.12.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.12.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.12.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.12.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.12.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.12.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.12.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.12.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.13.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.13.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.13.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.13.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.13.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.13.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.13.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.13.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.13.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.13.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.13.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.13.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.13.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.13.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.13.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.13.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.13.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.13.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.13.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.14.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.14.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.14.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.14.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.14.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.14.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.14.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.14.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.14.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.14.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.14.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.14.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.14.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.14.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.14.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.14.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.14.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.14.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.14.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.15.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.15.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.15.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.15.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.15.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.15.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.15.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.15.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.15.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.15.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.15.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.15.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.15.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.15.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.15.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.15.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.15.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.15.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.15.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.16.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.16.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.16.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.16.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.16.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.16.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.16.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.16.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.16.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.16.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.16.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.16.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.16.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.16.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.16.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.16.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.16.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.16.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.16.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.17.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.17.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.17.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.17.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.17.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.17.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.17.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.17.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.17.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.17.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.17.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.17.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.17.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.17.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.17.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.17.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.17.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.17.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.17.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.18.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.18.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.18.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.18.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.18.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.18.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.18.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.18.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.18.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.18.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.18.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.18.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.18.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.18.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.18.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.18.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.18.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.18.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.18.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.19.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.19.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.19.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.19.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.19.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.19.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.19.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.19.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.19.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.19.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.19.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.19.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.19.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.19.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.19.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.19.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.19.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.19.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.19.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.20.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.20.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.20.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.20.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.20.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.20.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.20.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.20.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.20.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.20.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.20.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.20.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.20.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.20.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.20.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.20.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.20.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.20.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.20.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.21.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.21.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.21.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.21.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.21.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.21.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.21.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.21.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.21.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.21.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.21.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.21.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.21.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.21.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.21.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.21.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.21.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.21.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.21.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.22.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.22.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.22.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.22.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.22.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.22.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.22.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.22.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.22.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.22.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.22.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.22.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.22.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.22.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.22.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.22.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.22.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.22.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.22.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.23.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.23.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.23.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.23.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.23.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.23.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.23.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.23.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.23.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.23.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.23.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.23.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.23.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.23.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.23.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.23.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.23.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.23.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.23.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.24.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.24.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.24.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.24.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.24.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.24.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.24.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.24.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.24.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.24.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.24.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.24.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.24.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.24.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.24.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.24.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.24.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.24.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.24.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.25.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.25.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.25.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.25.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.25.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.25.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.25.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.25.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.25.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.25.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.25.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.25.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.25.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.25.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.25.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.25.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.25.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.25.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.25.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.26.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.26.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.26.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.26.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.26.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.26.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.26.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.26.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.26.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.26.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.26.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.26.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.26.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.26.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.26.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.26.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.26.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.26.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.26.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.27.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.27.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.27.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.27.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.27.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.27.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.27.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.27.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.27.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.27.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.27.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.27.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.27.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.27.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.27.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.27.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.27.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.27.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.27.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.28.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.28.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.28.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.28.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.28.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.28.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.28.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.28.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.28.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.28.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.28.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.28.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.28.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.28.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.28.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.28.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.28.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.28.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.28.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.29.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.29.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.29.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.29.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.29.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.29.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.29.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.29.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.29.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.29.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.29.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.29.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.29.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.29.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.29.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.29.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.29.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.29.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.29.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.30.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.30.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.30.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.30.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.30.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.30.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.30.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.30.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.30.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.30.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.30.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.30.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.30.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.30.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.30.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.30.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.30.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.30.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.30.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.layers.31.scale_shift_table" : torch.Size([1, 6, 2560]),
    "decoder.layers.31.self_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.31.self_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.31.self_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.31.self_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.31.self_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.31.self_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.31.self_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.31.cross_attn_norm.weight" : torch.Size([2560]),
    "decoder.layers.31.cross_attn.q_proj.weight" : torch.Size([4096, 2560]),
    "decoder.layers.31.cross_attn.k_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.31.cross_attn.v_proj.weight" : torch.Size([1024, 2560]),
    "decoder.layers.31.cross_attn.o_proj.weight" : torch.Size([2560, 4096]),
    "decoder.layers.31.cross_attn.q_norm.weight" : torch.Size([128]),
    "decoder.layers.31.cross_attn.k_norm.weight" : torch.Size([128]),
    "decoder.layers.31.mlp_norm.weight" : torch.Size([2560]),
    "decoder.layers.31.mlp.gate_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.31.mlp.up_proj.weight" : torch.Size([9728, 2560]),
    "decoder.layers.31.mlp.down_proj.weight" : torch.Size([2560, 9728]),
    "decoder.norm_out.weight" : torch.Size([2560]),
    "decoder.proj_out.1.weight" : torch.Size([2560, 64, 2]),
    "decoder.proj_out.1.bias" : torch.Size([64]),
    "encoder.text_projector.weight" : torch.Size([2048, 1024]),
    "encoder.lyric_encoder.embed_tokens.weight" : torch.Size([2048, 1024]),
    "encoder.lyric_encoder.embed_tokens.bias" : torch.Size([2048]),
    "encoder.lyric_encoder.norm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.0.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.0.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.0.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.0.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.0.self_attn.q_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.0.self_attn.k_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.0.input_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.0.post_attention_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.0.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.0.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.0.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "encoder.lyric_encoder.layers.1.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.1.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.1.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.1.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.1.self_attn.q_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.1.self_attn.k_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.1.input_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.1.post_attention_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.1.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.1.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.1.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "encoder.lyric_encoder.layers.2.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.2.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.2.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.2.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.2.self_attn.q_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.2.self_attn.k_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.2.input_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.2.post_attention_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.2.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.2.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.2.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "encoder.lyric_encoder.layers.3.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.3.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.3.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.3.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.3.self_attn.q_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.3.self_attn.k_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.3.input_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.3.post_attention_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.3.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.3.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.3.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "encoder.lyric_encoder.layers.4.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.4.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.4.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.4.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.4.self_attn.q_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.4.self_attn.k_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.4.input_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.4.post_attention_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.4.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.4.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.4.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "encoder.lyric_encoder.layers.5.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.5.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.5.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.5.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.5.self_attn.q_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.5.self_attn.k_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.5.input_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.5.post_attention_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.5.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.5.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.5.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "encoder.lyric_encoder.layers.6.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.6.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.6.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.6.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.6.self_attn.q_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.6.self_attn.k_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.6.input_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.6.post_attention_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.6.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.6.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.6.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "encoder.lyric_encoder.layers.7.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.7.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.7.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "encoder.lyric_encoder.layers.7.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "encoder.lyric_encoder.layers.7.self_attn.q_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.7.self_attn.k_norm.weight" : torch.Size([128]),
    "encoder.lyric_encoder.layers.7.input_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.7.post_attention_layernorm.weight" : torch.Size([2048]),
    "encoder.lyric_encoder.layers.7.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.7.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "encoder.lyric_encoder.layers.7.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "encoder.timbre_encoder.special_token" : torch.Size([1, 1, 2048]),
    "encoder.timbre_encoder.embed_tokens.weight" : torch.Size([2048, 64]),
    "encoder.timbre_encoder.embed_tokens.bias" : torch.Size([2048]),
    "encoder.timbre_encoder.norm.weight" : torch.Size([2048]),
    "encoder.timbre_encoder.layers.0.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "encoder.timbre_encoder.layers.0.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "encoder.timbre_encoder.layers.0.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "encoder.timbre_encoder.layers.0.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "encoder.timbre_encoder.layers.0.self_attn.q_norm.weight" : torch.Size([128]),
    "encoder.timbre_encoder.layers.0.self_attn.k_norm.weight" : torch.Size([128]),
    "encoder.timbre_encoder.layers.0.input_layernorm.weight" : torch.Size([2048]),
    "encoder.timbre_encoder.layers.0.post_attention_layernorm.weight" : torch.Size([2048]),
    "encoder.timbre_encoder.layers.0.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "encoder.timbre_encoder.layers.0.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "encoder.timbre_encoder.layers.0.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "encoder.timbre_encoder.layers.1.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "encoder.timbre_encoder.layers.1.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "encoder.timbre_encoder.layers.1.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "encoder.timbre_encoder.layers.1.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "encoder.timbre_encoder.layers.1.self_attn.q_norm.weight" : torch.Size([128]),
    "encoder.timbre_encoder.layers.1.self_attn.k_norm.weight" : torch.Size([128]),
    "encoder.timbre_encoder.layers.1.input_layernorm.weight" : torch.Size([2048]),
    "encoder.timbre_encoder.layers.1.post_attention_layernorm.weight" : torch.Size([2048]),
    "encoder.timbre_encoder.layers.1.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "encoder.timbre_encoder.layers.1.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "encoder.timbre_encoder.layers.1.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "encoder.timbre_encoder.layers.2.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "encoder.timbre_encoder.layers.2.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "encoder.timbre_encoder.layers.2.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "encoder.timbre_encoder.layers.2.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "encoder.timbre_encoder.layers.2.self_attn.q_norm.weight" : torch.Size([128]),
    "encoder.timbre_encoder.layers.2.self_attn.k_norm.weight" : torch.Size([128]),
    "encoder.timbre_encoder.layers.2.input_layernorm.weight" : torch.Size([2048]),
    "encoder.timbre_encoder.layers.2.post_attention_layernorm.weight" : torch.Size([2048]),
    "encoder.timbre_encoder.layers.2.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "encoder.timbre_encoder.layers.2.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "encoder.timbre_encoder.layers.2.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "encoder.timbre_encoder.layers.3.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "encoder.timbre_encoder.layers.3.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "encoder.timbre_encoder.layers.3.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "encoder.timbre_encoder.layers.3.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "encoder.timbre_encoder.layers.3.self_attn.q_norm.weight" : torch.Size([128]),
    "encoder.timbre_encoder.layers.3.self_attn.k_norm.weight" : torch.Size([128]),
    "encoder.timbre_encoder.layers.3.input_layernorm.weight" : torch.Size([2048]),
    "encoder.timbre_encoder.layers.3.post_attention_layernorm.weight" : torch.Size([2048]),
    "encoder.timbre_encoder.layers.3.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "encoder.timbre_encoder.layers.3.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "encoder.timbre_encoder.layers.3.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "tokenizer.audio_acoustic_proj.weight" : torch.Size([2048, 64]),
    "tokenizer.audio_acoustic_proj.bias" : torch.Size([2048]),
    "tokenizer.attention_pooler.special_token" : torch.Size([1, 1, 2048]),
    "tokenizer.attention_pooler.embed_tokens.weight" : torch.Size([2048, 2048]),
    "tokenizer.attention_pooler.embed_tokens.bias" : torch.Size([2048]),
    "tokenizer.attention_pooler.norm.weight" : torch.Size([2048]),
    "tokenizer.attention_pooler.layers.0.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "tokenizer.attention_pooler.layers.0.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "tokenizer.attention_pooler.layers.0.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "tokenizer.attention_pooler.layers.0.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "tokenizer.attention_pooler.layers.0.self_attn.q_norm.weight" : torch.Size([128]),
    "tokenizer.attention_pooler.layers.0.self_attn.k_norm.weight" : torch.Size([128]),
    "tokenizer.attention_pooler.layers.0.input_layernorm.weight" : torch.Size([2048]),
    "tokenizer.attention_pooler.layers.0.post_attention_layernorm.weight" : torch.Size([2048]),
    "tokenizer.attention_pooler.layers.0.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "tokenizer.attention_pooler.layers.0.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "tokenizer.attention_pooler.layers.0.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "tokenizer.attention_pooler.layers.1.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "tokenizer.attention_pooler.layers.1.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "tokenizer.attention_pooler.layers.1.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "tokenizer.attention_pooler.layers.1.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "tokenizer.attention_pooler.layers.1.self_attn.q_norm.weight" : torch.Size([128]),
    "tokenizer.attention_pooler.layers.1.self_attn.k_norm.weight" : torch.Size([128]),
    "tokenizer.attention_pooler.layers.1.input_layernorm.weight" : torch.Size([2048]),
    "tokenizer.attention_pooler.layers.1.post_attention_layernorm.weight" : torch.Size([2048]),
    "tokenizer.attention_pooler.layers.1.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "tokenizer.attention_pooler.layers.1.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "tokenizer.attention_pooler.layers.1.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "tokenizer.quantizer.project_in.weight" : torch.Size([6, 2048]),
    "tokenizer.quantizer.project_in.bias" : torch.Size([6]),
    "tokenizer.quantizer.project_out.weight" : torch.Size([2048, 6]),
    "tokenizer.quantizer.project_out.bias" : torch.Size([2048]),
    "detokenizer.special_tokens" : torch.Size([1, 5, 2048]),
    "detokenizer.embed_tokens.weight" : torch.Size([2048, 2048]),
    "detokenizer.embed_tokens.bias" : torch.Size([2048]),
    "detokenizer.layers.0.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "detokenizer.layers.0.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "detokenizer.layers.0.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "detokenizer.layers.0.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "detokenizer.layers.0.self_attn.q_norm.weight" : torch.Size([128]),
    "detokenizer.layers.0.self_attn.k_norm.weight" : torch.Size([128]),
    "detokenizer.layers.0.input_layernorm.weight" : torch.Size([2048]),
    "detokenizer.layers.0.post_attention_layernorm.weight" : torch.Size([2048]),
    "detokenizer.layers.0.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "detokenizer.layers.0.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "detokenizer.layers.0.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "detokenizer.layers.1.self_attn.q_proj.weight" : torch.Size([2048, 2048]),
    "detokenizer.layers.1.self_attn.k_proj.weight" : torch.Size([1024, 2048]),
    "detokenizer.layers.1.self_attn.v_proj.weight" : torch.Size([1024, 2048]),
    "detokenizer.layers.1.self_attn.o_proj.weight" : torch.Size([2048, 2048]),
    "detokenizer.layers.1.self_attn.q_norm.weight" : torch.Size([128]),
    "detokenizer.layers.1.self_attn.k_norm.weight" : torch.Size([128]),
    "detokenizer.layers.1.input_layernorm.weight" : torch.Size([2048]),
    "detokenizer.layers.1.post_attention_layernorm.weight" : torch.Size([2048]),
    "detokenizer.layers.1.mlp.gate_proj.weight" : torch.Size([6144, 2048]),
    "detokenizer.layers.1.mlp.up_proj.weight" : torch.Size([6144, 2048]),
    "detokenizer.layers.1.mlp.down_proj.weight" : torch.Size([2048, 6144]),
    "detokenizer.norm.weight" : torch.Size([2048]),
    "detokenizer.proj_out.weight" : torch.Size([64, 2048]),
    "detokenizer.proj_out.bias" : torch.Size([64]),
}

state_dict_mapping = {
    "Chroma": chroma_keys_dict,
    "Flux": flux_keys_dict,
    "FluxSchnell": flux_schnell_keys_dict,
    "ErnieImage": ernie_keys_dict,
    "ZImage": zimage_keys_dict,
    "Flux2": flux2_keys_dict,
    "Anima": anima_keys_dict,
    "ACEStep15": ace_step_keys_dict,
}
