//===----------------------------------------------------------------------===// // // Part of libcu++, the C++ Standard Library for your entire system, // under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. // //===----------------------------------------------------------------------===// #ifndef _CUDA_STREAM_REF #define _CUDA_STREAM_REF /* stream_ref synopsis namespace cuda { class stream_ref { using value_type = cudaStream_t; stream_ref() = default; stream_ref(cudaStream_t stream_) noexcept : stream(stream_) {} stream_ref(int) = delete; stream_ref(nullptr_t) = delete; [[nodiscard]] value_type get() const noexcept; void wait() const; [[nodiscard]] bool ready() const; [[nodiscard]] friend bool operator==(stream_ref, stream_ref); [[nodiscard]] friend bool operator!=(stream_ref, stream_ref); private: cudaStream_t stream = 0; // exposition only }; } // cuda */ #include // cuda_runtime_api needs to come first #include #if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) # pragma GCC system_header #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) # pragma clang system_header #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) # pragma system_header #endif // no system header #include #include #include _LIBCUDACXX_BEGIN_NAMESPACE_CUDA /** * \brief A non-owning wrapper for a `cudaStream_t`. */ class stream_ref { protected: ::cudaStream_t __stream{0}; public: using value_type = ::cudaStream_t; /** * \brief Constructs a `stream_ref` of the "default" CUDA stream. * * For behavior of the default stream, * \see * https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html * */ _CCCL_HIDE_FROM_ABI stream_ref() = default; /** * \brief Constructs a `stream_ref` from a `cudaStream_t` handle. * * This constructor provides implicit conversion from `cudaStream_t`. * * \note: It is the callers responsibilty to ensure the `stream_ref` does not * outlive the stream identified by the `cudaStream_t` handle. * */ constexpr stream_ref(value_type __stream_) noexcept : __stream{__stream_} {} /// Disallow construction from an `int`, e.g., `0`. stream_ref(int) = delete; /// Disallow construction from `nullptr`. stream_ref(_CUDA_VSTD::nullptr_t) = delete; /** * \brief Compares two `stream_ref`s for equality * * \note Allows comparison with `cudaStream_t` due to implicit conversion to * `stream_ref`. * * \param lhs The first `stream_ref` to compare * \param rhs The second `stream_ref` to compare * \return true if equal, false if unequal */ _CCCL_NODISCARD_FRIEND constexpr bool operator==(const stream_ref& __lhs, const stream_ref& __rhs) noexcept { return __lhs.__stream == __rhs.__stream; } /** * \brief Compares two `stream_ref`s for inequality * * \note Allows comparison with `cudaStream_t` due to implicit conversion to * `stream_ref`. * * \param lhs The first `stream_ref` to compare * \param rhs The second `stream_ref` to compare * \return true if unequal, false if equal */ _CCCL_NODISCARD_FRIEND constexpr bool operator!=(const stream_ref& __lhs, const stream_ref& __rhs) noexcept { return __lhs.__stream != __rhs.__stream; } /// Returns the wrapped `cudaStream_t` handle. _CCCL_NODISCARD constexpr value_type get() const noexcept { return __stream; } /** * \brief Synchronizes the wrapped stream. * * \throws cuda::cuda_error if synchronization fails. * */ void wait() const { _CCCL_TRY_CUDA_API(::cudaStreamSynchronize, "Failed to synchronize stream.", get()); } /** * \brief Queries if all operations on the wrapped stream have completed. * * \throws cuda::cuda_error if the query fails. * * \return `true` if all operations have completed, or `false` if not. */ _CCCL_NODISCARD bool ready() const { const auto __result = ::cudaStreamQuery(get()); if (__result == ::cudaErrorNotReady) { return false; } switch (__result) { case ::cudaSuccess: break; default: ::cudaGetLastError(); // Clear CUDA error state ::cuda::__throw_cuda_error(__result, "Failed to query stream."); } return true; } /** * \brief Queries the priority of the wrapped stream. * * \throws cuda::cuda_error if the query fails. * * \return value representing the priority of the wrapped stream. */ _CCCL_NODISCARD int priority() const { int __result = 0; _CCCL_TRY_CUDA_API(::cudaStreamGetPriority, "Failed to get stream priority", get(), &__result); return __result; } }; _LIBCUDACXX_END_NAMESPACE_CUDA #endif //_CUDA_STREAM_REF