#pragma once #include "cublas_handle.hpp" namespace tf { // ---------------------------------------------------------------------------- // cublasFlowCapturere level-2 functions // ---------------------------------------------------------------------------- template cudaTask cublasFlowCapturer::gemv( cublasOperation_t trans, int m, int n, const T *alpha, const T *A, int lda, const T *x, int incx, const T *beta, T *y, int incy ) { return on([this, trans, m, n, alpha, A, lda, x, incx, beta, y, incy] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSgemv(_handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy ); } else if constexpr(std::is_same_v) { stat = cublasDgemv(_handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy ); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to capture gemv"); }); } // gemv template cudaTask cublasFlowCapturer::c_gemv( cublasOperation_t trans, int m, int n, const T *alpha, const T *A, int lda, const T *x, int incx, const T *beta, T *y, int incy ) { return gemv( cublas_rtran(trans), n, m, alpha, A, lda, x, incx, beta, y, incy ); } // trmv template cudaTask cublasFlowCapturer::trmv( cublasFillMode_t uplo, cublasOperation_t tran, cublasDiagType_t diag, int n, const T* A, int lda, T *x, int incx ) { return on([this, uplo, tran, diag, n, A, lda, x, incx] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasStrmv(_handle, uplo, tran, diag, n, A, lda, x, incx); } else if constexpr(std::is_same_v) { stat = cublasDtrmv(_handle, uplo, tran, diag, n, A, lda, x, incx); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to capture trmv"); }); } // c_trmv template cudaTask cublasFlowCapturer::c_trmv( cublasFillMode_t uplo, cublasOperation_t tran, cublasDiagType_t diag, int n, const T* A, int lda, T *x, int incx ) { return trmv( cublas_rfill(uplo), cublas_rtran(tran), diag, n, A, lda, x, incx ); } // trsv template cudaTask cublasFlowCapturer::trsv( cublasFillMode_t uplo, cublasOperation_t tran, cublasDiagType_t diag, int n, const T* A, int lda, T *x, int incx ) { return on([this, uplo, tran, diag, n, A, lda, x, incx] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasStrsv(_handle, uplo, tran, diag, n, A, lda, x, incx); } else if constexpr(std::is_same_v) { stat = cublasDtrsv(_handle, uplo, tran, diag, n, A, lda, x, incx); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to capture trsv"); }); } // c_trsv template cudaTask cublasFlowCapturer::c_trsv( cublasFillMode_t uplo, cublasOperation_t tran, cublasDiagType_t diag, int n, const T* A, int lda, T *x, int incx ) { return trsv( cublas_rfill(uplo), cublas_rtran(tran), diag, n, A, lda, x, incx ); } // symv template cudaTask cublasFlowCapturer::symv( cublasFillMode_t uplo, int n, const T *alpha, const T *A, int lda, const T *x, int incx, const T *beta, T *y, int incy ) { return on([this, uplo, n, alpha, A, lda, x, incx, beta, y, incy] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSsymv(_handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } else if constexpr(std::is_same_v) { stat = cublasDsymv(_handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to capture symv"); }); } // c_symv template cudaTask cublasFlowCapturer::c_symv( cublasFillMode_t uplo, int n, const T *alpha, const T *A, int lda, const T *x, int incx, const T *beta, T *y, int incy ) { return symv( cublas_rfill(uplo), n, alpha, A, lda, x, incx, beta, y, incy ); } // syr template cudaTask cublasFlowCapturer::syr( cublasFillMode_t uplo, int n, const T *alpha, const T *x, int incx, T *A, int lda ) { return on([this, uplo, n, alpha, x, incx, A, lda] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSsyr(_handle, uplo, n, alpha, x, incx, A, lda); } else if constexpr(std::is_same_v) { stat = cublasDsyr(_handle, uplo, n, alpha, x, incx, A, lda); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to capture syr"); }); } // c_syr template cudaTask cublasFlowCapturer::c_syr( cublasFillMode_t uplo, int n, const T *alpha, const T *x, int incx, T *A, int lda ) { return syr( cublas_rfill(uplo), n, alpha, x, incx, A, lda ); } // syr2 template cudaTask cublasFlowCapturer::syr2( cublasFillMode_t uplo, int n, const T *alpha, const T *x, int incx, const T *y, int incy, T *A, int lda ) { return on([this, uplo, n, alpha, x, incx, y, incy, A, lda] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSsyr2(_handle, uplo, n, alpha, x, incx, y, incy, A, lda); } else if constexpr(std::is_same_v) { stat = cublasDsyr2(_handle, uplo, n, alpha, x, incx, y, incy, A, lda); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to capture syr2"); }); } // c_syr2 template cudaTask cublasFlowCapturer::c_syr2( cublasFillMode_t uplo, int n, const T *alpha, const T *x, int incx, const T *y, int incy, T *A, int lda ) { return syr2( cublas_rfill(uplo), n, alpha, x, incx, y, incy, A, lda ); } } // end of namespace tf -----------------------------------------------------