This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new f7835a6f80 [Unity][CUTLASS] Require the residual input to have the
same shape as input (#14657)
f7835a6f80 is described below
commit f7835a6f808183aaae2ac35ed3e8769884be0173
Author: masahi <[email protected]>
AuthorDate: Thu Apr 20 04:02:04 2023 +0900
[Unity][CUTLASS] Require the residual input to have the same shape as input
(#14657)
Require residual input to have the same shape as input
---
python/tvm/relax/backend/contrib/cutlass.py | 35 +++++++++++++++++++++--------
tests/python/relax/test_codegen_cutlass.py | 31 +++++++++++++++++++++++++
2 files changed, 57 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 0c2f38e300..7d6dc6bf89 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -20,7 +20,7 @@
from typing import Mapping, Sequence
from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
-from tvm.relax import DataflowVar, Var, transform
+from tvm.relax import DataflowVar, Var, transform, Call
from tvm.relax.transform import PatternCheckContext
from ..pattern_registry import get_patterns_with_prefix, register_patterns
@@ -82,6 +82,26 @@ def _has_dependency(from_var: Var, to_var: Var, var_usages:
Mapping[Var, Sequenc
return False
+def _check_residual(root_call: Call, context: PatternCheckContext) -> bool:
+ if "residual" in context.annotated_expr:
+ residual = context.annotated_expr["residual"]
+ if not isinstance(residual, Var):
+ residual = context.value_to_bound_var[residual]
+
+ root_var = context.value_to_bound_var[root_call]
+ if _has_dependency(from_var=residual, to_var=root_var,
var_usages=context.var_usages):
+ # If residual depends on the result of the root call, this cannot
be handled by cutlass.
+ return False
+
+ shape1 = [int(s) for s in root_var.struct_info.shape]
+ shape2 = [int(s) for s in residual.struct_info.shape]
+
+ if shape1 != shape2:
+ return False
+
+ return True
+
+
def _check_conv2d(context: PatternCheckContext) -> bool:
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
if _has_leaking_intermediate_variables(context):
@@ -98,14 +118,8 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
):
return False
- if "residual" in context.annotated_expr:
- residual = context.annotated_expr["residual"]
- if not isinstance(residual, Var):
- residual = context.value_to_bound_var[residual]
- conv2d_var = context.value_to_bound_var[conv2d_call]
- if _has_dependency(from_var=residual, to_var=conv2d_var,
var_usages=context.var_usages):
- # If residual depends on the result of conv2d, this cannot be
handled by cutlass.
- return False
+ if not _check_residual(conv2d_call, context):
+ return False
# pylint: disable=invalid-name
IC = data.struct_info.shape.values[3]
@@ -127,6 +141,9 @@ def _check_matmul(context: PatternCheckContext) -> bool:
if not _is_supported_dtype(lhs_dtype, rhs_dtype):
return False
+ if not _check_residual(lhs, context):
+ return False
+
lhs_shape = lhs.struct_info.shape.values
rhs_shape = rhs.struct_info.shape.values
return is_shape_valid_for_cutlass_matmul(lhs_shape, rhs_shape)
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index db8abf34c2..a5fbf0f642 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -744,5 +744,36 @@ def
test_stacked_attention_strided_slice_offload(stacked_attention_size):
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+def test_invalid_residual():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def main(
+ x: R.Tensor((2, 64, 64, 8), dtype="float16"),
+ w: R.Tensor((8, 3, 3, 8), dtype="float16"),
+ bias: R.Tensor((1, 1, 8), dtype="float16"),
+ residual: R.Tensor((2, 1, 1, 8), dtype="float16"),
+ ) -> R.Tensor((1, 256, 64, 64), dtype="float16"):
+ with R.dataflow():
+ conv = R.nn.conv2d(
+ x,
+ w,
+ padding=[1, 1, 1, 1],
+ out_dtype="float16",
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ )
+ bias_out = R.add(conv, bias)
+ out = R.add(bias_out, residual)
+ R.output(out)
+ return out
+
+ rewritten = partition_for_cutlass(Module)
+ func_names = [gv.name_hint for gv in rewritten.functions.keys()]
+
+ assert "fused_relax_nn_conv2d_relax_add_relax_add_cutlass" not in
func_names
+ assert "fused_relax_nn_conv2d_relax_add_cutlass" in func_names
+
+
if __name__ == "__main__":
tvm.testing.main()