cocos-engine-external/sources/taskflow/cuda/cublas/cublas_handle.hpp

135 lines
3.2 KiB
C++

#pragma once
#include "cublas_error.hpp"
/**
@file cublas_handle.hpp
*/
namespace tf {
/** @private */
struct cublasHandleCreator {
cublasHandle_t operator () () const {
cublasHandle_t handle;
TF_CHECK_CUBLAS(
cublasCreate(&handle), "failed to create a cublas handle"
);
TF_CHECK_CUBLAS(
cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE),
"failed to set cublas pointer mode on device"
);
//std::cout << "create cublas handle " << handle << '\n';
return handle;
}
};
/** @private */
struct cublasHandleDeleter {
void operator () (cublasHandle_t ptr) const {
//std::cout << "destroy cublas handle " << ptr << '\n';
cublasDestroy(ptr);
}
};
/**
@private alias of per-thread cublas handle pool type
*/
using cublasPerThreadHandlePool = cudaPerThreadDeviceObjectPool<
cublasHandle_t, cublasHandleCreator, cublasHandleDeleter
>;
/**
@private per-thread cublas stream pool
*/
inline thread_local cublasPerThreadHandlePool cublas_per_thread_handle_pool;
// ----------------------------------------------------------------------------
// cublasScopedPerThreadHandle definition
// ----------------------------------------------------------------------------
/**
@brief class to provide RAII-styled guard of cublas handle acquisition
Sample usage:
@code{.cpp}
{
tf::cublasScopedPerThreadHandle handle(1); // acquires a cublas handle on device 1
// use handle as a normal cublas handle (cublasHandle_t)
cublasSetStream(handle, stream);
} // leaving the scope to release the handle back to the pool on device 1
@endcode
By default, the cublas handle has a pointer mode set to device
(i.e., @c CUBLAS_POINTER_MODE_DEVICE),
that is required for capturing cublas kernels.
The scoped per-thread cublas handle is primarily used by tf::cublasFlowCapturer.
%cublasScopedPerThreadHandle is neither movable nor copyable.
*/
class cublasScopedPerThreadHandle {
public:
/**
@brief constructs a scoped handle under the given device context
The constructor acquires a handle from a per-thread handle pool.
*/
explicit cublasScopedPerThreadHandle(int d) :
_ptr {cublas_per_thread_handle_pool.acquire(d)} {
}
/**
@brief constructs a scoped handle under caller's device context
The constructor acquires a handle from a per-thread handle pool.
*/
cublasScopedPerThreadHandle() :
_ptr {cublas_per_thread_handle_pool.acquire(cuda_get_device())} {
}
/**
@brief destructs the scoped handle guard
The destructor releases the handle to the per-thread handle pool.
*/
~cublasScopedPerThreadHandle() {
cublas_per_thread_handle_pool.release(std::move(_ptr));
}
/**
@brief implicit conversion to the native cublas handle (cublasHandle_t)
*/
operator cublasHandle_t () const {
return _ptr->object;
}
/**
@brief returns the number of shared owners
*/
long use_count() const {
return _ptr.use_count();
}
private:
cublasScopedPerThreadHandle(const cublasScopedPerThreadHandle&) = delete;
cublasScopedPerThreadHandle(cublasScopedPerThreadHandle&&) = delete;
std::shared_ptr<cublasPerThreadHandlePool::cudaDeviceObject> _ptr;
};
} // end of namespace tf -----------------------------------------------------