
    
9jb                       S SK Jr  S SKrS SKrS SKrS SKJr  S SKJr  S SK	J
r
  S SKJr   S SKJr  Sr\R"                  (       a  \R$                  \R&                  \R(                  \R*                  \R,                  \R.                  \R,                  \R.                  \R0                  \R2                  \R4                  \R2                  \R4                  S
.r\R8                  \R:                  \R<                  \R>                  S.r O0 r0 r SS jr! " S S\
5      r" " S S5      r#S r$S r% " S S5      r&g! \ a    S	r GNf = f)    )annotationsN)nccl)_store)_Backend)sparse)MPITF)bBiIlLqQefdFD)sumprodmaxminc                    U R                   R                  nU[        ;  a  [        SU R                    S35      e[        U   nUc  U R                  nUS;   a  USU-  4$ X14$ )NUnknown dtype 	 for NCCLFD   )dtypechar_nccl_dtypes	TypeErrorsize)arraycountr   
nccl_dtypes       V/home/wildlama/miniconda3/lib/python3.13/site-packages/cupyx/distributed/_nccl_comm.py_get_nccl_dtype_and_countr(   0   sh    KKEL .Y?@@e$J}

}1u9$$    c                     ^  \ rS rSrSr\R                  \R                  S4U 4S jjrS r	S r
S rS rS	 rS
 rSS jrSS jrSS jr SS jrSS jrSS jrSS jrSS jrSS jrSS jrSS jrS rSrU =r$ )NCCLBackend<   a  Interface that uses NVIDIA's NCCL to perform communications.

Args:
    n_devices (int): Total number of devices that will be used in the
        distributed execution.
    rank (int): Unique id of the GPU that the communicator is associated to
        its value needs to be `0 <= rank < n_devices`.
    host (str, optional): host address for the process rendezvous on
        initialization. Defaults to `"127.0.0.1"`.
    port (int, optional): port used for the process rendezvous on
        initialization. Defaults to `13333`.
    use_mpi(bool, optional): switch between MPI and use the included TCP
        server for initialization & synchronization. Defaults to `False`.
Fc                   > [         TU ]  XX45        [        =(       a    UU l        U R                  (       a  U R	                  X5        g U R                  XX45        g N)super__init___mpi_available_use_mpi_init_with_mpi_init_with_tcp_store)self	n_devicesrankhostportuse_mpi	__class__s         r'   r0   NCCLBackend.__init__L   sE     	$5&27==	0%%itBr)   c                Z   [         R                  U l        U R                  R                  5       U l        U R                  R                  5         S nU R                  S:X  a  [        R                  " 5       nU R                  R                  USS9n[        R                  " XU5      U l
        g )Nr   root)r   
COMM_WORLD	_mpi_commGet_rank	_mpi_rankBarrierr   get_unique_idbcastNcclCommunicator_comm)r5   r6   r7   nccl_ids       r'   r3   NCCLBackend._init_with_mpiV   s     002 >>Q((*G..&&wQ&7**9tD
r)   c                P   S nUS:X  aY  U R                   R                  X45        [        R                  " 5       nXPR                  S'   U R                  R                  5         O)U R                  R                  5         U R                  S   n[        R                  " XU5      U l        g )Nr   rI   )r   runr   rE   _store_proxybarrierrG   rH   )r5   r6   r7   r8   r9   rI   s         r'   r4    NCCLBackend._init_with_tcp_storec   s    19KKOOD'((*G+2i(%%'%%'''	2G**9tD
r)   c                    UR                   R                  (       d'  UR                   R                  (       d  [        S5      eg g )Nz4NCCL requires arrays to be either c- or f-contiguous)flagsc_contiguousf_contiguousRuntimeError)r5   r$   s     r'   _check_contiguousNCCLBackend._check_contiguouso   s6    {{''0H0HFH H 1I'r)   c                p    Uc(  [         R                  R                  R                  5       nUR                  $ r.   )cupycudastreamget_current_streamptr)r5   rZ   s     r'   _get_streamNCCLBackend._get_streamt   s)    >YY%%88:Fzzr)   c                t    U[         ;  a  [        SU S35      eUS;   a  US:w  a  [        S5      e[         U   $ )NzUnknown op r   r   r   z-Only nccl.SUM is supported for complex arrays)	_nccl_opsrT   
ValueError)r5   opr   s      r'   _get_opNCCLBackend._get_opy   sG    YRD	:;;D=R5[?A A}r)   c                    [         n[        US   [        [        45      (       a!  [        R
                  " US   S   5      (       d  [        R
                  " US   5      (       a  [        n[        X15      " U /UQ76   g Nr   )_DenseNCCLCommunicator
isinstancelisttupler   issparse_SparseNCCLCommunicatorgetattr)r5   functionargs
comm_classs       r'   _dispatch_arg_typeNCCLBackend._dispatch_arg_type   s]    +
Q$//a,,tAw''0J
%d2T2r)   c                ,    U R                  SXX445        g)a  Performs an all reduce operation.

Args:
    in_array (cupy.ndarray): array to be sent.
    out_array (cupy.ndarray): array where the result with be stored.
    op (str): reduction operation, can be one of
        ('sum', 'prod', 'min' 'max'), arrays of complex type only
        support `'sum'`. Defaults to `'sum'`.
    stream (cupy.cuda.Stream, optional): if supported, stream to
        perform the communication.

all_reduceNrq   )r5   in_array	out_arrayrb   rZ   s        r'   rt   NCCLBackend.all_reduce   s     	8;	=r)   c                .    U R                  SXX4U45        g)a:  Performs a reduce operation.

Args:
    in_array (cupy.ndarray): array to be sent.
    out_array (cupy.ndarray): array where the result with be stored.
        will only be modified by the `root` process.
    root (int, optional): rank of the process that will perform the
        reduction. Defaults to `0`.
    op (str): reduction operation, can be one of
        ('sum', 'prod', 'min' 'max'), arrays of complex type only
        support `'sum'`. Defaults to `'sum'`.
    stream (cupy.cuda.Stream, optional): if supported, stream to
        perform the communication.
reduceNru   )r5   rv   rw   r?   rb   rZ   s         r'   rz   NCCLBackend.reduce   s     	xDf=	?r)   c                ,    U R                  SXU45        g)am  Performs a broadcast operation.

Args:
    in_out_array (cupy.ndarray): array to be sent for `root` rank.
        Other ranks will receive the broadcast data here.
    root (int, optional): rank of the process that will send the
        broadcast. Defaults to `0`.
    stream (cupy.cuda.Stream, optional): if supported, stream to
        perform the communication.
	broadcastNru   )r5   in_out_arrayr?   rZ   s       r'   r}   NCCLBackend.broadcast   s     	,f5	7r)   c                .    U R                  SXX4U45        g)a  Performs a reduce scatter operation.

Args:
    in_array (cupy.ndarray): array to be sent.
    out_array (cupy.ndarray): array where the result with be stored.
    count (int): Number of elements to send to each rank.
    op (str): reduction operation, can be one of
        ('sum', 'prod', 'min' 'max'), arrays of complex type only
        support `'sum'`. Defaults to `'sum'`.
    stream (cupy.cuda.Stream, optional): if supported, stream to
        perform the communication.
reduce_scatterNru   )r5   rv   rw   r%   rb   rZ   s         r'   r   NCCLBackend.reduce_scatter   s     	xEvF	Hr)   c                ,    U R                  SXX445        g)a;  Performs an all gather operation.

Args:
    in_array (cupy.ndarray): array to be sent.
    out_array (cupy.ndarray): array where the result with be stored.
    count (int): Number of elements to send to each rank.
    stream (cupy.cuda.Stream, optional): if supported, stream to
        perform the communication.

all_gatherNru   )r5   rv   rw   r%   rZ   s        r'   r   NCCLBackend.all_gather   s     	8>	@r)   c                ,    U R                  SXU45        g)zPerforms a send operation.

Args:
    array (cupy.ndarray): array to be sent.
    peer (int): rank of the process `array` will be sent to.
    stream (cupy.cuda.Stream, optional): if supported, stream to
        perform the communication.
sendNru   )r5   r$   peerrZ   s       r'   r   NCCLBackend.send   s     	f(=>r)   c                ,    U R                  SXU45        g)a  Performs a receive operation.

Args:
    array (cupy.ndarray): array used to receive data.
    peer (int): rank of the process `array` will be received from.
    stream (cupy.cuda.Stream, optional): if supported, stream to
        perform the communication.
recvNru   )r5   rw   r   rZ   s       r'   r   NCCLBackend.recv   s     	&(ABr)   c                ,    U R                  SXX445        g)aS  Performs a send and receive operation.

Args:
    in_array (cupy.ndarray): array to be sent.
    out_array (cupy.ndarray): array used to receive data.
    peer (int): rank of the process to send `in_array` and receive
        `out_array`.
    stream (cupy.cuda.Stream, optional): if supported, stream to
        perform the communication.
	send_recvNru   )r5   rv   rw   r   rZ   s        r'   r   NCCLBackend.send_recv   s     	(t<	>r)   c                ,    U R                  SXX445        g)ap  Performs a scatter operation.

Args:
    in_array (cupy.ndarray): array to be sent. Its shape must be
        `(total_ranks, ...)`.
    out_array (cupy.ndarray): array where the result with be stored.
    root (int): rank that will send the `in_array` to other ranks.
    stream (cupy.cuda.Stream, optional): if supported, stream to
        perform the communication.
scatterNru   r5   rv   rw   r?   rZ   s        r'   r   NCCLBackend.scatter   s     	T:	<r)   c                ,    U R                  SXX445        g)ap  Performs a gather operation.

Args:
    in_array (cupy.ndarray): array to be sent.
    out_array (cupy.ndarray): array where the result with be stored.
        Its shape must be `(total_ranks, ...)`.
    root (int): rank that will receive `in_array` from other ranks.
    stream (cupy.cuda.Stream, optional): if supported, stream to
        perform the communication.
gatherNru   r   s        r'   r   NCCLBackend.gather  s     	xD9	;r)   c                ,    U R                  SXU45        g)aa  Performs an all to all operation.

Args:
    in_array (cupy.ndarray): array to be sent. Its shape must be
        `(total_ranks, ...)`.
    out_array (cupy.ndarray): array where the result with be stored.
        Its shape must be `(total_ranks, ...)`.
    stream (cupy.cuda.Stream, optional): if supported, stream to
        perform the communication.

all_to_allNru   )r5   rv   rw   rZ   s       r'   r   NCCLBackend.all_to_all  s     	87	9r)   c                    U R                   (       a  U R                  R                  5         gU R                  R	                  5         g)zPerforms a barrier operation.

The barrier is done in the cpu and is a explicit synchronization
mechanism that halts the thread progression.
N)r2   rA   rD   rM   rN   )r5   s    r'   rN   NCCLBackend.barrier'  s-     ==NN""$%%'r)   )rH   rA   rC   r2   r   Nr   r   Nr   Nr.   )__name__
__module____qualname____firstlineno____doc__r   _DEFAULT_HOST_DEFAULT_PORTr0   r3   r4   rU   r]   rc   rq   rt   rz   r}   r   r   r   r   r   r   r   r   rN   __static_attributes____classcell__)r;   s   @r'   r+   r+   <   s      **1E1ECE
EH

3=?$7  @DH"@	?	C><;9( (r)   r+   c                     \ rS rSr\SS j5       r\SS j5       r\SS j5       r\ SS j5       r\SS j5       r	\SS j5       r
\SS	 j5       r\SS
 j5       r\SS j5       r\SS j5       r\SS j5       r\SS j5       r\SS j5       rSrg)rg   i5  Nc                Z   UR                  U5        UR                  U5        UR                  U5      n[        U5      u  pgUR                  XBR                  R
                  5      nUR                  R                  UR                  R                  UR                  R                  XvXE5        g r.   )
rU   r]   r(   rc   r   r    rH   	allReducedatar\   )clscommrv   rw   rb   rZ   r   r%   s           r'   rt   !_DenseNCCLCommunicator.all_reduce7  s    x(y)!!&)0:\\"nn112

MMy~~115	Mr)   c           	     |   UR                  U5        UR                  U:X  a  UR                  U5        UR                  U5      n[        U5      u  pxUR	                  XRR
                  R                  5      nUR                  R                  UR                  R                  UR                  R                  XXTU5        g r.   )rU   r7   r]   r(   rc   r   r    rH   rz   r   r\   )	r   r   rv   rw   r?   rb   rZ   r   r%   s	            r'   rz   _DenseNCCLCommunicator.reduceA  s    x(99""9-!!&)0:\\"nn112

MMy~~11"F	,r)   c                    UR                  U5        UR                  U5      n[        U5      u  pVUR                  R	                  UR
                  R                  UR
                  R                  XeX45        g r.   )rU   r]   r(   rH   r}   r   r\   )r   r   r~   r?   rZ   r   r%   s          r'   r}    _DenseNCCLCommunicator.broadcastM  s`    |,!!&)0>

!!<#4#4#8#8$	(r)   c                Z   UR                  U5        UR                  U5        UR                  U5      n[        X$5      u  ptUR                  XRR                  R
                  5      nUR                  R                  UR                  R                  UR                  R                  XGXV5        g r.   )
rU   r]   r(   rc   r   r    rH   reduceScatterr   r\   )r   r   rv   rw   r%   rb   rZ   r   s           r'   r   %_DenseNCCLCommunicator.reduce_scatterV  s     	x(y)!!&)0A\\"nn112

  MMy~~115	Mr)   c                   UR                  U5        UR                  U5        UR                  U5      n[        X$5      u  pdUR                  R	                  UR
                  R                  UR
                  R                  XFU5        g r.   )rU   r]   r(   rH   	allGatherr   r\   )r   r   rv   rw   r%   rZ   r   s          r'   r   !_DenseNCCLCommunicator.all_gathera  sj    x(y)!!&)0A

MMy~~115	Ir)   c                    UR                  U5        UR                  U5      n[        U5      u  pVU R                  XX5Xd5        g r.   )rU   r]   r(   _send)r   r   r$   r   rZ   r   r%   s          r'   r   _DenseNCCLCommunicator.sendj  s<    u%!!&)07		$tE:r)   c                f    UR                   R                  UR                  R                  XTX65        g r.   )rH   r   r   r\   r   r   r$   r   r   r%   rZ   s          r'   r   _DenseNCCLCommunicator._sendq  s    



dCr)   c                    UR                  U5        UR                  U5      n[        U5      u  pVU R                  XX5Xd5        g r.   )rU   r]   r(   _recv)r   r   rw   r   rZ   r   r%   s          r'   r   _DenseNCCLCommunicator.recvu  s<    y)!!&)0;		$4>r)   c                f    UR                   R                  UR                  R                  XTX65        g r.   )rH   r   r   r\   r   r   rw   r   r   r%   rZ   s          r'   r   _DenseNCCLCommunicator._recv|  s     

	**E$Gr)   c                >   UR                  U5        UR                  U5        UR                  U5      n[        U5      u  pg[        U5      u  p[        R                  " 5         U R                  XXFXu5        U R                  XXHX5        [        R                  " 5         g r.   )rU   r]   r(   r   
groupStartr   r   groupEnd)
r   r   rv   rw   r   rZ   idtypeicountodtypeocounts
             r'   r    _DenseNCCLCommunicator.send_recv  sy    x(y)!!&)28<29=		$$?		$4@r)   c           	         UR                   S   UR                  :w  a%  [        SUR                   SUR                    35      eUR                  U5        UR                  U5        UR	                  U5      n[
        R                  " 5         XAR                  :X  a@  [        UR                  5       H'  nX&   n[        U5      u  pU R                  XXhX5        M)     [        U5      u  pU R                  XXJX5        [
        R                  " 5         g )Nr   z"scatter requires in_array to have 'elements in its first dimension, found )shape
_n_devicesrT   rU   r]   r   r   r7   ranger(   r   r   r   )r   r   rv   rw   r?   rZ   r   r$   r   r   r   r%   s               r'   r   _DenseNCCLCommunicator.scatter  s    >>!/4T__4E9(..9IKL L 	x(y)!!&)994??+ !:5!A		$q&A , 1;		$4>r)   c           	         UR                   S   UR                  :w  a%  [        SUR                   SUR                    35      eUR                  U5        UR                  U5        UR	                  U5      n[
        R                  " 5         XAR                  :X  a@  [        UR                  5       H'  nX6   n[        U5      u  pU R                  XXhX5        M)     [        U5      u  pU R                  XXJX5        [
        R                  " 5         g )Nr   z"gather requires out_array to have r   )r   r   rT   rU   r]   r   r   r7   r   r(   r   r   r   )r   r   rv   rw   r?   rZ   r   r$   r   r   r   r%   s               r'   r   _DenseNCCLCommunicator.gather  s     ??104T__4E9)//9JLM M 	x(y)!!&)994??+!!:5!A		$q&A , 1:		$$u=r)   c           	        UR                   S   UR                  :w  a%  [        SUR                   SUR                    35      eUR                   S   UR                  :w  a%  [        SUR                   SUR                    35      eUR                  U5        UR                  U5        UR	                  U5      n[        US   5      u  pV[        US   5      u  px[        R                  " 5         [        UR                  5       H/  n	U R                  XU	   XXd5        U R                  XU	   XX5        M1     [        R                  " 5         g )Nr   %all_to_all requires in_array to have r   z&all_to_all requires out_array to have )r   r   rT   rU   r]   r(   r   r   r   r   r   r   )
r   r   rv   rw   rZ   r   r   r   r   r   s
             r'   r   !_DenseNCCLCommunicator.all_to_all  s(    ??1077H9(..9IKL L ??1088I9)//9JLM M 	x(y)!!&)28A;?29Q<@t'AIIdQKFCIIdaL!VD ( 	r)    r   r   r   r.   )r   r   r   r   classmethodrt   rz   r}   r   r   r   r   r   r   r   r   r   r   r   r   r)   r'   rg   rg   5  s   M M 	, 	, ( ( DHM M I I ; ; D D ? ? H H 	 	  $  &  r)   rg   c                R   [         R                  " SU 5      n[         R                  " SS5      n[         R                  " SS5      nUS:X  a  [        R                  " X#U4SS9$ US:X  a  [        R                  " X#U4SS9$ US:X  a  [        R
                  " X#U44SS9$ [        S5      e)	N   r   csr)r   r   )r   csccoo4NCCL is not supported for this type of sparse matrix)rX   emptyr   
csr_matrix
csc_matrix
coo_matrixr"   )r   sparse_typer   ar	   s        r'   _make_sparse_emptyr     s    ::aD

1cA

1cAe  $1V<<		  $1V<<		  $Av>>BD 	Dr)   c                    [         R                  " U 5      (       a  g[         R                  " U 5      (       a  g[         R                  " U 5      (       a  g[	        S5      e)Nr   r   r   r   )r   isspmatrix_cooisspmatrix_csrisspmatrix_cscr"   )matrixs    r'   _get_sparse_typer     sP    V$$			v	&	&			v	&	&BD 	Dr)   c                  P   \ rS rSr\S 5       r\S 5       r\S 5       rS r\SS j5       r	\SS j5       r
\SS	 j5       r\ SS
 j5       r\SS j5       r\SS j5       r\SS j5       r\SS j5       r\SS j5       r\SS j5       r\SS j5       r\SS j5       r\SS j5       rSrg)rl   i  c                f   [         R                  " U5      (       a3  UR                  5         UR                  UR                  UR
                  4$ [         R                  " U5      (       d  [         R                  " U5      (       a#  UR                  UR                  UR                  4$ [        S5      e)Nr   )r   r   sum_duplicatesr   rowcolr   r   indptrindicesr"   )r   r$   s     r'   _get_internal_arrays,_SparseNCCLCommunicator._get_internal_arrays  s~      ''  "JJ		59955""5))V-B-B5-I-IJJemm<<NOOr)   c                0    U[        S U 5       5      -   nU$ )Nc              3  8   #    U  H  oR                   v   M     g 7fr.   )r#   ).0r   s     r'   	<genexpr>?_SparseNCCLCommunicator._get_shape_and_sizes.<locals>.<genexpr>  s     #;FqFFFs   )rj   )r   arraysr   sizes_shapes       r'   _get_shape_and_sizes,_SparseNCCLCommunicator._get_shape_and_sizes  s    
 e#;F#;;;r)   c                   UR                   (       Ga  US:X  a0  [        R                  " USS9nUR                  R	                  X2SS9  g US:X  a1  [        R
                  " SSS9nUR                  R                  X2SS9  U$ US	:X  aV  UR                  U:X  a  [        R                  " USS9nO[        R
                  " SSS9nUR                  R                  X2S
9  U$ US:X  aT  [        R                  " USS9n[        R
                  " UR                  S/SS9nUR                  R                  X6U5        U$ US:X  aS  [        R                  " USS9n[        R
                  " UR                  S/SS9nUR                  R                  X65        U$ [        S5      e[        R                  " S5        US:X  a4  [        R                  " USS9nU R!                  XX#R"                  SU5        g US:X  aI  [        R
                  " SSS9nU R%                  XX#R"                  SU5        [        R&                  " U5      $ US	:X  ae  UR                  U:X  a  [        R                  " USS9nO[        R
                  " SSS9n[(        R+                  XX%S9  [        R&                  " U5      $ US:X  aa  [        R                  " USS9n[        R
                  " UR                  S4SS9n[(        R-                  XXbUS9  [        R&                  " U5      $ US:X  a`  [        R                  " USS9n[        R
                  " UR                  S4SS9n[(        R/                  XXeS9  [        R&                  " U5      $ [        S5      e)Nr   r   r   r   )desttagr      )sourcer  rF   r>   r   alltoallzUnsupported methodzUsing NCCL for transferring sparse arrays metadata. This will cause device synchronization and a huge performance degradation. Please install MPI and `mpi4py` in order to avoid this issue.)r?   rZ   )rZ   )r2   numpyr$   rA   Sendr   Recvr7   Bcastr   GatherAlltoallrT   warningswarnrX   r   r   r   asnumpyrg   r}   r   r   )r   r   r   r   methodrZ   recv_bufs          r'   _exchange_shape_and_sizes1_SparseNCCLCommunicator._exchange_shape_and_sizes  s    ===#kk+SA##K#B6! $kk!37##K!#D""7"99$"'++k"EK"'++as";K$$[$<""8##kk+SA ;;';3G%%kTB:%#kk+SA ;;';3G''>"#788MM% "jjC@		t->->6K6! #jj#6		t->->6K||K007"99$"&**["DK"&**Qc":K&00D 1 A||K008#"jjC@::t&:#F&--x6 . K||H--:%"jjC@::t&:#F&11x 2 @||H--"#788r)   c                v   [         R                  " U 5      (       a/  US   U l        US   U l        US   U l        [        U5      U l        g [         R                  " U 5      (       d  [         R                  " U 5      (       a/  US   U l        US   U l	        US   U l
        [        U5      U l        g [        S5      e)Nr   r   r   r   )r   r   r   r   r   rj   _shaper   r   r   r   r"   )r   r   r   s      r'   _assign_arrays&_SparseNCCLCommunicator._assign_arraysE  s      (( )FKFJFJ!%LFM""6**f.C.CF.K.K )FK"1IFM#AYFN!%LFMFH Hr)   Nc                R    SnU R                  XX6XE5        U R                  XXe5        g rf   )rz   r}   )r   r   rv   rw   rb   rZ   r?   s          r'   rt   "_SparseNCCLCommunicator.all_reduceT  s'     

49B?dt4r)   c           
     B   U R                  U5      nU R                  XrR                  5      nU R                  XUSU5      nUR                  U:X  Gak  [        U5      [        U5      :w  a  [        S5      eUn	[        UR                  [        U5      5      n
[        U5       H  u  p[        USS 5      nUSS  n[        X5       VVs/ s H%  u  nn[        R                  " UUR                  S9PM'     nnnX:w  d  M\  [        R                  " 5         U H-  nU R!                  UUUUR                  UR"                  U5        M/     [        R$                  " 5         U R'                  XU5        US:X  a  X-   n	M  US:X  a  X-  n	M  [        S5      e   U R'                  X0R                  U	5      U	R                  5        g [        R                  " 5         U H-  nU R)                  UUUUR                  UR"                  U5        M/     [        R$                  " 5         g s  snnf )	Nr   z.in_array and out_array must be the same formatr   r   r   r   r   z.Sparse matrix only supports sum/prod reduction)r   r   r   r  r7   r   ra   r   r   	enumeraterj   ziprX   r   r   r   r   r#   r   r  r   )r   r   rv   rw   r?   rb   rZ   r   shape_and_sizesresultpartialr   ssr   sizessr   s                    r'   rz   _SparseNCCLCommunicator.reduce\  s   ))(3226>>J776;99)-=i-HH DF FF( 0 :<G &o6b1g12=@=O=OTQDJJq0=O   <OO%#		$4!&&&I $MMO&&w>U{!'!1v!'!1(LN N# 7, 33F;V\\K OO		!T177AFFF<  MMO5s   ,Hc           	     (   U R                  U5      nUR                  U:X  a  U R                  XRR                  5      nOSnU R	                  XUSU5      n[        USS 5      nUSS  nUR                  U:w  a;  [        X5       V	V
s/ s H#  u  p[        R                  " XR                  S9PM%     nn	n
[        R                  " 5         U H  n
[        R                  XX45        M     [        R                  " 5         U R                  X%U5        g s  sn
n	f )Nr   rF   r   r   r   )r   r7   r   r   r  rj   r  rX   r   r   r   r   rg   r}   r   r  )r   r   r~   r?   rZ   r   r  r   r   r!  r   s              r'   r}   !_SparseNCCLCommunicator.broadcast  s    )),799!66**,O !O77&:oa*+#999<U9KM9K

1GG,9K  M 	A",,TdC <7Ms   *Dc           	        Sn/ n[        U[        [        45      (       d  [        S5      eU HF  n	[	        U	R
                  [        U	5      5      n
U R                  XXXV5        UR                  U
5        MH     U R                  XX7U5        g )Nr   z5in_array must be a list or a tuple of sparse matrices)
rh   ri   rj   ra   r   r   r   rz   appendr   )r   r   rv   rw   r%   rb   rZ   r?   reduce_out_arrayss_mpartial_out_arrays              r'   r   &_SparseNCCLCommunicator.reduce_scatter  s    
 (T5M22GI IC 2		+C0!2JJt"32F$$%67	 
 	DYfEr)   c           	     8   Sn/ nU R                  XXvU5        UR                  U:w  aB  [        UR                  5       Vs/ s H"  n[	        UR
                  [        U5      5      PM$     nnU H&  n	U R                  XXe5        UR                  U	5        M(     g s  snf rf   )	r   r7   r   r   r   r   r   r}   r&  )
r   r   rv   rw   r%   rZ   r?   gather_out_arrays_arrs
             r'   r   "_SparseNCCLCommunicator.all_gather  s     

4#4FC99 t/!/A #8>>3CH3MN/  ! %CMM$T2S! %	!s   )Bc           	     :   U R                  U5      nU R                  XRR                  5      nU R                  XUSU5        [        R
                  " 5         U H+  nU R                  XX7R                  UR                  U5        M-     [        R                  " 5         g )Nr   )
r   r   r   r  r   r   r   r   r#   r   )r   r   r$   r   rZ   r   r  r   s           r'   r   _SparseNCCLCommunicator.send  sv    ))%0226;;G%%	9 	AIIdtWWafff= r)   c                   UR                   R                  nU[        ;  a  [        SUR                    S35      e[	        U5      u  pEUR                  U5      nUR                  R                  UR                  R                  XTX65        g Nr   r   )
r   r    r!   r"   r(   r]   rH   r   r   r\   r   s          r'   r   _SparseNCCLCommunicator._send  si      $nU[[MCDD07!!&)



dCr)   c           	        U R                  XSSU5      nU R                  U5      n[        USS 5      nUSS  n[        X5       V	V
s/ s H#  u  p[        R
                  " XR                  S9PM%     nn	n
[        R                  " 5         U H+  n
U R                  XX:R                  U
R                  U5        M-     [        R                  " 5         U R                  X+U5        g s  sn
n	f )Nr   r   r   r   r   )r  r   rj   r  rX   r   r   r   r   r   r#   r   r  )r   r   rw   r   rZ   r  r   r   r   r!  r   arrss               r'   r   _SparseNCCLCommunicator.recv  s    77FF, )))4oa*+#9<U9KL9K

1GG,9KLAIIdtWWafff= 9E2 Ms   *C"c                    UR                   nU[        ;  a  [        SUR                   S35      e[	        U5      u  pEUR                  U5      nUR                  R                  UR                  R                  XTX65        g r3  )
r    r!   r"   r   r(   r]   rH   r   r   r\   r   s          r'   r   _SparseNCCLCommunicator._recv  sf    

$nY__,=YGHH0;!!&)

	**E$Gr)   c                    [         R                  " 5         U R                  XXE5        U R                  XXE5        [         R                  " 5         g r.   )r   r   r   r   r   )r   r   rv   rw   r   rZ   s         r'   r   !_SparseNCCLCommunicator.send_recv  s1    .$/r)   c                X   UR                   U:X  a  [        R                  " 5         [        U5       H  u  pgXd:w  d  M  U R	                  XXe5        M      [        R
                  " 5         U R                  UU R                  X$   5      X$   R                  5        g U R                  XXE5        g r.   )
r7   r   r   r  r   r   r  r   r   r   )r   r   rv   rw   r?   rZ   r   s_as           r'   r   _SparseNCCLCommunicator.scatter  s     99OO&x0	<HHT5 1 MMO((8$$&
 HHTd3r)   c                j   UR                   U:X  a  [        UR                  5       Hw  n[        UR                  [        U5      5      nXd:w  a  U R                  XXe5        O,U R                  UU R                  U5      UR                  5        UR                  U5        My     g U R                  XXE5        g r.   )r7   r   r   r   r   r   r   r  r   r   r&  r   )r   r   rv   rw   r?   rZ   r   ress           r'   r   _SparseNCCLCommunicator.gather  s     99doo.(NN$4X$>@<HHT5&&00: (   % / HHTT2r)   c           
        [        U5      UR                  :w  a$  [        SUR                   S[        U5       35      e/ n/ n[        U5       H@  u  pxU R	                  U5      n	UR                  U R                  XR                  5      5        MB     U R                  UWUSU5      n[        UR                  5       GH;  n[        Xg   SS 5      n
Xg   SS  nU R	                  X'   5      n[        X5       VVs/ s H#  u  p[        R                  " XR                  S9PM%     nnn[        R                   " 5         U H+  nU R#                  XXxR                  UR$                  U5        M-     U H+  nU R'                  XXxR                  UR$                  U5        M-     [        R(                  " 5         UR                  [+        X'   R                  [-        X'   5      5      5        U R/                  X7   X5        GM>     g s  snnf )Nr   zelements, found r  r   r   r   )lenr   rT   r  r   r&  r   r   r  r   rj   r  rX   r   r   r   r   r   r#   r   r   r   r   r  )r   r   rv   rw   rZ   r  recv_shape_and_sizesr   r   r   r   r   s_arraysr!  r_arrayss                  r'   r   "_SparseNCCLCommunicator.all_to_all#  s    x=DOO+77H"3x=/34 4
 !h'DA--a0F""3#;#;FGG#LM (  #<<!_j& : t'A.1!A67E(+AB/E//<H :=U9MO9M

1GG,9M  OOO		$1ggqvvv> 		$1ggqvvv> MMO/!! -/ 0 y|X=! (
Os   7*G<r   r   r   r   r.   )r   r   r   r   r   r   r   r  r  rt   rz   r}   r   r   r   r   r   r   r   r   r   r   r   r   r)   r'   rl   rl     sU   P P   I9 I9VH 5 5 + +Z 8 80 DHF F  " "& 	 	 D D 3 3& H H   4 4 3 3" "> ">r)   rl   r.   )'
__future__r   r  r  rX   	cupy.cudar   cupyx.distributedr   cupyx.distributed._commr   cupyx.scipyr   mpi4pyr   r1   ImportError	available	NCCL_INT8
NCCL_UINT8
NCCL_INT32NCCL_UINT32
NCCL_INT64NCCL_UINT64NCCL_FLOAT16NCCL_FLOAT32NCCL_FLOAT64r!   NCCL_SUM	NCCL_PRODNCCL_MAXNCCL_MINr`   r(   r+   rg   r   r   rl   r   r)   r'   <module>r]     s(   "     $ , N
 >> ))))))**********,L 'I
 LI	v(( v(rS SlD	D`> `>m  Ns   D6 6EE