This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c0a47ed139 [CUBLAS][FP8] Enable R.matmul + R.multiply offloading
(#16974)
c0a47ed139 is described below
commit c0a47ed13999881d2e6ea68e3904f5c613bbdb94
Author: Ivan Sidorenko <[email protected]>
AuthorDate: Wed May 8 12:54:01 2024 +0300
[CUBLAS][FP8] Enable R.matmul + R.multiply offloading (#16974)
This commit enables offloading of the next pattern to cuBLAS:
mm = R.linear(data, weights)
scale = R.multiply(a_scale, w_scale)
out = R.multiply(mm, scale)
out = R.cast(out, dtype)
---
python/tvm/relax/backend/contrib/cublas.py | 11 +++-
python/tvm/relax/backend/patterns.py | 38 +++++++++++
src/relax/backend/contrib/cublas/codegen.cc | 5 +-
src/runtime/contrib/cublas/cublas.cc | 14 +++-
src/runtime/contrib/cublas/cublas_json_runtime.cc | 15 +++--
src/runtime/contrib/cublas/cublas_utils.h | 6 +-
tests/python/relax/test_codegen_cublas.py | 79 +++++++++++++++++++++++
7 files changed, 156 insertions(+), 12 deletions(-)
diff --git a/python/tvm/relax/backend/contrib/cublas.py
b/python/tvm/relax/backend/contrib/cublas.py
index e5bc55c327..db4bd332c5 100644
--- a/python/tvm/relax/backend/contrib/cublas.py
+++ b/python/tvm/relax/backend/contrib/cublas.py
@@ -25,7 +25,11 @@ from tvm.relax import transform
from tvm.relax.transform import PatternCheckContext
from ..pattern_registry import get_patterns_with_prefix, register_patterns
-from ..patterns import make_matmul_pattern, make_matmul_dequantize_pattern
+from ..patterns import (
+ make_matmul_pattern,
+ make_matmul_dequantize_pattern,
+ make_matmul_multiply_pattern,
+)
from ..utils import has_leaking_intermediate_variables
@@ -202,6 +206,11 @@ register_patterns(
*make_matmul_dequantize_pattern(transposed_rhs=True),
_check_matmul,
),
+ (
+ "cublas.matmul_transposed_multiply",
+ *make_matmul_multiply_pattern(transposed_rhs=True),
+ _check_matmul,
+ ),
]
)
diff --git a/python/tvm/relax/backend/patterns.py
b/python/tvm/relax/backend/patterns.py
index 404f7dc975..8ec43f1f27 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -376,6 +376,44 @@ def make_matmul_dequantize_pattern(
return out, annotations
+def make_matmul_multiply_pattern(
+ transposed_rhs: bool = False,
+) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
+ """
+ Create pattern for matrix multiplication and multiply operation.
+
+ Parameters
+ ----------
+ transposed_rhs: bool
+ Whether the right hand side of multiplication is transposed.
+
+ Returns
+ -------
+ pattern: DFPattern
+ The resulting pattern describing a matrix multiplication.
+
+ annotations: Mapping[str, DFPattern]
+ A mapping from name to sub pattern. It can be used to extract
important expressions from
+ match result, to power the partition check function and codegen.
+ """
+
+ lhs = wildcard()
+ rhs = wildcard()
+ scaleA = wildcard()
+ scaleB = wildcard()
+ annotations = {"lhs": lhs, "rhs": rhs, "scaleA": scaleA, "scaleB": scaleB}
+
+ if transposed_rhs:
+ rhs = is_op("relax.permute_dims")(rhs)
+ out = is_op("relax.matmul")(lhs, rhs)
+ annotations["root"] = out
+ scale = is_op("relax.multiply")(scaleA.has_shape((1,)),
scaleB.has_shape((1,)))
+ out = is_op("relax.multiply")(out, scale)
+ out = is_op("relax.astype")(out)
+
+ return out, annotations
+
+
def make_attention_rewrite_pattern(
qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool,
with_kv_repeat: bool = False
):
diff --git a/src/relax/backend/contrib/cublas/codegen.cc
b/src/relax/backend/contrib/cublas/codegen.cc
index 9f29d21aaa..e92ee57a5a 100644
--- a/src/relax/backend/contrib/cublas/codegen.cc
+++ b/src/relax/backend/contrib/cublas/codegen.cc
@@ -62,7 +62,7 @@ class CublasJSONSerializer : public JSONSerializer {
inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end());
}
- ICHECK(inputs_tmp.size() <= 3);
+ ICHECK(inputs_tmp.size() <= 4);
NodeEntries inputs(inputs_tmp.size());
auto arg_idx = backend::ExtractArgIdx(composite_name, fn);
@@ -70,6 +70,9 @@ class CublasJSONSerializer : public JSONSerializer {
inputs[1] = inputs_tmp[arg_idx["rhs"]->value];
if (inputs_tmp.size() == 3) {
inputs[2] = inputs_tmp[arg_idx["bias"]->value];
+ } else if (inputs_tmp.size() == 4) {
+ inputs[2] = inputs_tmp[arg_idx["scaleA"]->value];
+ inputs[3] = inputs_tmp[arg_idx["scaleB"]->value];
}
auto node = std::make_shared<JSONGraphNode>(composite_name, /* name_ */
diff --git a/src/runtime/contrib/cublas/cublas.cc
b/src/runtime/contrib/cublas/cublas.cc
index 1edb6b95c9..8925080abf 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -137,8 +137,9 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; }
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor*
A, const DLTensor* B,
- const DLTensor* bias, const DLTensor* C, bool transa, bool
transb,
- void* workspace_ptr, size_t workspace_size,
cublasLtEpilogue_t epilogue,
+ const DLTensor* bias, const DLTensor* scaleA, const
DLTensor* scaleB,
+ const DLTensor* C, bool transa, bool transb, void*
workspace_ptr,
+ size_t workspace_size, cublasLtEpilogue_t epilogue,
std::optional<float> dq_scale) {
ICHECK(TypeEqual(A->dtype, B->dtype));
// Reversed strides indicates an in-place transpose operation.
@@ -193,6 +194,15 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t
stream,
&bias->data,
sizeof(float*)));
}
+ if (scaleA != nullptr && scaleB != nullptr) {
+ auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
+ auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
+ &scaleA_data,
sizeof(float*)));
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
+ &scaleB_data,
sizeof(float*)));
+ }
+
if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc
b/src/runtime/contrib/cublas/cublas_json_runtime.cc
index 8578d86789..49ff061da5 100644
--- a/src/runtime/contrib/cublas/cublas_json_runtime.cc
+++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc
@@ -97,12 +97,15 @@ class CublasJSONRuntime : public JSONRuntimeBase {
return dl_tensors[eid];
};
- auto get_inputs = [=](const JSONGraphNode& node, bool has_bias) {
- const DLTensor* bias = nullptr;
+ auto get_inputs = [=](const JSONGraphNode& node, bool has_bias, bool
has_scale) {
+ const DLTensor *bias = nullptr, *scaleA = nullptr, *scaleB = nullptr;
if (has_bias) {
bias = get_input(node, 2);
+ } else if (has_scale) {
+ scaleA = get_input(node, 2);
+ scaleB = get_input(node, 3);
}
- return std::make_tuple(get_input(node, 0), get_input(node, 1), bias);
+ return std::make_tuple(get_input(node, 0), get_input(node, 1), bias,
scaleA, scaleB);
};
for (size_t i = 0; i < nodes_.size(); ++i) {
@@ -127,7 +130,9 @@ class CublasJSONRuntime : public JSONRuntimeBase {
epilogue = CUBLASLT_EPILOGUE_BIAS;
}
- auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue !=
CUBLASLT_EPILOGUE_DEFAULT);
+ bool has_scale = op_name.find("multiply") != std::string::npos;
+ auto [a_ptr, b_ptr, bias_ptr, scaleA_ptr, scaleB_ptr] =
+ get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT, has_scale);
std::optional<float> dq_scale = std::nullopt;
if (op_name.find("dequantize") != std::string::npos) {
@@ -135,7 +140,7 @@ class CublasJSONRuntime : public JSONRuntimeBase {
}
tvm::contrib::CallCublasLt(entry_ptr->handle, stream,
entry_ptr->matmul_pref_desc, a_ptr,
- b_ptr, bias_ptr, out_ptr, transa, transb,
+ b_ptr, bias_ptr, scaleA_ptr, scaleB_ptr,
out_ptr, transa, transb,
entry_ptr->workspace_ptr,
entry_ptr->workspace_size, epilogue,
dq_scale);
}
diff --git a/src/runtime/contrib/cublas/cublas_utils.h
b/src/runtime/contrib/cublas/cublas_utils.h
index 2906279f90..387065093e 100644
--- a/src/runtime/contrib/cublas/cublas_utils.h
+++ b/src/runtime/contrib/cublas/cublas_utils.h
@@ -123,9 +123,9 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
/*! \brief Execute matrix multiply followed by the specified epilogue, using
cuBLASLt. */
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor*
A, const DLTensor* B,
- const DLTensor* bias, const DLTensor* C, bool transa, bool
transb,
- void* workspace_ptr, size_t workspace_size,
- cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT,
+ const DLTensor* bias, const DLTensor* scaleA, const
DLTensor* scaleB,
+ const DLTensor* C, bool transa, bool transb, void*
workspace_ptr,
+ size_t workspace_size, cublasLtEpilogue_t epilogue =
CUBLASLT_EPILOGUE_DEFAULT,
std::optional<float> dq_scale = std::nullopt);
} // namespace contrib
diff --git a/tests/python/relax/test_codegen_cublas.py
b/tests/python/relax/test_codegen_cublas.py
index 4ff498ae2b..913f203d19 100644
--- a/tests/python/relax/test_codegen_cublas.py
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -134,6 +134,40 @@ def get_relax_matmul_dequantize_module(
return tvm.IRModule({"main": func})
+def get_relax_matmul_multiply_module(
+ x_shape,
+ y_shape,
+ z_shape,
+ in_dtype,
+ acc_dtype,
+ out_dtype,
+ transposed_y=False,
+):
+ """Create a matmul op followd by multiply operations."""
+ with IRBuilder() as builder:
+ with relax_builder.function():
+ R.func_name("main")
+ x = R.arg("x", R.Tensor(x_shape, in_dtype))
+ y = R.arg("y", R.Tensor(y_shape, in_dtype))
+ scaleA = R.arg("scaleA", R.Tensor(z_shape, acc_dtype))
+ scaleB = R.arg("scaleB", R.Tensor(z_shape, acc_dtype))
+
+ with R.dataflow() as frame:
+ if transposed_y:
+ axes = list(range(len(y_shape) - 2)) + [-1, -2]
+ y = R.emit(R.permute_dims(y, axes=axes))
+ result = R.emit(R.matmul(x, y, out_dtype=acc_dtype))
+ z = R.emit(R.multiply(scaleA, scaleB))
+ result = R.emit(R.multiply(result, z))
+ if acc_dtype != out_dtype:
+ result = R.emit(R.astype(result, out_dtype))
+ R.output(result)
+ R.func_ret_value(frame.output_vars[0])
+
+ func = builder.get()
+ return tvm.IRModule({"main": func})
+
+
@pytest.mark.parametrize(
"x_shape, y_shape, transpose_y, epilogue",
[
@@ -327,6 +361,36 @@ def test_matmul_fp8_dequantize_offload():
tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
[email protected]_cuda_compute_version(9)
[email protected](ml_dtypes is None, reason="requires ml_dtypes to be
installed")
+def test_matmul_fp8_multiply_offload():
+ x_shape = (10, 32)
+ y_shape = (64, 32)
+ z_shape = (1,)
+ in_dtype, acc_dtype = ("e4m3_float8", "float32")
+
+ mod = get_relax_matmul_multiply_module(
+ x_shape,
+ y_shape,
+ z_shape,
+ in_dtype,
+ acc_dtype,
+ "float16",
+ transposed_y=True,
+ )
+
+ numpytype = "float8_e4m3fn"
+ x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype)
+ y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype)
+ scaleA = np.random.uniform(low=0, high=5, size=z_shape).astype(acc_dtype)
+ scaleB = np.random.uniform(low=0, high=5, size=z_shape).astype(acc_dtype)
+ args = (x, y, scaleA, scaleB)
+
+ out = get_result_with_relax_cublas_offload(mod, args)
+ ref = build_and_run(mod, args, "llvm", legalize=True)
+ tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
+
+
@pytest.mark.parametrize(
"M, N, K, out_dtype, transposed_y, partition_done",
[
@@ -371,6 +435,21 @@ def test_cublas_partition_fp8_matmul_dequantize(M, N, K,
scale, zp, num_bindings
assert len(mod["main"].body.blocks[0].bindings) == num_bindings
+def test_cublas_partition_fp8_matmul_multiply():
+ M, N, K = (32, 64, 128)
+ mod = get_relax_matmul_multiply_module(
+ (M, K),
+ (N, K),
+ (1,),
+ "e4m3_float8",
+ "float32",
+ "float16",
+ transposed_y=True,
+ )
+ mod = partition_for_cublas(mod)
+ assert len(mod["main"].body.blocks[0].bindings) == 1
+
+
def test_cublas_partition_matmul_without_bias():
# cuBLAS does not handle 2D bias (residual input)
mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16",
bias_shape=(16, 32))