minminsun commented on a change in pull request #4353: [Perf] Enhance cudnn and 
cublas backend and enable TensorCore
URL: https://github.com/apache/incubator-tvm/pull/4353#discussion_r347709505
 
 

 ##########
 File path: src/runtime/contrib/cublas/cublas.cc
 ##########
 @@ -124,35 +169,203 @@ struct CublasDgemmBatchOp {
   }
 };
 
+// Check cublas supported mix-precision computation type and return computeType
+bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool 
int_support = true) {
+  if (int_support && TypeMatch(out_dtype, kDLInt, 32)) {
+    return TypeMatch(in_dtype, kDLInt, 8);
+  } else if (TypeMatch(out_dtype, kDLFloat, 32)) {
+    return TypeMatch(in_dtype, kDLInt, 8) ||
+           TypeMatch(in_dtype, kDLFloat, 16);
+  } else {
+    return false;
+  }
+}
+
+inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) {
+  DLTensor *A = args[0];
+  DLTensor *B = args[1];
+  DLTensor *C = args[2];
+  bool transa = args[3];
+  bool transb = args[4];
+  CHECK_EQ(A->ndim, 2);
+  CHECK_EQ(B->ndim, 2);
+  CHECK_EQ(C->ndim, 2);
+
+  CHECK_EQ(ElementStride(A), 1);
+  CHECK_EQ(ElementStride(B), 1);
+  CHECK_EQ(ElementStride(C), 1);
+
+  CHECK(TypeEqual(A->dtype, B->dtype));
+
+  // C can never be transposed.
+  CHECK(!IsInPlaceTransposed(C));
+
+  // Reversed strides indicates an in-place transpose operation.
+  transa = IsInPlaceTransposed(A) ? !transa : transa;
+  transb = IsInPlaceTransposed(B) ? !transb : transb;
+
+  CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type";
+  CHECK(!TypeMatch(A->dtype, kDLInt, 8) ||
+      ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 
gemm";
+  CHECK(!TypeMatch(B->dtype, kDLInt, 8) ||
+      ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 
gemm";
+  double alpha = args.size() > 5 ? args[5] : 1.0;
+  double beta = args.size() > 6 ? args[6] : 0.0;
+
+  cudaDataType_t cuda_in_type = GetCudaDataType(A->dtype);
+  cudaDataType_t cuda_out_type = GetCudaDataType(C->dtype);
+  cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
+  void *alpha_ptr = nullptr, *beta_ptr = nullptr;
+  auto alpha_int = static_cast<int32_t>(alpha);
+  auto beta_int = static_cast<int32_t>(beta);
+  auto alpha_float = static_cast<float>(alpha);
+  auto beta_float = static_cast<float>(beta);
+  if (C->dtype.code == kDLInt) {
+    alpha_ptr = &alpha_int;
+    beta_ptr = &beta_int;
+  } else if (C->dtype.code == kDLFloat) {
+    alpha_ptr = &alpha_float;
+    beta_ptr = &beta_float;
+  }
+
+  auto A_data = reinterpret_cast<void *>(static_cast<char *>(A->data) + 
A->byte_offset);
+  auto B_data = reinterpret_cast<void *>(static_cast<char *>(B->data) + 
B->byte_offset);
+  auto C_data = reinterpret_cast<void *>(static_cast<char *>(C->data) + 
C->byte_offset);
+
+  CHECK_CUBLAS_ERROR(cublasGemmEx(hdl,
+                                 BooleanToTranspose(transb),
+                                 BooleanToTranspose(transa),
+                                 ColumnCount(B, transb),
+                                 RowCount(A, transa),
+                                 ColumnCount(A, transa),
+                                 alpha_ptr,
+                                 B_data, cuda_in_type, ColumnStride(B),
+                                 A_data, cuda_in_type, ColumnStride(A),
+                                 beta_ptr,
+                                 C_data, cuda_out_type, ColumnStride(C),
+                                 cuda_out_type, algo));
 
 Review comment:
   Thanks, I see. So there's no problem here as far as the output type is not 
fp16 or int8.
   
   I double checked that CublasHgemm equals to cublasGemmEx with compute type 
fp16. So it would be better to use cublasGemmEx with compute type fp32 for 
computations with both fp16 input and fp16 output.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to