/****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ /** * \file * cub::AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide select. */ #pragma once #include #if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) # pragma GCC system_header #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) # pragma clang system_header #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) # pragma system_header #endif // no system header #include #include #include #include #include #include #include #include #include #include #include CUB_NAMESPACE_BEGIN /****************************************************************************** * Tuning policy types ******************************************************************************/ /** * Parameterizable tuning policy type for AgentSelectIf * * @tparam _BLOCK_THREADS * Threads per thread block * * @tparam _ITEMS_PER_THREAD * Items per thread (per tile of input) * * @tparam _LOAD_ALGORITHM * The BlockLoad algorithm to use * * @tparam _LOAD_MODIFIER * Cache load modifier for reading input elements * * @tparam _SCAN_ALGORITHM * The BlockScan algorithm to use * * @tparam DelayConstructorT * Implementation detail, do not specify directly, requirements on the * content of this type are subject to breaking change. */ template > struct AgentSelectIfPolicy { enum { /// Threads per thread block BLOCK_THREADS = _BLOCK_THREADS, /// Items per thread (per tile of input) ITEMS_PER_THREAD = _ITEMS_PER_THREAD, }; /// The BlockLoad algorithm to use static constexpr BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; /// Cache load modifier for reading input elements static constexpr CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; /// The BlockScan algorithm to use static constexpr BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; struct detail { using delay_constructor_t = DelayConstructorT; }; }; /****************************************************************************** * Thread block abstractions ******************************************************************************/ namespace detail { template struct partition_distinct_output_t { using selected_iterator_t = SelectedOutputItT; using rejected_iterator_t = RejectedOutputItT; selected_iterator_t selected_it; rejected_iterator_t rejected_it; }; } // namespace detail /** * @brief AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in * device-wide selection * * Performs functor-based selection if SelectOpT functor type != NullType * Otherwise performs flag-based selection if FlagsInputIterator's value type != NullType * Otherwise performs discontinuity selection (keep unique) * * @tparam AgentSelectIfPolicyT * Parameterized AgentSelectIfPolicy tuning policy type * * @tparam InputIteratorT * Random-access input iterator type for selection items * * @tparam FlagsInputIteratorT * Random-access input iterator type for selections (NullType* if a selection functor or * discontinuity flagging is to be used for selection) * * @tparam OutputIteratorWrapperT * Either a random-access iterator or an instance of the `partition_distinct_output_t` template. * * @tparam SelectOpT * Selection operator type (NullType if selections or discontinuity flagging is to be used for * selection) * * @tparam EqualityOpT * Equality operator type (NullType if selection functor or selections is to be used for * selection) * * @tparam OffsetT * Signed integer type for global offsets * * @tparam ScanTileStateT * The tile state class used in the decoupled look-back * * @tparam KEEP_REJECTS * Whether or not we push rejected items to the back of the output */ template struct AgentSelectIf { //--------------------------------------------------------------------- // Types and constants //--------------------------------------------------------------------- using ScanTileStateT = ScanTileState; // Indicates whether the BlockLoad algorithm uses shared memory to load or exchange the data static constexpr bool loads_via_smem = !(AgentSelectIfPolicyT::LOAD_ALGORITHM == BLOCK_LOAD_DIRECT || AgentSelectIfPolicyT::LOAD_ALGORITHM == BLOCK_LOAD_STRIPED || AgentSelectIfPolicyT::LOAD_ALGORITHM == BLOCK_LOAD_VECTORIZE); // If this may be an *in-place* stream compaction, we need to ensure that all of a tile's items have been loaded // before signalling a subsequent thread block's partial or inclusive state, hence we need a store release when // updating a tile state. Similarly, we need to make sure that the load of previous tile states precede writing of // the stream-compacted items and, hence, we need a load acquire when reading those tile states. static constexpr MemoryOrder memory_order = ((!KEEP_REJECTS) && MayAlias && (!loads_via_smem)) ? MemoryOrder::acquire_release : MemoryOrder::relaxed; // If we need to enforce memory order for in-place stream compaction, wrap the default decoupled look-back tile // state in a helper class that enforces memory order on reads and writes using MemoryOrderedTileStateT = detail::tile_state_with_memory_order; // The input value type using InputT = cub::detail::value_t; // The flag value type using FlagT = cub::detail::value_t; // Constants enum { USE_SELECT_OP, USE_SELECT_FLAGS, USE_DISCONTINUITY, USE_STENCIL_WITH_OP }; static constexpr ::cuda::std::int32_t BLOCK_THREADS = AgentSelectIfPolicyT::BLOCK_THREADS; static constexpr ::cuda::std::int32_t ITEMS_PER_THREAD = AgentSelectIfPolicyT::ITEMS_PER_THREAD; static constexpr ::cuda::std::int32_t TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD; static constexpr bool TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1); static constexpr bool has_select_op = (!::cuda::std::is_same::value); static constexpr bool has_flags_it = (!::cuda::std::is_same::value); static constexpr bool use_stencil_with_op = has_select_op && has_flags_it; static constexpr auto SELECT_METHOD = use_stencil_with_op ? USE_STENCIL_WITH_OP : has_select_op ? USE_SELECT_OP : has_flags_it ? USE_SELECT_FLAGS : USE_DISCONTINUITY; // Cache-modified Input iterator wrapper type (for applying cache modifier) for items // Wrap the native input pointer with CacheModifiedValuesInputIterator // or directly use the supplied input iterator type using WrappedInputIteratorT = ::cuda::std::_If<::cuda::std::is_pointer::value, CacheModifiedInputIterator, InputIteratorT>; // Cache-modified Input iterator wrapper type (for applying cache modifier) for values // Wrap the native input pointer with CacheModifiedValuesInputIterator // or directly use the supplied input iterator type using WrappedFlagsInputIteratorT = ::cuda::std::_If<::cuda::std::is_pointer::value, CacheModifiedInputIterator, FlagsInputIteratorT>; // Parameterized BlockLoad type for input data using BlockLoadT = BlockLoad; // Parameterized BlockLoad type for flags using BlockLoadFlags = BlockLoad; // Parameterized BlockDiscontinuity type for items using BlockDiscontinuityT = BlockDiscontinuity; // Parameterized BlockScan type using BlockScanT = BlockScan; // Callback type for obtaining tile prefix during block scan using DelayConstructorT = typename AgentSelectIfPolicyT::detail::delay_constructor_t; using TilePrefixCallbackOpT = TilePrefixCallbackOp; // Item exchange type using ItemExchangeT = InputT[TILE_ITEMS]; // Shared memory type for this thread block union _TempStorage { struct ScanStorage { // Smem needed for tile scanning typename BlockScanT::TempStorage scan; // Smem needed for cooperative prefix callback typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for discontinuity detection typename BlockDiscontinuityT::TempStorage discontinuity; } scan_storage; // Smem needed for loading items typename BlockLoadT::TempStorage load_items; // Smem needed for loading values typename BlockLoadFlags::TempStorage load_flags; // Smem needed for compacting items (allows non POD items in this union) Uninitialized raw_exchange; }; // Alias wrapper allowing storage to be unioned struct TempStorage : Uninitialized<_TempStorage> {}; //--------------------------------------------------------------------- // Per-thread fields //--------------------------------------------------------------------- _TempStorage& temp_storage; ///< Reference to temp_storage WrappedInputIteratorT d_in; ///< Input items OutputIteratorWrapperT d_selected_out; ///< Unique output items WrappedFlagsInputIteratorT d_flags_in; ///< Input selection flags (if applicable) InequalityWrapper inequality_op; ///< T inequality operator SelectOpT select_op; ///< Selection operator OffsetT num_items; ///< Total number of input items //--------------------------------------------------------------------- // Constructor //--------------------------------------------------------------------- /** * @param temp_storage * Reference to temp_storage * * @param d_in * Input data * * @param d_flags_in * Input selection flags (if applicable) * * @param d_selected_out * Output data * * @param select_op * Selection operator * * @param equality_op * Equality operator * * @param num_items * Total number of input items */ _CCCL_DEVICE _CCCL_FORCEINLINE AgentSelectIf( TempStorage& temp_storage, InputIteratorT d_in, FlagsInputIteratorT d_flags_in, OutputIteratorWrapperT d_selected_out, SelectOpT select_op, EqualityOpT equality_op, OffsetT num_items) : temp_storage(temp_storage.Alias()) , d_in(d_in) , d_selected_out(d_selected_out) , d_flags_in(d_flags_in) , inequality_op(equality_op) , select_op(select_op) , num_items(num_items) {} //--------------------------------------------------------------------- // Utility methods for initializing the selections //--------------------------------------------------------------------- /** * Initialize selections (specialized for selection operator) */ template _CCCL_DEVICE _CCCL_FORCEINLINE void InitializeSelections( OffsetT /*tile_offset*/, OffsetT num_tile_items, InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { // Out-of-bounds items are selection_flags selection_flags[ITEM] = 1; if (!IS_LAST_TILE || (static_cast(threadIdx.x * ITEMS_PER_THREAD + ITEM) < num_tile_items)) { selection_flags[ITEM] = static_cast(select_op(items[ITEM])); } } } /** * Initialize selections (specialized for selection_op applied to d_flags_in) */ template _CCCL_DEVICE _CCCL_FORCEINLINE void InitializeSelections( OffsetT tile_offset, OffsetT num_tile_items, InputT (& /*items*/)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { CTA_SYNC(); FlagT flags[ITEMS_PER_THREAD]; if (IS_LAST_TILE) { // Initialize the out-of-bounds flags #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { selection_flags[ITEM] = true; } // Guarded loads BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags, num_tile_items); } else { BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags); } #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { // Set selection_flags for out-of-bounds items if ((!IS_LAST_TILE) || (static_cast(threadIdx.x * ITEMS_PER_THREAD + ITEM) < num_tile_items)) { selection_flags[ITEM] = static_cast(select_op(flags[ITEM])); } } } /** * Initialize selections (specialized for valid flags) */ template _CCCL_DEVICE _CCCL_FORCEINLINE void InitializeSelections( OffsetT tile_offset, OffsetT num_tile_items, InputT (& /*items*/)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { CTA_SYNC(); FlagT flags[ITEMS_PER_THREAD]; if (IS_LAST_TILE) { // Out-of-bounds items are selection_flags BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags, num_tile_items, 1); } else { BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags); } // Convert flag type to selection_flags type #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { selection_flags[ITEM] = static_cast(flags[ITEM]); } } /** * Initialize selections (specialized for discontinuity detection) */ template _CCCL_DEVICE _CCCL_FORCEINLINE void InitializeSelections( OffsetT tile_offset, OffsetT num_tile_items, InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { if (IS_FIRST_TILE) { CTA_SYNC(); // Set head selection_flags. First tile sets the first flag for the first item BlockDiscontinuityT(temp_storage.scan_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op); } else { InputT tile_predecessor; if (threadIdx.x == 0) { tile_predecessor = d_in[tile_offset - 1]; } CTA_SYNC(); BlockDiscontinuityT(temp_storage.scan_storage.discontinuity) .FlagHeads(selection_flags, items, inequality_op, tile_predecessor); } // Set selection flags for out-of-bounds items #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { // Set selection_flags for out-of-bounds items if ((IS_LAST_TILE) && (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM >= num_tile_items)) { selection_flags[ITEM] = 1; } } } //--------------------------------------------------------------------- // Scatter utility methods //--------------------------------------------------------------------- /** * Scatter flagged items to output offsets (specialized for direct scattering). */ template _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterSelectedDirect( InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], OffsetT num_selections) { // Scatter flagged items #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { if (selection_flags[ITEM]) { if ((!IS_LAST_TILE) || selection_indices[ITEM] < num_selections) { d_selected_out[selection_indices[ITEM]] = items[ITEM]; } } } } /** * @brief Scatter flagged items to output offsets (specialized for two-phase scattering) * * @param num_tile_items * Number of valid items in this tile * * @param num_tile_selections * Number of selections in this tile * * @param num_selections_prefix * Total number of selections prior to this tile * * @param num_rejected_prefix * Total number of rejections prior to this tile * * @param is_keep_rejects * Marker type indicating whether to keep rejected items in the second partition */ template _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterSelectedTwoPhase( InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], int num_tile_selections, OffsetT num_selections_prefix) { CTA_SYNC(); // Compact and scatter items #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { int local_scatter_offset = selection_indices[ITEM] - num_selections_prefix; if (selection_flags[ITEM]) { temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM]; } } CTA_SYNC(); for (int item = threadIdx.x; item < num_tile_selections; item += BLOCK_THREADS) { d_selected_out[num_selections_prefix + item] = temp_storage.raw_exchange.Alias()[item]; } } /** * @brief Scatter flagged items. Specialized for selection algorithm that simply discards rejected items * * @param num_tile_items * Number of valid items in this tile * * @param num_tile_selections * Number of selections in this tile * * @param num_selections_prefix * Total number of selections prior to this tile * * @param num_rejected_prefix * Total number of rejections prior to this tile * * @param num_selections * Total number of selections including this tile */ template _CCCL_DEVICE _CCCL_FORCEINLINE void Scatter( InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], int num_tile_items, int num_tile_selections, OffsetT num_selections_prefix, OffsetT num_rejected_prefix, OffsetT num_selections, Int2Type /*is_keep_rejects*/) { // Do a two-phase scatter if two-phase is enabled and the average number of selection_flags items per thread is // greater than one if (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS)) { ScatterSelectedTwoPhase( items, selection_flags, selection_indices, num_tile_selections, num_selections_prefix); } else { ScatterSelectedDirect(items, selection_flags, selection_indices, num_selections); } } /** * @brief Scatter flagged items. Specialized for partitioning algorithm that writes rejected items to a second * partition. * * @param num_tile_items * Number of valid items in this tile * * @param num_tile_selections * Number of selections in this tile * * @param num_selections_prefix * Total number of selections prior to this tile * * @param num_rejected_prefix * Total number of rejections prior to this tile * * @param is_keep_rejects * Marker type indicating whether to keep rejected items in the second partition */ template _CCCL_DEVICE _CCCL_FORCEINLINE void Scatter( InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], int num_tile_items, int num_tile_selections, OffsetT num_selections_prefix, OffsetT num_rejected_prefix, OffsetT num_selections, Int2Type /*is_keep_rejects*/) { CTA_SYNC(); int tile_num_rejections = num_tile_items - num_tile_selections; // Scatter items to shared memory (rejections first) #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { int item_idx = (threadIdx.x * ITEMS_PER_THREAD) + ITEM; int local_selection_idx = selection_indices[ITEM] - num_selections_prefix; int local_rejection_idx = item_idx - local_selection_idx; int local_scatter_offset = (selection_flags[ITEM]) ? tile_num_rejections + local_selection_idx : local_rejection_idx; temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM]; } // Ensure all threads finished scattering to shared memory CTA_SYNC(); // Gather items from shared memory and scatter to global ScatterPartitionsToGlobal( num_tile_items, tile_num_rejections, num_selections_prefix, num_rejected_prefix, d_selected_out); } /** * @brief Second phase of scattering partitioned items to global memory. Specialized for partitioning to two * distinct partitions. */ template _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterPartitionsToGlobal( int num_tile_items, int tile_num_rejections, OffsetT num_selections_prefix, OffsetT num_rejected_prefix, detail::partition_distinct_output_t partitioned_out_it_wrapper) { #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x; int rejection_idx = item_idx; int selection_idx = item_idx - tile_num_rejections; OffsetT scatter_offset = (item_idx < tile_num_rejections) ? num_rejected_prefix + rejection_idx : num_selections_prefix + selection_idx; InputT item = temp_storage.raw_exchange.Alias()[item_idx]; if (!IS_LAST_TILE || (item_idx < num_tile_items)) { if (item_idx >= tile_num_rejections) { partitioned_out_it_wrapper.selected_it[scatter_offset] = item; } else { partitioned_out_it_wrapper.rejected_it[scatter_offset] = item; } } } } /** * @brief Second phase of scattering partitioned items to global memory. Specialized for partitioning to a single * iterator, where selected items are written in order from the beginning of the itereator and rejected items are * writtem from the iterators end backwards. */ template _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterPartitionsToGlobal( int num_tile_items, int tile_num_rejections, OffsetT num_selections_prefix, OffsetT num_rejected_prefix, PartitionedOutputItT partitioned_out_it) { #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x; int rejection_idx = item_idx; int selection_idx = item_idx - tile_num_rejections; OffsetT scatter_offset = (item_idx < tile_num_rejections) ? num_items - num_rejected_prefix - rejection_idx - 1 : num_selections_prefix + selection_idx; InputT item = temp_storage.raw_exchange.Alias()[item_idx]; if (!IS_LAST_TILE || (item_idx < num_tile_items)) { partitioned_out_it[scatter_offset] = item; } } } //--------------------------------------------------------------------- // Cooperatively scan a device-wide sequence of tiles with other CTAs //--------------------------------------------------------------------- /** * @brief Process first tile of input (dynamic chained scan). * * @param num_tile_items * Number of input items comprising this tile * * @param tile_offset * Tile offset * * @param tile_state_wrapper * A global tile state descriptor wrapped in a MemoryOrderedTileStateT that ensures consistent memory order across * all tile status updates and loads * * @return The running count of selections (including this tile) */ template _CCCL_DEVICE _CCCL_FORCEINLINE OffsetT ConsumeFirstTile(int num_tile_items, OffsetT tile_offset, MemoryOrderedTileStateT& tile_state_wrapper) { InputT items[ITEMS_PER_THREAD]; OffsetT selection_flags[ITEMS_PER_THREAD]; OffsetT selection_indices[ITEMS_PER_THREAD]; // Load items if (IS_LAST_TILE) { BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items); } else { BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items); } // Initialize selection_flags InitializeSelections( tile_offset, num_tile_items, items, selection_flags, Int2Type()); // Ensure temporary storage used during block load can be reused // Also, in case of in-place stream compaction, this is needed to order the loads of // *all threads of this thread block* before the st.release of the thread writing this thread block's tile state CTA_SYNC(); // Exclusive scan of selection_flags OffsetT num_tile_selections; BlockScanT(temp_storage.scan_storage.scan).ExclusiveSum(selection_flags, selection_indices, num_tile_selections); if (threadIdx.x == 0) { // Update tile status if this is not the last tile if (!IS_LAST_TILE) { tile_state_wrapper.SetInclusive(0, num_tile_selections); } } // Discount any out-of-bounds selections if (IS_LAST_TILE) { num_tile_selections -= (TILE_ITEMS - num_tile_items); } // Scatter flagged items Scatter( items, selection_flags, selection_indices, num_tile_items, num_tile_selections, 0, 0, num_tile_selections, cub::Int2Type{}); return num_tile_selections; } /** * @brief Process subsequent tile of input (dynamic chained scan). * * @param num_tile_items * Number of input items comprising this tile * * @param tile_idx * Tile index * * @param tile_offset * Tile offset * * @param tile_state_wrapper * A global tile state descriptor wrapped in a MemoryOrderedTileStateT that ensures consistent memory order across * all tile status updates and loads * * @return The running count of selections (including this tile) */ template _CCCL_DEVICE _CCCL_FORCEINLINE OffsetT ConsumeSubsequentTile( int num_tile_items, int tile_idx, OffsetT tile_offset, MemoryOrderedTileStateT& tile_state_wrapper) { InputT items[ITEMS_PER_THREAD]; OffsetT selection_flags[ITEMS_PER_THREAD]; OffsetT selection_indices[ITEMS_PER_THREAD]; // Load items if (IS_LAST_TILE) { BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items); } else { BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items); } // Initialize selection_flags InitializeSelections( tile_offset, num_tile_items, items, selection_flags, Int2Type()); // Ensure temporary storage used during block load can be reused // Also, in case of in-place stream compaction, this is needed to order the loads of // *all threads of this thread block* before the st.release of the thread writing this thread block's tile state CTA_SYNC(); // Exclusive scan of values and selection_flags TilePrefixCallbackOpT prefix_op(tile_state_wrapper, temp_storage.scan_storage.prefix, cub::Sum(), tile_idx); BlockScanT(temp_storage.scan_storage.scan).ExclusiveSum(selection_flags, selection_indices, prefix_op); OffsetT num_tile_selections = prefix_op.GetBlockAggregate(); OffsetT num_selections = prefix_op.GetInclusivePrefix(); OffsetT num_selections_prefix = prefix_op.GetExclusivePrefix(); OffsetT num_rejected_prefix = tile_offset - num_selections_prefix; // Discount any out-of-bounds selections if (IS_LAST_TILE) { int num_discount = TILE_ITEMS - num_tile_items; num_selections -= num_discount; num_tile_selections -= num_discount; } // note (only applies to in-place stream compaction): We can avoid having to introduce explicit memory order between // the look-back (i.e., loading previous tiles' states) and scattering items (which means, potentially overwriting // previous tiles' input items, in case of in-place compaction), because this is implicitly ensured through // execution dependency: The scatter stage requires the offset from the prefix-sum and it can only know the // prefix-sum after having read that from the decoupled look-back. Scatter flagged items Scatter( items, selection_flags, selection_indices, num_tile_items, num_tile_selections, num_selections_prefix, num_rejected_prefix, num_selections, cub::Int2Type{}); return num_selections; } /** * @brief Process a tile of input * * @param num_tile_items * Number of input items comprising this tile * * @param tile_idx * Tile index * * @param tile_offset * Tile offset * * @param tile_state_wrapper * A global tile state descriptor wrapped in a MemoryOrderedTileStateT that ensures consistent memory order across * all tile status updates and loads */ template _CCCL_DEVICE _CCCL_FORCEINLINE OffsetT ConsumeTile(int num_tile_items, int tile_idx, OffsetT tile_offset, MemoryOrderedTileStateT& tile_state_wrapper) { OffsetT num_selections; if (tile_idx == 0) { num_selections = ConsumeFirstTile(num_tile_items, tile_offset, tile_state_wrapper); } else { num_selections = ConsumeSubsequentTile(num_tile_items, tile_idx, tile_offset, tile_state_wrapper); } return num_selections; } /** * @brief Scan tiles of items as part of a dynamic chained scan * * @param num_tiles * Total number of input tiles * * @param tile_state * Global tile state descriptor * * @param d_num_selected_out * Output total number selection_flags * * @tparam NumSelectedIteratorT * Output iterator type for recording number of items selection_flags */ template _CCCL_DEVICE _CCCL_FORCEINLINE void ConsumeRange(int num_tiles, ScanTileStateT& tile_state, NumSelectedIteratorT d_num_selected_out) { // Ensure consistent memory order across all tile status updates and loads auto tile_state_wrapper = MemoryOrderedTileStateT{tile_state}; // Blocks are launched in increasing order, so just assign one tile per block int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index OffsetT tile_offset = static_cast(tile_idx) * static_cast(TILE_ITEMS); if (tile_idx < num_tiles - 1) { // Not the last tile (full) ConsumeTile(TILE_ITEMS, tile_idx, tile_offset, tile_state_wrapper); } else { // The last tile (possibly partially-full) OffsetT num_remaining = num_items - tile_offset; OffsetT num_selections = ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state_wrapper); if (threadIdx.x == 0) { // Output the total number of items selection_flags *d_num_selected_out = num_selections; } } } }; CUB_NAMESPACE_END