This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0c9aa58907 [Unity][BYOC] Add shape validation for bias arg in cuBLAS
(#14809)
0c9aa58907 is described below
commit 0c9aa58907c66ec4249644718f457a9c51736498
Author: Lite Ye <[email protected]>
AuthorDate: Tue May 9 20:50:12 2023 -0400
[Unity][BYOC] Add shape validation for bias arg in cuBLAS (#14809)
Add shape check for bias arg in cuBLAS backend
And use the correct type key for cublas json runtime
Fix lint
---
python/tvm/relax/backend/contrib/cublas.py | 8 ++++++++
python/tvm/relax/testing/matmul.py | 13 ++++---------
src/runtime/contrib/cublas/cublas_json_runtime.cc | 2 ++
tests/python/relax/test_codegen_cublas.py | 11 ++++++++++-
tests/python/relax/test_codegen_cutlass.py | 6 ++----
5 files changed, 26 insertions(+), 14 deletions(-)
diff --git a/python/tvm/relax/backend/contrib/cublas.py
b/python/tvm/relax/backend/contrib/cublas.py
index 627c936993..a929b92285 100644
--- a/python/tvm/relax/backend/contrib/cublas.py
+++ b/python/tvm/relax/backend/contrib/cublas.py
@@ -53,6 +53,14 @@ def _check_matmul(context: PatternCheckContext) -> bool:
lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
+ if "bias" in context.annotated_expr:
+ bias = context.annotated_expr["bias"]
+ bias_shape = bias.struct_info.shape.values
+ bias_batches = reduce(operator.mul, bias_shape[:-1], 1)
+ if not isinstance(bias_batches, (tvm.tir.expr.IntImm, int)) or
int(bias_batches) > 1:
+ # cuBLAS only supports bias vector
+ return False
+
# cuBLASLt does not seem to support batched GEMM with one of matrices
having
# one batch (with batch_stride 0). So for batched GEMM, the two batch
counts
# must be equal.
diff --git a/python/tvm/relax/testing/matmul.py
b/python/tvm/relax/testing/matmul.py
index bac6fc6c9a..3744bc9f42 100644
--- a/python/tvm/relax/testing/matmul.py
+++ b/python/tvm/relax/testing/matmul.py
@@ -26,31 +26,26 @@ def get_relax_matmul_module(
y_shape,
dtype,
transposed_y=False,
- with_bias=False,
+ bias_shape=None,
activation=None,
residual_bin_op=None,
residual_activation=None,
):
"""Create a matmul op followd by epilogue operations."""
- if transposed_y:
- n = y_shape[-2]
- else:
- n = y_shape[-1]
-
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
x = R.arg("x", R.Tensor(x_shape, dtype))
y = R.arg("y", R.Tensor(y_shape, dtype))
- if with_bias:
- bias = R.arg("bias", R.Tensor((n,), dtype))
+ if bias_shape is not None:
+ bias = R.arg("bias", R.Tensor(bias_shape, dtype))
with R.dataflow() as frame:
if transposed_y:
axes = list(range(len(y_shape) - 2)) + [-1, -2]
y = R.emit(R.permute_dims(y, axes=axes))
result = R.emit(R.matmul(x, y, out_dtype=dtype))
- if with_bias:
+ if bias_shape is not None:
result = R.emit(result + bias)
if activation is not None:
result = R.emit(activation(result))
diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc
b/src/runtime/contrib/cublas/cublas_json_runtime.cc
index 8afccb2730..b3931fb9fe 100644
--- a/src/runtime/contrib/cublas/cublas_json_runtime.cc
+++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc
@@ -49,6 +49,8 @@ class CublasJSONRuntime : public JSONRuntimeBase {
void Init(const Array<NDArray>& consts) override {}
+ const char* type_key() const override { return "cublas_json"; } // May be
overridden
+
void Run() override {
// TODO(masahi): Reuse the same handle across different subgraphs
cublasLtHandle_t handle;
diff --git a/tests/python/relax/test_codegen_cublas.py
b/tests/python/relax/test_codegen_cublas.py
index 023054256e..4eb0cc3b0a 100644
--- a/tests/python/relax/test_codegen_cublas.py
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -141,7 +141,7 @@ def test_matmul_offload(
x_shape,
y_shape,
dtype,
- with_bias=with_bias,
+ bias_shape=bias.shape if with_bias else None,
transposed_y=transpose_y,
activation=activation,
)
@@ -152,5 +152,14 @@ def test_matmul_offload(
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+def test_cublass_partition_matmul_without_bias():
+ # cuBLAS does not handle 2D bias (residual input)
+ mod = get_relax_matmul_module((16, 32), (32, 32), "float16",
bias_shape=(16, 32))
+ mod = partition_for_cublas(mod)
+
+ # R.add is still in the main function
+ assert len(mod["main"].body.blocks[0].bindings) == 2
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 7a831c094e..75285998a6 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -364,7 +364,7 @@ def test_matmul_offload(
x_shape,
y_shape,
dtype,
- with_bias=with_bias,
+ bias_shape=bias.shape if with_bias else None,
transposed_y=transpose_y,
activation=activation,
residual_bin_op=residual_bin_op,
@@ -479,9 +479,7 @@ def test_cutlass_partition_matmul_blocked(x_shape, y_shape,
transpose_y, dtype):
if transpose_y:
y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
- mod = get_relax_matmul_module(
- x_shape, y_shape, dtype, with_bias=False, transposed_y=transpose_y
- )
+ mod = get_relax_matmul_module(x_shape, y_shape, dtype,
transposed_y=transpose_y)
mod = partition_for_cutlass(mod)
assert len(mod.functions) == 1