#include "cupy_cub.h" // need to make atomicAdd visible to CUB templates early #include #ifndef CUPY_USE_HIP #include #include #include #include #include #include #include #include #include // numeric_limits #else #include #include #include #include #include #include #endif /* ------------------------------------ Minimum boilerplate to support complex numbers ------------------------------------ */ #ifndef CUPY_USE_HIP // - This works only because all data fields in the *Traits struct are not // used in . // - The Max() and Lowest() below are chosen to comply with NumPy's lexical // ordering; note that std::numeric_limits does not support complex // numbers as in general the comparison is ill defined. // - DO NOT USE THIS STUB for supporting CUB sorting!!!!!! using namespace cub; #define CUPY_CUB_NAMESPACE cub namespace cuda { namespace std { template <> class numeric_limits> { public: static __host__ __device__ thrust::complex max() noexcept { return thrust::complex(cuda::std::numeric_limits::max(), cuda::std::numeric_limits::max()); } static __host__ __device__ thrust::complex lowest() noexcept { return thrust::complex(-cuda::std::numeric_limits::max(), -cuda::std::numeric_limits::max()); } static __host__ __device__ thrust::complex infinity() noexcept { return thrust::complex(cuda::std::numeric_limits::infinity(), cuda::std::numeric_limits::infinity()); } static constexpr bool has_infinity = true; static constexpr bool is_specialized = true; }; template <> class numeric_limits> { public: static __host__ __device__ thrust::complex max() noexcept { return thrust::complex(cuda::std::numeric_limits::max(), cuda::std::numeric_limits::max()); } static __host__ __device__ thrust::complex lowest() noexcept { return thrust::complex(-cuda::std::numeric_limits::max(), -cuda::std::numeric_limits::max()); } static __host__ __device__ thrust::complex infinity() noexcept { return thrust::complex(cuda::std::numeric_limits::infinity(), cuda::std::numeric_limits::infinity()); } static constexpr bool has_infinity = true; static constexpr bool is_specialized = true; }; } // namespace std template <> inline constexpr bool is_floating_point_v> = true; template <> inline constexpr bool is_floating_point_v> = true; } // namespace cuda // NumericTraits specializations for complex types were removed because // marking them as primitive (is_primitive=true) caused UB in CUB's // decoupled lookback scan (torn reads/writes for sizeof(T) >= 16). // CUB reduce/scan still works for complex types without these traits // because CuPy provides custom operator specializations (Max, Min, // ArgMax, ArgMin) below, and the dtype_dispatcher handles the type // dispatch at the C++ level. // See: https://github.com/NVIDIA/cccl/issues/8207 // need specializations for initial values namespace std { template <> class numeric_limits> { public: static __host__ __device__ thrust::complex infinity() noexcept { return thrust::complex(std::numeric_limits::infinity(), std::numeric_limits::infinity()); } static constexpr bool has_infinity = true; }; template <> class numeric_limits> { public: static __host__ __device__ thrust::complex infinity() noexcept { return thrust::complex(std::numeric_limits::infinity(), std::numeric_limits::infinity()); } static constexpr bool has_infinity = true; }; template <> class numeric_limits<__half> { public: static __host__ __device__ constexpr __half infinity() noexcept { unsigned short inf_half = 0x7C00U; #if (defined(_MSC_VER) && _MSC_VER >= 1920) #if CUDA_VERSION < 11030 // WAR: CUDA 11.2.x + VS 2019 fails with __builtin_bit_cast union caster { unsigned short u_; __half h_; }; return caster{inf_half}.h_; #else // CUDA_VERSION < 11030 // WAR: // - we want a constexpr here, but reinterpret_cast cannot be used // - we want to use std::bit_cast, but it requires C++20 which is too new // - we use the compiler builtin, fortunately both gcc and msvc have it return __builtin_bit_cast(__half, inf_half); #endif #else return *reinterpret_cast<__half*>(&inf_half); #endif } static constexpr bool has_infinity = true; }; } // namespace std #else // hipCUB internally uses std::numeric_limits, so we should provide specializations for the complex numbers. // Note that there's std::complex, so to avoid name collision we must use the full decoration (thrust::complex)! // TODO(leofang): wrap CuPy's thrust namespace with another one (say, cupy::thrust) for safer scope resolution? #define CUPY_CUB_NAMESPACE hipcub namespace std { template <> class numeric_limits> { public: static __host__ __device__ thrust::complex max() noexcept { return thrust::complex(std::numeric_limits::max(), std::numeric_limits::max()); } static __host__ __device__ thrust::complex lowest() noexcept { return thrust::complex(-std::numeric_limits::max(), -std::numeric_limits::max()); } static __host__ __device__ thrust::complex infinity() noexcept { return thrust::complex(std::numeric_limits::infinity(), std::numeric_limits::infinity()); } static constexpr bool has_infinity = true; }; template <> class numeric_limits> { public: static __host__ __device__ thrust::complex max() noexcept { return thrust::complex(std::numeric_limits::max(), std::numeric_limits::max()); } static __host__ __device__ thrust::complex lowest() noexcept { return thrust::complex(-std::numeric_limits::max(), -std::numeric_limits::max()); } static __host__ __device__ thrust::complex infinity() noexcept { return thrust::complex(std::numeric_limits::infinity(), std::numeric_limits::infinity()); } static constexpr bool has_infinity = true; }; // Copied from https://github.com/ROCmSoftwarePlatform/hipCUB/blob/master-rocm-3.5/hipcub/include/hipcub/backend/rocprim/device/device_reduce.hpp // (For some reason the specialization for __half defined in the above file does not work, so we have to go // through the same route as we did above for complex numbers.) template <> class numeric_limits<__half> { public: static __host__ __device__ __half max() noexcept { unsigned short max_half = 0x7bff; __half max_value = *reinterpret_cast<__half*>(&max_half); return max_value; } static __host__ __device__ __half lowest() noexcept { unsigned short lowest_half = 0xfbff; __half lowest_value = *reinterpret_cast<__half*>(&lowest_half); return lowest_value; } static __host__ __device__ __half infinity() noexcept { unsigned short inf_half = 0x7C00U; __half inf_value = *reinterpret_cast<__half*>(&inf_half); return inf_value; } static constexpr bool has_infinity = true; }; } // namespace std using namespace hipcub; #endif // ifndef CUPY_USE_HIP __host__ __device__ __half half_negate_inf() { unsigned short minf_half = 0xFC00U; __half* minf_value = reinterpret_cast<__half*>(&minf_half); return *minf_value; } /* ------------------------------------ end of boilerplate ------------------------------------ */ /* ------------------------------------ "Patches" to CUB ------------------------------------ This stub is needed because CUB does not have a built-in "prod" operator */ // // product functor // #ifdef CUPY_USE_HIP struct _multiply { template __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a * b; } }; #else using _multiply = cuda::std::multiplies<>; #endif // // arange functor: arange(0, n+1) -> arange(0, n+1, step_size) // struct _arange { private: int step_size; public: __host__ __device__ __forceinline__ _arange(int i): step_size(i) {} __host__ __device__ __forceinline__ int operator()(const int &in) const { return step_size * in; } }; #ifndef CUPY_USE_HIP typedef thrust::transform_iterator<_arange, thrust::counting_iterator> seg_offset_itr; #else typedef TransformInputIterator> seg_offset_itr; #endif /* These stubs are needed because CUB does not handle NaNs properly, while NumPy has certain behaviors with which we must comply. CUDA/HIP have different signatures for Max/Min because of the recent changes in CCCL (for the former). */ #if ((__CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ == 2)) \ && (__CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__))) || (defined(__HIPCC__) || defined(CUPY_USE_HIP)) __host__ __device__ __forceinline__ bool half_isnan(const __half& x) { #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) return __hisnan(x); #else // TODO: avoid cast to float return isnan(__half2float(x)); #endif } __host__ __device__ __forceinline__ bool half_less(const __half& l, const __half& r) { #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) return l < r; #else // TODO: avoid cast to float return __half2float(l) < __half2float(r); #endif } __host__ __device__ __forceinline__ bool half_equal(const __half& l, const __half& r) { #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) return l == r; #else // TODO: avoid cast to float return __half2float(l) == __half2float(r); #endif } #endif #ifdef CUPY_USE_HIP // // Max() // template struct select_max { using type = Max; }; // specialization for float for handling NaNs template <> __host__ __device__ __forceinline__ float Max::operator()(const float &a, const float &b) const { // NumPy behavior: NaN is always chosen! if (isnan(a)) {return a;} else if (isnan(b)) {return b;} else {return a < b ? b : a;} } // specialization for double for handling NaNs template <> __host__ __device__ __forceinline__ double Max::operator()(const double &a, const double &b) const { // NumPy behavior: NaN is always chosen! if (isnan(a)) {return a;} else if (isnan(b)) {return b;} else {return a < b ? b : a;} } // specialization for complex for handling NaNs template <> __host__ __device__ __forceinline__ complex Max::operator()(const complex &a, const complex &b) const { // - TODO(leofang): just call max() here when the bug in cupy/complex.cuh is fixed // - NumPy behavior: If both a and b contain NaN, the first argument is chosen // - isnan() and max() are defined in cupy/complex.cuh if (isnan(a)) {return a;} else if (isnan(b)) {return b;} else {return a < b ? b : a;} } // specialization for complex for handling NaNs template <> __host__ __device__ __forceinline__ complex Max::operator()(const complex &a, const complex &b) const { // - TODO(leofang): just call max() here when the bug in cupy/complex.cuh is fixed // - NumPy behavior: If both a and b contain NaN, the first argument is chosen // - isnan() and max() are defined in cupy/complex.cuh if (isnan(a)) {return a;} else if (isnan(b)) {return b;} else {return a < b ? b : a;} } #if ((__CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ == 2)) \ && (__CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__))) || (defined(__HIPCC__) || defined(CUPY_USE_HIP)) // specialization for half for handling NaNs template <> __host__ __device__ __forceinline__ __half Max::operator()(const __half &a, const __half &b) const { // NumPy behavior: NaN is always chosen! if (half_isnan(a)) {return a;} else if (half_isnan(b)) {return b;} else { return half_less(a, b) ? b : a; } } #endif // // Min() // template struct select_min { using type = Min; }; // specialization for float for handling NaNs template <> __host__ __device__ __forceinline__ float Min::operator()(const float &a, const float &b) const { // NumPy behavior: NaN is always chosen! if (isnan(a)) {return a;} else if (isnan(b)) {return b;} else {return a < b ? a : b;} } // specialization for double for handling NaNs template <> __host__ __device__ __forceinline__ double Min::operator()(const double &a, const double &b) const { // NumPy behavior: NaN is always chosen! if (isnan(a)) {return a;} else if (isnan(b)) {return b;} else {return a < b ? a : b;} } // specialization for complex for handling NaNs template <> __host__ __device__ __forceinline__ complex Min::operator()(const complex &a, const complex &b) const { // - TODO(leofang): just call min() here when the bug in cupy/complex.cuh is fixed // - NumPy behavior: If both a and b contain NaN, the first argument is chosen // - isnan() and min() are defined in cupy/complex.cuh if (isnan(a)) {return a;} else if (isnan(b)) {return b;} else {return a < b ? a : b;} } // specialization for complex for handling NaNs template <> __host__ __device__ __forceinline__ complex Min::operator()(const complex &a, const complex &b) const { // - TODO(leofang): just call min() here when the bug in cupy/complex.cuh is fixed // - NumPy behavior: If both a and b contain NaN, the first argument is chosen // - isnan() and min() are defined in cupy/complex.cuh if (isnan(a)) {return a;} else if (isnan(b)) {return b;} else {return a < b ? a : b;} } #if ((__CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ == 2)) \ && (__CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__))) || (defined(__HIPCC__) || defined(CUPY_USE_HIP)) // specialization for half for handling NaNs template <> __host__ __device__ __forceinline__ __half Min::operator()(const __half &a, const __half &b) const { // NumPy behavior: NaN is always chosen! if (half_isnan(a)) {return a;} else if (half_isnan(b)) {return b;} else { return half_less(a, b) ? a : b; } } #endif #endif // ifdef CUPY_USE_HIP // // ArgMax() // // specialization for float for handling NaNs template <> __host__ __device__ __forceinline__ KeyValuePair ArgMax::operator()( const KeyValuePair &a, const KeyValuePair &b) const { if (isnan(a.value)) return a; else if (isnan(b.value)) return b; else if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) return b; else return a; } // specialization for double for handling NaNs template <> __host__ __device__ __forceinline__ KeyValuePair ArgMax::operator()( const KeyValuePair &a, const KeyValuePair &b) const { if (isnan(a.value)) return a; else if (isnan(b.value)) return b; else if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) return b; else return a; } // specialization for complex for handling NaNs template <> __host__ __device__ __forceinline__ KeyValuePair> ArgMax::operator()( const KeyValuePair> &a, const KeyValuePair> &b) const { if (isnan(a.value)) return a; else if (isnan(b.value)) return b; else if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) return b; else return a; } // specialization for complex for handling NaNs template <> __host__ __device__ __forceinline__ KeyValuePair> ArgMax::operator()( const KeyValuePair> &a, const KeyValuePair> &b) const { if (isnan(a.value)) return a; else if (isnan(b.value)) return b; else if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) return b; else return a; } #if ((__CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ == 2)) \ && (__CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__))) || (defined(__HIPCC__) || defined(CUPY_USE_HIP)) // specialization for half for handling NaNs template <> __host__ __device__ __forceinline__ KeyValuePair ArgMax::operator()( const KeyValuePair &a, const KeyValuePair &b) const { if (half_isnan(a.value)) return a; else if (half_isnan(b.value)) return b; else if ((half_less(a.value, b.value)) || (half_equal(a.value, b.value) && (b.key < a.key))) return b; else return a; } #endif // // ArgMin() // // specialization for float for handling NaNs template <> __host__ __device__ __forceinline__ KeyValuePair ArgMin::operator()( const KeyValuePair &a, const KeyValuePair &b) const { if (isnan(a.value)) return a; else if (isnan(b.value)) return b; else if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) return b; else return a; } // specialization for double for handling NaNs template <> __host__ __device__ __forceinline__ KeyValuePair ArgMin::operator()( const KeyValuePair &a, const KeyValuePair &b) const { if (isnan(a.value)) return a; else if (isnan(b.value)) return b; else if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) return b; else return a; } // specialization for complex for handling NaNs template <> __host__ __device__ __forceinline__ KeyValuePair> ArgMin::operator()( const KeyValuePair> &a, const KeyValuePair> &b) const { if (isnan(a.value)) return a; else if (isnan(b.value)) return b; else if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) return b; else return a; } // specialization for complex for handling NaNs template <> __host__ __device__ __forceinline__ KeyValuePair> ArgMin::operator()( const KeyValuePair> &a, const KeyValuePair> &b) const { if (isnan(a.value)) return a; else if (isnan(b.value)) return b; else if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) return b; else return a; } #if ((__CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ == 2)) \ && (__CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__))) || (defined(__HIPCC__) || defined(CUPY_USE_HIP)) // specialization for half for handling NaNs template <> __host__ __device__ __forceinline__ KeyValuePair ArgMin::operator()( const KeyValuePair &a, const KeyValuePair &b) const { if (half_isnan(a.value)) return a; else if (half_isnan(b.value)) return b; else if ((half_less(b.value, a.value)) || (half_equal(a.value, b.value) && (b.key < a.key))) return b; else return a; } #endif #ifndef CUPY_USE_HIP // // Max() // template struct select_max { #if CCCL_VERSION >= 2008000 using type = cuda::maximum<>; #else using type = cub::Max; #endif }; template struct nan_handling_max { __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { // NumPy behavior: NaN is always chosen! if (isnan(a)) {return a;} else if (isnan(b)) {return b;} else {return a < b ? b : a;} } }; template <> struct select_max { using type = nan_handling_max; }; template <> struct select_max { using type = nan_handling_max; }; // - NumPy behavior: If both a and b contain NaN, the first argument is chosen // - isnan() and max() are defined in cupy/complex.cuh template <> struct select_max> { using type = nan_handling_max>; }; template <> struct select_max> { using type = nan_handling_max>; }; #if ((__CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ == 2)) \ && (__CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__))) || (defined(__HIPCC__) || defined(CUPY_USE_HIP)) template <> struct select_max<__half> { struct type { __host__ __device__ __forceinline__ __half operator()(const __half &a, const __half &b) const { // NumPy behavior: NaN is always chosen! if (half_isnan(a)) {return a;} else if (half_isnan(b)) {return b;} else { return half_less(a, b) ? b : a; } } }; }; #endif // // Min() // template struct select_min { #if CCCL_VERSION >= 2008000 using type = cuda::minimum<>; #else using type = cub::Min; #endif }; template struct nan_handling_min { __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { // NumPy behavior: NaN is always chosen! if (isnan(a)) {return a;} else if (isnan(b)) {return b;} else {return a < b ? a : b;} } }; template <> struct select_min { using type = nan_handling_min; }; template <> struct select_min { using type = nan_handling_min; }; // - NumPy behavior: If both a and b contain NaN, the first argument is chosen // - isnan() and min() are defined in cupy/complex.cuh template <> struct select_min> { using type = nan_handling_min>; }; template <> struct select_min> { using type = nan_handling_min>; }; #if ((__CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ == 2)) \ && (__CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__))) || (defined(__HIPCC__) || defined(CUPY_USE_HIP)) template <> struct select_min<__half> { struct type { __host__ __device__ __forceinline__ __half operator()(const __half &a, const __half &b) const { // NumPy behavior: NaN is always chosen! if (half_isnan(a)) {return a;} else if (half_isnan(b)) {return b;} else { return half_less(a, b) ? a: b; } } }; }; #endif #endif // #ifndef CUPY_USE_HIP /* ------------------------------------ End of "patches" ------------------------------------ */ // // **** CUB Sum **** // struct _cub_reduce_sum { template void operator()(void* workspace, size_t& workspace_size, void* x, void* y, int num_items, cudaStream_t s) { DeviceReduce::Sum(workspace, workspace_size, static_cast(x), static_cast(y), num_items, s); } }; struct _cub_segmented_reduce_sum { template void operator()(void* workspace, size_t& workspace_size, void* x, void* y, int num_segments, seg_offset_itr offset_start, cudaStream_t s) { DeviceSegmentedReduce::Sum(workspace, workspace_size, static_cast(x), static_cast(y), num_segments, offset_start, offset_start+1, s); } }; // // **** CUB Prod **** // struct _cub_reduce_prod { template void operator()(void* workspace, size_t& workspace_size, void* x, void* y, int num_items, cudaStream_t s) { _multiply product_op; // the init value is cast from 1.0f because on host __half can only be // initialized by float or double; static_cast<__half>(1) = 0 on host. DeviceReduce::Reduce(workspace, workspace_size, static_cast(x), static_cast(y), num_items, product_op, static_cast(1.0f), s); } }; struct _cub_segmented_reduce_prod { template void operator()(void* workspace, size_t& workspace_size, void* x, void* y, int num_segments, seg_offset_itr offset_start, cudaStream_t s) { _multiply product_op; // the init value is cast from 1.0f because on host __half can only be // initialized by float or double; static_cast<__half>(1) = 0 on host. DeviceSegmentedReduce::Reduce(workspace, workspace_size, static_cast(x), static_cast(y), num_segments, offset_start, offset_start+1, product_op, static_cast(1.0f), s); } }; // // **** CUB Min **** // struct _cub_reduce_min { template void operator()(void* workspace, size_t& workspace_size, void* x, void* y, int num_items, cudaStream_t s) { if constexpr (std::numeric_limits::has_infinity) { DeviceReduce::Reduce(workspace, workspace_size, static_cast(x), static_cast(y), num_items, typename select_min::type{}, std::numeric_limits::infinity(), s); } else { DeviceReduce::Min(workspace, workspace_size, static_cast(x), static_cast(y), num_items, s); } } }; struct _cub_segmented_reduce_min { template void operator()(void* workspace, size_t& workspace_size, void* x, void* y, int num_segments, seg_offset_itr offset_start, cudaStream_t s) { if constexpr (std::numeric_limits::has_infinity) { DeviceSegmentedReduce::Reduce(workspace, workspace_size, static_cast(x), static_cast(y), num_segments, offset_start, offset_start+1, typename select_min::type{}, std::numeric_limits::infinity(), s); } else { DeviceSegmentedReduce::Min(workspace, workspace_size, static_cast(x), static_cast(y), num_segments, offset_start, offset_start+1, s); } } }; // // **** CUB Max **** // struct _cub_reduce_max { template void operator()(void* workspace, size_t& workspace_size, void* x, void* y, int num_items, cudaStream_t s) { if constexpr (std::numeric_limits::has_infinity) { // to avoid compiler error: invalid argument type '__half' to unary expression on HIP... if constexpr (std::is_same_v) { DeviceReduce::Reduce(workspace, workspace_size, static_cast(x), static_cast(y), num_items, typename select_max::type{}, half_negate_inf(), s); } else { DeviceReduce::Reduce(workspace, workspace_size, static_cast(x), static_cast(y), num_items, typename select_max::type{}, -std::numeric_limits::infinity(), s); } } else { DeviceReduce::Max(workspace, workspace_size, static_cast(x), static_cast(y), num_items, s); } } }; struct _cub_segmented_reduce_max { template void operator()(void* workspace, size_t& workspace_size, void* x, void* y, int num_segments, seg_offset_itr offset_start, cudaStream_t s) { if constexpr (std::numeric_limits::has_infinity) { // to avoid compiler error: invalid argument type '__half' to unary expression on HIP... if constexpr (std::is_same_v) { DeviceSegmentedReduce::Reduce(workspace, workspace_size, static_cast(x), static_cast(y), num_segments, offset_start, offset_start+1, typename select_max::type{}, half_negate_inf(), s); } else { DeviceSegmentedReduce::Reduce(workspace, workspace_size, static_cast(x), static_cast(y), num_segments, offset_start, offset_start+1, typename select_max::type{}, -std::numeric_limits::infinity(), s); } } else { DeviceSegmentedReduce::Max(workspace, workspace_size, static_cast(x), static_cast(y), num_segments, offset_start, offset_start+1, s); } } }; // // **** CUB ArgMin **** // struct _cub_reduce_argmin { template void operator()(void* workspace, size_t& workspace_size, void* x, void* y, int num_items, cudaStream_t s) { DeviceReduce::ArgMin(workspace, workspace_size, static_cast(x), static_cast*>(y), num_items, s); } }; // TODO(leofang): add _cub_segmented_reduce_argmin // // **** CUB ArgMax **** // struct _cub_reduce_argmax { template void operator()(void* workspace, size_t& workspace_size, void* x, void* y, int num_items, cudaStream_t s) { DeviceReduce::ArgMax(workspace, workspace_size, static_cast(x), static_cast*>(y), num_items, s); } }; // TODO(leofang): add _cub_segmented_reduce_argmax // // **** CUB InclusiveSum **** // struct _cub_inclusive_sum { template void operator()(void* workspace, size_t& workspace_size, void* input, void* output, int num_items, cudaStream_t s) { DeviceScan::InclusiveSum(workspace, workspace_size, static_cast(input), static_cast(output), num_items, s); } }; // // **** CUB inclusive product **** // struct _cub_inclusive_product { template void operator()(void* workspace, size_t& workspace_size, void* input, void* output, int num_items, cudaStream_t s) { _multiply product_op; DeviceScan::InclusiveScan(workspace, workspace_size, static_cast(input), static_cast(output), product_op, num_items, s); } }; // // **** CUB histogram range **** // struct _cub_histogram_range { template ::value, double, sampleT>::type> void operator()(void* workspace, size_t& workspace_size, void* input, void* output, int n_bins, void* bins, size_t n_samples, cudaStream_t s) const { // Ugly hack to avoid specializing complex types, which cub::DeviceHistogram does not support. // TODO(leofang): revisit this part when complex support is added to cupy.histogram() typedef typename std::conditional<(std::is_same>::value || std::is_same>::value), double, sampleT>::type h_sampleT; typedef typename std::conditional<(std::is_same>::value || std::is_same>::value), double, binT>::type h_binT; // TODO(leofang): CUB has a bug that when specializing n_samples with type size_t, // it would error out. Before the fix (thrust/cub#38) is merged we disable the code // path splitting for now. A type/range check must be done in the caller. // TODO(leofang): check if hipCUB has the same bug or not // if (n_samples < (1ULL << 31)) { int num_samples = n_samples; DeviceHistogram::HistogramRange(workspace, workspace_size, static_cast(input), #ifndef CUPY_USE_HIP static_cast(output), n_bins, static_cast(bins), num_samples, s); #else // rocPRIM looks up atomic_add() from the namespace rocprim::detail; there's no way we can // inject a "long long" version as we did for CUDA, so we must do it in "unsigned long long" // and convert later... static_cast(output), n_bins, static_cast(bins), num_samples, s); #endif // } else { // DeviceHistogram::HistogramRange(workspace, workspace_size, static_cast(input), // static_cast(output), n_bins, static_cast(bins), n_samples, s); // } } }; // // **** CUB histogram even **** // struct _cub_histogram_even { template void operator()(void* workspace, size_t& workspace_size, void* input, void* output, int& n_bins, int& lower, int& upper, size_t n_samples, cudaStream_t s) const { #ifndef CUPY_USE_HIP // Ugly hack to avoid specializing numerical types typedef typename std::conditional::value, sampleT, int>::type h_sampleT; int num_samples = n_samples; static_assert(sizeof(long long) == sizeof(intptr_t), "not supported"); DeviceHistogram::HistogramEven(workspace, workspace_size, static_cast(input), static_cast(output), n_bins, lower, upper, num_samples, s); #else throw std::runtime_error("HIP is not supported yet"); #endif } }; // // APIs exposed to CuPy // /* -------- device reduce -------- */ void cub_device_reduce(void* workspace, size_t& workspace_size, void* x, void* y, int num_items, cudaStream_t stream, int op, int dtype_id) { switch(op) { case CUPY_CUB_SUM: return dtype_dispatcher(dtype_id, _cub_reduce_sum(), workspace, workspace_size, x, y, num_items, stream); case CUPY_CUB_MIN: return dtype_dispatcher(dtype_id, _cub_reduce_min(), workspace, workspace_size, x, y, num_items, stream); case CUPY_CUB_MAX: return dtype_dispatcher(dtype_id, _cub_reduce_max(), workspace, workspace_size, x, y, num_items, stream); case CUPY_CUB_ARGMIN: return dtype_dispatcher(dtype_id, _cub_reduce_argmin(), workspace, workspace_size, x, y, num_items, stream); case CUPY_CUB_ARGMAX: return dtype_dispatcher(dtype_id, _cub_reduce_argmax(), workspace, workspace_size, x, y, num_items, stream); case CUPY_CUB_PROD: return dtype_dispatcher(dtype_id, _cub_reduce_prod(), workspace, workspace_size, x, y, num_items, stream); default: throw std::runtime_error("Unsupported operation"); } } size_t cub_device_reduce_get_workspace_size(void* x, void* y, int num_items, cudaStream_t stream, int op, int dtype_id) { size_t workspace_size = 0; cub_device_reduce(NULL, workspace_size, x, y, num_items, stream, op, dtype_id); return workspace_size; } /* -------- device segmented reduce -------- */ void cub_device_segmented_reduce(void* workspace, size_t& workspace_size, void* x, void* y, int num_segments, int segment_size, cudaStream_t stream, int op, int dtype_id) { // CUB internally use int for offset... // This iterates over [0, segment_size, 2*segment_size, 3*segment_size, ...] #ifndef CUPY_USE_HIP thrust::counting_iterator count_itr(0); #else rocprim::counting_iterator count_itr(0); #endif _arange scaling(segment_size); seg_offset_itr itr(count_itr, scaling); switch(op) { case CUPY_CUB_SUM: return dtype_dispatcher(dtype_id, _cub_segmented_reduce_sum(), workspace, workspace_size, x, y, num_segments, itr, stream); case CUPY_CUB_MIN: return dtype_dispatcher(dtype_id, _cub_segmented_reduce_min(), workspace, workspace_size, x, y, num_segments, itr, stream); case CUPY_CUB_MAX: return dtype_dispatcher(dtype_id, _cub_segmented_reduce_max(), workspace, workspace_size, x, y, num_segments, itr, stream); case CUPY_CUB_PROD: return dtype_dispatcher(dtype_id, _cub_segmented_reduce_prod(), workspace, workspace_size, x, y, num_segments, itr, stream); default: throw std::runtime_error("Unsupported operation"); } } size_t cub_device_segmented_reduce_get_workspace_size(void* x, void* y, int num_segments, int segment_size, cudaStream_t stream, int op, int dtype_id) { size_t workspace_size = 0; cub_device_segmented_reduce(NULL, workspace_size, x, y, num_segments, segment_size, stream, op, dtype_id); return workspace_size; } /* -------- device scan -------- */ void cub_device_scan(void* workspace, size_t& workspace_size, void* x, void* y, int num_items, cudaStream_t stream, int op, int dtype_id) { switch(op) { case CUPY_CUB_CUMSUM: return dtype_dispatcher(dtype_id, _cub_inclusive_sum(), workspace, workspace_size, x, y, num_items, stream); case CUPY_CUB_CUMPROD: return dtype_dispatcher(dtype_id, _cub_inclusive_product(), workspace, workspace_size, x, y, num_items, stream); default: throw std::runtime_error("Unsupported operation"); } } size_t cub_device_scan_get_workspace_size(void* x, void* y, int num_items, cudaStream_t stream, int op, int dtype_id) { size_t workspace_size = 0; cub_device_scan(NULL, workspace_size, x, y, num_items, stream, op, dtype_id); return workspace_size; } /* -------- device histogram -------- */ void cub_device_histogram_range(void* workspace, size_t& workspace_size, void* x, void* y, int n_bins, void* bins, size_t n_samples, cudaStream_t stream, int dtype_id) { // TODO(leofang): support complex if (dtype_id == CUPY_TYPE_COMPLEX64 || dtype_id == CUPY_TYPE_COMPLEX128) { throw std::runtime_error("complex dtype is not yet supported"); } // TODO(leofang): n_samples is of type size_t, but if it's < 2^31 we cast it to int later return dtype_dispatcher(dtype_id, _cub_histogram_range(), workspace, workspace_size, x, y, n_bins, bins, n_samples, stream); } size_t cub_device_histogram_range_get_workspace_size(void* x, void* y, int n_bins, void* bins, size_t n_samples, cudaStream_t stream, int dtype_id) { size_t workspace_size = 0; cub_device_histogram_range(NULL, workspace_size, x, y, n_bins, bins, n_samples, stream, dtype_id); return workspace_size; } void cub_device_histogram_even(void* workspace, size_t& workspace_size, void* x, void* y, int n_bins, int lower, int upper, size_t n_samples, cudaStream_t stream, int dtype_id) { #ifndef CUPY_USE_HIP return dtype_dispatcher(dtype_id, _cub_histogram_even(), workspace, workspace_size, x, y, n_bins, lower, upper, n_samples, stream); #endif } size_t cub_device_histogram_even_get_workspace_size(void* x, void* y, int n_bins, int lower, int upper, size_t n_samples, cudaStream_t stream, int dtype_id) { size_t workspace_size = 0; cub_device_histogram_even(NULL, workspace_size, x, y, n_bins, lower, upper, n_samples, stream, dtype_id); return workspace_size; }