This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c0a47ed139 [CUBLAS][FP8] Enable R.matmul + R.multiply offloading 
(#16974)
c0a47ed139 is described below

commit c0a47ed13999881d2e6ea68e3904f5c613bbdb94
Author: Ivan Sidorenko <[email protected]>
AuthorDate: Wed May 8 12:54:01 2024 +0300

    [CUBLAS][FP8] Enable R.matmul + R.multiply offloading (#16974)
    
    This commit enables offloading of the next pattern to cuBLAS:
      mm = R.linear(data, weights)
      scale = R.multiply(a_scale, w_scale)
      out = R.multiply(mm, scale)
      out = R.cast(out, dtype)
---
 python/tvm/relax/backend/contrib/cublas.py        | 11 +++-
 python/tvm/relax/backend/patterns.py              | 38 +++++++++++
 src/relax/backend/contrib/cublas/codegen.cc       |  5 +-
 src/runtime/contrib/cublas/cublas.cc              | 14 +++-
 src/runtime/contrib/cublas/cublas_json_runtime.cc | 15 +++--
 src/runtime/contrib/cublas/cublas_utils.h         |  6 +-
 tests/python/relax/test_codegen_cublas.py         | 79 +++++++++++++++++++++++
 7 files changed, 156 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relax/backend/contrib/cublas.py 
b/python/tvm/relax/backend/contrib/cublas.py
index e5bc55c327..db4bd332c5 100644
--- a/python/tvm/relax/backend/contrib/cublas.py
+++ b/python/tvm/relax/backend/contrib/cublas.py
@@ -25,7 +25,11 @@ from tvm.relax import transform
 from tvm.relax.transform import PatternCheckContext
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
-from ..patterns import make_matmul_pattern, make_matmul_dequantize_pattern
+from ..patterns import (
+    make_matmul_pattern,
+    make_matmul_dequantize_pattern,
+    make_matmul_multiply_pattern,
+)
 from ..utils import has_leaking_intermediate_variables
 
 
@@ -202,6 +206,11 @@ register_patterns(
             *make_matmul_dequantize_pattern(transposed_rhs=True),
             _check_matmul,
         ),
+        (
+            "cublas.matmul_transposed_multiply",
+            *make_matmul_multiply_pattern(transposed_rhs=True),
+            _check_matmul,
+        ),
     ]
 )
 
diff --git a/python/tvm/relax/backend/patterns.py 
b/python/tvm/relax/backend/patterns.py
index 404f7dc975..8ec43f1f27 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -376,6 +376,44 @@ def make_matmul_dequantize_pattern(
     return out, annotations
 
 
+def make_matmul_multiply_pattern(
+    transposed_rhs: bool = False,
+) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
+    """
+    Create pattern for matrix multiplication and multiply operation.
+
+    Parameters
+    ----------
+    transposed_rhs: bool
+        Whether the right hand side of multiplication is transposed.
+
+    Returns
+    -------
+    pattern: DFPattern
+        The resulting pattern describing a matrix multiplication.
+
+    annotations: Mapping[str, DFPattern]
+        A mapping from name to sub pattern. It can be used to extract 
important expressions from
+        match result, to power the partition check function and codegen.
+    """
+
+    lhs = wildcard()
+    rhs = wildcard()
+    scaleA = wildcard()
+    scaleB = wildcard()
+    annotations = {"lhs": lhs, "rhs": rhs, "scaleA": scaleA, "scaleB": scaleB}
+
+    if transposed_rhs:
+        rhs = is_op("relax.permute_dims")(rhs)
+    out = is_op("relax.matmul")(lhs, rhs)
+    annotations["root"] = out
+    scale = is_op("relax.multiply")(scaleA.has_shape((1,)), 
scaleB.has_shape((1,)))
+    out = is_op("relax.multiply")(out, scale)
+    out = is_op("relax.astype")(out)
+
+    return out, annotations
+
+
 def make_attention_rewrite_pattern(
     qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, 
with_kv_repeat: bool = False
 ):
diff --git a/src/relax/backend/contrib/cublas/codegen.cc 
b/src/relax/backend/contrib/cublas/codegen.cc
index 9f29d21aaa..e92ee57a5a 100644
--- a/src/relax/backend/contrib/cublas/codegen.cc
+++ b/src/relax/backend/contrib/cublas/codegen.cc
@@ -62,7 +62,7 @@ class CublasJSONSerializer : public JSONSerializer {
       inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end());
     }
 
-    ICHECK(inputs_tmp.size() <= 3);
+    ICHECK(inputs_tmp.size() <= 4);
     NodeEntries inputs(inputs_tmp.size());
 
     auto arg_idx = backend::ExtractArgIdx(composite_name, fn);
@@ -70,6 +70,9 @@ class CublasJSONSerializer : public JSONSerializer {
     inputs[1] = inputs_tmp[arg_idx["rhs"]->value];
     if (inputs_tmp.size() == 3) {
       inputs[2] = inputs_tmp[arg_idx["bias"]->value];
+    } else if (inputs_tmp.size() == 4) {
+      inputs[2] = inputs_tmp[arg_idx["scaleA"]->value];
+      inputs[3] = inputs_tmp[arg_idx["scaleB"]->value];
     }
 
     auto node = std::make_shared<JSONGraphNode>(composite_name, /* name_ */
diff --git a/src/runtime/contrib/cublas/cublas.cc 
b/src/runtime/contrib/cublas/cublas.cc
index 1edb6b95c9..8925080abf 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -137,8 +137,9 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; }
 
 void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
                   cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* 
A, const DLTensor* B,
-                  const DLTensor* bias, const DLTensor* C, bool transa, bool 
transb,
-                  void* workspace_ptr, size_t workspace_size, 
cublasLtEpilogue_t epilogue,
+                  const DLTensor* bias, const DLTensor* scaleA, const 
DLTensor* scaleB,
+                  const DLTensor* C, bool transa, bool transb, void* 
workspace_ptr,
+                  size_t workspace_size, cublasLtEpilogue_t epilogue,
                   std::optional<float> dq_scale) {
   ICHECK(TypeEqual(A->dtype, B->dtype));
   // Reversed strides indicates an in-place transpose operation.
@@ -193,6 +194,15 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t 
stream,
                                                       &bias->data, 
sizeof(float*)));
   }
 
+  if (scaleA != nullptr && scaleB != nullptr) {
+    auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
+    auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
+    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
+                                                      &scaleA_data, 
sizeof(float*)));
+    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
+                                                      &scaleB_data, 
sizeof(float*)));
+  }
+
   if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
     CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_EPILOGUE,
                                                       &epilogue, 
sizeof(epilogue)));
diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc 
b/src/runtime/contrib/cublas/cublas_json_runtime.cc
index 8578d86789..49ff061da5 100644
--- a/src/runtime/contrib/cublas/cublas_json_runtime.cc
+++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc
@@ -97,12 +97,15 @@ class CublasJSONRuntime : public JSONRuntimeBase {
       return dl_tensors[eid];
     };
 
-    auto get_inputs = [=](const JSONGraphNode& node, bool has_bias) {
-      const DLTensor* bias = nullptr;
+    auto get_inputs = [=](const JSONGraphNode& node, bool has_bias, bool 
has_scale) {
+      const DLTensor *bias = nullptr, *scaleA = nullptr, *scaleB = nullptr;
       if (has_bias) {
         bias = get_input(node, 2);
+      } else if (has_scale) {
+        scaleA = get_input(node, 2);
+        scaleB = get_input(node, 3);
       }
-      return std::make_tuple(get_input(node, 0), get_input(node, 1), bias);
+      return std::make_tuple(get_input(node, 0), get_input(node, 1), bias, 
scaleA, scaleB);
     };
 
     for (size_t i = 0; i < nodes_.size(); ++i) {
@@ -127,7 +130,9 @@ class CublasJSONRuntime : public JSONRuntimeBase {
           epilogue = CUBLASLT_EPILOGUE_BIAS;
         }
 
-        auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != 
CUBLASLT_EPILOGUE_DEFAULT);
+        bool has_scale = op_name.find("multiply") != std::string::npos;
+        auto [a_ptr, b_ptr, bias_ptr, scaleA_ptr, scaleB_ptr] =
+            get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT, has_scale);
 
         std::optional<float> dq_scale = std::nullopt;
         if (op_name.find("dequantize") != std::string::npos) {
@@ -135,7 +140,7 @@ class CublasJSONRuntime : public JSONRuntimeBase {
         }
 
         tvm::contrib::CallCublasLt(entry_ptr->handle, stream, 
entry_ptr->matmul_pref_desc, a_ptr,
-                                   b_ptr, bias_ptr, out_ptr, transa, transb,
+                                   b_ptr, bias_ptr, scaleA_ptr, scaleB_ptr, 
out_ptr, transa, transb,
                                    entry_ptr->workspace_ptr, 
entry_ptr->workspace_size, epilogue,
                                    dq_scale);
       }
diff --git a/src/runtime/contrib/cublas/cublas_utils.h 
b/src/runtime/contrib/cublas/cublas_utils.h
index 2906279f90..387065093e 100644
--- a/src/runtime/contrib/cublas/cublas_utils.h
+++ b/src/runtime/contrib/cublas/cublas_utils.h
@@ -123,9 +123,9 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
 /*! \brief Execute matrix multiply followed by the specified epilogue, using 
cuBLASLt. */
 void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
                   cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* 
A, const DLTensor* B,
-                  const DLTensor* bias, const DLTensor* C, bool transa, bool 
transb,
-                  void* workspace_ptr, size_t workspace_size,
-                  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT,
+                  const DLTensor* bias, const DLTensor* scaleA, const 
DLTensor* scaleB,
+                  const DLTensor* C, bool transa, bool transb, void* 
workspace_ptr,
+                  size_t workspace_size, cublasLtEpilogue_t epilogue = 
CUBLASLT_EPILOGUE_DEFAULT,
                   std::optional<float> dq_scale = std::nullopt);
 
 }  // namespace contrib
diff --git a/tests/python/relax/test_codegen_cublas.py 
b/tests/python/relax/test_codegen_cublas.py
index 4ff498ae2b..913f203d19 100644
--- a/tests/python/relax/test_codegen_cublas.py
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -134,6 +134,40 @@ def get_relax_matmul_dequantize_module(
     return tvm.IRModule({"main": func})
 
 
+def get_relax_matmul_multiply_module(
+    x_shape,
+    y_shape,
+    z_shape,
+    in_dtype,
+    acc_dtype,
+    out_dtype,
+    transposed_y=False,
+):
+    """Create a matmul op followd by multiply operations."""
+    with IRBuilder() as builder:
+        with relax_builder.function():
+            R.func_name("main")
+            x = R.arg("x", R.Tensor(x_shape, in_dtype))
+            y = R.arg("y", R.Tensor(y_shape, in_dtype))
+            scaleA = R.arg("scaleA", R.Tensor(z_shape, acc_dtype))
+            scaleB = R.arg("scaleB", R.Tensor(z_shape, acc_dtype))
+
+            with R.dataflow() as frame:
+                if transposed_y:
+                    axes = list(range(len(y_shape) - 2)) + [-1, -2]
+                    y = R.emit(R.permute_dims(y, axes=axes))
+                result = R.emit(R.matmul(x, y, out_dtype=acc_dtype))
+                z = R.emit(R.multiply(scaleA, scaleB))
+                result = R.emit(R.multiply(result, z))
+                if acc_dtype != out_dtype:
+                    result = R.emit(R.astype(result, out_dtype))
+                R.output(result)
+            R.func_ret_value(frame.output_vars[0])
+
+    func = builder.get()
+    return tvm.IRModule({"main": func})
+
+
 @pytest.mark.parametrize(
     "x_shape, y_shape, transpose_y, epilogue",
     [
@@ -327,6 +361,36 @@ def test_matmul_fp8_dequantize_offload():
     tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
 
 
[email protected]_cuda_compute_version(9)
[email protected](ml_dtypes is None, reason="requires ml_dtypes to be 
installed")
+def test_matmul_fp8_multiply_offload():
+    x_shape = (10, 32)
+    y_shape = (64, 32)
+    z_shape = (1,)
+    in_dtype, acc_dtype = ("e4m3_float8", "float32")
+
+    mod = get_relax_matmul_multiply_module(
+        x_shape,
+        y_shape,
+        z_shape,
+        in_dtype,
+        acc_dtype,
+        "float16",
+        transposed_y=True,
+    )
+
+    numpytype = "float8_e4m3fn"
+    x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype)
+    y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype)
+    scaleA = np.random.uniform(low=0, high=5, size=z_shape).astype(acc_dtype)
+    scaleB = np.random.uniform(low=0, high=5, size=z_shape).astype(acc_dtype)
+    args = (x, y, scaleA, scaleB)
+
+    out = get_result_with_relax_cublas_offload(mod, args)
+    ref = build_and_run(mod, args, "llvm", legalize=True)
+    tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
+
+
 @pytest.mark.parametrize(
     "M, N, K, out_dtype, transposed_y, partition_done",
     [
@@ -371,6 +435,21 @@ def test_cublas_partition_fp8_matmul_dequantize(M, N, K, 
scale, zp, num_bindings
     assert len(mod["main"].body.blocks[0].bindings) == num_bindings
 
 
+def test_cublas_partition_fp8_matmul_multiply():
+    M, N, K = (32, 64, 128)
+    mod = get_relax_matmul_multiply_module(
+        (M, K),
+        (N, K),
+        (1,),
+        "e4m3_float8",
+        "float32",
+        "float16",
+        transposed_y=True,
+    )
+    mod = partition_for_cublas(mod)
+    assert len(mod["main"].body.blocks[0].bindings) == 1
+
+
 def test_cublas_partition_matmul_without_bias():
     # cuBLAS does not handle 2D bias (residual input)
     mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", 
bias_shape=(16, 32))

Reply via email to