#pragma once #include "cublas_handle.hpp" namespace tf { // ---------------------------------------------------------------------------- // cublasFlowCapturere level-3 functions // ---------------------------------------------------------------------------- // Function: geam template cudaTask cublasFlowCapturer::geam( cublasOperation_t ta, cublasOperation_t tb, int m, int n, const T *alpha, const T *A, int lda, const T *beta, const T *B, int ldb, T *C, int ldc ) { return on([this, ta, tb, m, n, alpha, A, lda, beta, B, ldb, C, ldc] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSgeam(_handle, ta, tb, m, n, alpha, A, lda, beta, B, ldb, C, ldc ); } else if constexpr(std::is_same_v) { stat = cublasDgeam(_handle, ta, tb, m, n, alpha, A, lda, beta, B, ldb, C, ldc ); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to run geam"); }); } // Function: c_geam template cudaTask cublasFlowCapturer::c_geam( cublasOperation_t ta, cublasOperation_t tb, int m, int n, const T *alpha, const T *A, int lda, const T *beta, const T *B, int ldb, T *C, int ldc ) { return geam( ta, tb, n, m, alpha, A, lda, beta, B, ldb, C, ldc ); } // Function: gemm template cudaTask cublasFlowCapturer::gemm( cublasOperation_t ta, cublasOperation_t tb, int m, int n, int k, const T *alpha, const T *A, int lda, const T *B, int ldb, const T *beta, T *C, int ldc ) { return on([this, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSgemm(_handle, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); } else if constexpr(std::is_same_v) { stat = cublasDgemm(_handle, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); } else { static_assert(dependent_false_v, "unknown cublas data type"); } TF_CHECK_CUBLAS(stat, "failed to run gemm"); }); } template cudaTask cublasFlowCapturer::c_gemm( cublasOperation_t ta, cublasOperation_t tb, int m, int n, int k, const T *alpha, const T *A, int lda, const T *B, int ldb, const T *beta, T *C, int ldc ) { return gemm( tb, ta, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc ); } // Function: gemm_batched template cudaTask cublasFlowCapturer::gemm_batched( cublasOperation_t ta, cublasOperation_t tb, int m, int n, int k, const T *alpha, const T *A[], int lda, const T *B[], int ldb, const T *beta, T *C[], int ldc, int bc ) { return on([this, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, bc] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSgemmBatched(_handle, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, bc ); } else if constexpr(std::is_same_v) { stat = cublasDgemmBatched(_handle, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, bc ); } else static_assert(dependent_false_v, "unknown cublas data type"); TF_CHECK_CUBLAS(stat, "failed to run gemm_batched"); }); } // Function: c_gemm_batched template cudaTask cublasFlowCapturer::c_gemm_batched( cublasOperation_t ta, cublasOperation_t tb, int m, int n, int k, const T *alpha, const T *A[], int lda, const T *B[], int ldb, const T *beta, T *C[], int ldc, int bc ) { return gemm_batched( tb, ta, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc, bc ); } // Function: gemm_sbatched (strided) template cudaTask cublasFlowCapturer::gemm_sbatched( cublasOperation_t ta, cublasOperation_t tb, int m, int n, int k, const T *alpha, const T *A, int lda, long long int sA, const T *B, int ldb, long long int sB, const T *beta, T *C, int ldc, long long int sC, int bc ) { return on([this, ta, tb, m, n, k, alpha, A, lda, sA, B, ldb, sB, beta, C, ldc, sC, bc] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSgemmStridedBatched(_handle, ta, tb, m, n, k, alpha, A, lda, sA, B, ldb, sB, beta, C, ldc, sC, bc ); } else if constexpr(std::is_same_v) { stat = cublasDgemmStridedBatched(_handle, ta, tb, m, n, k, alpha, A, lda, sA, B, ldb, sB, beta, C, ldc, sC, bc ); } else static_assert(dependent_false_v, "unknown cublas data type"); TF_CHECK_CUBLAS(stat, "failed to run gemm_sbatched"); }); } // Function: c_gemm_sbatched (strided) template cudaTask cublasFlowCapturer::c_gemm_sbatched( cublasOperation_t ta, cublasOperation_t tb, int m, int n, int k, const T *alpha, const T *A, int lda, long long int sA, const T *B, int ldb, long long int sB, const T *beta, T *C, int ldc, long long int sC, int bc ){ return gemm_sbatched( tb, ta, n, m, k, alpha, B, ldb, sB, A, lda, sA, beta, C, ldc, sC, bc ); } // symm template cudaTask cublasFlowCapturer::symm( cublasSideMode_t side, cublasFillMode_t uplo, int m, int n, const T *alpha, const T *A, int lda, const T *B, int ldb, const T *beta, T *C, int ldc ) { return on( [this, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSsymm(_handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc ); } else if constexpr(std::is_same_v) { stat = cublasDsymm(_handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc ); } else static_assert(dependent_false_v, "unknown cublas data type"); TF_CHECK_CUBLAS(stat, "failed to run symm"); }); } // c_symm template cudaTask cublasFlowCapturer::c_symm( cublasSideMode_t side, cublasFillMode_t uplo, int m, int n, const T *alpha, const T *A, int lda, const T *B, int ldb, const T *beta, T *C, int ldc ) { return symm( cublas_rside(side), cublas_rfill(uplo), n, m, alpha, A, lda, B, ldb, beta, C, ldc ); } // syrk template cudaTask cublasFlowCapturer::syrk( cublasFillMode_t uplo, cublasOperation_t tran, int n, int k, const T *alpha, const T *A, int lda, const T *beta, T *C, int ldc ) { return on( [this, uplo, tran, n, k, alpha, A, lda, beta, C, ldc] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSsyrk(_handle, uplo, tran, n, k, alpha, A, lda, beta, C, ldc ); } else if constexpr(std::is_same_v) { stat = cublasDsyrk(_handle, uplo, tran, n, k, alpha, A, lda, beta, C, ldc ); } else static_assert(dependent_false_v, "unknown cublas data type"); TF_CHECK_CUBLAS(stat, "failed to run syrk"); }); } // c_syrk template cudaTask cublasFlowCapturer::c_syrk( cublasFillMode_t uplo, cublasOperation_t tran, int n, int k, const T *alpha, const T *A, int lda, const T *beta, T *C, int ldc ) { return syrk( cublas_rfill(uplo), cublas_rtran(tran), n, k, alpha, A, lda, beta, C, ldc ); } // syr2k template cudaTask cublasFlowCapturer::syr2k( cublasFillMode_t uplo, cublasOperation_t tran, int n, int k, const T *alpha, const T *A, int lda, const T *B, int ldb, const T *beta, T *C, int ldc ) { return on( [this, uplo, tran, n, k, alpha, A, lda, B, ldb, beta, C, ldc] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSsyr2k(_handle, uplo, tran, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); } else if constexpr(std::is_same_v) { stat = cublasDsyr2k(_handle, uplo, tran, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); } else static_assert(dependent_false_v, "unknown cublas data type"); TF_CHECK_CUBLAS(stat, "failed to run syr2k"); }); } // c_syr2k template cudaTask cublasFlowCapturer::c_syr2k( cublasFillMode_t uplo, cublasOperation_t tran, int n, int k, const T *alpha, const T *A, int lda, const T *B, int ldb, const T *beta, T *C, int ldc ) { return syr2k( cublas_rfill(uplo), cublas_rtran(tran), n, k, alpha, B, ldb, A, lda, beta, C, ldc ); } // syrkx template cudaTask cublasFlowCapturer::syrkx( cublasFillMode_t uplo, cublasOperation_t tran, int n, int k, const T *alpha, const T *A, int lda, const T *B, int ldb, const T *beta, T *C, int ldc ) { return on( [this, uplo, tran, n, k, alpha, A, lda, B, ldb, beta, C, ldc] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasSsyrkx(_handle, uplo, tran, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); } else if constexpr(std::is_same_v) { stat = cublasDsyrkx(_handle, uplo, tran, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); } else static_assert(dependent_false_v, "unknown cublas data type"); TF_CHECK_CUBLAS(stat, "failed to run syrkx"); }); } // c_syrkx template cudaTask cublasFlowCapturer::c_syrkx( cublasFillMode_t uplo, cublasOperation_t tran, int n, int k, const T *alpha, const T *A, int lda, const T *B, int ldb, const T *beta, T *C, int ldc ) { return syrkx( cublas_rfill(uplo), cublas_rtran(tran), n, k, alpha, B, ldb, A, lda, beta, C, ldc ); } // trmm template cudaTask cublasFlowCapturer::trmm( cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t tran, cublasDiagType_t diag, int m, int n, const T *alpha, const T *A, int lda, const T *B, int ldb, T *C, int ldc ) { return on( [this, side, uplo, tran, diag, m, n, alpha, A, lda, B, ldb, C, ldc] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasStrmm(_handle, side, uplo, tran, diag, m, n, alpha, A, lda, B, ldb, C, ldc ); } else if constexpr(std::is_same_v) { stat = cublasDtrmm(_handle, side, uplo, tran, diag, m, n, alpha, A, lda, B, ldb, C, ldc ); } else static_assert(dependent_false_v, "unknown cublas data type"); TF_CHECK_CUBLAS(stat, "failed to run trmm"); }); } // c_trmm template cudaTask cublasFlowCapturer::c_trmm( cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t tran, cublasDiagType_t diag, int m, int n, const T *alpha, const T *A, int lda, const T *B, int ldb, T *C, int ldc ) { return trmm( cublas_rside(side), cublas_rfill(uplo), tran, diag, n, m, alpha, A, lda, B, ldb, C, ldc ); } // trsm template cudaTask cublasFlowCapturer::trsm( cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t tran, cublasDiagType_t diag, int m, int n, const T *alpha, const T *A, int lda, T *B, int ldb ) { return on( [this, side, uplo, tran, diag, m, n, alpha, A, lda, B, ldb] (cudaStream_t stream) mutable { _stream(stream); cublasStatus_t stat; if constexpr(std::is_same_v) { stat = cublasStrsm(_handle, side, uplo, tran, diag, m, n, alpha, A, lda, B, ldb ); } else if constexpr(std::is_same_v) { stat = cublasDtrsm(_handle, side, uplo, tran, diag, m, n, alpha, A, lda, B, ldb ); } else static_assert(dependent_false_v, "unknown cublas data type"); TF_CHECK_CUBLAS(stat, "failed to run trsm"); }); } // c_trsm template cudaTask cublasFlowCapturer::c_trsm( cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t tran, cublasDiagType_t diag, int m, int n, const T *alpha, const T *A, int lda, T *B, int ldb ) { return trsm( cublas_rside(side), cublas_rfill(uplo), tran, diag, n, m, alpha, A, lda, B, ldb ); } } // end of namespace tf -----------------------------------------------------