This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d1ac73ca2d [CUBLAS][FP8] Support e4m3 gemm in cuBLAS BYOC (#16888)
d1ac73ca2d is described below
commit d1ac73ca2d3c14dc69e47818871478e8b0f295aa
Author: Ivan Sidorenko <[email protected]>
AuthorDate: Tue Apr 16 21:55:11 2024 +0300
[CUBLAS][FP8] Support e4m3 gemm in cuBLAS BYOC (#16888)
[CUBLAS][FP8] Support e4m3 gemm in cuBLAS BYOC (#63)
Co-authored-by: Andrey Malyshev <[email protected]>
---
include/tvm/runtime/data_type.h | 3 ++
python/tvm/contrib/tvmjs.py | 19 +++++++++
python/tvm/relax/backend/contrib/cublas.py | 16 ++++++-
python/tvm/relax/transform/legalize_ops/qdq.py | 27 +++++++-----
src/relax/backend/contrib/utils.h | 4 ++
src/relax/op/tensor/qdq.cc | 18 +++++---
src/runtime/contrib/cublas/cublas.cc | 3 ++
src/tir/op/op.cc | 2 +
tests/python/relax/test_codegen_cublas.py | 59 ++++++++++++++++++++++++++
tests/python/relax/test_op_qdq.py | 37 ++++++++++++++++
10 files changed, 169 insertions(+), 19 deletions(-)
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index f7284ec690..a330ccbbdf 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -126,6 +126,9 @@ class DataType {
code() == DataType::kE5M2Float) &&
bits() == 8;
}
+ bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float &&
bits() == 8); }
+
+ bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float &&
bits() == 8); }
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
index 8d8bd1b051..923301a1f5 100644
--- a/python/tvm/contrib/tvmjs.py
+++ b/python/tvm/contrib/tvmjs.py
@@ -28,6 +28,11 @@ from typing import Iterator, Mapping, Tuple, Union
import numpy as np
+try:
+ import ml_dtypes
+except ImportError:
+ ml_dtypes = None
+
import tvm
from tvm._ffi.libinfo import find_lib_path
@@ -295,6 +300,20 @@ def load_ndarray_cache(cachepath: str, device:
tvm.runtime.Device):
arr = tvm.nd.empty(shape, dtype, device=device)
assert offset + nbytes <= len(raw_data)
buffer_source = raw_data[offset : offset + nbytes]
+ if dtype == "e4m3_float8":
+ if ml_dtypes is not None:
+ dtype = ml_dtypes.float8_e4m3fn
+ else:
+ raise RuntimeError(
+ "ml_dtypes is not installed, cannot convert
e4m3_float8 array to numpy."
+ )
+ if dtype == "e5m2_float8":
+ if ml_dtypes is not None:
+ dtype = ml_dtypes.float8_e5m2
+ else:
+ raise RuntimeError(
+ "ml_dtypes is not installed, cannot convert
e5m2_float8 array to numpy."
+ )
if encode_format == "f32-to-bf16" and dtype == "float32":
data = np.frombuffer(buffer_source,
dtype="uint16").reshape(shape)
arr.copyfrom(_convert_bf16_to_f32(data))
diff --git a/python/tvm/relax/backend/contrib/cublas.py
b/python/tvm/relax/backend/contrib/cublas.py
index eecd531e74..f66001d0e8 100644
--- a/python/tvm/relax/backend/contrib/cublas.py
+++ b/python/tvm/relax/backend/contrib/cublas.py
@@ -28,8 +28,11 @@ from ..patterns import make_matmul_pattern
from ..utils import has_leaking_intermediate_variables
-def _is_supported_dtype(lhs_dtype, rhs_dtype):
+def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
"""Check if dtypes in the given workload are supported by cuBLAS BYOC."""
+ if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
+ # The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8'
+ return out_dtype != "e5m2_float8"
return (
(lhs_dtype == "float16" and rhs_dtype == "float16")
or (lhs_dtype == "float32" and rhs_dtype == "float32")
@@ -42,10 +45,12 @@ def _check_matmul(context: PatternCheckContext) -> bool:
return False
lhs = context.annotated_expr["lhs"]
rhs = context.annotated_expr["rhs"]
+ matmul_call = context.annotated_expr["root"]
lhs_dtype = lhs.struct_info.dtype
rhs_dtype = rhs.struct_info.dtype
- if not _is_supported_dtype(lhs_dtype, rhs_dtype):
+ out_dtype = matmul_call.struct_info.dtype
+ if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
return False
lhs_shape = lhs.struct_info.shape.values
@@ -62,6 +67,13 @@ def _check_matmul(context: PatternCheckContext) -> bool:
if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or
rhs_shape[-1] % 4 != 0:
# Rows number must be multiples of 4 for IGEMM
return False
+ elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
+ # Matrix dimensions must be multiples of 16. This requirement is
missing from the cuBLAS
+ # docs, but it was observed during testing.
+ if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or
rhs_shape[-1] % 16 != 0:
+ return False
+ if not isinstance(rhs_shape[-2], (tvm.tir.expr.IntImm, int)) or
rhs_shape[-2] % 16 != 0:
+ return False
lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
diff --git a/python/tvm/relax/transform/legalize_ops/qdq.py
b/python/tvm/relax/transform/legalize_ops/qdq.py
index 4f1e43d988..7484285c1e 100644
--- a/python/tvm/relax/transform/legalize_ops/qdq.py
+++ b/python/tvm/relax/transform/legalize_ops/qdq.py
@@ -52,7 +52,8 @@ def _quantize(bb: BlockBuilder, call: Call) -> Expr:
def quantize_compute(*indices):
scale_value = scale if is_const_scalar(scale) else
scale[indices[axis]]
zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
- round_val = te.round(data[indices] / scale_value) + zp_value
+ scaled = data[indices] / scale_value
+ round_val = (te.round(scaled) if "int" in out_dtype else scaled) +
zp_value
return clip_cast(round_val, out_dtype)
output_shape = data.shape
@@ -75,15 +76,18 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr:
Compute datatype: float32
Example of lowering:
- qnn.dequantize(data, scale, zp, "float32") -->
- sub = subtract(cast(data, "int32"), zp)
- out = multiply(cast(sub, "float32"), scale)
-
- qnn.dequantize(data, scale, zp, "float16") -->
- sub = subtract(cast(data, "int32"), zp)
- mul = multiply(cast(sub, "float32"), cast(scale, "float32"))
- clipped_out = clip(mul, float32(-65504.0), float32(65504.0))
- out = cast(clipped_out, "float16")
+
+ dtype = ["int32"|"float32"]
+
+ qnn.dequantize(data, scale, zp, "float32") -->
+ sub = subtract(cast(data, dtype), zp)
+ out = multiply(cast(sub, "float32"), scale)
+
+ qnn.dequantize(data, scale, zp, "float16") -->
+ sub = subtract(cast(data, dtype), zp)
+ mul = multiply(cast(sub, "float32"), cast(scale, "float32"))
+ clipped_out = clip(mul, float32(-65504.0), float32(65504.0))
+ out = cast(clipped_out, "float16")
"""
axis = call.attrs.axis
out_dtype = call.attrs.out_dtype
@@ -96,7 +100,8 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr:
def dequantize_compute(*indices):
scale_value = scale if is_const_scalar(scale) else
scale[indices[axis]]
zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
- sub = te.subtract(data[indices].astype("int32"), zp_value)
+ dtype = "float32" if "float" in data.dtype else "int32"
+ sub = te.subtract(data[indices].astype(dtype), zp_value)
out = te.multiply(sub, scale_value.astype("float32"))
if out_dtype == "float32":
return out
diff --git a/src/relax/backend/contrib/utils.h
b/src/relax/backend/contrib/utils.h
index ee1240aaed..412651d3f9 100644
--- a/src/relax/backend/contrib/utils.h
+++ b/src/relax/backend/contrib/utils.h
@@ -72,6 +72,10 @@ inline std::string DType2String(const tvm::DataType dtype) {
std::ostringstream os;
if (dtype.is_float()) {
os << "float";
+ } else if (dtype.is_e4m3_float8()) {
+ os << "e4m3_float";
+ } else if (dtype.is_e5m2_float8()) {
+ os << "e5m2_float";
} else if (dtype.is_int()) {
os << "int";
} else if (dtype.is_uint()) {
diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc
index f8b0ed0ca2..0189ef9678 100644
--- a/src/relax/op/tensor/qdq.cc
+++ b/src/relax/op/tensor/qdq.cc
@@ -49,7 +49,9 @@
TVM_REGISTER_GLOBAL("relax.op.quantize").set_body_typed(quantize);
StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) {
const auto* attrs = call->attrs.as<QuantizeAttrs>();
if (attrs->out_dtype != DataType::Int(8) && attrs->out_dtype !=
DataType::UInt(8) &&
- attrs->out_dtype != DataType::Int(16) && attrs->out_dtype !=
DataType::UInt(16)) {
+ attrs->out_dtype != DataType::Int(16) && attrs->out_dtype !=
DataType::UInt(16) &&
+ attrs->out_dtype != DataType::NVFloat8E4M3() &&
+ attrs->out_dtype != DataType::NVFloat8E5M2()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Unsupported output datatype attribute for operation:
'"
<< attrs->out_dtype);
@@ -73,9 +75,10 @@ StructInfo InferStructInfoQuantize(const Call& call, const
BlockBuilder& ctx) {
}
// Check datatype of zero_point param:
- if (zp_sinfo->dtype != DataType::Int(8)) {
+ if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype !=
DataType::Float(16)) {
ctx->ReportFatal(Diagnostic::Error(call)
- << "zero_point param datatype should be int8, but got "
<< zp_sinfo->dtype);
+ << "zero_point param datatype should be 'int8' or
'float16', but got "
+ << zp_sinfo->dtype);
}
// Check that "axis" attribute is not out of range:
@@ -142,7 +145,9 @@ StructInfo InferStructInfoDequantize(const Call& call,
const BlockBuilder& ctx)
// Check input datatype:
if (input_sinfo->dtype != DataType::Int(8) && input_sinfo->dtype !=
DataType::UInt(8) &&
input_sinfo->dtype != DataType::Int(16) && input_sinfo->dtype !=
DataType::UInt(16) &&
- input_sinfo->dtype != DataType::Int(32)) {
+ input_sinfo->dtype != DataType::Int(32) && input_sinfo->dtype !=
DataType::NVFloat8E4M3() &&
+ input_sinfo->dtype != DataType::NVFloat8E5M2() && input_sinfo->dtype !=
DataType::Float(16) &&
+ input_sinfo->dtype != DataType::Float(32)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Unsupported input datatype for operation: " <<
attrs->out_dtype);
}
@@ -155,9 +160,10 @@ StructInfo InferStructInfoDequantize(const Call& call,
const BlockBuilder& ctx)
}
// Check datatype of zero_point param:
- if (zp_sinfo->dtype != DataType::Int(8)) {
+ if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype !=
DataType::Float(16)) {
ctx->ReportFatal(Diagnostic::Error(call)
- << "zero_point param datatype should be int8, but got "
<< zp_sinfo->dtype);
+ << "zero_point param datatype should be 'int8' or
'float16', but got "
+ << zp_sinfo->dtype);
}
// Check that "axis" attribute is not out of range:
diff --git a/src/runtime/contrib/cublas/cublas.cc
b/src/runtime/contrib/cublas/cublas.cc
index 7a867f4bae..49aa35a7e0 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -161,6 +161,9 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
ab_type = CUDA_R_16F;
} else if (TypeMatch(A->dtype, kDLInt, 8)) {
ab_type = CUDA_R_8I;
+ } else if (TypeMatch(A->dtype, DataType::TypeCode::kE4M3Float, 8)) {
+ ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kE4M3Float, 8));
+ ab_type = CUDA_R_8F_E4M3;
}
if (TypeMatch(C->dtype, kDLFloat, 16)) {
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index b613639786..c79a148e4b 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -263,6 +263,7 @@ PrimExpr max_value(const DataType& dtype, Span span) {
} else if (dtype.is_bfloat16()) {
return FloatImm(dtype, std::numeric_limits<float>::max(), span);
} else if (dtype.is_float8()) {
+ // according to https://arxiv.org/pdf/2209.05433.pdf
if (dtype.code() == DataType::TypeCode::kE5M2Float) {
return FloatImm(dtype, 57344.0, span);
} else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
@@ -303,6 +304,7 @@ PrimExpr min_value(const DataType& dtype, Span span) {
} else if (dtype.is_bfloat16()) {
return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
} else if (dtype.is_float8()) {
+ // according to https://arxiv.org/pdf/2209.05433.pdf
if (dtype.code() == DataType::TypeCode::kE5M2Float) {
return FloatImm(dtype, -57344.0, span);
} else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
diff --git a/tests/python/relax/test_codegen_cublas.py
b/tests/python/relax/test_codegen_cublas.py
index 52ad8b94b9..11247b3801 100644
--- a/tests/python/relax/test_codegen_cublas.py
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -25,6 +25,11 @@ from tvm.relax.backend.contrib.cublas import
partition_for_cublas
from tvm.relax.testing import get_relax_matmul_module
from tvm.script import relax as R
+try:
+ import ml_dtypes
+except ImportError:
+ ml_dtypes = None
+
@pytest.fixture(autouse=True)
def reset_seed():
@@ -226,6 +231,60 @@ def test_matmul_igemm_offload(
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
[email protected](ml_dtypes is None, reason="requires ml_dtypes to be
installed")
[email protected](
+ "x_shape, y_shape, transpose_y, out_dtype",
+ [
+ ((10, 32), (64, 32), True, "float32"),
+ ((32, 16), (32, 16), True, "float16"),
+ ((2, 10, 32), (2, 64, 32), True, "float32"),
+ ],
+)
+def test_matmul_fp8_offload(
+ x_shape,
+ y_shape,
+ transpose_y,
+ out_dtype,
+):
+ in_dtype = "e4m3_float8"
+ mod = get_relax_matmul_module(
+ x_shape,
+ y_shape,
+ in_dtype,
+ out_dtype,
+ bias_shape=None,
+ transposed_y=transpose_y,
+ activation=None,
+ )
+ numpytype = "float8_e4m3fn"
+ x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype)
+ y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype)
+ z = np.swapaxes(y, -2, -1) if transpose_y else y
+ args = (x, y)
+
+ out = get_result_with_relax_cublas_offload(mod, args)
+ ref_out = np.matmul(x, z).astype(out_dtype)
+
+ tvm.testing.assert_allclose(out, ref_out, rtol=1e-3, atol=1e-3)
+
+
[email protected](
+ "M, N, K, out_dtype, partition_done",
+ [
+ (15, 64, 32, "float32", True),
+ (15, 64, 32, "e4m3_float8", True),
+ (15, 64, 32, "e5m2_float8", False),
+ (16, 32, 60, "float32", False),
+ (16, 30, 64, "float32", False),
+ ],
+)
+def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, partition_done):
+ mod = get_relax_matmul_module((M, K), (N, K), "e4m3_float8", out_dtype,
transposed_y=True)
+ mod = partition_for_cublas(mod)
+ func_name = "relax_matmul_cublas" if partition_done else "R.matmul"
+ assert func_name in mod["main"].script()
+
+
def test_cublas_partition_matmul_without_bias():
# cuBLAS does not handle 2D bias (residual input)
mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16",
bias_shape=(16, 32))
diff --git a/tests/python/relax/test_op_qdq.py
b/tests/python/relax/test_op_qdq.py
index 42391120e9..8b2d499041 100644
--- a/tests/python/relax/test_op_qdq.py
+++ b/tests/python/relax/test_op_qdq.py
@@ -68,5 +68,42 @@ def test_qdq_op_infer_struct_info_symbolic():
)
+def test_qdq_e4m3_float8_op_infer_struct_info_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ x = relax.Var("x", R.Tensor((n, 3), "float32"))
+ dx = relax.Var("dx", R.Tensor((n, 3), "e4m3_float8"))
+ s = relax.Var("s", R.Tensor([3], "float32"))
+ zp = relax.Var("zp", R.Tensor([3], "float16"))
+ _check_inference(
+ bb,
+ relax.op.quantize(x, s, zp, 1, "e4m3_float8"),
+ relax.TensorStructInfo((n, 3), "e4m3_float8"),
+ )
+ _check_inference(
+ bb,
+ relax.op.dequantize(dx, s, zp, 1, "float32"),
+ relax.TensorStructInfo((n, 3), "float32"),
+ )
+
+
+def test_qdq_e5m2_float8_op_infer_struct_info_symbolic():
+ dtype = "e5m2_float8"
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ x = relax.Var("x", R.Tensor((n, 3), "float32"))
+ dx = relax.Var("dx", R.Tensor((n, 3), dtype))
+ s = relax.Var("s", R.Tensor([3], "float32"))
+ zp = relax.Var("zp", R.Tensor([3], "float16"))
+ _check_inference(
+ bb, relax.op.quantize(x, s, zp, 1, dtype), relax.TensorStructInfo((n,
3), dtype)
+ )
+ _check_inference(
+ bb,
+ relax.op.dequantize(dx, s, zp, 1, "float32"),
+ relax.TensorStructInfo((n, 3), "float32"),
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()