This is an automated email from the ASF dual-hosted git repository.

hongyij pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2ae2aca7d1 Add Python functor support for TIR expressions and 
statements (#18060)
2ae2aca7d1 is described below

commit 2ae2aca7d18b14a5db5c40cc441e2f76307ad5ef
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Jun 17 02:14:01 2025 +0800

    Add Python functor support for TIR expressions and statements (#18060)
    
    This commit introduces Python interfaces for TIR functors,
    enabling Python-side customization of expression and statement visiting
    and mutation operations.
    
    Key changes:
    - Add PyStmtExprVisitorNode and PyStmtExprMutatorNode classes in C++
    - Implement Python bindings for all TIR expression and statement types
    - Support both visitor (read-only) and mutator (transforming) patterns
    - Add comprehensive dispatch mechanisms for Python callbacks
    - Include FFI registrations for Python-C++ interoperability
    
    This enables users to write custom TIR transformations and analysis
    passes directly in Python while maintaining performance through
    selective Python callback dispatch.
---
 python/tvm/relax/expr_functor.py               |    4 +-
 python/tvm/runtime/support.py                  |  137 ++
 python/tvm/tir/__init__.py                     |    1 +
 python/tvm/tir/functor.py                      | 2051 ++++++++++++++++++++++++
 src/tir/analysis/buffer_access_lca_detector.cc |    2 +-
 src/tir/ir/py_functor.cc                       |  859 ++++++++++
 tests/python/tir-transform/test_tir_functor.py |  436 +++++
 7 files changed, 3487 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py
index 49d3b14505..a40b81c233 100644
--- a/python/tvm/relax/expr_functor.py
+++ b/python/tvm/relax/expr_functor.py
@@ -20,8 +20,8 @@ from typing import Callable, Optional
 
 import tvm
 from tvm.ir import Op
-from tvm.meta_schedule.utils import derived_object
 from tvm.runtime import Object
+from tvm.runtime.support import derived_object
 
 from ..ir.module import IRModule
 from . import _ffi_api
@@ -31,7 +31,6 @@ from .expr import (
     BindingBlock,
     Call,
     Constant,
-    Id,
     DataflowBlock,
     DataflowVar,
     DataTypeImm,
@@ -39,6 +38,7 @@ from .expr import (
     ExternFunc,
     Function,
     GlobalVar,
+    Id,
     If,
     MatchCast,
     PrimValue,
diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py
index 149a66ef7b..2669459d71 100644
--- a/python/tvm/runtime/support.py
+++ b/python/tvm/runtime/support.py
@@ -18,6 +18,7 @@
 """Runtime support infra of TVM."""
 
 import re
+from typing import TypeVar
 
 import tvm.ffi
 
@@ -67,3 +68,139 @@ def _regex_match(regex_pattern: str, match_against: str) -> 
bool:
     """
     match = re.match(regex_pattern, match_against)
     return match is not None
+
+
+T = TypeVar("T")
+
+
+def derived_object(cls: type[T]) -> type[T]:
+    """A decorator to register derived subclasses for TVM objects.
+
+    Parameters
+    ----------
+    cls : type
+        The derived class to be registered.
+
+    Returns
+    -------
+    cls : type
+        The decorated TVM object.
+
+    Example
+    -------
+    .. code-block:: python
+
+        @register_object("meta_schedule.PyRunner")
+        class _PyRunner(meta_schedule.Runner):
+            def __init__(self, f_run: Callable = None):
+                self.__init_handle_by_constructor__(_ffi_api.RunnerPyRunner, 
f_run)
+
+        class PyRunner:
+            _tvm_metadata = {
+                "cls": _PyRunner,
+                "methods": ["run"]
+            }
+            def run(self, runner_inputs):
+                raise NotImplementedError
+
+        @derived_object
+        class LocalRunner(PyRunner):
+            def run(self, runner_inputs):
+                ...
+    """
+
+    import functools  # pylint: disable=import-outside-toplevel
+    import weakref  # pylint: disable=import-outside-toplevel
+
+    def _extract(inst: type, name: str):
+        """Extract function from intrinsic class."""
+
+        def method(*args, **kwargs):
+            return getattr(inst, name)(*args, **kwargs)
+
+        for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]):
+            # extract functions that differ from the base class
+            if not hasattr(base_cls, name):
+                continue
+            if getattr(base_cls, name) is getattr(inherit_cls, name) and name 
!= "__str__":
+                continue
+            return method
+
+        # for task scheduler return None means calling default function
+        # otherwise it will trigger a TVMError of method not implemented
+        # on the c++ side when you call the method, __str__ not required
+        return None
+
+    assert isinstance(cls.__base__, type)
+    if hasattr(cls, "_type") and cls._type == "TVMDerivedObject":  # type: 
ignore
+        raise TypeError(
+            (
+                f"Inheritance from a decorated object `{cls.__name__}` is not 
allowed. "
+                f"Please inherit from `{cls.__name__}._cls`."
+            )
+        )
+    assert hasattr(
+        cls, "_tvm_metadata"
+    ), "Please use the user-facing method overriding class, i.e., PyRunner."
+
+    base = cls.__base__
+    metadata = getattr(base, "_tvm_metadata")
+    fields = metadata.get("fields", [])
+    methods = metadata.get("methods", [])
+
+    class TVMDerivedObject(metadata["cls"]):  # type: ignore
+        """The derived object to avoid cyclic dependency."""
+
+        _cls = cls
+        _type = "TVMDerivedObject"
+
+        def __init__(self, *args, **kwargs):
+            """Constructor."""
+            self._inst = cls(*args, **kwargs)
+
+            super().__init__(
+                # the constructor's parameters, builder, runner, etc.
+                *[getattr(self._inst, name) for name in fields],
+                # the function methods, init_with_tune_context, build, run, 
etc.
+                *[_extract(self._inst, name) for name in methods],
+            )
+
+            # for task scheduler hybrid funcs in c++ & python side
+            # using weakref to avoid cyclic dependency
+            self._inst._outer = weakref.ref(self)
+
+        def __getattr__(self, name):
+            import inspect  # pylint: disable=import-outside-toplevel
+
+            try:
+                # fall back to instance attribute if there is not any
+                # return self._inst.__getattribute__(name)
+                result = self._inst.__getattribute__(name)
+            except AttributeError:
+                result = super(TVMDerivedObject, self).__getattr__(name)
+
+            if inspect.ismethod(result):
+
+                def method(*args, **kwargs):
+                    return result(*args, **kwargs)
+
+                # set __own__ to aviod implicit deconstruction
+                setattr(method, "__own__", self)
+                return method
+
+            return result
+
+        def __setattr__(self, name, value):
+            if name not in ["_inst", "key", "handle"]:
+                self._inst.__setattr__(name, value)
+            else:
+                super(TVMDerivedObject, self).__setattr__(name, value)
+
+    functools.update_wrapper(TVMDerivedObject.__init__, cls.__init__)  # type: 
ignore
+    TVMDerivedObject.__name__ = cls.__name__
+    TVMDerivedObject.__doc__ = cls.__doc__
+    TVMDerivedObject.__module__ = cls.__module__
+    for key, value in cls.__dict__.items():
+        if isinstance(value, (classmethod, staticmethod)):
+            setattr(TVMDerivedObject, key, value)
+    return TVMDerivedObject
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 5c4a2b91f5..120d652dd8 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -108,3 +108,4 @@ from . import analysis
 from . import stmt_functor
 from .build import build
 from .pipeline import get_tir_pipeline, get_default_tir_pipeline
+from .functor import PyStmtExprVisitor, PyStmtExprMutator
diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py
new file mode 100644
index 0000000000..06985f6645
--- /dev/null
+++ b/python/tvm/tir/functor.py
@@ -0,0 +1,2051 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, 
arguments-differ
+"""The expression and statement functor of TIR."""
+from typing import Callable
+
+import tvm
+from tvm.ir import PrimExpr
+from tvm.runtime import Object
+from tvm.runtime.support import derived_object
+
+from . import _ffi_api
+from .expr import (
+    EQ,
+    GE,
+    GT,
+    LE,
+    LT,
+    NE,
+    Add,
+    And,
+    Broadcast,
+    BufferLoad,
+    Call,
+    Cast,
+    Div,
+    FloatImm,
+    FloorDiv,
+    FloorMod,
+    IntImm,
+    Let,
+    Max,
+    Min,
+    Mod,
+    Mul,
+    Not,
+    Or,
+    ProducerLoad,
+    Ramp,
+    Reduce,
+    Select,
+    Shuffle,
+    SizeVar,
+    StringImm,
+    Sub,
+    Var,
+)
+from .stmt import (
+    Allocate,
+    AllocateConst,
+    AssertStmt,
+    AttrStmt,
+    Block,
+    BlockRealize,
+    BufferRealize,
+    BufferStore,
+    DeclBuffer,
+    Evaluate,
+    For,
+    IfThenElse,
+    LetStmt,
+    SeqStmt,
+    Stmt,
+    While,
+)
+
+visitor = derived_object
+"""
+A decorator to wrap user-customized PyStmtExprVisitor as TVM object 
_PyStmtExprVisitor.
+
+Parameters
+----------
+visitor_cls : PyStmtExprVisitor
+    The user-customized PyStmtExprVisitor.
+
+Returns
+-------
+cls : _PyStmtExprVisitor
+    The decorated TVM object _PyStmtExprVisitor(StmtExprVisitor on the C++ 
side).
+
+Example
+-------
+.. code-block:: python
+
+    @tir.functor.stmt_expr_visitor
+    class MyStmtExprVisitor(PyStmtExprVisitor):
+        # customize visit function
+        def visit_call_(self, op: Call) -> None:
+            # just for demo purposes
+            ...
+    # myvisitor is now a special visitor that visit every Call with
+    # user-customized visit_call_
+    myvisitor = MyStmtExprVisitor()
+    # apply myvisitor to PrimExpr and Stmt
+    myvisitor.visit_expr(expr)
+    myvisitor.visit_stmt(stmt)
+"""
+
+mutator = derived_object
+"""
+A decorator to wrap user-customized PyStmtExprMutator as TVM object 
_PyStmtExprMutator.
+
+Parameters
+----------
+mutator_cls : PyStmtExprMutator
+    The user-customized PyStmtExprMutator.
+
+Returns
+-------
+cls : _PyStmtExprMutator
+    The decorated TVM object _PyStmtExprMutator(StmtExprMutator on the C++ 
side).
+
+Example
+-------
+.. code-block:: python
+
+    @tir.functor.stmt_expr_mutator
+    class MyStmtExprMutator(PyStmtExprMutator):
+        # customize rewrite function
+        def visit_add_(self, op: Add) -> PrimExpr:
+            # just for demo purposes
+            ...
+
+    # mymutator is now a special mutator that rewrite every Add with
+    # user-customized visit_add_
+    mymutator = MyStmtExprMutator()
+    # apply mymutator to PrimExpr and Stmt
+    mymutator.visit_expr(expr)
+    mymutator.visit_stmt(stmt)
+"""
+
+
[email protected]_object("tir.PyStmtExprVisitor")
+class _PyStmtExprVisitor(Object):
+    """
+    An internal wrapper to interface between C++ and Python StmtExprVisitor.
+    This is the TVM object that wraps PyStmtExprVisitor.
+
+    Do not use this class directly. Use PyStmtExprVisitor instead.
+
+    See also: PyStmtExprVisitor, stmt_expr_visitor
+    """
+
+    def __init__(
+        self,
+        f_visit_stmt: Callable = None,
+        f_visit_expr: Callable = None,
+        # Stmt
+        f_visit_let_stmt: Callable = None,
+        f_visit_attr_stmt: Callable = None,
+        f_visit_if_then_else: Callable = None,
+        f_visit_for: Callable = None,
+        f_visit_while: Callable = None,
+        f_visit_allocate: Callable = None,
+        f_visit_allocate_const: Callable = None,
+        f_visit_decl_buffer: Callable = None,
+        f_visit_buffer_store: Callable = None,
+        f_visit_buffer_realize: Callable = None,
+        f_visit_assert_stmt: Callable = None,
+        f_visit_seq_stmt: Callable = None,
+        f_visit_evaluate: Callable = None,
+        f_visit_block: Callable = None,
+        f_visit_block_realize: Callable = None,
+        # PrimExpr
+        f_visit_var: Callable = None,
+        f_visit_size_var: Callable = None,
+        f_visit_buffer_load: Callable = None,
+        f_visit_producer_load: Callable = None,
+        f_visit_let: Callable = None,
+        f_visit_call: Callable = None,
+        f_visit_add: Callable = None,
+        f_visit_sub: Callable = None,
+        f_visit_mul: Callable = None,
+        f_visit_div: Callable = None,
+        f_visit_mod: Callable = None,
+        f_visit_floor_div: Callable = None,
+        f_visit_floor_mod: Callable = None,
+        f_visit_min: Callable = None,
+        f_visit_max: Callable = None,
+        f_visit_eq: Callable = None,
+        f_visit_ne: Callable = None,
+        f_visit_lt: Callable = None,
+        f_visit_le: Callable = None,
+        f_visit_gt: Callable = None,
+        f_visit_ge: Callable = None,
+        f_visit_and: Callable = None,
+        f_visit_or: Callable = None,
+        f_visit_reduce: Callable = None,
+        f_visit_cast: Callable = None,
+        f_visit_not: Callable = None,
+        f_visit_select: Callable = None,
+        f_visit_ramp: Callable = None,
+        f_visit_broadcast: Callable = None,
+        f_visit_shuffle: Callable = None,
+        f_visit_int_imm: Callable = None,
+        f_visit_float_imm: Callable = None,
+        f_visit_string_imm: Callable = None,
+    ) -> None:
+        """Constructor."""
+        self.__init_handle_by_constructor__(
+            _ffi_api.MakePyStmtExprVisitor,  # type: ignore
+            f_visit_stmt,
+            f_visit_expr,
+            # Stmt
+            f_visit_let_stmt,
+            f_visit_attr_stmt,
+            f_visit_if_then_else,
+            f_visit_for,
+            f_visit_while,
+            f_visit_allocate,
+            f_visit_allocate_const,
+            f_visit_decl_buffer,
+            f_visit_buffer_store,
+            f_visit_buffer_realize,
+            f_visit_assert_stmt,
+            f_visit_seq_stmt,
+            f_visit_evaluate,
+            f_visit_block,
+            f_visit_block_realize,
+            # PrimExpr
+            f_visit_var,
+            f_visit_size_var,
+            f_visit_buffer_load,
+            f_visit_producer_load,
+            f_visit_let,
+            f_visit_call,
+            f_visit_add,
+            f_visit_sub,
+            f_visit_mul,
+            f_visit_div,
+            f_visit_mod,
+            f_visit_floor_div,
+            f_visit_floor_mod,
+            f_visit_min,
+            f_visit_max,
+            f_visit_eq,
+            f_visit_ne,
+            f_visit_lt,
+            f_visit_le,
+            f_visit_gt,
+            f_visit_ge,
+            f_visit_and,
+            f_visit_or,
+            f_visit_reduce,
+            f_visit_cast,
+            f_visit_not,
+            f_visit_select,
+            f_visit_ramp,
+            f_visit_broadcast,
+            f_visit_shuffle,
+            f_visit_int_imm,
+            f_visit_float_imm,
+            f_visit_string_imm,
+        )
+
+
+class PyStmtExprVisitor:
+    """
+    A Python StmtExprVisitor to define custom visitor for both Stmt and 
PrimExpr.
+
+    Users can customize any of the visit function.
+    """
+
+    _tvm_metadata = {
+        "cls": _PyStmtExprVisitor,
+        "methods": [
+            "visit_stmt",
+            "visit_expr",
+            # Stmt
+            "visit_let_stmt_",
+            "visit_attr_stmt_",
+            "visit_if_then_else_",
+            "visit_for_",
+            "visit_while_",
+            "visit_allocate_",
+            "visit_allocate_const_",
+            "visit_decl_buffer_",
+            "visit_buffer_store_",
+            "visit_buffer_realize_",
+            "visit_assert_stmt_",
+            "visit_seq_stmt_",
+            "visit_evaluate_",
+            "visit_block_",
+            "visit_block_realize_",
+            # PrimExpr
+            "visit_var_",
+            "visit_size_var_",
+            "visit_buffer_load_",
+            "visit_producer_load_",
+            "visit_let_",
+            "visit_call_",
+            "visit_add_",
+            "visit_sub_",
+            "visit_mul_",
+            "visit_div_",
+            "visit_mod_",
+            "visit_floor_div_",
+            "visit_floor_mod_",
+            "visit_min_",
+            "visit_max_",
+            "visit_eq_",
+            "visit_ne_",
+            "visit_lt_",
+            "visit_le_",
+            "visit_gt_",
+            "visit_ge_",
+            "visit_and_",
+            "visit_or_",
+            "visit_reduce_",
+            "visit_cast_",
+            "visit_not_",
+            "visit_select_",
+            "visit_ramp_",
+            "visit_broadcast_",
+            "visit_shuffle_",
+            "visit_int_imm_",
+            "visit_float_imm_",
+            "visit_string_imm_",
+        ],
+    }
+
+    def visit_stmt(self, stmt: Stmt) -> None:
+        """Visit a Stmt.
+
+        Parameters
+        ----------
+        stmt : Stmt
+            The Stmt to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorVisitStmt(self._outer(), stmt)  # type: 
ignore
+
+    def visit_expr(self, expr: PrimExpr) -> None:
+        """Visit a PrimExpr.
+
+        Parameters
+        ----------
+        expr : PrimExpr
+            The PrimExpr to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorVisitExpr(self._outer(), expr)  # type: 
ignore
+
+    def visit_attr_stmt_(self, op: AttrStmt) -> None:
+        """Visit AttrStmt.
+        Users can customize this function to overwrite VisitStmt_(const 
AttrStmtNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : AttrStmt
+            The AttrStmt to be visited.
+        """
+        print("visit_attr_stmt_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_if_then_else_(self, op: IfThenElse) -> None:
+        """Visit IfThenElse.
+        Users can customize this function to overwrite VisitStmt_(const 
IfThenElseNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : IfThenElse
+            The IfThenElse to be visited.
+        """
+        print("visit_if_then_else_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_let_stmt_(self, op: LetStmt) -> None:
+        """Visit LetStmt.
+        Users can customize this function to overwrite VisitStmt_(const 
LetStmtNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : LetStmt
+            The LetStmt to be visited.
+        """
+        print("visit_let_stmt_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_for_(self, op: For) -> None:
+        """Visit For.
+        Users can customize this function to overwrite VisitStmt_(const 
ForNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : For
+            The For to be visited.
+        """
+        print("visit_for_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_while_(self, op: While) -> None:
+        """Visit While.
+        Users can customize this function to overwrite VisitStmt_(const 
WhileNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : While
+            The While to be visited.
+        """
+        print("visit_while_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_allocate_(self, op: Allocate) -> None:
+        """Visit Allocate.
+        Users can customize this function to overwrite VisitStmt_(const 
AllocateNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Allocate
+            The Allocate to be visited.
+        """
+        print("visit_allocate_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_allocate_const_(self, op: AllocateConst) -> None:
+        """Visit AllocateConst.
+        Users can customize this function to overwrite VisitStmt_(const 
AllocateConstNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : AllocateConst
+            The AllocateConst to be visited.
+        """
+        print("visit_allocate_const_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_decl_buffer_(self, op: DeclBuffer) -> None:
+        """Visit DeclBuffer.
+        Users can customize this function to overwrite VisitStmt_(const 
DeclBufferNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : DeclBuffer
+            The DeclBuffer to be visited.
+        """
+        print("visit_decl_buffer_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_buffer_store_(self, op: BufferStore) -> None:
+        """Visit BufferStore.
+        Users can customize this function to overwrite VisitStmt_(const 
BufferStoreNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : BufferStore
+            The BufferStore to be visited.
+        """
+        print("visit_buffer_store_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_buffer_realize_(self, op: BufferRealize) -> None:
+        """Visit BufferRealize.
+        Users can customize this function to overwrite VisitStmt_(const 
BufferRealizeNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : BufferRealize
+            The BufferRealize to be visited.
+        """
+        print("visit_buffer_realize_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_assert_stmt_(self, op: AssertStmt) -> None:
+        """Visit AssertStmt.
+        Users can customize this function to overwrite VisitStmt_(const 
AssertStmtNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : AssertStmt
+            The AssertStmt to be visited.
+        """
+        print("visit_assert_stmt_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_seq_stmt_(self, op: SeqStmt) -> None:
+        """Visit SeqStmt.
+        Users can customize this function to overwrite VisitStmt_(const 
SeqStmtNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : SeqStmt
+            The SeqStmt to be visited.
+        """
+        print("visit_seq_stmt_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_evaluate_(self, op: Evaluate) -> None:
+        """Visit Evaluate.
+        Users can customize this function to overwrite VisitStmt_(const 
EvaluateNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Evaluate
+            The Evaluate to be visited.
+        """
+        print("visit_evaluate_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_block_(self, op: Block) -> None:
+        """Visit Block.
+        Users can customize this function to overwrite VisitStmt_(const 
BlockNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Block
+            The Block to be visited.
+        """
+        print("visit_block_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_block_realize_(self, op: BlockRealize) -> None:
+        """Visit BlockRealize.
+        Users can customize this function to overwrite VisitStmt_(const 
BlockRealizeNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : BlockRealize
+            The BlockRealize to be visited.
+        """
+        print("visit_block_realize_", op)
+        _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op)  # type: 
ignore
+
+    def visit_var_(self, op: Var) -> None:
+        """Visit Var.
+
+        Users can customize this function to overwrite VisitVar_(const 
VarNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Var
+            The Var to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_size_var_(self, op: SizeVar) -> None:
+        """Visit SizeVar.
+
+        Users can customize this function to overwrite VisitSizeVar_(const 
SizeVarNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : SizeVar
+            The SizeVar to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_buffer_load_(self, op: BufferLoad) -> None:
+        """Visit BufferLoad.
+
+        Users can customize this function to overwrite VisitBufferLoad_(const 
BufferLoadNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : BufferLoad
+            The BufferLoad to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_producer_load_(self, op: ProducerLoad) -> None:
+        """Visit ProducerLoad.
+
+        Users can customize this function to overwrite
+        VisitProducerLoad_(const ProducerLoadNode* op) on the C++ side.
+
+        Parameters
+        ----------
+        op : ProducerLoad
+            The ProducerLoad to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_let_(self, op: Let) -> None:
+        """Visit Let.
+
+        Users can customize this function to overwrite VisitLet_(const 
LetNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Let
+            The Let to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_call_(self, op: Call) -> None:
+        """Visit Call.
+
+        Users can customize this function to overwrite VisitCall_(const 
CallNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Call
+            The Call to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_add_(self, op: Add) -> None:
+        """Visit Add.
+
+        Users can customize this function to overwrite VisitAdd_(const 
AddNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Add
+            The Add to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_sub_(self, op: Sub) -> None:
+        """Visit Sub.
+
+        Users can customize this function to overwrite VisitSub_(const 
SubNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Sub
+            The Sub to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_mul_(self, op: Mul) -> None:
+        """Visit Mul.
+
+        Users can customize this function to overwrite VisitMul_(const 
MulNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Mul
+            The Mul to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_div_(self, op: Div) -> None:
+        """Visit Div.
+
+        Users can customize this function to overwrite VisitDiv_(const 
DivNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Div
+            The Div to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_mod_(self, op: Mod) -> None:
+        """Visit Mod.
+
+        Users can customize this function to overwrite VisitMod_(const 
ModNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Mod
+            The Mod to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_floor_div_(self, op: FloorDiv) -> None:
+        """Visit FloorDiv.
+
+        Users can customize this function to overwrite VisitFloorDiv_(const 
FloorDivNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : FloorDiv
+            The FloorDiv to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_floor_mod_(self, op: FloorMod) -> None:
+        """Visit FloorMod.
+
+        Users can customize this function to overwrite VisitFloorMod_(const 
FloorModNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : FloorMod
+            The FloorMod to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_min_(self, op: Min) -> None:
+        """Visit Min.
+
+        Users can customize this function to overwrite VisitMin_(const 
MinNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Min
+            The Min to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_max_(self, op: Max) -> None:
+        """Visit Max.
+
+        Users can customize this function to overwrite VisitMax_(const 
MaxNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Max
+            The Max to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_eq_(self, op: EQ) -> None:
+        """Visit EQ.
+
+        Users can customize this function to overwrite VisitEQ_(const EQNode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : EQ
+            The EQ to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_ne_(self, op: NE) -> None:
+        """Visit NE.
+
+        Users can customize this function to overwrite VisitNE_(const NENode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : NE
+            The NE to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_lt_(self, op: LT) -> None:
+        """Visit LT.
+
+        Users can customize this function to overwrite VisitLT_(const LTNode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : LT
+            The LT to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_le_(self, op: LE) -> None:
+        """Visit LE.
+
+        Users can customize this function to overwrite VisitLE_(const LENode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : LE
+            The LE to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_gt_(self, op: GT) -> None:
+        """Visit GT.
+
+        Users can customize this function to overwrite VisitGT_(const GTNode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : GT
+            The GT to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_ge_(self, op: GE) -> None:
+        """Visit GE.
+
+        Users can customize this function to overwrite VisitGE_(const GENode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : GE
+            The GE to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_and_(self, op: And) -> None:
+        """Visit And.
+
+        Users can customize this function to overwrite VisitAnd_(const 
AndNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : And
+            The And to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_or_(self, op: Or) -> None:
+        """Visit Or.
+
+        Users can customize this function to overwrite VisitOr_(const OrNode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Or
+            The Or to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_reduce_(self, op: Reduce) -> None:
+        """Visit Reduce.
+
+        Users can customize this function to overwrite VisitReduce_(const 
ReduceNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Reduce
+            The Reduce to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_cast_(self, op: Cast) -> None:
+        """Visit Cast.
+
+        Users can customize this function to overwrite VisitCast_(const 
CastNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Cast
+            The Cast to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_not_(self, op: Not) -> None:
+        """Visit Not.
+
+        Users can customize this function to overwrite VisitNot_(const 
NotNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Not
+            The Not to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_select_(self, op: Select) -> None:
+        """Visit Select.
+
+        Users can customize this function to overwrite VisitSelect_(const 
SelectNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Select
+            The Select to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_ramp_(self, op: Ramp) -> None:
+        """Visit Ramp.
+
+        Users can customize this function to overwrite VisitRamp_(const 
RampNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Ramp
+            The Ramp to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_broadcast_(self, op: Broadcast) -> None:
+        """Visit Broadcast.
+
+        Users can customize this function to overwrite VisitBroadcast_(const 
BroadcastNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Broadcast
+            The Broadcast to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_shuffle_(self, op: Shuffle) -> None:
+        """Visit Shuffle.
+
+        Users can customize this function to overwrite VisitShuffle_(const 
ShuffleNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Shuffle
+            The Shuffle to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_int_imm_(self, op: IntImm) -> None:
+        """Visit IntImm.
+
+        Users can customize this function to overwrite VisitIntImm_(const 
IntImmNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : IntImm
+            The IntImm to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_float_imm_(self, op: FloatImm) -> None:
+        """Visit FloatImm.
+
+        Users can customize this function to overwrite VisitFloatImm_(const 
FloatImmNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : FloatImm
+            The FloatImm to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+    def visit_string_imm_(self, op: StringImm) -> None:
+        """Visit StringImm.
+
+        Users can customize this function to overwrite VisitStringImm_(const 
StringImmNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : StringImm
+            The StringImm to be visited.
+        """
+        _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op)  # type: 
ignore
+
+
[email protected]_object("tir.PyStmtExprMutator")
+class _PyStmtExprMutator(Object):
+    """
+    A TVM object to support customization of StmtExprMutator on the python 
side.
+    This is the decorated result returned from stmt_expr_mutator decorator.
+
+    WARNING: This is NOT the user facing class for method overwriting 
inheritance.
+
+    See also: stmt_expr_mutator, PyStmtExprMutator
+    """
+
+    def __init__(
+        self,
+        f_visit_stmt: Callable = None,
+        f_visit_expr: Callable = None,
+        # Stmt
+        f_visit_let_stmt: Callable = None,
+        f_visit_attr_stmt: Callable = None,
+        f_visit_if_then_else: Callable = None,
+        f_visit_for: Callable = None,
+        f_visit_while: Callable = None,
+        f_visit_allocate: Callable = None,
+        f_visit_allocate_const: Callable = None,
+        f_visit_decl_buffer: Callable = None,
+        f_visit_buffer_store: Callable = None,
+        f_visit_buffer_realize: Callable = None,
+        f_visit_assert_stmt: Callable = None,
+        f_visit_seq_stmt: Callable = None,
+        f_visit_evaluate: Callable = None,
+        f_visit_block: Callable = None,
+        f_visit_block_realize: Callable = None,
+        # PrimExpr
+        f_visit_var: Callable = None,
+        f_visit_size_var: Callable = None,
+        f_visit_buffer_load: Callable = None,
+        f_visit_producer_load: Callable = None,
+        f_visit_let: Callable = None,
+        f_visit_call: Callable = None,
+        f_visit_add: Callable = None,
+        f_visit_sub: Callable = None,
+        f_visit_mul: Callable = None,
+        f_visit_div: Callable = None,
+        f_visit_mod: Callable = None,
+        f_visit_floor_div: Callable = None,
+        f_visit_floor_mod: Callable = None,
+        f_visit_min: Callable = None,
+        f_visit_max: Callable = None,
+        f_visit_eq: Callable = None,
+        f_visit_ne: Callable = None,
+        f_visit_lt: Callable = None,
+        f_visit_le: Callable = None,
+        f_visit_gt: Callable = None,
+        f_visit_ge: Callable = None,
+        f_visit_and: Callable = None,
+        f_visit_or: Callable = None,
+        f_visit_reduce: Callable = None,
+        f_visit_cast: Callable = None,
+        f_visit_not: Callable = None,
+        f_visit_select: Callable = None,
+        f_visit_ramp: Callable = None,
+        f_visit_broadcast: Callable = None,
+        f_visit_shuffle: Callable = None,
+        f_visit_int_imm: Callable = None,
+        f_visit_float_imm: Callable = None,
+        f_visit_string_imm: Callable = None,
+    ) -> None:
+        """Constructor."""
+        self.__init_handle_by_constructor__(
+            _ffi_api.MakePyStmtExprMutator,  # type: ignore
+            f_visit_stmt,
+            f_visit_expr,
+            # Stmt
+            f_visit_let_stmt,
+            f_visit_attr_stmt,
+            f_visit_if_then_else,
+            f_visit_for,
+            f_visit_while,
+            f_visit_allocate,
+            f_visit_allocate_const,
+            f_visit_decl_buffer,
+            f_visit_buffer_store,
+            f_visit_buffer_realize,
+            f_visit_assert_stmt,
+            f_visit_seq_stmt,
+            f_visit_evaluate,
+            f_visit_block,
+            f_visit_block_realize,
+            # PrimExpr
+            f_visit_var,
+            f_visit_size_var,
+            f_visit_buffer_load,
+            f_visit_producer_load,
+            f_visit_let,
+            f_visit_call,
+            f_visit_add,
+            f_visit_sub,
+            f_visit_mul,
+            f_visit_div,
+            f_visit_mod,
+            f_visit_floor_div,
+            f_visit_floor_mod,
+            f_visit_min,
+            f_visit_max,
+            f_visit_eq,
+            f_visit_ne,
+            f_visit_lt,
+            f_visit_le,
+            f_visit_gt,
+            f_visit_ge,
+            f_visit_and,
+            f_visit_or,
+            f_visit_reduce,
+            f_visit_cast,
+            f_visit_not,
+            f_visit_select,
+            f_visit_ramp,
+            f_visit_broadcast,
+            f_visit_shuffle,
+            f_visit_int_imm,
+            f_visit_float_imm,
+            f_visit_string_imm,
+        )
+
+
+class PyStmtExprMutator:
+    """
+    A Python StmtExprMutator to define custom mutator for both Stmt and 
PrimExpr.
+
+    Users can customize any of the visit function.
+    """
+
+    _tvm_metadata = {
+        "cls": _PyStmtExprMutator,
+        "methods": [
+            "visit_stmt",
+            "visit_expr",
+            # Stmt
+            "visit_let_stmt_",
+            "visit_attr_stmt_",
+            "visit_if_then_else_",
+            "visit_for_",
+            "visit_while_",
+            "visit_allocate_",
+            "visit_allocate_const_",
+            "visit_decl_buffer_",
+            "visit_buffer_store_",
+            "visit_buffer_realize_",
+            "visit_assert_stmt_",
+            "visit_seq_stmt_",
+            "visit_evaluate_",
+            "visit_block_",
+            "visit_block_realize_",
+            # PrimExpr
+            "visit_var_",
+            "visit_size_var_",
+            "visit_buffer_load_",
+            "visit_producer_load_",
+            "visit_let_",
+            "visit_call_",
+            "visit_add_",
+            "visit_sub_",
+            "visit_mul_",
+            "visit_div_",
+            "visit_mod_",
+            "visit_floor_div_",
+            "visit_floor_mod_",
+            "visit_min_",
+            "visit_max_",
+            "visit_eq_",
+            "visit_ne_",
+            "visit_lt_",
+            "visit_le_",
+            "visit_gt_",
+            "visit_ge_",
+            "visit_and_",
+            "visit_or_",
+            "visit_reduce_",
+            "visit_cast_",
+            "visit_not_",
+            "visit_select_",
+            "visit_ramp_",
+            "visit_broadcast_",
+            "visit_shuffle_",
+            "visit_int_imm_",
+            "visit_float_imm_",
+            "visit_string_imm_",
+        ],
+    }
+
+    def visit_expr(self, expr: PrimExpr) -> PrimExpr:
+        """Visit PrimExpr.
+        Users can customize this function to overwrite VisitExpr(const 
PrimExpr& expr)
+        on the C++ side.
+
+        Parameters
+        ----------
+        expr : PrimExpr
+            The PrimExpr to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorVisitExpr(self._outer(), expr)  # 
type: ignore
+
+    def visit_stmt(self, stmt: Stmt) -> Stmt:
+        """Visit Stmt.
+        Users can customize this function to overwrite VisitStmt(const Stmt& 
stmt)
+        on the C++ side.
+
+        Parameters
+        ----------
+        stmt : Stmt
+            The Stmt to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorVisitStmt(self._outer(), stmt)  # 
type: ignore
+
+    def visit_attr_stmt_(self, op: AttrStmt) -> Stmt:
+        """Visit AttrStmt.
+        Users can customize this function to overwrite VisitStmt_(const 
AttrStmtNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : AttrStmt
+            The AttrStmt to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_if_then_else_(self, op: IfThenElse) -> Stmt:
+        """Visit IfThenElse.
+        Users can customize this function to overwrite VisitStmt_(const 
IfThenElseNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : IfThenElse
+            The IfThenElse to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_let_stmt_(self, op: LetStmt) -> Stmt:
+        """Visit LetStmt.
+        Users can customize this function to overwrite VisitStmt_(const 
LetStmtNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : LetStmt
+            The LetStmt to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_for_(self, op: For) -> Stmt:
+        """Visit For.
+        Users can customize this function to overwrite VisitStmt_(const 
ForNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : For
+            The For to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_while_(self, op: While) -> Stmt:
+        """Visit While.
+        Users can customize this function to overwrite VisitStmt_(const 
WhileNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : While
+            The While to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_allocate_(self, op: Allocate) -> Stmt:
+        """Visit Allocate.
+        Users can customize this function to overwrite VisitStmt_(const 
AllocateNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Allocate
+            The Allocate to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_allocate_const_(self, op: AllocateConst) -> Stmt:
+        """Visit AllocateConst.
+        Users can customize this function to overwrite VisitStmt_(const 
AllocateConstNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : AllocateConst
+            The AllocateConst to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_decl_buffer_(self, op: DeclBuffer) -> Stmt:
+        """Visit DeclBuffer.
+        Users can customize this function to overwrite VisitStmt_(const 
DeclBufferNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : DeclBuffer
+            The DeclBuffer to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_buffer_store_(self, op: BufferStore) -> Stmt:
+        """Visit BufferStore.
+        Users can customize this function to overwrite VisitStmt_(const 
BufferStoreNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : BufferStore
+            The BufferStore to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_buffer_realize_(self, op: BufferRealize) -> Stmt:
+        """Visit BufferRealize.
+        Users can customize this function to overwrite VisitStmt_(const 
BufferRealizeNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : BufferRealize
+            The BufferRealize to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_assert_stmt_(self, op: AssertStmt) -> Stmt:
+        """Visit AssertStmt.
+        Users can customize this function to overwrite VisitStmt_(const 
AssertStmtNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : AssertStmt
+            The AssertStmt to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_seq_stmt_(self, op: SeqStmt) -> Stmt:
+        """Visit SeqStmt.
+        Users can customize this function to overwrite VisitStmt_(const 
SeqStmtNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : SeqStmt
+            The SeqStmt to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_evaluate_(self, op: Evaluate) -> Stmt:
+        """Visit Evaluate.
+        Users can customize this function to overwrite VisitStmt_(const 
EvaluateNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Evaluate
+            The Evaluate to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_block_(self, op: Block) -> Stmt:
+        """Visit Block.
+        Users can customize this function to overwrite VisitStmt_(const 
BlockNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Block
+            The Block to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_block_realize_(self, op: BlockRealize) -> Stmt:
+        """Visit BlockRealize.
+        Users can customize this function to overwrite VisitStmt_(const 
BlockRealizeNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : BlockRealize
+            The BlockRealize to be visited.
+
+        Returns
+        -------
+        result : Stmt
+            The mutated Stmt.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)  
# type: ignore
+
+    def visit_var_(self, op: Var) -> PrimExpr:
+        """Visit Var.
+
+        Users can customize this function to overwrite VisitVar_(const 
VarNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Var
+            The Var to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_size_var_(self, op: SizeVar) -> PrimExpr:
+        """Visit SizeVar.
+
+        Users can customize this function to overwrite VisitSizeVar_(const 
SizeVarNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : SizeVar
+            The SizeVar to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_buffer_load_(self, op: BufferLoad) -> PrimExpr:
+        """Visit BufferLoad.
+
+        Users can customize this function to overwrite VisitBufferLoad_(const 
BufferLoadNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : BufferLoad
+            The BufferLoad to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_producer_load_(self, op: ProducerLoad) -> PrimExpr:
+        """Visit ProducerLoad.
+
+        Users can customize this function to overwrite
+        VisitProducerLoad_(const ProducerLoadNode* op) on the C++ side.
+
+        Parameters
+        ----------
+        op : ProducerLoad
+            The ProducerLoad to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_let_(self, op: Let) -> PrimExpr:
+        """Visit Let.
+
+        Users can customize this function to overwrite VisitLet_(const 
LetNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Let
+            The Let to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_call_(self, op: Call) -> PrimExpr:
+        """Visit Call.
+
+        Users can customize this function to overwrite VisitCall_(const 
CallNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Call
+            The Call to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_add_(self, op: Add) -> PrimExpr:
+        """Visit Add.
+
+        Users can customize this function to overwrite VisitAdd_(const 
AddNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Add
+            The Add to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_sub_(self, op: Sub) -> PrimExpr:
+        """Visit Sub.
+
+        Users can customize this function to overwrite VisitSub_(const 
SubNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Sub
+            The Sub to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_mul_(self, op: Mul) -> PrimExpr:
+        """Visit Mul.
+
+        Users can customize this function to overwrite VisitMul_(const 
MulNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Mul
+            The Mul to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_div_(self, op: Div) -> PrimExpr:
+        """Visit Div.
+
+        Users can customize this function to overwrite VisitDiv_(const 
DivNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Div
+            The Div to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_mod_(self, op: Mod) -> PrimExpr:
+        """Visit Mod.
+
+        Users can customize this function to overwrite VisitMod_(const 
ModNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Mod
+            The Mod to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_floor_div_(self, op: FloorDiv) -> PrimExpr:
+        """Visit FloorDiv.
+
+        Users can customize this function to overwrite VisitFloorDiv_(const 
FloorDivNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : FloorDiv
+            The FloorDiv to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_floor_mod_(self, op: FloorMod) -> PrimExpr:
+        """Visit FloorMod.
+
+        Users can customize this function to overwrite VisitFloorMod_(const 
FloorModNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : FloorMod
+            The FloorMod to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_min_(self, op: Min) -> PrimExpr:
+        """Visit Min.
+
+        Users can customize this function to overwrite VisitMin_(const 
MinNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Min
+            The Min to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_max_(self, op: Max) -> PrimExpr:
+        """Visit Max.
+
+        Users can customize this function to overwrite VisitMax_(const 
MaxNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Max
+            The Max to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_eq_(self, op: EQ) -> PrimExpr:
+        """Visit EQ.
+
+        Users can customize this function to overwrite VisitEQ_(const EQNode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : EQ
+            The EQ to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_ne_(self, op: NE) -> PrimExpr:
+        """Visit NE.
+
+        Users can customize this function to overwrite VisitNE_(const NENode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : NE
+            The NE to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_lt_(self, op: LT) -> PrimExpr:
+        """Visit LT.
+
+        Users can customize this function to overwrite VisitLT_(const LTNode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : LT
+            The LT to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_le_(self, op: LE) -> PrimExpr:
+        """Visit LE.
+
+        Users can customize this function to overwrite VisitLE_(const LENode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : LE
+            The LE to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_gt_(self, op: GT) -> PrimExpr:
+        """Visit GT.
+
+        Users can customize this function to overwrite VisitGT_(const GTNode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : GT
+            The GT to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_ge_(self, op: GE) -> PrimExpr:
+        """Visit GE.
+
+        Users can customize this function to overwrite VisitGE_(const GENode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : GE
+            The GE to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_and_(self, op: And) -> PrimExpr:
+        """Visit And.
+
+        Users can customize this function to overwrite VisitAnd_(const 
AndNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : And
+            The And to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_or_(self, op: Or) -> PrimExpr:
+        """Visit Or.
+
+        Users can customize this function to overwrite VisitOr_(const OrNode* 
op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Or
+            The Or to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_reduce_(self, op: Reduce) -> PrimExpr:
+        """Visit Reduce.
+
+        Users can customize this function to overwrite VisitReduce_(const 
ReduceNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Reduce
+            The Reduce to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_cast_(self, op: Cast) -> PrimExpr:
+        """Visit Cast.
+
+        Users can customize this function to overwrite VisitCast_(const 
CastNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Cast
+            The Cast to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_not_(self, op: Not) -> PrimExpr:
+        """Visit Not.
+
+        Users can customize this function to overwrite VisitNot_(const 
NotNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Not
+            The Not to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_select_(self, op: Select) -> PrimExpr:
+        """Visit Select.
+
+        Users can customize this function to overwrite VisitSelect_(const 
SelectNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Select
+            The Select to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_ramp_(self, op: Ramp) -> PrimExpr:
+        """Visit Ramp.
+
+        Users can customize this function to overwrite VisitRamp_(const 
RampNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Ramp
+            The Ramp to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_broadcast_(self, op: Broadcast) -> PrimExpr:
+        """Visit Broadcast.
+
+        Users can customize this function to overwrite VisitBroadcast_(const 
BroadcastNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Broadcast
+            The Broadcast to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_shuffle_(self, op: Shuffle) -> PrimExpr:
+        """Visit Shuffle.
+
+        Users can customize this function to overwrite VisitShuffle_(const 
ShuffleNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : Shuffle
+            The Shuffle to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_int_imm_(self, op: IntImm) -> PrimExpr:
+        """Visit IntImm.
+
+        Users can customize this function to overwrite VisitIntImm_(const 
IntImmNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : IntImm
+            The IntImm to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_float_imm_(self, op: FloatImm) -> PrimExpr:
+        """Visit FloatImm.
+
+        Users can customize this function to overwrite VisitFloatImm_(const 
FloatImmNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : FloatImm
+            The FloatImm to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
+
+    def visit_string_imm_(self, op: StringImm) -> PrimExpr:
+        """Visit StringImm.
+
+        Users can customize this function to overwrite VisitStringImm_(const 
StringImmNode* op)
+        on the C++ side.
+
+        Parameters
+        ----------
+        op : StringImm
+            The StringImm to be visited.
+
+        Returns
+        -------
+        result : PrimExpr
+            The mutated PrimExpr.
+        """
+        return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)  
# type: ignore
diff --git a/src/tir/analysis/buffer_access_lca_detector.cc 
b/src/tir/analysis/buffer_access_lca_detector.cc
index aca4c99e11..7ac84ce894 100644
--- a/src/tir/analysis/buffer_access_lca_detector.cc
+++ b/src/tir/analysis/buffer_access_lca_detector.cc
@@ -147,7 +147,7 @@ class LCADetector : public StmtExprVisitor {
     auto do_collect_itervar_scope = [this](const IterVar& itervar,
                                            const PrimExpr& binding) -> const 
ScopeInfo* {
       const ScopeInfo* highest_scope = nullptr;
-      PostOrderVisit(binding, [this, &itervar, &highest_scope](const 
ObjectRef& obj) {
+      PostOrderVisit(binding, [this, &highest_scope](const ObjectRef& obj) {
         if (const VarNode* loop_var = obj.as<VarNode>()) {
           auto it = loop_scope_map_.find(loop_var);
           if (it == loop_scope_map_.end()) {
diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc
new file mode 100644
index 0000000000..6152c99eaf
--- /dev/null
+++ b/src/tir/ir/py_functor.cc
@@ -0,0 +1,859 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/py_functor.cc
+ * \brief The python interface of ExprVisitor/ExprMutator, 
StmtVisitor/StmtMutator,
+ *        StmtExprVisitor/StmtExprMutator.
+ */
+
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+// ================================================
+// Helper Macros
+// ================================================
+#define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \
+  void VisitExpr_(const OP* op) override {    \
+    if (PY_FUNC != nullptr) {                 \
+      PY_FUNC(op);                            \
+    } else {                                  \
+      StmtExprVisitor::VisitExpr_(op);        \
+    }                                         \
+  }
+
+#define IR_EXPR_VISITOR_DEFAULT_DISPATCH(OP)                             \
+  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
+    self->StmtExprVisitor::VisitExpr_(static_cast<const OP*>(n.get()));  \
+  });
+
+#define PY_STMT_VISITOR_DISPATCH(OP, PY_FUNC) \
+  void VisitStmt_(const OP* op) override {    \
+    if (PY_FUNC != nullptr) {                 \
+      PY_FUNC(op);                            \
+    } else {                                  \
+      StmtExprVisitor::VisitStmt_(op);        \
+    }                                         \
+  }
+
+#define PY_STMT_VISITOR_DEFAULT_DISPATCH(OP)                             \
+  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
+    self->StmtExprVisitor::VisitStmt_(static_cast<const OP*>(n.get()));  \
+  });
+
+#define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC)  \
+  PrimExpr VisitExpr_(const OP* op) override { \
+    if (PY_FUNC != nullptr) {                  \
+      return PY_FUNC(op).cast<PrimExpr>();     \
+    } else {                                   \
+      return StmtExprMutator::VisitExpr_(op);  \
+    }                                          \
+  }
+
+#define PY_EXPR_MUTATOR_DEFAULT_DISPATCH(OP)                                   
\
+  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) {       
\
+    return self->StmtExprMutator::VisitExpr_(static_cast<const OP*>(n.get())); 
\
+  });
+
+#define PY_STMT_MUTATOR_DISPATCH(OP, PY_FUNC) \
+  Stmt VisitStmt_(const OP* op) override {    \
+    if (PY_FUNC != nullptr) {                 \
+      return PY_FUNC(op).cast<Stmt>();        \
+    } else {                                  \
+      return StmtExprMutator::VisitStmt_(op); \
+    }                                         \
+  }
+
+#define PY_STMT_MUTATOR_DEFAULT_DISPATCH(OP)                                   
\
+  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) {       
\
+    return self->StmtExprMutator::VisitStmt_(static_cast<const OP*>(n.get())); 
\
+  });
+
+/*! \brief The python interface of StmtExprVisitor. */
+class PyStmtExprVisitorNode : public Object, public StmtExprVisitor {
+ private:
+  using TSelf = PyStmtExprVisitorNode;
+  using FExprType = tvm::NodeFunctor<void(const ObjectRef& n, TSelf* self)>;
+  using FStmtType = tvm::NodeFunctor<void(const ObjectRef& n, TSelf* self)>;
+
+ public:
+  // Expression functions
+  /*! \brief The packed function to the `VisitExpr(const Expr& expr)` 
function. */
+  ffi::Function f_visit_expr{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` 
function. */
+  ffi::Function f_visit_var{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const SizeVarNode* op)` 
function. */
+  ffi::Function f_visit_size_var{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const BufferLoadNode* op)` 
function. */
+  ffi::Function f_visit_buffer_load{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const ProducerLoadNode* 
op)` function. */
+  ffi::Function f_visit_producer_load{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const LetNode* op)` 
function. */
+  ffi::Function f_visit_let{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` 
function. */
+  ffi::Function f_visit_call{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const AddNode* op)` 
function. */
+  ffi::Function f_visit_add{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const SubNode* op)` 
function. */
+  ffi::Function f_visit_sub{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const MulNode* op)` 
function. */
+  ffi::Function f_visit_mul{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const DivNode* op)` 
function. */
+  ffi::Function f_visit_div{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const ModNode* op)` 
function. */
+  ffi::Function f_visit_mod{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const FloorDivNode* op)` 
function. */
+  ffi::Function f_visit_floor_div{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const FloorModNode* op)` 
function. */
+  ffi::Function f_visit_floor_mod{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const MinNode* op)` 
function. */
+  ffi::Function f_visit_min{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const MaxNode* op)` 
function. */
+  ffi::Function f_visit_max{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const EQNode* op)` 
function. */
+  ffi::Function f_visit_eq{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const NENode* op)` 
function. */
+  ffi::Function f_visit_ne{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const LTNode* op)` 
function. */
+  ffi::Function f_visit_lt{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const LENode* op)` 
function. */
+  ffi::Function f_visit_le{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const GTNode* op)` 
function. */
+  ffi::Function f_visit_gt{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const GENode* op)` 
function. */
+  ffi::Function f_visit_ge{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const AndNode* op)` 
function. */
+  ffi::Function f_visit_and{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const OrNode* op)` 
function. */
+  ffi::Function f_visit_or{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const ReduceNode* op)` 
function. */
+  ffi::Function f_visit_reduce{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const CastNode* op)` 
function. */
+  ffi::Function f_visit_cast{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const NotNode* op)` 
function. */
+  ffi::Function f_visit_not{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const SelectNode* op)` 
function. */
+  ffi::Function f_visit_select{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const RampNode* op)` 
function. */
+  ffi::Function f_visit_ramp{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const BroadcastNode* op)` 
function. */
+  ffi::Function f_visit_broadcast{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const ShuffleNode* op)` 
function. */
+  ffi::Function f_visit_shuffle{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const IntImmNode* op)` 
function. */
+  ffi::Function f_visit_int_imm{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const FloatImmNode* op)` 
function. */
+  ffi::Function f_visit_float_imm{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` 
function. */
+  ffi::Function f_visit_string_imm{nullptr};
+
+  // Statement functions
+  /*! \brief The packed function to the `VisitStmt(const Stmt& stmt)` 
function. */
+  ffi::Function f_visit_stmt{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)` 
function. */
+  ffi::Function f_visit_attr_stmt{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)` 
function. */
+  ffi::Function f_visit_if_then_else{nullptr};  // NOLINT(readability/braces)
+  /*! \brief The packed function to the `VisitStmt_(const ForNode* op)` 
function. */
+  ffi::Function f_visit_let_stmt{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` 
function. */
+  ffi::Function f_visit_for{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const WhileNode* op)` 
function. */
+  ffi::Function f_visit_while{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const AllocateNode* op)` 
function. */
+  ffi::Function f_visit_allocate{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const AllocateConstNode* 
op)` function. */
+  ffi::Function f_visit_allocate_const{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const DeclBufferNode* op)` 
function. */
+  ffi::Function f_visit_decl_buffer{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const BufferStoreNode* 
op)` function. */
+  ffi::Function f_visit_buffer_store{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const BufferRealizeNode* 
op)` function. */
+  ffi::Function f_visit_buffer_realize{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const AssertStmtNode* op)` 
function. */
+  ffi::Function f_visit_assert_stmt{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const SeqStmtNode* op)` 
function. */
+  ffi::Function f_visit_seq_stmt{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const EvaluateNode* op)` 
function. */
+  ffi::Function f_visit_evaluate{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const BlockNode* op)` 
function. */
+  ffi::Function f_visit_block{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const BlockRealizeNode* 
op)` function. */
+  ffi::Function f_visit_block_realize{nullptr};
+
+  using StmtExprVisitor::VisitExpr;
+  using StmtExprVisitor::VisitStmt;
+
+  void DefaultVisitExpr(const PrimExpr& expr) {
+    static FExprType vtable = InitExprVTable();
+    vtable(expr, this);
+  }
+
+  void DefaultVisitStmt(const Stmt& stmt) {
+    static FStmtType vtable = InitStmtVTable();
+    vtable(stmt, this);
+  }
+
+  void VisitAttrs(AttrVisitor* v) {}
+  static constexpr const char* _type_key = "tir.PyStmtExprVisitor";
+  TVM_DECLARE_BASE_OBJECT_INFO(PyStmtExprVisitorNode, Object);
+
+ private:
+  // Statement functions
+  PY_STMT_VISITOR_DISPATCH(LetStmtNode, f_visit_let_stmt);
+  PY_STMT_VISITOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt);
+  PY_STMT_VISITOR_DISPATCH(IfThenElseNode, f_visit_if_then_else);
+  PY_STMT_VISITOR_DISPATCH(ForNode, f_visit_for);
+  PY_STMT_VISITOR_DISPATCH(WhileNode, f_visit_while);
+  PY_STMT_VISITOR_DISPATCH(AllocateNode, f_visit_allocate);
+  PY_STMT_VISITOR_DISPATCH(AllocateConstNode, f_visit_allocate_const);
+  PY_STMT_VISITOR_DISPATCH(DeclBufferNode, f_visit_decl_buffer);
+  PY_STMT_VISITOR_DISPATCH(BufferStoreNode, f_visit_buffer_store);
+  PY_STMT_VISITOR_DISPATCH(BufferRealizeNode, f_visit_buffer_realize);
+  PY_STMT_VISITOR_DISPATCH(AssertStmtNode, f_visit_assert_stmt);
+  PY_STMT_VISITOR_DISPATCH(SeqStmtNode, f_visit_seq_stmt);
+  PY_STMT_VISITOR_DISPATCH(EvaluateNode, f_visit_evaluate);
+  PY_STMT_VISITOR_DISPATCH(BlockNode, f_visit_block);
+  PY_STMT_VISITOR_DISPATCH(BlockRealizeNode, f_visit_block_realize);
+  // Expression functions
+  PY_EXPR_VISITOR_DISPATCH(VarNode, f_visit_var);
+  PY_EXPR_VISITOR_DISPATCH(SizeVarNode, f_visit_size_var);
+  PY_EXPR_VISITOR_DISPATCH(BufferLoadNode, f_visit_buffer_load);
+  PY_EXPR_VISITOR_DISPATCH(ProducerLoadNode, f_visit_producer_load);
+  PY_EXPR_VISITOR_DISPATCH(LetNode, f_visit_let);
+  PY_EXPR_VISITOR_DISPATCH(CallNode, f_visit_call);
+  PY_EXPR_VISITOR_DISPATCH(AddNode, f_visit_add);
+  PY_EXPR_VISITOR_DISPATCH(SubNode, f_visit_sub);
+  PY_EXPR_VISITOR_DISPATCH(MulNode, f_visit_mul);
+  PY_EXPR_VISITOR_DISPATCH(DivNode, f_visit_div);
+  PY_EXPR_VISITOR_DISPATCH(ModNode, f_visit_mod);
+  PY_EXPR_VISITOR_DISPATCH(FloorDivNode, f_visit_floor_div);
+  PY_EXPR_VISITOR_DISPATCH(FloorModNode, f_visit_floor_mod);
+  PY_EXPR_VISITOR_DISPATCH(MinNode, f_visit_min);
+  PY_EXPR_VISITOR_DISPATCH(MaxNode, f_visit_max);
+  PY_EXPR_VISITOR_DISPATCH(EQNode, f_visit_eq);
+  PY_EXPR_VISITOR_DISPATCH(NENode, f_visit_ne);
+  PY_EXPR_VISITOR_DISPATCH(LTNode, f_visit_lt);
+  PY_EXPR_VISITOR_DISPATCH(LENode, f_visit_le);
+  PY_EXPR_VISITOR_DISPATCH(GTNode, f_visit_gt);
+  PY_EXPR_VISITOR_DISPATCH(GENode, f_visit_ge);
+  PY_EXPR_VISITOR_DISPATCH(AndNode, f_visit_and);
+  PY_EXPR_VISITOR_DISPATCH(OrNode, f_visit_or);
+  PY_EXPR_VISITOR_DISPATCH(ReduceNode, f_visit_reduce);
+  PY_EXPR_VISITOR_DISPATCH(CastNode, f_visit_cast);
+  PY_EXPR_VISITOR_DISPATCH(NotNode, f_visit_not);
+  PY_EXPR_VISITOR_DISPATCH(SelectNode, f_visit_select);
+  PY_EXPR_VISITOR_DISPATCH(RampNode, f_visit_ramp);
+  PY_EXPR_VISITOR_DISPATCH(BroadcastNode, f_visit_broadcast);
+  PY_EXPR_VISITOR_DISPATCH(ShuffleNode, f_visit_shuffle);
+  PY_EXPR_VISITOR_DISPATCH(IntImmNode, f_visit_int_imm);
+  PY_EXPR_VISITOR_DISPATCH(FloatImmNode, f_visit_float_imm);
+  PY_EXPR_VISITOR_DISPATCH(StringImmNode, f_visit_string_imm);
+
+ private:
+  static FExprType InitExprVTable() {
+    FExprType vtable;
+    // Set dispatch
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(VarNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(SizeVarNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(BufferLoadNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(ProducerLoadNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(LetNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(CallNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(AddNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(SubNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(MulNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(DivNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(ModNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(FloorDivNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(FloorModNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(MinNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(MaxNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(EQNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(NENode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(LTNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(LENode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(GTNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(GENode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(AndNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(OrNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(ReduceNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(CastNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(NotNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(SelectNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(RampNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(ShuffleNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(BroadcastNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(IntImmNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(FloatImmNode);
+    IR_EXPR_VISITOR_DEFAULT_DISPATCH(StringImmNode);
+    vtable.Finalize();
+    return vtable;
+  }
+
+  static FStmtType InitStmtVTable() {
+    FStmtType vtable;
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(LetStmtNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(AttrStmtNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(IfThenElseNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(ForNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(WhileNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(AllocateNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(AllocateConstNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(DeclBufferNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(BufferStoreNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(BufferRealizeNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(AssertStmtNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(SeqStmtNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(EvaluateNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(BlockNode);
+    PY_STMT_VISITOR_DEFAULT_DISPATCH(BlockRealizeNode);
+    vtable.Finalize();
+    return vtable;
+  }
+};
+
+/*!
+ * \brief Managed reference to PyStmtExprVisitorNode.
+ * \sa PyStmtExprVisitorNode
+ */
+class PyStmtExprVisitor : public ObjectRef {
+ public:
+  TVM_DLL static PyStmtExprVisitor MakePyStmtExprVisitor(ffi::Function 
f_visit_stmt,            //
+                                                         ffi::Function 
f_visit_expr,            //
+                                                         ffi::Function 
f_visit_let_stmt,        //
+                                                         ffi::Function 
f_visit_attr_stmt,       //
+                                                         ffi::Function 
f_visit_if_then_else,    //
+                                                         ffi::Function 
f_visit_for,             //
+                                                         ffi::Function 
f_visit_while,           //
+                                                         ffi::Function 
f_visit_allocate,        //
+                                                         ffi::Function 
f_visit_allocate_const,  //
+                                                         ffi::Function 
f_visit_decl_buffer,     //
+                                                         ffi::Function 
f_visit_buffer_store,    //
+                                                         ffi::Function 
f_visit_buffer_realize,  //
+                                                         ffi::Function 
f_visit_assert_stmt,     //
+                                                         ffi::Function 
f_visit_seq_stmt,        //
+                                                         ffi::Function 
f_visit_evaluate,        //
+                                                         ffi::Function 
f_visit_block,           //
+                                                         ffi::Function 
f_visit_block_realize,   //
+                                                         ffi::Function 
f_visit_var,             //
+                                                         ffi::Function 
f_visit_size_var,        //
+                                                         ffi::Function 
f_visit_buffer_load,     //
+                                                         ffi::Function 
f_visit_producer_load,   //
+                                                         ffi::Function 
f_visit_let,             //
+                                                         ffi::Function 
f_visit_call,            //
+                                                         ffi::Function 
f_visit_add,             //
+                                                         ffi::Function 
f_visit_sub,             //
+                                                         ffi::Function 
f_visit_mul,             //
+                                                         ffi::Function 
f_visit_div,             //
+                                                         ffi::Function 
f_visit_mod,             //
+                                                         ffi::Function 
f_visit_floor_div,       //
+                                                         ffi::Function 
f_visit_floor_mod,       //
+                                                         ffi::Function 
f_visit_min,             //
+                                                         ffi::Function 
f_visit_max,             //
+                                                         ffi::Function 
f_visit_eq,              //
+                                                         ffi::Function 
f_visit_ne,              //
+                                                         ffi::Function 
f_visit_lt,              //
+                                                         ffi::Function 
f_visit_le,              //
+                                                         ffi::Function 
f_visit_gt,              //
+                                                         ffi::Function 
f_visit_ge,              //
+                                                         ffi::Function 
f_visit_and,             //
+                                                         ffi::Function 
f_visit_or,              //
+                                                         ffi::Function 
f_visit_reduce,          //
+                                                         ffi::Function 
f_visit_cast,            //
+                                                         ffi::Function 
f_visit_not,             //
+                                                         ffi::Function 
f_visit_select,          //
+                                                         ffi::Function 
f_visit_ramp,            //
+                                                         ffi::Function 
f_visit_broadcast,       //
+                                                         ffi::Function 
f_visit_shuffle,         //
+                                                         ffi::Function 
f_visit_int_imm,         //
+                                                         ffi::Function 
f_visit_float_imm,       //
+                                                         ffi::Function 
f_visit_string_imm) {
+    ObjectPtr<PyStmtExprVisitorNode> n = make_object<PyStmtExprVisitorNode>();
+    n->f_visit_stmt = std::move(f_visit_stmt);
+    n->f_visit_expr = std::move(f_visit_expr);
+    // Set statement functions
+    n->f_visit_let_stmt = std::move(f_visit_let_stmt);
+    n->f_visit_attr_stmt = std::move(f_visit_attr_stmt);
+    n->f_visit_if_then_else = std::move(f_visit_if_then_else);
+    n->f_visit_for = std::move(f_visit_for);
+    n->f_visit_while = std::move(f_visit_while);
+    n->f_visit_allocate = std::move(f_visit_allocate);
+    n->f_visit_allocate_const = std::move(f_visit_allocate_const);
+    n->f_visit_decl_buffer = std::move(f_visit_decl_buffer);
+    n->f_visit_buffer_store = std::move(f_visit_buffer_store);
+    n->f_visit_buffer_realize = std::move(f_visit_buffer_realize);
+    n->f_visit_assert_stmt = std::move(f_visit_assert_stmt);
+    n->f_visit_seq_stmt = std::move(f_visit_seq_stmt);
+    n->f_visit_evaluate = std::move(f_visit_evaluate);
+    n->f_visit_block = std::move(f_visit_block);
+    n->f_visit_block_realize = std::move(f_visit_block_realize);
+    // Set expression functions
+    n->f_visit_var = std::move(f_visit_var);
+    n->f_visit_size_var = std::move(f_visit_size_var);
+    n->f_visit_buffer_load = std::move(f_visit_buffer_load);
+    n->f_visit_producer_load = std::move(f_visit_producer_load);
+    n->f_visit_let = std::move(f_visit_let);
+    n->f_visit_call = std::move(f_visit_call);
+    n->f_visit_add = std::move(f_visit_add);
+    n->f_visit_sub = std::move(f_visit_sub);
+    n->f_visit_mul = std::move(f_visit_mul);
+    n->f_visit_div = std::move(f_visit_div);
+    n->f_visit_mod = std::move(f_visit_mod);
+    n->f_visit_floor_div = std::move(f_visit_floor_div);
+    n->f_visit_floor_mod = std::move(f_visit_floor_mod);
+    n->f_visit_min = std::move(f_visit_min);
+    n->f_visit_max = std::move(f_visit_max);
+    n->f_visit_eq = std::move(f_visit_eq);
+    n->f_visit_ne = std::move(f_visit_ne);
+    n->f_visit_lt = std::move(f_visit_lt);
+    n->f_visit_le = std::move(f_visit_le);
+    n->f_visit_gt = std::move(f_visit_gt);
+    n->f_visit_ge = std::move(f_visit_ge);
+    n->f_visit_and = std::move(f_visit_and);
+    n->f_visit_or = std::move(f_visit_or);
+    n->f_visit_reduce = std::move(f_visit_reduce);
+    n->f_visit_cast = std::move(f_visit_cast);
+    n->f_visit_not = std::move(f_visit_not);
+    n->f_visit_select = std::move(f_visit_select);
+    n->f_visit_ramp = std::move(f_visit_ramp);
+    n->f_visit_broadcast = std::move(f_visit_broadcast);
+    n->f_visit_shuffle = std::move(f_visit_shuffle);
+    n->f_visit_int_imm = std::move(f_visit_int_imm);
+    n->f_visit_float_imm = std::move(f_visit_float_imm);
+    n->f_visit_string_imm = std::move(f_visit_string_imm);
+    return PyStmtExprVisitor(n);
+  }
+
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyStmtExprVisitor, 
ObjectRef,
+                                                    PyStmtExprVisitorNode);
+};
+
+/*! \brief The python interface of StmtExprMutator. */
+class PyStmtExprMutatorNode : public Object, public StmtExprMutator {
+ private:
+  using TSelf = PyStmtExprMutatorNode;
+  using FExprType = tvm::NodeFunctor<PrimExpr(const ObjectRef& n, TSelf* 
self)>;
+  using FStmtType = tvm::NodeFunctor<Stmt(const ObjectRef& n, TSelf* self)>;
+
+ public:
+  // Expression functions
+  /*! \brief The packed function to the `VisitExpr(const Expr& expr)` 
function. */
+  ffi::Function f_visit_expr{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` 
function. */
+  ffi::Function f_visit_var{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const SizeVarNode* op)` 
function. */
+  ffi::Function f_visit_size_var{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const BufferLoadNode* op)` 
function. */
+  ffi::Function f_visit_buffer_load{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const ProducerLoadNode* 
op)` function. */
+  ffi::Function f_visit_producer_load{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const LetNode* op)` 
function. */
+  ffi::Function f_visit_let{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` 
function. */
+  ffi::Function f_visit_call{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const AddNode* op)` 
function. */
+  ffi::Function f_visit_add{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const SubNode* op)` 
function. */
+  ffi::Function f_visit_sub{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const MulNode* op)` 
function. */
+  ffi::Function f_visit_mul{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const DivNode* op)` 
function. */
+  ffi::Function f_visit_div{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const ModNode* op)` 
function. */
+  ffi::Function f_visit_mod{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const FloorDivNode* op)` 
function. */
+  ffi::Function f_visit_floor_div{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const FloorModNode* op)` 
function. */
+  ffi::Function f_visit_floor_mod{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const MinNode* op)` 
function. */
+  ffi::Function f_visit_min{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const MaxNode* op)` 
function. */
+  ffi::Function f_visit_max{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const EQNode* op)` 
function. */
+  ffi::Function f_visit_eq{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const NENode* op)` 
function. */
+  ffi::Function f_visit_ne{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const LTNode* op)` 
function. */
+  ffi::Function f_visit_lt{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const LENode* op)` 
function. */
+  ffi::Function f_visit_le{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const GTNode* op)` 
function. */
+  ffi::Function f_visit_gt{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const GENode* op)` 
function. */
+  ffi::Function f_visit_ge{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const AndNode* op)` 
function. */
+  ffi::Function f_visit_and{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const OrNode* op)` 
function. */
+  ffi::Function f_visit_or{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const ReduceNode* op)` 
function. */
+  ffi::Function f_visit_reduce{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const CastNode* op)` 
function. */
+  ffi::Function f_visit_cast{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const NotNode* op)` 
function. */
+  ffi::Function f_visit_not{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const SelectNode* op)` 
function. */
+  ffi::Function f_visit_select{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const RampNode* op)` 
function. */
+  ffi::Function f_visit_ramp{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const BroadcastNode* op)` 
function. */
+  ffi::Function f_visit_broadcast{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const ShuffleNode* op)` 
function. */
+  ffi::Function f_visit_shuffle{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const IntImmNode* op)` 
function. */
+  ffi::Function f_visit_int_imm{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const FloatImmNode* op)` 
function. */
+  ffi::Function f_visit_float_imm{nullptr};
+  /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` 
function. */
+  ffi::Function f_visit_string_imm{nullptr};
+
+  // Statement functions
+  /*! \brief The packed function to the `VisitStmt(const Stmt& stmt)` 
function. */
+  ffi::Function f_visit_stmt{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)` 
function. */
+  ffi::Function f_visit_let_stmt{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` 
function. */
+  ffi::Function f_visit_attr_stmt{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)` 
function. */
+  ffi::Function f_visit_if_then_else{nullptr};  // NOLINT(readability/braces)
+  /*! \brief The packed function to the `VisitStmt_(const ForNode* op)` 
function. */
+  ffi::Function f_visit_for{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const WhileNode* op)` 
function. */
+  ffi::Function f_visit_while{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const AllocateNode* op)` 
function. */
+  ffi::Function f_visit_allocate{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const AllocateConstNode* 
op)` function. */
+  ffi::Function f_visit_allocate_const{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const DeclBufferNode* op)` 
function. */
+  ffi::Function f_visit_decl_buffer{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const BufferStoreNode* 
op)` function. */
+  ffi::Function f_visit_buffer_store{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const BufferRealizeNode* 
op)` function. */
+  ffi::Function f_visit_buffer_realize{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const AssertStmtNode* op)` 
function. */
+  ffi::Function f_visit_assert_stmt{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const SeqStmtNode* op)` 
function. */
+  ffi::Function f_visit_seq_stmt{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const EvaluateNode* op)` 
function. */
+  ffi::Function f_visit_evaluate{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const BlockNode* op)` 
function. */
+  ffi::Function f_visit_block{nullptr};
+  /*! \brief The packed function to the `VisitStmt_(const BlockRealizeNode* 
op)` function. */
+  ffi::Function f_visit_block_realize{nullptr};
+
+  using StmtExprMutator::VisitExpr;
+  using StmtExprMutator::VisitStmt;
+
+  void DefaultVisitExpr(const PrimExpr& expr) {
+    static FExprType vtable = InitExprVTable();
+    vtable(expr, this);
+  }
+
+  void DefaultVisitStmt(const Stmt& stmt) {
+    static FStmtType vtable = InitStmtVTable();
+    vtable(stmt, this);
+  }
+  void VisitAttrs(AttrVisitor* v) {}
+  static constexpr const char* _type_key = "tir.PyStmtExprMutator";
+  TVM_DECLARE_BASE_OBJECT_INFO(PyStmtExprMutatorNode, Object);
+
+ private:
+  // Statement functions
+  PY_STMT_MUTATOR_DISPATCH(LetStmtNode, f_visit_let_stmt);
+  PY_STMT_MUTATOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt);
+  PY_STMT_MUTATOR_DISPATCH(IfThenElseNode, f_visit_if_then_else);
+  PY_STMT_MUTATOR_DISPATCH(ForNode, f_visit_for);
+  PY_STMT_MUTATOR_DISPATCH(WhileNode, f_visit_while);
+  PY_STMT_MUTATOR_DISPATCH(AllocateNode, f_visit_allocate);
+  PY_STMT_MUTATOR_DISPATCH(AllocateConstNode, f_visit_allocate_const);
+  PY_STMT_MUTATOR_DISPATCH(DeclBufferNode, f_visit_decl_buffer);
+  PY_STMT_MUTATOR_DISPATCH(BufferStoreNode, f_visit_buffer_store);
+  PY_STMT_MUTATOR_DISPATCH(BufferRealizeNode, f_visit_buffer_realize);
+  PY_STMT_MUTATOR_DISPATCH(AssertStmtNode, f_visit_assert_stmt);
+  PY_STMT_MUTATOR_DISPATCH(SeqStmtNode, f_visit_seq_stmt);
+  PY_STMT_MUTATOR_DISPATCH(EvaluateNode, f_visit_evaluate);
+  PY_STMT_MUTATOR_DISPATCH(BlockNode, f_visit_block);
+  PY_STMT_MUTATOR_DISPATCH(BlockRealizeNode, f_visit_block_realize);
+  // Expression functions
+  PY_EXPR_MUTATOR_DISPATCH(VarNode, f_visit_var);
+  PY_EXPR_MUTATOR_DISPATCH(SizeVarNode, f_visit_size_var);
+  PY_EXPR_MUTATOR_DISPATCH(BufferLoadNode, f_visit_buffer_load);
+  PY_EXPR_MUTATOR_DISPATCH(ProducerLoadNode, f_visit_producer_load);
+  PY_EXPR_MUTATOR_DISPATCH(LetNode, f_visit_let);
+  PY_EXPR_MUTATOR_DISPATCH(CallNode, f_visit_call);
+  PY_EXPR_MUTATOR_DISPATCH(AddNode, f_visit_add);
+  PY_EXPR_MUTATOR_DISPATCH(SubNode, f_visit_sub);
+  PY_EXPR_MUTATOR_DISPATCH(MulNode, f_visit_mul);
+  PY_EXPR_MUTATOR_DISPATCH(DivNode, f_visit_div);
+  PY_EXPR_MUTATOR_DISPATCH(ModNode, f_visit_mod);
+  PY_EXPR_MUTATOR_DISPATCH(FloorDivNode, f_visit_floor_div);
+  PY_EXPR_MUTATOR_DISPATCH(FloorModNode, f_visit_floor_mod);
+  PY_EXPR_MUTATOR_DISPATCH(MinNode, f_visit_min);
+  PY_EXPR_MUTATOR_DISPATCH(MaxNode, f_visit_max);
+  PY_EXPR_MUTATOR_DISPATCH(EQNode, f_visit_eq);
+  PY_EXPR_MUTATOR_DISPATCH(NENode, f_visit_ne);
+  PY_EXPR_MUTATOR_DISPATCH(LTNode, f_visit_lt);
+  PY_EXPR_MUTATOR_DISPATCH(LENode, f_visit_le);
+  PY_EXPR_MUTATOR_DISPATCH(GTNode, f_visit_gt);
+  PY_EXPR_MUTATOR_DISPATCH(GENode, f_visit_ge);
+  PY_EXPR_MUTATOR_DISPATCH(AndNode, f_visit_and);
+  PY_EXPR_MUTATOR_DISPATCH(OrNode, f_visit_or);
+  PY_EXPR_MUTATOR_DISPATCH(ReduceNode, f_visit_reduce);
+  PY_EXPR_MUTATOR_DISPATCH(CastNode, f_visit_cast);
+  PY_EXPR_MUTATOR_DISPATCH(NotNode, f_visit_not);
+  PY_EXPR_MUTATOR_DISPATCH(SelectNode, f_visit_select);
+  PY_EXPR_MUTATOR_DISPATCH(RampNode, f_visit_ramp);
+  PY_EXPR_MUTATOR_DISPATCH(BroadcastNode, f_visit_broadcast);
+  PY_EXPR_MUTATOR_DISPATCH(ShuffleNode, f_visit_shuffle);
+  PY_EXPR_MUTATOR_DISPATCH(IntImmNode, f_visit_int_imm);
+  PY_EXPR_MUTATOR_DISPATCH(FloatImmNode, f_visit_float_imm);
+  PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm);
+
+ private:
+ private:
+  static FExprType InitExprVTable() {
+    FExprType vtable;
+    // Set dispatch
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(VarNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(SizeVarNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(BufferLoadNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ProducerLoadNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(LetNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(CallNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(AddNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(SubNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(MulNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(DivNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ModNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(FloorDivNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(FloorModNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(MinNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(MaxNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(EQNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(NENode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(LTNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(LENode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(GTNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(GENode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(AndNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(OrNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ReduceNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(CastNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(NotNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(SelectNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(RampNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ShuffleNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(BroadcastNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(IntImmNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(FloatImmNode);
+    PY_EXPR_MUTATOR_DEFAULT_DISPATCH(StringImmNode);
+    vtable.Finalize();
+    return vtable;
+  }
+
+  static FStmtType InitStmtVTable() {
+    FStmtType vtable;
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(LetStmtNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(AttrStmtNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(IfThenElseNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(ForNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(WhileNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(AllocateNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(AllocateConstNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(DeclBufferNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(BufferStoreNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(BufferRealizeNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(AssertStmtNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(SeqStmtNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(EvaluateNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(BlockNode);
+    PY_STMT_MUTATOR_DEFAULT_DISPATCH(BlockRealizeNode);
+    vtable.Finalize();
+    return vtable;
+  }
+};
+
+/*! \brief Managed reference to PyStmtExprMutatorNode. */
+class PyStmtExprMutator : public ObjectRef {
+ public:
+  /*!
+   * \brief Create a PyStmtExprMutator with customized methods on the 
python-side.
+   * \return The PyStmtExprMutator created.
+   */
+  TVM_DLL static PyStmtExprMutator MakePyStmtExprMutator(ffi::Function 
f_visit_stmt,            //
+                                                         ffi::Function 
f_visit_expr,            //
+                                                         ffi::Function 
f_visit_let_stmt,        //
+                                                         ffi::Function 
f_visit_attr_stmt,       //
+                                                         ffi::Function 
f_visit_if_then_else,    //
+                                                         ffi::Function 
f_visit_for,             //
+                                                         ffi::Function 
f_visit_while,           //
+                                                         ffi::Function 
f_visit_allocate,        //
+                                                         ffi::Function 
f_visit_allocate_const,  //
+                                                         ffi::Function 
f_visit_decl_buffer,     //
+                                                         ffi::Function 
f_visit_buffer_store,    //
+                                                         ffi::Function 
f_visit_buffer_realize,  //
+                                                         ffi::Function 
f_visit_assert_stmt,     //
+                                                         ffi::Function 
f_visit_seq_stmt,        //
+                                                         ffi::Function 
f_visit_evaluate,        //
+                                                         ffi::Function 
f_visit_block,           //
+                                                         ffi::Function 
f_visit_block_realize,   //
+                                                         ffi::Function 
f_visit_var,             //
+                                                         ffi::Function 
f_visit_size_var,        //
+                                                         ffi::Function 
f_visit_buffer_load,     //
+                                                         ffi::Function 
f_visit_producer_load,   //
+                                                         ffi::Function 
f_visit_let,             //
+                                                         ffi::Function 
f_visit_call,            //
+                                                         ffi::Function 
f_visit_add,             //
+                                                         ffi::Function 
f_visit_sub,             //
+                                                         ffi::Function 
f_visit_mul,             //
+                                                         ffi::Function 
f_visit_div,             //
+                                                         ffi::Function 
f_visit_mod,             //
+                                                         ffi::Function 
f_visit_floor_div,       //
+                                                         ffi::Function 
f_visit_floor_mod,       //
+                                                         ffi::Function 
f_visit_min,             //
+                                                         ffi::Function 
f_visit_max,             //
+                                                         ffi::Function 
f_visit_eq,              //
+                                                         ffi::Function 
f_visit_ne,              //
+                                                         ffi::Function 
f_visit_lt,              //
+                                                         ffi::Function 
f_visit_le,              //
+                                                         ffi::Function 
f_visit_gt,              //
+                                                         ffi::Function 
f_visit_ge,              //
+                                                         ffi::Function 
f_visit_and,             //
+                                                         ffi::Function 
f_visit_or,              //
+                                                         ffi::Function 
f_visit_reduce,          //
+                                                         ffi::Function 
f_visit_cast,            //
+                                                         ffi::Function 
f_visit_not,             //
+                                                         ffi::Function 
f_visit_select,          //
+                                                         ffi::Function 
f_visit_ramp,            //
+                                                         ffi::Function 
f_visit_broadcast,       //
+                                                         ffi::Function 
f_visit_shuffle,         //
+                                                         ffi::Function 
f_visit_int_imm,         //
+                                                         ffi::Function 
f_visit_float_imm,       //
+                                                         ffi::Function 
f_visit_string_imm) {
+    ObjectPtr<PyStmtExprMutatorNode> n = make_object<PyStmtExprMutatorNode>();
+    n->f_visit_stmt = std::move(f_visit_stmt);
+    n->f_visit_expr = std::move(f_visit_expr);
+    // Statement functions
+    n->f_visit_let_stmt = std::move(f_visit_let_stmt);
+    n->f_visit_attr_stmt = std::move(f_visit_attr_stmt);
+    n->f_visit_if_then_else = std::move(f_visit_if_then_else);
+    n->f_visit_for = std::move(f_visit_for);
+    n->f_visit_while = std::move(f_visit_while);
+    n->f_visit_allocate = std::move(f_visit_allocate);
+    n->f_visit_allocate_const = std::move(f_visit_allocate_const);
+    n->f_visit_decl_buffer = std::move(f_visit_decl_buffer);
+    n->f_visit_buffer_store = std::move(f_visit_buffer_store);
+    n->f_visit_buffer_realize = std::move(f_visit_buffer_realize);
+    n->f_visit_assert_stmt = std::move(f_visit_assert_stmt);
+    n->f_visit_seq_stmt = std::move(f_visit_seq_stmt);
+    n->f_visit_evaluate = std::move(f_visit_evaluate);
+    n->f_visit_block = std::move(f_visit_block);
+    n->f_visit_block_realize = std::move(f_visit_block_realize);
+    // Expression functions
+    n->f_visit_var = std::move(f_visit_var);
+    n->f_visit_size_var = std::move(f_visit_size_var);
+    n->f_visit_buffer_load = std::move(f_visit_buffer_load);
+    n->f_visit_producer_load = std::move(f_visit_producer_load);
+    n->f_visit_let = std::move(f_visit_let);
+    n->f_visit_call = std::move(f_visit_call);
+    n->f_visit_add = std::move(f_visit_add);
+    n->f_visit_sub = std::move(f_visit_sub);
+    n->f_visit_mul = std::move(f_visit_mul);
+    n->f_visit_div = std::move(f_visit_div);
+    n->f_visit_mod = std::move(f_visit_mod);
+    n->f_visit_floor_div = std::move(f_visit_floor_div);
+    n->f_visit_floor_mod = std::move(f_visit_floor_mod);
+    n->f_visit_min = std::move(f_visit_min);
+    n->f_visit_max = std::move(f_visit_max);
+    n->f_visit_eq = std::move(f_visit_eq);
+    n->f_visit_ne = std::move(f_visit_ne);
+    n->f_visit_lt = std::move(f_visit_lt);
+    n->f_visit_le = std::move(f_visit_le);
+    n->f_visit_gt = std::move(f_visit_gt);
+    n->f_visit_ge = std::move(f_visit_ge);
+    n->f_visit_and = std::move(f_visit_and);
+    n->f_visit_or = std::move(f_visit_or);
+    n->f_visit_reduce = std::move(f_visit_reduce);
+    n->f_visit_cast = std::move(f_visit_cast);
+    n->f_visit_not = std::move(f_visit_not);
+    n->f_visit_select = std::move(f_visit_select);
+    n->f_visit_ramp = std::move(f_visit_ramp);
+    n->f_visit_broadcast = std::move(f_visit_broadcast);
+    n->f_visit_shuffle = std::move(f_visit_shuffle);
+    n->f_visit_int_imm = std::move(f_visit_int_imm);
+    n->f_visit_float_imm = std::move(f_visit_float_imm);
+    n->f_visit_string_imm = std::move(f_visit_string_imm);
+    return PyStmtExprMutator(n);
+  }
+
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyStmtExprMutator, 
ObjectRef,
+                                                    PyStmtExprMutatorNode);
+};
+
+// ================================================
+// TVM Register
+// ================================================
+
+TVM_REGISTER_NODE_TYPE(PyStmtExprVisitorNode);
+TVM_REGISTER_NODE_TYPE(PyStmtExprMutatorNode);
+
+TVM_FFI_REGISTER_GLOBAL("tir.MakePyStmtExprVisitor")
+    .set_body_typed(PyStmtExprVisitor::MakePyStmtExprVisitor);
+TVM_FFI_REGISTER_GLOBAL("tir.MakePyStmtExprMutator")
+    .set_body_typed(PyStmtExprMutator::MakePyStmtExprMutator);
+
+// StmtExprVisitor
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorDefaultVisitExpr")
+    .set_body_typed([](PyStmtExprVisitor visitor, const PrimExpr& expr) {
+      visitor->DefaultVisitExpr(expr);
+    });
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorDefaultVisitStmt")
+    .set_body_typed([](PyStmtExprVisitor visitor, const Stmt& stmt) {
+      visitor->DefaultVisitStmt(stmt);
+    });
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorVisitStmt")
+    .set_body_typed([](PyStmtExprVisitor visitor, const Stmt& stmt) { 
visitor->VisitStmt(stmt); });
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorVisitExpr")
+    .set_body_typed([](PyStmtExprVisitor visitor, const PrimExpr& expr) {
+      visitor->VisitExpr(expr);
+    });
+
+// StmtExprMutator
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorDefaultVisitExpr")
+    .set_body_typed([](PyStmtExprMutator mutator, const PrimExpr& expr) {
+      return mutator->DefaultVisitExpr(expr);
+    });
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorDefaultVisitStmt")
+    .set_body_typed([](PyStmtExprMutator mutator, const Stmt& stmt) {
+      return mutator->DefaultVisitStmt(stmt);
+    });
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorVisitExpr")
+    .set_body_typed([](PyStmtExprMutator mutator, const PrimExpr& expr) {
+      return mutator->VisitExpr(expr);
+    });
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorVisitStmt")
+    .set_body_typed([](PyStmtExprMutator mutator, const Stmt& stmt) {
+      return mutator->VisitStmt(stmt);
+    });
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/tir-transform/test_tir_functor.py 
b/tests/python/tir-transform/test_tir_functor.py
new file mode 100644
index 0000000000..e8463027db
--- /dev/null
+++ b/tests/python/tir-transform/test_tir_functor.py
@@ -0,0 +1,436 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.tir import (
+    EQ,
+    LT,
+    Add,
+    Cast,
+    Evaluate,
+    FloatImm,
+    For,
+    IfThenElse,
+    IntImm,
+    Max,
+    Min,
+    Mul,
+    PyStmtExprMutator,
+    PyStmtExprVisitor,
+    StringImm,
+    Sub,
+    Var,
+)
+
+
+class ASTLog:
+    """Helper class to log AST"""
+
+    def __init__(self) -> None:
+        self.log = []
+        self.indent = "\t"
+        self.level = 0
+
+    def push_scope(self):
+        self.level += 1
+
+    def pop_scope(self):
+        self.level -= 1
+
+    def add(self, s: str):
+        self.log.append(self.indent * self.level + s)
+
+    def __str__(self) -> str:
+        return "\n".join(self.log)
+
+
[email protected]
+class ASTPrinter(PyStmtExprVisitor):
+    """Print tir AST in structured format. The shape of Node is ignored."""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.log = ASTLog()
+
+    def visit_var_(self, op: Var) -> None:
+        self.log.add("Stmt: Var")
+        super().visit_var_(op)
+
+    def visit_add_(self, op: Add) -> None:
+        self.log.add("Stmt: Add")
+        super().visit_add_(op)
+
+
[email protected]
+class SimpleExprCounter(PyStmtExprVisitor):
+    """Count expressions without recursion"""
+
+    def __init__(self):
+        super().__init__()
+        self.var_count = 0
+        self.add_count = 0
+        self.mul_count = 0
+
+    def visit_var_(self, op: Var):
+        self.var_count += 1
+        # Don't recursively visit children to avoid infinite recursion
+
+    def visit_add_(self, op: Add):
+        self.add_count += 1
+        # Visit children manually
+        super().visit_add_(op)
+
+    def visit_mul_(self, op: Mul):
+        self.mul_count += 1
+        # Visit children manually
+        super().visit_mul_(op)
+
+
[email protected]
+class VariableReplacer(PyStmtExprMutator):
+    """Replace variables with constants"""
+
+    def __init__(self, replacements):
+        super().__init__()
+        self.replacements = replacements
+
+    def visit_var_(self, op: Var):
+        if op.name in self.replacements:
+            return IntImm("int32", self.replacements[op.name])
+        return op
+
+
[email protected]
+class AddToSubMutator(PyStmtExprMutator):
+    """Convert Add operations to Sub operations"""
+
+    def visit_add_(self, op: Add):
+        # First mutate the operands
+        a = self.visit_expr(op.a)
+        b = self.visit_expr(op.b)
+        # Convert Add to Sub
+        return Sub(a, b)
+
+
[email protected]
+class SimpleStmtCounter(PyStmtExprVisitor):
+    """Count statements without recursion"""
+
+    def __init__(self):
+        super().__init__()
+        self.for_count = 0
+        self.if_count = 0
+        self.evaluate_count = 0
+
+    def visit_for_(self, op: For):
+        self.for_count += 1
+        super().visit_for_(op)
+
+    def visit_if_then_else_(self, op: IfThenElse):
+        self.if_count += 1
+        super().visit_if_then_else_(op)
+
+    def visit_evaluate_(self, op: Evaluate):
+        self.evaluate_count += 1
+        super().visit_evaluate_(op)
+
+
[email protected]
+class ForLoopUnroller(PyStmtExprMutator):
+    """Simple loop unroller for demonstration"""
+
+    def __init__(self, unroll_factor=2):
+        super().__init__()
+        self.unroll_factor = unroll_factor
+
+    def visit_for_(self, op: For):
+        # For demonstration, just return the original for now
+        # In a real implementation, we would unroll small loops
+        return super().visit_for_(op)
+
+
[email protected]
+class SimpleStmtExprVisitor(PyStmtExprVisitor):
+    """Visitor that handles both statements and expressions"""
+
+    def __init__(self):
+        super().__init__()
+        self.expr_count = 0
+        self.stmt_count = 0
+        self.var_names = set()
+
+    def visit_var_(self, op: Var):
+        self.var_names.add(op.name)
+        self.expr_count += 1
+
+    def visit_evaluate_(self, op: Evaluate):
+        self.stmt_count += 1
+        # Visit the expression
+        self.visit_expr(op.value)
+
+
[email protected]
+class ComplexMutator(PyStmtExprMutator):
+    """Mutator that handles both statements and expressions"""
+
+    def __init__(self):
+        super().__init__()
+        self.modifications = 0
+
+    def visit_add_(self, op: Add):
+        self.modifications += 1
+        # Convert a + b to a * 2 + b for demonstration
+        a = self.visit_expr(op.a)
+        b = self.visit_expr(op.b)
+        return Add(Mul(a, IntImm("int32", 2)), b)
+
+
+def test_basic_visitor():
+    """Test the basic AST printer visitor"""
+    expr = Add(Var("x", dtype="int32"), Var("y", dtype="int32"))
+    printer = ASTPrinter()
+    printer.visit_expr(expr)
+    assert str(printer.log) == "\n".join(["Stmt: Add", "Stmt: Var", "Stmt: 
Var"])
+
+
+def test_simple_expr_counter():
+    """Test simple expression counting visitor"""
+    x = Var("x", dtype="int32")
+    y = Var("y", dtype="int32")
+
+    # Create simple expression: x + y
+    expr = Add(x, y)
+
+    counter = SimpleExprCounter()
+    counter.visit_expr(expr)
+
+    assert counter.var_count == 2  # x and y
+    assert counter.add_count == 1  # one add
+
+
+def test_variable_replacer():
+    """Test expression mutator that replaces variables"""
+    x = Var("x", dtype="int32")
+    y = Var("y", dtype="int32")
+    expr = Add(x, Mul(y, IntImm("int32", 3)))
+
+    replacer = VariableReplacer({"x": 10, "y": 5})
+    result = replacer.visit_expr(expr)
+
+    # Should be Add(IntImm(10), Mul(IntImm(5), IntImm(3)))
+    assert isinstance(result, Add)
+    assert isinstance(result.a, IntImm)
+    assert result.a.value == 10
+    assert isinstance(result.b, Mul)
+    assert isinstance(result.b.a, IntImm)
+    assert result.b.a.value == 5
+
+
+def test_add_to_sub_mutator():
+    """Test mutator that converts Add to Sub"""
+    x = Var("x", dtype="int32")
+    y = Var("y", dtype="int32")
+    expr = Add(x, y)
+
+    mutator = AddToSubMutator()
+    result = mutator.visit_expr(expr)
+
+    assert isinstance(result, Sub)
+    assert isinstance(result.a, Var)
+    assert isinstance(result.b, Var)
+    assert result.a.name == "x"
+    assert result.b.name == "y"
+
+
+def test_simple_stmt_counter():
+    """Test statement visitor that counts statements"""
+    i = Var("i", dtype="int32")
+
+    # Create a simple for loop
+    loop_body = Evaluate(IntImm("int32", 0))
+    for_stmt = For(i, IntImm("int32", 0), IntImm("int32", 10), 
tir.ForKind.SERIAL, loop_body)
+
+    counter = SimpleStmtCounter()
+    counter.visit_stmt(for_stmt)
+
+    assert counter.for_count == 1  # One for loop
+    assert counter.evaluate_count == 1  # One evaluate in the body
+
+
+def test_if_then_else_visitor():
+    """Test visitor with if-then-else statements"""
+    x = Var("x", dtype="int32")
+    condition = EQ(x, IntImm("int32", 0))
+    then_stmt = Evaluate(IntImm("int32", 1))
+    else_stmt = Evaluate(IntImm("int32", 2))
+
+    if_stmt = IfThenElse(condition, then_stmt, else_stmt)
+
+    counter = SimpleStmtCounter()
+    counter.visit_stmt(if_stmt)
+
+    assert counter.if_count == 1
+    assert counter.for_count == 0
+
+
+def test_simple_stmt_expr_visitor():
+    """Test stmt_expr_visitor with mixed statements and expressions"""
+    x = Var("x", dtype="int32")
+    y = Var("y", dtype="int32")
+
+    # Create an evaluate statement with an expression
+    expr = Add(x, y)
+    stmt = Evaluate(expr)
+
+    visitor = SimpleStmtExprVisitor()
+    visitor.visit_stmt(stmt)
+
+    assert visitor.stmt_count == 1  # One Evaluate statement
+    assert visitor.expr_count == 2  # Two variables
+    assert "x" in visitor.var_names
+    assert "y" in visitor.var_names
+
+
+def test_complex_mutator():
+    """Test stmt_expr_mutator"""
+    x = Var("x", dtype="int32")
+    y = Var("y", dtype="int32")
+
+    # Expression with Add operations
+    expr = Add(x, y)
+    stmt = Evaluate(expr)
+
+    mutator = ComplexMutator()
+    result = mutator.visit_stmt(stmt)
+    print(type(mutator))
+
+    assert mutator.modifications == 1  # One Add operation modified
+    assert isinstance(result, Evaluate)
+
+    # Check that the expression was modified
+    modified_expr = result.value
+    assert isinstance(modified_expr, Add)
+    assert isinstance(modified_expr.a, Mul)  # First operand should be 
multiplied by 2
+
+
+def test_different_expr_types():
+    """Test visitor with various expression types"""
+    x = Var("x", dtype="int32")
+
+    # Test different expression types individually
+    exprs = [
+        IntImm("int32", 42),
+        FloatImm("float32", 3.14),
+        StringImm("hello"),
+        Cast("float32", x),
+        Min(x, IntImm("int32", 10)),
+        Max(x, IntImm("int32", 0)),
+        LT(x, IntImm("int32", 5)),
+    ]
+
+    # Just test that we can create and visit each type
+    counter = SimpleExprCounter()
+    for expr in exprs:
+        try:
+            counter.visit_expr(expr)
+        except Exception as e:
+            # Some expressions might not be supported, that's ok
+            pass
+
+
+def test_decorator_functionality():
+    """Test that decorators work correctly"""
+
+    # Test that decorated classes are properly wrapped
+    visitor = SimpleExprCounter()
+    assert hasattr(visitor, "_outer")  # Should have the wrapper functionality
+
+    mutator = VariableReplacer({})
+    assert hasattr(mutator, "_outer")
+
+
+def test_empty_expressions():
+    """Test handling of simple expressions"""
+    counter = SimpleExprCounter()
+
+    # Test with just a variable
+    x = Var("x", dtype="int32")
+    counter.visit_expr(x)
+
+    assert counter.var_count == 1
+
+    # Test with just a constant
+    counter = SimpleExprCounter()
+    const = IntImm("int32", 5)
+    counter.visit_expr(const)
+
+    # Constants don't increase var_count
+    assert counter.var_count == 0
+
+
+def test_stmt_mutator():
+    """Test basic statement mutator functionality"""
+    x = Var("x", dtype="int32")
+    stmt = Evaluate(Add(x, IntImm("int32", 1)))
+
+    unroller = ForLoopUnroller()
+    result = unroller.visit_stmt(stmt)
+
+    # Should return the same statement (no actual unrolling implemented)
+    assert isinstance(result, Evaluate)
+
+
+def test_nested_expressions():
+    """Test with nested expressions"""
+    x = Var("x", dtype="int32")
+    y = Var("y", dtype="int32")
+    z = Var("z", dtype="int32")
+
+    # Create nested expression: (x + y) * z
+    inner_add = Add(x, y)
+    expr = Mul(inner_add, z)
+
+    counter = SimpleExprCounter()
+    counter.visit_expr(expr)
+
+    assert counter.var_count == 3  # x, y, z
+    assert counter.add_count == 1  # one add
+    assert counter.mul_count == 1  # one mul
+
+
+def test_simple_mutations():
+    """Test simple expression mutations"""
+    x = Var("x", dtype="int32")
+    y = Var("y", dtype="int32")
+
+    # Test multiple replacements
+    expr = Add(x, y)
+    replacer = VariableReplacer({"x": 1, "y": 2})
+    result = replacer.visit_expr(expr)
+
+    assert isinstance(result, Add)
+    assert isinstance(result.a, IntImm)
+    assert isinstance(result.b, IntImm)
+    assert result.a.value == 1
+    assert result.b.value == 2
+
+
+if __name__ == "__main__":
+    test_basic_visitor()
+    tvm.testing.main()

Reply via email to