Hzfengsy commented on a change in pull request #4353: [Perf] Enhance cudnn and
cublas backend and enable TensorCore
URL: https://github.com/apache/incubator-tvm/pull/4353#discussion_r347546810
##
File path: src/runtime/contrib/cublas/cublas.cc
##
@@ -124,35 +169,203 @@ struct CublasDgemmBatchOp {
}
};
+// Check cublas supported mix-precision computation type and return computeType
+bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool
int_support = true) {
+ if (int_support && TypeMatch(out_dtype, kDLInt, 32)) {
+return TypeMatch(in_dtype, kDLInt, 8);
+ } else if (TypeMatch(out_dtype, kDLFloat, 32)) {
+return TypeMatch(in_dtype, kDLInt, 8) ||
+ TypeMatch(in_dtype, kDLFloat, 16);
+ } else {
+return false;
+ }
+}
+
+inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) {
+ DLTensor *A = args[0];
+ DLTensor *B = args[1];
+ DLTensor *C = args[2];
+ bool transa = args[3];
+ bool transb = args[4];
+ CHECK_EQ(A->ndim, 2);
+ CHECK_EQ(B->ndim, 2);
+ CHECK_EQ(C->ndim, 2);
+
+ CHECK_EQ(ElementStride(A), 1);
+ CHECK_EQ(ElementStride(B), 1);
+ CHECK_EQ(ElementStride(C), 1);
+
+ CHECK(TypeEqual(A->dtype, B->dtype));
+
+ // C can never be transposed.
+ CHECK(!IsInPlaceTransposed(C));
+
+ // Reversed strides indicates an in-place transpose operation.
+ transa = IsInPlaceTransposed(A) ? !transa : transa;
+ transb = IsInPlaceTransposed(B) ? !transb : transb;
+
+ CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type";
+ CHECK(!TypeMatch(A->dtype, kDLInt, 8) ||
+ ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8
gemm";
+ CHECK(!TypeMatch(B->dtype, kDLInt, 8) ||
+ ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8
gemm";
+ double alpha = args.size() > 5 ? args[5] : 1.0;
+ double beta = args.size() > 6 ? args[6] : 0.0;
+
+ cudaDataType_t cuda_in_type = GetCudaDataType(A->dtype);
+ cudaDataType_t cuda_out_type = GetCudaDataType(C->dtype);
+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
+ void *alpha_ptr = nullptr, *beta_ptr = nullptr;
+ auto alpha_int = static_cast(alpha);
+ auto beta_int = static_cast(beta);
+ auto alpha_float = static_cast(alpha);
+ auto beta_float = static_cast(beta);
+ if (C->dtype.code == kDLInt) {
+alpha_ptr = _int;
+beta_ptr = _int;
+ } else if (C->dtype.code == kDLFloat) {
+alpha_ptr = _float;
+beta_ptr = _float;
+ }
+
+ auto A_data = reinterpret_cast(static_cast(A->data) +
A->byte_offset);
+ auto B_data = reinterpret_cast(static_cast(B->data) +
B->byte_offset);
+ auto C_data = reinterpret_cast(static_cast(C->data) +
C->byte_offset);
+
+ CHECK_CUBLAS_ERROR(cublasGemmEx(hdl,
+ BooleanToTranspose(transb),
+ BooleanToTranspose(transa),
+ ColumnCount(B, transb),
+ RowCount(A, transa),
+ ColumnCount(A, transa),
+ alpha_ptr,
+ B_data, cuda_in_type, ColumnStride(B),
+ A_data, cuda_in_type, ColumnStride(A),
+ beta_ptr,
+ C_data, cuda_out_type, ColumnStride(C),
+ cuda_out_type, algo));
Review comment:
Thank you! Here are two things worth noting:
- If we have fp16 input and output, we will use `CublasHgemm` rather than
`cublasGemmEx`.
- Currently, I can not any int8 output support in `cublasGemmEx`. Please see
details at https://docs.nvidia.com/cuda/cublas/index.html#cublas-GemmEx
Hence, I don't think there is a precision problem for now. Please figure out
if I have missed anything.
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org
With regards,
Apache Git Services