This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 354683860c [Unity][BYOC] Add batch matmul support to Relax CUTLASS 
BYOC (#14166)
354683860c is described below

commit 354683860c7b45b18923b1c2e3541ebc49053c0d
Author: Lite Ye <[email protected]>
AuthorDate: Thu Mar 2 01:13:08 2023 -0500

    [Unity][BYOC] Add batch matmul support to Relax CUTLASS BYOC (#14166)
    
    * Add batch matmul support to Relax CUTLASS BYOC
    
    * Allow more dtypes
    
    * Fix tests
    
    * Revert how to get batch attr
---
 python/tvm/contrib/cutlass/build.py          |  59 +++++-
 python/tvm/contrib/cutlass/gemm_operation.py |   9 +-
 python/tvm/contrib/cutlass/gen_tensor_op.py  |  15 +-
 python/tvm/relax/backend/contrib/cutlass.py  |  90 +++++++--
 python/tvm/relax/backend/pattern_registry.py |  77 ++++++--
 src/relax/backend/pattern_registry.cc        |  20 +-
 src/relax/backend/pattern_registry.h         |  17 +-
 src/relax/transform/fuse_ops.cc              |   3 +-
 tests/python/relax/test_codegen_cutlass.py   | 286 ++++++++++++++-------------
 9 files changed, 380 insertions(+), 196 deletions(-)

diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 954aef60c2..7e81113f44 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -16,10 +16,13 @@
 # under the License.
 # pylint: disable=invalid-name, dangerous-default-value, arguments-differ
 """Driver for partitioning and building a Relay module for CUTLASS offload."""
+import itertools
 import logging
 import multiprocessing
+import operator
 import os
-from typing import Optional
+from functools import reduce
+from typing import Optional, Tuple
 
 import tvm
 from tvm import relax, relay, runtime
@@ -569,6 +572,31 @@ def _extract_arg_idx(pattern_name, f):
     return arg_idx
 
 
+def is_valid_for_cutlass_matmul(lhs_shape: Tuple[int], rhs_shape: Tuple[int]) 
-> bool:
+    """
+    Check whether the shape of inputs can be handled by CUTLASS GEMM.
+
+    The stride-based batch matmul in CUTLASS cannot handle cases that some of
+    the batch dimensions need to be stretched while others don't. This means
+    it can only handle ND x ND whose batch dimensions match exactly on both 
side,
+    as well as ND x 2D and 2D x ND. For example, it cannot handle matmul with 
shape
+    (2, 1, 4, 8) x (2, 3, 8, 16), because the batch stride of lhs is not 
constant.
+    """
+    lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
+    rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
+    if lhs_batches == 1 or rhs_batches == 1:
+        # This could be regular matmul or batch matmul with shape ND x 2D or 
2D x ND
+        return True
+
+    # If one side has less dimensions, use 1 to fill the gap
+    batch_dim_pairs = itertools.zip_longest(
+        lhs_shape[-3::-1],  # Remove the last two dimensions and reverse
+        rhs_shape[-3::-1],
+        fillvalue=1,
+    )
+    return all(p[0] == p[1] for p in batch_dim_pairs)
+
+
 @relax.expr_functor.mutator
 class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
     """A Relax function mutator that tunes and annotates CUTLASS composite 
functions
@@ -659,17 +687,35 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
         rhs_dtype = signature[f"{rhs_arg}_dtype"]
         out_dtype = signature["ret_dtype"]
 
-        MM = lhs_shape[0]
-        KK = lhs_shape[1]
+        if not is_valid_for_cutlass_matmul(lhs_shape, rhs_shape):
+            raise ValueError(f"Cannot handle the input shapes, lhs: 
{lhs_shape}, rhs: {rhs_shape}")
+
+        MM = lhs_shape[-2]
+        KK = lhs_shape[-1]
         if "transposed" in op_type:
-            NN = rhs_shape[0]
+            NN = rhs_shape[-2]
             ldb = "K"
             layout_b = LayoutType.ColumnMajor
         else:
-            NN = rhs_shape[1]
+            NN = rhs_shape[-1]
             ldb = "N"
             layout_b = LayoutType.RowMajor
 
+        lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
+        rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
+        if lhs_batches == 1 and rhs_batches == 1:
+            # Regular matmul
+            is_batched = False
+            batch_attrs = {}
+        else:
+            is_batched = True
+            batch_attrs = {
+                "batch": max(lhs_batches, rhs_batches),
+                "batch_stride_A": 0 if lhs_batches == 1 else MM * KK,
+                "batch_stride_B": 0 if rhs_batches == 1 else KK * NN,
+                "batch_stride_C": MM * NN,
+            }
+
         use_3xtf32 = self.options.get("use_3xtf32", False)
         find_first_valid = self.options.get("find_first_valid", True)
         use_multiprocessing = self.options.get("use_multiprocessing", True)
@@ -683,7 +729,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             lhs_dtype,
             rhs_dtype,
             use_3xtf32,
-            batched=False,
+            batched=is_batched,
             find_first_valid=find_first_valid,
             use_multiprocessing=use_multiprocessing,
             layout_b=layout_b,
@@ -706,6 +752,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
                 "ldc": "N",
                 "cutlass_op_name": op_name,
                 "cutlass_op_def": op_def,
+                **batch_attrs,
             }
         )
 
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py 
b/python/tvm/contrib/cutlass/gemm_operation.py
index 3e74cbaec8..1675e1f035 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -287,7 +287,7 @@ def instantiate_gemm_template(attrs):
    {static_cast<ElementInputA*>(ptr_a), ${lda}}, ${batch_stride_A}
    {static_cast<ElementInputB*>(ptr_b), ${ldb}}, ${batch_stride_B}
    {static_cast<ElementOutput*>(${ptr_c}), ${c_stride}}, ${batch_stride_C}
-   {static_cast<ElementOutput*>(ptr_out), ${ldc}}, ${batch_stride_C}
+   {static_cast<ElementOutput*>(ptr_out), ${ldc}}, ${batch_stride_D}
    {${alpha_beta}},
    ${split_k_slices_or_batch}
   };
@@ -303,8 +303,7 @@ def instantiate_gemm_template(attrs):
 """
     has_bias = "bias" in attrs["op_type"]
     is_gelu = "gelu" in attrs["op_type"]
-    batched = "batch_matmul" in attrs["op_type"]
-
+    batched = "batch" in attrs
     aux_map = {"kernel": "Gemm"}
 
     if has_bias:
@@ -335,6 +334,10 @@ def instantiate_gemm_template(attrs):
         else:
             aux_map[key] = attrs[key] + ","
 
+    aux_map["batch_stride_D"] = aux_map["batch_stride_C"]
+    if has_bias and batched:
+        aux_map["batch_stride_C"] = "0,"
+
     if batched:
         attrs["split_k_slices_or_batch"] = attrs["batch"]
     else:
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 92bf04e863..2976946dd2 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -529,8 +529,7 @@ def instantiate_template(func_name, annotations, func_args):
         return dim1 + " * " + dim2
 
     if "dense" in func_name or "matmul" in func_name:
-        batched = "batch_matmul" in func_name
-        batched_offset = 1 if batched else 0
+        batched = "batch" in annotations
         transposed = "transposed" in func_name
         lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx", 
0)
         rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx", 
1)
@@ -539,6 +538,8 @@ def instantiate_template(func_name, annotations, func_args):
         rhs_arg = func_args[rhs_arg_idx]
         lhs_shape = annotations[f"arg{lhs_arg_idx}_shape"]
         rhs_shape = annotations[f"arg{rhs_arg_idx}_shape"]
+        lhs_batched_offset = len(lhs_shape) - 2
+        rhs_batched_offset = len(rhs_shape) - 2
 
         attrs["lhs_arg"] = lhs_arg
         attrs["rhs_arg"] = rhs_arg
@@ -548,17 +549,17 @@ def instantiate_template(func_name, annotations, 
func_args):
         attrs["ElementInputB"] = 
DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]]
         attrs["ElementOutput"] = 
DataTypeTag[dtype_map[annotations["ret_dtype"]]]
 
-        attrs["K"] = str(int(lhs_shape[batched_offset + 1]))
-        attrs["M"] = get_dim(lhs_shape[batched_offset], lhs_arg, 0, 
batched_offset)
+        attrs["K"] = str(int(lhs_shape[lhs_batched_offset + 1]))
+        attrs["M"] = get_dim(lhs_shape[lhs_batched_offset], lhs_arg, 0, 
lhs_batched_offset)
 
         if transposed:
-            attrs["N"] = get_dim(rhs_shape[batched_offset], rhs_arg, 0, 
batched_offset)
+            attrs["N"] = get_dim(rhs_shape[rhs_batched_offset], rhs_arg, 0, 
rhs_batched_offset)
         else:
-            attrs["N"] = get_dim(rhs_shape[batched_offset + 1], rhs_arg, 1, 
batched_offset)
+            attrs["N"] = get_dim(rhs_shape[rhs_batched_offset + 1], rhs_arg, 
1, rhs_batched_offset)
 
         if batched:
             headers.append("cutlass/gemm/device/gemm_batched.h")
-            attrs["batch"] = get_dim(lhs_shape[0], lhs_arg, 0)
+            attrs["batch"] = get_dim(annotations["batch"], lhs_arg, 0)
             attrs["batch_stride_A"] = get_batch_stride(
                 annotations["batch_stride_A"],
                 lhs_arg_idx,
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index e98194ca21..2d8908184b 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -17,16 +17,69 @@
 
 """Pattern table for CUTLASS backend"""
 
-from tvm.relax import transform
+from typing import Mapping, Optional, Tuple
+
+import tvm
+from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul
+from tvm.relax import Call, Expr, ShapeExpr, transform
+from tvm.relax.dpl import DFPattern
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
 from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern
 
+
+def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]:
+    result = []
+    for dim in shape.values:
+        if isinstance(dim, tvm.tir.expr.IntImm):
+            result.append(int(dim))
+        else:
+            return None
+    return result
+
+
+def _is_supported_dtype(lhs_dtype, rhs_dtype):
+    """Check if dtypes in the given workload are supported by CUTLASS."""
+    return (
+        (lhs_dtype == "float16" and rhs_dtype == "float16")
+        or (lhs_dtype == "float32" and rhs_dtype == "float32")
+        or (lhs_dtype in ("int8", "uint8") and rhs_dtype in ("int8", "uint8"))
+    )
+
+
+def _check_matmul(
+    match_result: Mapping[DFPattern, Expr],
+    _: Expr,
+) -> bool:
+    matmul_call: Call = None
+    for _, expr in match_result.items():
+        if (
+            isinstance(expr, Call)
+            and isinstance(expr.op, tvm.ir.Op)
+            and expr.op.name == "relax.matmul"
+        ):
+            matmul_call = expr
+    if matmul_call is None:
+        raise ValueError("Cannot find call to matmul from match_result.")
+
+    lhs_shape = _get_static_shape(matmul_call.args[0].struct_info.shape)
+    rhs_shape = _get_static_shape(matmul_call.args[1].struct_info.shape)
+    if len(lhs_shape) < 2 or len(rhs_shape) < 2:
+        return False
+
+    lhs_dtype = matmul_call.args[0].struct_info.dtype
+    rhs_dtype = matmul_call.args[1].struct_info.dtype
+    if not _is_supported_dtype(lhs_dtype, rhs_dtype):
+        return False
+
+    return is_valid_for_cutlass_matmul(lhs_shape, rhs_shape)
+
+
 register_patterns(
     [
         (
             "cutlass.conv2d",
-            make_fused_bias_activation_pattern(
+            *make_fused_bias_activation_pattern(
                 "relax.nn.conv2d",
                 with_bias=False,
                 activation=None,
@@ -34,7 +87,7 @@ register_patterns(
         ),
         (
             "cutlass.conv2d_bias_relu",
-            make_fused_bias_activation_pattern(
+            *make_fused_bias_activation_pattern(
                 "relax.nn.conv2d",
                 with_bias=True,
                 activation="relax.nn.relu",
@@ -42,59 +95,67 @@ register_patterns(
         ),
         (
             "cutlass.matmul",
-            make_matmul_pattern(
+            *make_matmul_pattern(
                 with_bias=False,
             ),
+            _check_matmul,
         ),
         (
             "cutlass.matmul_bias",
-            make_matmul_pattern(
+            *make_matmul_pattern(
                 with_bias=True,
             ),
+            _check_matmul,
         ),
         (
             "cutlass.matmul_bias_relu",
-            make_matmul_pattern(
+            *make_matmul_pattern(
                 with_bias=True,
                 activation="relax.nn.relu",
             ),
+            _check_matmul,
         ),
         (
             "cutlass.matmul_bias_gelu",
-            make_matmul_pattern(
+            *make_matmul_pattern(
                 with_bias=True,
                 activation="relax.nn.gelu",
             ),
+            _check_matmul,
         ),
         (
             "cutlass.matmul_transposed",
-            make_matmul_pattern(
+            *make_matmul_pattern(
                 with_bias=False,
                 transposed_rhs=True,
             ),
+            _check_matmul,
         ),
         (
             "cutlass.matmul_transposed_bias",
-            make_matmul_pattern(
+            *make_matmul_pattern(
                 with_bias=True,
                 transposed_rhs=True,
             ),
+            _check_matmul,
         ),
         (
             "cutlass.matmul_transposed_bias_relu",
-            make_matmul_pattern(
+            *make_matmul_pattern(
                 with_bias=True,
                 activation="relax.nn.relu",
                 transposed_rhs=True,
             ),
+            _check_matmul,
         ),
         (
             "cutlass.matmul_transposed_bias_gelu",
-            make_matmul_pattern(
+            *make_matmul_pattern(
                 with_bias=True,
                 activation="relax.nn.gelu",
                 transposed_rhs=True,
             ),
+            _check_matmul,
         ),
     ]
 )
@@ -116,7 +177,6 @@ def partition_for_cutlass(mod):
         compiled by the CUTLASS backend.
     """
 
-    cutlass_patterns = get_patterns_with_prefix("cutlass")
-    return transform.FuseOpsByPattern(cutlass_patterns, bind_constants=True, 
annotate_codegen=True)(
-        mod
-    )
+    cutlass_pattern_entries = get_patterns_with_prefix("cutlass")
+    patterns = [(e.name, e.pattern, e.check) for e in cutlass_pattern_entries]
+    return transform.FuseOpsByPattern(patterns, bind_constants=False, 
annotate_codegen=True)(mod)
diff --git a/python/tvm/relax/backend/pattern_registry.py 
b/python/tvm/relax/backend/pattern_registry.py
index 0cccab7842..5a35eba03d 100644
--- a/python/tvm/relax/backend/pattern_registry.py
+++ b/python/tvm/relax/backend/pattern_registry.py
@@ -17,14 +17,15 @@
 
 """Pattern registry for BYOC backends"""
 
-from typing import Callable, List, Mapping, Optional, Tuple, Union
+import atexit
+from typing import Callable, List, Mapping, Optional, Set, Tuple, Union
 
 import tvm
 from tvm.relax.dpl import DFPattern
 from tvm.runtime import Object
 
-from . import _ffi_api
 from ..expr import Expr
+from . import _ffi_api
 
 
 @tvm._ffi.register_object("relax.backend.PatternRegistryEntry")
@@ -46,25 +47,59 @@ class PatternRegistryEntry(Object):
     arg_patterns: Mapping[str, DFPattern]
         The mapping from arg name to its pattern. It can be used to extract 
arg expression
         from match result. All DFPattern in this map should be part of the 
`pattern`.
+
+    check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
+        The function to check whether the match result is accepted.
     """
 
     name: str
     pattern: DFPattern
     arg_patterns: Mapping[str, DFPattern]
-
-    def __init__(self, name: str, pattern: DFPattern, arg_patterns: 
Mapping[str, DFPattern]):
+    check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
+
+    def __init__(
+        self,
+        name: str,
+        pattern: DFPattern,
+        arg_patterns: Mapping[str, DFPattern],
+        check: Callable[[Mapping[DFPattern, Expr], Expr], bool],
+    ):
         self.__init_handle_by_constructor__(
-            _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns  # 
type: ignore
+            _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns, check  
# type: ignore
         )
 
 
+_REGISTERED_PATTERN_NAMES: Set[str] = set()
+
+
+def _cleanup_registered_patterns():
+    _ffi_api.RemovePatterns(list(_REGISTERED_PATTERN_NAMES))  # type: ignore # 
pylint: disable=no-member
+
+
+_CLEANUP_REGISTERED = False
+
+
+def _ensure_cleanup_function_registered():
+    """
+    Add a cleanup function to be called on interpreter termination, to remove 
all
+    patterns registered on the Python side. Without cleaning up those patterns,
+    program will segfault on termination. It's because the check functions of 
pattern
+    entries are referenced from the static memory of libtvm, thus they will be 
cleaned
+    up at the very end, making calls to Py_DecRef after Python interpreter 
terminates.
+    """
+    global _CLEANUP_REGISTERED  # pylint: disable=global-statement
+
+    if not _CLEANUP_REGISTERED:
+        atexit.register(_cleanup_registered_patterns)
+        _CLEANUP_REGISTERED = True
+
+
+CheckFunc = Callable[[Mapping[DFPattern, Expr], Expr], bool]
 Pattern = Union[
     PatternRegistryEntry,
     Tuple[str, DFPattern],
-    Tuple[
-        str,
-        Tuple[DFPattern, Mapping[str, DFPattern], Callable[[Mapping[DFPattern, 
Expr], Expr], bool]],
-    ],
+    Tuple[str, DFPattern, Mapping[str, DFPattern]],
+    Tuple[str, DFPattern, Mapping[str, DFPattern], CheckFunc],
 ]
 
 
@@ -79,21 +114,27 @@ def register_patterns(patterns: List[Pattern]):
         Patterns to be registered. Patterns that appear later in the list have
         higher priority when partitioning DataflowBlock.
     """
+    _ensure_cleanup_function_registered()
+
     entries = []
     for item in patterns:
         if isinstance(item, PatternRegistryEntry):
             entries.append(item)
         elif isinstance(item, tuple):
-            name, pattern_or_tuple = item
-            if isinstance(pattern_or_tuple, tuple):
-                if len(pattern_or_tuple) == 2:
-                    pattern, arg_patterns = pattern_or_tuple
-                    check = lambda *_: True
-                else:
-                    pattern, arg_patterns, check = pattern_or_tuple
+            name, pattern, *rest = item
+
+            if len(rest) > 0:
+                arg_patterns = rest[0]
             else:
-                pattern, arg_patterns, check = pattern_or_tuple, {}, lambda 
*_: True
-            entries.append(PatternRegistryEntry(name, pattern, arg_patterns))
+                arg_patterns = {}
+
+            if len(rest) > 1:
+                check = rest[1]
+            else:
+                check = lambda *_: True
+
+            entries.append(PatternRegistryEntry(name, pattern, arg_patterns, 
check))
+            _REGISTERED_PATTERN_NAMES.add(name)
         else:
             raise TypeError(f"Cannot register type {type(pattern)} as pattern")
     _ffi_api.RegisterPatterns(entries)
diff --git a/src/relax/backend/pattern_registry.cc 
b/src/relax/backend/pattern_registry.cc
index 3ca7973365..553018d690 100644
--- a/src/relax/backend/pattern_registry.cc
+++ b/src/relax/backend/pattern_registry.cc
@@ -26,11 +26,12 @@ namespace relax {
 namespace backend {
 
 PatternRegistryEntry::PatternRegistryEntry(String name, DFPattern pattern,
-                                           Map<String, DFPattern> 
arg_patterns) {
+                                           Map<String, DFPattern> 
arg_patterns, PackedFunc check) {
   ObjectPtr<PatternRegistryEntryNode> n = 
make_object<PatternRegistryEntryNode>();
   n->name = std::move(name);
   n->pattern = std::move(pattern);
   n->arg_patterns = std::move(arg_patterns);
+  n->check = check;
   data_ = std::move(n);
 }
 
@@ -48,6 +49,17 @@ void RegisterPatterns(Array<PatternRegistryEntry> entries) {
   }
 }
 
+void RemovePatterns(Array<String> names) {
+  std::unordered_set<String> name_set{names.begin(), names.end()};
+
+  auto* table = GetRegistryTable();
+  table->erase(std::remove_if(table->begin(), table->end(),
+                              [&](const PatternRegistryEntry& entry) {
+                                return name_set.count(entry->name) > 0;
+                              }),
+               table->end());
+}
+
 Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix) {
   auto* table = GetRegistryTable();
   Array<PatternRegistryEntry> result;
@@ -70,10 +82,12 @@ Optional<PatternRegistryEntry> GetPattern(const String& 
pattern_name) {
 }
 
 TVM_REGISTER_GLOBAL("relax.backend.PatternRegistryEntry")
-    .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern> 
arg_patterns) {
-      return PatternRegistryEntry(name, pattern, arg_patterns);
+    .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern> 
arg_patterns,
+                       PackedFunc check) {
+      return PatternRegistryEntry(name, pattern, arg_patterns, check);
     });
 
TVM_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns);
+TVM_REGISTER_GLOBAL("relax.backend.RemovePatterns").set_body_typed(RemovePatterns);
 
TVM_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix").set_body_typed(GetPatternsWithPrefix);
 TVM_REGISTER_GLOBAL("relax.backend.GetPattern").set_body_typed(GetPattern);
 
diff --git a/src/relax/backend/pattern_registry.h 
b/src/relax/backend/pattern_registry.h
index 1fdb69319f..e765f56b4e 100644
--- a/src/relax/backend/pattern_registry.h
+++ b/src/relax/backend/pattern_registry.h
@@ -42,8 +42,6 @@ namespace backend {
  */
 class PatternRegistryEntryNode : public Object {
  public:
-  using FCheckMatch = runtime::TypedPackedFunc<bool(const Map<DFPattern, 
Expr>&, const Expr&)>;
-
   /*!
    * \brief The name of pattern. Usually it starts with the name of backend, 
like
    * 'cutlass.matmul'.
@@ -63,13 +61,17 @@ class PatternRegistryEntryNode : public Object {
 
   /*!
    * \brief The function to check whether the match result is accepted.
+   *
+   * It should have signature
+   * bool(const Map<DFPattern, Expr>& match_result, const Expr& matched_expr)
    */
-  FCheckMatch check;
+  PackedFunc check;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("pattern", &pattern);
     v->Visit("arg_patterns", &arg_patterns);
+    v->Visit("check", &check);
   }
 
   static constexpr const char* _type_key = 
"relax.backend.PatternRegistryEntry";
@@ -78,7 +80,8 @@ class PatternRegistryEntryNode : public Object {
 
 class PatternRegistryEntry : public ObjectRef {
  public:
-  PatternRegistryEntry(String name, DFPattern pattern, Map<String, DFPattern> 
arg_patterns);
+  PatternRegistryEntry(String name, DFPattern pattern, Map<String, DFPattern> 
arg_patterns,
+                       PackedFunc check);
 
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternRegistryEntry, ObjectRef,
                                             PatternRegistryEntryNode);
@@ -92,6 +95,12 @@ class PatternRegistryEntry : public ObjectRef {
  */
 void RegisterPatterns(Array<PatternRegistryEntry> entries);
 
+/*!
+ * \brief Remove patterns from the registry by their name.
+ * \param names The name of patterns to be removed
+ */
+void RemovePatterns(Array<String> names);
+
 /*!
  * \brief Find patterns whose name starts with a particular prefix.
  * \param prefx The pattern name prefix.
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 72427a8a0e..d6013c8874 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -40,7 +40,6 @@
 
 #include "../../relay/analysis/graph_partitioner.h"
 #include "../../support/arena.h"
-#include "../backend/pattern_registry.h"
 
 namespace tvm {
 namespace relax {
@@ -902,7 +901,7 @@ class PatternBasedPartitioner : ExprVisitor {
   using Group = GraphPartitioner::Group;
   using GroupMap = OperatorFusor::GroupMap;
   using ExprVisitor::VisitExpr_;
-  using FCheckMatch = backend::PatternRegistryEntryNode::FCheckMatch;
+  using FCheckMatch = runtime::TypedPackedFunc<bool(const Map<DFPattern, 
Expr>&, const Expr&)>;
 
   static GroupMap Run(String pattern_name, DFPattern pattern, FCheckMatch 
check, Expr expr,
                       support::Arena* arena) {
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 6eb476496c..83104d6fe1 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -16,11 +16,14 @@
 # under the License.
 import numpy as np
 import pytest
+import scipy
 
 import tvm
 import tvm.testing
-from tvm import relax
+from tvm import relax, relay
+from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul
 from tvm.relax.backend import get_patterns_with_prefix
+from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
 from tvm.script import relax as R
 
 
@@ -95,20 +98,17 @@ def build_and_run(mod, inputs_np, target, legalize=False):
 
 def get_result_with_relax_cutlass_offload(mod, *args):
     patterns = [(entry.name, entry.pattern) for entry in 
get_patterns_with_prefix("cutlass")]
-
     assert len(patterns) != 0, "Cannot find cutlass patterns"
 
-    seq = tvm.transform.Sequential(
-        [
-            relax.transform.FuseOpsByPattern(patterns, bind_constants=False, 
annotate_codegen=True),
-            relax.transform.RunCodegen({"cutlass": {"sm": 80, 
"find_first_valid": True}}),
-        ]
-    )
+    mod = partition_for_cutlass(mod)
+    codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, 
"find_first_valid": True}})
+    mod = codegen_pass(mod)
 
-    return build_and_run(seq(mod), args, "cuda")
+    return build_and_run(mod, args, "cuda")
 
 
 def test_conv2d_offload():
+    low, high = -1, 1
     data = np.random.randint(low, high, size=(16, 32, 32, 
16)).astype("float16")
     weight = np.random.randint(low, high, size=(32, 3, 3, 
16)).astype("float16")
     bias = np.random.randint(low, high, size=(1, 1, 1, 32)).astype("float16")
@@ -120,12 +120,26 @@ def test_conv2d_offload():
     np.testing.assert_equal(out, ref)
 
 
+def test_kernel_sharing():
+    low, high = -1, 1
+    data_np = np.random.randint(low, high, size=(16, 32, 32, 
8)).astype("float16")
+    weight1_np = np.random.randint(low, high, size=(8, 3, 3, 
8)).astype("float16")
+    weight2_np = np.random.randint(low, high, size=(8, 3, 3, 
8)).astype("float16")
+
+    out = get_result_with_relax_cutlass_offload(Conv2dx2, data_np, weight1_np, 
weight2_np)
+    ref = build_and_run(Conv2dx2, [data_np, weight1_np, weight2_np], "llvm", 
legalize=True)
+
+    np.testing.assert_equal(out, ref)
+
+
 def get_relax_matmul_module(x, y, transposed_y=False, with_bias=False, 
activation=None):
+    m, k = x.shape[-2:]
     if transposed_y:
         n = y.shape[-2]
     else:
         n = y.shape[-1]
     dtype = str(x.dtype)
+    y_shape = y.shape
 
     from tvm.script.ir_builder import IRBuilder
     from tvm.script.ir_builder import relax as relax_builder
@@ -140,7 +154,8 @@ def get_relax_matmul_module(x, y, transposed_y=False, 
with_bias=False, activatio
 
             with R.dataflow() as frame:
                 if transposed_y:
-                    y = R.emit(R.permute_dims(y))
+                    axes = list(range(len(y_shape) - 2)) + [-1, -2]
+                    y = R.emit(R.permute_dims(y, axes=axes))
                 result = R.emit(R.matmul(x, y, out_dtype=dtype))
                 if with_bias:
                     result = R.emit(result + bias)
@@ -154,140 +169,135 @@ def get_relax_matmul_module(x, y, transposed_y=False, 
with_bias=False, activatio
     return tvm.IRModule({"main": func})
 
 
[email protected](params=["float16"])
-def target_dtype(request):
-    return request.param
-
-
[email protected](
-    params=[
-        # M, K, N
-        (32, 6, 16),
-        (29, 17, 19),
-        (64, 128, 1024),
-    ]
[email protected](
+    "x_shape, y_shape, transpose_y",
+    [
+        # Regular
+        ((32, 6), (6, 16), False),
+        # Transposed
+        ((4, 16), (16, 128), True),
+        ((35, 8), (8, 8), True),
+        # 3D x 3D
+        ((6, 32, 8), (6, 8, 10), False),
+        ((6, 32, 8), (6, 8, 10), True),
+        # 3D x 2D
+        ((6, 32, 8), (8, 10), False),
+        ((10, 16, 8), (8, 10), True),
+        # 2D x 3D
+        ((32, 8), (10, 8, 10), False),
+        ((32, 8), (10, 8, 10), True),
+        # ND x 2D
+        ((3, 6, 32, 8), (8, 10), False),
+        # 2D x ND
+        ((32, 8), (5, 3, 8, 10), False),
+        # ND x ND
+        ((5, 3, 32, 8), (5, 3, 8, 10), True),
+        ((3, 2, 4, 16, 15), (1, 1, 15, 2), True),
+        ((1, 1, 16, 15), (3, 2, 4, 15, 2), False),
+    ],
 )
-def matmul_size(request):
-    return request.param
-
-
-low, high = -10, 10
-
-
[email protected]
-def matmul_x(matmul_size, target_dtype):
-    m, k, _ = matmul_size
-    return np.random.randint(low, high, size=(m, k)).astype(target_dtype)
-
-
[email protected]
-def matmul_y(matmul_size, target_dtype):
-    _, k, n = matmul_size
-    return np.random.randint(low, high, size=(k, n)).astype(target_dtype)
-
-
[email protected]
-def matmul_bias(matmul_size, target_dtype):
-    _, _, n = matmul_size
-    return np.random.randint(low, high, size=(n,)).astype(target_dtype)
-
-
-def test_matmul_offload(matmul_x, matmul_y):
-    x, y = matmul_x, matmul_y
-
-    mod = get_relax_matmul_module(x, y)
-    out = get_result_with_relax_cutlass_offload(mod, x, y)
-    ref = build_and_run(mod, [x, y], "llvm", legalize=True)
-
-    np.testing.assert_equal(out, ref)
-
-
-def test_matmul_bias_offload(matmul_x, matmul_y, matmul_bias):
-    x, y, bias = matmul_x, matmul_y, matmul_bias
-
-    mod = get_relax_matmul_module(x, y, with_bias=True)
-    out = get_result_with_relax_cutlass_offload(mod, x, y, bias)
-    ref = build_and_run(mod, [x, y, bias], "llvm", legalize=True)
-
-    np.testing.assert_equal(out, ref)
-
-
-def test_matmul_bias_relu_offload(matmul_x, matmul_y, matmul_bias):
-    x, y, bias = matmul_x, matmul_y, matmul_bias
-
-    mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.relu)
-    out = get_result_with_relax_cutlass_offload(mod, x, y, bias)
-    ref = build_and_run(mod, [x, y, bias], "llvm", legalize=True)
-
-    np.testing.assert_equal(out, ref)
-
-
-def test_matmul_bias_gelu_offload(matmul_x, matmul_y, matmul_bias):
-    x, y, bias = matmul_x, matmul_y, matmul_bias
-    mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.gelu)
-
-    out = get_result_with_relax_cutlass_offload(mod, x, y, bias)
-    ref = build_and_run(mod, [x, y, bias], "llvm", legalize=True)
-
-    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-3)
-
-
-def test_kernel_sharing():
-    low, high = -1, 1
-    data_np = np.random.randint(low, high, size=(16, 32, 32, 
8)).astype("float16")
-    weight1_np = np.random.randint(low, high, size=(8, 3, 3, 
8)).astype("float16")
-    weight2_np = np.random.randint(low, high, size=(8, 3, 3, 
8)).astype("float16")
-
-    out = get_result_with_relax_cutlass_offload(Conv2dx2, data_np, weight1_np, 
weight2_np)
-    ref = build_and_run(Conv2dx2, [data_np, weight1_np, weight2_np], "llvm", 
legalize=True)
-
-    np.testing.assert_equal(out, ref)
-
-
-def test_matmul_transposed_offload(matmul_x, matmul_y):
-    x, y = matmul_x, matmul_y
-
-    mod = get_relax_matmul_module(x, y.transpose(), transposed_y=True)
-    out = get_result_with_relax_cutlass_offload(mod, x, y.transpose())
-    ref = build_and_run(mod, [x, y.transpose()], "llvm", legalize=True)
-
-    np.testing.assert_equal(out, ref)
-
-
-def test_matmul_transposed_bias_offload(matmul_x, matmul_y, matmul_bias):
-    x, y, bias = matmul_x, matmul_y, matmul_bias
-
-    mod = get_relax_matmul_module(
-        x, y.transpose(), transposed_y=True, with_bias=True, activation=None
-    )
-    out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias)
-    ref = build_and_run(mod, [x, y.transpose(), bias], "llvm", legalize=True)
-
-    np.testing.assert_equal(out, ref)
-
-
-def test_matmul_transposed_bias_relu_offload(matmul_x, matmul_y, matmul_bias):
-    x, y, bias = matmul_x, matmul_y, matmul_bias
[email protected](
+    "with_bias, activation",
+    [
+        (True, None),
+        (False, None),
+        (True, R.nn.relu),
+        (True, R.nn.gelu),
+    ],
+    ids=[
+        "no_bias",
+        "biased",
+        "biased_relu",
+        "biased_gelu",
+    ],
+)
[email protected](
+    "dtype",
+    [
+        "float16",
+    ],
+)
+def test_matmul_offload(
+    x_shape,
+    y_shape,
+    transpose_y,
+    with_bias,
+    activation,
+    dtype,
+):
+    x = np.random.randn(*x_shape).astype(dtype)
+    y = np.random.randn(*y_shape).astype(dtype)
+
+    if transpose_y:
+        y = np.swapaxes(y, -2, -1)
+
+    if with_bias:
+        bias = np.random.randn(y_shape[-1]).astype(dtype)
+        args = (x, y, bias)
+    else:
+        bias = None
+        args = (x, y)
 
     mod = get_relax_matmul_module(
-        x, y.transpose(), transposed_y=True, with_bias=True, 
activation=R.nn.relu
+        x, y, with_bias=with_bias, transposed_y=transpose_y, 
activation=activation
     )
-    out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias)
-    ref = build_and_run(mod, [x, y.transpose(), bias], "llvm", legalize=True)
-
-    np.testing.assert_equal(out, ref)
-
-
-def test_matmul_transposed_bias_gelu_offload(matmul_x, matmul_y, matmul_bias):
-    x, y, bias = matmul_x, matmul_y, matmul_bias
+    out = get_result_with_relax_cutlass_offload(mod, *args)
+    ref = build_and_run(mod, args, "llvm", legalize=True)
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
[email protected](
+    "x_shape, y_shape, expected",
+    [
+        # Regular matmul
+        ((3, 4), (4, 5), True),
+        # Batch matmul without stretching
+        ((3, 16, 15), (3, 15, 2), True),
+        # Broadcast 2D to 3D
+        ((3, 16, 15), (15, 2), True),
+        ((16, 15), (3, 15, 2), True),
+        # Broadcast one-length dimension
+        ((1, 16, 15), (3, 15, 2), True),
+        ((3, 16, 15), (1, 15, 2), True),
+        ((1, 1, 16, 15), (3, 2, 4, 15, 2), True),
+        # ND x ND
+        ((3, 2, 4, 16, 15), (3, 2, 4, 15, 2), True),
+        # ND x ND with one-length dimension
+        ((1, 2, 4, 16, 15), (1, 2, 4, 15, 2), True),
+        ((3, 2, 1, 16, 15), (3, 2, 1, 15, 2), True),
+        # Extra one-length dimension doesn't block broadcasting
+        ((3, 2, 1, 16, 15), (1, 1, 3, 2, 1, 15, 2), True),
+        # Not broadcasting all dims. Cannot be computed by stride-based batch 
gemm
+        ((3, 1, 1, 16, 15), (3, 2, 4, 15, 2), False),
+        ((3, 2, 4, 16, 15), (2, 4, 15, 2), False),
+        # Different shape
+        ((3, 4, 16, 15), (3, 2, 15, 2), False),
+    ],
+)
+def test_is_valid_for_cutlass_matmul(x_shape, y_shape, expected):
+    assert is_valid_for_cutlass_matmul(x_shape, y_shape) == expected
+
+
[email protected](
+    "x_shape, y_shape, transpose_y, dtype",
+    [
+        # Not broadcasting all dims. Cannot be computed by stride-based batch 
gemm
+        ((3, 1, 1, 16, 15), (3, 2, 4, 15, 2), False, "float16"),
+        ((1, 2, 1, 16, 15), (2, 1, 4, 15, 2), False, "float16"),
+        ((3, 2, 4, 16, 15), (2, 4, 15, 2), True, "float16"),
+        ((3, 16, 15), (2, 1, 3, 15, 2), True, "float16"),
+    ],
+)
+def test_cutlass_partition_matmul_blocked(x_shape, y_shape, transpose_y, 
dtype):
+    x = np.random.randn(*x_shape).astype(dtype)
+    y = np.random.randn(*y_shape).astype(dtype)
+    if transpose_y:
+        y = np.swapaxes(y, -2, -1)
 
-    mod = get_relax_matmul_module(
-        x, y.transpose(), transposed_y=True, with_bias=True, 
activation=R.nn.gelu
-    )
-    out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias)
-    ref = build_and_run(mod, [x, y.transpose(), bias], "llvm", legalize=True)
+    mod = get_relax_matmul_module(x, y, with_bias=False, 
transposed_y=transpose_y)
 
-    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-3)
+    tvm.ir.assert_structural_equal(mod, partition_for_cutlass(mod))
 
 
 if __name__ == "__main__":


Reply via email to