This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 354683860c [Unity][BYOC] Add batch matmul support to Relax CUTLASS
BYOC (#14166)
354683860c is described below
commit 354683860c7b45b18923b1c2e3541ebc49053c0d
Author: Lite Ye <[email protected]>
AuthorDate: Thu Mar 2 01:13:08 2023 -0500
[Unity][BYOC] Add batch matmul support to Relax CUTLASS BYOC (#14166)
* Add batch matmul support to Relax CUTLASS BYOC
* Allow more dtypes
* Fix tests
* Revert how to get batch attr
---
python/tvm/contrib/cutlass/build.py | 59 +++++-
python/tvm/contrib/cutlass/gemm_operation.py | 9 +-
python/tvm/contrib/cutlass/gen_tensor_op.py | 15 +-
python/tvm/relax/backend/contrib/cutlass.py | 90 +++++++--
python/tvm/relax/backend/pattern_registry.py | 77 ++++++--
src/relax/backend/pattern_registry.cc | 20 +-
src/relax/backend/pattern_registry.h | 17 +-
src/relax/transform/fuse_ops.cc | 3 +-
tests/python/relax/test_codegen_cutlass.py | 286 ++++++++++++++-------------
9 files changed, 380 insertions(+), 196 deletions(-)
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 954aef60c2..7e81113f44 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -16,10 +16,13 @@
# under the License.
# pylint: disable=invalid-name, dangerous-default-value, arguments-differ
"""Driver for partitioning and building a Relay module for CUTLASS offload."""
+import itertools
import logging
import multiprocessing
+import operator
import os
-from typing import Optional
+from functools import reduce
+from typing import Optional, Tuple
import tvm
from tvm import relax, relay, runtime
@@ -569,6 +572,31 @@ def _extract_arg_idx(pattern_name, f):
return arg_idx
+def is_valid_for_cutlass_matmul(lhs_shape: Tuple[int], rhs_shape: Tuple[int])
-> bool:
+ """
+ Check whether the shape of inputs can be handled by CUTLASS GEMM.
+
+ The stride-based batch matmul in CUTLASS cannot handle cases that some of
+ the batch dimensions need to be stretched while others don't. This means
+ it can only handle ND x ND whose batch dimensions match exactly on both
side,
+ as well as ND x 2D and 2D x ND. For example, it cannot handle matmul with
shape
+ (2, 1, 4, 8) x (2, 3, 8, 16), because the batch stride of lhs is not
constant.
+ """
+ lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
+ rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
+ if lhs_batches == 1 or rhs_batches == 1:
+ # This could be regular matmul or batch matmul with shape ND x 2D or
2D x ND
+ return True
+
+ # If one side has less dimensions, use 1 to fill the gap
+ batch_dim_pairs = itertools.zip_longest(
+ lhs_shape[-3::-1], # Remove the last two dimensions and reverse
+ rhs_shape[-3::-1],
+ fillvalue=1,
+ )
+ return all(p[0] == p[1] for p in batch_dim_pairs)
+
+
@relax.expr_functor.mutator
class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
"""A Relax function mutator that tunes and annotates CUTLASS composite
functions
@@ -659,17 +687,35 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
rhs_dtype = signature[f"{rhs_arg}_dtype"]
out_dtype = signature["ret_dtype"]
- MM = lhs_shape[0]
- KK = lhs_shape[1]
+ if not is_valid_for_cutlass_matmul(lhs_shape, rhs_shape):
+ raise ValueError(f"Cannot handle the input shapes, lhs:
{lhs_shape}, rhs: {rhs_shape}")
+
+ MM = lhs_shape[-2]
+ KK = lhs_shape[-1]
if "transposed" in op_type:
- NN = rhs_shape[0]
+ NN = rhs_shape[-2]
ldb = "K"
layout_b = LayoutType.ColumnMajor
else:
- NN = rhs_shape[1]
+ NN = rhs_shape[-1]
ldb = "N"
layout_b = LayoutType.RowMajor
+ lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
+ rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
+ if lhs_batches == 1 and rhs_batches == 1:
+ # Regular matmul
+ is_batched = False
+ batch_attrs = {}
+ else:
+ is_batched = True
+ batch_attrs = {
+ "batch": max(lhs_batches, rhs_batches),
+ "batch_stride_A": 0 if lhs_batches == 1 else MM * KK,
+ "batch_stride_B": 0 if rhs_batches == 1 else KK * NN,
+ "batch_stride_C": MM * NN,
+ }
+
use_3xtf32 = self.options.get("use_3xtf32", False)
find_first_valid = self.options.get("find_first_valid", True)
use_multiprocessing = self.options.get("use_multiprocessing", True)
@@ -683,7 +729,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
lhs_dtype,
rhs_dtype,
use_3xtf32,
- batched=False,
+ batched=is_batched,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
layout_b=layout_b,
@@ -706,6 +752,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
"ldc": "N",
"cutlass_op_name": op_name,
"cutlass_op_def": op_def,
+ **batch_attrs,
}
)
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py
b/python/tvm/contrib/cutlass/gemm_operation.py
index 3e74cbaec8..1675e1f035 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -287,7 +287,7 @@ def instantiate_gemm_template(attrs):
{static_cast<ElementInputA*>(ptr_a), ${lda}}, ${batch_stride_A}
{static_cast<ElementInputB*>(ptr_b), ${ldb}}, ${batch_stride_B}
{static_cast<ElementOutput*>(${ptr_c}), ${c_stride}}, ${batch_stride_C}
- {static_cast<ElementOutput*>(ptr_out), ${ldc}}, ${batch_stride_C}
+ {static_cast<ElementOutput*>(ptr_out), ${ldc}}, ${batch_stride_D}
{${alpha_beta}},
${split_k_slices_or_batch}
};
@@ -303,8 +303,7 @@ def instantiate_gemm_template(attrs):
"""
has_bias = "bias" in attrs["op_type"]
is_gelu = "gelu" in attrs["op_type"]
- batched = "batch_matmul" in attrs["op_type"]
-
+ batched = "batch" in attrs
aux_map = {"kernel": "Gemm"}
if has_bias:
@@ -335,6 +334,10 @@ def instantiate_gemm_template(attrs):
else:
aux_map[key] = attrs[key] + ","
+ aux_map["batch_stride_D"] = aux_map["batch_stride_C"]
+ if has_bias and batched:
+ aux_map["batch_stride_C"] = "0,"
+
if batched:
attrs["split_k_slices_or_batch"] = attrs["batch"]
else:
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 92bf04e863..2976946dd2 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -529,8 +529,7 @@ def instantiate_template(func_name, annotations, func_args):
return dim1 + " * " + dim2
if "dense" in func_name or "matmul" in func_name:
- batched = "batch_matmul" in func_name
- batched_offset = 1 if batched else 0
+ batched = "batch" in annotations
transposed = "transposed" in func_name
lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx",
0)
rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx",
1)
@@ -539,6 +538,8 @@ def instantiate_template(func_name, annotations, func_args):
rhs_arg = func_args[rhs_arg_idx]
lhs_shape = annotations[f"arg{lhs_arg_idx}_shape"]
rhs_shape = annotations[f"arg{rhs_arg_idx}_shape"]
+ lhs_batched_offset = len(lhs_shape) - 2
+ rhs_batched_offset = len(rhs_shape) - 2
attrs["lhs_arg"] = lhs_arg
attrs["rhs_arg"] = rhs_arg
@@ -548,17 +549,17 @@ def instantiate_template(func_name, annotations,
func_args):
attrs["ElementInputB"] =
DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]]
attrs["ElementOutput"] =
DataTypeTag[dtype_map[annotations["ret_dtype"]]]
- attrs["K"] = str(int(lhs_shape[batched_offset + 1]))
- attrs["M"] = get_dim(lhs_shape[batched_offset], lhs_arg, 0,
batched_offset)
+ attrs["K"] = str(int(lhs_shape[lhs_batched_offset + 1]))
+ attrs["M"] = get_dim(lhs_shape[lhs_batched_offset], lhs_arg, 0,
lhs_batched_offset)
if transposed:
- attrs["N"] = get_dim(rhs_shape[batched_offset], rhs_arg, 0,
batched_offset)
+ attrs["N"] = get_dim(rhs_shape[rhs_batched_offset], rhs_arg, 0,
rhs_batched_offset)
else:
- attrs["N"] = get_dim(rhs_shape[batched_offset + 1], rhs_arg, 1,
batched_offset)
+ attrs["N"] = get_dim(rhs_shape[rhs_batched_offset + 1], rhs_arg,
1, rhs_batched_offset)
if batched:
headers.append("cutlass/gemm/device/gemm_batched.h")
- attrs["batch"] = get_dim(lhs_shape[0], lhs_arg, 0)
+ attrs["batch"] = get_dim(annotations["batch"], lhs_arg, 0)
attrs["batch_stride_A"] = get_batch_stride(
annotations["batch_stride_A"],
lhs_arg_idx,
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index e98194ca21..2d8908184b 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -17,16 +17,69 @@
"""Pattern table for CUTLASS backend"""
-from tvm.relax import transform
+from typing import Mapping, Optional, Tuple
+
+import tvm
+from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul
+from tvm.relax import Call, Expr, ShapeExpr, transform
+from tvm.relax.dpl import DFPattern
from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern
+
+def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]:
+ result = []
+ for dim in shape.values:
+ if isinstance(dim, tvm.tir.expr.IntImm):
+ result.append(int(dim))
+ else:
+ return None
+ return result
+
+
+def _is_supported_dtype(lhs_dtype, rhs_dtype):
+ """Check if dtypes in the given workload are supported by CUTLASS."""
+ return (
+ (lhs_dtype == "float16" and rhs_dtype == "float16")
+ or (lhs_dtype == "float32" and rhs_dtype == "float32")
+ or (lhs_dtype in ("int8", "uint8") and rhs_dtype in ("int8", "uint8"))
+ )
+
+
+def _check_matmul(
+ match_result: Mapping[DFPattern, Expr],
+ _: Expr,
+) -> bool:
+ matmul_call: Call = None
+ for _, expr in match_result.items():
+ if (
+ isinstance(expr, Call)
+ and isinstance(expr.op, tvm.ir.Op)
+ and expr.op.name == "relax.matmul"
+ ):
+ matmul_call = expr
+ if matmul_call is None:
+ raise ValueError("Cannot find call to matmul from match_result.")
+
+ lhs_shape = _get_static_shape(matmul_call.args[0].struct_info.shape)
+ rhs_shape = _get_static_shape(matmul_call.args[1].struct_info.shape)
+ if len(lhs_shape) < 2 or len(rhs_shape) < 2:
+ return False
+
+ lhs_dtype = matmul_call.args[0].struct_info.dtype
+ rhs_dtype = matmul_call.args[1].struct_info.dtype
+ if not _is_supported_dtype(lhs_dtype, rhs_dtype):
+ return False
+
+ return is_valid_for_cutlass_matmul(lhs_shape, rhs_shape)
+
+
register_patterns(
[
(
"cutlass.conv2d",
- make_fused_bias_activation_pattern(
+ *make_fused_bias_activation_pattern(
"relax.nn.conv2d",
with_bias=False,
activation=None,
@@ -34,7 +87,7 @@ register_patterns(
),
(
"cutlass.conv2d_bias_relu",
- make_fused_bias_activation_pattern(
+ *make_fused_bias_activation_pattern(
"relax.nn.conv2d",
with_bias=True,
activation="relax.nn.relu",
@@ -42,59 +95,67 @@ register_patterns(
),
(
"cutlass.matmul",
- make_matmul_pattern(
+ *make_matmul_pattern(
with_bias=False,
),
+ _check_matmul,
),
(
"cutlass.matmul_bias",
- make_matmul_pattern(
+ *make_matmul_pattern(
with_bias=True,
),
+ _check_matmul,
),
(
"cutlass.matmul_bias_relu",
- make_matmul_pattern(
+ *make_matmul_pattern(
with_bias=True,
activation="relax.nn.relu",
),
+ _check_matmul,
),
(
"cutlass.matmul_bias_gelu",
- make_matmul_pattern(
+ *make_matmul_pattern(
with_bias=True,
activation="relax.nn.gelu",
),
+ _check_matmul,
),
(
"cutlass.matmul_transposed",
- make_matmul_pattern(
+ *make_matmul_pattern(
with_bias=False,
transposed_rhs=True,
),
+ _check_matmul,
),
(
"cutlass.matmul_transposed_bias",
- make_matmul_pattern(
+ *make_matmul_pattern(
with_bias=True,
transposed_rhs=True,
),
+ _check_matmul,
),
(
"cutlass.matmul_transposed_bias_relu",
- make_matmul_pattern(
+ *make_matmul_pattern(
with_bias=True,
activation="relax.nn.relu",
transposed_rhs=True,
),
+ _check_matmul,
),
(
"cutlass.matmul_transposed_bias_gelu",
- make_matmul_pattern(
+ *make_matmul_pattern(
with_bias=True,
activation="relax.nn.gelu",
transposed_rhs=True,
),
+ _check_matmul,
),
]
)
@@ -116,7 +177,6 @@ def partition_for_cutlass(mod):
compiled by the CUTLASS backend.
"""
- cutlass_patterns = get_patterns_with_prefix("cutlass")
- return transform.FuseOpsByPattern(cutlass_patterns, bind_constants=True,
annotate_codegen=True)(
- mod
- )
+ cutlass_pattern_entries = get_patterns_with_prefix("cutlass")
+ patterns = [(e.name, e.pattern, e.check) for e in cutlass_pattern_entries]
+ return transform.FuseOpsByPattern(patterns, bind_constants=False,
annotate_codegen=True)(mod)
diff --git a/python/tvm/relax/backend/pattern_registry.py
b/python/tvm/relax/backend/pattern_registry.py
index 0cccab7842..5a35eba03d 100644
--- a/python/tvm/relax/backend/pattern_registry.py
+++ b/python/tvm/relax/backend/pattern_registry.py
@@ -17,14 +17,15 @@
"""Pattern registry for BYOC backends"""
-from typing import Callable, List, Mapping, Optional, Tuple, Union
+import atexit
+from typing import Callable, List, Mapping, Optional, Set, Tuple, Union
import tvm
from tvm.relax.dpl import DFPattern
from tvm.runtime import Object
-from . import _ffi_api
from ..expr import Expr
+from . import _ffi_api
@tvm._ffi.register_object("relax.backend.PatternRegistryEntry")
@@ -46,25 +47,59 @@ class PatternRegistryEntry(Object):
arg_patterns: Mapping[str, DFPattern]
The mapping from arg name to its pattern. It can be used to extract
arg expression
from match result. All DFPattern in this map should be part of the
`pattern`.
+
+ check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
+ The function to check whether the match result is accepted.
"""
name: str
pattern: DFPattern
arg_patterns: Mapping[str, DFPattern]
-
- def __init__(self, name: str, pattern: DFPattern, arg_patterns:
Mapping[str, DFPattern]):
+ check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
+
+ def __init__(
+ self,
+ name: str,
+ pattern: DFPattern,
+ arg_patterns: Mapping[str, DFPattern],
+ check: Callable[[Mapping[DFPattern, Expr], Expr], bool],
+ ):
self.__init_handle_by_constructor__(
- _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns #
type: ignore
+ _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns, check
# type: ignore
)
+_REGISTERED_PATTERN_NAMES: Set[str] = set()
+
+
+def _cleanup_registered_patterns():
+ _ffi_api.RemovePatterns(list(_REGISTERED_PATTERN_NAMES)) # type: ignore #
pylint: disable=no-member
+
+
+_CLEANUP_REGISTERED = False
+
+
+def _ensure_cleanup_function_registered():
+ """
+ Add a cleanup function to be called on interpreter termination, to remove
all
+ patterns registered on the Python side. Without cleaning up those patterns,
+ program will segfault on termination. It's because the check functions of
pattern
+ entries are referenced from the static memory of libtvm, thus they will be
cleaned
+ up at the very end, making calls to Py_DecRef after Python interpreter
terminates.
+ """
+ global _CLEANUP_REGISTERED # pylint: disable=global-statement
+
+ if not _CLEANUP_REGISTERED:
+ atexit.register(_cleanup_registered_patterns)
+ _CLEANUP_REGISTERED = True
+
+
+CheckFunc = Callable[[Mapping[DFPattern, Expr], Expr], bool]
Pattern = Union[
PatternRegistryEntry,
Tuple[str, DFPattern],
- Tuple[
- str,
- Tuple[DFPattern, Mapping[str, DFPattern], Callable[[Mapping[DFPattern,
Expr], Expr], bool]],
- ],
+ Tuple[str, DFPattern, Mapping[str, DFPattern]],
+ Tuple[str, DFPattern, Mapping[str, DFPattern], CheckFunc],
]
@@ -79,21 +114,27 @@ def register_patterns(patterns: List[Pattern]):
Patterns to be registered. Patterns that appear later in the list have
higher priority when partitioning DataflowBlock.
"""
+ _ensure_cleanup_function_registered()
+
entries = []
for item in patterns:
if isinstance(item, PatternRegistryEntry):
entries.append(item)
elif isinstance(item, tuple):
- name, pattern_or_tuple = item
- if isinstance(pattern_or_tuple, tuple):
- if len(pattern_or_tuple) == 2:
- pattern, arg_patterns = pattern_or_tuple
- check = lambda *_: True
- else:
- pattern, arg_patterns, check = pattern_or_tuple
+ name, pattern, *rest = item
+
+ if len(rest) > 0:
+ arg_patterns = rest[0]
else:
- pattern, arg_patterns, check = pattern_or_tuple, {}, lambda
*_: True
- entries.append(PatternRegistryEntry(name, pattern, arg_patterns))
+ arg_patterns = {}
+
+ if len(rest) > 1:
+ check = rest[1]
+ else:
+ check = lambda *_: True
+
+ entries.append(PatternRegistryEntry(name, pattern, arg_patterns,
check))
+ _REGISTERED_PATTERN_NAMES.add(name)
else:
raise TypeError(f"Cannot register type {type(pattern)} as pattern")
_ffi_api.RegisterPatterns(entries)
diff --git a/src/relax/backend/pattern_registry.cc
b/src/relax/backend/pattern_registry.cc
index 3ca7973365..553018d690 100644
--- a/src/relax/backend/pattern_registry.cc
+++ b/src/relax/backend/pattern_registry.cc
@@ -26,11 +26,12 @@ namespace relax {
namespace backend {
PatternRegistryEntry::PatternRegistryEntry(String name, DFPattern pattern,
- Map<String, DFPattern>
arg_patterns) {
+ Map<String, DFPattern>
arg_patterns, PackedFunc check) {
ObjectPtr<PatternRegistryEntryNode> n =
make_object<PatternRegistryEntryNode>();
n->name = std::move(name);
n->pattern = std::move(pattern);
n->arg_patterns = std::move(arg_patterns);
+ n->check = check;
data_ = std::move(n);
}
@@ -48,6 +49,17 @@ void RegisterPatterns(Array<PatternRegistryEntry> entries) {
}
}
+void RemovePatterns(Array<String> names) {
+ std::unordered_set<String> name_set{names.begin(), names.end()};
+
+ auto* table = GetRegistryTable();
+ table->erase(std::remove_if(table->begin(), table->end(),
+ [&](const PatternRegistryEntry& entry) {
+ return name_set.count(entry->name) > 0;
+ }),
+ table->end());
+}
+
Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix) {
auto* table = GetRegistryTable();
Array<PatternRegistryEntry> result;
@@ -70,10 +82,12 @@ Optional<PatternRegistryEntry> GetPattern(const String&
pattern_name) {
}
TVM_REGISTER_GLOBAL("relax.backend.PatternRegistryEntry")
- .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern>
arg_patterns) {
- return PatternRegistryEntry(name, pattern, arg_patterns);
+ .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern>
arg_patterns,
+ PackedFunc check) {
+ return PatternRegistryEntry(name, pattern, arg_patterns, check);
});
TVM_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns);
+TVM_REGISTER_GLOBAL("relax.backend.RemovePatterns").set_body_typed(RemovePatterns);
TVM_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix").set_body_typed(GetPatternsWithPrefix);
TVM_REGISTER_GLOBAL("relax.backend.GetPattern").set_body_typed(GetPattern);
diff --git a/src/relax/backend/pattern_registry.h
b/src/relax/backend/pattern_registry.h
index 1fdb69319f..e765f56b4e 100644
--- a/src/relax/backend/pattern_registry.h
+++ b/src/relax/backend/pattern_registry.h
@@ -42,8 +42,6 @@ namespace backend {
*/
class PatternRegistryEntryNode : public Object {
public:
- using FCheckMatch = runtime::TypedPackedFunc<bool(const Map<DFPattern,
Expr>&, const Expr&)>;
-
/*!
* \brief The name of pattern. Usually it starts with the name of backend,
like
* 'cutlass.matmul'.
@@ -63,13 +61,17 @@ class PatternRegistryEntryNode : public Object {
/*!
* \brief The function to check whether the match result is accepted.
+ *
+ * It should have signature
+ * bool(const Map<DFPattern, Expr>& match_result, const Expr& matched_expr)
*/
- FCheckMatch check;
+ PackedFunc check;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("pattern", &pattern);
v->Visit("arg_patterns", &arg_patterns);
+ v->Visit("check", &check);
}
static constexpr const char* _type_key =
"relax.backend.PatternRegistryEntry";
@@ -78,7 +80,8 @@ class PatternRegistryEntryNode : public Object {
class PatternRegistryEntry : public ObjectRef {
public:
- PatternRegistryEntry(String name, DFPattern pattern, Map<String, DFPattern>
arg_patterns);
+ PatternRegistryEntry(String name, DFPattern pattern, Map<String, DFPattern>
arg_patterns,
+ PackedFunc check);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternRegistryEntry, ObjectRef,
PatternRegistryEntryNode);
@@ -92,6 +95,12 @@ class PatternRegistryEntry : public ObjectRef {
*/
void RegisterPatterns(Array<PatternRegistryEntry> entries);
+/*!
+ * \brief Remove patterns from the registry by their name.
+ * \param names The name of patterns to be removed
+ */
+void RemovePatterns(Array<String> names);
+
/*!
* \brief Find patterns whose name starts with a particular prefix.
* \param prefx The pattern name prefix.
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 72427a8a0e..d6013c8874 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -40,7 +40,6 @@
#include "../../relay/analysis/graph_partitioner.h"
#include "../../support/arena.h"
-#include "../backend/pattern_registry.h"
namespace tvm {
namespace relax {
@@ -902,7 +901,7 @@ class PatternBasedPartitioner : ExprVisitor {
using Group = GraphPartitioner::Group;
using GroupMap = OperatorFusor::GroupMap;
using ExprVisitor::VisitExpr_;
- using FCheckMatch = backend::PatternRegistryEntryNode::FCheckMatch;
+ using FCheckMatch = runtime::TypedPackedFunc<bool(const Map<DFPattern,
Expr>&, const Expr&)>;
static GroupMap Run(String pattern_name, DFPattern pattern, FCheckMatch
check, Expr expr,
support::Arena* arena) {
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 6eb476496c..83104d6fe1 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -16,11 +16,14 @@
# under the License.
import numpy as np
import pytest
+import scipy
import tvm
import tvm.testing
-from tvm import relax
+from tvm import relax, relay
+from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul
from tvm.relax.backend import get_patterns_with_prefix
+from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
from tvm.script import relax as R
@@ -95,20 +98,17 @@ def build_and_run(mod, inputs_np, target, legalize=False):
def get_result_with_relax_cutlass_offload(mod, *args):
patterns = [(entry.name, entry.pattern) for entry in
get_patterns_with_prefix("cutlass")]
-
assert len(patterns) != 0, "Cannot find cutlass patterns"
- seq = tvm.transform.Sequential(
- [
- relax.transform.FuseOpsByPattern(patterns, bind_constants=False,
annotate_codegen=True),
- relax.transform.RunCodegen({"cutlass": {"sm": 80,
"find_first_valid": True}}),
- ]
- )
+ mod = partition_for_cutlass(mod)
+ codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80,
"find_first_valid": True}})
+ mod = codegen_pass(mod)
- return build_and_run(seq(mod), args, "cuda")
+ return build_and_run(mod, args, "cuda")
def test_conv2d_offload():
+ low, high = -1, 1
data = np.random.randint(low, high, size=(16, 32, 32,
16)).astype("float16")
weight = np.random.randint(low, high, size=(32, 3, 3,
16)).astype("float16")
bias = np.random.randint(low, high, size=(1, 1, 1, 32)).astype("float16")
@@ -120,12 +120,26 @@ def test_conv2d_offload():
np.testing.assert_equal(out, ref)
+def test_kernel_sharing():
+ low, high = -1, 1
+ data_np = np.random.randint(low, high, size=(16, 32, 32,
8)).astype("float16")
+ weight1_np = np.random.randint(low, high, size=(8, 3, 3,
8)).astype("float16")
+ weight2_np = np.random.randint(low, high, size=(8, 3, 3,
8)).astype("float16")
+
+ out = get_result_with_relax_cutlass_offload(Conv2dx2, data_np, weight1_np,
weight2_np)
+ ref = build_and_run(Conv2dx2, [data_np, weight1_np, weight2_np], "llvm",
legalize=True)
+
+ np.testing.assert_equal(out, ref)
+
+
def get_relax_matmul_module(x, y, transposed_y=False, with_bias=False,
activation=None):
+ m, k = x.shape[-2:]
if transposed_y:
n = y.shape[-2]
else:
n = y.shape[-1]
dtype = str(x.dtype)
+ y_shape = y.shape
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder
@@ -140,7 +154,8 @@ def get_relax_matmul_module(x, y, transposed_y=False,
with_bias=False, activatio
with R.dataflow() as frame:
if transposed_y:
- y = R.emit(R.permute_dims(y))
+ axes = list(range(len(y_shape) - 2)) + [-1, -2]
+ y = R.emit(R.permute_dims(y, axes=axes))
result = R.emit(R.matmul(x, y, out_dtype=dtype))
if with_bias:
result = R.emit(result + bias)
@@ -154,140 +169,135 @@ def get_relax_matmul_module(x, y, transposed_y=False,
with_bias=False, activatio
return tvm.IRModule({"main": func})
[email protected](params=["float16"])
-def target_dtype(request):
- return request.param
-
-
[email protected](
- params=[
- # M, K, N
- (32, 6, 16),
- (29, 17, 19),
- (64, 128, 1024),
- ]
[email protected](
+ "x_shape, y_shape, transpose_y",
+ [
+ # Regular
+ ((32, 6), (6, 16), False),
+ # Transposed
+ ((4, 16), (16, 128), True),
+ ((35, 8), (8, 8), True),
+ # 3D x 3D
+ ((6, 32, 8), (6, 8, 10), False),
+ ((6, 32, 8), (6, 8, 10), True),
+ # 3D x 2D
+ ((6, 32, 8), (8, 10), False),
+ ((10, 16, 8), (8, 10), True),
+ # 2D x 3D
+ ((32, 8), (10, 8, 10), False),
+ ((32, 8), (10, 8, 10), True),
+ # ND x 2D
+ ((3, 6, 32, 8), (8, 10), False),
+ # 2D x ND
+ ((32, 8), (5, 3, 8, 10), False),
+ # ND x ND
+ ((5, 3, 32, 8), (5, 3, 8, 10), True),
+ ((3, 2, 4, 16, 15), (1, 1, 15, 2), True),
+ ((1, 1, 16, 15), (3, 2, 4, 15, 2), False),
+ ],
)
-def matmul_size(request):
- return request.param
-
-
-low, high = -10, 10
-
-
[email protected]
-def matmul_x(matmul_size, target_dtype):
- m, k, _ = matmul_size
- return np.random.randint(low, high, size=(m, k)).astype(target_dtype)
-
-
[email protected]
-def matmul_y(matmul_size, target_dtype):
- _, k, n = matmul_size
- return np.random.randint(low, high, size=(k, n)).astype(target_dtype)
-
-
[email protected]
-def matmul_bias(matmul_size, target_dtype):
- _, _, n = matmul_size
- return np.random.randint(low, high, size=(n,)).astype(target_dtype)
-
-
-def test_matmul_offload(matmul_x, matmul_y):
- x, y = matmul_x, matmul_y
-
- mod = get_relax_matmul_module(x, y)
- out = get_result_with_relax_cutlass_offload(mod, x, y)
- ref = build_and_run(mod, [x, y], "llvm", legalize=True)
-
- np.testing.assert_equal(out, ref)
-
-
-def test_matmul_bias_offload(matmul_x, matmul_y, matmul_bias):
- x, y, bias = matmul_x, matmul_y, matmul_bias
-
- mod = get_relax_matmul_module(x, y, with_bias=True)
- out = get_result_with_relax_cutlass_offload(mod, x, y, bias)
- ref = build_and_run(mod, [x, y, bias], "llvm", legalize=True)
-
- np.testing.assert_equal(out, ref)
-
-
-def test_matmul_bias_relu_offload(matmul_x, matmul_y, matmul_bias):
- x, y, bias = matmul_x, matmul_y, matmul_bias
-
- mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.relu)
- out = get_result_with_relax_cutlass_offload(mod, x, y, bias)
- ref = build_and_run(mod, [x, y, bias], "llvm", legalize=True)
-
- np.testing.assert_equal(out, ref)
-
-
-def test_matmul_bias_gelu_offload(matmul_x, matmul_y, matmul_bias):
- x, y, bias = matmul_x, matmul_y, matmul_bias
- mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.gelu)
-
- out = get_result_with_relax_cutlass_offload(mod, x, y, bias)
- ref = build_and_run(mod, [x, y, bias], "llvm", legalize=True)
-
- tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-3)
-
-
-def test_kernel_sharing():
- low, high = -1, 1
- data_np = np.random.randint(low, high, size=(16, 32, 32,
8)).astype("float16")
- weight1_np = np.random.randint(low, high, size=(8, 3, 3,
8)).astype("float16")
- weight2_np = np.random.randint(low, high, size=(8, 3, 3,
8)).astype("float16")
-
- out = get_result_with_relax_cutlass_offload(Conv2dx2, data_np, weight1_np,
weight2_np)
- ref = build_and_run(Conv2dx2, [data_np, weight1_np, weight2_np], "llvm",
legalize=True)
-
- np.testing.assert_equal(out, ref)
-
-
-def test_matmul_transposed_offload(matmul_x, matmul_y):
- x, y = matmul_x, matmul_y
-
- mod = get_relax_matmul_module(x, y.transpose(), transposed_y=True)
- out = get_result_with_relax_cutlass_offload(mod, x, y.transpose())
- ref = build_and_run(mod, [x, y.transpose()], "llvm", legalize=True)
-
- np.testing.assert_equal(out, ref)
-
-
-def test_matmul_transposed_bias_offload(matmul_x, matmul_y, matmul_bias):
- x, y, bias = matmul_x, matmul_y, matmul_bias
-
- mod = get_relax_matmul_module(
- x, y.transpose(), transposed_y=True, with_bias=True, activation=None
- )
- out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias)
- ref = build_and_run(mod, [x, y.transpose(), bias], "llvm", legalize=True)
-
- np.testing.assert_equal(out, ref)
-
-
-def test_matmul_transposed_bias_relu_offload(matmul_x, matmul_y, matmul_bias):
- x, y, bias = matmul_x, matmul_y, matmul_bias
[email protected](
+ "with_bias, activation",
+ [
+ (True, None),
+ (False, None),
+ (True, R.nn.relu),
+ (True, R.nn.gelu),
+ ],
+ ids=[
+ "no_bias",
+ "biased",
+ "biased_relu",
+ "biased_gelu",
+ ],
+)
[email protected](
+ "dtype",
+ [
+ "float16",
+ ],
+)
+def test_matmul_offload(
+ x_shape,
+ y_shape,
+ transpose_y,
+ with_bias,
+ activation,
+ dtype,
+):
+ x = np.random.randn(*x_shape).astype(dtype)
+ y = np.random.randn(*y_shape).astype(dtype)
+
+ if transpose_y:
+ y = np.swapaxes(y, -2, -1)
+
+ if with_bias:
+ bias = np.random.randn(y_shape[-1]).astype(dtype)
+ args = (x, y, bias)
+ else:
+ bias = None
+ args = (x, y)
mod = get_relax_matmul_module(
- x, y.transpose(), transposed_y=True, with_bias=True,
activation=R.nn.relu
+ x, y, with_bias=with_bias, transposed_y=transpose_y,
activation=activation
)
- out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias)
- ref = build_and_run(mod, [x, y.transpose(), bias], "llvm", legalize=True)
-
- np.testing.assert_equal(out, ref)
-
-
-def test_matmul_transposed_bias_gelu_offload(matmul_x, matmul_y, matmul_bias):
- x, y, bias = matmul_x, matmul_y, matmul_bias
+ out = get_result_with_relax_cutlass_offload(mod, *args)
+ ref = build_and_run(mod, args, "llvm", legalize=True)
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
[email protected](
+ "x_shape, y_shape, expected",
+ [
+ # Regular matmul
+ ((3, 4), (4, 5), True),
+ # Batch matmul without stretching
+ ((3, 16, 15), (3, 15, 2), True),
+ # Broadcast 2D to 3D
+ ((3, 16, 15), (15, 2), True),
+ ((16, 15), (3, 15, 2), True),
+ # Broadcast one-length dimension
+ ((1, 16, 15), (3, 15, 2), True),
+ ((3, 16, 15), (1, 15, 2), True),
+ ((1, 1, 16, 15), (3, 2, 4, 15, 2), True),
+ # ND x ND
+ ((3, 2, 4, 16, 15), (3, 2, 4, 15, 2), True),
+ # ND x ND with one-length dimension
+ ((1, 2, 4, 16, 15), (1, 2, 4, 15, 2), True),
+ ((3, 2, 1, 16, 15), (3, 2, 1, 15, 2), True),
+ # Extra one-length dimension doesn't block broadcasting
+ ((3, 2, 1, 16, 15), (1, 1, 3, 2, 1, 15, 2), True),
+ # Not broadcasting all dims. Cannot be computed by stride-based batch
gemm
+ ((3, 1, 1, 16, 15), (3, 2, 4, 15, 2), False),
+ ((3, 2, 4, 16, 15), (2, 4, 15, 2), False),
+ # Different shape
+ ((3, 4, 16, 15), (3, 2, 15, 2), False),
+ ],
+)
+def test_is_valid_for_cutlass_matmul(x_shape, y_shape, expected):
+ assert is_valid_for_cutlass_matmul(x_shape, y_shape) == expected
+
+
[email protected](
+ "x_shape, y_shape, transpose_y, dtype",
+ [
+ # Not broadcasting all dims. Cannot be computed by stride-based batch
gemm
+ ((3, 1, 1, 16, 15), (3, 2, 4, 15, 2), False, "float16"),
+ ((1, 2, 1, 16, 15), (2, 1, 4, 15, 2), False, "float16"),
+ ((3, 2, 4, 16, 15), (2, 4, 15, 2), True, "float16"),
+ ((3, 16, 15), (2, 1, 3, 15, 2), True, "float16"),
+ ],
+)
+def test_cutlass_partition_matmul_blocked(x_shape, y_shape, transpose_y,
dtype):
+ x = np.random.randn(*x_shape).astype(dtype)
+ y = np.random.randn(*y_shape).astype(dtype)
+ if transpose_y:
+ y = np.swapaxes(y, -2, -1)
- mod = get_relax_matmul_module(
- x, y.transpose(), transposed_y=True, with_bias=True,
activation=R.nn.gelu
- )
- out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias)
- ref = build_and_run(mod, [x, y.transpose(), bias], "llvm", legalize=True)
+ mod = get_relax_matmul_module(x, y, with_bias=False,
transposed_y=transpose_y)
- tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-3)
+ tvm.ir.assert_structural_equal(mod, partition_for_cutlass(mod))
if __name__ == "__main__":