#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/diffusion_gemma/modular_diffusion_gemma.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_diffusion_gemma.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Callable
from dataclasses import dataclass
from typing import Optional

import torch
from torch import nn

from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...integrations import use_experts_implementation
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPast,
    BaseModelOutputWithPooling,
    CausalLMOutputWithPast,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
    TransformersKwargs,
    auto_docstring,
    can_return_tuple,
    torch_compilable_check,
)
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ..auto import AutoModel
from .configuration_diffusion_gemma import DiffusionGemmaConfig, DiffusionGemmaTextConfig
from .generation_diffusion_gemma import DiffusionGemmaGenerationConfig, DiffusionGemmaGenerationMixin


class DiffusionGemmaTextRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, config: DiffusionGemmaTextConfig, device=None, layer_type=None):
        super().__init__()
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.layer_types = set(config.layer_types)
        self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {}
        self.rope_type: dict[str, str] = {}

        for layer_type in self.layer_types:
            rope_params = self.config.rope_parameters[layer_type]
            if rope_params is None:
                continue

            if (rope_type := rope_params["rope_type"]) != "default":
                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
            else:
                rope_init_fn = self.compute_default_rope_parameters

            self.rope_init_fns[layer_type] = rope_init_fn
            self.rope_type[layer_type] = rope_type

            rope_init_fn_kwargs = {"device": device, "layer_type": layer_type}
            if layer_type == "full_attention" and rope_type == "proportional":
                rope_init_fn_kwargs["head_dim_key"] = "global_head_dim"

            curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, **rope_init_fn_kwargs)
            self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
            self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
            setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)

    @staticmethod
    def compute_default_rope_parameters(
        config: DiffusionGemmaTextConfig | None = None,
        device: Optional["torch.device"] = None,
        seq_len: int | None = None,
        layer_type: str | None = None,
    ) -> tuple["torch.Tensor", float]:
        """
        Computes the inverse frequencies according to the original RoPE implementation
        Args:
            config ([`~transformers.PreTrainedConfig`]):
                The model configuration.
            device (`torch.device`):
                The device to use for initialization of the inverse frequencies.
            seq_len (`int`, *optional*):
                The current sequence length. Unused for this type of RoPE.
            layer_type (`str`, *optional*):
                The current layer type if the model has different RoPE parameters per type.
                Should not be used unless `config.layer_types is not None`

        Returns:
            Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
            post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
        """
        # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
        base = config.rope_parameters[layer_type]["rope_theta"]
        dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads

        attention_factor = 1.0  # Unused in this type of RoPE

        # Compute the inverse frequencies
        inv_freq = 1.0 / (
            base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
        )
        return inv_freq, attention_factor

    @torch.no_grad()
    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
    def forward(self, x, position_ids, layer_type=None):
        inv_freq = getattr(self, f"{layer_type}_inv_freq")
        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")

        inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with maybe_autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * attention_scaling
            sin = emb.sin() * attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class DiffusionGemmaRMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
        super().__init__()
        self.eps = eps
        self.with_scale = with_scale

        if self.with_scale:
            self.weight = nn.Parameter(torch.ones(dim), requires_grad=True)

    def _norm(self, hidden_states: torch.Tensor):
        mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps
        # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX
        return hidden_states * torch.pow(mean_squared, -0.5)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        normed_output = self._norm(hidden_states.float())
        if self.with_scale:
            normed_output = normed_output * self.weight.float()
        return normed_output.type_as(hidden_states)


class DiffusionGemmaClippableLinear(nn.Module):
    def __init__(
        self,
        config: PreTrainedConfig,
        in_features: int,
        out_features: int,
    ) -> None:
        super().__init__()
        self.use_clipped_linears = config.use_clipped_linears
        self.linear = nn.Linear(in_features, out_features, bias=False)

        if self.use_clipped_linears:
            self.register_buffer("input_min", torch.tensor(-float("inf")))
            self.register_buffer("input_max", torch.tensor(float("inf")))
            self.register_buffer("output_min", torch.tensor(-float("inf")))
            self.register_buffer("output_max", torch.tensor(float("inf")))

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.use_clipped_linears:
            hidden_states = torch.clamp(hidden_states, self.input_min, self.input_max)

        hidden_states = self.linear(hidden_states)

        if self.use_clipped_linears:
            hidden_states = torch.clamp(hidden_states, self.output_min, self.output_max)

        return hidden_states


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        x (`torch.Tensor`): The tensor to embed.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    return (x * cos) + (rotate_half(x) * sin)


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor | None,
    dropout: float | int = 0.0,
    scaling: float | None = None,
    softcap: float | None = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    if scaling is None:
        scaling = module.head_dim**-0.5

    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling

    if softcap is not None:
        attn_weights = attn_weights / softcap
        attn_weights = torch.tanh(attn_weights)
        attn_weights = attn_weights * softcap
    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    return attn_output, attn_weights


class DiffusionGemmaEncoderTextAttention(nn.Module):
    """Attention layer for the diffusion model.

    This layer is just like `Gemma4TextAttention`, with one key differences:
    1. Removes shared KV cache logic, as it is unused in DiffusionGemma.
    """

    def __init__(self, config: DiffusionGemmaTextConfig, layer_idx: int):
        super().__init__()
        self.is_causal = config.use_bidirectional_attention != "all"

        self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
        self.config = config
        self.layer_idx = layer_idx
        self.is_sliding = self.layer_type == "sliding_attention"
        self.sliding_window = config.sliding_window if self.is_sliding else None

        self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
        num_key_value_heads = config.num_global_key_value_heads if not self.is_sliding else config.num_key_value_heads
        self.num_key_value_groups = config.num_attention_heads // num_key_value_heads
        self.scaling = 1.0
        self.attention_dropout = self.config.attention_dropout

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = (
            nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias)
            if self.is_sliding
            else None
        )

        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )

        self.q_norm = DiffusionGemmaRMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = DiffusionGemmaRMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
        self.v_norm = DiffusionGemmaRMSNorm(dim=self.head_dim, eps=config.rms_norm_eps, with_scale=False)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: torch.Tensor,
        attention_mask: torch.Tensor | None,
        past_key_values: Cache | None = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
        # The code in this function is adapted from Gemma4TextAttention. ** The modified parts are clearly indicated **
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        cos, sin = position_embeddings

        query_states = self.q_proj(hidden_states).view(hidden_shape)
        query_states = self.q_norm(query_states)
        query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
        query_states = query_states.transpose(1, 2)

        # CHANGED: removed `if self.is_kv_shared_layer` branch, kept the `else`
        key_states = self.k_proj(hidden_states).view(hidden_shape)
        value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states

        key_states = self.k_norm(key_states)
        key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
        key_states = key_states.transpose(1, 2)

        value_states = self.v_norm(value_states)
        value_states = value_states.transpose(1, 2)

        if past_key_values is not None:
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
        # CHANGED: removed the `if self.store_full_length_kv` branch

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=self.attention_dropout if self.training else 0.0,
            scaling=self.scaling,
            sliding_window=self.sliding_window,
            is_causal=self.is_causal,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class DiffusionGemmaDecoderTextAttention(nn.Module):
    """Attention layer for the diffusion model.

    This layer is just like `Gemma4TextAttention`, with three key differences:
    1. Removes shared KV cache logic, as it is unused in DiffusionGemma.
    2. It doesn't update the KV cache in the forward pass. The KV cache here corresponds to the
       encoder's KV cache, which is passed in via `past_key_values` -- from the decoder's perspective, it can be seen
       as a read-only encoder KV cache.
    3. `self.is_causal` is set to `False`. `config.use_bidirectional_attention` only controls the
       encoder, not the decoder attention.
    """

    def __init__(self, config: DiffusionGemmaTextConfig, layer_idx: int):
        super().__init__()
        self.is_causal = False  # In the decoder, attention is bidirectional!

        self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
        self.config = config
        self.layer_idx = layer_idx
        self.is_sliding = self.layer_type == "sliding_attention"
        self.sliding_window = config.sliding_window if self.is_sliding else None

        self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
        num_key_value_heads = config.num_global_key_value_heads if not self.is_sliding else config.num_key_value_heads
        self.num_key_value_groups = config.num_attention_heads // num_key_value_heads
        self.scaling = 1.0
        self.attention_dropout = self.config.attention_dropout

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = (
            nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias)
            if self.is_sliding
            else None
        )

        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )

        self.q_norm = DiffusionGemmaRMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = DiffusionGemmaRMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
        self.v_norm = DiffusionGemmaRMSNorm(dim=self.head_dim, eps=config.rms_norm_eps, with_scale=False)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: torch.Tensor,
        attention_mask: torch.Tensor | None,
        past_key_values: Cache | None = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
        # The code in this function is adapted from Gemma4TextAttention. ** The modified parts are clearly indicated **
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        cos, sin = position_embeddings

        query_states = self.q_proj(hidden_states).view(hidden_shape)
        query_states = self.q_norm(query_states)
        query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
        query_states = query_states.transpose(1, 2)

        # CHANGED: removed `if self.is_kv_shared_layer` branch, kept the `else`
        key_states = self.k_proj(hidden_states).view(hidden_shape)
        value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states

        key_states = self.k_norm(key_states)
        key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
        key_states = key_states.transpose(1, 2)

        value_states = self.v_norm(value_states)
        value_states = value_states.transpose(1, 2)

        if past_key_values is not None:
            # CHANGED: instead of calling `past_key_values.update()` which updates the KV cache in-place and returns
            # the full KV states, we first obtain the encoder cache contents, and then append the current KV states.
            encoder_key_states = past_key_values.layers[self.layer_idx].keys
            encoder_value_states = past_key_values.layers[self.layer_idx].values
            key_states = torch.cat([encoder_key_states, key_states], dim=2)
            value_states = torch.cat([encoder_value_states, value_states], dim=2)
        # CHANGED: removed the `if self.store_full_length_kv` branch

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=self.attention_dropout if self.training else 0.0,
            scaling=self.scaling,
            sliding_window=self.sliding_window,
            is_causal=self.is_causal,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class DiffusionGemmaText4MLP(nn.Module):
    def __init__(self, config: DiffusionGemmaTextConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_activation]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


class DiffusionGemmaTextRouter(nn.Module):
    def __init__(self, config: DiffusionGemmaTextConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.scalar_root_size = self.hidden_size**-0.5
        self.eps = config.rms_norm_eps

        self.norm = DiffusionGemmaRMSNorm(self.hidden_size, eps=self.eps, with_scale=False)
        self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False)
        self.scale = nn.Parameter(torch.ones(self.hidden_size))
        self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))

    def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        hidden_states = self.norm(hidden_states)
        hidden_states = hidden_states * self.scale * self.scalar_root_size

        expert_scores = self.proj(hidden_states)  # [B*S, E]
        # TODO(joao): propagate fp32 to gemma4 and delete the modular overwrite in DiffusionGemma
        router_probabilities = nn.functional.softmax(expert_scores, dim=-1, dtype=torch.float32)

        # topk returns both values (probabilities) and indices directly
        top_k_weights, top_k_index = torch.topk(
            router_probabilities,
            k=self.config.top_k_experts,
            dim=-1,
        )  # both [B*S, K]

        # Normalize the top-k weights so they sum to 1 per token
        top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)

        # Apply per-expert scale directly to the weights
        top_k_weights = top_k_weights * self.per_expert_scale[top_k_index]

        return router_probabilities, top_k_weights, top_k_index


@use_experts_implementation
class DiffusionGemmaTextExperts(nn.Module):
    """Collection of expert weights stored as 3D tensors."""

    def __init__(self, config: DiffusionGemmaTextConfig):
        super().__init__()
        self.num_experts = config.num_experts
        self.hidden_dim = config.hidden_size
        self.intermediate_dim = config.moe_intermediate_size
        self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
        self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
        self.act_fn = ACT2FN[config.hidden_activation]

    def forward(
        self,
        hidden_states: torch.Tensor,
        top_k_index: torch.Tensor,
        top_k_weights: torch.Tensor,
    ) -> torch.Tensor:
        final_hidden_states = torch.zeros_like(hidden_states)
        with torch.no_grad():
            expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
            expert_mask = expert_mask.permute(2, 1, 0)
            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

        for expert_idx in expert_hit:
            expert_idx = expert_idx[0]
            if expert_idx == self.num_experts:
                continue
            top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
            current_state = hidden_states[token_idx]
            gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
            current_hidden_states = self.act_fn(gate) * up
            current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
            current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
            final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

        return final_hidden_states


class DiffusionGemmaEncoderTextLayer(GradientCheckpointingLayer):
    """Encoder layer for the diffusion encoder.

    Identical to `Gemma4TextDecoderLayer` except that:
    1. It doesn't have the PLE code path
    2. Doesn't pipe `shared_kv_states` around
    """

    def __init__(self, config: DiffusionGemmaConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.layer_idx = layer_idx
        self.self_attn = DiffusionGemmaEncoderTextAttention(config=config, layer_idx=layer_idx)
        self.mlp = DiffusionGemmaText4MLP(config, layer_idx)
        self.input_layernorm = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
        self.pre_feedforward_layernorm = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
        self.post_feedforward_layernorm = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
        self.register_buffer("layer_scalar", torch.ones(1))

        self.router = DiffusionGemmaTextRouter(config)
        self.experts = DiffusionGemmaTextExperts(config)
        self.post_feedforward_layernorm_1 = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
        self.post_feedforward_layernorm_2 = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
        self.pre_feedforward_layernorm_2 = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: torch.Tensor = None,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        **kwargs,
    ) -> torch.Tensor:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            **kwargs,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states)

        # Take hidden states before MLP here
        hidden_states_flat = residual.reshape(-1, residual.shape[-1])
        hidden_states_2_for_routing = hidden_states_flat
        hidden_states_2_for_experts = self.pre_feedforward_layernorm_2(hidden_states_flat)
        _, top_k_weights, top_k_index = self.router(hidden_states_2_for_routing)
        hidden_states_2 = self.experts(hidden_states_2_for_experts, top_k_index, top_k_weights)
        hidden_states_2 = hidden_states_2.reshape(residual.shape)
        hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2)

        # Combine mlp and moe outputs
        hidden_states = hidden_states_1 + hidden_states_2

        hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        hidden_states *= self.layer_scalar
        return hidden_states


class DiffusionGemmaDecoderTextLayer(GradientCheckpointingLayer):
    """Decoder layer for the diffusion decoder.

    Identical to `Gemma4TextDecoderLayer` except that:
    1. Uses `DiffusionGemmaDecoderTextAttention`, which reads from the encoder KV cache without updating it
    2. It doesn't have the PLE code path
    3. Doesn't pipe `shared_kv_states` around
    """

    def __init__(self, config: DiffusionGemmaConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.layer_idx = layer_idx
        self.self_attn = DiffusionGemmaDecoderTextAttention(config=config, layer_idx=layer_idx)
        self.mlp = DiffusionGemmaText4MLP(config, layer_idx)
        self.input_layernorm = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
        self.pre_feedforward_layernorm = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
        self.post_feedforward_layernorm = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
        self.register_buffer("layer_scalar", torch.ones(1))

        self.router = DiffusionGemmaTextRouter(config)
        self.experts = DiffusionGemmaTextExperts(config)
        self.post_feedforward_layernorm_1 = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
        self.post_feedforward_layernorm_2 = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
        self.pre_feedforward_layernorm_2 = DiffusionGemmaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: torch.Tensor = None,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        **kwargs,
    ) -> torch.Tensor:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            **kwargs,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states)

        # Take hidden states before MLP here
        hidden_states_flat = residual.reshape(-1, residual.shape[-1])
        hidden_states_2_for_routing = hidden_states_flat
        hidden_states_2_for_experts = self.pre_feedforward_layernorm_2(hidden_states_flat)
        _, top_k_weights, top_k_index = self.router(hidden_states_2_for_routing)
        hidden_states_2 = self.experts(hidden_states_2_for_experts, top_k_index, top_k_weights)
        hidden_states_2 = hidden_states_2.reshape(residual.shape)
        hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2)

        # Combine mlp and moe outputs
        hidden_states = hidden_states_1 + hidden_states_2

        hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        hidden_states *= self.layer_scalar
        return hidden_states


class DiffusionGemmaTextScaledWordEmbedding(nn.Embedding):
    """
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
    """

    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
        super().__init__(num_embeddings, embedding_dim, padding_idx)
        self.scalar_embed_scale = embed_scale
        self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)

    def forward(self, input_ids: torch.Tensor):
        return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)


class DiffusionGemmaMultimodalEmbedder(nn.Module):
    """Embeds token ids or soft tokens for multimodal content into language model space."""

    def __init__(
        self,
        multimodal_config: PreTrainedConfig,
        text_config: DiffusionGemmaTextConfig,
    ):
        super().__init__()

        self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size)
        self.eps = multimodal_config.rms_norm_eps
        self.text_hidden_size = text_config.hidden_size
        self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
        self.embedding_pre_projection_norm = DiffusionGemmaRMSNorm(
            self.multimodal_hidden_size, eps=self.eps, with_scale=False
        )

    def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
        """Embeds token ids or soft tokens for multimodal content into language model space.
        Args:
            inputs_embeds: A torch.Tensor containing the soft tokens to embed.
        Returns:
            A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
        """
        embs_normed = self.embedding_pre_projection_norm(inputs_embeds)
        return self.embedding_projection(embs_normed)


class DiffusionGemmaSelfConditioning(nn.Module):
    """
    Self-conditioning module using a feed-forward block.

    Processes soft-embeddings from the previous denoising step, converted from the returned logits, into a
    self-conditioning signal that is added to the decoder's input embeddings. Uses Gemma4's Gated MLP structure,
    with pre/post rms norm.
    """

    def __init__(self, config: DiffusionGemmaTextConfig):
        super().__init__()
        hidden_size = config.hidden_size
        intermediate_size = config.intermediate_size

        self.pre_norm = DiffusionGemmaRMSNorm(hidden_size, eps=config.rms_norm_eps)
        self.post_norm = DiffusionGemmaRMSNorm(hidden_size, eps=config.rms_norm_eps, with_scale=False)
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_activation]

    def forward(self, inputs_embeds, self_conditioning_signal: torch.Tensor) -> torch.Tensor:
        """
        Args:
            self_conditioning_signal: Soft-embeddings from previous denoising step
                of shape `(batch_size, canvas_length, hidden_size)`.

        Returns:
            Processed self-conditioning signal, same shape.
        """
        normed = self.pre_norm(self_conditioning_signal)
        sc_signal = self.down_proj(self.act_fn(self.gate_proj(normed)) * self.up_proj(normed))
        combined = inputs_embeds + sc_signal
        return self.post_norm(combined)


class DiffusionGemmaPreTrainedModel(PreTrainedModel):
    config: DiffusionGemmaConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = False
    _no_split_modules = [
        "DiffusionGemmaDecoderTextLayer",
        "DiffusionGemmaEncoderTextLayer",
        "DiffusionGemmaVisionEncoderLayer",
    ]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn = True
    _supports_sdpa = True
    _supports_flex_attn = True
    _can_compile_fullgraph = True
    _supports_attention_backend = True
    _can_record_outputs = None  # override
    input_modalities = ("image", "text")

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        if isinstance(module, DiffusionGemmaTextRotaryEmbedding):
            for layer_type, rope_init_fn in module.rope_init_fns.items():
                rope_init_fn_kwargs = {"layer_type": layer_type}
                if layer_type == "full_attention" and module.rope_type[layer_type] == "proportional":
                    rope_init_fn_kwargs["head_dim_key"] = "global_head_dim"

                curr_inv_freq, _ = rope_init_fn(module.config, **rope_init_fn_kwargs)
                init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
                init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)

        elif isinstance(module, DiffusionGemmaTextScaledWordEmbedding):
            init.constant_(module.embed_scale, module.scalar_embed_scale)
        elif isinstance(module, DiffusionGemmaTextRouter):
            init.ones_(module.scale)
            init.ones_(module.per_expert_scale)
        elif isinstance(module, DiffusionGemmaTextExperts):
            std = self.config.initializer_range
            init.normal_(module.gate_up_proj, mean=0.0, std=std)
            init.normal_(module.down_proj, mean=0.0, std=std)
        elif isinstance(module, DiffusionGemmaDecoderTextLayer):
            init.ones_(module.layer_scalar)
        elif isinstance(module, DiffusionGemmaClippableLinear) and module.use_clipped_linears:
            init.constant_(module.input_min, -float("inf"))
            init.constant_(module.input_max, float("inf"))
            init.constant_(module.output_min, -float("inf"))
            init.constant_(module.output_max, float("inf"))
        # Gemma4 modules' classes won't be correctly expanded with modular, so we match the class name
        # Gemma4VisionPatchEmbedder
        elif module.__class__.__name__.endswith("VisionPatchEmbedder"):
            init.ones_(module.position_embedding_table)
        # Gemma4VisionRotaryEmbedding
        elif module.__class__.__name__.endswith("VisionRotaryEmbedding"):
            rope_fn = (
                ROPE_INIT_FUNCTIONS[module.rope_type]
                if module.rope_type != "default"
                else module.compute_default_rope_parameters
            )
            buffer_value, _ = rope_fn(module.config)
            init.copy_(module.inv_freq, buffer_value)
            init.copy_(module.original_inv_freq, buffer_value)
        # Gemma4VisionModel
        elif module.__class__.__name__.endswith("Gemma4VisionModel") and module.config.standardize:
            init.zeros_(module.std_bias)
            init.ones_(module.std_scale)


class DiffusionGemmaEncoderTextModel(DiffusionGemmaPreTrainedModel):
    config: DiffusionGemmaTextConfig
    input_modalities = ("text",)
    _can_record_outputs = {
        "router_logits": OutputRecorder(DiffusionGemmaTextRouter, index=0),
        "hidden_states": DiffusionGemmaEncoderTextLayer,
        "attentions": DiffusionGemmaEncoderTextAttention,
    }

    def __init__(self, config: DiffusionGemmaTextConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        # DiffusionGemmaEncoder downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
        self.embed_tokens = DiffusionGemmaTextScaledWordEmbedding(
            config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
        )
        self.layers = nn.ModuleList(
            [DiffusionGemmaEncoderTextLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = DiffusionGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = DiffusionGemmaTextRotaryEmbedding(config)
        self.unique_layer_types = set(config.layer_types)

        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    @auto_docstring
    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | dict | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutputWithPast:
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if input_ids is not None:
            inputs_embeds = self.embed_tokens(input_ids)

        if past_key_values is None:
            past_key_values = DynamicCache(config=self.config)

        if position_ids is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
            position_ids = position_ids.unsqueeze(0)

        # It may already have been prepared by e.g. `generate`
        if not isinstance(causal_mask_mapping := attention_mask, dict):
            # Prepare mask arguments
            mask_kwargs = {
                "config": self.config,
                "inputs_embeds": inputs_embeds,
                "attention_mask": attention_mask,
                "past_key_values": past_key_values,
                "position_ids": position_ids,
            }
            # Create the masks
            causal_mask_mapping = {
                "full_attention": create_causal_mask(**mask_kwargs),
                "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
            }

        # embed positions
        hidden_states = inputs_embeds
        position_embeddings = {}
        for layer_type in self.unique_layer_types:
            position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)

        # decoder layers
        for i, encoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
            hidden_states = encoder_layer(
                hidden_states,
                position_embeddings=position_embeddings[self.config.layer_types[i]],
                attention_mask=causal_mask_mapping[self.config.layer_types[i]],
                position_ids=position_ids,
                past_key_values=past_key_values,
                **kwargs,
            )

        hidden_states = self.norm(hidden_states)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
        )


def get_block_sequence_ids_for_mask(mm_token_type_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
    mm_token_type_ids = mm_token_type_ids.to(device)

    is_vision = (mm_token_type_ids == 1) | (mm_token_type_ids == 2)
    is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1)
    is_prev_vision[..., 0] = False
    new_vision_starts = is_vision & ~is_prev_vision
    vision_group_ids = torch.cumsum(new_vision_starts.int(), dim=1) - 1
    block_sequence_ids = torch.where(is_vision, vision_group_ids, -1)
    return block_sequence_ids


@auto_docstring(
    custom_intro="""
    The DiffusionGemma encoder model comprising a vision backbone and a language model, *without* a language modeling
    head. It is very similar to Gemma4Model, except that it doesn't support audio or video inputs, and always
    assumes the MoE code path in the inner layers.
    """
)
class DiffusionGemmaEncoderModel(DiffusionGemmaPreTrainedModel):
    # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
    accepts_loss_kwargs = False
    config: DiffusionGemmaConfig
    _can_record_outputs = {
        "router_logits": OutputRecorder(DiffusionGemmaTextRouter, index=0),
        "hidden_states": DiffusionGemmaEncoderTextLayer,
        "attentions": DiffusionGemmaEncoderTextAttention,
    }

    def __init__(self, config: DiffusionGemmaConfig):
        super().__init__(config)
        self.vocab_size = config.text_config.vocab_size

        self.language_model = DiffusionGemmaEncoderTextModel(config=config.text_config)
        self.vision_tower = AutoModel.from_config(config.vision_config)
        self.embed_vision = DiffusionGemmaMultimodalEmbedder(config.vision_config, config.text_config)

        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
    def get_image_features(
        self,
        pixel_values: torch.FloatTensor,
        image_position_ids: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutputWithPooling:
        r"""
        image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
            The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1).
        """
        vision_outputs = self.vision_tower(
            pixel_values=pixel_values,
            pixel_position_ids=image_position_ids,
            **kwargs,
        )
        last_hidden_state = vision_outputs.last_hidden_state
        vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state)
        return vision_outputs

    def get_placeholder_mask(
        self,
        input_ids: torch.LongTensor | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
    ) -> torch.BoolTensor:
        """
        Obtains mask for multimodal placeholders (replaced by soft tokens) and hard text tokens.

        Masks will be obtained from `input_ids` or `inputs_embeds` as available and in that
        precedence order.

        Args:
            input_ids: A tensor containing the hard token IDs from the text tokenizer.
            inputs_embeds: A tensor containing the embeddings for all hard text tokens.

        Returns:
            image_mask
        """
        if input_ids is not None:
            special_image_mask = input_ids == self.config.image_token_id
        else:
            image_token_embeddings = self.get_input_embeddings()(
                torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
            )
            special_image_mask = (inputs_embeds == image_token_embeddings).all(-1)

        return special_image_mask

    @merge_with_config_defaults
    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        pixel_values: torch.FloatTensor | None = None,
        attention_mask: torch.Tensor | dict | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        mm_token_type_ids: torch.LongTensor | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        image_position_ids: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutputWithPast:
        r"""
        image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
            2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding.
            Passed through to the vision encoder for positional embedding computation.
        """
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        image_mask = self.get_placeholder_mask(input_ids, inputs_embeds)

        # Replace image id with PAD if the image token if OOV, to avoid index-errors
        llm_input_ids = None
        if inputs_embeds is None:
            llm_input_ids = input_ids.clone()
            llm_input_ids[image_mask] = self.config.text_config.pad_token_id
            inputs_embeds = self.get_input_embeddings()(llm_input_ids)

        # Merge text and images
        if pixel_values is not None:
            image_features = self.get_image_features(pixel_values, image_position_ids, return_dict=True).pooler_output
            image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)

            # Confirm the number of soft tokens from the vision tower matches the number of slots in the embeddings.
            n_image_tokens = image_mask.sum()
            image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
            torch_compilable_check(
                inputs_embeds[image_mask].numel() == image_features.numel(),
                f"Image features and image tokens do not match, tokens: {n_image_tokens}, features:"
                f" {image_features.shape[0]}",
            )

            inputs_embeds = inputs_embeds.masked_scatter(
                image_mask.to(inputs_embeds.device), image_features.to(inputs_embeds.device)
            )

        # It may already have been prepared by, e.g., `generate`
        if position_ids is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
            position_ids = position_ids.unsqueeze(0)

        if not isinstance(causal_mask_mapping := attention_mask, dict):
            self.create_masks_for_generate(
                config=self.config.get_text_config(),
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                position_ids=position_ids,
                mm_token_type_ids=mm_token_type_ids,
            )

        outputs = self.language_model(
            attention_mask=causal_mask_mapping,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        return BaseModelOutputWithPast(
            last_hidden_state=outputs.last_hidden_state,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def get_per_layer_input_embeddings(self):
        return self.language_model.embed_tokens_per_layer

    def set_per_layer_input_embeddings(self, value):
        self.language_model.embed_tokens_per_layer = value

    @staticmethod
    def create_masks_for_generate(
        config: PreTrainedConfig,
        inputs_embeds: torch.Tensor,
        attention_mask: torch.Tensor | None,
        past_key_values: Cache | None,
        position_ids: torch.Tensor | None,
        mm_token_type_ids: torch.Tensor | None = None,
    ) -> dict:
        # TODO(joao): this fn exists in a gemma4 class, but not in Gemma4Model. Move it there, and remove the modular
        # overwrite in DiffusionGemma. Also rewrite Gemma4Model to use this function.
        mask_kwargs = {
            "config": config.get_text_config(),
            "inputs_embeds": inputs_embeds,
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "position_ids": position_ids,
        }

        # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs
        # Smaller Gemma models use a conventional casual attention mask
        if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision":
            block_sequence_ids = torch.full([*inputs_embeds.size()[:-1]], -1, device=inputs_embeds.device)
            if mm_token_type_ids is not None:
                block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device)

            mask_kwargs["block_sequence_ids"] = block_sequence_ids

        return create_masks_for_generate(**mask_kwargs)


class DiffusionGemmaDecoderModel(DiffusionGemmaPreTrainedModel):
    """
    Decoder model for DiffusionGemma.

    Processes canvas tokens with bidirectional self-attention and cross-attention to the encoder's KV cache.
    The decoder reads but does not update the KV cache. Excluding these differences, it is similar to
    `DiffusionGemmaEncoderTextModel`, and they share all weights they have in common.
    """

    config: DiffusionGemmaConfig
    input_modalities = ("text",)
    _can_record_outputs = {
        "router_logits": OutputRecorder(DiffusionGemmaTextRouter, index=0),
        "hidden_states": DiffusionGemmaDecoderTextLayer,
        "attentions": DiffusionGemmaDecoderTextAttention,
    }

    def __init__(self, config: DiffusionGemmaConfig):
        super().__init__(config)
        self.text_config = config.text_config
        self.padding_idx = config.text_config.pad_token_id
        self.vocab_size = config.text_config.vocab_size

        self.embed_tokens = DiffusionGemmaTextScaledWordEmbedding(
            num_embeddings=config.text_config.vocab_size,
            embedding_dim=config.text_config.hidden_size,
            padding_idx=self.padding_idx,
            embed_scale=config.text_config.hidden_size**0.5,
        )
        self.layers = nn.ModuleList(
            [
                DiffusionGemmaDecoderTextLayer(config.text_config, layer_idx)
                for layer_idx in range(config.text_config.num_hidden_layers)
            ]
        )
        self.norm = DiffusionGemmaRMSNorm(config.text_config.hidden_size, eps=config.text_config.rms_norm_eps)
        self.rotary_emb = DiffusionGemmaTextRotaryEmbedding(config.text_config)
        self.self_conditioning = DiffusionGemmaSelfConditioning(config.text_config)
        self.unique_layer_types = set(config.text_config.layer_types)

        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    @capture_outputs
    @auto_docstring
    def forward(
        self,
        decoder_input_ids: torch.LongTensor,
        past_key_values: Cache | None = None,
        self_conditioning_logits: torch.FloatTensor | None = None,
        self_conditioning_mask: torch.BoolTensor | None = None,
        decoder_attention_mask: torch.Tensor | dict | None = None,
        decoder_position_ids: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutput:
        r"""
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, canvas_length)`):
            Token IDs for the canvas to be refined.
        self_conditioning_logits (`torch.FloatTensor` of shape `(batch_size, canvas_length, vocab_size)`, *optional*):
            Self-conditioning logits from the previous denoising step, used to compute the
            self-conditioning embeddings.
        self_conditioning_mask (`torch.BoolTensor` of shape `(batch_size,)`, *optional*):
            Per-example mask over `self_conditioning_logits`: examples set to `False` get a zeroed self-conditioning
            signal, as if no logits were passed for them. Useful for training, where self-conditioning is enabled per
            example with some probability.
        decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length+canvas_length)` or `dict`, *optional*):
            Attention mask for the decoder KV cache. Used to specify padded/unpopulated encoder KV cached entries.
        decoder_position_ids (`torch.LongTensor` of shape `(batch_size, canvas_length)`, *optional*):
            The position IDs for the tokens in the canvas.
        """
        if "use_cache" in kwargs:
            raise ValueError(
                "The decoder of DiffusionGemma always uses a cache, so it doesn't accept the `use_cache` argument"
            )

        inputs_embeds = self.embed_tokens(decoder_input_ids)

        # If no self-conditioning signal is passed, the self-conditioning embeddings should be set to zeros.
        # This corresponds to the first denoising step.
        if self_conditioning_logits is not None:
            soft_embeddings = torch.matmul(
                self_conditioning_logits.softmax(dim=-1, dtype=torch.float32).to(self.embed_tokens.weight.dtype),
                self.embed_tokens.weight,
            ) * self.embed_tokens.embed_scale.to(inputs_embeds.dtype)
            if self_conditioning_mask is not None:
                soft_embeddings = soft_embeddings * self_conditioning_mask.to(soft_embeddings.dtype)[:, None, None]
        else:
            soft_embeddings = torch.zeros_like(inputs_embeds)
        inputs_embeds = self.self_conditioning(inputs_embeds, soft_embeddings)

        # The decoder positions continue after the encoder sequence. These are the position ids to be used in the
        # canvas.
        if decoder_position_ids is None:
            canvas_length = inputs_embeds.shape[1]
            cache_seq_length = past_key_values.get_seq_length(layer_idx=0) if past_key_values is not None else 0
            decoder_position_ids = torch.arange(
                cache_seq_length,
                cache_seq_length + canvas_length,
                device=inputs_embeds.device,
                dtype=torch.long,
            )
            decoder_position_ids = decoder_position_ids.unsqueeze(0)

        if not isinstance(mask_mapping := decoder_attention_mask, dict):
            mask_mapping = self.create_diffusion_decoder_attention_mask(
                config=self.text_config,
                inputs_embeds=inputs_embeds,
                past_key_values=past_key_values,
                decoder_attention_mask=decoder_attention_mask,
            )

        # Embed positions
        hidden_states = inputs_embeds
        position_embeddings = {}
        for layer_type in self.unique_layer_types:
            position_embeddings[layer_type] = self.rotary_emb(hidden_states, decoder_position_ids, layer_type)

        for i, decoder_layer in enumerate(self.layers[: self.text_config.num_hidden_layers]):
            hidden_states = decoder_layer(
                hidden_states,
                position_embeddings=position_embeddings[self.text_config.layer_types[i]],
                attention_mask=mask_mapping[self.text_config.layer_types[i]],
                position_ids=decoder_position_ids,
                past_key_values=past_key_values,
                **kwargs,
            )

        hidden_states = self.norm(hidden_states)

        # No past_key_values in the output: the decoder doesn't produce a KV cache
        return BaseModelOutput(last_hidden_state=hidden_states)

    @staticmethod
    def create_diffusion_decoder_attention_mask(
        config: DiffusionGemmaTextConfig,
        inputs_embeds: torch.Tensor,
        past_key_values: Cache,
        decoder_attention_mask: torch.Tensor | dict | None = None,
    ) -> dict[str, torch.Tensor | None]:
        """
        Creates the bidirectional attention mask for the decoder model.

        The decoder mask must have the length of the encoder kv cache plus the canvas being denoised, and it is
        bidirectional. The part of the attention mask corresponding to the encoder kv cache works like a usual
        bidirectional mask for an AR model -- it might be left or right padded. However, the part of the mask
        corresponding to the canvas is *always* set to 1.

        > [!TIP]
        > If `decoder_attention_mask` is manually set, be sure to follow the following practices:
        > 1. It has shape `(batch_size, sequence_length+canvas_length)`;
        > 2. The attention in the last `canvas_length` positions is set to 1s.

        A complex example:
        Let's consider a static-shaped KV cache with batch size = 2. One of the entries is left-padded, because
        it's shorter than the other. In our example, the canvas has a length of 4 tokens. Our cache has a length of 8
        tokens, and is pre-populated -- one of the sequences has 4 cached tokens, the other has 2 cached tokens
        (meaning that it has 2 left-padding tokens). Both sequences will have 4 empty positions in their cache.
        The produced attention mask corresponding to the encoder kv cache should be as follows

        indexing key: [batch_idx, canvas_idx]; shown dimension: kv attention
        [0, 0] ■ ■ ■ ■ ⬚ ⬚ ⬚ ⬚
        [0, 1] ■ ■ ■ ■ ⬚ ⬚ ⬚ ⬚
        [0, 2] ■ ■ ■ ■ ⬚ ⬚ ⬚ ⬚
        [0, 3] ■ ■ ■ ■ ⬚ ⬚ ⬚ ⬚
        [1, 0] ⬚ ⬚ ■ ■ ⬚ ⬚ ⬚ ⬚
        [1, 1] ⬚ ⬚ ■ ■ ⬚ ⬚ ⬚ ⬚
        [1, 2] ⬚ ⬚ ■ ■ ⬚ ⬚ ⬚ ⬚
        [1, 3] ⬚ ⬚ ■ ■ ⬚ ⬚ ⬚ ⬚

        In other words, the canvas will be able to attend to all non-padding and non-empty kv cache positions.
        To complete the attention mask, we add a bidirectional attention to the canvas tokens, resulting in the
        following final attention mask

        indexing key: [batch_idx, canvas_idx]; shown dimension: kv attention
        [0, 0] ■ ■ ■ ■ ⬚ ⬚ ⬚ ⬚ ■ ■ ■ ■
        [0, 1] ■ ■ ■ ■ ⬚ ⬚ ⬚ ⬚ ■ ■ ■ ■
        [0, 2] ■ ■ ■ ■ ⬚ ⬚ ⬚ ⬚ ■ ■ ■ ■
        [0, 3] ■ ■ ■ ■ ⬚ ⬚ ⬚ ⬚ ■ ■ ■ ■
        [1, 0] ⬚ ⬚ ■ ■ ⬚ ⬚ ⬚ ⬚ ■ ■ ■ ■
        [1, 1] ⬚ ⬚ ■ ■ ⬚ ⬚ ⬚ ⬚ ■ ■ ■ ■
        [1, 2] ⬚ ⬚ ■ ■ ⬚ ⬚ ⬚ ⬚ ■ ■ ■ ■
        [1, 3] ⬚ ⬚ ■ ■ ⬚ ⬚ ⬚ ⬚ ■ ■ ■ ■

        As a result, the canvas tokens for each batch index can attend to themselves, as well as to valid entries
        in the corresponding encoder kv cache.

        For more examples, see the tests for this function.

        Args:
            config (`DiffusionGemmaTextConfig`):
                The config used by the model.
            inputs_embeds (`torch.Tensor` of shape `(batch_size, canvas_length, hidden_dimension)`):
                The input embeddings used in the current forward pass. Only used to obtain the first two dimensions.
            past_key_values (`Cache`):
                The cache produced by the encoder part of the model.
            decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length+canvas_length)` or `dict`, *optional*):
                Attention mask for the decoder KV cache. Used to specify padded/unpopulated encoder KV cached entries.
        """

        # NOTE: common mask utilities like `create_bidirectional_mask` are NOT used here, as they contain a few subtle
        # AR assumptions. Example: in sliding window mask preparation, we consider a KV with length
        # `sliding_window - 1 + query_length`, where we want `sliding_window + query_length`
        # (https://github.com/huggingface/transformers/blame/b75feb2af64c3e29cbbc1bd859958c5432cc7ed4/src/transformers/cache_utils.py#L249)

        batch_size, canvas_length, _ = inputs_embeds.shape

        if past_key_values is None:
            raise ValueError(
                "`past_key_values` must be a `Cache` instance in `create_diffusion_decoder_attention_mask`."
            )
        if past_key_values.is_compileable and decoder_attention_mask is None:
            raise ValueError(
                "When `past_key_values` is a compileable cache, i.e. a static-shaped cache, `decoder_attention_mask` "
                "must be set."
            )
        # Shortcut: not compiling for sure AND no padding -> delegate mask creation to the inner functions by returning None
        if decoder_attention_mask is None or (not past_key_values.is_compileable and decoder_attention_mask.all()):
            return {"full_attention": None, "sliding_attention": None}

        # If we reach this point, we have padding and/or we may want to compile the forward pass. In either case, we
        # materialize the full mask.
        # - Full attention mask: built from the `decoder_attention_mask` input (if unset, then it's all 1s).
        # - Sliding attention mask: built from full attention mask, taking a slice of the attention mask based on the
        #   filled cache positions, plus the canvas attention
        valid_cache_tokens = past_key_values.get_seq_length()
        if past_key_values.is_compileable:
            full_cache_kv_length = past_key_values.max_cache_len
        else:
            full_cache_kv_length = valid_cache_tokens
        full_kv_length = full_cache_kv_length + canvas_length
        if decoder_attention_mask.shape != (batch_size, full_kv_length):
            raise ValueError(
                "When set, `decoder_attention_mask` must have the length = cache length + canvas length."
                f" Got `decoder_attention_mask` with length {decoder_attention_mask.shape[1]} "
                f"(!= {full_cache_kv_length} + {canvas_length})"
            )
        if (decoder_attention_mask.sum(dim=-1) > valid_cache_tokens + canvas_length).any():
            raise ValueError(
                "Your `decoder_attention_mask` has more 1s than there are cached + canvas tokens. "
                "There is one or more rows in the `decoder_attention_mask` with "
                f"{decoder_attention_mask.sum(dim=-1).max()} 1s, while there are at most "
                f"{valid_cache_tokens + canvas_length} tokens to be processed in each "
                "row. If you're using a static cache, don't forget to set empty positions to 0."
            )

        # 2D [batch_size, full_kv_length] -> 4D [batch_size, 1, query_length, full_kv_length]
        full_mask = decoder_attention_mask[:, None, None, :].bool()
        full_mask = full_mask.expand(batch_size, 1, canvas_length, full_kv_length)

        # Sliding window: first take the right slice of the full mask
        sliding_cache_is_full = valid_cache_tokens >= config.sliding_window
        if sliding_cache_is_full:
            # NOTE: currently, the compiled sliding window cache layer is 1 element longer than the non-compiled case.
            # This means that we technically have a slightly different implementation with compilable caches, where
            # the decoder sees one extra token.
            if past_key_values.is_compileable:
                sliding_start_idx = valid_cache_tokens - config.sliding_window
            else:
                sliding_start_idx = valid_cache_tokens - config.sliding_window + 1
            sliding_end_idx = valid_cache_tokens
        else:
            sliding_start_idx = 0
            if past_key_values.is_compileable:
                sliding_end_idx = min(config.sliding_window, past_key_values.max_cache_len)
            else:
                sliding_end_idx = valid_cache_tokens
        sliding_mask = full_mask[..., sliding_start_idx:sliding_end_idx]
        # Then append the canvas bidirectional mask
        sliding_mask = torch.nn.functional.pad(sliding_mask, (0, canvas_length), value=True)

        return {"full_attention": full_mask, "sliding_attention": sliding_mask}


@auto_docstring
@dataclass
class DiffusionGemmaModelOutputWithPast(BaseModelOutputWithPast):
    r"""
    encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden states at the output of the last layer of the encoder. Only set when `input_ids` is
        provided, e.g. to compute an autoregressive loss on the encoder during training.
    """

    encoder_last_hidden_state: torch.FloatTensor | None = None


@auto_docstring
@dataclass
class DiffusionGemmaBlockDiffusionOutputWithPast(CausalLMOutputWithPast):
    r"""
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
        Language modeling loss.
    logits (`torch.FloatTensor` of shape `(batch_size, canvas_length, config.text_config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden states at the output of the last layer of the encoder. Only set when `input_ids` is
        provided, e.g. to compute an autoregressive loss on the encoder during training.
    """

    encoder_last_hidden_state: torch.FloatTensor | None = None


@auto_docstring
class DiffusionGemmaModel(DiffusionGemmaPreTrainedModel):
    """
    DiffusionGemma model consisting of an auto-regressive encoder (DiffusionGemmaEncoderModel, very similar to a
    Gemma4Model), and a diffusion decoder (DiffusionGemmaDecoderModel).

    NOTE: contrarily to most encoder-decoder models, where the encoder feeds its hidden states to the decoder, here the
    encoder only feeds its KV cache to the decoder. From the decoder's perspective, the KV cache is read-only.
    """

    # All weights in the text part of the encoder are present in the decoder. However, only the decoder has the
    # self-conditioning layers. At the time of writing, HF code assumes only weights can be tied.
    _tied_weights_keys = {
        "encoder.language_model.norm.weight": "decoder.norm.weight",
        # The lines below are equivalent to `"encoder.language_model.layers": "decoder.layers"`, but don't tie buffers
        # (see comment above).
        r"encoder.language_model.layers\.(?:[^.]+\.)*weight": r"decoder.layers\.(?:[^.]+\.)*weight",
        r"encoder.language_model.layers\.(?:[^.]+\.)*scale": r"decoder.layers\.(?:[^.]+\.)*scale",
        r"encoder.language_model.layers\.(?:[^.]+\.)*per_expert_scale": r"decoder.layers\.(?:[^.]+\.)*per_expert_scale",
        r"encoder.language_model.layers\.(?:[^.]+\.)*gate_up_proj": r"decoder.layers\.(?:[^.]+\.)*gate_up_proj",
        r"encoder.language_model.layers\.(?:[^.]+\.)*down_proj": r"decoder.layers\.(?:[^.]+\.)*down_proj",
        "encoder.language_model.embed_tokens.weight": "decoder.embed_tokens.weight",
    }

    def __init__(self, config: DiffusionGemmaConfig):
        super().__init__(config)

        self.encoder = DiffusionGemmaEncoderModel(config)
        self.decoder = DiffusionGemmaDecoderModel(config)

        self.post_init()

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    def get_input_embeddings(self):
        return self.encoder.get_input_embeddings()

    def set_input_embeddings(self, new_embeddings):
        return self.encoder.set_input_embeddings(new_embeddings)

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | dict | None = None,
        past_key_values: Cache | None = None,
        position_ids: torch.LongTensor | None = None,
        decoder_input_ids: torch.LongTensor | None = None,
        self_conditioning_logits: torch.FloatTensor | None = None,
        self_conditioning_mask: torch.BoolTensor | None = None,
        decoder_attention_mask: torch.Tensor | dict | None = None,
        decoder_position_ids: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> DiffusionGemmaModelOutputWithPast:
        r"""
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Uncached token IDs for the prompt to be encoded as context for the canvas.
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)` or `dict`, *optional*):
            Mask for the input tokens.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, canvas_length)`, *optional*):
            Token IDs for the canvas to be refined.
        self_conditioning_logits (`torch.FloatTensor` of shape `(batch_size, canvas_length, vocab_size)`, *optional*):
            Self-conditioning logits from the previous denoising step, used to compute the
            self-conditioning embeddings.
        self_conditioning_mask (`torch.BoolTensor` of shape `(batch_size,)`, *optional*):
            Per-example mask over `self_conditioning_logits`: examples set to `False` get a zeroed self-conditioning
            signal, as if no logits were passed for them. Useful for training, where self-conditioning is enabled per
            example with some probability.
        decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length+canvas_length)` or `dict`, *optional*):
            Attention mask for the decoder KV cache. Used to specify padded/unpopulated encoder KV cached entries.
        decoder_position_ids (`torch.LongTensor` of shape `(batch_size, canvas_length)`, *optional*):
            The position IDs for the tokens in the canvas.
        """

        # 1: Encode new prompt tokens into the KV cache
        encoder_last_hidden_state = None
        if input_ids is not None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                position_ids=position_ids,
                **kwargs,
            )
            past_key_values = encoder_outputs.past_key_values
            encoder_last_hidden_state = encoder_outputs.last_hidden_state
        elif past_key_values is None:
            raise ValueError("Either `input_ids` or `past_key_values` must be provided.")

        # 2: Run decoder with bidirectional self-attention in the canvas, and cross-attention to the KV cache.
        # In other words, the decoder attends to all tokens, KV cache and canvas, by default.

        # 2.a.: Prepare inputs for the decoder
        # If the canvas is unset, randomly sample from the vocabulary with uniform distribution
        if decoder_input_ids is None:
            decoder_input_ids = torch.randint(
                low=0,
                high=self.config.text_config.vocab_size,
                size=(input_ids.shape[0], self.config.canvas_length),
                device=self.decoder.device,
            )

        # 2.b.: Run the decoder
        decoder_outputs = self.decoder(
            decoder_input_ids=decoder_input_ids,
            past_key_values=past_key_values,
            self_conditioning_logits=self_conditioning_logits,
            self_conditioning_mask=self_conditioning_mask,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            **kwargs,
        )

        return DiffusionGemmaModelOutputWithPast(
            last_hidden_state=decoder_outputs.last_hidden_state,
            hidden_states=decoder_outputs.hidden_states,
            attentions=decoder_outputs.attentions,
            past_key_values=past_key_values,
            encoder_last_hidden_state=encoder_last_hidden_state,
        )


class DiffusionGemmaForBlockDiffusion(DiffusionGemmaPreTrainedModel, DiffusionGemmaGenerationMixin):
    """
    DiffusionGemma model for block diffusion. It calls `DiffusionGemmaModel` to obtains the hidden states for
    the input canvas, conditioned by a prompt KV cache. Using its LM Head and self-conditioning blocks, it converts
    those hidden states into logits to sample the next canvas, as well as the self-conditioning embeddings for the
    next block diffusion step.
    """

    base_model_prefix = "model"
    _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"}
    generation_config_class = DiffusionGemmaGenerationConfig

    def __init__(self, config: DiffusionGemmaConfig):
        super().__init__(config)

        self.model = DiffusionGemmaModel(config)
        self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
        self.final_logit_softcapping = config.text_config.final_logit_softcapping

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.encoder.language_model.get_input_embeddings()

    def set_input_embeddings(self, value):
        self.model.encoder.language_model.set_input_embeddings(value)

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | dict | None = None,
        past_key_values: Cache | None = None,
        position_ids: torch.LongTensor | None = None,
        decoder_input_ids: torch.LongTensor | None = None,
        self_conditioning_logits: torch.FloatTensor | None = None,
        self_conditioning_mask: torch.BoolTensor | None = None,
        decoder_attention_mask: torch.Tensor | dict | None = None,
        decoder_position_ids: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> DiffusionGemmaBlockDiffusionOutputWithPast:
        r"""
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Uncached token IDs for the prompt to be encoded as context for the canvas.
        attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)` or `dict`, *optional*):
            Mask for the input tokens.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, canvas_length)`, *optional*):
            Token IDs for the canvas to be refined.
        self_conditioning_logits (`torch.FloatTensor` of shape `(batch_size, canvas_length, vocab_size)`, *optional*):
            Self-conditioning logits from the previous denoising step, used to compute the self-conditioning
            embeddings.
        self_conditioning_mask (`torch.BoolTensor` of shape `(batch_size,)`, *optional*):
            Per-example mask over `self_conditioning_logits`: examples set to `False` get a zeroed self-conditioning
            signal, as if no logits were passed for them. Useful for training, where self-conditioning is enabled per
            example with some probability.
        decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length+canvas_length)` or `dict`, *optional*):
            Attention mask for the decoder KV cache. Used to specify padded/unpopulated encoder KV cached entries.
        decoder_position_ids (`torch.LongTensor` of shape `(batch_size, canvas_length)`, *optional*):
            The position IDs for the tokens in the canvas.
        """

        # 1: Call the model
        model_outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            position_ids=position_ids,
            decoder_input_ids=decoder_input_ids,
            self_conditioning_logits=self_conditioning_logits,
            self_conditioning_mask=self_conditioning_mask,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            **kwargs,
        )

        # 2. Obtain the logits and apply logits softcapping
        logits = self.lm_head(model_outputs.last_hidden_state)
        logits = logits.to(torch.float32)
        logits = logits / self.final_logit_softcapping
        logits = torch.tanh(logits)
        logits = logits * self.final_logit_softcapping

        return DiffusionGemmaBlockDiffusionOutputWithPast(
            logits=logits,
            hidden_states=model_outputs.hidden_states,
            attentions=model_outputs.attentions,
            past_key_values=model_outputs.past_key_values,
            encoder_last_hidden_state=model_outputs.encoder_last_hidden_state,
        )


__all__ = [
    "DiffusionGemmaPreTrainedModel",
    "DiffusionGemmaModel",
    "DiffusionGemmaDecoderModel",
    "DiffusionGemmaEncoderModel",
    "DiffusionGemmaEncoderTextModel",
    "DiffusionGemmaForBlockDiffusion",
]
