vinx13 commented on code in PR #14291:
URL: https://github.com/apache/tvm/pull/14291#discussion_r1153789203


##########
src/runtime/contrib/cublas/cublas.cc:
##########
@@ -133,6 +134,120 @@ bool CheckMixPrecisionType(DLDataType in_dtype, 
DLDataType out_dtype, bool int_s
 int roundoff(int v, int d) { return (v + d - 1) / d * d; }
 
 #if CUDART_VERSION >= 10010
+
+void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, 
const DLTensor* bias,
+                  const DLTensor* C, bool transa, bool transb, 
cublasLtEpilogue_t epilogue) {
+  ICHECK(TypeEqual(A->dtype, B->dtype));
+  // Reversed strides indicates an in-place transpose operation.
+  transa = IsInPlaceTransposed(A) ? !transa : transa;
+  transb = IsInPlaceTransposed(B) ? !transb : transb;
+
+  auto compute_type = CUBLAS_COMPUTE_32F;
+  auto scale_type = CUDA_R_32F;
+  cudaDataType_t ab_type = CUDA_R_32F;
+  cudaDataType_t c_type = CUDA_R_32F;
+  float one_fp32 = 1.0;
+  float zero_fp32 = 0.0;
+  auto one_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 
10>(1.0);
+  auto zero_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 
10>(0.0);
+  void* alpha = &one_fp32;
+  void* beta = &zero_fp32;
+
+  if (A->dtype.bits == 16 && A->dtype.code == kDLFloat) {
+    ab_type = CUDA_R_16F;
+  }
+
+  if (C->dtype.bits == 16 && C->dtype.code == kDLFloat) {
+    c_type = CUDA_R_16F;
+    compute_type = CUBLAS_COMPUTE_16F;
+    scale_type = CUDA_R_16F;
+    alpha = &one_fp16;
+    beta = &zero_fp16;
+  }
+
+  cublasLtMatmulDesc_t op_desc;
+  cublasOperation_t op_transa = CUBLASBooleanToTranspose(transa);
+  cublasOperation_t op_transb = CUBLASBooleanToTranspose(transb);
+
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&op_desc, compute_type, 
scale_type));
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_TRANSA,
+                                                    &op_transb, 
sizeof(op_transa)));
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_TRANSB,
+                                                    &op_transa, 
sizeof(op_transb)));
+
+  if (bias != nullptr) {
+    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
+                                                      &bias->data, 
sizeof(float*)));
+  }
+
+  if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
+    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_EPILOGUE,
+                                                      &epilogue, 
sizeof(epilogue)));
+  }
+
+  int batch_offset_A = A->ndim - 2;
+  int batch_offset_B = B->ndim - 2;
+
+  int M = ColumnCount(B, transb, batch_offset_B);
+  int N = RowCount(A, transa, batch_offset_A);
+  int K = ColumnCount(A, transa, batch_offset_A);
+
+  int lda = transb ? K : M;
+  int ldb = transa ? N : K;
+  int ldc = M;
+
+  cublasLtMatrixLayout_t A_desc, B_desc, C_desc;
+  CHECK_CUBLAS_ERROR(
+      cublasLtMatrixLayoutCreate(&A_desc, ab_type, !transb ? M : K, !transb ? 
K : M, lda));
+  CHECK_CUBLAS_ERROR(
+      cublasLtMatrixLayoutCreate(&B_desc, ab_type, !transa ? K : N, !transa ? 
N : K, ldb));
+  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&C_desc, c_type, M, N, ldc));
+
+  if (A->ndim != 2 || B->ndim != 2) {
+    auto get_batch_count = [](int64_t* shape, int batch_offset) {
+      int64_t count = 1;
+      for (int i = 0; i < batch_offset; ++i) {
+        count *= shape[i];
+      }
+      return count;
+    };
+    auto set_batch = [](cublasLtMatrixLayout_t mat_desc, int batch_count, 
int64_t batch_stride) {
+      CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
+          mat_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, 
sizeof(batch_count)));
+      CHECK_CUBLAS_ERROR(
+          cublasLtMatrixLayoutSetAttribute(mat_desc, 
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
+                                           &batch_stride, 
sizeof(batch_stride)));
+    };
+
+    int batch_count_A = get_batch_count(A->shape, batch_offset_A);
+    int batch_count_B = get_batch_count(B->shape, batch_offset_B);
+    int batch_count_C = get_batch_count(C->shape, C->ndim - 2);
+    int64_t batch_stride_A = M * K;
+    int64_t batch_stride_B = K * N;
+    int64_t batch_stride_C = M * N;
+
+    // cuBLASLt does not seem to support batched GEMM with one of matrices 
having
+    // one batch (with batch_stride 0).
+    ICHECK_EQ(batch_count_A, batch_count_B);
+
+    set_batch(A_desc, batch_count_A, batch_stride_A);
+    set_batch(B_desc, batch_count_B, batch_stride_B);
+    set_batch(C_desc, batch_count_C, batch_stride_C);
+  }
+
+  auto A_data = static_cast<char*>(A->data) + A->byte_offset;
+  auto B_data = static_cast<char*>(B->data) + B->byte_offset;
+  auto C_data = static_cast<char*>(C->data) + C->byte_offset;
+
+  CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, op_desc, alpha, B_data, A_desc, 
A_data, B_desc, beta,
+                                    C_data, C_desc, C_data, C_desc, nullptr, 
nullptr, 0, nullptr));

Review Comment:
   cublas API which has a default workspace pool, it seems cublasLT always 
require workspace being explicit set, does passing nullptr here impact the 
performance? We may want to have default workspace allocated and stored as 
thread local in `CublasThreadEntry`



##########
src/runtime/contrib/cublas/cublas.cc:
##########
@@ -133,6 +134,120 @@ bool CheckMixPrecisionType(DLDataType in_dtype, 
DLDataType out_dtype, bool int_s
 int roundoff(int v, int d) { return (v + d - 1) / d * d; }
 
 #if CUDART_VERSION >= 10010
+
+void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, 
const DLTensor* bias,
+                  const DLTensor* C, bool transa, bool transb, 
cublasLtEpilogue_t epilogue) {
+  ICHECK(TypeEqual(A->dtype, B->dtype));
+  // Reversed strides indicates an in-place transpose operation.
+  transa = IsInPlaceTransposed(A) ? !transa : transa;
+  transb = IsInPlaceTransposed(B) ? !transb : transb;
+
+  auto compute_type = CUBLAS_COMPUTE_32F;
+  auto scale_type = CUDA_R_32F;
+  cudaDataType_t ab_type = CUDA_R_32F;
+  cudaDataType_t c_type = CUDA_R_32F;
+  float one_fp32 = 1.0;
+  float zero_fp32 = 0.0;
+  auto one_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 
10>(1.0);
+  auto zero_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 
10>(0.0);
+  void* alpha = &one_fp32;
+  void* beta = &zero_fp32;
+
+  if (A->dtype.bits == 16 && A->dtype.code == kDLFloat) {
+    ab_type = CUDA_R_16F;
+  }
+
+  if (C->dtype.bits == 16 && C->dtype.code == kDLFloat) {
+    c_type = CUDA_R_16F;
+    compute_type = CUBLAS_COMPUTE_16F;
+    scale_type = CUDA_R_16F;
+    alpha = &one_fp16;
+    beta = &zero_fp16;
+  }
+
+  cublasLtMatmulDesc_t op_desc;
+  cublasOperation_t op_transa = CUBLASBooleanToTranspose(transa);
+  cublasOperation_t op_transb = CUBLASBooleanToTranspose(transb);
+
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&op_desc, compute_type, 
scale_type));
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_TRANSA,
+                                                    &op_transb, 
sizeof(op_transa)));
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_TRANSB,
+                                                    &op_transa, 
sizeof(op_transb)));
+
+  if (bias != nullptr) {
+    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
+                                                      &bias->data, 
sizeof(float*)));
+  }
+
+  if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
+    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_EPILOGUE,
+                                                      &epilogue, 
sizeof(epilogue)));
+  }
+
+  int batch_offset_A = A->ndim - 2;
+  int batch_offset_B = B->ndim - 2;
+
+  int M = ColumnCount(B, transb, batch_offset_B);
+  int N = RowCount(A, transa, batch_offset_A);
+  int K = ColumnCount(A, transa, batch_offset_A);
+
+  int lda = transb ? K : M;
+  int ldb = transa ? N : K;
+  int ldc = M;
+
+  cublasLtMatrixLayout_t A_desc, B_desc, C_desc;
+  CHECK_CUBLAS_ERROR(
+      cublasLtMatrixLayoutCreate(&A_desc, ab_type, !transb ? M : K, !transb ? 
K : M, lda));
+  CHECK_CUBLAS_ERROR(
+      cublasLtMatrixLayoutCreate(&B_desc, ab_type, !transa ? K : N, !transa ? 
N : K, ldb));
+  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&C_desc, c_type, M, N, ldc));
+
+  if (A->ndim != 2 || B->ndim != 2) {
+    auto get_batch_count = [](int64_t* shape, int batch_offset) {
+      int64_t count = 1;
+      for (int i = 0; i < batch_offset; ++i) {
+        count *= shape[i];
+      }
+      return count;
+    };
+    auto set_batch = [](cublasLtMatrixLayout_t mat_desc, int batch_count, 
int64_t batch_stride) {
+      CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
+          mat_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, 
sizeof(batch_count)));
+      CHECK_CUBLAS_ERROR(
+          cublasLtMatrixLayoutSetAttribute(mat_desc, 
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
+                                           &batch_stride, 
sizeof(batch_stride)));
+    };
+
+    int batch_count_A = get_batch_count(A->shape, batch_offset_A);
+    int batch_count_B = get_batch_count(B->shape, batch_offset_B);
+    int batch_count_C = get_batch_count(C->shape, C->ndim - 2);
+    int64_t batch_stride_A = M * K;
+    int64_t batch_stride_B = K * N;
+    int64_t batch_stride_C = M * N;
+
+    // cuBLASLt does not seem to support batched GEMM with one of matrices 
having
+    // one batch (with batch_stride 0).
+    ICHECK_EQ(batch_count_A, batch_count_B);
+
+    set_batch(A_desc, batch_count_A, batch_stride_A);
+    set_batch(B_desc, batch_count_B, batch_stride_B);
+    set_batch(C_desc, batch_count_C, batch_stride_C);
+  }
+
+  auto A_data = static_cast<char*>(A->data) + A->byte_offset;
+  auto B_data = static_cast<char*>(B->data) + B->byte_offset;
+  auto C_data = static_cast<char*>(C->data) + C->byte_offset;
+
+  CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, op_desc, alpha, B_data, A_desc, 
A_data, B_desc, beta,
+                                    C_data, C_desc, C_data, C_desc, nullptr, 
nullptr, 0, nullptr));

Review Comment:
   cublas API has a default workspace pool, it seems cublasLT always require 
workspace being explicit set, does passing nullptr here impact the performance? 
We may want to have default workspace allocated and stored as thread local in 
`CublasThreadEntry`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to