This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0c9aa58907 [Unity][BYOC] Add shape validation for bias arg in cuBLAS 
(#14809)
0c9aa58907 is described below

commit 0c9aa58907c66ec4249644718f457a9c51736498
Author: Lite Ye <[email protected]>
AuthorDate: Tue May 9 20:50:12 2023 -0400

    [Unity][BYOC] Add shape validation for bias arg in cuBLAS (#14809)
    
    Add shape check for bias arg in cuBLAS backend
    
    And use the correct type key for cublas json runtime
    
    Fix lint
---
 python/tvm/relax/backend/contrib/cublas.py        |  8 ++++++++
 python/tvm/relax/testing/matmul.py                | 13 ++++---------
 src/runtime/contrib/cublas/cublas_json_runtime.cc |  2 ++
 tests/python/relax/test_codegen_cublas.py         | 11 ++++++++++-
 tests/python/relax/test_codegen_cutlass.py        |  6 ++----
 5 files changed, 26 insertions(+), 14 deletions(-)

diff --git a/python/tvm/relax/backend/contrib/cublas.py 
b/python/tvm/relax/backend/contrib/cublas.py
index 627c936993..a929b92285 100644
--- a/python/tvm/relax/backend/contrib/cublas.py
+++ b/python/tvm/relax/backend/contrib/cublas.py
@@ -53,6 +53,14 @@ def _check_matmul(context: PatternCheckContext) -> bool:
     lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
     rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
 
+    if "bias" in context.annotated_expr:
+        bias = context.annotated_expr["bias"]
+        bias_shape = bias.struct_info.shape.values
+        bias_batches = reduce(operator.mul, bias_shape[:-1], 1)
+        if not isinstance(bias_batches, (tvm.tir.expr.IntImm, int)) or 
int(bias_batches) > 1:
+            # cuBLAS only supports bias vector
+            return False
+
     # cuBLASLt does not seem to support batched GEMM with one of matrices 
having
     # one batch (with batch_stride 0). So for batched GEMM, the two batch 
counts
     # must be equal.
diff --git a/python/tvm/relax/testing/matmul.py 
b/python/tvm/relax/testing/matmul.py
index bac6fc6c9a..3744bc9f42 100644
--- a/python/tvm/relax/testing/matmul.py
+++ b/python/tvm/relax/testing/matmul.py
@@ -26,31 +26,26 @@ def get_relax_matmul_module(
     y_shape,
     dtype,
     transposed_y=False,
-    with_bias=False,
+    bias_shape=None,
     activation=None,
     residual_bin_op=None,
     residual_activation=None,
 ):
     """Create a matmul op followd by epilogue operations."""
-    if transposed_y:
-        n = y_shape[-2]
-    else:
-        n = y_shape[-1]
-
     with IRBuilder() as builder:
         with relax_builder.function():
             R.func_name("main")
             x = R.arg("x", R.Tensor(x_shape, dtype))
             y = R.arg("y", R.Tensor(y_shape, dtype))
-            if with_bias:
-                bias = R.arg("bias", R.Tensor((n,), dtype))
+            if bias_shape is not None:
+                bias = R.arg("bias", R.Tensor(bias_shape, dtype))
 
             with R.dataflow() as frame:
                 if transposed_y:
                     axes = list(range(len(y_shape) - 2)) + [-1, -2]
                     y = R.emit(R.permute_dims(y, axes=axes))
                 result = R.emit(R.matmul(x, y, out_dtype=dtype))
-                if with_bias:
+                if bias_shape is not None:
                     result = R.emit(result + bias)
                 if activation is not None:
                     result = R.emit(activation(result))
diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc 
b/src/runtime/contrib/cublas/cublas_json_runtime.cc
index 8afccb2730..b3931fb9fe 100644
--- a/src/runtime/contrib/cublas/cublas_json_runtime.cc
+++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc
@@ -49,6 +49,8 @@ class CublasJSONRuntime : public JSONRuntimeBase {
 
   void Init(const Array<NDArray>& consts) override {}
 
+  const char* type_key() const override { return "cublas_json"; }  // May be 
overridden
+
   void Run() override {
     // TODO(masahi): Reuse the same handle across different subgraphs
     cublasLtHandle_t handle;
diff --git a/tests/python/relax/test_codegen_cublas.py 
b/tests/python/relax/test_codegen_cublas.py
index 023054256e..4eb0cc3b0a 100644
--- a/tests/python/relax/test_codegen_cublas.py
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -141,7 +141,7 @@ def test_matmul_offload(
         x_shape,
         y_shape,
         dtype,
-        with_bias=with_bias,
+        bias_shape=bias.shape if with_bias else None,
         transposed_y=transpose_y,
         activation=activation,
     )
@@ -152,5 +152,14 @@ def test_matmul_offload(
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_cublass_partition_matmul_without_bias():
+    # cuBLAS does not handle 2D bias (residual input)
+    mod = get_relax_matmul_module((16, 32), (32, 32), "float16", 
bias_shape=(16, 32))
+    mod = partition_for_cublas(mod)
+
+    # R.add is still in the main function
+    assert len(mod["main"].body.blocks[0].bindings) == 2
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 7a831c094e..75285998a6 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -364,7 +364,7 @@ def test_matmul_offload(
         x_shape,
         y_shape,
         dtype,
-        with_bias=with_bias,
+        bias_shape=bias.shape if with_bias else None,
         transposed_y=transpose_y,
         activation=activation,
         residual_bin_op=residual_bin_op,
@@ -479,9 +479,7 @@ def test_cutlass_partition_matmul_blocked(x_shape, y_shape, 
transpose_y, dtype):
     if transpose_y:
         y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
 
-    mod = get_relax_matmul_module(
-        x_shape, y_shape, dtype, with_bias=False, transposed_y=transpose_y
-    )
+    mod = get_relax_matmul_module(x_shape, y_shape, dtype, 
transposed_y=transpose_y)
     mod = partition_for_cutlass(mod)
 
     assert len(mod.functions) == 1

Reply via email to