This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 4e03d85514 [Unity][BYOC]Add relax backend pattern registry (#14106)
4e03d85514 is described below
commit 4e03d85514191996bb5c1fabc1c8e3463efffa2c
Author: Lite Ye <[email protected]>
AuthorDate: Fri Feb 24 04:16:47 2023 -0500
[Unity][BYOC]Add relax backend pattern registry (#14106)
* Add relax backend pattern registry
* Add doc
---
CMakeLists.txt | 1 +
python/tvm/relax/backend/__init__.py | 20 +++++
python/tvm/relax/backend/_ffi_api.py | 21 +++++
python/tvm/relax/backend/contrib/__init__.py | 20 +++++
python/tvm/relax/backend/contrib/cutlass.py | 90 +++++++++++++++++++
python/tvm/relax/backend/pattern_registry.py | 125 +++++++++++++++++++++++++++
python/tvm/relax/backend/patterns.py | 115 ++++++++++++++++++++++++
python/tvm/relax/dpl/pattern.py | 27 ++----
src/relax/backend/pattern_registry.cc | 82 ++++++++++++++++++
src/relax/backend/pattern_registry.h | 106 +++++++++++++++++++++++
tests/python/relax/test_codegen_cutlass.py | 67 +++-----------
11 files changed, 598 insertions(+), 76 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 18be118832..22e82e2fb7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -295,6 +295,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/relax/transform/*.cc
src/relax/backend/vm/*.cc
src/relax/backend/task_extraction.cc
+ src/relax/backend/pattern_registry.cc
src/relax/utils.cc
)
diff --git a/python/tvm/relax/backend/__init__.py
b/python/tvm/relax/backend/__init__.py
new file mode 100644
index 0000000000..c3786591e3
--- /dev/null
+++ b/python/tvm/relax/backend/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relax backends"""
+
+from . import contrib
+from .pattern_registry import get_pattern, get_patterns_with_prefix
diff --git a/python/tvm/relax/backend/_ffi_api.py
b/python/tvm/relax/backend/_ffi_api.py
new file mode 100644
index 0000000000..d1378b2eac
--- /dev/null
+++ b/python/tvm/relax/backend/_ffi_api.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI API for Relax backend."""
+
+import tvm._ffi
+
+tvm._ffi._init_api("relax.backend", __name__)
diff --git a/python/tvm/relax/backend/contrib/__init__.py
b/python/tvm/relax/backend/contrib/__init__.py
new file mode 100644
index 0000000000..a094c97d24
--- /dev/null
+++ b/python/tvm/relax/backend/contrib/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""External backend codegen modules for Relax."""
+
+from .cutlass import partition_for_cutlass
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
new file mode 100644
index 0000000000..20cf57a40a
--- /dev/null
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern table for CUTLASS backend"""
+
+from tvm.relax import transform
+
+from ..pattern_registry import get_patterns_with_prefix, register_patterns
+from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern
+
+register_patterns(
+ [
+ (
+ "cutlass.conv2d",
+ make_fused_bias_activation_pattern(
+ "relax.nn.conv2d",
+ with_bias=False,
+ activation=None,
+ ),
+ ),
+ (
+ "cutlass.conv2d_bias_relu",
+ make_fused_bias_activation_pattern(
+ "relax.nn.conv2d",
+ with_bias=True,
+ activation="relax.nn.relu",
+ ),
+ ),
+ (
+ "cutlass.matmul",
+ make_matmul_pattern(
+ with_bias=False,
+ ),
+ ),
+ (
+ "cutlass.matmul_bias",
+ make_matmul_pattern(
+ with_bias=True,
+ ),
+ ),
+ (
+ "cutlass.matmul_bias_relu",
+ make_matmul_pattern(
+ with_bias=True,
+ activation="relax.nn.relu",
+ ),
+ ),
+ (
+ "cutlass.matmul_bias_gelu",
+ make_matmul_pattern(
+ with_bias=True,
+ activation="relax.nn.gelu",
+ ),
+ ),
+ ]
+)
+
+
+def partition_for_cutlass(mod):
+ """
+ Partition the input module into CUTLASS-supported subgraphs.
+
+ Parameters
+ ----------
+ mod: tvm.IRModule
+ The IRModule to be partitioned.
+
+ Returns
+ -------
+ mod: tvm.IRModule
+ The resulting IRModule, containing partitioned subgraphs to be
+ compiled by the CUTLASS backend.
+ """
+
+ cutlass_patterns = get_patterns_with_prefix("cutlass")
+ return transform.FuseOpsByPattern(cutlass_patterns,
annotate_codegen=True)(mod)
diff --git a/python/tvm/relax/backend/pattern_registry.py
b/python/tvm/relax/backend/pattern_registry.py
new file mode 100644
index 0000000000..0016de0a50
--- /dev/null
+++ b/python/tvm/relax/backend/pattern_registry.py
@@ -0,0 +1,125 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern registry for BYOC backends"""
+
+from typing import List, Mapping, Optional, Tuple, Union
+
+import tvm
+from tvm.relax.dpl import DFPattern
+from tvm.runtime import Object
+
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("relax.backend.PatternRegistryEntry")
+class PatternRegistryEntry(Object):
+ """
+ An entry in the pattern registry. This represents a single pattern that
+ can be used to identify expressions that can be handled by external
+ backends, like CUTLASS and TensorRT.
+
+ Parameters
+ ----------
+ name: str
+ The name of pattern. Usually it starts with the name of backend, like
'cutlass.matmul'.
+
+ pattern: DFPattern
+ The dataflow pattern that will be used to match expressions that can
be handled
+ by external backends.
+
+ arg_patterns: Mapping[str, DFPattern]
+ The mapping from arg name to its pattern. It can be used to extract
arg expression
+ from match result. All DFPattern in this map should be part of the
`pattern`.
+ """
+
+ name: str
+ pattern: DFPattern
+ arg_patterns: Mapping[str, DFPattern]
+
+ def __init__(self, name: str, pattern: DFPattern, arg_patterns:
Mapping[str, DFPattern]):
+ self.__init_handle_by_constructor__(
+ _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns #
type: ignore
+ )
+
+
+Pattern = Union[
+ PatternRegistryEntry,
+ Tuple[str, DFPattern],
+ Tuple[str, Tuple[DFPattern, Mapping[str, DFPattern]]],
+]
+
+
+def register_patterns(patterns: List[Pattern]):
+ """
+ Register patterns which will be used to partition the DataflowBlock into
+ subgraphs that are supported by external backends.
+
+ Parameters
+ ----------
+ patterns: List[Pattern]
+ Patterns to be registered. Patterns that appear later in the list have
+ higher priority when partitioning DataflowBlock.
+ """
+ entries = []
+ for item in patterns:
+ if isinstance(item, PatternRegistryEntry):
+ entries.append(item)
+ elif isinstance(item, tuple):
+ name, pattern_or_tuple = item
+ if isinstance(pattern_or_tuple, tuple):
+ pattern, arg_patterns = pattern_or_tuple
+ else:
+ pattern, arg_patterns = pattern_or_tuple, {}
+ entries.append(PatternRegistryEntry(name, pattern, arg_patterns))
+ else:
+ raise TypeError(f"Cannot register type {type(pattern)} as pattern")
+ _ffi_api.RegisterPatterns(entries)
+
+
+def get_patterns_with_prefix(prefix: str) -> List[PatternRegistryEntry]:
+ """
+ Get a list of patterns whose names startwith `prefix`.
+
+ Parameters
+ ----------
+ prefix: str
+ The prefix of pattern name.
+
+ Returns
+ -------
+ patterns: PatternRegistryEntry
+ Matched patterns, ordered by priority from high to low.
+ """
+ return _ffi_api.GetPatternsWithPrefix(prefix)
+
+
+def get_pattern(name: str) -> Optional[PatternRegistryEntry]:
+ """
+ Find the pattern with a particular name.
+
+ Parameters
+ ----------
+ name: str
+ The pattern name.
+
+ Returns
+ -------
+ pattern: Optional[PatternRegistryEntry]
+ The matched pattern. Returns None if such pattern is not found.
+ """
+ return _ffi_api.GetPattern(name)
diff --git a/python/tvm/relax/backend/patterns.py
b/python/tvm/relax/backend/patterns.py
new file mode 100644
index 0000000000..2f744af660
--- /dev/null
+++ b/python/tvm/relax/backend/patterns.py
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Common patterns used in BYOC"""
+
+from typing import Dict, Mapping, Tuple
+
+from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard
+
+
+def _with_bias_activation_pattern(
+ out: DFPattern,
+ args: Dict[str, DFPattern],
+ with_bias: bool = False,
+ activation: str = None,
+) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
+ if with_bias:
+ args["bias"] = bias = wildcard()
+ out = is_op("relax.add")(out, bias)
+
+ if activation:
+ out = is_op(activation)(out)
+
+ return out, args
+
+
+def make_fused_bias_activation_pattern(
+ op_name: str,
+ with_bias: bool = False,
+ activation: str = None,
+) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
+ """
+ A simple utility to create patterns for an operation fused with bias
addition and activation.
+
+ Parameters
+ ----------
+ op_name: str
+ The name of a Relax op, such as "relax.nn.conv2d"
+
+ with_bias: bool
+ Whether or not to include bias addition
+
+ activation: str
+ The name of an activation Relax op, such as "relax.nn.relu"
+
+ Returns
+ -------
+ pattern: DFPattern
+ The resulting pattern describing a fused operation
+
+ args: Mapping[str, DFPattern]
+ The mapping from arg name to its pattern. It can be used to extract
+ arg expression from match result.
+ """
+ lhs = wildcard()
+ rhs = wildcard()
+ args = {"lhs": lhs, "rhs": rhs}
+ out = is_op(op_name)(lhs, rhs)
+
+ return _with_bias_activation_pattern(out, args, with_bias, activation)
+
+
+def make_matmul_pattern(
+ with_bias: bool = False,
+ activation: str = None,
+ transposed_rhs: bool = False,
+) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
+ """
+ Create pattern for matrix multiplication.
+
+ Parameters
+ ----------
+ with_bias: bool
+ Whether or not to include bias addition
+
+ activation: str
+ The name of an activation Relax op, such as "relax.nn.relu"
+
+ transposed_rhs: bool
+ Whether the right hand side of multiplication is transposed.
+
+ Returns
+ -------
+ pattern: DFPattern
+ The resulting pattern describing a matrix multiplication.
+
+ args: Mapping[str, DFPattern]
+ The mapping from arg name to its pattern. It can be used to extract
+ arg expression from match result.
+ """
+
+ lhs = wildcard()
+ rhs = wildcard()
+ args = {"lhs": lhs, "rhs": rhs}
+
+ if transposed_rhs:
+ rhs = is_op("relax.permute_dims")(rhs)
+
+ out = is_op("relax.matmul")(lhs, rhs)
+
+ return _with_bias_activation_pattern(out, args, with_bias, activation)
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 44faa0c93a..9e1963f7ed 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -1046,17 +1046,6 @@ def _only_used_by(
return ffi.only_used_by(lhs, rhs, index) # type: ignore
-def _add_bias_activation_pattern(out, with_bias=False, activation=None):
- if with_bias:
- bias = wildcard()
- out = is_op("relax.add")(out, bias)
-
- if activation:
- return is_op(activation)(out)
-
- return out
-
-
def make_fused_bias_activation_pattern(op_name, with_bias=False,
activation=None):
"""
A simple utility to create patterns for an operation fused with bias
addition and activation.
@@ -1081,15 +1070,11 @@ def make_fused_bias_activation_pattern(op_name,
with_bias=False, activation=None
rhs = wildcard()
out = is_op(op_name)(lhs, rhs)
- return _add_bias_activation_pattern(out, with_bias, activation)
-
+ if with_bias:
+ bias = wildcard()
+ out = is_op("relax.add")(out, bias)
-def make_matmul_pattern(with_bias=False, activation=None, transposed_b=False):
- lhs = wildcard()
- if transposed_b:
- rhs = is_op("relax.permute_dims")(wildcard())
- else:
- rhs = wildcard()
- out = is_op("relax.matmul")(lhs, rhs)
+ if activation:
+ return is_op(activation)(out)
- return _add_bias_activation_pattern(out, with_bias, activation)
+ return out
diff --git a/src/relax/backend/pattern_registry.cc
b/src/relax/backend/pattern_registry.cc
new file mode 100644
index 0000000000..3ca7973365
--- /dev/null
+++ b/src/relax/backend/pattern_registry.cc
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./pattern_registry.h"
+
+#include "../../support/utils.h"
+
+namespace tvm {
+namespace relax {
+namespace backend {
+
+PatternRegistryEntry::PatternRegistryEntry(String name, DFPattern pattern,
+ Map<String, DFPattern>
arg_patterns) {
+ ObjectPtr<PatternRegistryEntryNode> n =
make_object<PatternRegistryEntryNode>();
+ n->name = std::move(name);
+ n->pattern = std::move(pattern);
+ n->arg_patterns = std::move(arg_patterns);
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(PatternRegistryEntryNode);
+
+static std::vector<PatternRegistryEntry>* GetRegistryTable() {
+ static std::vector<PatternRegistryEntry> table;
+ return &table;
+}
+
+void RegisterPatterns(Array<PatternRegistryEntry> entries) {
+ auto* table = GetRegistryTable();
+ for (const auto& entry : entries) {
+ table->push_back(entry);
+ }
+}
+
+Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix) {
+ auto* table = GetRegistryTable();
+ Array<PatternRegistryEntry> result;
+ for (auto it = table->rbegin(); it != table->rend(); ++it) {
+ if (support::StartsWith((*it)->name, prefix.data())) {
+ result.push_back(*it);
+ }
+ }
+ return result;
+}
+
+Optional<PatternRegistryEntry> GetPattern(const String& pattern_name) {
+ auto* table = GetRegistryTable();
+ for (auto it = table->rbegin(); it != table->rend(); ++it) {
+ if ((*it)->name == pattern_name) {
+ return *it;
+ }
+ }
+ return NullOpt;
+}
+
+TVM_REGISTER_GLOBAL("relax.backend.PatternRegistryEntry")
+ .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern>
arg_patterns) {
+ return PatternRegistryEntry(name, pattern, arg_patterns);
+ });
+TVM_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns);
+TVM_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix").set_body_typed(GetPatternsWithPrefix);
+TVM_REGISTER_GLOBAL("relax.backend.GetPattern").set_body_typed(GetPattern);
+
+} // namespace backend
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/backend/pattern_registry.h
b/src/relax/backend/pattern_registry.h
new file mode 100644
index 0000000000..2e199a2bb1
--- /dev/null
+++ b/src/relax/backend/pattern_registry.h
@@ -0,0 +1,106 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/backend/contrib/pattern_registry.h
+ * \brief Functions related to registering and retrieving patterns for
+ * functions handled by backends.
+ */
+#ifndef TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_
+#define TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_
+
+#include <tvm/relax/dataflow_pattern.h>
+#include <tvm/relax/expr.h>
+#include <tvm/runtime/container/optional.h>
+#include <tvm/runtime/object.h>
+
+namespace tvm {
+namespace relax {
+namespace backend {
+
+/*!
+ * \brief An entry in the pattern registry. This represents a single pattern
that
+ * can be used to identify expressions that can be handled by external
+ * backends, like CUTLASS and TensorRT.
+ */
+class PatternRegistryEntryNode : public Object {
+ public:
+ /*!
+ * \brief The name of pattern. Usually it starts with the name of backend,
like
+ * 'cutlass.matmul'.
+ */
+ String name;
+ /*!
+ * \brief The dataflow pattern that will be used to match expressions that
can
+ * be handled by external backends.
+ */
+ DFPattern pattern;
+ /*!
+ * \brief The mapping from arg name to its pattern. It can be used to extract
+ * arg expression from match result. All DFPattern in this map should be
part of
+ * the `pattern`.
+ */
+ Map<String, DFPattern> arg_patterns;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("name", &name);
+ v->Visit("pattern", &pattern);
+ v->Visit("arg_patterns", &arg_patterns);
+ }
+
+ static constexpr const char* _type_key =
"relax.backend.PatternRegistryEntry";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PatternRegistryEntryNode, Object);
+};
+
+class PatternRegistryEntry : public ObjectRef {
+ public:
+ PatternRegistryEntry(String name, DFPattern pattern, Map<String, DFPattern>
arg_patterns);
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternRegistryEntry, ObjectRef,
+ PatternRegistryEntryNode);
+};
+
+/*!
+ * \brief Register patterns which will be used to partition the DataflowBlock
+ * into subgraphs that are supported by external backends.
+ * \param patterns Patterns to be registered. Patterns that appear later in
the list have
+ * higher priority when partitioning DataflowBlock.
+ */
+void RegisterPatterns(Array<PatternRegistryEntry> entries);
+
+/*!
+ * \brief Find patterns whose name starts with a particular prefix.
+ * \param prefx The pattern name prefix.
+ * \return Matched patterns, ordered by priority from high to low.
+ */
+Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix);
+
+/*!
+ * \brief Find the pattern with a particular name.
+ * \param name The pattern name.
+ * \return The matched pattern. NullOpt if not found.
+ */
+Optional<PatternRegistryEntry> GetPattern(const String& name);
+
+} // namespace backend
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 5556d1e5d9..673155342c 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -23,7 +23,7 @@ import pytest
import tvm
import tvm.testing
from tvm import relax, relay
-from tvm.relax.dpl import make_fused_bias_activation_pattern,
make_matmul_pattern
+from tvm.relax.backend import get_patterns_with_prefix
from tvm.script import relax as R
@@ -219,7 +219,11 @@ cutlass_enabled = pytest.mark.skipif(
pytestmark = [cutlass_enabled]
-def get_result_with_relax_cutlass_offload(mod, patterns: List[Tuple], *args):
+def get_result_with_relax_cutlass_offload(mod, *args):
+ patterns = [(entry.name, entry.pattern) for entry in
get_patterns_with_prefix("cutlass")]
+
+ assert len(patterns) != 0, "Cannot find cutlass patterns"
+
seq = tvm.transform.Sequential(
[
relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True),
@@ -243,15 +247,7 @@ def test_conv2d_offload():
weight = np.random.randn(32, 3, 3, 16).astype("float16")
bias = np.random.randn(1, 1, 1, 32).astype("float16")
- patterns = [
- (
- "cutlass.conv2d_bias_relu",
- make_fused_bias_activation_pattern(
- "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu"
- ),
- )
- ]
- out = get_result_with_relax_cutlass_offload(Conv2dBiasReLU, patterns,
data, weight, bias)
+ out = get_result_with_relax_cutlass_offload(Conv2dBiasReLU, data, weight,
bias)
ref_relay_expr = get_relay_conv2d_bias_relu(data.shape, weight.shape)
ref = get_relay_ref(ref_relay_expr, data, weight, bias)
@@ -327,17 +323,8 @@ def matmul_bias(matmul_size, target_dtype):
def test_matmul_offload(matmul_x, matmul_y):
x, y = matmul_x, matmul_y
- patterns = [
- (
- "cutlass.matmul",
- make_matmul_pattern(
- with_bias=False,
- ),
- ),
- ]
-
mod = get_relax_matmul_module(x, y)
- out = get_result_with_relax_cutlass_offload(mod, patterns, x, y)
+ out = get_result_with_relax_cutlass_offload(mod, x, y)
ref_relay_expr = get_relay_matmul(x.shape, y.shape[::-1])
ref = get_relay_ref(ref_relay_expr, x, y.transpose())
@@ -347,16 +334,8 @@ def test_matmul_offload(matmul_x, matmul_y):
def test_matmul_bias_offload(matmul_x, matmul_y, matmul_bias):
x, y, bias = matmul_x, matmul_y, matmul_bias
- patterns = [
- (
- "cutlass.matmul_bias",
- make_matmul_pattern(
- with_bias=True,
- ),
- ),
- ]
mod = get_relax_matmul_module(x, y, with_bias=True)
- out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias)
+ out = get_result_with_relax_cutlass_offload(mod, x, y, bias)
ref_relay_expr = get_relay_matmul_bias(x.shape, y.shape[::-1])
ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias)
@@ -367,17 +346,8 @@ def test_matmul_bias_offload(matmul_x, matmul_y,
matmul_bias):
def test_matmul_bias_relu_offload(matmul_x, matmul_y, matmul_bias):
x, y, bias = matmul_x, matmul_y, matmul_bias
- patterns = [
- (
- "cutlass.matmul_bias_relu",
- make_matmul_pattern(
- with_bias=True,
- activation="relax.nn.relu",
- ),
- ),
- ]
mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.relu)
- out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias)
+ out = get_result_with_relax_cutlass_offload(mod, x, y, bias)
ref_relay_expr = get_relay_matmul_bias_relu(x.shape, y.shape[::-1])
ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias)
@@ -388,17 +358,8 @@ def test_matmul_bias_relu_offload(matmul_x, matmul_y,
matmul_bias):
def test_matmul_bias_gelu_offload(matmul_x, matmul_y, matmul_bias):
x, y, bias = matmul_x, matmul_y, matmul_bias
- patterns = [
- (
- "cutlass.matmul_bias_gelu",
- make_matmul_pattern(
- with_bias=True,
- activation="relax.nn.gelu",
- ),
- ),
- ]
mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.gelu)
- out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias)
+ out = get_result_with_relax_cutlass_offload(mod, x, y, bias)
ref_relay_expr = get_relay_matmul_bias_gelu(x.shape, y.shape[::-1])
ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias)
@@ -411,11 +372,7 @@ def test_kernel_sharing():
weight1_np = np.random.randn(16, 3, 3, 16).astype("float16")
weight2_np = np.random.randn(16, 3, 3, 16).astype("float16")
- pat = make_fused_bias_activation_pattern("relax.nn.conv2d",
with_bias=False, activation=None)
-
- out = get_result_with_relax_cutlass_offload(
- Conv2dx2, [("cutlass.conv2d", pat)], data_np, weight1_np, weight2_np
- )
+ out = get_result_with_relax_cutlass_offload(Conv2dx2, data_np, weight1_np,
weight2_np)
relay_expr = get_relay_conv2d_relu_x2(data_np.shape, weight1_np.shape)
ref = get_relay_ref(relay_expr, data_np, weight1_np, weight2_np)