This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 290395a  [microNPU] Refactor Relay to TIR hook (#10599)
290395a is described below

commit 290395a87be1c4be0b426928841af4cbbca069cf
Author: Luke Hutton <[email protected]>
AuthorDate: Mon Mar 21 17:19:23 2022 +0000

    [microNPU] Refactor Relay to TIR hook (#10599)
    
    * [microNPU] Refactor Relay to TIR hook
    
    Refactors the Relay to TIR python hook for the NPU so that optimizations
    can be applied across the whole module and not just functions that will
    be offloaded to the NPU. A pass `OutlineCompilerFunctions` is introduced
    to outline NPU functions, which now happens before optimization passes
    are run (this previously happened after the prim_func had been created).
    
    In addition, optimization passes that should only run on NPU functions
    are now limited to running on outlined functions for the NPU (by
    checking the "Compiler" attribute). To help avoid code duplication, a
    helpful decorator `create_npu_function_pass` has been created for python
    passes that should only run on NPU functions.
    
    This refactor helps move a number of passes in the microNPU codegen to
    use an IRModule -> IRModule philosophy.
    
    Change-Id: Icdea9ba43da0157d5ee17529d2b23b761396d112
    
    * add mixed compilers to test
    
    Change-Id: I3ca48738e096bb0f4dc362f0e9550317fc0d5afd
    
    * Address comments including renaming both npu_pass and RelayToTIR
    
    This commit renames `npu_pass` -> `create_npu_function_pass`.
    
    It also renames the `RelayToTIR` pass created in Python to `LowerToTIR`,
    along with moving it to compiler.py to make it clear that this pass is a
    wrapper around the `_lower_to_tir` function. In addition, to make it
    explicit that the `lower_to_tir` func->func pass should not be used
    directly it has been renamed to `_lower_to_tir` - it is being maintained
    since it is used in many tests.
    
    Change-Id: I3a0a06801f029aeaa4a51c2d86d8703bb0d7afbb
    
    * address nit and small fix to example
    
    Change-Id: I44c64de15fa8680cc89ce0440ffa6c9e0ec62a50
---
 python/tvm/relay/backend/contrib/ethosu/codegen.py |  98 +++--
 .../tvm/relay/backend/contrib/ethosu/legalize.py   | 484 ++-------------------
 .../relay/backend/contrib/ethosu/tir/compiler.py   |  38 +-
 python/tvm/relay/backend/contrib/ethosu/util.py    |  61 +++
 src/relay/backend/contrib/ethosu/codegen.cc        |  99 +++--
 tests/python/contrib/test_ethosu/test_compiler.py  |   4 +-
 .../contrib/test_ethosu/test_encode_constants.py   |  12 +-
 .../contrib/test_ethosu/test_identity_optimizer.py |  15 +-
 .../contrib/test_ethosu/test_layout_optimizer.py   |  13 +-
 tests/python/contrib/test_ethosu/test_legalize.py  |  28 +-
 .../contrib/test_ethosu/test_lut_optimizer.py      |  10 +-
 .../test_ethosu/test_outline_compiler_functions.py |  86 ++++
 .../test_ethosu/test_remove_concatenates.py        |   4 +-
 .../test_ethosu/test_replace_binary_elementwise.py |   6 +-
 .../contrib/test_ethosu/test_replace_conv2d.py     |  12 +-
 .../contrib/test_ethosu/test_replace_copy.py       |   6 +-
 .../test_ethosu/test_replace_depthwise_conv2d.py   |   4 +-
 .../contrib/test_ethosu/test_replace_identity.py   |   4 +-
 .../contrib/test_ethosu/test_replace_pooling.py    |   6 +-
 .../test_ethosu/test_replace_unary_elementwise.py  |   4 +-
 tests/python/contrib/test_ethosu/test_scheduler.py |   4 +-
 21 files changed, 405 insertions(+), 593 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py 
b/python/tvm/relay/backend/contrib/ethosu/codegen.py
index f968d6a..e8b5cc2 100644
--- a/python/tvm/relay/backend/contrib/ethosu/codegen.py
+++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py
@@ -19,8 +19,7 @@ from collections import defaultdict
 
 import tvm
 from tvm import relay
-from tvm import ir
-from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from tvm.relay.backend.contrib.ethosu.tir.compiler import LowerToTIR
 from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants
 from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
 from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
@@ -112,30 +111,24 @@ class OptimizeLUTs(ExprMutator):
         return new_call
 
 
[email protected]_pass(opt_level=1, name="LUTsOptimizer")
[email protected]_npu_function_pass(opt_level=1)
 class LUTsOptimizer:
     """Register LUTsOptimizer as a relay pass."""
 
-    def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule:
-        """Visit relay nodes in the given module.
+    def transform_npu_function(self, _, func: relay.Function) -> 
relay.Function:
+        """Visit relay nodes in the given NPU function.
 
         Parameters
         ----------
         func : tvm.relay.function.Function
             The function to apply the optimization pass for multiple LUTs to.
-        mod : tvm.IRModule
-            The module to apply the optimization pass for multiple LUTs to.
 
         Returns
         -------
         mod : tvm.IRModule
             New module with optimized LUTs.
         """
-        assert len(mod.functions.items()) == 1, "Module can only contain one 
function."
-        global_var, func = mod.functions.items()[0]
-        optimized_func = OptimizeLUTs().visit(func)
-        mod.update_func(global_var, optimized_func)
-        return mod
+        return OptimizeLUTs().visit(func)
 
     def __call__(self, *args, **kwargs):
         pass
@@ -272,30 +265,27 @@ class LayoutOptimization(ExprMutator):
         return super().visit_call(call)
 
 
[email protected]_pass(opt_level=1, name="LayoutOptimizer")
[email protected]_npu_function_pass(opt_level=1)
 class LayoutOptimizer:
     """Register LayoutOptimizer as a Relay pass."""
 
-    OPTIMIZE_OPS = {
-        "contrib.ethosu.conv2d": op.ethosu_conv2d,
-        "contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
-        "contrib.ethosu.pooling": op.ethosu_pooling,
-        "contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
-        "contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise,
-    }
-
-    def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule:
+    def transform_npu_function(self, _, func: relay.Function) -> 
relay.Function:
         """A pass to optimize the layout of NPU operations. If both the
         producer and consumer of a tensor are NPU operators, then the
         layout is converted from NHWC to NHCWB16 as this is the layout NPU
         uses internally."""
-        assert len(mod.functions.items()) == 1, "Module can only contain one 
function."
-        global_var, func = mod.functions.items()[0]
-        analyze = AnalyzeConsumers(self.OPTIMIZE_OPS)
+
+        optimize_ops = {
+            "contrib.ethosu.conv2d": op.ethosu_conv2d,
+            "contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
+            "contrib.ethosu.pooling": op.ethosu_pooling,
+            "contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
+            "contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise,
+        }
+
+        analyze = AnalyzeConsumers(optimize_ops)
         analyze.visit(func)
-        optimized_func = LayoutOptimization(analyze.npu_consumers, 
self.OPTIMIZE_OPS).visit(func)
-        mod.update_func(global_var, optimized_func)
-        return mod
+        return LayoutOptimization(analyze.npu_consumers, 
optimize_ops).visit(func)
 
     def __call__(self, *args, **kwargs):
         pass
@@ -312,6 +302,22 @@ def IdentityOptimizer():  # pylint: disable=invalid-name
     return _ffi_api.IdentityOptimizer()
 
 
+def OutlineCompilerFunctions(compiler_name):  # pylint: disable=invalid-name
+    """Pass that outlines functions given a named Compiler attribute.
+
+    Parameters
+    ----------
+    compiler_name
+        The name of the compiler to look for and outline.
+
+    Return
+    ------
+    Pass
+        The module pass.
+    """
+    return _ffi_api.OutlineCompilerFunctions(compiler_name)
+
+
 @tvm._ffi.register_func("relay.ext.ethos-u.constant_updater")
 def constant_updater(expr, symbol):  # pylint: disable=unused-argument
     """
@@ -322,43 +328,41 @@ def constant_updater(expr, symbol):  # pylint: 
disable=unused-argument
     return dict()
 
 
-@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir_func")
-def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
+@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir")
+def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
     """
-    This is the hook for python-based lowering of relay function
-    that gets offloaded to the microNPU.
+    This is the hook for python-based lowering of a Relay module which lowers 
NPU
+    external functions to TIR.
 
     Parameters
     ----------
-    ext_func : relay.Function
-        This is the partitioned relay function
+    mod : tvm.ir.IRModule
+        This is the Relay module.
 
     Returns
     -------
-    primfunc : tir.PrimFunc
-        This returns the scheduled PrimFunc
+    mod : tvm.ir.IRModule
+        The Relay module with scheduled NPU external functions.
     """
-    assert len(ext_func.params) == 1
-    mod = tvm.IRModule()
-    mod["main"] = ext_func
+    mod = OutlineCompilerFunctions("ethos-u")(mod)
     mod = LegalizeEthosU()(mod)
     mod = LUTsOptimizer()(mod)
     mod = IdentityOptimizer()(mod)
     mod = LayoutOptimizer()(mod)
     mod = relay.transform.InferType()(mod)
+
+    device_contexts = {
+        gv: "ethos-u" for gv, _ in filter(lambda x: util.is_npu_func(x[1]), 
mod.functions.items())
+    }
+    mod = mod.with_attr("device_contexts", device_contexts)
+
     # We are currently using copy_constants scheduler In the long run,
     # this should be a single intelligent and a composite scheduler
     # that can perform scheduling based on user inputs such as
     # scratch memory size.
-    tir_mod, const_dict = lower_to_tir(mod["main"], copy_constants())
-
-    for param in const_dict.keys():
-        const_dict[param] = tvm.nd.array(const_dict[param])
+    mod = LowerToTIR(copy_constants)(mod)
 
-    primfunc = tir_mod["main"]
-    primfunc = primfunc.with_attr("global_symbol", 
ext_func.attrs["global_symbol"])
-    primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
-    return primfunc
+    return mod
 
 
 @tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact")
diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py 
b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index 3fdcdb6..6f37b90 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -23,7 +23,6 @@ import numpy as np  # type: ignore
 
 import tvm  # type: ignore
 from tvm import relay
-from tvm import ir
 from tvm.relay.dataflow_pattern import DFPatternCallback  # type: ignore
 from tvm.relay.dataflow_pattern import wildcard
 from tvm.relay.dataflow_pattern import is_op
@@ -127,23 +126,6 @@ class PartitionedSplitRewriter(DFPatternCallback):
         return relay.op.split(split_input, indices_or_sections, 
axis=axis).astuple()
 
 
[email protected]_pass(opt_level=1)
-class LegalizeSplit:
-    """This is the pass that wraps SplitRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(PartitionedSplitRewriter(), func)
-            func = rewrite(SplitRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 def get_lut_from_func(
     ifm_scale: float,
     ifm_zp: int,
@@ -244,22 +226,6 @@ class TanhRewriter(LutActivationRewriter):
         )
 
 
[email protected]_pass(opt_level=1)
-class LegalizeTanh:
-    """This is the pass that wraps TanhRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(TanhRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 def sigmoid_calc_func(x: float) -> float:
     """Function to calculate the values for sigmoid"""
     # These limits are inherited from TFLite
@@ -286,22 +252,6 @@ class SigmoidRewriter(LutActivationRewriter):
         )
 
 
[email protected]_pass(opt_level=1)
-class LegalizeSigmoid:
-    """This is the pass that wraps SigmoidRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(SigmoidRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 def leaky_relu_calc_func(x: float, alpha: float) -> float:
     """Function to calculate the values for leaky relu."""
     return x if x >= 0 else x * alpha
@@ -322,22 +272,6 @@ class LeakyReLURewriter(LutActivationRewriter):
         return {"alpha": params.alpha}
 
 
[email protected]_pass(opt_level=1)
-class LegalizeLeakyReLU:
-    """This is the pass that wraps LeakyReLURewriter."""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(LeakyReLURewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class Conv2DRewriter(DFPatternCallback):
     """Convert conv2d related composite functions into ethosu_conv2d 
operators"""
 
@@ -405,22 +339,6 @@ class Conv2DRewriter(DFPatternCallback):
         return ethosu_conv2d
 
 
[email protected]_pass(opt_level=1)
-class LegalizeConv2D:
-    """This is the pass that wraps the Conv2DRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(Conv2DRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class Conv2DTransposeRewriter(DFPatternCallback):
     """Convert conv2d_transpose related composite functions into
     ethosu_conv2d_transpose operators."""
@@ -486,22 +404,6 @@ class Conv2DTransposeRewriter(DFPatternCallback):
         return relay.strided_slice(reduced_op, (0, 0, 0, 0), ofm_shape)
 
 
[email protected]_pass(opt_level=1)
-class LegalizeConv2DTranspose:
-    """This is the pass that wraps the Conv2DTransposeRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(Conv2DTransposeRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class DepthwiseConv2DRewriter(DFPatternCallback):
     """Convert ethosu.qnn_depthwise_conv2d composite functions to 
ethosu_depthwise_conv2d
     operators"""
@@ -576,22 +478,6 @@ class DepthwiseConv2DRewriter(DFPatternCallback):
         return ethosu_depthwise_conv2d
 
 
[email protected]_pass(opt_level=1)
-class LegalizeDepthwiseConv2D:
-    """This is the pass that wraps the DepthwiseConv2DRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(DepthwiseConv2DRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class PoolingRewriter(DFPatternCallback):
     """Convert ethosu.avgpool2d and ethosu.maxpool2d composite functions to
     ethosu_pooling operators"""
@@ -658,22 +544,6 @@ class MaxPoolingRewriter(PoolingRewriter):
         )
 
 
[email protected]_pass(opt_level=1)
-class LegalizeMaxPooling:
-    """This is the pass that wraps the MaxPoolingRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(MaxPoolingRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class AvgPoolingRewriter(PoolingRewriter):
     def __init__(self):
         super().__init__(
@@ -684,22 +554,6 @@ class AvgPoolingRewriter(PoolingRewriter):
         )
 
 
[email protected]_pass(opt_level=1)
-class LegalizeAvgPooling:
-    """This is the pass that wraps the AvgPoolingRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(AvgPoolingRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class BinaryElementwiseRewriter(DFPatternCallback):
     """Convert ethosu binary elementwise composite functions to
     ethosu_binary_elementwise operators"""
@@ -826,22 +680,6 @@ class AddRewriter(BinaryElementwiseRewriter):
         )
 
 
[email protected]_pass(opt_level=1)
-class LegalizeAdd:
-    """This is the pass that wraps the AddRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(AddRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class SubRewriter(BinaryElementwiseRewriter):
     def __init__(self):
         super().__init__(
@@ -852,22 +690,6 @@ class SubRewriter(BinaryElementwiseRewriter):
         )
 
 
[email protected]_pass(opt_level=1)
-class LegalizeSub:
-    """This is the pass that wraps the SubRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(SubRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class MulRewriter(BinaryElementwiseRewriter):
     def __init__(self):
         super().__init__(
@@ -878,22 +700,6 @@ class MulRewriter(BinaryElementwiseRewriter):
         )
 
 
[email protected]_pass(opt_level=1)
-class LegalizeMul:
-    """This is the pass that wraps the MulRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(MulRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class MinRewriter(BinaryElementwiseRewriter):
     def __init__(self):
         super().__init__(
@@ -904,22 +710,6 @@ class MinRewriter(BinaryElementwiseRewriter):
         )
 
 
[email protected]_pass(opt_level=1)
-class LegalizeMin:
-    """This is the pass that wraps the MinRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(MinRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class MaxRewriter(BinaryElementwiseRewriter):
     def __init__(self):
         super().__init__(
@@ -930,22 +720,6 @@ class MaxRewriter(BinaryElementwiseRewriter):
         )
 
 
[email protected]_pass(opt_level=1)
-class LegalizeMax:
-    """This is the pass that wraps the MaxRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(MaxRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class ShlRewriter(BinaryElementwiseRewriter):
     def __init__(self):
         super().__init__(
@@ -956,22 +730,6 @@ class ShlRewriter(BinaryElementwiseRewriter):
         )
 
 
[email protected]_pass(opt_level=1)
-class LegalizeShl:
-    """This is the pass that wraps the ShlRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(ShlRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class StridedSliceRewriter(DFPatternCallback):
     """This pass brings the strided slice out of the partitioned function"""
 
@@ -1005,22 +763,6 @@ class StridedSliceRewriter(DFPatternCallback):
         return strided_slice
 
 
[email protected]_pass(opt_level=1)
-class LegalizeStridedSlice:
-    """This is the pass that wraps StridedSliceRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(StridedSliceRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class ReshapeRewriter(DFPatternCallback):
     """This pass brings the reshape out of the partitioned function"""
 
@@ -1039,22 +781,6 @@ class ReshapeRewriter(DFPatternCallback):
         return relay.op.reshape(reshape_input, newshape=new_shape)
 
 
[email protected]_pass(opt_level=1)
-class LegalizeReshape:
-    """This is the pass that wraps ReshapeRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(ReshapeRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class NoOpRewriter(DFPatternCallback):
     """This pass adds an idenity operator to reshape and strided slice to 
avoid a no op
     without a consumer"""
@@ -1073,22 +799,6 @@ class NoOpRewriter(DFPatternCallback):
         return ethosu_ops.ethosu_identity(ifm=post, lut=relay.const([], 
dtype="int8"))
 
 
[email protected]_pass(opt_level=1)
-class LegalizeNoOps:
-    """This is the pass that wraps RewriteNoOps"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(NoOpRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class UnaryElementwiseRewriter(DFPatternCallback):
     """
     Convert ethosu unary elementwise composite function to
@@ -1160,22 +870,6 @@ class AbsRewriter(UnaryElementwiseRewriter):
         )
 
 
[email protected]_pass(opt_level=1)
-class LegalizeAbs:
-    """This is the pass that wraps the AbsRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(AbsRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class MeanRewriter(DFPatternCallback):
     """Convert ethosu.mean composite functions to to an equivalent 
legalization:
     - Case 1 (axis == [1, 2] and keepsdims == True):
@@ -1324,22 +1018,6 @@ class MeanRewriter(DFPatternCallback):
         return reduced_op
 
 
[email protected]_pass(opt_level=1)
-class LegalizeMean:
-    """This is the pass that wraps the MeanRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(MeanRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class ConcatRewriter(DFPatternCallback):
     """The newer versions of TFLite converters return a concatenate operator 
that concatenates
     tensors with same QNN params (if the QNN params of tensors were initially 
different,
@@ -1366,22 +1044,6 @@ class ConcatRewriter(DFPatternCallback):
         return concat
 
 
[email protected]_pass(opt_level=1)
-class LegalizeConcat:
-    """This is the pass that wraps ConcatRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(ConcatRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class RequantizeRewriter(DFPatternCallback):
     """Convert ethos-u.requantize composite function to an identity 
operation."""
 
@@ -1409,22 +1071,6 @@ class RequantizeRewriter(DFPatternCallback):
         )
 
 
[email protected]_pass(opt_level=1)
-class LegalizeRequantize:
-    """This is the pass that wraps RequantizeRewriter."""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(RequantizeRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class Resize2dRewriter(DFPatternCallback):
     """
     Convert ethos-u.resize2d composite function to an equivalent operation that
@@ -1504,22 +1150,6 @@ class Resize2dRewriter(DFPatternCallback):
         return total_padding
 
 
[email protected]_pass(opt_level=1)
-class LegalizeResize2d:
-    """This is the pass that wraps Resize2dRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(Resize2dRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class ExpandDimsRewriter(DFPatternCallback):
     """Legalize expand dims to a reshape operator."""
 
@@ -1536,22 +1166,6 @@ class ExpandDimsRewriter(DFPatternCallback):
         return relay.op.reshape(post.args[0], newshape=params.output.shape)
 
 
[email protected]_pass(opt_level=1)
-class LegalizeExpandDims:
-    """This is the pass that wraps ExpandDimsRewriter."""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(ExpandDimsRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class SqueezeRewriter(DFPatternCallback):
     """Legalize squeeze to a reshape operator."""
 
@@ -1568,22 +1182,6 @@ class SqueezeRewriter(DFPatternCallback):
         return relay.op.reshape(post.args[0], newshape=params.output.shape)
 
 
[email protected]_pass(opt_level=1)
-class LegalizeSqueeze:
-    """This is the pass that wraps SqueezeRewriter."""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(SqueezeRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
 class FullyConnectedRewriter(DFPatternCallback):
     """Legalize Fully Connected (with bias and clip) to an NPU operator"""
 
@@ -1654,62 +1252,50 @@ class FullyConnectedRewriter(DFPatternCallback):
         return ethosu_fc
 
 
[email protected]_pass(opt_level=1)
-class LegalizeFullyConnected:
-    """This is the pass that wraps the FullyConnectedRewriter"""
-
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
-        for global_var, func in mod.functions.items():
-            func = rewrite(FullyConnectedRewriter(), func)
-            mod.update_func(global_var, func)
-        return mod
-
-    def __call__(self, *args, **kwargs):
-        pass
-
-
[email protected]_pass(opt_level=1)
[email protected]_npu_function_pass(opt_level=1)
 class LegalizeEthosU:
     """This is the pass to call graph-rewrites to perform graph transformation
     in a way such that the operations are replaced with hardware/codegen 
supported
     operations.
     """
 
-    def transform_module(
-        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
-    ) -> tvm.ir.IRModule:
+    def transform_npu_function(self, _, func: relay.Function) -> 
relay.Function:
         """This is the method that replaces the operations with 
hardware/codegen supported
         operations.
         """
-        mod = LegalizeSplit()(mod)
-        mod = LegalizeConv2D()(mod)
-        mod = LegalizeConv2DTranspose()(mod)
-        mod = LegalizeDepthwiseConv2D()(mod)
-        mod = LegalizeMaxPooling()(mod)
-        mod = LegalizeAvgPooling()(mod)
-        mod = LegalizeAdd()(mod)
-        mod = LegalizeSub()(mod)
-        mod = LegalizeMul()(mod)
-        mod = LegalizeMin()(mod)
-        mod = LegalizeMax()(mod)
-        mod = LegalizeShl()(mod)
-        mod = LegalizeAbs()(mod)
-        mod = LegalizeTanh()(mod)
-        mod = LegalizeLeakyReLU()(mod)
-        mod = LegalizeMean()(mod)
-        mod = LegalizeConcat()(mod)
-        mod = LegalizeSigmoid()(mod)
-        mod = LegalizeRequantize()(mod)
-        mod = LegalizeResize2d()(mod)
-        mod = LegalizeExpandDims()(mod)
-        mod = LegalizeSqueeze()(mod)
-        mod = LegalizeReshape()(mod)
-        mod = LegalizeStridedSlice()(mod)
-        mod = LegalizeFullyConnected()(mod)
-        mod = LegalizeNoOps()(mod)
-        return mod
+        rewriters = [
+            PartitionedSplitRewriter(),
+            SplitRewriter(),
+            Conv2DRewriter(),
+            Conv2DTransposeRewriter(),
+            DepthwiseConv2DRewriter(),
+            FullyConnectedRewriter(),
+            MaxPoolingRewriter(),
+            AvgPoolingRewriter(),
+            AddRewriter(),
+            SubRewriter(),
+            MulRewriter(),
+            MinRewriter(),
+            MaxRewriter(),
+            ShlRewriter(),
+            AbsRewriter(),
+            TanhRewriter(),
+            LeakyReLURewriter(),
+            MeanRewriter(),
+            ConcatRewriter(),
+            SigmoidRewriter(),
+            RequantizeRewriter(),
+            Resize2dRewriter(),
+            ExpandDimsRewriter(),
+            SqueezeRewriter(),
+            ReshapeRewriter(),
+            StridedSliceRewriter(),
+            NoOpRewriter(),
+        ]
+        for rewriter in rewriters:
+            func = rewrite(rewriter, func)
+
+        return func
 
     def __call__(self, *args, **kwargs):
         # pylint is unable figure out the decorated
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py 
b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
index bdc3b31..aa15d91 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
@@ -23,6 +23,7 @@ from tvm.driver.build_module import schedule_to_module
 
 from . import passes as ethosu_passes
 from .scheduler import schedule
+from .. import util
 
 
 def lower_ethosu(sch, args, const_dict, name="main"):
@@ -172,7 +173,42 @@ def extract_constants(func):
     return new_func, const_dict
 
 
-def lower_to_tir(func, cascader=None):
[email protected]_npu_function_pass(opt_level=1)
+class LowerToTIR:
+    """A pass that lowers NPU Relay functions to TIR. This pass wraps
+    the _lower_to_tir pass that operates function->function, while this
+    is IRModule->IRModule.
+
+    Attributes
+    ----------
+    scheduler : callable
+        A function to schedule NPU operations. For example,
+        scheduler.py/copy_constants.
+    """
+
+    def __init__(self, scheduler):
+        self.scheduler = scheduler
+
+    def transform_npu_function(self, _, func: relay.Function) -> 
relay.Function:
+        """Lower NPU functions to TIR."""
+
+        tir_mod, const_dict = _lower_to_tir(func, self.scheduler())
+
+        for param in const_dict.keys():
+            const_dict[param] = tvm.nd.array(const_dict[param])
+
+        compiler_name = "ethos-u"
+        primfunc = tir_mod["main"]
+        primfunc = primfunc.with_attr("global_symbol", 
func.attrs["global_symbol"])
+        primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
+        primfunc = primfunc.with_attr("target", 
tvm.target.Target(compiler_name))
+        return primfunc
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
+def _lower_to_tir(func, cascader=None):
     """Lower a Relay function to TIR for the Arm(R) Ethos(TM)-U NPU target.
 
     The Relay function should only contain operations supported
diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py 
b/python/tvm/relay/backend/contrib/ethosu/util.py
index dffc237..64c561e 100644
--- a/python/tvm/relay/backend/contrib/ethosu/util.py
+++ b/python/tvm/relay/backend/contrib/ethosu/util.py
@@ -143,6 +143,11 @@ class QDenseArgs(Enum):
     WEIGHTS_SCALE = 5
 
 
+def is_npu_func(func: relay.Function) -> bool:
+    """Check if the given function is an NPU function."""
+    return func.attrs and "Compiler" in func.attrs and func.attrs["Compiler"] 
== "ethos-u"
+
+
 def is_composite_func(func: relay.Function, name: str) -> bool:
     """
     This method checks whether the call is to
@@ -313,3 +318,59 @@ class CompilationArtifact(Object):
             encoded_constants,
             base_addresses,
         )
+
+
+def create_npu_function_pass(opt_level: int, name: str = ""):
+    """
+    A utility decorator that wraps a given class as an NPU function pass. That 
is,
+    a pass that behaves like a function pass and only traverses NPU external
+    functions. How each NPU function is mutated is defined by the
+    `transform_npu_function(global_variable, relay_function)` function which 
should
+    be created in the class that is to be decorated. See the example below.
+
+    Example
+    -------
+    This small example demonstrates a pass over NPU functions that performs no
+    mutation.
+
+    @create_npu_function_pass(opt_level=1)
+    class MyPass:
+        def transform_npu_function(self, global_var, func):
+            return func
+
+    mod = tvm.IRModule()
+    mod = MyPass()(mod)
+
+    Parameters
+    ----------
+    opt_level: int
+        Optimization level for the module pass.
+    name: str, optional
+        Name for the module pass.
+
+    Returns
+    -------
+    decorator
+        The npu_pass decorator.
+    """
+
+    def decorator(npu_pass_class):
+        @tvm.ir.transform.module_pass(name=name, opt_level=opt_level)
+        class ModulePassWrapper:
+            """The wrapper for the NPU pass."""
+
+            def __init__(self, *args, **kwargs):
+                self.args = args
+                self.kwargs = kwargs
+
+            def transform_module(self, mod: tvm.ir.IRModule, _) -> 
tvm.ir.IRModule:
+                npu_functions = filter(lambda x: is_npu_func(x[1]), 
mod.functions.items())
+                for global_var, func in npu_functions:
+                    npu_pass = npu_pass_class(*self.args, **self.kwargs)
+                    func = npu_pass.transform_npu_function(global_var, func)
+                    mod.update_func(global_var, func)
+                return mod
+
+        return ModulePassWrapper
+
+    return decorator
diff --git a/src/relay/backend/contrib/ethosu/codegen.cc 
b/src/relay/backend/contrib/ethosu/codegen.cc
index ca41ccd..7044669 100644
--- a/src/relay/backend/contrib/ethosu/codegen.cc
+++ b/src/relay/backend/contrib/ethosu/codegen.cc
@@ -48,59 +48,63 @@ namespace contrib {
 namespace ethosu {
 
 /*!
- * \brief This mutator lowers each external
- * relay function to a TIR PrimFunc
+ * \brief This mutator outlines functions that are marked with a named
+ * "Compiler" attribute. Functions that do not match this condition remain
+ * unaltered.
  */
-class RelayToTIRMutator : public MixedModeMutator {
+class OutlineCompilerFunctionsMutator : public MixedModeMutator {
  public:
-  explicit RelayToTIRMutator(IRModule ir_module) : ir_module_(ir_module) {}
-
-  IRModule operator()() {
-    GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
-    Function main = Downcast<Function>(ir_module_->Lookup(main_global_var));
-    Function mutated_main = WithFields(main, main->params, 
VisitExpr(main->body));
-
-    ir_module_->Update(main_global_var, mutated_main);
-    ir_module_ = WithAttr(ir_module_, "device_contexts", device_contexts_);
-    return ir_module_;
-  }
+  explicit OutlineCompilerFunctionsMutator(const IRModule& mod, const 
std::string& compiler_name)
+      : mod_(mod), compiler_name_(compiler_name) {}
 
   Expr Rewrite_(const CallNode* pre, const Expr& post) override {
     Call call = Downcast<Call>(post);
     if (call->op->IsInstance<FunctionNode>()) {
       Function func = Downcast<Function>(call->op);
-      auto codegen_name = func->GetAttr<String>(attr::kCompiler);
-      if (codegen_name.defined() && codegen_name == "ethos-u") {
-        auto relay_to_tir_func_pf =
-            tvm::runtime::Registry::Get("relay.ext.ethos-u.relay_to_tir_func");
-        ICHECK(relay_to_tir_func_pf);
-        tir::PrimFunc prim_func = (*relay_to_tir_func_pf)(func);
-        prim_func = WithAttr(prim_func, tvm::attr::kTarget, Target("ethos-u"));
-        String symbol_name = 
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
-        GlobalVar gv(symbol_name);
-        Array<RelayExpr> args = call->args;
-        gv->checked_type_ = func->checked_type();
-        ir_module_->Update(gv, prim_func);
-        device_contexts_.Set(gv, codegen_name.value());
-        return Call(gv, args, call->attrs, call->type_args);
+      auto compiler = func->GetAttr<String>(attr::kCompiler);
+      if (compiler.defined() && compiler == compiler_name_) {
+        auto gv_name = func->GetAttr<String>("global_symbol").value_or("");
+        ICHECK_NE(gv_name, "")
+            << "Function to be outlined must have global_symbol attribute, but 
didn't.";
+        GlobalVar gv(gv_name);
+        if (func->checked_type_.defined()) {
+          gv->checked_type_ = func->checked_type();
+        }
+        mod_->Update(gv, func);
+        return Call(gv, call->args, call->attrs, call->type_args);
       }
     }
     return post;
   }
 
  private:
-  IRModule ir_module_;
-  Map<GlobalVar, String> device_contexts_;
+  IRModule mod_;
+  std::string compiler_name_;
 };
 
-tvm::transform::Pass RelayToTIR() {
+/*!
+ * \brief A pass to outline compiler specific functions.
+ */
+tvm::transform::Pass OutlineCompilerFunctions(const std::string& 
compiler_name) {
   runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> 
pass_func =
-      [=](IRModule ir_module, transform::PassContext pass_context) {
-        return RelayToTIRMutator(ir_module)();
+      [=](IRModule mod, transform::PassContext ctx) {
+        GlobalVar gv = mod->GetGlobalVar("main");
+        Function main_func = Downcast<Function>(mod->Lookup("main"));
+        auto new_main_body =
+            OutlineCompilerFunctionsMutator(mod, 
compiler_name).VisitExpr(main_func->body);
+        if (!new_main_body.same_as(main_func->body)) {
+          Function new_main_func = WithFields(main_func, main_func->params, 
new_main_body);
+          mod->Update(gv, new_main_func);
+        }
+        return mod;
       };
-  return tvm::transform::CreateModulePass(pass_func, 0, 
"relay.contrib.ethos-u.RelayToTIR", {});
+  return tvm::transform::CreateModulePass(
+      pass_func, 0, "relay.backend.contrib.ethos-u.OutlineCompilerFunctions", 
{});
 }
 
+TVM_REGISTER_GLOBAL("relay.ext.ethos-u.OutlineCompilerFunctions")
+    .set_body_typed(OutlineCompilerFunctions);
+
 /*!
  * \brief This mutator removes identity operations that are not necessary. 
Specifically, an
  * identity operation can be removed when it is immediately followed by an NPU 
compute
@@ -161,11 +165,14 @@ tvm::transform::Pass IdentityOptimizer() {
   runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> 
pass_func =
       [=](IRModule mod, transform::PassContext ctx) {
         for (auto gv : mod->GetGlobalVars()) {
-          Function main_func = Downcast<Function>(mod->Lookup(gv));
-          auto new_main_body = 
RemoveRedundantIdentities().VisitExpr(main_func->body);
-          if (!new_main_body.same_as(main_func->body)) {
-            Function new_main_func = WithFields(main_func, main_func->params, 
new_main_body);
-            mod->Update(gv, new_main_func);
+          Function func = Downcast<Function>(mod->Lookup(gv));
+          auto compiler_name = func->GetAttr<String>(attr::kCompiler);
+          if (compiler_name.defined() && compiler_name == "ethos-u") {
+            auto new_body = RemoveRedundantIdentities().VisitExpr(func->body);
+            if (!new_body.same_as(func->body)) {
+              Function new_func = WithFields(func, func->params, new_body);
+              mod->Update(gv, new_func);
+            }
           }
         }
         return mod;
@@ -177,6 +184,20 @@ tvm::transform::Pass IdentityOptimizer() {
 
TVM_REGISTER_GLOBAL("relay.ext.ethos-u.IdentityOptimizer").set_body_typed(IdentityOptimizer);
 
 /*!
+ * \brief This pass will lower NPU functions in a Relay module to scheduled 
TIR prim functions.
+ */
+tvm::transform::Pass RelayToTIR() {
+  runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> 
pass_func =
+      [=](IRModule ir_module, transform::PassContext pass_context) {
+        auto relay_to_tir_pf = 
tvm::runtime::Registry::Get("relay.ext.ethos-u.relay_to_tir");
+        ICHECK(relay_to_tir_pf);
+        ir_module = (*relay_to_tir_pf)(ir_module);
+        return ir_module;
+      };
+  return tvm::transform::CreateModulePass(pass_func, 0, 
"relay.contrib.ethos-u.RelayToTIR", {});
+}
+
+/*!
  * \brief This function lowers the IRModule with PrimFunc
  * with the target of the microNPU to a C-source runtime module
  */
diff --git a/tests/python/contrib/test_ethosu/test_compiler.py 
b/tests/python/contrib/test_ethosu/test_compiler.py
index 0e31be8..5da9163 100644
--- a/tests/python/contrib/test_ethosu/test_compiler.py
+++ b/tests/python/contrib/test_ethosu/test_compiler.py
@@ -19,7 +19,7 @@ import pytest
 pytest.importorskip("ethosu.vela")
 import tvm
 from tvm import relay
-from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir
 from . import infra
 
 
@@ -57,7 +57,7 @@ def test_lower_to_tir_arg_count(relay_function, arg_count):
     mod = tvm.IRModule()
     mod["main"] = relay_function()
     mod = relay.transform.InferType()(mod)
-    tir_mod = lower_to_tir(mod["main"])[0]
+    tir_mod = _lower_to_tir(mod["main"])[0]
     primfunc = tir_mod["main"]
     assert len(primfunc.params) == arg_count
 
diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py 
b/tests/python/contrib/test_ethosu/test_encode_constants.py
index 8878e46..760f375 100644
--- a/tests/python/contrib/test_ethosu/test_encode_constants.py
+++ b/tests/python/contrib/test_ethosu/test_encode_constants.py
@@ -22,7 +22,7 @@ import tvm
 from tvm import relay
 from tvm.script import tir as T
 from tvm.relay.testing import run_opt_pass
-from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir
 from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute
 from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants
 from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
@@ -96,7 +96,7 @@ def test_weight_stream_only():
         return func
 
     func = _get_func()
-    mod, consts = lower_to_tir(func, cascader=_planner)
+    mod, consts = _lower_to_tir(func, cascader=_planner)
     script = mod.script(show_meta=True)
     test_mod = tvm.script.from_source(script)
     reference_mod = WeightStreamOnly
@@ -159,7 +159,7 @@ def test_re_read_weights():
         return func
 
     func = _get_func()
-    mod, consts = lower_to_tir(func, cascader=_cascader)
+    mod, consts = _lower_to_tir(func, cascader=_cascader)
     script = mod.script(show_meta=True)
     test_mod = tvm.script.from_source(script)
     reference_mod = RereadWeights
@@ -217,7 +217,7 @@ def test_direct_read_only():
         return func
 
     func = _get_func()
-    mod, consts = lower_to_tir(func)
+    mod, consts = _lower_to_tir(func)
 
     script = mod.script(show_meta=True)
     test_mod = tvm.script.from_source(script)
@@ -306,7 +306,7 @@ def test_mixed_read():
         return func
 
     func = _get_func()
-    mod, consts = lower_to_tir(func, cascader=_planner)
+    mod, consts = _lower_to_tir(func, cascader=_planner)
 
     script = mod.script(show_meta=True)
     test_mod = tvm.script.from_source(script)
@@ -353,7 +353,7 @@ def test_constant_as_input():
         func = run_opt_pass(func, relay.transform.InferType())
         return func
 
-    tir_mod, params = lower_to_tir(get_graph(), copy_constants())
+    tir_mod, params = _lower_to_tir(get_graph(), copy_constants())
 
     # Check tile address for the scalar constant input hasn't been
     # overwritten.
diff --git a/tests/python/contrib/test_ethosu/test_identity_optimizer.py 
b/tests/python/contrib/test_ethosu/test_identity_optimizer.py
index 833b8d0..a2bb4f4 100644
--- a/tests/python/contrib/test_ethosu/test_identity_optimizer.py
+++ b/tests/python/contrib/test_ethosu/test_identity_optimizer.py
@@ -28,21 +28,22 @@ import tensorflow as tf
 import tvm
 from tvm import relay
 from tvm.relay.op.contrib.ethosu import partition_for_ethosu
-from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func
+from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir
 from tvm.relay.backend.contrib.ethosu.codegen import IdentityOptimizer
 
 from . import infra
 from .test_codegen import _compare_tvm_with_tflite
 
 
-def _optimize(expr, optimize=True):
+def _optimize(func, optimize=True):
     """Create IRModule and run identity optimizer pass."""
-    mod = tvm.IRModule.from_expr(expr)
+    func = func.with_attr("Compiler", "ethos-u")
+    mod = tvm.IRModule.from_expr(func)
     mod = relay.transform.InferType()(mod)
     if optimize:
         mod = IdentityOptimizer()(mod)
     entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
+    return entry if isinstance(func, relay.Function) else entry.body
 
 
 def _assert_structural_equal(a, b):
@@ -266,7 +267,7 @@ def test_identity_single_removal_on_binary_elementwise():
     _assert_structural_equal(actual, expected)
 
 
-def test_layout_optimizer_runs_in_compilation_pipeline():
+def test_identity_optimizer_runs_in_compilation_pipeline():
     """Checks that the identity optimization pass is run as part of the NPU 
compilation pipeline."""
 
     def get_graph():
@@ -278,10 +279,10 @@ def test_layout_optimizer_runs_in_compilation_pipeline():
 
     mod = get_graph()
     mod = partition_for_ethosu(mod)
+    mod = relay_to_tir(mod)
 
     external_gv_name = mod["main"].body.op.name_hint
-    external_func = mod[external_gv_name]
-    prim_func = relay_to_tir_func(external_func)
+    prim_func = mod[external_gv_name]
 
     # Check for hints in the TIR prim func that the identity optimization pass
     # has ran. There should not be an identity in the prim func.
diff --git a/tests/python/contrib/test_ethosu/test_layout_optimizer.py 
b/tests/python/contrib/test_ethosu/test_layout_optimizer.py
index 9199cdd..a2161c7 100644
--- a/tests/python/contrib/test_ethosu/test_layout_optimizer.py
+++ b/tests/python/contrib/test_ethosu/test_layout_optimizer.py
@@ -33,19 +33,20 @@ import tvm
 from tvm import relay
 from tvm.relay.op.contrib.ethosu import partition_for_ethosu
 from tvm.relay.backend.contrib.ethosu.codegen import LayoutOptimizer
-from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func
+from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir
 
 from . import infra
 
 
-def _optimize(expr, optimize=True):
+def _optimize(func, optimize=True):
     """Create IRModule and run layout optimizer pass."""
-    mod = tvm.IRModule.from_expr(expr)
+    func = func.with_attr("Compiler", "ethos-u")
+    mod = tvm.IRModule.from_expr(func)
     mod = relay.transform.InferType()(mod)
     if optimize:
         mod = LayoutOptimizer()(mod)
     entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
+    return entry if isinstance(func, relay.Function) else entry.body
 
 
 def _assert_structural_equal(a, b):
@@ -721,10 +722,10 @@ def test_layout_optimizer_runs_in_compilation_pipeline():
 
     mod = get_graph()
     mod = partition_for_ethosu(mod)
+    mod = relay_to_tir(mod)
 
     external_gv_name = mod["main"].body.op.name_hint
-    external_func = mod[external_gv_name]
-    prim_func = relay_to_tir_func(external_func)
+    prim_func = mod[external_gv_name]
 
     # Check for hints in the TIR prim func that the layout optimization pass 
has ran
     ops = prim_func.body.body.seq
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py 
b/tests/python/contrib/test_ethosu/test_legalize.py
index 710c3e8..32cf2c1 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -102,15 +102,21 @@ def test_split_indices_legalize():
         """
         return tvm.parser.fromtext(expected_ir_string)
 
+    rewrite_split = [legalize.PartitionedSplitRewriter(), 
legalize.SplitRewriter()]
+
     mod_axis1 = tvm.IRModule()
-    mod_axis1["tvmgen_default_ethos_u_main_0"] = create_graph(1)
-    mod_axis1 = legalize.LegalizeSplit()(mod_axis1)
+    func = create_graph(1)
+    for r in rewrite_split:
+        func = dataflow_pattern.rewrite(r, func)
+    mod_axis1["tvmgen_default_ethos_u_main_0"] = func
     expected_axis1 = expected_mod_axis1()
     tvm.ir.assert_structural_equal(mod_axis1, expected_axis1)
 
     mod_axis2 = tvm.IRModule()
-    mod_axis2["tvmgen_default_ethos_u_main_0"] = create_graph(2)
-    mod_axis2 = legalize.LegalizeSplit()(mod_axis2)
+    func = create_graph(2)
+    for r in rewrite_split:
+        func = dataflow_pattern.rewrite(r, func)
+    mod_axis2["tvmgen_default_ethos_u_main_0"] = func
     expected_axis2 = expected_mod_axis2()
     tvm.ir.assert_structural_equal(mod_axis2, expected_axis2)
 
@@ -198,15 +204,21 @@ def test_split_sections_legalize():
         """
         return tvm.parser.fromtext(expected_ir_string)
 
+    rewrite_split = [legalize.PartitionedSplitRewriter(), 
legalize.SplitRewriter()]
+
     mod_axis1 = tvm.IRModule()
-    mod_axis1["tvmgen_default_ethos_u_main_0"] = create_graph(1, 5)
-    mod_axis1 = legalize.LegalizeSplit()(mod_axis1)
+    func = create_graph(1, 5)
+    for r in rewrite_split:
+        func = dataflow_pattern.rewrite(r, func)
+    mod_axis1["tvmgen_default_ethos_u_main_0"] = func
     expected_axis1 = expected_mod_axis1()
     tvm.ir.assert_structural_equal(mod_axis1, expected_axis1)
 
     mod_axis2 = tvm.IRModule()
-    mod_axis2["tvmgen_default_ethos_u_main_0"] = create_graph(2, 5)
-    mod_axis2 = legalize.LegalizeSplit()(mod_axis2)
+    func = create_graph(2, 5)
+    for r in rewrite_split:
+        func = dataflow_pattern.rewrite(r, func)
+    mod_axis2["tvmgen_default_ethos_u_main_0"] = func
     expected_axis2 = expected_mod_axis2()
     tvm.ir.assert_structural_equal(mod_axis2, expected_axis2)
 
diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py 
b/tests/python/contrib/test_ethosu/test_lut_optimizer.py
index d9a543c..db2a1d5 100644
--- a/tests/python/contrib/test_ethosu/test_lut_optimizer.py
+++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py
@@ -27,7 +27,7 @@ import numpy as np
 import tvm
 from tvm import relay
 from tvm.relay.backend.contrib.ethosu.codegen import LUTsOptimizer
-from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func
+from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir
 from tvm.relay.op.contrib.ethosu import partition_for_ethosu
 
 from .test_codegen import _get_tflite_graph
@@ -49,6 +49,7 @@ def test_merge_lut_into_conv():
         id2 = infra.make_ethosu_identity(conv2, lut=lut2, activation="SIGMOID")
 
         func = relay.Function(relay.analysis.free_vars(id2), id2)
+        func = func.with_attr("Compiler", "ethos-u")
         mod = tvm.IRModule.from_expr(func)
         return mod
 
@@ -61,6 +62,7 @@ def test_merge_lut_into_conv():
         )
 
         func = relay.Function(relay.analysis.free_vars(conv2), conv2)
+        func = func.with_attr("Compiler", "ethos-u")
         mod = tvm.IRModule.from_expr(func)
         mod = relay.transform.InferType()(mod)
         return mod
@@ -84,6 +86,7 @@ def test_multiple_luts():
         id2 = infra.make_ethosu_identity(id1, lut=lut2, activation="TANH")
 
         func = relay.Function(relay.analysis.free_vars(id2), id2)
+        func = func.with_attr("Compiler", "ethos-u")
         mod = tvm.IRModule.from_expr(func)
         return mod
 
@@ -94,6 +97,7 @@ def test_multiple_luts():
         id2 = infra.make_ethosu_identity(conv1, lut=lut2, activation="TANH")
 
         func = relay.Function(relay.analysis.free_vars(id2), id2)
+        func = func.with_attr("Compiler", "ethos-u")
         mod = tvm.IRModule.from_expr(func)
         mod = relay.transform.InferType()(mod)
         return mod
@@ -119,10 +123,10 @@ def test_lut_optimizer_runs_in_compilation_pipeline():
 
     mod, _ = _get_tflite_graph(get_graph, [ifm_shape])
     mod = partition_for_ethosu(mod)
+    mod = relay_to_tir(mod)
 
     external_gv_name = mod["main"].body.op.name_hint
-    external_func = mod[external_gv_name]
-    prim_func = relay_to_tir_func(external_func)
+    prim_func = mod[external_gv_name]
 
     # Check for hints in the TIR prim func that the LUT optimization pass has 
ran.
     # If the module was optimized, there should be no identity operations.
diff --git 
a/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py 
b/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py
new file mode 100644
index 0000000..91458f6
--- /dev/null
+++ b/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Test the outline compiler functions pass.
+"""
+
+import pytest
+
+pytest.importorskip("ethosu.vela")
+
+import tvm
+from tvm import relay
+from tvm.relay.backend.contrib.ethosu.codegen import OutlineCompilerFunctions
+
+
+def test_outline_compiler_functions():
+    compiler_name = "my-compiler"
+    wrong_compiler_name = "wrong-compiler"
+
+    def before():
+        inp = relay.var("input")
+
+        # Inlined functions for "my-compiler"
+        x = relay.var("x", shape=(1, 2, 2, 4))
+        x = relay.reshape(x, newshape=(1, 4, 4))
+        x = relay.Function(relay.analysis.free_vars(x), x)
+        x = x.with_attr("Compiler", compiler_name)
+        x = x.with_attr("global_symbol", "ext_func")
+
+        # Inlined function for "wrong-compiler"
+        y = relay.var("y", shape=(1, 4, 4))
+        y = relay.reshape(y, newshape=(1, 16))
+        y = relay.Function(relay.analysis.free_vars(y), y)
+        y = y.with_attr("Compiler", wrong_compiler_name)
+        y = y.with_attr("global_symbol", "ext_func_2")
+
+        out = relay.Call(x, [inp])
+        out = relay.Call(y, [out])
+        out = relay.Function([inp], out)
+        return tvm.ir.IRModule.from_expr(out)
+
+    def expected():
+        mod = tvm.ir.IRModule()
+
+        inp = relay.var("input")
+
+        x = relay.var("x", shape=(1, 2, 2, 4))
+        x = relay.reshape(x, newshape=(1, 4, 4))
+        x = relay.Function(relay.analysis.free_vars(x), x)
+        x = x.with_attr("Compiler", compiler_name)
+        x = x.with_attr("global_symbol", "ext_func")
+        mod["ext_func"] = x
+
+        y = relay.var("y", shape=(1, 4, 4))
+        y = relay.reshape(y, newshape=(1, 16))
+        y = relay.Function(relay.analysis.free_vars(y), y)
+        y = y.with_attr("Compiler", wrong_compiler_name)
+        y = y.with_attr("global_symbol", "ext_func_2")
+
+        out = relay.Call(mod.get_global_var("ext_func"), [inp])
+        out = relay.Call(y, [out])
+        mod["main"] = relay.Function([inp], out)
+        return mod
+
+    after = OutlineCompilerFunctions(compiler_name)(before())
+    exp = expected()
+
+    global_vars = [str(gv) for gv in after.get_global_vars()]
+    assert "@ext_func" in global_vars
+    assert "@ext_func_2" not in global_vars
+    assert tvm.ir.structural_equal(after["ext_func"], exp["ext_func"])
diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py 
b/tests/python/contrib/test_ethosu/test_remove_concatenates.py
index f82351c..355b756 100644
--- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py
+++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py
@@ -22,7 +22,7 @@ import tvm.script
 from tvm.script import tir as T
 from tvm import relay
 from tvm.relay.testing import run_opt_pass
-from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir
 from .infra import make_ethosu_conv2d
 
 
@@ -69,7 +69,7 @@ def test_concat():
         return func
 
     func = _get_func()
-    mod, _ = lower_to_tir(func)
+    mod, _ = _lower_to_tir(func)
     script = mod.script(show_meta=True)
     test_mod = tvm.script.from_source(script)
 
diff --git 
a/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py 
b/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py
index 7d40054..b518f51 100644
--- a/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py
+++ b/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py
@@ -22,7 +22,7 @@ import tvm
 from tvm import relay
 from tvm.relay.testing import run_opt_pass
 from tvm.relay.backend.contrib.ethosu.tir import spec
-from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir
 from .infra import make_ethosu_binary_elementwise, get_binary_elementwise_args
 
 
@@ -71,7 +71,7 @@ def test_binary_elementwise_single(
     )
     func = relay.Function(relay.analysis.free_vars(binary_elementwise), 
binary_elementwise)
     func = run_opt_pass(func, relay.transform.InferType())
-    mod, _ = lower_to_tir(func)
+    mod, _ = _lower_to_tir(func)
     data = []
 
     def _visit(stmt):
@@ -227,7 +227,7 @@ def test_shift_binary_elementwise_single(
     )
     func = relay.Function(relay.analysis.free_vars(binary_elementwise), 
binary_elementwise)
     func = run_opt_pass(func, relay.transform.InferType())
-    mod, _ = lower_to_tir(func)
+    mod, _ = _lower_to_tir(func)
     data = []
 
     def _visit(stmt):
diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py 
b/tests/python/contrib/test_ethosu/test_replace_conv2d.py
index 5a9aa98..b51c932 100644
--- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py
+++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py
@@ -21,7 +21,7 @@ import tvm
 from tvm.script import tir as T
 from tvm import relay
 from tvm.relay.testing import run_opt_pass
-from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir
 from tvm.relay.backend.contrib.ethosu.tir.scheduler import total_cascader
 from .infra import make_ethosu_conv2d, get_convolutional_args
 
@@ -316,7 +316,7 @@ def test_conv2d_single(trial):
         [(1, 2, 12, 9, 16), 182, 67, (1, 3), (6, 3), (2, 2), (1, 1), "CLIP", 
"NHCWB16", "NHCWB16"],
     ]
     func = _get_func(*trial)
-    mod, _ = lower_to_tir(func)
+    mod, _ = _lower_to_tir(func)
     data = []
 
     def _visit(stmt):
@@ -593,7 +593,7 @@ def test_conv2d_double_cascade(trial):
     reference_mod = trial[0]
     params = trial[1:]
     func = _get_func(*params[:-1])
-    mod, _ = lower_to_tir(func, cascader=total_cascader(params[-1]))
+    mod, _ = _lower_to_tir(func, cascader=total_cascader(params[-1]))
     script = mod.script(show_meta=True)
     mod = tvm.script.from_source(script)
     tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True)
@@ -652,7 +652,7 @@ def test_conv2d_inline_copy(trial):
     reference_mod = trial[0]
     params = trial[1:]
     func = _get_func(*params)
-    mod, _ = lower_to_tir(func)
+    mod, _ = _lower_to_tir(func)
     script = mod.script(show_meta=True)
     mod = tvm.script.from_source(script)
     tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True)
@@ -755,7 +755,7 @@ def test_conv2d_inline_reshape(trial):
     reference_mod = trial[0]
     params = trial[1:]
     func = _get_func(*params)
-    mod, _ = lower_to_tir(func, cascader=total_cascader((1, 4, 6, 16)))
+    mod, _ = _lower_to_tir(func, cascader=total_cascader((1, 4, 6, 16)))
     script = mod.script(show_meta=True)
     mod = tvm.script.from_source(script)
     tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True)
@@ -775,7 +775,7 @@ def test_conv2d_big_pad():
         return func
 
     func = _get_func()
-    mod, _ = lower_to_tir(func, cascader=total_cascader((1, 4, 4, 16)))
+    mod, _ = _lower_to_tir(func, cascader=total_cascader((1, 4, 4, 16)))
 
 
 if __name__ == "__main__":
diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py 
b/tests/python/contrib/test_ethosu/test_replace_copy.py
index 4bfbae5..92b2940 100644
--- a/tests/python/contrib/test_ethosu/test_replace_copy.py
+++ b/tests/python/contrib/test_ethosu/test_replace_copy.py
@@ -21,7 +21,7 @@ import tvm
 from tvm.script import tir as T
 from tvm import relay
 from tvm.relay.testing import run_opt_pass
-from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir
 from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants, 
Convolution2DCompute
 
 from .infra import make_ethosu_conv2d
@@ -65,7 +65,7 @@ def test_copy():
         return func
 
     func = _get_func()
-    mod, _ = lower_to_tir(func, cascader=copy_constants())
+    mod, _ = _lower_to_tir(func, cascader=copy_constants())
 
     script = mod.script(show_meta=True)
     test_mod = tvm.script.from_source(script)
@@ -129,7 +129,7 @@ def test_weight_stream():
         return func
 
     func = _get_func()
-    mod, _ = lower_to_tir(func, cascader=_cascader)
+    mod, _ = _lower_to_tir(func, cascader=_cascader)
 
     script = mod.script(show_meta=True)
     test_mod = tvm.script.from_source(script)
diff --git a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py 
b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py
index edbfb49..fe11a0f 100644
--- a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py
+++ b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py
@@ -22,7 +22,7 @@ pytest.importorskip("ethosu.vela")
 import tvm
 from tvm import relay
 from tvm.relay.testing import run_opt_pass
-from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir
 from .infra import make_ethosu_depthwise_conv2d, get_convolutional_args
 
 
@@ -108,7 +108,7 @@ def test_depthwise_conv2d_single(trial):
         return func
 
     func = _get_func(*trial)
-    mod, _ = lower_to_tir(func)
+    mod, _ = _lower_to_tir(func)
     data = []
 
     def _visit(stmt):
diff --git a/tests/python/contrib/test_ethosu/test_replace_identity.py 
b/tests/python/contrib/test_ethosu/test_replace_identity.py
index 1ce55c4..e53230c 100644
--- a/tests/python/contrib/test_ethosu/test_replace_identity.py
+++ b/tests/python/contrib/test_ethosu/test_replace_identity.py
@@ -22,7 +22,7 @@ import tvm
 from tvm import relay
 from tvm.relay.testing import run_opt_pass
 from tvm.relay.backend.contrib.ethosu.tir import spec
-from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir
 from .infra import make_ethosu_identity, get_pooling_args
 
 
@@ -33,7 +33,7 @@ def test_identity(ifm_shape):
 
     func = relay.Function(relay.analysis.free_vars(identity), identity)
     func = run_opt_pass(func, relay.transform.InferType())
-    mod, _ = lower_to_tir(func)
+    mod, _ = _lower_to_tir(func)
     data = []
 
     def _visit(stmt):
diff --git a/tests/python/contrib/test_ethosu/test_replace_pooling.py 
b/tests/python/contrib/test_ethosu/test_replace_pooling.py
index c535498..0680f0c 100644
--- a/tests/python/contrib/test_ethosu/test_replace_pooling.py
+++ b/tests/python/contrib/test_ethosu/test_replace_pooling.py
@@ -22,7 +22,7 @@ import tvm
 from tvm import relay
 from tvm.relay.testing import run_opt_pass
 from tvm.relay.backend.contrib.ethosu.tir import spec
-from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir
 from .infra import make_ethosu_pooling, get_pooling_args
 
 
@@ -181,7 +181,7 @@ def test_pooling_single(
     )
     func = relay.Function(relay.analysis.free_vars(pooling), pooling)
     func = run_opt_pass(func, relay.transform.InferType())
-    mod, _ = lower_to_tir(func)
+    mod, _ = _lower_to_tir(func)
     data = []
 
     def _visit(stmt):
@@ -241,7 +241,7 @@ def test_correct_stride_with_multiple_pooling():
     )
     func = relay.Function(relay.analysis.free_vars(op), op)
     func = run_opt_pass(func, relay.transform.InferType())
-    mod, _ = lower_to_tir(func)
+    mod, _ = _lower_to_tir(func)
 
     data = []
 
diff --git a/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py 
b/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py
index 498609f..6240b54 100644
--- a/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py
+++ b/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py
@@ -22,7 +22,7 @@ import tvm.script
 from tvm import relay
 from tvm.relay.testing import run_opt_pass
 from tvm.relay.backend.contrib.ethosu.tir import spec
-from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir
 from .infra import make_ethosu_unary_elementwise
 
 
@@ -69,7 +69,7 @@ def test_unary_elementwise_single(
     )
     func = relay.Function(relay.analysis.free_vars(unary_elementwise), 
unary_elementwise)
     func = run_opt_pass(func, relay.transform.InferType())
-    mod, _ = lower_to_tir(func)
+    mod, _ = _lower_to_tir(func)
     data = []
 
     def _visit(stmt):
diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py 
b/tests/python/contrib/test_ethosu/test_scheduler.py
index 5c6f064..0602591 100644
--- a/tests/python/contrib/test_ethosu/test_scheduler.py
+++ b/tests/python/contrib/test_ethosu/test_scheduler.py
@@ -34,7 +34,7 @@ from tvm.relay.backend.contrib.ethosu.tir.scheduler import (
 from tvm.relay.backend.contrib.ethosu.tir.compiler import (
     lower_to_te,
     extract_constants,
-    lower_to_tir,
+    _lower_to_tir,
 )
 from .infra import (
     AttachType,
@@ -216,7 +216,7 @@ def test_schedule_diamond_graph():
     func = relay.Function(relay.analysis.free_vars(add), add)
     func = run_opt_pass(func, relay.transform.InferType())
 
-    test_mod, _ = lower_to_tir(func, copy_constants())
+    test_mod, _ = _lower_to_tir(func, copy_constants())
     reference_mod = DiamondGraphTir
 
     tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], 
True)

Reply via email to