This is an automated email from the ASF dual-hosted git repository.
hongyij pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2ae2aca7d1 Add Python functor support for TIR expressions and
statements (#18060)
2ae2aca7d1 is described below
commit 2ae2aca7d18b14a5db5c40cc441e2f76307ad5ef
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Jun 17 02:14:01 2025 +0800
Add Python functor support for TIR expressions and statements (#18060)
This commit introduces Python interfaces for TIR functors,
enabling Python-side customization of expression and statement visiting
and mutation operations.
Key changes:
- Add PyStmtExprVisitorNode and PyStmtExprMutatorNode classes in C++
- Implement Python bindings for all TIR expression and statement types
- Support both visitor (read-only) and mutator (transforming) patterns
- Add comprehensive dispatch mechanisms for Python callbacks
- Include FFI registrations for Python-C++ interoperability
This enables users to write custom TIR transformations and analysis
passes directly in Python while maintaining performance through
selective Python callback dispatch.
---
python/tvm/relax/expr_functor.py | 4 +-
python/tvm/runtime/support.py | 137 ++
python/tvm/tir/__init__.py | 1 +
python/tvm/tir/functor.py | 2051 ++++++++++++++++++++++++
src/tir/analysis/buffer_access_lca_detector.cc | 2 +-
src/tir/ir/py_functor.cc | 859 ++++++++++
tests/python/tir-transform/test_tir_functor.py | 436 +++++
7 files changed, 3487 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py
index 49d3b14505..a40b81c233 100644
--- a/python/tvm/relax/expr_functor.py
+++ b/python/tvm/relax/expr_functor.py
@@ -20,8 +20,8 @@ from typing import Callable, Optional
import tvm
from tvm.ir import Op
-from tvm.meta_schedule.utils import derived_object
from tvm.runtime import Object
+from tvm.runtime.support import derived_object
from ..ir.module import IRModule
from . import _ffi_api
@@ -31,7 +31,6 @@ from .expr import (
BindingBlock,
Call,
Constant,
- Id,
DataflowBlock,
DataflowVar,
DataTypeImm,
@@ -39,6 +38,7 @@ from .expr import (
ExternFunc,
Function,
GlobalVar,
+ Id,
If,
MatchCast,
PrimValue,
diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py
index 149a66ef7b..2669459d71 100644
--- a/python/tvm/runtime/support.py
+++ b/python/tvm/runtime/support.py
@@ -18,6 +18,7 @@
"""Runtime support infra of TVM."""
import re
+from typing import TypeVar
import tvm.ffi
@@ -67,3 +68,139 @@ def _regex_match(regex_pattern: str, match_against: str) ->
bool:
"""
match = re.match(regex_pattern, match_against)
return match is not None
+
+
+T = TypeVar("T")
+
+
+def derived_object(cls: type[T]) -> type[T]:
+ """A decorator to register derived subclasses for TVM objects.
+
+ Parameters
+ ----------
+ cls : type
+ The derived class to be registered.
+
+ Returns
+ -------
+ cls : type
+ The decorated TVM object.
+
+ Example
+ -------
+ .. code-block:: python
+
+ @register_object("meta_schedule.PyRunner")
+ class _PyRunner(meta_schedule.Runner):
+ def __init__(self, f_run: Callable = None):
+ self.__init_handle_by_constructor__(_ffi_api.RunnerPyRunner,
f_run)
+
+ class PyRunner:
+ _tvm_metadata = {
+ "cls": _PyRunner,
+ "methods": ["run"]
+ }
+ def run(self, runner_inputs):
+ raise NotImplementedError
+
+ @derived_object
+ class LocalRunner(PyRunner):
+ def run(self, runner_inputs):
+ ...
+ """
+
+ import functools # pylint: disable=import-outside-toplevel
+ import weakref # pylint: disable=import-outside-toplevel
+
+ def _extract(inst: type, name: str):
+ """Extract function from intrinsic class."""
+
+ def method(*args, **kwargs):
+ return getattr(inst, name)(*args, **kwargs)
+
+ for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]):
+ # extract functions that differ from the base class
+ if not hasattr(base_cls, name):
+ continue
+ if getattr(base_cls, name) is getattr(inherit_cls, name) and name
!= "__str__":
+ continue
+ return method
+
+ # for task scheduler return None means calling default function
+ # otherwise it will trigger a TVMError of method not implemented
+ # on the c++ side when you call the method, __str__ not required
+ return None
+
+ assert isinstance(cls.__base__, type)
+ if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": # type:
ignore
+ raise TypeError(
+ (
+ f"Inheritance from a decorated object `{cls.__name__}` is not
allowed. "
+ f"Please inherit from `{cls.__name__}._cls`."
+ )
+ )
+ assert hasattr(
+ cls, "_tvm_metadata"
+ ), "Please use the user-facing method overriding class, i.e., PyRunner."
+
+ base = cls.__base__
+ metadata = getattr(base, "_tvm_metadata")
+ fields = metadata.get("fields", [])
+ methods = metadata.get("methods", [])
+
+ class TVMDerivedObject(metadata["cls"]): # type: ignore
+ """The derived object to avoid cyclic dependency."""
+
+ _cls = cls
+ _type = "TVMDerivedObject"
+
+ def __init__(self, *args, **kwargs):
+ """Constructor."""
+ self._inst = cls(*args, **kwargs)
+
+ super().__init__(
+ # the constructor's parameters, builder, runner, etc.
+ *[getattr(self._inst, name) for name in fields],
+ # the function methods, init_with_tune_context, build, run,
etc.
+ *[_extract(self._inst, name) for name in methods],
+ )
+
+ # for task scheduler hybrid funcs in c++ & python side
+ # using weakref to avoid cyclic dependency
+ self._inst._outer = weakref.ref(self)
+
+ def __getattr__(self, name):
+ import inspect # pylint: disable=import-outside-toplevel
+
+ try:
+ # fall back to instance attribute if there is not any
+ # return self._inst.__getattribute__(name)
+ result = self._inst.__getattribute__(name)
+ except AttributeError:
+ result = super(TVMDerivedObject, self).__getattr__(name)
+
+ if inspect.ismethod(result):
+
+ def method(*args, **kwargs):
+ return result(*args, **kwargs)
+
+ # set __own__ to aviod implicit deconstruction
+ setattr(method, "__own__", self)
+ return method
+
+ return result
+
+ def __setattr__(self, name, value):
+ if name not in ["_inst", "key", "handle"]:
+ self._inst.__setattr__(name, value)
+ else:
+ super(TVMDerivedObject, self).__setattr__(name, value)
+
+ functools.update_wrapper(TVMDerivedObject.__init__, cls.__init__) # type:
ignore
+ TVMDerivedObject.__name__ = cls.__name__
+ TVMDerivedObject.__doc__ = cls.__doc__
+ TVMDerivedObject.__module__ = cls.__module__
+ for key, value in cls.__dict__.items():
+ if isinstance(value, (classmethod, staticmethod)):
+ setattr(TVMDerivedObject, key, value)
+ return TVMDerivedObject
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 5c4a2b91f5..120d652dd8 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -108,3 +108,4 @@ from . import analysis
from . import stmt_functor
from .build import build
from .pipeline import get_tir_pipeline, get_default_tir_pipeline
+from .functor import PyStmtExprVisitor, PyStmtExprMutator
diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py
new file mode 100644
index 0000000000..06985f6645
--- /dev/null
+++ b/python/tvm/tir/functor.py
@@ -0,0 +1,2051 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name,
arguments-differ
+"""The expression and statement functor of TIR."""
+from typing import Callable
+
+import tvm
+from tvm.ir import PrimExpr
+from tvm.runtime import Object
+from tvm.runtime.support import derived_object
+
+from . import _ffi_api
+from .expr import (
+ EQ,
+ GE,
+ GT,
+ LE,
+ LT,
+ NE,
+ Add,
+ And,
+ Broadcast,
+ BufferLoad,
+ Call,
+ Cast,
+ Div,
+ FloatImm,
+ FloorDiv,
+ FloorMod,
+ IntImm,
+ Let,
+ Max,
+ Min,
+ Mod,
+ Mul,
+ Not,
+ Or,
+ ProducerLoad,
+ Ramp,
+ Reduce,
+ Select,
+ Shuffle,
+ SizeVar,
+ StringImm,
+ Sub,
+ Var,
+)
+from .stmt import (
+ Allocate,
+ AllocateConst,
+ AssertStmt,
+ AttrStmt,
+ Block,
+ BlockRealize,
+ BufferRealize,
+ BufferStore,
+ DeclBuffer,
+ Evaluate,
+ For,
+ IfThenElse,
+ LetStmt,
+ SeqStmt,
+ Stmt,
+ While,
+)
+
+visitor = derived_object
+"""
+A decorator to wrap user-customized PyStmtExprVisitor as TVM object
_PyStmtExprVisitor.
+
+Parameters
+----------
+visitor_cls : PyStmtExprVisitor
+ The user-customized PyStmtExprVisitor.
+
+Returns
+-------
+cls : _PyStmtExprVisitor
+ The decorated TVM object _PyStmtExprVisitor(StmtExprVisitor on the C++
side).
+
+Example
+-------
+.. code-block:: python
+
+ @tir.functor.stmt_expr_visitor
+ class MyStmtExprVisitor(PyStmtExprVisitor):
+ # customize visit function
+ def visit_call_(self, op: Call) -> None:
+ # just for demo purposes
+ ...
+ # myvisitor is now a special visitor that visit every Call with
+ # user-customized visit_call_
+ myvisitor = MyStmtExprVisitor()
+ # apply myvisitor to PrimExpr and Stmt
+ myvisitor.visit_expr(expr)
+ myvisitor.visit_stmt(stmt)
+"""
+
+mutator = derived_object
+"""
+A decorator to wrap user-customized PyStmtExprMutator as TVM object
_PyStmtExprMutator.
+
+Parameters
+----------
+mutator_cls : PyStmtExprMutator
+ The user-customized PyStmtExprMutator.
+
+Returns
+-------
+cls : _PyStmtExprMutator
+ The decorated TVM object _PyStmtExprMutator(StmtExprMutator on the C++
side).
+
+Example
+-------
+.. code-block:: python
+
+ @tir.functor.stmt_expr_mutator
+ class MyStmtExprMutator(PyStmtExprMutator):
+ # customize rewrite function
+ def visit_add_(self, op: Add) -> PrimExpr:
+ # just for demo purposes
+ ...
+
+ # mymutator is now a special mutator that rewrite every Add with
+ # user-customized visit_add_
+ mymutator = MyStmtExprMutator()
+ # apply mymutator to PrimExpr and Stmt
+ mymutator.visit_expr(expr)
+ mymutator.visit_stmt(stmt)
+"""
+
+
[email protected]_object("tir.PyStmtExprVisitor")
+class _PyStmtExprVisitor(Object):
+ """
+ An internal wrapper to interface between C++ and Python StmtExprVisitor.
+ This is the TVM object that wraps PyStmtExprVisitor.
+
+ Do not use this class directly. Use PyStmtExprVisitor instead.
+
+ See also: PyStmtExprVisitor, stmt_expr_visitor
+ """
+
+ def __init__(
+ self,
+ f_visit_stmt: Callable = None,
+ f_visit_expr: Callable = None,
+ # Stmt
+ f_visit_let_stmt: Callable = None,
+ f_visit_attr_stmt: Callable = None,
+ f_visit_if_then_else: Callable = None,
+ f_visit_for: Callable = None,
+ f_visit_while: Callable = None,
+ f_visit_allocate: Callable = None,
+ f_visit_allocate_const: Callable = None,
+ f_visit_decl_buffer: Callable = None,
+ f_visit_buffer_store: Callable = None,
+ f_visit_buffer_realize: Callable = None,
+ f_visit_assert_stmt: Callable = None,
+ f_visit_seq_stmt: Callable = None,
+ f_visit_evaluate: Callable = None,
+ f_visit_block: Callable = None,
+ f_visit_block_realize: Callable = None,
+ # PrimExpr
+ f_visit_var: Callable = None,
+ f_visit_size_var: Callable = None,
+ f_visit_buffer_load: Callable = None,
+ f_visit_producer_load: Callable = None,
+ f_visit_let: Callable = None,
+ f_visit_call: Callable = None,
+ f_visit_add: Callable = None,
+ f_visit_sub: Callable = None,
+ f_visit_mul: Callable = None,
+ f_visit_div: Callable = None,
+ f_visit_mod: Callable = None,
+ f_visit_floor_div: Callable = None,
+ f_visit_floor_mod: Callable = None,
+ f_visit_min: Callable = None,
+ f_visit_max: Callable = None,
+ f_visit_eq: Callable = None,
+ f_visit_ne: Callable = None,
+ f_visit_lt: Callable = None,
+ f_visit_le: Callable = None,
+ f_visit_gt: Callable = None,
+ f_visit_ge: Callable = None,
+ f_visit_and: Callable = None,
+ f_visit_or: Callable = None,
+ f_visit_reduce: Callable = None,
+ f_visit_cast: Callable = None,
+ f_visit_not: Callable = None,
+ f_visit_select: Callable = None,
+ f_visit_ramp: Callable = None,
+ f_visit_broadcast: Callable = None,
+ f_visit_shuffle: Callable = None,
+ f_visit_int_imm: Callable = None,
+ f_visit_float_imm: Callable = None,
+ f_visit_string_imm: Callable = None,
+ ) -> None:
+ """Constructor."""
+ self.__init_handle_by_constructor__(
+ _ffi_api.MakePyStmtExprVisitor, # type: ignore
+ f_visit_stmt,
+ f_visit_expr,
+ # Stmt
+ f_visit_let_stmt,
+ f_visit_attr_stmt,
+ f_visit_if_then_else,
+ f_visit_for,
+ f_visit_while,
+ f_visit_allocate,
+ f_visit_allocate_const,
+ f_visit_decl_buffer,
+ f_visit_buffer_store,
+ f_visit_buffer_realize,
+ f_visit_assert_stmt,
+ f_visit_seq_stmt,
+ f_visit_evaluate,
+ f_visit_block,
+ f_visit_block_realize,
+ # PrimExpr
+ f_visit_var,
+ f_visit_size_var,
+ f_visit_buffer_load,
+ f_visit_producer_load,
+ f_visit_let,
+ f_visit_call,
+ f_visit_add,
+ f_visit_sub,
+ f_visit_mul,
+ f_visit_div,
+ f_visit_mod,
+ f_visit_floor_div,
+ f_visit_floor_mod,
+ f_visit_min,
+ f_visit_max,
+ f_visit_eq,
+ f_visit_ne,
+ f_visit_lt,
+ f_visit_le,
+ f_visit_gt,
+ f_visit_ge,
+ f_visit_and,
+ f_visit_or,
+ f_visit_reduce,
+ f_visit_cast,
+ f_visit_not,
+ f_visit_select,
+ f_visit_ramp,
+ f_visit_broadcast,
+ f_visit_shuffle,
+ f_visit_int_imm,
+ f_visit_float_imm,
+ f_visit_string_imm,
+ )
+
+
+class PyStmtExprVisitor:
+ """
+ A Python StmtExprVisitor to define custom visitor for both Stmt and
PrimExpr.
+
+ Users can customize any of the visit function.
+ """
+
+ _tvm_metadata = {
+ "cls": _PyStmtExprVisitor,
+ "methods": [
+ "visit_stmt",
+ "visit_expr",
+ # Stmt
+ "visit_let_stmt_",
+ "visit_attr_stmt_",
+ "visit_if_then_else_",
+ "visit_for_",
+ "visit_while_",
+ "visit_allocate_",
+ "visit_allocate_const_",
+ "visit_decl_buffer_",
+ "visit_buffer_store_",
+ "visit_buffer_realize_",
+ "visit_assert_stmt_",
+ "visit_seq_stmt_",
+ "visit_evaluate_",
+ "visit_block_",
+ "visit_block_realize_",
+ # PrimExpr
+ "visit_var_",
+ "visit_size_var_",
+ "visit_buffer_load_",
+ "visit_producer_load_",
+ "visit_let_",
+ "visit_call_",
+ "visit_add_",
+ "visit_sub_",
+ "visit_mul_",
+ "visit_div_",
+ "visit_mod_",
+ "visit_floor_div_",
+ "visit_floor_mod_",
+ "visit_min_",
+ "visit_max_",
+ "visit_eq_",
+ "visit_ne_",
+ "visit_lt_",
+ "visit_le_",
+ "visit_gt_",
+ "visit_ge_",
+ "visit_and_",
+ "visit_or_",
+ "visit_reduce_",
+ "visit_cast_",
+ "visit_not_",
+ "visit_select_",
+ "visit_ramp_",
+ "visit_broadcast_",
+ "visit_shuffle_",
+ "visit_int_imm_",
+ "visit_float_imm_",
+ "visit_string_imm_",
+ ],
+ }
+
+ def visit_stmt(self, stmt: Stmt) -> None:
+ """Visit a Stmt.
+
+ Parameters
+ ----------
+ stmt : Stmt
+ The Stmt to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorVisitStmt(self._outer(), stmt) # type:
ignore
+
+ def visit_expr(self, expr: PrimExpr) -> None:
+ """Visit a PrimExpr.
+
+ Parameters
+ ----------
+ expr : PrimExpr
+ The PrimExpr to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorVisitExpr(self._outer(), expr) # type:
ignore
+
+ def visit_attr_stmt_(self, op: AttrStmt) -> None:
+ """Visit AttrStmt.
+ Users can customize this function to overwrite VisitStmt_(const
AttrStmtNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : AttrStmt
+ The AttrStmt to be visited.
+ """
+ print("visit_attr_stmt_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_if_then_else_(self, op: IfThenElse) -> None:
+ """Visit IfThenElse.
+ Users can customize this function to overwrite VisitStmt_(const
IfThenElseNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : IfThenElse
+ The IfThenElse to be visited.
+ """
+ print("visit_if_then_else_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_let_stmt_(self, op: LetStmt) -> None:
+ """Visit LetStmt.
+ Users can customize this function to overwrite VisitStmt_(const
LetStmtNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : LetStmt
+ The LetStmt to be visited.
+ """
+ print("visit_let_stmt_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_for_(self, op: For) -> None:
+ """Visit For.
+ Users can customize this function to overwrite VisitStmt_(const
ForNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : For
+ The For to be visited.
+ """
+ print("visit_for_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_while_(self, op: While) -> None:
+ """Visit While.
+ Users can customize this function to overwrite VisitStmt_(const
WhileNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : While
+ The While to be visited.
+ """
+ print("visit_while_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_allocate_(self, op: Allocate) -> None:
+ """Visit Allocate.
+ Users can customize this function to overwrite VisitStmt_(const
AllocateNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Allocate
+ The Allocate to be visited.
+ """
+ print("visit_allocate_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_allocate_const_(self, op: AllocateConst) -> None:
+ """Visit AllocateConst.
+ Users can customize this function to overwrite VisitStmt_(const
AllocateConstNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : AllocateConst
+ The AllocateConst to be visited.
+ """
+ print("visit_allocate_const_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_decl_buffer_(self, op: DeclBuffer) -> None:
+ """Visit DeclBuffer.
+ Users can customize this function to overwrite VisitStmt_(const
DeclBufferNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : DeclBuffer
+ The DeclBuffer to be visited.
+ """
+ print("visit_decl_buffer_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_buffer_store_(self, op: BufferStore) -> None:
+ """Visit BufferStore.
+ Users can customize this function to overwrite VisitStmt_(const
BufferStoreNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : BufferStore
+ The BufferStore to be visited.
+ """
+ print("visit_buffer_store_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_buffer_realize_(self, op: BufferRealize) -> None:
+ """Visit BufferRealize.
+ Users can customize this function to overwrite VisitStmt_(const
BufferRealizeNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : BufferRealize
+ The BufferRealize to be visited.
+ """
+ print("visit_buffer_realize_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_assert_stmt_(self, op: AssertStmt) -> None:
+ """Visit AssertStmt.
+ Users can customize this function to overwrite VisitStmt_(const
AssertStmtNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : AssertStmt
+ The AssertStmt to be visited.
+ """
+ print("visit_assert_stmt_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_seq_stmt_(self, op: SeqStmt) -> None:
+ """Visit SeqStmt.
+ Users can customize this function to overwrite VisitStmt_(const
SeqStmtNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : SeqStmt
+ The SeqStmt to be visited.
+ """
+ print("visit_seq_stmt_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_evaluate_(self, op: Evaluate) -> None:
+ """Visit Evaluate.
+ Users can customize this function to overwrite VisitStmt_(const
EvaluateNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Evaluate
+ The Evaluate to be visited.
+ """
+ print("visit_evaluate_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_block_(self, op: Block) -> None:
+ """Visit Block.
+ Users can customize this function to overwrite VisitStmt_(const
BlockNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Block
+ The Block to be visited.
+ """
+ print("visit_block_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_block_realize_(self, op: BlockRealize) -> None:
+ """Visit BlockRealize.
+ Users can customize this function to overwrite VisitStmt_(const
BlockRealizeNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : BlockRealize
+ The BlockRealize to be visited.
+ """
+ print("visit_block_realize_", op)
+ _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type:
ignore
+
+ def visit_var_(self, op: Var) -> None:
+ """Visit Var.
+
+ Users can customize this function to overwrite VisitVar_(const
VarNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Var
+ The Var to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_size_var_(self, op: SizeVar) -> None:
+ """Visit SizeVar.
+
+ Users can customize this function to overwrite VisitSizeVar_(const
SizeVarNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : SizeVar
+ The SizeVar to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_buffer_load_(self, op: BufferLoad) -> None:
+ """Visit BufferLoad.
+
+ Users can customize this function to overwrite VisitBufferLoad_(const
BufferLoadNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : BufferLoad
+ The BufferLoad to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_producer_load_(self, op: ProducerLoad) -> None:
+ """Visit ProducerLoad.
+
+ Users can customize this function to overwrite
+ VisitProducerLoad_(const ProducerLoadNode* op) on the C++ side.
+
+ Parameters
+ ----------
+ op : ProducerLoad
+ The ProducerLoad to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_let_(self, op: Let) -> None:
+ """Visit Let.
+
+ Users can customize this function to overwrite VisitLet_(const
LetNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Let
+ The Let to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_call_(self, op: Call) -> None:
+ """Visit Call.
+
+ Users can customize this function to overwrite VisitCall_(const
CallNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Call
+ The Call to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_add_(self, op: Add) -> None:
+ """Visit Add.
+
+ Users can customize this function to overwrite VisitAdd_(const
AddNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Add
+ The Add to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_sub_(self, op: Sub) -> None:
+ """Visit Sub.
+
+ Users can customize this function to overwrite VisitSub_(const
SubNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Sub
+ The Sub to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_mul_(self, op: Mul) -> None:
+ """Visit Mul.
+
+ Users can customize this function to overwrite VisitMul_(const
MulNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Mul
+ The Mul to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_div_(self, op: Div) -> None:
+ """Visit Div.
+
+ Users can customize this function to overwrite VisitDiv_(const
DivNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Div
+ The Div to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_mod_(self, op: Mod) -> None:
+ """Visit Mod.
+
+ Users can customize this function to overwrite VisitMod_(const
ModNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Mod
+ The Mod to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_floor_div_(self, op: FloorDiv) -> None:
+ """Visit FloorDiv.
+
+ Users can customize this function to overwrite VisitFloorDiv_(const
FloorDivNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : FloorDiv
+ The FloorDiv to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_floor_mod_(self, op: FloorMod) -> None:
+ """Visit FloorMod.
+
+ Users can customize this function to overwrite VisitFloorMod_(const
FloorModNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : FloorMod
+ The FloorMod to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_min_(self, op: Min) -> None:
+ """Visit Min.
+
+ Users can customize this function to overwrite VisitMin_(const
MinNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Min
+ The Min to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_max_(self, op: Max) -> None:
+ """Visit Max.
+
+ Users can customize this function to overwrite VisitMax_(const
MaxNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Max
+ The Max to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_eq_(self, op: EQ) -> None:
+ """Visit EQ.
+
+ Users can customize this function to overwrite VisitEQ_(const EQNode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : EQ
+ The EQ to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_ne_(self, op: NE) -> None:
+ """Visit NE.
+
+ Users can customize this function to overwrite VisitNE_(const NENode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : NE
+ The NE to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_lt_(self, op: LT) -> None:
+ """Visit LT.
+
+ Users can customize this function to overwrite VisitLT_(const LTNode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : LT
+ The LT to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_le_(self, op: LE) -> None:
+ """Visit LE.
+
+ Users can customize this function to overwrite VisitLE_(const LENode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : LE
+ The LE to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_gt_(self, op: GT) -> None:
+ """Visit GT.
+
+ Users can customize this function to overwrite VisitGT_(const GTNode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : GT
+ The GT to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_ge_(self, op: GE) -> None:
+ """Visit GE.
+
+ Users can customize this function to overwrite VisitGE_(const GENode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : GE
+ The GE to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_and_(self, op: And) -> None:
+ """Visit And.
+
+ Users can customize this function to overwrite VisitAnd_(const
AndNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : And
+ The And to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_or_(self, op: Or) -> None:
+ """Visit Or.
+
+ Users can customize this function to overwrite VisitOr_(const OrNode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Or
+ The Or to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_reduce_(self, op: Reduce) -> None:
+ """Visit Reduce.
+
+ Users can customize this function to overwrite VisitReduce_(const
ReduceNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Reduce
+ The Reduce to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_cast_(self, op: Cast) -> None:
+ """Visit Cast.
+
+ Users can customize this function to overwrite VisitCast_(const
CastNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Cast
+ The Cast to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_not_(self, op: Not) -> None:
+ """Visit Not.
+
+ Users can customize this function to overwrite VisitNot_(const
NotNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Not
+ The Not to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_select_(self, op: Select) -> None:
+ """Visit Select.
+
+ Users can customize this function to overwrite VisitSelect_(const
SelectNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Select
+ The Select to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_ramp_(self, op: Ramp) -> None:
+ """Visit Ramp.
+
+ Users can customize this function to overwrite VisitRamp_(const
RampNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Ramp
+ The Ramp to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_broadcast_(self, op: Broadcast) -> None:
+ """Visit Broadcast.
+
+ Users can customize this function to overwrite VisitBroadcast_(const
BroadcastNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Broadcast
+ The Broadcast to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_shuffle_(self, op: Shuffle) -> None:
+ """Visit Shuffle.
+
+ Users can customize this function to overwrite VisitShuffle_(const
ShuffleNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Shuffle
+ The Shuffle to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_int_imm_(self, op: IntImm) -> None:
+ """Visit IntImm.
+
+ Users can customize this function to overwrite VisitIntImm_(const
IntImmNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : IntImm
+ The IntImm to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_float_imm_(self, op: FloatImm) -> None:
+ """Visit FloatImm.
+
+ Users can customize this function to overwrite VisitFloatImm_(const
FloatImmNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : FloatImm
+ The FloatImm to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+ def visit_string_imm_(self, op: StringImm) -> None:
+ """Visit StringImm.
+
+ Users can customize this function to overwrite VisitStringImm_(const
StringImmNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : StringImm
+ The StringImm to be visited.
+ """
+ _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type:
ignore
+
+
[email protected]_object("tir.PyStmtExprMutator")
+class _PyStmtExprMutator(Object):
+ """
+ A TVM object to support customization of StmtExprMutator on the python
side.
+ This is the decorated result returned from stmt_expr_mutator decorator.
+
+ WARNING: This is NOT the user facing class for method overwriting
inheritance.
+
+ See also: stmt_expr_mutator, PyStmtExprMutator
+ """
+
+ def __init__(
+ self,
+ f_visit_stmt: Callable = None,
+ f_visit_expr: Callable = None,
+ # Stmt
+ f_visit_let_stmt: Callable = None,
+ f_visit_attr_stmt: Callable = None,
+ f_visit_if_then_else: Callable = None,
+ f_visit_for: Callable = None,
+ f_visit_while: Callable = None,
+ f_visit_allocate: Callable = None,
+ f_visit_allocate_const: Callable = None,
+ f_visit_decl_buffer: Callable = None,
+ f_visit_buffer_store: Callable = None,
+ f_visit_buffer_realize: Callable = None,
+ f_visit_assert_stmt: Callable = None,
+ f_visit_seq_stmt: Callable = None,
+ f_visit_evaluate: Callable = None,
+ f_visit_block: Callable = None,
+ f_visit_block_realize: Callable = None,
+ # PrimExpr
+ f_visit_var: Callable = None,
+ f_visit_size_var: Callable = None,
+ f_visit_buffer_load: Callable = None,
+ f_visit_producer_load: Callable = None,
+ f_visit_let: Callable = None,
+ f_visit_call: Callable = None,
+ f_visit_add: Callable = None,
+ f_visit_sub: Callable = None,
+ f_visit_mul: Callable = None,
+ f_visit_div: Callable = None,
+ f_visit_mod: Callable = None,
+ f_visit_floor_div: Callable = None,
+ f_visit_floor_mod: Callable = None,
+ f_visit_min: Callable = None,
+ f_visit_max: Callable = None,
+ f_visit_eq: Callable = None,
+ f_visit_ne: Callable = None,
+ f_visit_lt: Callable = None,
+ f_visit_le: Callable = None,
+ f_visit_gt: Callable = None,
+ f_visit_ge: Callable = None,
+ f_visit_and: Callable = None,
+ f_visit_or: Callable = None,
+ f_visit_reduce: Callable = None,
+ f_visit_cast: Callable = None,
+ f_visit_not: Callable = None,
+ f_visit_select: Callable = None,
+ f_visit_ramp: Callable = None,
+ f_visit_broadcast: Callable = None,
+ f_visit_shuffle: Callable = None,
+ f_visit_int_imm: Callable = None,
+ f_visit_float_imm: Callable = None,
+ f_visit_string_imm: Callable = None,
+ ) -> None:
+ """Constructor."""
+ self.__init_handle_by_constructor__(
+ _ffi_api.MakePyStmtExprMutator, # type: ignore
+ f_visit_stmt,
+ f_visit_expr,
+ # Stmt
+ f_visit_let_stmt,
+ f_visit_attr_stmt,
+ f_visit_if_then_else,
+ f_visit_for,
+ f_visit_while,
+ f_visit_allocate,
+ f_visit_allocate_const,
+ f_visit_decl_buffer,
+ f_visit_buffer_store,
+ f_visit_buffer_realize,
+ f_visit_assert_stmt,
+ f_visit_seq_stmt,
+ f_visit_evaluate,
+ f_visit_block,
+ f_visit_block_realize,
+ # PrimExpr
+ f_visit_var,
+ f_visit_size_var,
+ f_visit_buffer_load,
+ f_visit_producer_load,
+ f_visit_let,
+ f_visit_call,
+ f_visit_add,
+ f_visit_sub,
+ f_visit_mul,
+ f_visit_div,
+ f_visit_mod,
+ f_visit_floor_div,
+ f_visit_floor_mod,
+ f_visit_min,
+ f_visit_max,
+ f_visit_eq,
+ f_visit_ne,
+ f_visit_lt,
+ f_visit_le,
+ f_visit_gt,
+ f_visit_ge,
+ f_visit_and,
+ f_visit_or,
+ f_visit_reduce,
+ f_visit_cast,
+ f_visit_not,
+ f_visit_select,
+ f_visit_ramp,
+ f_visit_broadcast,
+ f_visit_shuffle,
+ f_visit_int_imm,
+ f_visit_float_imm,
+ f_visit_string_imm,
+ )
+
+
+class PyStmtExprMutator:
+ """
+ A Python StmtExprMutator to define custom mutator for both Stmt and
PrimExpr.
+
+ Users can customize any of the visit function.
+ """
+
+ _tvm_metadata = {
+ "cls": _PyStmtExprMutator,
+ "methods": [
+ "visit_stmt",
+ "visit_expr",
+ # Stmt
+ "visit_let_stmt_",
+ "visit_attr_stmt_",
+ "visit_if_then_else_",
+ "visit_for_",
+ "visit_while_",
+ "visit_allocate_",
+ "visit_allocate_const_",
+ "visit_decl_buffer_",
+ "visit_buffer_store_",
+ "visit_buffer_realize_",
+ "visit_assert_stmt_",
+ "visit_seq_stmt_",
+ "visit_evaluate_",
+ "visit_block_",
+ "visit_block_realize_",
+ # PrimExpr
+ "visit_var_",
+ "visit_size_var_",
+ "visit_buffer_load_",
+ "visit_producer_load_",
+ "visit_let_",
+ "visit_call_",
+ "visit_add_",
+ "visit_sub_",
+ "visit_mul_",
+ "visit_div_",
+ "visit_mod_",
+ "visit_floor_div_",
+ "visit_floor_mod_",
+ "visit_min_",
+ "visit_max_",
+ "visit_eq_",
+ "visit_ne_",
+ "visit_lt_",
+ "visit_le_",
+ "visit_gt_",
+ "visit_ge_",
+ "visit_and_",
+ "visit_or_",
+ "visit_reduce_",
+ "visit_cast_",
+ "visit_not_",
+ "visit_select_",
+ "visit_ramp_",
+ "visit_broadcast_",
+ "visit_shuffle_",
+ "visit_int_imm_",
+ "visit_float_imm_",
+ "visit_string_imm_",
+ ],
+ }
+
+ def visit_expr(self, expr: PrimExpr) -> PrimExpr:
+ """Visit PrimExpr.
+ Users can customize this function to overwrite VisitExpr(const
PrimExpr& expr)
+ on the C++ side.
+
+ Parameters
+ ----------
+ expr : PrimExpr
+ The PrimExpr to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorVisitExpr(self._outer(), expr) #
type: ignore
+
+ def visit_stmt(self, stmt: Stmt) -> Stmt:
+ """Visit Stmt.
+ Users can customize this function to overwrite VisitStmt(const Stmt&
stmt)
+ on the C++ side.
+
+ Parameters
+ ----------
+ stmt : Stmt
+ The Stmt to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorVisitStmt(self._outer(), stmt) #
type: ignore
+
+ def visit_attr_stmt_(self, op: AttrStmt) -> Stmt:
+ """Visit AttrStmt.
+ Users can customize this function to overwrite VisitStmt_(const
AttrStmtNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : AttrStmt
+ The AttrStmt to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_if_then_else_(self, op: IfThenElse) -> Stmt:
+ """Visit IfThenElse.
+ Users can customize this function to overwrite VisitStmt_(const
IfThenElseNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : IfThenElse
+ The IfThenElse to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_let_stmt_(self, op: LetStmt) -> Stmt:
+ """Visit LetStmt.
+ Users can customize this function to overwrite VisitStmt_(const
LetStmtNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : LetStmt
+ The LetStmt to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_for_(self, op: For) -> Stmt:
+ """Visit For.
+ Users can customize this function to overwrite VisitStmt_(const
ForNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : For
+ The For to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_while_(self, op: While) -> Stmt:
+ """Visit While.
+ Users can customize this function to overwrite VisitStmt_(const
WhileNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : While
+ The While to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_allocate_(self, op: Allocate) -> Stmt:
+ """Visit Allocate.
+ Users can customize this function to overwrite VisitStmt_(const
AllocateNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Allocate
+ The Allocate to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_allocate_const_(self, op: AllocateConst) -> Stmt:
+ """Visit AllocateConst.
+ Users can customize this function to overwrite VisitStmt_(const
AllocateConstNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : AllocateConst
+ The AllocateConst to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_decl_buffer_(self, op: DeclBuffer) -> Stmt:
+ """Visit DeclBuffer.
+ Users can customize this function to overwrite VisitStmt_(const
DeclBufferNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : DeclBuffer
+ The DeclBuffer to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_buffer_store_(self, op: BufferStore) -> Stmt:
+ """Visit BufferStore.
+ Users can customize this function to overwrite VisitStmt_(const
BufferStoreNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : BufferStore
+ The BufferStore to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_buffer_realize_(self, op: BufferRealize) -> Stmt:
+ """Visit BufferRealize.
+ Users can customize this function to overwrite VisitStmt_(const
BufferRealizeNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : BufferRealize
+ The BufferRealize to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_assert_stmt_(self, op: AssertStmt) -> Stmt:
+ """Visit AssertStmt.
+ Users can customize this function to overwrite VisitStmt_(const
AssertStmtNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : AssertStmt
+ The AssertStmt to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_seq_stmt_(self, op: SeqStmt) -> Stmt:
+ """Visit SeqStmt.
+ Users can customize this function to overwrite VisitStmt_(const
SeqStmtNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : SeqStmt
+ The SeqStmt to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_evaluate_(self, op: Evaluate) -> Stmt:
+ """Visit Evaluate.
+ Users can customize this function to overwrite VisitStmt_(const
EvaluateNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Evaluate
+ The Evaluate to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_block_(self, op: Block) -> Stmt:
+ """Visit Block.
+ Users can customize this function to overwrite VisitStmt_(const
BlockNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Block
+ The Block to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_block_realize_(self, op: BlockRealize) -> Stmt:
+ """Visit BlockRealize.
+ Users can customize this function to overwrite VisitStmt_(const
BlockRealizeNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : BlockRealize
+ The BlockRealize to be visited.
+
+ Returns
+ -------
+ result : Stmt
+ The mutated Stmt.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op)
# type: ignore
+
+ def visit_var_(self, op: Var) -> PrimExpr:
+ """Visit Var.
+
+ Users can customize this function to overwrite VisitVar_(const
VarNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Var
+ The Var to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_size_var_(self, op: SizeVar) -> PrimExpr:
+ """Visit SizeVar.
+
+ Users can customize this function to overwrite VisitSizeVar_(const
SizeVarNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : SizeVar
+ The SizeVar to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_buffer_load_(self, op: BufferLoad) -> PrimExpr:
+ """Visit BufferLoad.
+
+ Users can customize this function to overwrite VisitBufferLoad_(const
BufferLoadNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : BufferLoad
+ The BufferLoad to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_producer_load_(self, op: ProducerLoad) -> PrimExpr:
+ """Visit ProducerLoad.
+
+ Users can customize this function to overwrite
+ VisitProducerLoad_(const ProducerLoadNode* op) on the C++ side.
+
+ Parameters
+ ----------
+ op : ProducerLoad
+ The ProducerLoad to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_let_(self, op: Let) -> PrimExpr:
+ """Visit Let.
+
+ Users can customize this function to overwrite VisitLet_(const
LetNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Let
+ The Let to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_call_(self, op: Call) -> PrimExpr:
+ """Visit Call.
+
+ Users can customize this function to overwrite VisitCall_(const
CallNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Call
+ The Call to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_add_(self, op: Add) -> PrimExpr:
+ """Visit Add.
+
+ Users can customize this function to overwrite VisitAdd_(const
AddNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Add
+ The Add to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_sub_(self, op: Sub) -> PrimExpr:
+ """Visit Sub.
+
+ Users can customize this function to overwrite VisitSub_(const
SubNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Sub
+ The Sub to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_mul_(self, op: Mul) -> PrimExpr:
+ """Visit Mul.
+
+ Users can customize this function to overwrite VisitMul_(const
MulNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Mul
+ The Mul to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_div_(self, op: Div) -> PrimExpr:
+ """Visit Div.
+
+ Users can customize this function to overwrite VisitDiv_(const
DivNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Div
+ The Div to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_mod_(self, op: Mod) -> PrimExpr:
+ """Visit Mod.
+
+ Users can customize this function to overwrite VisitMod_(const
ModNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Mod
+ The Mod to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_floor_div_(self, op: FloorDiv) -> PrimExpr:
+ """Visit FloorDiv.
+
+ Users can customize this function to overwrite VisitFloorDiv_(const
FloorDivNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : FloorDiv
+ The FloorDiv to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_floor_mod_(self, op: FloorMod) -> PrimExpr:
+ """Visit FloorMod.
+
+ Users can customize this function to overwrite VisitFloorMod_(const
FloorModNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : FloorMod
+ The FloorMod to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_min_(self, op: Min) -> PrimExpr:
+ """Visit Min.
+
+ Users can customize this function to overwrite VisitMin_(const
MinNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Min
+ The Min to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_max_(self, op: Max) -> PrimExpr:
+ """Visit Max.
+
+ Users can customize this function to overwrite VisitMax_(const
MaxNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Max
+ The Max to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_eq_(self, op: EQ) -> PrimExpr:
+ """Visit EQ.
+
+ Users can customize this function to overwrite VisitEQ_(const EQNode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : EQ
+ The EQ to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_ne_(self, op: NE) -> PrimExpr:
+ """Visit NE.
+
+ Users can customize this function to overwrite VisitNE_(const NENode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : NE
+ The NE to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_lt_(self, op: LT) -> PrimExpr:
+ """Visit LT.
+
+ Users can customize this function to overwrite VisitLT_(const LTNode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : LT
+ The LT to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_le_(self, op: LE) -> PrimExpr:
+ """Visit LE.
+
+ Users can customize this function to overwrite VisitLE_(const LENode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : LE
+ The LE to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_gt_(self, op: GT) -> PrimExpr:
+ """Visit GT.
+
+ Users can customize this function to overwrite VisitGT_(const GTNode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : GT
+ The GT to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_ge_(self, op: GE) -> PrimExpr:
+ """Visit GE.
+
+ Users can customize this function to overwrite VisitGE_(const GENode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : GE
+ The GE to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_and_(self, op: And) -> PrimExpr:
+ """Visit And.
+
+ Users can customize this function to overwrite VisitAnd_(const
AndNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : And
+ The And to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_or_(self, op: Or) -> PrimExpr:
+ """Visit Or.
+
+ Users can customize this function to overwrite VisitOr_(const OrNode*
op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Or
+ The Or to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_reduce_(self, op: Reduce) -> PrimExpr:
+ """Visit Reduce.
+
+ Users can customize this function to overwrite VisitReduce_(const
ReduceNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Reduce
+ The Reduce to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_cast_(self, op: Cast) -> PrimExpr:
+ """Visit Cast.
+
+ Users can customize this function to overwrite VisitCast_(const
CastNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Cast
+ The Cast to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_not_(self, op: Not) -> PrimExpr:
+ """Visit Not.
+
+ Users can customize this function to overwrite VisitNot_(const
NotNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Not
+ The Not to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_select_(self, op: Select) -> PrimExpr:
+ """Visit Select.
+
+ Users can customize this function to overwrite VisitSelect_(const
SelectNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Select
+ The Select to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_ramp_(self, op: Ramp) -> PrimExpr:
+ """Visit Ramp.
+
+ Users can customize this function to overwrite VisitRamp_(const
RampNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Ramp
+ The Ramp to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_broadcast_(self, op: Broadcast) -> PrimExpr:
+ """Visit Broadcast.
+
+ Users can customize this function to overwrite VisitBroadcast_(const
BroadcastNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Broadcast
+ The Broadcast to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_shuffle_(self, op: Shuffle) -> PrimExpr:
+ """Visit Shuffle.
+
+ Users can customize this function to overwrite VisitShuffle_(const
ShuffleNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : Shuffle
+ The Shuffle to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_int_imm_(self, op: IntImm) -> PrimExpr:
+ """Visit IntImm.
+
+ Users can customize this function to overwrite VisitIntImm_(const
IntImmNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : IntImm
+ The IntImm to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_float_imm_(self, op: FloatImm) -> PrimExpr:
+ """Visit FloatImm.
+
+ Users can customize this function to overwrite VisitFloatImm_(const
FloatImmNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : FloatImm
+ The FloatImm to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
+
+ def visit_string_imm_(self, op: StringImm) -> PrimExpr:
+ """Visit StringImm.
+
+ Users can customize this function to overwrite VisitStringImm_(const
StringImmNode* op)
+ on the C++ side.
+
+ Parameters
+ ----------
+ op : StringImm
+ The StringImm to be visited.
+
+ Returns
+ -------
+ result : PrimExpr
+ The mutated PrimExpr.
+ """
+ return _ffi_api.PyStmtExprMutatorDefaultVisitExpr(self._outer(), op)
# type: ignore
diff --git a/src/tir/analysis/buffer_access_lca_detector.cc
b/src/tir/analysis/buffer_access_lca_detector.cc
index aca4c99e11..7ac84ce894 100644
--- a/src/tir/analysis/buffer_access_lca_detector.cc
+++ b/src/tir/analysis/buffer_access_lca_detector.cc
@@ -147,7 +147,7 @@ class LCADetector : public StmtExprVisitor {
auto do_collect_itervar_scope = [this](const IterVar& itervar,
const PrimExpr& binding) -> const
ScopeInfo* {
const ScopeInfo* highest_scope = nullptr;
- PostOrderVisit(binding, [this, &itervar, &highest_scope](const
ObjectRef& obj) {
+ PostOrderVisit(binding, [this, &highest_scope](const ObjectRef& obj) {
if (const VarNode* loop_var = obj.as<VarNode>()) {
auto it = loop_scope_map_.find(loop_var);
if (it == loop_scope_map_.end()) {
diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc
new file mode 100644
index 0000000000..6152c99eaf
--- /dev/null
+++ b/src/tir/ir/py_functor.cc
@@ -0,0 +1,859 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/py_functor.cc
+ * \brief The python interface of ExprVisitor/ExprMutator,
StmtVisitor/StmtMutator,
+ * StmtExprVisitor/StmtExprMutator.
+ */
+
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+// ================================================
+// Helper Macros
+// ================================================
+#define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \
+ void VisitExpr_(const OP* op) override { \
+ if (PY_FUNC != nullptr) { \
+ PY_FUNC(op); \
+ } else { \
+ StmtExprVisitor::VisitExpr_(op); \
+ } \
+ }
+
+#define IR_EXPR_VISITOR_DEFAULT_DISPATCH(OP) \
+ vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
+ self->StmtExprVisitor::VisitExpr_(static_cast<const OP*>(n.get())); \
+ });
+
+#define PY_STMT_VISITOR_DISPATCH(OP, PY_FUNC) \
+ void VisitStmt_(const OP* op) override { \
+ if (PY_FUNC != nullptr) { \
+ PY_FUNC(op); \
+ } else { \
+ StmtExprVisitor::VisitStmt_(op); \
+ } \
+ }
+
+#define PY_STMT_VISITOR_DEFAULT_DISPATCH(OP) \
+ vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \
+ self->StmtExprVisitor::VisitStmt_(static_cast<const OP*>(n.get())); \
+ });
+
+#define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \
+ PrimExpr VisitExpr_(const OP* op) override { \
+ if (PY_FUNC != nullptr) { \
+ return PY_FUNC(op).cast<PrimExpr>(); \
+ } else { \
+ return StmtExprMutator::VisitExpr_(op); \
+ } \
+ }
+
+#define PY_EXPR_MUTATOR_DEFAULT_DISPATCH(OP)
\
+ vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) {
\
+ return self->StmtExprMutator::VisitExpr_(static_cast<const OP*>(n.get()));
\
+ });
+
+#define PY_STMT_MUTATOR_DISPATCH(OP, PY_FUNC) \
+ Stmt VisitStmt_(const OP* op) override { \
+ if (PY_FUNC != nullptr) { \
+ return PY_FUNC(op).cast<Stmt>(); \
+ } else { \
+ return StmtExprMutator::VisitStmt_(op); \
+ } \
+ }
+
+#define PY_STMT_MUTATOR_DEFAULT_DISPATCH(OP)
\
+ vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) {
\
+ return self->StmtExprMutator::VisitStmt_(static_cast<const OP*>(n.get()));
\
+ });
+
+/*! \brief The python interface of StmtExprVisitor. */
+class PyStmtExprVisitorNode : public Object, public StmtExprVisitor {
+ private:
+ using TSelf = PyStmtExprVisitorNode;
+ using FExprType = tvm::NodeFunctor<void(const ObjectRef& n, TSelf* self)>;
+ using FStmtType = tvm::NodeFunctor<void(const ObjectRef& n, TSelf* self)>;
+
+ public:
+ // Expression functions
+ /*! \brief The packed function to the `VisitExpr(const Expr& expr)`
function. */
+ ffi::Function f_visit_expr{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const VarNode* op)`
function. */
+ ffi::Function f_visit_var{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const SizeVarNode* op)`
function. */
+ ffi::Function f_visit_size_var{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const BufferLoadNode* op)`
function. */
+ ffi::Function f_visit_buffer_load{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const ProducerLoadNode*
op)` function. */
+ ffi::Function f_visit_producer_load{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const LetNode* op)`
function. */
+ ffi::Function f_visit_let{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const CallNode* op)`
function. */
+ ffi::Function f_visit_call{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const AddNode* op)`
function. */
+ ffi::Function f_visit_add{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const SubNode* op)`
function. */
+ ffi::Function f_visit_sub{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const MulNode* op)`
function. */
+ ffi::Function f_visit_mul{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const DivNode* op)`
function. */
+ ffi::Function f_visit_div{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const ModNode* op)`
function. */
+ ffi::Function f_visit_mod{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const FloorDivNode* op)`
function. */
+ ffi::Function f_visit_floor_div{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const FloorModNode* op)`
function. */
+ ffi::Function f_visit_floor_mod{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const MinNode* op)`
function. */
+ ffi::Function f_visit_min{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const MaxNode* op)`
function. */
+ ffi::Function f_visit_max{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const EQNode* op)`
function. */
+ ffi::Function f_visit_eq{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const NENode* op)`
function. */
+ ffi::Function f_visit_ne{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const LTNode* op)`
function. */
+ ffi::Function f_visit_lt{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const LENode* op)`
function. */
+ ffi::Function f_visit_le{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const GTNode* op)`
function. */
+ ffi::Function f_visit_gt{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const GENode* op)`
function. */
+ ffi::Function f_visit_ge{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const AndNode* op)`
function. */
+ ffi::Function f_visit_and{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const OrNode* op)`
function. */
+ ffi::Function f_visit_or{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const ReduceNode* op)`
function. */
+ ffi::Function f_visit_reduce{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const CastNode* op)`
function. */
+ ffi::Function f_visit_cast{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const NotNode* op)`
function. */
+ ffi::Function f_visit_not{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const SelectNode* op)`
function. */
+ ffi::Function f_visit_select{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const RampNode* op)`
function. */
+ ffi::Function f_visit_ramp{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const BroadcastNode* op)`
function. */
+ ffi::Function f_visit_broadcast{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const ShuffleNode* op)`
function. */
+ ffi::Function f_visit_shuffle{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const IntImmNode* op)`
function. */
+ ffi::Function f_visit_int_imm{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const FloatImmNode* op)`
function. */
+ ffi::Function f_visit_float_imm{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)`
function. */
+ ffi::Function f_visit_string_imm{nullptr};
+
+ // Statement functions
+ /*! \brief The packed function to the `VisitStmt(const Stmt& stmt)`
function. */
+ ffi::Function f_visit_stmt{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)`
function. */
+ ffi::Function f_visit_attr_stmt{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)`
function. */
+ ffi::Function f_visit_if_then_else{nullptr}; // NOLINT(readability/braces)
+ /*! \brief The packed function to the `VisitStmt_(const ForNode* op)`
function. */
+ ffi::Function f_visit_let_stmt{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)`
function. */
+ ffi::Function f_visit_for{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const WhileNode* op)`
function. */
+ ffi::Function f_visit_while{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const AllocateNode* op)`
function. */
+ ffi::Function f_visit_allocate{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const AllocateConstNode*
op)` function. */
+ ffi::Function f_visit_allocate_const{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const DeclBufferNode* op)`
function. */
+ ffi::Function f_visit_decl_buffer{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const BufferStoreNode*
op)` function. */
+ ffi::Function f_visit_buffer_store{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const BufferRealizeNode*
op)` function. */
+ ffi::Function f_visit_buffer_realize{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const AssertStmtNode* op)`
function. */
+ ffi::Function f_visit_assert_stmt{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const SeqStmtNode* op)`
function. */
+ ffi::Function f_visit_seq_stmt{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const EvaluateNode* op)`
function. */
+ ffi::Function f_visit_evaluate{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const BlockNode* op)`
function. */
+ ffi::Function f_visit_block{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const BlockRealizeNode*
op)` function. */
+ ffi::Function f_visit_block_realize{nullptr};
+
+ using StmtExprVisitor::VisitExpr;
+ using StmtExprVisitor::VisitStmt;
+
+ void DefaultVisitExpr(const PrimExpr& expr) {
+ static FExprType vtable = InitExprVTable();
+ vtable(expr, this);
+ }
+
+ void DefaultVisitStmt(const Stmt& stmt) {
+ static FStmtType vtable = InitStmtVTable();
+ vtable(stmt, this);
+ }
+
+ void VisitAttrs(AttrVisitor* v) {}
+ static constexpr const char* _type_key = "tir.PyStmtExprVisitor";
+ TVM_DECLARE_BASE_OBJECT_INFO(PyStmtExprVisitorNode, Object);
+
+ private:
+ // Statement functions
+ PY_STMT_VISITOR_DISPATCH(LetStmtNode, f_visit_let_stmt);
+ PY_STMT_VISITOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt);
+ PY_STMT_VISITOR_DISPATCH(IfThenElseNode, f_visit_if_then_else);
+ PY_STMT_VISITOR_DISPATCH(ForNode, f_visit_for);
+ PY_STMT_VISITOR_DISPATCH(WhileNode, f_visit_while);
+ PY_STMT_VISITOR_DISPATCH(AllocateNode, f_visit_allocate);
+ PY_STMT_VISITOR_DISPATCH(AllocateConstNode, f_visit_allocate_const);
+ PY_STMT_VISITOR_DISPATCH(DeclBufferNode, f_visit_decl_buffer);
+ PY_STMT_VISITOR_DISPATCH(BufferStoreNode, f_visit_buffer_store);
+ PY_STMT_VISITOR_DISPATCH(BufferRealizeNode, f_visit_buffer_realize);
+ PY_STMT_VISITOR_DISPATCH(AssertStmtNode, f_visit_assert_stmt);
+ PY_STMT_VISITOR_DISPATCH(SeqStmtNode, f_visit_seq_stmt);
+ PY_STMT_VISITOR_DISPATCH(EvaluateNode, f_visit_evaluate);
+ PY_STMT_VISITOR_DISPATCH(BlockNode, f_visit_block);
+ PY_STMT_VISITOR_DISPATCH(BlockRealizeNode, f_visit_block_realize);
+ // Expression functions
+ PY_EXPR_VISITOR_DISPATCH(VarNode, f_visit_var);
+ PY_EXPR_VISITOR_DISPATCH(SizeVarNode, f_visit_size_var);
+ PY_EXPR_VISITOR_DISPATCH(BufferLoadNode, f_visit_buffer_load);
+ PY_EXPR_VISITOR_DISPATCH(ProducerLoadNode, f_visit_producer_load);
+ PY_EXPR_VISITOR_DISPATCH(LetNode, f_visit_let);
+ PY_EXPR_VISITOR_DISPATCH(CallNode, f_visit_call);
+ PY_EXPR_VISITOR_DISPATCH(AddNode, f_visit_add);
+ PY_EXPR_VISITOR_DISPATCH(SubNode, f_visit_sub);
+ PY_EXPR_VISITOR_DISPATCH(MulNode, f_visit_mul);
+ PY_EXPR_VISITOR_DISPATCH(DivNode, f_visit_div);
+ PY_EXPR_VISITOR_DISPATCH(ModNode, f_visit_mod);
+ PY_EXPR_VISITOR_DISPATCH(FloorDivNode, f_visit_floor_div);
+ PY_EXPR_VISITOR_DISPATCH(FloorModNode, f_visit_floor_mod);
+ PY_EXPR_VISITOR_DISPATCH(MinNode, f_visit_min);
+ PY_EXPR_VISITOR_DISPATCH(MaxNode, f_visit_max);
+ PY_EXPR_VISITOR_DISPATCH(EQNode, f_visit_eq);
+ PY_EXPR_VISITOR_DISPATCH(NENode, f_visit_ne);
+ PY_EXPR_VISITOR_DISPATCH(LTNode, f_visit_lt);
+ PY_EXPR_VISITOR_DISPATCH(LENode, f_visit_le);
+ PY_EXPR_VISITOR_DISPATCH(GTNode, f_visit_gt);
+ PY_EXPR_VISITOR_DISPATCH(GENode, f_visit_ge);
+ PY_EXPR_VISITOR_DISPATCH(AndNode, f_visit_and);
+ PY_EXPR_VISITOR_DISPATCH(OrNode, f_visit_or);
+ PY_EXPR_VISITOR_DISPATCH(ReduceNode, f_visit_reduce);
+ PY_EXPR_VISITOR_DISPATCH(CastNode, f_visit_cast);
+ PY_EXPR_VISITOR_DISPATCH(NotNode, f_visit_not);
+ PY_EXPR_VISITOR_DISPATCH(SelectNode, f_visit_select);
+ PY_EXPR_VISITOR_DISPATCH(RampNode, f_visit_ramp);
+ PY_EXPR_VISITOR_DISPATCH(BroadcastNode, f_visit_broadcast);
+ PY_EXPR_VISITOR_DISPATCH(ShuffleNode, f_visit_shuffle);
+ PY_EXPR_VISITOR_DISPATCH(IntImmNode, f_visit_int_imm);
+ PY_EXPR_VISITOR_DISPATCH(FloatImmNode, f_visit_float_imm);
+ PY_EXPR_VISITOR_DISPATCH(StringImmNode, f_visit_string_imm);
+
+ private:
+ static FExprType InitExprVTable() {
+ FExprType vtable;
+ // Set dispatch
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(VarNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(SizeVarNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(BufferLoadNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(ProducerLoadNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(LetNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(CallNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(AddNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(SubNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(MulNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(DivNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(ModNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(FloorDivNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(FloorModNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(MinNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(MaxNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(EQNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(NENode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(LTNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(LENode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(GTNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(GENode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(AndNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(OrNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(ReduceNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(CastNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(NotNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(SelectNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(RampNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(ShuffleNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(BroadcastNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(IntImmNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(FloatImmNode);
+ IR_EXPR_VISITOR_DEFAULT_DISPATCH(StringImmNode);
+ vtable.Finalize();
+ return vtable;
+ }
+
+ static FStmtType InitStmtVTable() {
+ FStmtType vtable;
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(LetStmtNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(AttrStmtNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(IfThenElseNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(ForNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(WhileNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(AllocateNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(AllocateConstNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(DeclBufferNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(BufferStoreNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(BufferRealizeNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(AssertStmtNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(SeqStmtNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(EvaluateNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(BlockNode);
+ PY_STMT_VISITOR_DEFAULT_DISPATCH(BlockRealizeNode);
+ vtable.Finalize();
+ return vtable;
+ }
+};
+
+/*!
+ * \brief Managed reference to PyStmtExprVisitorNode.
+ * \sa PyStmtExprVisitorNode
+ */
+class PyStmtExprVisitor : public ObjectRef {
+ public:
+ TVM_DLL static PyStmtExprVisitor MakePyStmtExprVisitor(ffi::Function
f_visit_stmt, //
+ ffi::Function
f_visit_expr, //
+ ffi::Function
f_visit_let_stmt, //
+ ffi::Function
f_visit_attr_stmt, //
+ ffi::Function
f_visit_if_then_else, //
+ ffi::Function
f_visit_for, //
+ ffi::Function
f_visit_while, //
+ ffi::Function
f_visit_allocate, //
+ ffi::Function
f_visit_allocate_const, //
+ ffi::Function
f_visit_decl_buffer, //
+ ffi::Function
f_visit_buffer_store, //
+ ffi::Function
f_visit_buffer_realize, //
+ ffi::Function
f_visit_assert_stmt, //
+ ffi::Function
f_visit_seq_stmt, //
+ ffi::Function
f_visit_evaluate, //
+ ffi::Function
f_visit_block, //
+ ffi::Function
f_visit_block_realize, //
+ ffi::Function
f_visit_var, //
+ ffi::Function
f_visit_size_var, //
+ ffi::Function
f_visit_buffer_load, //
+ ffi::Function
f_visit_producer_load, //
+ ffi::Function
f_visit_let, //
+ ffi::Function
f_visit_call, //
+ ffi::Function
f_visit_add, //
+ ffi::Function
f_visit_sub, //
+ ffi::Function
f_visit_mul, //
+ ffi::Function
f_visit_div, //
+ ffi::Function
f_visit_mod, //
+ ffi::Function
f_visit_floor_div, //
+ ffi::Function
f_visit_floor_mod, //
+ ffi::Function
f_visit_min, //
+ ffi::Function
f_visit_max, //
+ ffi::Function
f_visit_eq, //
+ ffi::Function
f_visit_ne, //
+ ffi::Function
f_visit_lt, //
+ ffi::Function
f_visit_le, //
+ ffi::Function
f_visit_gt, //
+ ffi::Function
f_visit_ge, //
+ ffi::Function
f_visit_and, //
+ ffi::Function
f_visit_or, //
+ ffi::Function
f_visit_reduce, //
+ ffi::Function
f_visit_cast, //
+ ffi::Function
f_visit_not, //
+ ffi::Function
f_visit_select, //
+ ffi::Function
f_visit_ramp, //
+ ffi::Function
f_visit_broadcast, //
+ ffi::Function
f_visit_shuffle, //
+ ffi::Function
f_visit_int_imm, //
+ ffi::Function
f_visit_float_imm, //
+ ffi::Function
f_visit_string_imm) {
+ ObjectPtr<PyStmtExprVisitorNode> n = make_object<PyStmtExprVisitorNode>();
+ n->f_visit_stmt = std::move(f_visit_stmt);
+ n->f_visit_expr = std::move(f_visit_expr);
+ // Set statement functions
+ n->f_visit_let_stmt = std::move(f_visit_let_stmt);
+ n->f_visit_attr_stmt = std::move(f_visit_attr_stmt);
+ n->f_visit_if_then_else = std::move(f_visit_if_then_else);
+ n->f_visit_for = std::move(f_visit_for);
+ n->f_visit_while = std::move(f_visit_while);
+ n->f_visit_allocate = std::move(f_visit_allocate);
+ n->f_visit_allocate_const = std::move(f_visit_allocate_const);
+ n->f_visit_decl_buffer = std::move(f_visit_decl_buffer);
+ n->f_visit_buffer_store = std::move(f_visit_buffer_store);
+ n->f_visit_buffer_realize = std::move(f_visit_buffer_realize);
+ n->f_visit_assert_stmt = std::move(f_visit_assert_stmt);
+ n->f_visit_seq_stmt = std::move(f_visit_seq_stmt);
+ n->f_visit_evaluate = std::move(f_visit_evaluate);
+ n->f_visit_block = std::move(f_visit_block);
+ n->f_visit_block_realize = std::move(f_visit_block_realize);
+ // Set expression functions
+ n->f_visit_var = std::move(f_visit_var);
+ n->f_visit_size_var = std::move(f_visit_size_var);
+ n->f_visit_buffer_load = std::move(f_visit_buffer_load);
+ n->f_visit_producer_load = std::move(f_visit_producer_load);
+ n->f_visit_let = std::move(f_visit_let);
+ n->f_visit_call = std::move(f_visit_call);
+ n->f_visit_add = std::move(f_visit_add);
+ n->f_visit_sub = std::move(f_visit_sub);
+ n->f_visit_mul = std::move(f_visit_mul);
+ n->f_visit_div = std::move(f_visit_div);
+ n->f_visit_mod = std::move(f_visit_mod);
+ n->f_visit_floor_div = std::move(f_visit_floor_div);
+ n->f_visit_floor_mod = std::move(f_visit_floor_mod);
+ n->f_visit_min = std::move(f_visit_min);
+ n->f_visit_max = std::move(f_visit_max);
+ n->f_visit_eq = std::move(f_visit_eq);
+ n->f_visit_ne = std::move(f_visit_ne);
+ n->f_visit_lt = std::move(f_visit_lt);
+ n->f_visit_le = std::move(f_visit_le);
+ n->f_visit_gt = std::move(f_visit_gt);
+ n->f_visit_ge = std::move(f_visit_ge);
+ n->f_visit_and = std::move(f_visit_and);
+ n->f_visit_or = std::move(f_visit_or);
+ n->f_visit_reduce = std::move(f_visit_reduce);
+ n->f_visit_cast = std::move(f_visit_cast);
+ n->f_visit_not = std::move(f_visit_not);
+ n->f_visit_select = std::move(f_visit_select);
+ n->f_visit_ramp = std::move(f_visit_ramp);
+ n->f_visit_broadcast = std::move(f_visit_broadcast);
+ n->f_visit_shuffle = std::move(f_visit_shuffle);
+ n->f_visit_int_imm = std::move(f_visit_int_imm);
+ n->f_visit_float_imm = std::move(f_visit_float_imm);
+ n->f_visit_string_imm = std::move(f_visit_string_imm);
+ return PyStmtExprVisitor(n);
+ }
+
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyStmtExprVisitor,
ObjectRef,
+ PyStmtExprVisitorNode);
+};
+
+/*! \brief The python interface of StmtExprMutator. */
+class PyStmtExprMutatorNode : public Object, public StmtExprMutator {
+ private:
+ using TSelf = PyStmtExprMutatorNode;
+ using FExprType = tvm::NodeFunctor<PrimExpr(const ObjectRef& n, TSelf*
self)>;
+ using FStmtType = tvm::NodeFunctor<Stmt(const ObjectRef& n, TSelf* self)>;
+
+ public:
+ // Expression functions
+ /*! \brief The packed function to the `VisitExpr(const Expr& expr)`
function. */
+ ffi::Function f_visit_expr{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const VarNode* op)`
function. */
+ ffi::Function f_visit_var{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const SizeVarNode* op)`
function. */
+ ffi::Function f_visit_size_var{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const BufferLoadNode* op)`
function. */
+ ffi::Function f_visit_buffer_load{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const ProducerLoadNode*
op)` function. */
+ ffi::Function f_visit_producer_load{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const LetNode* op)`
function. */
+ ffi::Function f_visit_let{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const CallNode* op)`
function. */
+ ffi::Function f_visit_call{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const AddNode* op)`
function. */
+ ffi::Function f_visit_add{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const SubNode* op)`
function. */
+ ffi::Function f_visit_sub{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const MulNode* op)`
function. */
+ ffi::Function f_visit_mul{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const DivNode* op)`
function. */
+ ffi::Function f_visit_div{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const ModNode* op)`
function. */
+ ffi::Function f_visit_mod{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const FloorDivNode* op)`
function. */
+ ffi::Function f_visit_floor_div{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const FloorModNode* op)`
function. */
+ ffi::Function f_visit_floor_mod{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const MinNode* op)`
function. */
+ ffi::Function f_visit_min{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const MaxNode* op)`
function. */
+ ffi::Function f_visit_max{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const EQNode* op)`
function. */
+ ffi::Function f_visit_eq{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const NENode* op)`
function. */
+ ffi::Function f_visit_ne{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const LTNode* op)`
function. */
+ ffi::Function f_visit_lt{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const LENode* op)`
function. */
+ ffi::Function f_visit_le{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const GTNode* op)`
function. */
+ ffi::Function f_visit_gt{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const GENode* op)`
function. */
+ ffi::Function f_visit_ge{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const AndNode* op)`
function. */
+ ffi::Function f_visit_and{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const OrNode* op)`
function. */
+ ffi::Function f_visit_or{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const ReduceNode* op)`
function. */
+ ffi::Function f_visit_reduce{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const CastNode* op)`
function. */
+ ffi::Function f_visit_cast{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const NotNode* op)`
function. */
+ ffi::Function f_visit_not{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const SelectNode* op)`
function. */
+ ffi::Function f_visit_select{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const RampNode* op)`
function. */
+ ffi::Function f_visit_ramp{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const BroadcastNode* op)`
function. */
+ ffi::Function f_visit_broadcast{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const ShuffleNode* op)`
function. */
+ ffi::Function f_visit_shuffle{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const IntImmNode* op)`
function. */
+ ffi::Function f_visit_int_imm{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const FloatImmNode* op)`
function. */
+ ffi::Function f_visit_float_imm{nullptr};
+ /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)`
function. */
+ ffi::Function f_visit_string_imm{nullptr};
+
+ // Statement functions
+ /*! \brief The packed function to the `VisitStmt(const Stmt& stmt)`
function. */
+ ffi::Function f_visit_stmt{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)`
function. */
+ ffi::Function f_visit_let_stmt{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)`
function. */
+ ffi::Function f_visit_attr_stmt{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)`
function. */
+ ffi::Function f_visit_if_then_else{nullptr}; // NOLINT(readability/braces)
+ /*! \brief The packed function to the `VisitStmt_(const ForNode* op)`
function. */
+ ffi::Function f_visit_for{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const WhileNode* op)`
function. */
+ ffi::Function f_visit_while{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const AllocateNode* op)`
function. */
+ ffi::Function f_visit_allocate{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const AllocateConstNode*
op)` function. */
+ ffi::Function f_visit_allocate_const{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const DeclBufferNode* op)`
function. */
+ ffi::Function f_visit_decl_buffer{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const BufferStoreNode*
op)` function. */
+ ffi::Function f_visit_buffer_store{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const BufferRealizeNode*
op)` function. */
+ ffi::Function f_visit_buffer_realize{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const AssertStmtNode* op)`
function. */
+ ffi::Function f_visit_assert_stmt{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const SeqStmtNode* op)`
function. */
+ ffi::Function f_visit_seq_stmt{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const EvaluateNode* op)`
function. */
+ ffi::Function f_visit_evaluate{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const BlockNode* op)`
function. */
+ ffi::Function f_visit_block{nullptr};
+ /*! \brief The packed function to the `VisitStmt_(const BlockRealizeNode*
op)` function. */
+ ffi::Function f_visit_block_realize{nullptr};
+
+ using StmtExprMutator::VisitExpr;
+ using StmtExprMutator::VisitStmt;
+
+ void DefaultVisitExpr(const PrimExpr& expr) {
+ static FExprType vtable = InitExprVTable();
+ vtable(expr, this);
+ }
+
+ void DefaultVisitStmt(const Stmt& stmt) {
+ static FStmtType vtable = InitStmtVTable();
+ vtable(stmt, this);
+ }
+ void VisitAttrs(AttrVisitor* v) {}
+ static constexpr const char* _type_key = "tir.PyStmtExprMutator";
+ TVM_DECLARE_BASE_OBJECT_INFO(PyStmtExprMutatorNode, Object);
+
+ private:
+ // Statement functions
+ PY_STMT_MUTATOR_DISPATCH(LetStmtNode, f_visit_let_stmt);
+ PY_STMT_MUTATOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt);
+ PY_STMT_MUTATOR_DISPATCH(IfThenElseNode, f_visit_if_then_else);
+ PY_STMT_MUTATOR_DISPATCH(ForNode, f_visit_for);
+ PY_STMT_MUTATOR_DISPATCH(WhileNode, f_visit_while);
+ PY_STMT_MUTATOR_DISPATCH(AllocateNode, f_visit_allocate);
+ PY_STMT_MUTATOR_DISPATCH(AllocateConstNode, f_visit_allocate_const);
+ PY_STMT_MUTATOR_DISPATCH(DeclBufferNode, f_visit_decl_buffer);
+ PY_STMT_MUTATOR_DISPATCH(BufferStoreNode, f_visit_buffer_store);
+ PY_STMT_MUTATOR_DISPATCH(BufferRealizeNode, f_visit_buffer_realize);
+ PY_STMT_MUTATOR_DISPATCH(AssertStmtNode, f_visit_assert_stmt);
+ PY_STMT_MUTATOR_DISPATCH(SeqStmtNode, f_visit_seq_stmt);
+ PY_STMT_MUTATOR_DISPATCH(EvaluateNode, f_visit_evaluate);
+ PY_STMT_MUTATOR_DISPATCH(BlockNode, f_visit_block);
+ PY_STMT_MUTATOR_DISPATCH(BlockRealizeNode, f_visit_block_realize);
+ // Expression functions
+ PY_EXPR_MUTATOR_DISPATCH(VarNode, f_visit_var);
+ PY_EXPR_MUTATOR_DISPATCH(SizeVarNode, f_visit_size_var);
+ PY_EXPR_MUTATOR_DISPATCH(BufferLoadNode, f_visit_buffer_load);
+ PY_EXPR_MUTATOR_DISPATCH(ProducerLoadNode, f_visit_producer_load);
+ PY_EXPR_MUTATOR_DISPATCH(LetNode, f_visit_let);
+ PY_EXPR_MUTATOR_DISPATCH(CallNode, f_visit_call);
+ PY_EXPR_MUTATOR_DISPATCH(AddNode, f_visit_add);
+ PY_EXPR_MUTATOR_DISPATCH(SubNode, f_visit_sub);
+ PY_EXPR_MUTATOR_DISPATCH(MulNode, f_visit_mul);
+ PY_EXPR_MUTATOR_DISPATCH(DivNode, f_visit_div);
+ PY_EXPR_MUTATOR_DISPATCH(ModNode, f_visit_mod);
+ PY_EXPR_MUTATOR_DISPATCH(FloorDivNode, f_visit_floor_div);
+ PY_EXPR_MUTATOR_DISPATCH(FloorModNode, f_visit_floor_mod);
+ PY_EXPR_MUTATOR_DISPATCH(MinNode, f_visit_min);
+ PY_EXPR_MUTATOR_DISPATCH(MaxNode, f_visit_max);
+ PY_EXPR_MUTATOR_DISPATCH(EQNode, f_visit_eq);
+ PY_EXPR_MUTATOR_DISPATCH(NENode, f_visit_ne);
+ PY_EXPR_MUTATOR_DISPATCH(LTNode, f_visit_lt);
+ PY_EXPR_MUTATOR_DISPATCH(LENode, f_visit_le);
+ PY_EXPR_MUTATOR_DISPATCH(GTNode, f_visit_gt);
+ PY_EXPR_MUTATOR_DISPATCH(GENode, f_visit_ge);
+ PY_EXPR_MUTATOR_DISPATCH(AndNode, f_visit_and);
+ PY_EXPR_MUTATOR_DISPATCH(OrNode, f_visit_or);
+ PY_EXPR_MUTATOR_DISPATCH(ReduceNode, f_visit_reduce);
+ PY_EXPR_MUTATOR_DISPATCH(CastNode, f_visit_cast);
+ PY_EXPR_MUTATOR_DISPATCH(NotNode, f_visit_not);
+ PY_EXPR_MUTATOR_DISPATCH(SelectNode, f_visit_select);
+ PY_EXPR_MUTATOR_DISPATCH(RampNode, f_visit_ramp);
+ PY_EXPR_MUTATOR_DISPATCH(BroadcastNode, f_visit_broadcast);
+ PY_EXPR_MUTATOR_DISPATCH(ShuffleNode, f_visit_shuffle);
+ PY_EXPR_MUTATOR_DISPATCH(IntImmNode, f_visit_int_imm);
+ PY_EXPR_MUTATOR_DISPATCH(FloatImmNode, f_visit_float_imm);
+ PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm);
+
+ private:
+ private:
+ static FExprType InitExprVTable() {
+ FExprType vtable;
+ // Set dispatch
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(VarNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(SizeVarNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(BufferLoadNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ProducerLoadNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(LetNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(CallNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(AddNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(SubNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(MulNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(DivNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ModNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(FloorDivNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(FloorModNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(MinNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(MaxNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(EQNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(NENode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(LTNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(LENode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(GTNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(GENode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(AndNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(OrNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ReduceNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(CastNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(NotNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(SelectNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(RampNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ShuffleNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(BroadcastNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(IntImmNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(FloatImmNode);
+ PY_EXPR_MUTATOR_DEFAULT_DISPATCH(StringImmNode);
+ vtable.Finalize();
+ return vtable;
+ }
+
+ static FStmtType InitStmtVTable() {
+ FStmtType vtable;
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(LetStmtNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(AttrStmtNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(IfThenElseNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(ForNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(WhileNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(AllocateNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(AllocateConstNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(DeclBufferNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(BufferStoreNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(BufferRealizeNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(AssertStmtNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(SeqStmtNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(EvaluateNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(BlockNode);
+ PY_STMT_MUTATOR_DEFAULT_DISPATCH(BlockRealizeNode);
+ vtable.Finalize();
+ return vtable;
+ }
+};
+
+/*! \brief Managed reference to PyStmtExprMutatorNode. */
+class PyStmtExprMutator : public ObjectRef {
+ public:
+ /*!
+ * \brief Create a PyStmtExprMutator with customized methods on the
python-side.
+ * \return The PyStmtExprMutator created.
+ */
+ TVM_DLL static PyStmtExprMutator MakePyStmtExprMutator(ffi::Function
f_visit_stmt, //
+ ffi::Function
f_visit_expr, //
+ ffi::Function
f_visit_let_stmt, //
+ ffi::Function
f_visit_attr_stmt, //
+ ffi::Function
f_visit_if_then_else, //
+ ffi::Function
f_visit_for, //
+ ffi::Function
f_visit_while, //
+ ffi::Function
f_visit_allocate, //
+ ffi::Function
f_visit_allocate_const, //
+ ffi::Function
f_visit_decl_buffer, //
+ ffi::Function
f_visit_buffer_store, //
+ ffi::Function
f_visit_buffer_realize, //
+ ffi::Function
f_visit_assert_stmt, //
+ ffi::Function
f_visit_seq_stmt, //
+ ffi::Function
f_visit_evaluate, //
+ ffi::Function
f_visit_block, //
+ ffi::Function
f_visit_block_realize, //
+ ffi::Function
f_visit_var, //
+ ffi::Function
f_visit_size_var, //
+ ffi::Function
f_visit_buffer_load, //
+ ffi::Function
f_visit_producer_load, //
+ ffi::Function
f_visit_let, //
+ ffi::Function
f_visit_call, //
+ ffi::Function
f_visit_add, //
+ ffi::Function
f_visit_sub, //
+ ffi::Function
f_visit_mul, //
+ ffi::Function
f_visit_div, //
+ ffi::Function
f_visit_mod, //
+ ffi::Function
f_visit_floor_div, //
+ ffi::Function
f_visit_floor_mod, //
+ ffi::Function
f_visit_min, //
+ ffi::Function
f_visit_max, //
+ ffi::Function
f_visit_eq, //
+ ffi::Function
f_visit_ne, //
+ ffi::Function
f_visit_lt, //
+ ffi::Function
f_visit_le, //
+ ffi::Function
f_visit_gt, //
+ ffi::Function
f_visit_ge, //
+ ffi::Function
f_visit_and, //
+ ffi::Function
f_visit_or, //
+ ffi::Function
f_visit_reduce, //
+ ffi::Function
f_visit_cast, //
+ ffi::Function
f_visit_not, //
+ ffi::Function
f_visit_select, //
+ ffi::Function
f_visit_ramp, //
+ ffi::Function
f_visit_broadcast, //
+ ffi::Function
f_visit_shuffle, //
+ ffi::Function
f_visit_int_imm, //
+ ffi::Function
f_visit_float_imm, //
+ ffi::Function
f_visit_string_imm) {
+ ObjectPtr<PyStmtExprMutatorNode> n = make_object<PyStmtExprMutatorNode>();
+ n->f_visit_stmt = std::move(f_visit_stmt);
+ n->f_visit_expr = std::move(f_visit_expr);
+ // Statement functions
+ n->f_visit_let_stmt = std::move(f_visit_let_stmt);
+ n->f_visit_attr_stmt = std::move(f_visit_attr_stmt);
+ n->f_visit_if_then_else = std::move(f_visit_if_then_else);
+ n->f_visit_for = std::move(f_visit_for);
+ n->f_visit_while = std::move(f_visit_while);
+ n->f_visit_allocate = std::move(f_visit_allocate);
+ n->f_visit_allocate_const = std::move(f_visit_allocate_const);
+ n->f_visit_decl_buffer = std::move(f_visit_decl_buffer);
+ n->f_visit_buffer_store = std::move(f_visit_buffer_store);
+ n->f_visit_buffer_realize = std::move(f_visit_buffer_realize);
+ n->f_visit_assert_stmt = std::move(f_visit_assert_stmt);
+ n->f_visit_seq_stmt = std::move(f_visit_seq_stmt);
+ n->f_visit_evaluate = std::move(f_visit_evaluate);
+ n->f_visit_block = std::move(f_visit_block);
+ n->f_visit_block_realize = std::move(f_visit_block_realize);
+ // Expression functions
+ n->f_visit_var = std::move(f_visit_var);
+ n->f_visit_size_var = std::move(f_visit_size_var);
+ n->f_visit_buffer_load = std::move(f_visit_buffer_load);
+ n->f_visit_producer_load = std::move(f_visit_producer_load);
+ n->f_visit_let = std::move(f_visit_let);
+ n->f_visit_call = std::move(f_visit_call);
+ n->f_visit_add = std::move(f_visit_add);
+ n->f_visit_sub = std::move(f_visit_sub);
+ n->f_visit_mul = std::move(f_visit_mul);
+ n->f_visit_div = std::move(f_visit_div);
+ n->f_visit_mod = std::move(f_visit_mod);
+ n->f_visit_floor_div = std::move(f_visit_floor_div);
+ n->f_visit_floor_mod = std::move(f_visit_floor_mod);
+ n->f_visit_min = std::move(f_visit_min);
+ n->f_visit_max = std::move(f_visit_max);
+ n->f_visit_eq = std::move(f_visit_eq);
+ n->f_visit_ne = std::move(f_visit_ne);
+ n->f_visit_lt = std::move(f_visit_lt);
+ n->f_visit_le = std::move(f_visit_le);
+ n->f_visit_gt = std::move(f_visit_gt);
+ n->f_visit_ge = std::move(f_visit_ge);
+ n->f_visit_and = std::move(f_visit_and);
+ n->f_visit_or = std::move(f_visit_or);
+ n->f_visit_reduce = std::move(f_visit_reduce);
+ n->f_visit_cast = std::move(f_visit_cast);
+ n->f_visit_not = std::move(f_visit_not);
+ n->f_visit_select = std::move(f_visit_select);
+ n->f_visit_ramp = std::move(f_visit_ramp);
+ n->f_visit_broadcast = std::move(f_visit_broadcast);
+ n->f_visit_shuffle = std::move(f_visit_shuffle);
+ n->f_visit_int_imm = std::move(f_visit_int_imm);
+ n->f_visit_float_imm = std::move(f_visit_float_imm);
+ n->f_visit_string_imm = std::move(f_visit_string_imm);
+ return PyStmtExprMutator(n);
+ }
+
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyStmtExprMutator,
ObjectRef,
+ PyStmtExprMutatorNode);
+};
+
+// ================================================
+// TVM Register
+// ================================================
+
+TVM_REGISTER_NODE_TYPE(PyStmtExprVisitorNode);
+TVM_REGISTER_NODE_TYPE(PyStmtExprMutatorNode);
+
+TVM_FFI_REGISTER_GLOBAL("tir.MakePyStmtExprVisitor")
+ .set_body_typed(PyStmtExprVisitor::MakePyStmtExprVisitor);
+TVM_FFI_REGISTER_GLOBAL("tir.MakePyStmtExprMutator")
+ .set_body_typed(PyStmtExprMutator::MakePyStmtExprMutator);
+
+// StmtExprVisitor
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorDefaultVisitExpr")
+ .set_body_typed([](PyStmtExprVisitor visitor, const PrimExpr& expr) {
+ visitor->DefaultVisitExpr(expr);
+ });
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorDefaultVisitStmt")
+ .set_body_typed([](PyStmtExprVisitor visitor, const Stmt& stmt) {
+ visitor->DefaultVisitStmt(stmt);
+ });
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorVisitStmt")
+ .set_body_typed([](PyStmtExprVisitor visitor, const Stmt& stmt) {
visitor->VisitStmt(stmt); });
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprVisitorVisitExpr")
+ .set_body_typed([](PyStmtExprVisitor visitor, const PrimExpr& expr) {
+ visitor->VisitExpr(expr);
+ });
+
+// StmtExprMutator
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorDefaultVisitExpr")
+ .set_body_typed([](PyStmtExprMutator mutator, const PrimExpr& expr) {
+ return mutator->DefaultVisitExpr(expr);
+ });
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorDefaultVisitStmt")
+ .set_body_typed([](PyStmtExprMutator mutator, const Stmt& stmt) {
+ return mutator->DefaultVisitStmt(stmt);
+ });
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorVisitExpr")
+ .set_body_typed([](PyStmtExprMutator mutator, const PrimExpr& expr) {
+ return mutator->VisitExpr(expr);
+ });
+TVM_FFI_REGISTER_GLOBAL("tir.PyStmtExprMutatorVisitStmt")
+ .set_body_typed([](PyStmtExprMutator mutator, const Stmt& stmt) {
+ return mutator->VisitStmt(stmt);
+ });
+
+} // namespace tir
+} // namespace tvm
diff --git a/tests/python/tir-transform/test_tir_functor.py
b/tests/python/tir-transform/test_tir_functor.py
new file mode 100644
index 0000000000..e8463027db
--- /dev/null
+++ b/tests/python/tir-transform/test_tir_functor.py
@@ -0,0 +1,436 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.tir import (
+ EQ,
+ LT,
+ Add,
+ Cast,
+ Evaluate,
+ FloatImm,
+ For,
+ IfThenElse,
+ IntImm,
+ Max,
+ Min,
+ Mul,
+ PyStmtExprMutator,
+ PyStmtExprVisitor,
+ StringImm,
+ Sub,
+ Var,
+)
+
+
+class ASTLog:
+ """Helper class to log AST"""
+
+ def __init__(self) -> None:
+ self.log = []
+ self.indent = "\t"
+ self.level = 0
+
+ def push_scope(self):
+ self.level += 1
+
+ def pop_scope(self):
+ self.level -= 1
+
+ def add(self, s: str):
+ self.log.append(self.indent * self.level + s)
+
+ def __str__(self) -> str:
+ return "\n".join(self.log)
+
+
[email protected]
+class ASTPrinter(PyStmtExprVisitor):
+ """Print tir AST in structured format. The shape of Node is ignored."""
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.log = ASTLog()
+
+ def visit_var_(self, op: Var) -> None:
+ self.log.add("Stmt: Var")
+ super().visit_var_(op)
+
+ def visit_add_(self, op: Add) -> None:
+ self.log.add("Stmt: Add")
+ super().visit_add_(op)
+
+
[email protected]
+class SimpleExprCounter(PyStmtExprVisitor):
+ """Count expressions without recursion"""
+
+ def __init__(self):
+ super().__init__()
+ self.var_count = 0
+ self.add_count = 0
+ self.mul_count = 0
+
+ def visit_var_(self, op: Var):
+ self.var_count += 1
+ # Don't recursively visit children to avoid infinite recursion
+
+ def visit_add_(self, op: Add):
+ self.add_count += 1
+ # Visit children manually
+ super().visit_add_(op)
+
+ def visit_mul_(self, op: Mul):
+ self.mul_count += 1
+ # Visit children manually
+ super().visit_mul_(op)
+
+
[email protected]
+class VariableReplacer(PyStmtExprMutator):
+ """Replace variables with constants"""
+
+ def __init__(self, replacements):
+ super().__init__()
+ self.replacements = replacements
+
+ def visit_var_(self, op: Var):
+ if op.name in self.replacements:
+ return IntImm("int32", self.replacements[op.name])
+ return op
+
+
[email protected]
+class AddToSubMutator(PyStmtExprMutator):
+ """Convert Add operations to Sub operations"""
+
+ def visit_add_(self, op: Add):
+ # First mutate the operands
+ a = self.visit_expr(op.a)
+ b = self.visit_expr(op.b)
+ # Convert Add to Sub
+ return Sub(a, b)
+
+
[email protected]
+class SimpleStmtCounter(PyStmtExprVisitor):
+ """Count statements without recursion"""
+
+ def __init__(self):
+ super().__init__()
+ self.for_count = 0
+ self.if_count = 0
+ self.evaluate_count = 0
+
+ def visit_for_(self, op: For):
+ self.for_count += 1
+ super().visit_for_(op)
+
+ def visit_if_then_else_(self, op: IfThenElse):
+ self.if_count += 1
+ super().visit_if_then_else_(op)
+
+ def visit_evaluate_(self, op: Evaluate):
+ self.evaluate_count += 1
+ super().visit_evaluate_(op)
+
+
[email protected]
+class ForLoopUnroller(PyStmtExprMutator):
+ """Simple loop unroller for demonstration"""
+
+ def __init__(self, unroll_factor=2):
+ super().__init__()
+ self.unroll_factor = unroll_factor
+
+ def visit_for_(self, op: For):
+ # For demonstration, just return the original for now
+ # In a real implementation, we would unroll small loops
+ return super().visit_for_(op)
+
+
[email protected]
+class SimpleStmtExprVisitor(PyStmtExprVisitor):
+ """Visitor that handles both statements and expressions"""
+
+ def __init__(self):
+ super().__init__()
+ self.expr_count = 0
+ self.stmt_count = 0
+ self.var_names = set()
+
+ def visit_var_(self, op: Var):
+ self.var_names.add(op.name)
+ self.expr_count += 1
+
+ def visit_evaluate_(self, op: Evaluate):
+ self.stmt_count += 1
+ # Visit the expression
+ self.visit_expr(op.value)
+
+
[email protected]
+class ComplexMutator(PyStmtExprMutator):
+ """Mutator that handles both statements and expressions"""
+
+ def __init__(self):
+ super().__init__()
+ self.modifications = 0
+
+ def visit_add_(self, op: Add):
+ self.modifications += 1
+ # Convert a + b to a * 2 + b for demonstration
+ a = self.visit_expr(op.a)
+ b = self.visit_expr(op.b)
+ return Add(Mul(a, IntImm("int32", 2)), b)
+
+
+def test_basic_visitor():
+ """Test the basic AST printer visitor"""
+ expr = Add(Var("x", dtype="int32"), Var("y", dtype="int32"))
+ printer = ASTPrinter()
+ printer.visit_expr(expr)
+ assert str(printer.log) == "\n".join(["Stmt: Add", "Stmt: Var", "Stmt:
Var"])
+
+
+def test_simple_expr_counter():
+ """Test simple expression counting visitor"""
+ x = Var("x", dtype="int32")
+ y = Var("y", dtype="int32")
+
+ # Create simple expression: x + y
+ expr = Add(x, y)
+
+ counter = SimpleExprCounter()
+ counter.visit_expr(expr)
+
+ assert counter.var_count == 2 # x and y
+ assert counter.add_count == 1 # one add
+
+
+def test_variable_replacer():
+ """Test expression mutator that replaces variables"""
+ x = Var("x", dtype="int32")
+ y = Var("y", dtype="int32")
+ expr = Add(x, Mul(y, IntImm("int32", 3)))
+
+ replacer = VariableReplacer({"x": 10, "y": 5})
+ result = replacer.visit_expr(expr)
+
+ # Should be Add(IntImm(10), Mul(IntImm(5), IntImm(3)))
+ assert isinstance(result, Add)
+ assert isinstance(result.a, IntImm)
+ assert result.a.value == 10
+ assert isinstance(result.b, Mul)
+ assert isinstance(result.b.a, IntImm)
+ assert result.b.a.value == 5
+
+
+def test_add_to_sub_mutator():
+ """Test mutator that converts Add to Sub"""
+ x = Var("x", dtype="int32")
+ y = Var("y", dtype="int32")
+ expr = Add(x, y)
+
+ mutator = AddToSubMutator()
+ result = mutator.visit_expr(expr)
+
+ assert isinstance(result, Sub)
+ assert isinstance(result.a, Var)
+ assert isinstance(result.b, Var)
+ assert result.a.name == "x"
+ assert result.b.name == "y"
+
+
+def test_simple_stmt_counter():
+ """Test statement visitor that counts statements"""
+ i = Var("i", dtype="int32")
+
+ # Create a simple for loop
+ loop_body = Evaluate(IntImm("int32", 0))
+ for_stmt = For(i, IntImm("int32", 0), IntImm("int32", 10),
tir.ForKind.SERIAL, loop_body)
+
+ counter = SimpleStmtCounter()
+ counter.visit_stmt(for_stmt)
+
+ assert counter.for_count == 1 # One for loop
+ assert counter.evaluate_count == 1 # One evaluate in the body
+
+
+def test_if_then_else_visitor():
+ """Test visitor with if-then-else statements"""
+ x = Var("x", dtype="int32")
+ condition = EQ(x, IntImm("int32", 0))
+ then_stmt = Evaluate(IntImm("int32", 1))
+ else_stmt = Evaluate(IntImm("int32", 2))
+
+ if_stmt = IfThenElse(condition, then_stmt, else_stmt)
+
+ counter = SimpleStmtCounter()
+ counter.visit_stmt(if_stmt)
+
+ assert counter.if_count == 1
+ assert counter.for_count == 0
+
+
+def test_simple_stmt_expr_visitor():
+ """Test stmt_expr_visitor with mixed statements and expressions"""
+ x = Var("x", dtype="int32")
+ y = Var("y", dtype="int32")
+
+ # Create an evaluate statement with an expression
+ expr = Add(x, y)
+ stmt = Evaluate(expr)
+
+ visitor = SimpleStmtExprVisitor()
+ visitor.visit_stmt(stmt)
+
+ assert visitor.stmt_count == 1 # One Evaluate statement
+ assert visitor.expr_count == 2 # Two variables
+ assert "x" in visitor.var_names
+ assert "y" in visitor.var_names
+
+
+def test_complex_mutator():
+ """Test stmt_expr_mutator"""
+ x = Var("x", dtype="int32")
+ y = Var("y", dtype="int32")
+
+ # Expression with Add operations
+ expr = Add(x, y)
+ stmt = Evaluate(expr)
+
+ mutator = ComplexMutator()
+ result = mutator.visit_stmt(stmt)
+ print(type(mutator))
+
+ assert mutator.modifications == 1 # One Add operation modified
+ assert isinstance(result, Evaluate)
+
+ # Check that the expression was modified
+ modified_expr = result.value
+ assert isinstance(modified_expr, Add)
+ assert isinstance(modified_expr.a, Mul) # First operand should be
multiplied by 2
+
+
+def test_different_expr_types():
+ """Test visitor with various expression types"""
+ x = Var("x", dtype="int32")
+
+ # Test different expression types individually
+ exprs = [
+ IntImm("int32", 42),
+ FloatImm("float32", 3.14),
+ StringImm("hello"),
+ Cast("float32", x),
+ Min(x, IntImm("int32", 10)),
+ Max(x, IntImm("int32", 0)),
+ LT(x, IntImm("int32", 5)),
+ ]
+
+ # Just test that we can create and visit each type
+ counter = SimpleExprCounter()
+ for expr in exprs:
+ try:
+ counter.visit_expr(expr)
+ except Exception as e:
+ # Some expressions might not be supported, that's ok
+ pass
+
+
+def test_decorator_functionality():
+ """Test that decorators work correctly"""
+
+ # Test that decorated classes are properly wrapped
+ visitor = SimpleExprCounter()
+ assert hasattr(visitor, "_outer") # Should have the wrapper functionality
+
+ mutator = VariableReplacer({})
+ assert hasattr(mutator, "_outer")
+
+
+def test_empty_expressions():
+ """Test handling of simple expressions"""
+ counter = SimpleExprCounter()
+
+ # Test with just a variable
+ x = Var("x", dtype="int32")
+ counter.visit_expr(x)
+
+ assert counter.var_count == 1
+
+ # Test with just a constant
+ counter = SimpleExprCounter()
+ const = IntImm("int32", 5)
+ counter.visit_expr(const)
+
+ # Constants don't increase var_count
+ assert counter.var_count == 0
+
+
+def test_stmt_mutator():
+ """Test basic statement mutator functionality"""
+ x = Var("x", dtype="int32")
+ stmt = Evaluate(Add(x, IntImm("int32", 1)))
+
+ unroller = ForLoopUnroller()
+ result = unroller.visit_stmt(stmt)
+
+ # Should return the same statement (no actual unrolling implemented)
+ assert isinstance(result, Evaluate)
+
+
+def test_nested_expressions():
+ """Test with nested expressions"""
+ x = Var("x", dtype="int32")
+ y = Var("y", dtype="int32")
+ z = Var("z", dtype="int32")
+
+ # Create nested expression: (x + y) * z
+ inner_add = Add(x, y)
+ expr = Mul(inner_add, z)
+
+ counter = SimpleExprCounter()
+ counter.visit_expr(expr)
+
+ assert counter.var_count == 3 # x, y, z
+ assert counter.add_count == 1 # one add
+ assert counter.mul_count == 1 # one mul
+
+
+def test_simple_mutations():
+ """Test simple expression mutations"""
+ x = Var("x", dtype="int32")
+ y = Var("y", dtype="int32")
+
+ # Test multiple replacements
+ expr = Add(x, y)
+ replacer = VariableReplacer({"x": 1, "y": 2})
+ result = replacer.visit_expr(expr)
+
+ assert isinstance(result, Add)
+ assert isinstance(result.a, IntImm)
+ assert isinstance(result.b, IntImm)
+ assert result.a.value == 1
+ assert result.b.value == 2
+
+
+if __name__ == "__main__":
+ test_basic_visitor()
+ tvm.testing.main()