This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch tir-bench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit bb57b65fd8723cb82f7834a0b9974adaad093dca
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed May 20 22:08:14 2026 -0700

    fix(tirx/stmt_functor): add ScopeIdDefStmt to Python StmtFunctor dispatch 
(#637)
    
    The PR-#636 / kernels-#306 ``Tx.device_entry()`` flat-form migration
    introduced free-standing ``ScopeIdDefStmt`` nodes (``wg_id =
    Tx.warpgroup_id([N])``, ``warp_id = Tx.warp_id_in_wg([4])``, …) at the
    kernel-body top level. The C++ ``StmtVisitor`` / ``StmtMutator`` already
    handle them (visit extents via ``VisitExpr``; mutate via rebuild), but
    the Python ``StmtFunctor.__init__`` did not register
    ``"tirx.ScopeIdDefStmt"`` in its ``_dispatch_map``, so every Python
    visitor / mutator that walks a post-migration kernel falls through to
    ``visit_stmt_default_`` and blows up with
    ``NotImplementedError: Do not have a default for ScopeIdDefStmt``.
    
    Surfaced concretely from cpusim — ~47 visitor / mutator subclasses
    were affected, and the tools/cpusim-side workaround (monkey-patching
    ``StmtFunctor.__init__``) belongs upstream.
    
    Changes:
    
    * ``stmt.py`` — register a Python ``ScopeIdDefStmt`` class so the FFI
      returns instances of it (instead of an auto-generated fallback).
      Re-export from ``tvm.tirx``. The C++ field is named ``def`` (a
      Python keyword), so access is ``getattr(stmt, "def")``.
    
    * ``stmt_functor.py`` — wire ``"tirx.ScopeIdDefStmt"`` into the
      ``StmtFunctor._dispatch_map`` and add ``visit_scope_id_def_stmt_``
      with the same shape as the existing ``visit_*_`` methods:
        - Abstract on ``StmtFunctor`` (raises via ``visit_stmt_default_``).
        - Concrete on ``StmtVisitor`` — walk extents and preferred_extents
          via ``visit_expr`` (mirrors the C++ visitor).
        - Concrete on ``StmtMutator`` — walk extents and preferred_extents
          via ``visit_expr``; if any changed, rebuild ``ScopeIdDef`` (using
          the new ``_SCOPE_BINDING_TO_PARENT_CUR`` reverse map) and wrap
          in a fresh ``ScopeIdDefStmt`` (mirrors the C++ mutator).
    
    * ``exec_scope.py`` — add ``_SCOPE_BINDING_TO_PARENT_CUR`` mirroring
      the C++ ``ScopeBinding`` enum, so the Python mutator can rebuild a
      ``ScopeIdDef`` from an existing one (whose ``scope`` field is the
      int form).
---
 python/tvm/tirx/__init__.py     |  2 +-
 python/tvm/tirx/exec_scope.py   | 19 ++++++++++++++
 python/tvm/tirx/stmt.py         | 35 +++++++++++++++++++++++++-
 python/tvm/tirx/stmt_functor.py | 55 +++++++++++++++++++++++++++++++++++++++++
 4 files changed, 109 insertions(+), 2 deletions(-)

diff --git a/python/tvm/tirx/__init__.py b/python/tvm/tirx/__init__.py
index 00a3522238..10de65a564 100644
--- a/python/tvm/tirx/__init__.py
+++ b/python/tvm/tirx/__init__.py
@@ -44,7 +44,7 @@ from .stmt import BufferStore, AllocBuffer, AttrStmt, 
DeclBuffer
 from .stmt import SeqStmt
 from .stmt import IfThenElse, Evaluate, stmt_seq, stmt_list
 from .stmt import BufferRegion, MatchBufferRegion, SBlock, SBlockRealize
-from .stmt import TilePrimitiveCall, ExecScopeStmt
+from .stmt import TilePrimitiveCall, ExecScopeStmt, ScopeIdDefStmt
 
 from .function import PrimFunc, TensorIntrin, IndexMap
 
diff --git a/python/tvm/tirx/exec_scope.py b/python/tvm/tirx/exec_scope.py
index 9a6e00eb6c..e63d6830df 100644
--- a/python/tvm/tirx/exec_scope.py
+++ b/python/tvm/tirx/exec_scope.py
@@ -65,6 +65,25 @@ _SCOPE_KIND_TO_NAME = {
 }
 
 
+# Mirror of ``enum class ScopeBinding`` in tvm/tirx/exec_scope.h. Maps the
+# ``int`` value of ``ScopeIdDef.scope`` back to the ``(parent, cur)`` pair
+# that ``ScopeIdDef.__init__`` accepts — needed when Python code wants to
+# rebuild a ``ScopeIdDef`` from an existing one (e.g. a StmtMutator
+# walking and rewriting extents).
+_SCOPE_BINDING_TO_PARENT_CUR = {
+    0: ("kernel", "cluster"),
+    1: ("kernel", "cta"),
+    2: ("cluster", "cta"),
+    3: ("cta", "warpgroup"),
+    4: ("cta", "warp"),
+    5: ("warpgroup", "warp"),
+    6: ("warp", "thread"),
+    7: ("cta", "thread"),
+    8: ("warpgroup", "thread"),
+    9: ("cluster", "cta_pair"),
+}
+
+
 @register_object("tirx.ExecScope")
 class ExecScope(Object):
     """An execution scope, identified by one of {cluster, cta, warpgroup, warp,
diff --git a/python/tvm/tirx/stmt.py b/python/tvm/tirx/stmt.py
index f1072bf25a..4972c71518 100644
--- a/python/tvm/tirx/stmt.py
+++ b/python/tvm/tirx/stmt.py
@@ -39,7 +39,7 @@ from tvm.tirx import FloatImm
 
 from . import _ffi_api
 from .buffer import Buffer
-from .exec_scope import ExecScope
+from .exec_scope import ExecScope, ScopeIdDef
 from .expr import IterVar, StringImm, Var
 
 if TYPE_CHECKING:
@@ -848,6 +848,39 @@ class ExecScopeStmt(Stmt):
         )  # type: ignore
 
 
+@tvm_ffi.register_object("tirx.ScopeIdDefStmt")
+class ScopeIdDefStmt(Stmt):
+    """ScopeIdDefStmt node.
+
+    Leaf statement that introduces scope-identifier vars
+    (``wg_id = Tx.warpgroup_id([N])``, ``warp_id = Tx.warp_id_in_wg([4])``,
+    ``lane_id = Tx.lane_id([32])``, …) at the kernel-body top level. The
+    underlying ``ScopeIdDef`` carries the def vars, their extents, and
+    the parent/child scope binding.
+
+    Note: the C++ field is named ``def`` (a Python keyword). Access it
+    via ``getattr(stmt, "def")`` or ``stmt.__getattribute__("def")`` —
+    the type-annotation alias here is purely for documentation.
+
+    Parameters
+    ----------
+    def_ : ScopeIdDef
+        The scope-id definition (def vars, extents, scope binding).
+
+    span : Optional[Span]
+        The location of this statement in the source code.
+    """
+
+    span: Span | None
+
+    def __init__(self, def_: ScopeIdDef, span: Span | None = None) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.ScopeIdDefStmt,  # type: ignore
+            def_,
+            span,
+        )  # type: ignore
+
+
 @tvm_ffi.register_object("tirx.Break")
 class Break(Stmt):
     """Break node.
diff --git a/python/tvm/tirx/stmt_functor.py b/python/tvm/tirx/stmt_functor.py
index 65c08921b9..c67032d4b0 100644
--- a/python/tvm/tirx/stmt_functor.py
+++ b/python/tvm/tirx/stmt_functor.py
@@ -54,6 +54,7 @@ class StmtFunctor:
             "tirx.SBlock": self.visit_block_,
             "tirx.SBlockRealize": self.visit_block_realize_,
             "tirx.ExecScopeStmt": self.visit_exec_scope_stmt_,
+            "tirx.ScopeIdDefStmt": self.visit_scope_id_def_stmt_,
             "tirx.TilePrimitiveCall": self.visit_op_call_,
             "tirx.AllocBuffer": self.visit_alloc_buffer_,
         }
@@ -176,6 +177,10 @@ class StmtFunctor:
         """Visitor for ExecScopeStmt nodes."""
         return self.visit_stmt_default_(op)
 
+    def visit_scope_id_def_stmt_(self, op):
+        """Visitor for ScopeIdDefStmt nodes."""
+        return self.visit_stmt_default_(op)
+
     def visit_op_call_(self, op):
         """Visitor for TilePrimitiveCall nodes."""
         return self.visit_stmt_default_(op)
@@ -338,6 +343,23 @@ class StmtVisitor(StmtFunctor):
         """Visitor implementation for ExecScopeStmt."""
         self.visit_stmt(op.body)
 
+    def visit_scope_id_def_stmt_(self, op):
+        """Visitor implementation for ScopeIdDefStmt.
+
+        Mirrors the C++ visitor: walk extents and preferred_extents via
+        ``visit_expr``; there is no body to recurse into (the def vars
+        themselves are leaves the visitor doesn't otherwise inspect).
+        """
+        # The C++ field is named ``def``, which is a Python keyword,
+        # so it's accessed via ``getattr``.
+        sid = getattr(op, "def")
+        if sid.extents is not None:
+            for e in sid.extents:
+                self.visit_expr(e)
+        if sid.preferred_extents is not None:
+            for e in sid.preferred_extents:
+                self.visit_expr(e)
+
     def visit_op_call_(self, op):
         """Visitor implementation for TilePrimitiveCall."""
         for arg in op.args:
@@ -781,6 +803,39 @@ class StmtMutator(StmtFunctor):
 
         return tvm.tirx.ExecScopeStmt(op.exec_scope, body, op.span)
 
+    def visit_scope_id_def_stmt_(self, op):
+        """Mutator implementation for ScopeIdDefStmt.
+
+        Mirrors the C++ mutator: rewrite ``extents`` and
+        ``preferred_extents`` via ``visit_expr``. Deferred-extent defs
+        (extents is None) and unchanged extents pass through.
+        """
+        from .exec_scope import _SCOPE_BINDING_TO_PARENT_CUR, ScopeIdDef
+
+        # ``def`` is a Python keyword; access the C++ field via ``getattr``.
+        sid = getattr(op, "def")
+        changed = False
+
+        def _walk(arr):
+            nonlocal changed
+            if arr is None:
+                return None
+            out = []
+            for e in arr:
+                ne = self.visit_expr(e)
+                if ne is not e:
+                    changed = True
+                out.append(ne)
+            return out
+
+        new_extents = _walk(sid.extents)
+        new_pref = _walk(sid.preferred_extents)
+        if not changed:
+            return op
+        parent, cur = _SCOPE_BINDING_TO_PARENT_CUR[sid.scope]
+        new_def = ScopeIdDef(sid.def_ids, new_extents, parent, cur, new_pref)
+        return tvm.tirx.ScopeIdDefStmt(new_def, op.span)
+
     def visit_op_call_(self, op):
         """Mutator implementation for TilePrimitiveCall."""
         new_args = []

Reply via email to