This is an automated email from the ASF dual-hosted git repository.
MasterJH5574 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b7807dbc1b [Relax][TensorRT] Add partition_for_tensorrt and a pattern
table (#19820)
b7807dbc1b is described below
commit b7807dbc1b4af7351aaa734bc7349da7afa97fb1
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Jun 18 14:47:47 2026 -0400
[Relax][TensorRT] Add partition_for_tensorrt and a pattern table (#19820)
This pr is the follow-up pr to #19810 Add partition_for_tensorrt, which
offloads TensorRT-supported subgraphs from a module with a single call,
together with the pattern table whose composite names ("tensorrt.<op>")
match the runtime converter registry. This is the entry point other BYOC
backends expose as partition_for_<name>.
---
python/tvm/relax/backend/contrib/tensorrt.py | 140 +++++++++++++++++++++++++++
tests/python/relax/test_codegen_tensorrt.py | 33 +++++++
2 files changed, 173 insertions(+)
diff --git a/python/tvm/relax/backend/contrib/tensorrt.py
b/python/tvm/relax/backend/contrib/tensorrt.py
new file mode 100644
index 0000000000..303ebc394c
--- /dev/null
+++ b/python/tvm/relax/backend/contrib/tensorrt.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern table and partitioning for the TensorRT BYOC backend.
+
+The composite name of each pattern is "tensorrt.<op>", matching the runtime
+converter registered under the same name (the converters are keyed by
+"tensorrt." + op_name). ``partition_for_tensorrt`` carves the matched subgraphs
+out of the module and annotates them for the ``tensorrt`` codegen.
+"""
+
+from collections.abc import Mapping
+
+from tvm.ir import IRModule
+from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard
+from tvm.relax.transform import FuseOpsByPattern, MergeCompositeFunctions
+
+from ..pattern_registry import get_patterns_with_prefix, register_patterns
+
+Pattern = tuple[str, DFPattern, Mapping[str, DFPattern]]
+
+
+def _op_pattern(composite_name: str, op_name: str, num_args: int) -> Pattern:
+ """A pattern matching a single op called with ``num_args`` wildcard
arguments."""
+ args = [wildcard() for _ in range(num_args)]
+ return (composite_name, is_op(op_name)(*args), {})
+
+
+def _tensorrt_patterns() -> list[Pattern]:
+ patterns: list[Pattern] = []
+
+ # Activations and unary elementwise ops (single tensor argument).
+ for composite, op in [
+ ("tensorrt.nn.relu", "relax.nn.relu"),
+ ("tensorrt.sigmoid", "relax.sigmoid"),
+ ("tensorrt.tanh", "relax.tanh"),
+ ("tensorrt.exp", "relax.exp"),
+ ("tensorrt.log", "relax.log"),
+ ("tensorrt.sqrt", "relax.sqrt"),
+ ("tensorrt.abs", "relax.abs"),
+ ("tensorrt.negative", "relax.negative"),
+ ("tensorrt.sin", "relax.sin"),
+ ("tensorrt.cos", "relax.cos"),
+ ("tensorrt.atan", "relax.atan"),
+ ("tensorrt.ceil", "relax.ceil"),
+ ("tensorrt.floor", "relax.floor"),
+ ("tensorrt.erf", "relax.erf"),
+ ("tensorrt.nn.softmax", "relax.nn.softmax"),
+ ("tensorrt.nn.batch_flatten", "relax.nn.batch_flatten"),
+ ("tensorrt.expand_dims", "relax.expand_dims"),
+ ("tensorrt.squeeze", "relax.squeeze"),
+ ("tensorrt.transpose", "relax.permute_dims"),
+ ("tensorrt.layout_transform", "relax.layout_transform"),
+ ("tensorrt.nn.max_pool2d", "relax.nn.max_pool2d"),
+ ("tensorrt.nn.avg_pool2d", "relax.nn.avg_pool2d"),
+ ("tensorrt.nn.max_pool3d", "relax.nn.max_pool3d"),
+ ("tensorrt.nn.avg_pool3d", "relax.nn.avg_pool3d"),
+ ("tensorrt.nn.adaptive_avg_pool2d", "relax.nn.adaptive_avg_pool2d"),
+ ("tensorrt.sum", "relax.sum"),
+ ("tensorrt.prod", "relax.prod"),
+ ("tensorrt.max", "relax.max"),
+ ("tensorrt.min", "relax.min"),
+ ("tensorrt.mean", "relax.mean"),
+ ("tensorrt.concatenate", "relax.concat"),
+ ("tensorrt.split", "relax.split"),
+ ]:
+ patterns.append(_op_pattern(composite, op, 1))
+
+ # Binary elementwise ops (two tensor arguments).
+ for composite, op in [
+ ("tensorrt.add", "relax.add"),
+ ("tensorrt.subtract", "relax.subtract"),
+ ("tensorrt.multiply", "relax.multiply"),
+ ("tensorrt.divide", "relax.divide"),
+ ("tensorrt.power", "relax.power"),
+ ("tensorrt.maximum", "relax.maximum"),
+ ("tensorrt.minimum", "relax.minimum"),
+ ]:
+ patterns.append(_op_pattern(composite, op, 2))
+
+ # Convolutions and matmul (data + weight).
+ for composite, op in [
+ ("tensorrt.nn.conv1d", "relax.nn.conv1d"),
+ ("tensorrt.nn.conv2d", "relax.nn.conv2d"),
+ ("tensorrt.nn.conv3d", "relax.nn.conv3d"),
+ ("tensorrt.nn.conv2d_transpose", "relax.nn.conv2d_transpose"),
+ ("tensorrt.nn.conv3d_transpose", "relax.nn.conv3d_transpose"),
+ ("tensorrt.nn.batch_matmul", "relax.matmul"),
+ ("tensorrt.reshape", "relax.reshape"),
+ ]:
+ patterns.append(_op_pattern(composite, op, 2))
+
+ # layer_norm (data, gamma, beta) and clip (data, min, max).
+ patterns.append(_op_pattern("tensorrt.nn.layer_norm",
"relax.nn.layer_norm", 3))
+ patterns.append(_op_pattern("tensorrt.clip", "relax.clip", 3))
+
+ # strided_slice is called either with or without the optional strides
argument.
+ patterns.append(_op_pattern("tensorrt.strided_slice",
"relax.strided_slice", 5))
+ patterns.append(_op_pattern("tensorrt.strided_slice",
"relax.strided_slice", 4))
+
+ return patterns
+
+
+register_patterns(_tensorrt_patterns())
+
+
+def partition_for_tensorrt(mod: IRModule) -> IRModule:
+ """Partition the module, offloading TensorRT-supported subgraphs.
+
+ Parameters
+ ----------
+ mod : tvm.ir.IRModule
+ The module to partition. Bind model parameters (e.g. via
+ ``relax.transform.BindParams``) before calling this so that weights are
+ available to TensorRT as constants.
+
+ Returns
+ -------
+ mod : tvm.ir.IRModule
+ The module with TensorRT-supported subgraphs grouped into composite
+ functions annotated for the ``tensorrt`` codegen.
+ """
+ patterns = get_patterns_with_prefix("tensorrt")
+ mod = FuseOpsByPattern(patterns, bind_constants=True,
annotate_codegen=False)(mod)
+ mod = MergeCompositeFunctions()(mod)
+ return mod
diff --git a/tests/python/relax/test_codegen_tensorrt.py
b/tests/python/relax/test_codegen_tensorrt.py
index 14d3394a48..1afe4cc174 100644
--- a/tests/python/relax/test_codegen_tensorrt.py
+++ b/tests/python/relax/test_codegen_tensorrt.py
@@ -596,5 +596,38 @@ def test_tensorrt_split_indices():
_offload_and_compare(SplitIdx, {}, patterns, data)
+def test_partition_for_tensorrt():
+ # End-to-end test of the partition_for_tensorrt entry point: it should
offload the
+ # conv2d -> relu subgraph to TensorRT with a single call.
+ from tvm.relax.backend.contrib.tensorrt import partition_for_tensorrt
+
+ @tvm.script.ir_module
+ class Model:
+ @R.function
+ def main(
+ data: R.Tensor((1, 8, 16, 16), "float32"), weight: R.Tensor((16,
8, 3, 3), "float32")
+ ):
+ with R.dataflow():
+ conv = relax.op.nn.conv2d(data, weight, padding=1)
+ out = relax.op.nn.relu(conv)
+ R.output(out)
+ return out
+
+ data = np.random.randn(1, 8, 16, 16).astype("float32")
+ weight = np.random.randn(16, 8, 3, 3).astype("float32")
+ ref = build_and_run(Model, [data, weight], "llvm", legalize=True)
+
+ mod = relax.transform.BindParams("main", {"weight": weight})(Model)
+ mod = partition_for_tensorrt(mod)
+ assert any(
+ isinstance(fn, relax.Function) and fn.attrs is not None and "Codegen"
in fn.attrs
+ for fn in mod.functions.values()
+ ), "expected partition_for_tensorrt to offload a subgraph to TensorRT"
+
+ mod = relax.transform.RunCodegen()(mod)
+ out = build_and_run(mod, [data], "cuda")
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
if __name__ == "__main__":
tvm.testing.main()