
    
3j"              	       <   S SK r S SKJr  S SKJrJrJrJrJrJ	r	  S SK
rS SKrS SKrSSKJrJr  SSKJrJr  SSKJrJr  \" 5       (       a  S SKr\ " S	 S
\5      5       r " S S5      r " S S5      r  SS\S\S\S   S\R<                  4S jjr " S S\\5      r g)    N)	dataclass)CallableListLiteralOptionalTupleUnion   )ConfigMixinregister_to_config)
BaseOutputis_scipy_available   )KarrasDiffusionSchedulersSchedulerMixinc                   `    \ rS rSr% Sr\R                  \S'   Sr\R                  S-  \S'   Sr	g)DPMSolverSDESchedulerOutput    aM  
Output class for the scheduler's `step` function output.

Args:
    prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
        Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
        denoising loop.
    pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
        The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
        `pred_original_sample` can be used to preview progress or for guidance.
prev_sampleNpred_original_sample )
__name__
__module____qualname____firstlineno____doc__torchTensor__annotations__r   __static_attributes__r       g/home/wildlama/miniconda3/lib/python3.13/site-packages/diffusers/schedulers/scheduling_dpmsolver_sde.pyr   r       s'    
 04%,,-4r!   r   c                       \ rS rSrSr SS\R                  S\S\S\\	\
\\
   4      4S jjr\S	\S
\S\\\\4   4S j5       rS\S\S\R                  4S jrSrg)BatchedBrownianTree3   zGA wrapper around torchsde.BrownianTree that enables batches of entropy.Nxt0t1seedc                    U R                  X#5      u  p#U l        UR                  S[        R                  " U5      5      nUc&  [        R
                  " SS/ 5      R                  5       nSU l         [        U5      UR                  S   :X  d   eUS   nU Vs/ s H=  n[        R                  " UUUR                  UR                  UR                  USSSS9	PM?     snU l        g ! [         a    U/nSU l         Ngf = fs  snf )	Nw0r   l    TFư>   )	r'   r(   sizedtypedeviceentropytol	pool_sizehalfway_tree)sortsigngetr   
zeros_likerandintitembatchedlenshape	TypeErrortorchsdeBrownianIntervalr/   r0   trees)selfr&   r'   r(   r)   kwargsr+   ss           r"   __init__BatchedBrownianTree.__init__6   s     !IIb-	ZZe..q12<==Ir2779D	!t9
***AB  
  %%XXhhyy!
 

  	!6D DL	!
s   1#C% AC?%C<;C<abreturnc                     X:  a  XS4$ XS4$ )aG  
Sorts two float values and returns them along with a sign indicating if they were swapped.

Args:
    a (`float`):
        The first value.
    b (`float`):
        The second value.

Returns:
    `Tuple[float, float, float]`:
        A tuple containing the sorted values (min, max) and a sign (1.0 if a < b, -1.0 otherwise).
      ?g      r   )rG   rH   s     r"   r5   BatchedBrownianTree.sortX   s      ec{5!5r!   c           	          U R                  X5      u  pn[        R                  " U R                   Vs/ s H
  oD" X5      PM     sn5      U R                  U-  -  nU R
                  (       a  U$ US   $ s  snf )Nr   )r5   r   stackrA   r6   r;   )rB   r'   r(   r6   treews         r"   __call__BatchedBrownianTree.__call__i   sa    yy(KK$**=*$b*=>$))dBRSLLq*ad* >s   A3)r;   r6   rA   N)r   r   r   r   r   r   r   floatr   r	   intr   rE   staticmethodr   r5   rQ   r    r   r!   r"   r$   r$   3   s    Q 15 
<< 
  
 	 

 uS$s)^,- 
D 6 6% 6E%*=$> 6 6 +5 +e + +r!   r$   c                       \ rS rSrSrSS 4S\R                  S\S\S\\	\
\\
   4      S	\\/\4   4
S
 jjrS\S\S\R                  4S jrSrg)BrownianTreeNoiseSamplero   aR  A noise sampler backed by a torchsde.BrownianTree.

Args:
    x (`torch.Tensor`): The tensor whose shape, device and dtype is used to generate random samples.
    sigma_min (`float`): The low end of the valid interval.
    sigma_max (`float`): The high end of the valid interval.
    seed (`int` or `List[int]`): The random seed. If a list of seeds is
        supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each
        with its own seed.
    transform (`callable`): A function that maps sigma to the sampler's
        internal timestep.
Nc                     U $ rS   r   )r&   s    r"   <lambda>!BrownianTreeNoiseSampler.<lambda>   s    r!   r&   	sigma_min	sigma_maxr)   	transformc                     XPl         U R                  [        R                  " U5      5      U R                  [        R                  " U5      5      pv[        XXt5      U l        g rS   )r_   r   	as_tensorr$   rO   )rB   r&   r]   r^   r)   r_   r'   r(   s           r"   rE   !BrownianTreeNoiseSampler.__init__}   sC     #	 :;T^^EOO\eLf=gB'r8	r!   sigma
sigma_nextrI   c                     U R                  [        R                  " U5      5      U R                  [        R                  " U5      5      pCU R                  X45      XC-
  R	                  5       R                  5       -  $ rS   )r_   r   ra   rO   abssqrt)rB   rc   rd   r'   r(   s        r"   rQ   !BrownianTreeNoiseSampler.__call__   sS     67XbHc9dByy BG==?#7#7#999r!   )r_   rO   )r   r   r   r   r   r   r   rT   r   r	   rU   r   r   rE   rQ   r    r   r!   r"   rX   rX   o   s    $ 15.9
9<<
9 
9 	
9
 uS$s)^,-
9 UGUN+
9:e : :5<< :r!   rX   num_diffusion_timestepsmax_betaalpha_transform_type)cosineexplaplacerI   c           
      :   US:X  a  S nO"US:X  a  S nOUS:X  a  S nO[        SU 35      e/ n[        U 5       H<  nXP-  nUS-   U -  nUR                  [        SU" U5      U" U5      -  -
  U5      5        M>     [        R
                  " U[        R                  S	9$ )
a  
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].

Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.

Args:
    num_diffusion_timesteps (`int`):
        The number of betas to produce.
    max_beta (`float`, defaults to `0.999`):
        The maximum beta to use; use values lower than 1 to avoid numerical instability.
    alpha_transform_type (`str`, defaults to `"cosine"`):
        The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.

Returns:
    `torch.Tensor`:
        The betas used by the scheduler to step the model outputs.
rl   c                 h    [         R                  " U S-   S-  [         R                  -  S-  5      S-  $ )NgMb?gT㥛 ?r
   )mathcospits    r"   alpha_bar_fn)betas_for_alpha_bar.<locals>.alpha_bar_fn   s-    88QY%/$''9A=>!CCr!   rn   c           	         S[         R                  " SSU -
  5      -  [         R                  " SS[         R                  " SU -
  5      -  -
  S-   5      -  n[         R                  " U5      n[         R
                  " USU-   -  5      $ )Ng      r         ?r
   r,   )rq   copysignlogfabsrm   rg   )ru   lmbsnrs      r"   rv   rw      sm    q#'22TXXa!diiPSVWPWFXBX>X[_>_5``C((3-C99SAG_--r!   rm   c                 4    [         R                  " U S-  5      $ )Ng      ()rq   rm   rt   s    r"   rv   rw      s    88AI&&r!   z"Unsupported alpha_transform_type: r   r/   )
ValueErrorrangeappendminr   tensorfloat32)ri   rj   rk   rv   betasir(   t2s           r"   betas_for_alpha_barr      s    0 x'	D 
	*	.
 
	&	' =>R=STUUE*+(!e..S\"-R0@@@(KL , <<U]]33r!   c                      \ rS rSrSr\ V Vs/ s H  oR                  PM     snn rSr\	            S:S\
S\S\S\S	   S
\R                  \\   -  S-  S\S   S\S\S\S\
S-  S\S   S\
4S jj5       r S;S\\R&                  -  S\R&                  S-  S\
4S jjrS\\R&                  -  SS4S jr\S\R&                  4S j5       r\S\\
S4   4S j5       r\S\\
S4   4S j5       rS<S\
SS4S jjrS\R&                  S\\R&                  -  S\R&                  4S  jr  S=S!\
S"\\R<                  -  S\
S-  SS4S# jjrS$\R                  S%\R                  S\R                  4S& jr S'\R                  S%\R                  S\R                  4S( jr!S)\R&                  S\R&                  4S* jr"S)\R&                  S!\
S\R&                  4S+ jr# S>S)\R&                  S!\
S,\S-\S\R&                  4
S. jjr$\S\4S/ j5       r%  S?S0\R&                  S\\R&                  -  S\R&                  S1\S2\S\&\'-  4S3 jjr(S4\R&                  S5\R&                  S6\R&                  S\R&                  4S7 jr)S\
4S8 jr*S9r+gs  snn f )@DPMSolverSDEScheduler   u7
  
DPMSolverSDEScheduler implements the stochastic sampler from the [Elucidating the Design Space of Diffusion-Based
Generative Models](https://huggingface.co/papers/2206.00364) paper.

This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.

Args:
    num_train_timesteps (`int`, defaults to 1000):
        The number of diffusion steps to train the model.
    beta_start (`float`, defaults to 0.00085):
        The starting `beta` value of inference.
    beta_end (`float`, defaults to 0.012):
        The final `beta` value.
    beta_schedule (`str`, defaults to `"linear"`):
        The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
        `linear` or `scaled_linear`.
    trained_betas (`np.ndarray`, *optional*):
        Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
    prediction_type (`str`, defaults to `epsilon`, *optional*):
        Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
        `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
        Video](https://huggingface.co/papers/2210.02303) paper).
    use_karras_sigmas (`bool`, *optional*, defaults to `False`):
        Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
        the sigmas are determined according to a sequence of noise levels {σi}.
    use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
        Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
    use_beta_sigmas (`bool`, *optional*, defaults to `False`):
        Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
        Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
    noise_sampler_seed (`int`, *optional*, defaults to `None`):
        The random seed to use for the noise sampler. If `None`, a random seed is generated.
    timestep_spacing (`str`, defaults to `"linspace"`):
        The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
        Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
    steps_offset (`int`, defaults to 0):
        An offset added to the inference steps, as required by some model families.
r
   Nnum_train_timesteps
beta_startbeta_endbeta_schedule)linearscaled_linearsquaredcos_cap_v2trained_betasprediction_type)epsilonsamplev_predictionuse_karras_sigmasuse_exponential_sigmasuse_beta_sigmasnoise_sampler_seedtimestep_spacing)linspaceleadingtrailingsteps_offsetc                    U R                   R                  (       a  [        5       (       d  [        S5      e[	        U R                   R                  U R                   R
                  U R                   R                  /5      S:  a  [        S5      eUb)  [        R                  " U[        R                  S9U l        OUS:X  a*  [        R                  " X#U[        R                  S9U l        OkUS:X  a4  [        R                  " US-  US-  U[        R                  S9S-  U l        O1US	:X  a  [        U5      U l        O[        U S
U R                   35      eSU R                  -
  U l        [        R"                  " U R                   SS9U l        U R'                  US U5        Xpl        S U l        Xl        S U l        S U l        U R0                  R3                  S5      U l        g )Nz:Make sure to install scipy if you want to use beta sigmas.r   znOnly one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.r   r   r   ry   r
   r   z is not implemented for rK   r   )dimcpu)configr   r   ImportErrorsumr   r   r   r   r   r   r   r   r   NotImplementedError	__class__alphascumprodalphas_cumprodset_timestepsnoise_samplerr   _step_index_begin_indexsigmasto)rB   r   r   r   r   r   r   r   r   r   r   r   r   s                r"   rE   DPMSolverSDEScheduler.__init__   s     ;;&&/A/C/CZ[[KK//KK66KK11   A  $m5==IDJh&
>QY^YfYfgDJo- OcM'--	  J 11,-@ADJ%7OPTP^P^O_&`aaDJJ&#mmDKKQ? 	.6IJ!2!"4 kknnU+r!   timestepschedule_timestepsrI   c                     Uc  U R                   nX!:H  R                  5       n[        U5      S:  a  SOSnX4   R                  5       $ )a  
Find the index of a given timestep in the timestep schedule.

Args:
    timestep (`float` or `torch.Tensor`):
        The timestep value to find in the schedule.
    schedule_timesteps (`torch.Tensor`, *optional*):
        The timestep schedule to search in. If `None`, uses `self.timesteps`.

Returns:
    `int`:
        The index of the timestep in the schedule. For the very first step, returns the second index if
        multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
r   r   )	timestepsnonzeror<   r:   )rB   r   r   indicesposs        r"   index_for_timestep(DPMSolverSDEScheduler.index_for_timestep0  sJ    " %!%%1::< w<!#a|  ""r!   c                     U R                   c[  [        U[        R                  5      (       a%  UR	                  U R
                  R                  5      nU R                  U5      U l        gU R                  U l        g)z
Initialize the step index for the scheduler based on the given timestep.

Args:
    timestep (`float` or `torch.Tensor`):
        The current timestep to initialize the step index from.
N)
begin_index
isinstancer   r   r   r   r0   r   r   r   )rB   r   s     r"   _init_step_index&DPMSolverSDEScheduler._init_step_indexO  sZ     #(ELL11#;;t~~'<'<=#66x@D#00Dr!   c                     U R                   R                  S;   a  U R                  R                  5       $ U R                  R                  5       S-  S-   S-  $ )N)r   r   r
   r   ry   )r   r   r   maxrB   s    r"   init_noise_sigma&DPMSolverSDEScheduler.init_noise_sigma^  sH     ;;''+CC;;??$$!Q&*s22r!   c                     U R                   $ )zW
The index counter for current timestep. It will increase 1 after each scheduler step.
)r   r   s    r"   
step_index DPMSolverSDEScheduler.step_indexf  s    
 r!   c                     U R                   $ )za
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
r   r   s    r"   r   !DPMSolverSDEScheduler.begin_indexm  s    
    r!   r   c                     Xl         g)z
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

Args:
    begin_index (`int`, defaults to `0`):
        The begin index for the scheduler.
Nr   )rB   r   s     r"   set_begin_index%DPMSolverSDEScheduler.set_begin_indexu  s
     (r!   r   c                     U R                   c  U R                  U5        U R                  U R                      nU R                  (       a  UOU R                  nXS-  S-   S-  -  nU$ )aN  
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.

Args:
    sample (`torch.Tensor`):
        The input sample.
    timestep (`int`, *optional*):
        The current timestep in the diffusion chain.

Returns:
    `torch.Tensor`:
        A scaled input sample.
r
   r   ry   )r   r   r   state_in_first_ordermid_point_sigma)rB   r   r   rc   sigma_inputs        r"   scale_model_input'DPMSolverSDEScheduler.scale_model_input  s^    & ??"!!(+DOO,#88ed>R>RNQ.367r!   num_inference_stepsr0   c           	      	   Xl         U=(       d    U R                  R                  nU R                  R                  S:X  a4  [        R
                  " SUS-
  U[        S9SSS2   R                  5       nGO(U R                  R                  S:X  av  X0R                   -  n[        R                  " SU5      U-  R                  5       SSS2   R                  5       R                  [        5      nX@R                  R                  -  nOU R                  R                  S:X  a\  X0R                   -  n[        R                  " USU* 5      R                  5       R                  5       R                  [        5      nUS-  nO"[        U R                  R                   S	35      e[        R                  " SU R                  -
  U R                  -  S
-  5      n[        R                  " U5      n[        R                   " U[        R                  " S[#        U5      5      U5      nU R                  R$                  (       aE  U R'                  US9n[        R                  " U Vs/ s H  oR)                  X5      PM     sn5      nOU R                  R*                  (       aE  U R-                  XaS9n[        R                  " U Vs/ s H  oR)                  X5      PM     sn5      nO_U R                  R.                  (       aD  U R1                  XaS9n[        R                  " U Vs/ s H  oR)                  X5      PM     sn5      nU R3                  Xg5      n	[        R4                  " US//5      R                  [        R6                  5      n[8        R:                  " U5      R=                  US9n[8        R>                  " USS USS RA                  S5      USS /5      U l!        [8        R:                  " U5      n[8        R:                  " U	5      n	[8        R>                  " USS USS RA                  S5      /5      nXSSS2'   [E        U5      RG                  S5      (       a$  UR=                  U[8        R6                  S9U l$        OUR=                  US9U l$        SU l%        SU l&        SU l'        SU l(        U RB                  R=                  S5      U l!        SU l)        gs  snf s  snf s  snf )a  
Sets the discrete timesteps used for the diffusion chain (to be run before inference).

Args:
    num_inference_steps (`int`):
        The number of diffusion steps used when generating samples with a pre-trained model.
    device (`str` or `torch.device`, *optional*):
        The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
    num_train_timesteps (`int`, *optional*):
        The number of train timesteps. If `None`, uses `self.config.num_train_timesteps`.
r   r   r   r   Nr   r   zY is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'.ry   )	in_sigmas)r   r   g        )r0   r
   mpsr   )*r   r   r   r   npr   rT   copyarangeroundastyper   r   arrayr   r{   interpr<   r   _convert_to_karras_sigma_to_tr   _convert_to_exponentialr   _convert_to_beta_second_order_timestepsconcatenater   r   
from_numpyr   catrepeat_interleaver   str
startswithr   r   r   r   r   r   )
rB   r   r0   r   r   
step_ratior   
log_sigmasrc   second_order_timestepss
             r"   r   #DPMSolverSDEScheduler.set_timesteps  s   " $7 1TT[[5T5T ;;'':5A':Q'>@S[`abfdfbfgllnI[[))Y6,0H0HHJ 1&9:ZGNNPQUSUQUV[[]ddejkI111I[[))Z7,/G/GGJ #6J;GNNPUUW^^_deINI;;//0  1J  K  A 3 33t7J7JJsRSVVF^
9bii3v;&?H;;((,,v,>FSY!ZSY%"2"25"ESY!Z[I[[//11F1lFSY!ZSY%"2"25"ESY!Z[I[[((**V*eFSY!ZSY%"2"25"ESY!Z[I!%!=!=f!Q#077

C!!&),,F,;iiVAb\-K-KA-NPVWYWZP[ \]$$Y/	!&!1!12H!IIIy!}im.M.Ma.PQR	0!$Q$v;!!%((&\\&\FDN&\\\8DN # kknnU+!A "[ "[ "[s   S9SSr   r   c           	          S nS nSnU" U5      n[         R                  " U5      nUS S Xu-  -   nU" U5      n	[         R                  " U	 V
s/ s H  oR                  X5      PM     sn
5      nU$ s  sn
f )Nc                 0    [         R                  " U * 5      $ rS   )r   rm   _ts    r"   sigma_fn?DPMSolverSDEScheduler._second_order_timesteps.<locals>.sigma_fn  s    662#;r!   c                 0    [         R                  " U 5      * $ rS   )r   r{   _sigmas    r"   t_fn;DPMSolverSDEScheduler._second_order_timesteps.<locals>.t_fn  s    FF6N?"r!   ry   r   )r   diffr   r   )rB   r   r   r   r   midpoint_ratioru   
delta_time
t_proposedsig_proposedrc   r   s               r"   r   -DPMSolverSDEScheduler._second_order_timesteps  sx    		# LWWQZ
sVj99

+HH|\|e..uA|\]	 ]s   A0rc   c                    [         R                  " [         R                  " US5      5      nX2SS2[         R                  4   -
  n[         R                  " US:  SS9R                  SS9R                  UR                  S   S-
  S9nUS-   nX%   nX&   nXs-
  Xx-
  -  n	[         R                  " U	SS5      n	SU	-
  U-  X-  -   n
U
R                  UR                  5      n
U
$ )at  
Convert sigma values to corresponding timestep values through interpolation.

Args:
    sigma (`np.ndarray`):
        The sigma value(s) to convert to timestep(s).
    log_sigmas (`np.ndarray`):
        The logarithm of the sigma schedule used for interpolation.

Returns:
    `np.ndarray`:
        The interpolated timestep value(s) corresponding to the input sigma(s).
g|=Nr   )axisr
   )r   r   )	r   r{   maximumnewaxiscumsumargmaxclipr=   reshape)rB   rc   r   	log_sigmadistslow_idxhigh_idxlowhighrP   ru   s              r"   r   !DPMSolverSDEScheduler._sigma_to_t  s     FF2::eU34	 q"**}55 ))UaZq188a8@EE*JZJZ[\J]`aJaEbQ;!# _,GGAq! Ug,IIekk"r!   r   c                     US   R                  5       nUS   R                  5       nSn[        R                  " SSU R                  5      nUSU-  -  nUSU-  -  nXuXg-
  -  -   U-  nU$ )aY  
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).

Args:
    in_sigmas (`torch.Tensor`):
        The input sigma values to be converted.

Returns:
    `torch.Tensor`:
        The converted sigma values following the Karras noise schedule.
r   r   g      @r   )r:   r   r   r   )	rB   r   r]   r^   rhorampmin_inv_rhomax_inv_rhor   s	            r"   r   (DPMSolverSDEScheduler._convert_to_karras   s{     %R=--/	$Q<,,.	{{1a!9!9:AG,AG,(A BBsJr!   c                    [        U R                  S5      (       a  U R                  R                  nOSn[        U R                  S5      (       a  U R                  R                  nOSnUb  UOUS   R	                  5       nUb  UOUS   R	                  5       n[
        R                  " [
        R                  " [        R                  " U5      [        R                  " U5      U5      5      nU$ )aP  
Construct an exponential noise schedule.

Args:
    in_sigmas (`torch.Tensor`):
        The input sigma values to be converted.
    num_inference_steps (`int`):
        The number of inference steps to generate the noise schedule for.

Returns:
    `torch.Tensor`:
        The converted sigma values following an exponential schedule.
r]   Nr^   r   r   )
hasattrr   r]   r^   r:   r   rm   r   rq   r{   )rB   r   r   r]   r^   r   s         r"   r   -DPMSolverSDEScheduler._convert_to_exponential9  s    " 4;;,,--II4;;,,--II!*!6IIbM<N<N<P	!*!6IIaL<M<M<O	DHHY$7)9LNabcr!   alphabetac           
      J   [        U R                  S5      (       a  U R                  R                  nOSn[        U R                  S5      (       a  U R                  R                  nOSnUb  UOUS   R	                  5       nUb  UOUS   R	                  5       n[
        R                  " S[
        R                  " SSU5      -
   Vs/ s H-  n[        R                  R                  R                  XsU5      PM/     sn Vs/ s H  nXXXe-
  -  -   PM     sn5      n	U	$ s  snf s  snf )az  
Construct a beta noise schedule as proposed in [Beta Sampling is All You
Need](https://huggingface.co/papers/2407.12173).

Args:
    in_sigmas (`torch.Tensor`):
        The input sigma values to be converted.
    num_inference_steps (`int`):
        The number of inference steps to generate the noise schedule for.
    alpha (`float`, *optional*, defaults to `0.6`):
        The alpha parameter for the beta distribution.
    beta (`float`, *optional*, defaults to `0.6`):
        The beta parameter for the beta distribution.

Returns:
    `torch.Tensor`:
        The converted sigma values following a beta distribution schedule.
r]   Nr^   r   r   r   )r  r   r]   r^   r:   r   r   r   scipystatsr  ppf)
rB   r   r   r  r  r]   r^   r   r   r   s
             r"   r   &DPMSolverSDEScheduler._convert_to_beta[  s   0 4;;,,--II4;;,,--II!*!6IIbM<N<N<P	!*!6IIaL<M<M<O	
 %&Aq:M(N$N$N KK$$(($?$NC I$9:;
 s   4D?D c                     U R                   S L $ rS   )r   r   s    r"   r   *DPMSolverSDEScheduler.state_in_first_order  s    {{d""r!   model_outputreturn_dicts_noisec                 .   U R                   c  U R                  U5        U R                  c{  U R                  U R                  S:     R	                  5       U R                  R                  5       pv[        X6R                  5       UR                  5       U R                  5      U l        S[        R                  S[        R                  4S jnS[        R                  S[        R                  4S jn	U R                  (       a6  U R                  U R                      n
U R                  U R                   S-      nO5U R                  U R                   S-
     n
U R                  U R                      nS	nU	" U
5      U	" U5      pX-
  nXU-  -   nU R                  R                  S
:X  a$  U R                  (       a  U
OU" U5      nUUU-  -
  nOU R                  R                  S:X  a:  U R                  (       a  U
OU" U5      nUU* US-  S-   S	-  -  -  UUS-  S-   -  -   nOHU R                  R                  S:X  a  [        S5      e[        SU R                  R                   S35      eUS:X  a  UU-
  U
-  nX-
  nUUU-  -   nOU R                  (       a  UnOU R                   nU" U5      nU" U5      n[	        UUS-  US-  US-  -
  -  US-  -  S	-  5      nUS-  US-  -
  S	-  nU	" U5      nU" U5      U" U5      -  U-  UU-
  R#                  5       U-  -
  nUU R                  U" U5      U" U5      5      U-  U-  -   nU R                  (       a  X0l        U" U5      U l        OSU l        SU l        U =R&                  S-  sl        U(       d  UU4$ [)        UUS9$ )a
  
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).

Args:
    model_output (`torch.Tensor`):
        The direct output from learned diffusion model.
    timestep (`float` or `torch.Tensor`):
        The current discrete timestep in the diffusion chain.
    sample (`torch.Tensor`):
        A current instance of a sample created by the diffusion process.
    return_dict (`bool`):
        Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
        tuple.
    s_noise (`float`, *optional*, defaults to 1.0):
        Scaling factor for noise added to the sample.

Returns:
    [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or `tuple`:
        If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] is
        returned, otherwise a tuple is returned where the first element is the sample tensor.
Nr   r   rI   c                 >    U R                  5       R                  5       $ rS   )negrm   r   s    r"   r   ,DPMSolverSDEScheduler.step.<locals>.sigma_fn  s    668<<>!r!   r   c                 >    U R                  5       R                  5       $ rS   )r{   r)  r   s    r"   r   (DPMSolverSDEScheduler.step.<locals>.t_fn  s    ::<##%%r!   r   ry   r   r   r
   r   z+prediction_type not implemented yet: samplezprediction_type given as z, must be one of `epsilon`, or `v_prediction`)r   r   )r   r   r   r   r   r   rX   r:   r   r   r   r   r   r   r   r   r   expm1r   r   r   )rB   r$  r   r   r%  r&  	min_sigma	max_sigmar   r   rc   rd   r   ru   t_nextr   r   r   r   
derivativedtr   
sigma_fromsigma_tosigma_up
sigma_downancestral_ts                              r"   stepDPMSolverSDEScheduler.step  s   < ??"!!(+ %#';;t{{Q#?#C#C#Et{{GXy!9()..*:D<S<S"D
	" 	"%,, 	"	& 	&%,, 	& $$KK0ET__q%89J KK! 34ET__5J Kj!16Z
n44
 ;;&&)3#'#<#<%(:BVK#)K,,F#F [[((N:#'#<#<%(:BVK#/K<;PQ>TUCUZ]B]3]#^+q.1,-$  [[((H4%&STT+DKK,G,G+HHtu  ? #775@J#B :?2K((#!!J'H1
A! ;<z1}LQTTH #A+!3;Jz*K#K08A;>&HKeg,L- -K &(:(:8A;QWHX(Y\c(cfn(nnK(($'/'7$ #'+$ 	A$ 
 +{Ymnnr!   original_samplesnoiser   c                     U R                   R                  UR                  UR                  S9nUR                  R                  S:X  av  [
        R                  " U5      (       a[  U R                  R                  UR                  [
        R                  S9nUR                  UR                  [
        R                  S9nO@U R                  R                  UR                  5      nUR                  UR                  5      nU R                  c!  U Vs/ s H  o`R                  Xe5      PM     nnOHU R                  b  U R                  /UR                  S   -  nOU R                  /UR                  S   -  nXG   R                  5       n[        UR                  5      [        UR                  5      :  a?  UR                  S5      n[        UR                  5      [        UR                  5      :  a  M?  XU-  -   n	U	$ s  snf )a  
Add noise to the original samples according to the noise schedule at the specified timesteps.

Args:
    original_samples (`torch.Tensor`):
        The original samples to which noise will be added.
    noise (`torch.Tensor`):
        The noise tensor to add to the original samples.
    timesteps (`torch.Tensor`):
        The timesteps at which to add noise, determining the noise level from the schedule.

Returns:
    `torch.Tensor`:
        The noisy samples with added noise scaled according to the timestep schedule.
)r0   r/   r   r   r   r   )r   r   r0   r/   typer   is_floating_pointr   r   r   r   r   r=   flattenr<   	unsqueeze)
rB   r:  r;  r   r   r   ru   step_indicesrc   noisy_sampless
             r"   	add_noiseDPMSolverSDEScheduler.add_noise  s   , '7'>'>FVF\F\]""''50U5L5LY5W5W!%!2!23C3J3JRWR_R_!2!`!%5%<%<EMMRI!%!2!23C3J3J!K!%5%<%<=I #T]^T]q33AJT]L^L__( OO,yq/AAL !,,-	0BBL$,,.%++%5%;%;!<<OOB'E %++%5%;%;!<< )5=8 _s   G;c                 .    U R                   R                  $ rS   )r   r   r   s    r"   __len__DPMSolverSDEScheduler.__len__8  s    {{...r!   )r   r   r   r   r   r   r   r   r   r   r   r   r   )i  g_QK?g~jt?r   Nr   FFFNr   r   rS   )r   )NN)333333?rH  )TrK   ),r   r   r   r   r   r   name_compatiblesorderr   rU   rT   r   r   ndarraylistboolrE   r   r   r   r   propertyr   r	   r   r   r   r   r   r0   r   r   r   r   r   r   r   r   tupler8  rC  rF  r    ).0es   00r"   r   r      s   &P %>>$=qFF$=>LE $(#QY9=HQ"'', %)-GQ=, =, =, 	=,
 MN=, zzDK/$6=, !!DE=,  =, !%=, =,  $J=, ""CD=, =, =,B Y]#,#BG,,QUBU#	#>1)= 1$ 1 3%,, 3 3  E#t),     !U39- ! !(3 (t ( %,,& 
	< &**.	O" O" ell"O" !4Z	O"
 
O"bbjj bjj UWU_U_  " " "

 "JELL U\\ 2 TW \a\h\h F dg..<?.HM.[`.	.` #d # # !vollvo %,,&vo 	vo
 vo vo 
%u	,vor.,,. ||. <<	.
 
.`/ /Y ?s   I>r   )g+?rl   )!rq   dataclassesr   typingr   r   r   r   r   r	   numpyr   r   r?   configuration_utilsr   r   utilsr   r   scheduling_utilsr   r   scipy.statsr  r   r$   rX   rU   rT   r   r   r   r   r!   r"   <module>rZ     s     ! B B    A 2 G  5* 5 5"9+ 9+x: :D @H14 1414 ""<=14 \\	14hv	/NK v	/r!   