#pragma once #include "cublas_handle.hpp" namespace tf { // ---------------------------------------------------------------------------- // cublasFlowCapturere level-1 functions // ---------------------------------------------------------------------------- // Function: amax template cudaTask cublasFlowCapturer::amax( int n, const T* x, int incx, int* result ) { return on([this, n, x, incx, result] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasIsamax(_handle, n, x, incx, result); } else if constexpr(std::is_same_v) { stat = cublasIdamax(_handle, n, x, incx, result); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to run cublasamax"); }); } // Function: amin template cudaTask cublasFlowCapturer::amin( int n, const T* x, int incx, int* result ) { return on([this, n, x, incx, result] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasIsamin(_handle, n, x, incx, result); } else if constexpr(std::is_same_v) { stat = cublasIdamin(_handle, n, x, incx, result); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to run cublasamin"); }); } // Function: asum template cudaTask cublasFlowCapturer::asum( int n, const T* x, int incx, T* result ) { return on([this, n, x, incx, result] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSasum(_handle, n, x, incx, result); } else if constexpr(std::is_same_v) { stat = cublasDasum(_handle, n, x, incx, result); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to run cublasasum"); }); } // Function: axpy template cudaTask cublasFlowCapturer::axpy( int n, const T *alpha, const T *x, int incx, T *y, int incy ) { return on([this, n, alpha, x, incx, y, incy] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSaxpy(_handle, n, alpha, x, incx, y, incy); } else if constexpr(std::is_same_v) { stat = cublasDaxpy(_handle, n, alpha, x, incx, y, incy); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to run cublasaxpy"); }); } // Function: vcopy template cudaTask cublasFlowCapturer::vcopy( int n, const T* x, int incx, T* y, int incy ) { return on([this, n, x, incx, y, incy] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasScopy(_handle, n, x, incx, y, incy); } else if constexpr(std::is_same_v) { stat = cublasDcopy(_handle, n, x, incx, y, incy); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to run cublascopy"); }); } // Function: dot template cudaTask cublasFlowCapturer::dot( int n, const T* x, int incx, const T* y, int incy, T* result ) { return on([this, n, x, incx, y, incy, result] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSdot(_handle, n, x, incx, y, incy, result); } else if constexpr(std::is_same_v) { stat = cublasDdot(_handle, n, x, incx, y, incy, result); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to run cublasdot"); }); } template cudaTask cublasFlowCapturer::nrm2(int n, const T* x, int incx, T* result) { return on([this, n, x, incx, result] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSnrm2(_handle, n, x, incx, result); } else if constexpr(std::is_same_v) { stat = cublasDnrm2(_handle, n, x, incx, result); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to run cublasnrm2"); }); } // Function: scal template cudaTask cublasFlowCapturer::scal(int n, const T* scalar, T* x, int incx) { return on([this, n, scalar, x, incx] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSscal(_handle, n, scalar, x, incx); } else if constexpr(std::is_same_v) { stat = cublasDscal(_handle, n, scalar, x, incx); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to run cublasscal"); }); } template cudaTask cublasFlowCapturer::swap(int n, T* x, int incx, T* y, int incy) { return on([this, n, x, incx, y, incy] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSswap(_handle, n, x, incx, y, incy); } else if constexpr(std::is_same_v) { stat = cublasDswap(_handle, n, x, incx, y, incy); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to run cublasswap"); }); } } // end of namespace tf -----------------------------------------------------