/****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ /** * \file * Callback operator types for supplying BlockScan prefixes */ #pragma once #include #if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) # pragma GCC system_header #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) # pragma clang system_header #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) # pragma system_header #endif // no system header #include #include #include #include #include #include #include #include #include #include CUB_NAMESPACE_BEGIN /****************************************************************************** * Prefix functor type for maintaining a running prefix while scanning a * region independent of other thread blocks ******************************************************************************/ /** * Stateful callback operator type for supplying BlockScan prefixes. * Maintains a running prefix that can be applied to consecutive * BlockScan operations. * * @tparam T * BlockScan value type * * @tparam ScanOpT * Wrapped scan operator type */ template struct BlockScanRunningPrefixOp { /// Wrapped scan operator ScanOpT op; /// Running block-wide prefix T running_total; /// Constructor _CCCL_DEVICE _CCCL_FORCEINLINE BlockScanRunningPrefixOp(ScanOpT op) : op(op) {} /// Constructor _CCCL_DEVICE _CCCL_FORCEINLINE BlockScanRunningPrefixOp(T starting_prefix, ScanOpT op) : op(op) , running_total(starting_prefix) {} /** * Prefix callback operator. Returns the block-wide running_total in thread-0. * * @param block_aggregate * The aggregate sum of the BlockScan inputs */ _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(const T& block_aggregate) { T retval = running_total; running_total = op(running_total, block_aggregate); return retval; } }; /****************************************************************************** * Generic tile status interface types for block-cooperative scans ******************************************************************************/ /** * Enumerations of tile status */ enum ScanTileStatus { SCAN_TILE_OOB, // Out-of-bounds (e.g., padding) SCAN_TILE_INVALID = 99, // Not yet processed SCAN_TILE_PARTIAL, // Tile aggregate is available SCAN_TILE_INCLUSIVE, // Inclusive tile prefix is available }; /** * Enum class used for specifying the memory order that shall be enforced while reading and writing the tile status. */ enum class MemoryOrder { // Uses relaxed loads when reading a tile's status and relaxed stores when updating a tile's status relaxed, // Uses load acquire when reading a tile's status and store release when updating a tile's status acquire_release }; namespace detail { template _CCCL_DEVICE _CCCL_FORCEINLINE void delay() { NV_IF_TARGET(NV_PROVIDES_SM_70, (if (Delay > 0) { if (gridDim.x < GridThreshold) { __threadfence_block(); } else { __nanosleep(Delay); } })); } template _CCCL_DEVICE _CCCL_FORCEINLINE void delay(int ns) { NV_IF_TARGET(NV_PROVIDES_SM_70, (if (ns > 0) { if (gridDim.x < GridThreshold) { __threadfence_block(); } else { __nanosleep(ns); } })); } template _CCCL_DEVICE _CCCL_FORCEINLINE void always_delay() { NV_IF_TARGET(NV_PROVIDES_SM_70, (__nanosleep(Delay);)); } _CCCL_DEVICE _CCCL_FORCEINLINE void always_delay(int ns) { NV_IF_TARGET(NV_PROVIDES_SM_70, (__nanosleep(ns);), ((void) ns;)); } template _CCCL_DEVICE _CCCL_FORCEINLINE void delay_or_prevent_hoisting() { NV_IF_TARGET(NV_PROVIDES_SM_70, (delay();), (__threadfence_block();)); } template _CCCL_DEVICE _CCCL_FORCEINLINE void delay_or_prevent_hoisting(int ns) { NV_IF_TARGET(NV_PROVIDES_SM_70, (delay(ns);), ((void) ns; __threadfence_block();)); } template _CCCL_DEVICE _CCCL_FORCEINLINE void always_delay_or_prevent_hoisting() { NV_IF_TARGET(NV_PROVIDES_SM_70, (always_delay(Delay);), (__threadfence_block();)); } _CCCL_DEVICE _CCCL_FORCEINLINE void always_delay_or_prevent_hoisting(int ns) { NV_IF_TARGET(NV_PROVIDES_SM_70, (always_delay(ns);), ((void) ns; __threadfence_block();)); } template struct no_delay_constructor_t { struct delay_t { _CCCL_DEVICE _CCCL_FORCEINLINE void operator()() { NV_IF_TARGET(NV_PROVIDES_SM_70, (), (__threadfence_block();)); } }; _CCCL_DEVICE _CCCL_FORCEINLINE no_delay_constructor_t(unsigned int /* seed */) { delay(); } _CCCL_DEVICE _CCCL_FORCEINLINE delay_t operator()() { return {}; } }; template struct reduce_by_key_delay_constructor_t { struct delay_t { _CCCL_DEVICE _CCCL_FORCEINLINE void operator()() { NV_DISPATCH_TARGET( NV_IS_EXACTLY_SM_80, (delay();), NV_PROVIDES_SM_70, (delay<0, GridThreshold>();), NV_IS_DEVICE, (__threadfence_block();)); } }; _CCCL_DEVICE _CCCL_FORCEINLINE reduce_by_key_delay_constructor_t(unsigned int /* seed */) { delay(); } _CCCL_DEVICE _CCCL_FORCEINLINE delay_t operator()() { return {}; } }; template struct fixed_delay_constructor_t { struct delay_t { _CCCL_DEVICE _CCCL_FORCEINLINE void operator()() { delay_or_prevent_hoisting(); } }; _CCCL_DEVICE _CCCL_FORCEINLINE fixed_delay_constructor_t(unsigned int /* seed */) { delay(); } _CCCL_DEVICE _CCCL_FORCEINLINE delay_t operator()() { return {}; } }; template struct exponential_backoff_constructor_t { struct delay_t { int delay; _CCCL_DEVICE _CCCL_FORCEINLINE void operator()() { always_delay_or_prevent_hoisting(delay); delay <<= 1; } }; _CCCL_DEVICE _CCCL_FORCEINLINE exponential_backoff_constructor_t(unsigned int /* seed */) { always_delay(); } _CCCL_DEVICE _CCCL_FORCEINLINE delay_t operator()() { return {InitialDelay}; } }; template struct exponential_backoff_jitter_constructor_t { struct delay_t { static constexpr unsigned int a = 16807; static constexpr unsigned int c = 0; static constexpr unsigned int m = 1u << 31; unsigned int max_delay; unsigned int& seed; _CCCL_DEVICE _CCCL_FORCEINLINE unsigned int next(unsigned int min, unsigned int max) { return (seed = (a * seed + c) % m) % (max + 1 - min) + min; } _CCCL_DEVICE _CCCL_FORCEINLINE void operator()() { always_delay_or_prevent_hoisting(next(0, max_delay)); max_delay <<= 1; } }; unsigned int seed; _CCCL_DEVICE _CCCL_FORCEINLINE exponential_backoff_jitter_constructor_t(unsigned int seed) : seed(seed) { always_delay(); } _CCCL_DEVICE _CCCL_FORCEINLINE delay_t operator()() { return {InitialDelay, seed}; } }; template struct exponential_backoff_jitter_window_constructor_t { struct delay_t { static constexpr unsigned int a = 16807; static constexpr unsigned int c = 0; static constexpr unsigned int m = 1u << 31; unsigned int max_delay; unsigned int& seed; _CCCL_DEVICE _CCCL_FORCEINLINE unsigned int next(unsigned int min, unsigned int max) { return (seed = (a * seed + c) % m) % (max + 1 - min) + min; } _CCCL_DEVICE _CCCL_FORCEINLINE void operator()() { unsigned int next_max_delay = max_delay << 1; always_delay_or_prevent_hoisting(next(max_delay, next_max_delay)); max_delay = next_max_delay; } }; unsigned int seed; _CCCL_DEVICE _CCCL_FORCEINLINE exponential_backoff_jitter_window_constructor_t(unsigned int seed) : seed(seed) { always_delay(); } _CCCL_DEVICE _CCCL_FORCEINLINE delay_t operator()() { return {InitialDelay, seed}; } }; template struct exponential_backon_jitter_window_constructor_t { struct delay_t { static constexpr unsigned int a = 16807; static constexpr unsigned int c = 0; static constexpr unsigned int m = 1u << 31; unsigned int max_delay; unsigned int& seed; _CCCL_DEVICE _CCCL_FORCEINLINE unsigned int next(unsigned int min, unsigned int max) { return (seed = (a * seed + c) % m) % (max + 1 - min) + min; } _CCCL_DEVICE _CCCL_FORCEINLINE void operator()() { int prev_delay = max_delay >> 1; always_delay_or_prevent_hoisting(next(prev_delay, max_delay)); max_delay = prev_delay; } }; unsigned int seed; unsigned int max_delay = InitialDelay; _CCCL_DEVICE _CCCL_FORCEINLINE exponential_backon_jitter_window_constructor_t(unsigned int seed) : seed(seed) { always_delay(); } _CCCL_DEVICE _CCCL_FORCEINLINE delay_t operator()() { max_delay >>= 1; return {max_delay, seed}; } }; template struct exponential_backon_jitter_constructor_t { struct delay_t { static constexpr unsigned int a = 16807; static constexpr unsigned int c = 0; static constexpr unsigned int m = 1u << 31; unsigned int max_delay; unsigned int& seed; _CCCL_DEVICE _CCCL_FORCEINLINE unsigned int next(unsigned int min, unsigned int max) { return (seed = (a * seed + c) % m) % (max + 1 - min) + min; } _CCCL_DEVICE _CCCL_FORCEINLINE void operator()() { always_delay_or_prevent_hoisting(next(0, max_delay)); max_delay >>= 1; } }; unsigned int seed; unsigned int max_delay = InitialDelay; _CCCL_DEVICE _CCCL_FORCEINLINE exponential_backon_jitter_constructor_t(unsigned int seed) : seed(seed) { always_delay(); } _CCCL_DEVICE _CCCL_FORCEINLINE delay_t operator()() { max_delay >>= 1; return {max_delay, seed}; } }; template struct exponential_backon_constructor_t { struct delay_t { unsigned int delay; _CCCL_DEVICE _CCCL_FORCEINLINE void operator()() { always_delay_or_prevent_hoisting(delay); delay >>= 1; } }; unsigned int max_delay = InitialDelay; _CCCL_DEVICE _CCCL_FORCEINLINE exponential_backon_constructor_t(unsigned int /* seed */) { always_delay(); } _CCCL_DEVICE _CCCL_FORCEINLINE delay_t operator()() { max_delay >>= 1; return {max_delay}; } }; using default_no_delay_constructor_t = no_delay_constructor_t<450>; using default_no_delay_t = default_no_delay_constructor_t::delay_t; template using default_delay_constructor_t = ::cuda::std::_If::PRIMITIVE, fixed_delay_constructor_t<350, 450>, default_no_delay_constructor_t>; template using default_delay_t = typename default_delay_constructor_t::delay_t; template using default_reduce_by_key_delay_constructor_t = ::cuda::std::_If<(Traits::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16), reduce_by_key_delay_constructor_t<350, 450>, default_delay_constructor_t>>; /** * @brief Alias template for a ScanTileState specialized for a given value type, `T`, and memory order `Order`. * * @tparam T The ScanTileState's value type * @tparam Order The memory order to be implemented by the ScanTileState */ template struct tile_state_with_memory_order { ScanTileStateT& tile_state; using T = typename ScanTileStateT::StatusValueT; using StatusWord = typename ScanTileStateT::StatusWord; /** * Update the specified tile's inclusive value and corresponding status */ _CCCL_DEVICE _CCCL_FORCEINLINE void SetInclusive(int tile_idx, T tile_inclusive) { tile_state.template SetInclusive(tile_idx, tile_inclusive); } /** * Update the specified tile's partial value and corresponding status */ _CCCL_DEVICE _CCCL_FORCEINLINE void SetPartial(int tile_idx, T tile_partial) { tile_state.template SetPartial(tile_idx, tile_partial); } /** * Wait for the corresponding tile to become non-invalid */ template _CCCL_DEVICE _CCCL_FORCEINLINE void WaitForValid(int tile_idx, StatusWord& status, T& value, DelayT delay = {}) { tile_state.template WaitForValid(tile_idx, status, value, delay); } _CCCL_DEVICE _CCCL_FORCEINLINE T LoadValid(int tile_idx) { return tile_state.template LoadValid(tile_idx); } }; } // namespace detail /** * Tile status interface. */ template ::PRIMITIVE> struct ScanTileState; /** * Tile status interface specialized for scan status and value types * that can be combined into one machine word that can be * read/written coherently in a single access. */ template struct ScanTileState { using StatusValueT = T; // Status word type using StatusWord = ::cuda::std::_If< sizeof(T) == 8, unsigned long long, ::cuda::std::_If>>; // Unit word type using TxnWord = ::cuda::std::_If>; // Device word type struct TileDescriptor { StatusWord status; T value; }; // Constants enum { TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, }; // Device storage TxnWord* d_tile_descriptors; /// Constructor _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ScanTileState() : d_tile_descriptors(nullptr) {} /** * @brief Initializer * * @param[in] num_tiles * Number of tiles * * @param[in] d_temp_storage * Device-accessible allocation of temporary storage. * When nullptr, the required allocation size is written to \p temp_storage_bytes and no work is * done. * * @param[in] temp_storage_bytes * Size in bytes of \t d_temp_storage allocation */ _CCCL_HOST_DEVICE _CCCL_FORCEINLINE cudaError_t Init(int /*num_tiles*/, void* d_temp_storage, size_t /*temp_storage_bytes*/) { d_tile_descriptors = reinterpret_cast(d_temp_storage); return cudaSuccess; } /** * @brief Compute device memory needed for tile status * * @param[in] num_tiles * Number of tiles * * @param[out] temp_storage_bytes * Size in bytes of \t d_temp_storage allocation */ _CCCL_HOST_DEVICE _CCCL_FORCEINLINE static cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes) { // bytes needed for tile status descriptors temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TxnWord); return cudaSuccess; } /** * Initialize (from device) */ _CCCL_DEVICE _CCCL_FORCEINLINE void InitializeStatus(int num_tiles) { int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; TxnWord val = TxnWord(); TileDescriptor* descriptor = reinterpret_cast(&val); if (tile_idx < num_tiles) { // Not-yet-set descriptor->status = StatusWord(SCAN_TILE_INVALID); d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val; } if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) { // Padding descriptor->status = StatusWord(SCAN_TILE_OOB); d_tile_descriptors[threadIdx.x] = val; } } private: template _CCCL_DEVICE _CCCL_FORCEINLINE typename ::cuda::std::enable_if<(Order == MemoryOrder::relaxed), void>::type StoreStatus(TxnWord* ptr, TxnWord alias) { detail::store_relaxed(ptr, alias); } template _CCCL_DEVICE _CCCL_FORCEINLINE typename ::cuda::std::enable_if<(Order == MemoryOrder::acquire_release), void>::type StoreStatus(TxnWord* ptr, TxnWord alias) { detail::store_release(ptr, alias); } template _CCCL_DEVICE _CCCL_FORCEINLINE typename ::cuda::std::enable_if<(Order == MemoryOrder::relaxed), TxnWord>::type LoadStatus(TxnWord* ptr) { return detail::load_relaxed(ptr); } template _CCCL_DEVICE _CCCL_FORCEINLINE typename ::cuda::std::enable_if<(Order == MemoryOrder::acquire_release), TxnWord>::type LoadStatus(TxnWord* ptr) { // For pre-volta we hoist the memory barrier to outside the loop, i.e., after reading a valid state NV_IF_TARGET(NV_PROVIDES_SM_70, (return detail::load_acquire(ptr);), (return detail::load_relaxed(ptr);)); } template _CCCL_DEVICE _CCCL_FORCEINLINE typename ::cuda::std::enable_if<(Order == MemoryOrder::relaxed), void>::type ThreadfenceForLoadAcqPreVolta() {} template _CCCL_DEVICE _CCCL_FORCEINLINE typename ::cuda::std::enable_if<(Order == MemoryOrder::acquire_release), void>::type ThreadfenceForLoadAcqPreVolta() { NV_IF_TARGET(NV_PROVIDES_SM_70, (), (__threadfence();)); } public: template _CCCL_DEVICE _CCCL_FORCEINLINE void SetInclusive(int tile_idx, T tile_inclusive) { TileDescriptor tile_descriptor; tile_descriptor.status = SCAN_TILE_INCLUSIVE; tile_descriptor.value = tile_inclusive; TxnWord alias; *reinterpret_cast(&alias) = tile_descriptor; StoreStatus(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); } template _CCCL_DEVICE _CCCL_FORCEINLINE void SetPartial(int tile_idx, T tile_partial) { TileDescriptor tile_descriptor; tile_descriptor.status = SCAN_TILE_PARTIAL; tile_descriptor.value = tile_partial; TxnWord alias; *reinterpret_cast(&alias) = tile_descriptor; StoreStatus(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); } /** * Wait for the corresponding tile to become non-invalid */ template , MemoryOrder Order = MemoryOrder::relaxed> _CCCL_DEVICE _CCCL_FORCEINLINE void WaitForValid(int tile_idx, StatusWord& status, T& value, DelayT delay_or_prevent_hoisting = {}) { TileDescriptor tile_descriptor; { TxnWord alias = LoadStatus(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); tile_descriptor = reinterpret_cast(alias); } while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff)) { delay_or_prevent_hoisting(); TxnWord alias = LoadStatus(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); tile_descriptor = reinterpret_cast(alias); } // For pre-Volta and load acquire we emit relaxed loads in LoadStatus and hoist the threadfence here ThreadfenceForLoadAcqPreVolta(); status = tile_descriptor.status; value = tile_descriptor.value; } /** * Loads and returns the tile's value. The returned value is undefined if either (a) the tile's status is invalid or * (b) there is no memory fence between reading a non-invalid status and the call to LoadValid. */ _CCCL_DEVICE _CCCL_FORCEINLINE T LoadValid(int tile_idx) { TxnWord alias = d_tile_descriptors[TILE_STATUS_PADDING + tile_idx]; TileDescriptor tile_descriptor = reinterpret_cast(alias); return tile_descriptor.value; } }; /** * Tile status interface specialized for scan status and value types that * cannot be combined into one machine word. */ template struct ScanTileState { using StatusValueT = T; // Status word type using StatusWord = unsigned int; // Constants enum { TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, }; // Device storage StatusWord* d_tile_status; T* d_tile_partial; T* d_tile_inclusive; /// Constructor _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ScanTileState() : d_tile_status(nullptr) , d_tile_partial(nullptr) , d_tile_inclusive(nullptr) {} /** * @brief Initializer * * @param[in] num_tiles * Number of tiles * * @param[in] d_temp_storage * Device-accessible allocation of temporary storage. * When nullptr, the required allocation size is written to \p temp_storage_bytes and no work is * done. * * @param[in] temp_storage_bytes * Size in bytes of \t d_temp_storage allocation */ /// Initializer _CCCL_HOST_DEVICE _CCCL_FORCEINLINE cudaError_t Init(int num_tiles, void* d_temp_storage, size_t temp_storage_bytes) { cudaError_t error = cudaSuccess; do { void* allocations[3] = {}; size_t allocation_sizes[3]; // bytes needed for tile status descriptors allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord); // bytes needed for partials allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized); // bytes needed for inclusives allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized); // Compute allocation pointers into the single storage blob error = CubDebug(AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes)); if (cudaSuccess != error) { break; } // Alias the offsets d_tile_status = reinterpret_cast(allocations[0]); d_tile_partial = reinterpret_cast(allocations[1]); d_tile_inclusive = reinterpret_cast(allocations[2]); } while (0); return error; } /** * @brief Compute device memory needed for tile status * * @param[in] num_tiles * Number of tiles * * @param[out] temp_storage_bytes * Size in bytes of \t d_temp_storage allocation */ _CCCL_HOST_DEVICE _CCCL_FORCEINLINE static cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes) { // Specify storage allocation requirements size_t allocation_sizes[3]; // bytes needed for tile status descriptors allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord); // bytes needed for partials allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized); // bytes needed for inclusives allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized); // Set the necessary size of the blob void* allocations[3] = {}; return CubDebug(AliasTemporaries(nullptr, temp_storage_bytes, allocations, allocation_sizes)); } /** * Initialize (from device) */ _CCCL_DEVICE _CCCL_FORCEINLINE void InitializeStatus(int num_tiles) { int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; if (tile_idx < num_tiles) { // Not-yet-set d_tile_status[TILE_STATUS_PADDING + tile_idx] = StatusWord(SCAN_TILE_INVALID); } if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) { // Padding d_tile_status[threadIdx.x] = StatusWord(SCAN_TILE_OOB); } } /** * Update the specified tile's inclusive value and corresponding status */ template _CCCL_DEVICE _CCCL_FORCEINLINE void SetInclusive(int tile_idx, T tile_inclusive) { // Update tile inclusive value ThreadStore(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx, tile_inclusive); detail::store_release(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_INCLUSIVE)); } /** * Update the specified tile's partial value and corresponding status */ template _CCCL_DEVICE _CCCL_FORCEINLINE void SetPartial(int tile_idx, T tile_partial) { // Update tile partial value ThreadStore(d_tile_partial + TILE_STATUS_PADDING + tile_idx, tile_partial); detail::store_release(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_PARTIAL)); } /** * Wait for the corresponding tile to become non-invalid */ template _CCCL_DEVICE _CCCL_FORCEINLINE void WaitForValid(int tile_idx, StatusWord& status, T& value, DelayT delay = {}) { do { delay(); status = detail::load_relaxed(d_tile_status + TILE_STATUS_PADDING + tile_idx); __threadfence(); } while (WARP_ANY((status == SCAN_TILE_INVALID), 0xffffffff)); if (status == StatusWord(SCAN_TILE_PARTIAL)) { value = ThreadLoad(d_tile_partial + TILE_STATUS_PADDING + tile_idx); } else { value = ThreadLoad(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx); } } /** * Loads and returns the tile's value. The returned value is undefined if either (a) the tile's status is invalid or * (b) there is no memory fence between reading a non-invalid status and the call to LoadValid. */ _CCCL_DEVICE _CCCL_FORCEINLINE T LoadValid(int tile_idx) { return d_tile_inclusive[TILE_STATUS_PADDING + tile_idx]; } }; /****************************************************************************** * ReduceByKey tile status interface types for block-cooperative scans ******************************************************************************/ /** * Tile status interface for reduction by key. * */ template ::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16)> struct ReduceByKeyScanTileState; /** * Tile status interface for reduction by key, specialized for scan status and value types that * cannot be combined into one machine word. */ template struct ReduceByKeyScanTileState : ScanTileState> { using SuperClass = ScanTileState>; /// Constructor _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ReduceByKeyScanTileState() : SuperClass() {} }; /** * Tile status interface for reduction by key, specialized for scan status and value types that * can be combined into one machine word that can be read/written coherently in a single access. */ template struct ReduceByKeyScanTileState { using KeyValuePairT = KeyValuePair; // Constants enum { PAIR_SIZE = static_cast(sizeof(ValueT) + sizeof(KeyT)), TXN_WORD_SIZE = 1 << Log2::VALUE, STATUS_WORD_SIZE = TXN_WORD_SIZE - PAIR_SIZE, TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, }; // Status word type using StatusWord = ::cuda::std::_If< STATUS_WORD_SIZE == 8, unsigned long long, ::cuda::std:: _If>>; // Status word type using TxnWord = ::cuda::std:: _If>; // Device word type (for when sizeof(ValueT) == sizeof(KeyT)) struct TileDescriptorBigStatus { KeyT key; ValueT value; StatusWord status; }; // Device word type (for when sizeof(ValueT) != sizeof(KeyT)) struct TileDescriptorLittleStatus { ValueT value; StatusWord status; KeyT key; }; // Device word type using TileDescriptor = ::cuda::std::_If; // Device storage TxnWord* d_tile_descriptors; /// Constructor _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ReduceByKeyScanTileState() : d_tile_descriptors(nullptr) {} /** * @brief Initializer * * @param[in] num_tiles * Number of tiles * * @param[in] d_temp_storage * Device-accessible allocation of temporary storage. When nullptr, the required allocation size * is written to \p temp_storage_bytes and no work is done. * * @param[in] temp_storage_bytes * Size in bytes of \t d_temp_storage allocation */ _CCCL_HOST_DEVICE _CCCL_FORCEINLINE cudaError_t Init(int /*num_tiles*/, void* d_temp_storage, size_t /*temp_storage_bytes*/) { d_tile_descriptors = reinterpret_cast(d_temp_storage); return cudaSuccess; } /** * @brief Compute device memory needed for tile status * * @param[in] num_tiles * Number of tiles * * @param[out] temp_storage_bytes * Size in bytes of \t d_temp_storage allocation */ _CCCL_HOST_DEVICE _CCCL_FORCEINLINE static cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes) { // bytes needed for tile status descriptors temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TxnWord); return cudaSuccess; } /** * Initialize (from device) */ _CCCL_DEVICE _CCCL_FORCEINLINE void InitializeStatus(int num_tiles) { int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; TxnWord val = TxnWord(); TileDescriptor* descriptor = reinterpret_cast(&val); if (tile_idx < num_tiles) { // Not-yet-set descriptor->status = StatusWord(SCAN_TILE_INVALID); d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val; } if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) { // Padding descriptor->status = StatusWord(SCAN_TILE_OOB); d_tile_descriptors[threadIdx.x] = val; } } /** * Update the specified tile's inclusive value and corresponding status */ _CCCL_DEVICE _CCCL_FORCEINLINE void SetInclusive(int tile_idx, KeyValuePairT tile_inclusive) { TileDescriptor tile_descriptor; tile_descriptor.status = SCAN_TILE_INCLUSIVE; tile_descriptor.value = tile_inclusive.value; tile_descriptor.key = tile_inclusive.key; TxnWord alias; *reinterpret_cast(&alias) = tile_descriptor; detail::store_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); } _CCCL_DEVICE _CCCL_FORCEINLINE void SetPartial(int tile_idx, KeyValuePairT tile_partial) { TileDescriptor tile_descriptor; tile_descriptor.status = SCAN_TILE_PARTIAL; tile_descriptor.value = tile_partial.value; tile_descriptor.key = tile_partial.key; TxnWord alias; *reinterpret_cast(&alias) = tile_descriptor; detail::store_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); } /** * Wait for the corresponding tile to become non-invalid */ template ::delay_t> _CCCL_DEVICE _CCCL_FORCEINLINE void WaitForValid(int tile_idx, StatusWord& status, KeyValuePairT& value, DelayT delay_or_prevent_hoisting = {}) { // TxnWord alias = ThreadLoad(d_tile_descriptors + TILE_STATUS_PADDING + // tile_idx); TileDescriptor tile_descriptor = reinterpret_cast(alias); // // while (tile_descriptor.status == SCAN_TILE_INVALID) // { // __threadfence_block(); // prevent hoisting loads from loop // // alias = ThreadLoad(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); // tile_descriptor = reinterpret_cast(alias); // } // // status = tile_descriptor.status; // value.value = tile_descriptor.value; // value.key = tile_descriptor.key; TileDescriptor tile_descriptor; do { delay_or_prevent_hoisting(); TxnWord alias = detail::load_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); tile_descriptor = reinterpret_cast(alias); } while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff)); status = tile_descriptor.status; value.value = tile_descriptor.value; value.key = tile_descriptor.key; } }; /****************************************************************************** * Prefix call-back operator for coupling local block scan within a * block-cooperative scan ******************************************************************************/ /** * Stateful block-scan prefix functor. Provides the the running prefix for * the current tile by using the call-back warp to wait on on * aggregates/prefixes from predecessor tiles to become available. * * @tparam DelayConstructorT * Implementation detail, do not specify directly, requirements on the * content of this type are subject to breaking change. */ template > struct TilePrefixCallbackOp { // Parameterized warp reduce using WarpReduceT = WarpReduce; // Temporary storage type struct _TempStorage { typename WarpReduceT::TempStorage warp_reduce; T exclusive_prefix; T inclusive_prefix; T block_aggregate; }; // Alias wrapper allowing temporary storage to be unioned struct TempStorage : Uninitialized<_TempStorage> {}; // Type of status word using StatusWord = typename ScanTileStateT::StatusWord; // Fields _TempStorage& temp_storage; ///< Reference to a warp-reduction instance ScanTileStateT& tile_status; ///< Interface to tile status ScanOpT scan_op; ///< Binary scan operator int tile_idx; ///< The current tile index T exclusive_prefix; ///< Exclusive prefix for the tile T inclusive_prefix; ///< Inclusive prefix for the tile // Constructs prefix functor for a given tile index. // Precondition: thread blocks processing all of the predecessor tiles were scheduled. _CCCL_DEVICE _CCCL_FORCEINLINE TilePrefixCallbackOp(ScanTileStateT& tile_status, TempStorage& temp_storage, ScanOpT scan_op, int tile_idx) : temp_storage(temp_storage.Alias()) , tile_status(tile_status) , scan_op(scan_op) , tile_idx(tile_idx) {} // Computes the tile index and constructs prefix functor with it. // Precondition: thread block per tile assignment. _CCCL_DEVICE _CCCL_FORCEINLINE TilePrefixCallbackOp(ScanTileStateT& tile_status, TempStorage& temp_storage, ScanOpT scan_op) : TilePrefixCallbackOp(tile_status, temp_storage, scan_op, blockIdx.x) {} /** * @brief Block until all predecessors within the warp-wide window have non-invalid status * * @param predecessor_idx * Preceding tile index to inspect * * @param[out] predecessor_status * Preceding tile status * * @param[out] window_aggregate * Relevant partial reduction from this window of preceding tiles */ template > _CCCL_DEVICE _CCCL_FORCEINLINE void ProcessWindow(int predecessor_idx, StatusWord& predecessor_status, T& window_aggregate, DelayT delay = {}) { T value; tile_status.WaitForValid(predecessor_idx, predecessor_status, value, delay); // Perform a segmented reduction to get the prefix for the current window. // Use the swizzled scan operator because we are now scanning *down* towards thread0. int tail_flag = (predecessor_status == StatusWord(SCAN_TILE_INCLUSIVE)); window_aggregate = WarpReduceT(temp_storage.warp_reduce).TailSegmentedReduce(value, tail_flag, SwizzleScanOp(scan_op)); } // BlockScan prefix callback functor (called by the first warp) _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T block_aggregate) { // Update our status with our tile-aggregate if (threadIdx.x == 0) { detail::uninitialized_copy_single(&temp_storage.block_aggregate, block_aggregate); tile_status.SetPartial(tile_idx, block_aggregate); } int predecessor_idx = tile_idx - threadIdx.x - 1; StatusWord predecessor_status; T window_aggregate; // Wait for the warp-wide window of predecessor tiles to become valid DelayConstructorT construct_delay(tile_idx); ProcessWindow(predecessor_idx, predecessor_status, window_aggregate, construct_delay()); // The exclusive tile prefix starts out as the current window aggregate exclusive_prefix = window_aggregate; // Keep sliding the window back until we come across a tile whose inclusive prefix is known while (WARP_ALL((predecessor_status != StatusWord(SCAN_TILE_INCLUSIVE)), 0xffffffff)) { predecessor_idx -= CUB_PTX_WARP_THREADS; // Update exclusive tile prefix with the window prefix ProcessWindow(predecessor_idx, predecessor_status, window_aggregate, construct_delay()); exclusive_prefix = scan_op(window_aggregate, exclusive_prefix); } // Compute the inclusive tile prefix and update the status for this tile if (threadIdx.x == 0) { inclusive_prefix = scan_op(exclusive_prefix, block_aggregate); tile_status.SetInclusive(tile_idx, inclusive_prefix); detail::uninitialized_copy_single(&temp_storage.exclusive_prefix, exclusive_prefix); detail::uninitialized_copy_single(&temp_storage.inclusive_prefix, inclusive_prefix); } // Return exclusive_prefix return exclusive_prefix; } // Get the exclusive prefix stored in temporary storage _CCCL_DEVICE _CCCL_FORCEINLINE T GetExclusivePrefix() { return temp_storage.exclusive_prefix; } // Get the inclusive prefix stored in temporary storage _CCCL_DEVICE _CCCL_FORCEINLINE T GetInclusivePrefix() { return temp_storage.inclusive_prefix; } // Get the block aggregate stored in temporary storage _CCCL_DEVICE _CCCL_FORCEINLINE T GetBlockAggregate() { return temp_storage.block_aggregate; } _CCCL_DEVICE _CCCL_FORCEINLINE int GetTileIdx() const { return tile_idx; } }; CUB_NAMESPACE_END