This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 95a4b1a819 [IR][TIR] Remove body from AssertStmt (#18832)
95a4b1a819 is described below

commit 95a4b1a8193c28a62f8195be6a13539e7f701a4d
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Feb 27 19:34:38 2026 -0500

    [IR][TIR] Remove body from AssertStmt (#18832)
    
    This PR Remove the body field from AssertStmt, making it a leaf
    statement. Constraints from AssertStmt are now tracked via
    WithGroup<ConstraintContext> in a ScopeStack, providing clean RAII-based
    scope management.
    
    New utilities:
    - WithGroup<T>: manages a dynamic group of With<T> RAII contexts
    - ScopeStack<T>: scope stack for hierarchical state during IR visiting
---
 include/tvm/arith/analyzer.h                       |   2 +-
 include/tvm/ir/scope_stack.h                       | 123 +++++++++++++++++
 include/tvm/support/with.h                         |  74 +++++++++++
 include/tvm/tir/stmt.h                             |   2 +-
 python/tvm/tir/stmt.py                             |   9 +-
 src/arith/ir_mutator_with_analyzer.cc              | 146 ++++++++++++---------
 src/arith/ir_mutator_with_analyzer.h               |   5 +
 src/arith/ir_visitor_with_analyzer.cc              |  79 ++++++-----
 src/arith/ir_visitor_with_analyzer.h               |   6 +
 src/relax/op/tensor/inspect.cc                     |  11 +-
 src/s_tir/analysis/estimate_flops.cc               |   1 -
 src/s_tir/transform/bound_checker.cc               |   3 +-
 src/script/ir_builder/tir/frame.cc                 |  11 +-
 src/script/printer/tir/stmt.cc                     |  22 ++--
 src/target/llvm/codegen_llvm.cc                    |   5 +-
 src/target/source/codegen_c.cc                     |   1 -
 src/target/source/codegen_c_host.cc                |   1 -
 src/target/source/codegen_webgpu.cc                |   3 +-
 src/target/spirv/codegen_spirv.cc                  |   3 +-
 src/tir/ir/stmt.cc                                 |  10 +-
 src/tir/ir/stmt_functor.cc                         |   5 +-
 src/tir/ir/tir_visitor_with_path.cc                |   1 -
 src/tir/transform/arg_binder.cc                    |  16 +--
 src/tir/transform/ir_utils.cc                      |   7 +-
 src/tir/transform/make_packed_api.cc               |  12 +-
 src/tir/transform/skip_assert.cc                   |   5 +-
 src/tir/transform/split_host_device.cc             |   2 +-
 tests/python/tir-base/test_tir_constructor.py      |   3 +-
 .../tvmscript/test_tvmscript_ir_builder_tir.py     |   9 +-
 .../python/tvmscript/test_tvmscript_printer_tir.py |   4 +-
 30 files changed, 403 insertions(+), 178 deletions(-)

diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 85d814f5b2..b77f2ee5db 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -563,7 +563,7 @@ class ConstraintContext {
   Analyzer* analyzer_;
   /*! \brief The constraint */
   PrimExpr constraint_;
-  /*! \brief function to be called in recovery */
+  /*! \brief functions to be called in recovery */
   std::vector<std::function<void()>> recovery_functions_;
 };
 
diff --git a/include/tvm/ir/scope_stack.h b/include/tvm/ir/scope_stack.h
new file mode 100644
index 0000000000..694d35e19e
--- /dev/null
+++ b/include/tvm/ir/scope_stack.h
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/scope_stack.h
+ * \brief A generic scope stack for managing hierarchical state during IR 
visiting.
+ */
+#ifndef TVM_IR_SCOPE_STACK_H_
+#define TVM_IR_SCOPE_STACK_H_
+
+#include <tvm/ffi/error.h>
+
+#include <deque>
+#include <type_traits>
+
+namespace tvm {
+
+/*!
+ * \brief A scope stack for maintaining hierarchical state during IR visiting.
+ *
+ * During IR tree traversal, visitors often need to track scope-local state
+ * (e.g., active constraints, variable bindings) that should be automatically
+ * cleaned up when leaving a scope. ScopeStack provides this via WithNewScope,
+ * which pushes a new element on entry and pops it on exit.
+ *
+ * \code
+ *   ScopeStack<WithGroup<ConstraintContext>> constraints;
+ *
+ *   // In VisitStmt_(ForNode):
+ *   return constraints.WithNewScope([&]() -> Stmt {
+ *     constraints.Current().Emplace(&analyzer, condition);
+ *     return StmtExprMutator::VisitStmt_(op);
+ *   });
+ * \endcode
+ *
+ * \tparam T The element type stored on the stack. Must be 
default-constructible.
+ */
+template <typename T>
+class ScopeStack {
+ public:
+  /*! \brief Construct with one initial scope level. */
+  ScopeStack() { stack_.emplace_back(); }
+
+  /*! \brief Return the number of active scopes. */
+  size_t size() const { return stack_.size(); }
+
+  /*! \brief Return true if no scopes are active. */
+  bool empty() const { return stack_.empty(); }
+
+  /*!
+   * \brief Access the current (innermost) scope element.
+   *
+   * The returned reference is stable across push_back/pop_back because
+   * std::deque guarantees pointer stability for these operations.
+   *
+   * \return Mutable reference to the top element.
+   */
+  T& Current() {
+    TVM_FFI_ICHECK(!stack_.empty());
+    return stack_.back();
+  }
+
+  /*! \brief Const access to the current (innermost) scope element. */
+  const T& Current() const {
+    TVM_FFI_ICHECK(!stack_.empty());
+    return stack_.back();
+  }
+
+  /*!
+   * \brief Execute body within a new scope.
+   *
+   * Pushes a new T onto the stack, executes the body, then pops it.
+   *
+   * \param body A callable to execute within the scope.
+   * \return The return value of body(), if non-void.
+   */
+  template <typename F>
+  auto WithNewScope(F&& body) -> decltype(body()) {
+    stack_.emplace_back();
+    struct Guard {
+      std::deque<T>* stack;
+      ~Guard() noexcept(false) { stack->pop_back(); }
+    } guard{&stack_};
+    if constexpr (std::is_void_v<decltype(body())>) {
+      body();
+    } else {
+      return body();
+    }
+  }
+
+ private:
+  /*!
+   * \brief The scope stack.
+   *
+   * We use std::deque rather than std::vector for pointer stability:
+   * references returned by Current() remain valid across push/pop operations.
+   * This is critical because methods called on Current() (e.g., Emplace on
+   * a WithGroup) may trigger re-entrant code that pushes new scopes onto
+   * the same stack. With std::vector the internal buffer reallocation would
+   * invalidate the reference, causing use-after-free.
+   */
+  std::deque<T> stack_;
+};
+
+}  // namespace tvm
+
+#endif  // TVM_IR_SCOPE_STACK_H_
diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h
index 8d4d48038d..8cf2823e06 100644
--- a/include/tvm/support/with.h
+++ b/include/tvm/support/with.h
@@ -25,8 +25,11 @@
 #ifndef TVM_SUPPORT_WITH_H_
 #define TVM_SUPPORT_WITH_H_
 
+#include <exception>
 #include <functional>
+#include <memory>
 #include <utility>
+#include <vector>
 
 namespace tvm {
 
@@ -90,5 +93,76 @@ class With {
   ContextType ctx_;
 };
 
+/*!
+ * \brief A group of RAII contexts managed together.
+ *
+ * Allows dynamically emplacing multiple context objects that are
+ * all exited (in reverse order) when the group is destroyed.
+ * ContextType must declare `friend class With<ContextType>`
+ * and provide EnterWithScope() / ExitWithScope() methods.
+ *
+ * \code
+ *   WithGroup<ConstraintContext> group;
+ *   group.Emplace(&analyzer, cond1);  // constructs and enters
+ *   group.Emplace(&analyzer, cond2);  // constructs and enters
+ *   // destructor: exits cond2, then cond1
+ * \endcode
+ *
+ * \tparam ContextType The context type with EnterWithScope/ExitWithScope.
+ */
+template <typename ContextType>
+class WithGroup {
+ public:
+  WithGroup() = default;
+  WithGroup(WithGroup&&) = default;
+  WithGroup& operator=(WithGroup&&) = default;
+  WithGroup(const WithGroup&) = delete;
+  WithGroup& operator=(const WithGroup&) = delete;
+
+  /*!
+   * \brief Construct a context and enter its scope.
+   * \param args Arguments forwarded to ContextType constructor.
+   */
+  template <typename... Args>
+  void Emplace(Args&&... args) {
+    
entries_.push_back(std::make_unique<With<ContextType>>(std::forward<Args>(args)...));
+  }
+
+  /*! \brief Number of active contexts in this group. */
+  size_t size() const { return entries_.size(); }
+
+  /*!
+   * \brief Destructor — exits all contexts in reverse order.
+   *
+   * On normal exit: if any ExitWithScope throws, the remaining
+   * contexts are still cleaned up, then the first exception
+   * is re-thrown.
+   *
+   * During stack unwinding: all exceptions are swallowed
+   * to avoid std::terminate.
+   */
+  ~WithGroup() noexcept(false) {
+    bool unwinding = std::uncaught_exceptions() > 0;
+    std::exception_ptr first_exc;
+    while (!entries_.empty()) {
+      // Move the last entry out of the vector first, then destroy it.
+      // This ensures entries_ shrinks even if ~With() throws.
+      auto entry = std::move(entries_.back());
+      entries_.pop_back();
+      try {
+        entry.reset();  // calls ~With<ContextType>() -> ExitWithScope()
+      } catch (...) {
+        if (!unwinding && !first_exc) {
+          first_exc = std::current_exception();
+        }
+      }
+    }
+    if (first_exc) std::rethrow_exception(first_exc);
+  }
+
+ private:
+  std::vector<std::unique_ptr<With<ContextType>>> entries_;
+};
+
 }  // namespace tvm
 #endif  // TVM_SUPPORT_WITH_H_
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index b41c92d66a..c12e14e0c0 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -176,7 +176,7 @@ class AssertStmtNode : public StmtNode {
  */
 class AssertStmt : public Stmt {
  public:
-  TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span 
span = Span());
+  TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Span span = Span());
 
   TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AssertStmt, Stmt, AssertStmtNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode);
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 6abd96a360..293ebed6a7 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -90,26 +90,19 @@ class AssertStmt(Stmt):
     message : PrimExpr
         The error message.
 
-    body : tvm.tir.Stmt
-        The body statement.
-
     span : Optional[Span]
         The location of the stmt in the source code.
     """
 
     condition: PrimExpr
     message: PrimExpr
-    body: Stmt
     span: Optional[Span]
 
-    def __init__(
-        self, condition: PrimExpr, message: PrimExpr, body: Stmt, span: 
Optional[Span] = None
-    ) -> None:
+    def __init__(self, condition: PrimExpr, message: PrimExpr, span: 
Optional[Span] = None) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.AssertStmt,
             condition,
             message,
-            body,
             span,  # type: ignore
         )
 
diff --git a/src/arith/ir_mutator_with_analyzer.cc 
b/src/arith/ir_mutator_with_analyzer.cc
index e57e4c0d5a..96a64c5e47 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -60,19 +60,23 @@ ffi::Array<PrimExpr> 
IRMutatorWithAnalyzer::IterMapSimplifyWithContext(
 }
 
 Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
-  // record the loop variable as iterators
-  Range dom = Range::FromMinExtent(op->min, op->extent);
-  analyzer_->Bind(op->loop_var, dom);
-  iter_vars_.Set(op->loop_var, dom);
-  return StmtExprMutator::VisitStmt_(op);
+  return constraint_scope_.WithNewScope([&]() -> Stmt {
+    // record the loop variable as iterators
+    Range dom = Range::FromMinExtent(op->min, op->extent);
+    analyzer_->Bind(op->loop_var, dom);
+    iter_vars_.Set(op->loop_var, dom);
+    return StmtExprMutator::VisitStmt_(op);
+  });
 }
 
 Stmt IRMutatorWithAnalyzer::VisitStmt_(const SBlockNode* op) {
-  for (const auto& iter_var : op->iter_vars) {
-    analyzer_->Bind(iter_var->var, iter_var->dom);
-    iter_vars_.Set(iter_var->var, iter_var->dom);
-  }
-  return StmtExprMutator::VisitStmt_(op);
+  return constraint_scope_.WithNewScope([&]() -> Stmt {
+    for (const auto& iter_var : op->iter_vars) {
+      analyzer_->Bind(iter_var->var, iter_var->dom);
+      iter_vars_.Set(iter_var->var, iter_var->dom);
+    }
+    return StmtExprMutator::VisitStmt_(op);
+  });
 }
 
 Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
@@ -94,88 +98,97 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* 
op) {
 }
 
 Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
-  PrimExpr condition = this->VisitExpr(op->condition);
-  PrimExpr real_condition = condition;
-  static auto op_likely = Op::Get("tir.likely");
+  return constraint_scope_.WithNewScope([&]() -> Stmt {
+    PrimExpr condition = this->VisitExpr(op->condition);
+    PrimExpr real_condition = condition;
+    static auto op_likely = Op::Get("tir.likely");
 
-  if (auto call = condition.as<CallNode>()) {
-    if (call->op.same_as(op_likely)) {
-      real_condition = call->args[0];
+    if (auto call = condition.as<CallNode>()) {
+      if (call->op.same_as(op_likely)) {
+        real_condition = call->args[0];
+      }
     }
-  }
 
-  Stmt then_case;
-  ffi::Optional<Stmt> else_case;
-  {
-    With<ConstraintContext> ctx(analyzer_, real_condition);
-    WithRecordIterPredicate(real_condition, [&] { then_case = 
this->VisitStmt(op->then_case); });
-  }
-  if (op->else_case) {
-    With<ConstraintContext> ctx(analyzer_, 
analyzer_->rewrite_simplify(Not(real_condition)));
-    else_case = this->VisitStmt(op->else_case.value());
-  }
-  if (is_one(real_condition)) return then_case;
-  if (is_zero(real_condition)) {
-    return else_case.value_or(Evaluate(0));
-  }
+    Stmt then_case;
+    ffi::Optional<Stmt> else_case;
+    constraint_scope_.WithNewScope([&]() {
+      constraint_scope_.Current().Emplace(analyzer_, real_condition);
+      WithRecordIterPredicate(real_condition, [&] { then_case = 
this->VisitStmt(op->then_case); });
+    });
+    if (op->else_case) {
+      PrimExpr neg_condition = 
analyzer_->rewrite_simplify(Not(real_condition));
+      constraint_scope_.WithNewScope([&]() {
+        constraint_scope_.Current().Emplace(analyzer_, neg_condition);
+        else_case = this->VisitStmt(op->else_case.value());
+      });
+    }
+    if (is_one(real_condition)) return then_case;
+    if (is_zero(real_condition)) {
+      return else_case.value_or(Evaluate(0));
+    }
 
-  if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
-      else_case.same_as(op->else_case)) {
-    return ffi::GetRef<Stmt>(op);
-  } else {
-    auto n = this->CopyOnWrite(op);
-    n->condition = std::move(condition);
-    n->then_case = std::move(then_case);
-    n->else_case = std::move(else_case);
-    return Stmt(n);
-  }
+    if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
+        else_case.same_as(op->else_case)) {
+      return ffi::GetRef<Stmt>(op);
+    } else {
+      auto n = this->CopyOnWrite(op);
+      n->condition = std::move(condition);
+      n->then_case = std::move(then_case);
+      n->else_case = std::move(else_case);
+      return Stmt(n);
+    }
+  });
 }
 
 Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
-  if (op->attr_key == tir::attr::thread_extent || op->attr_key == 
tir::attr::virtual_thread) {
-    IterVar iv = Downcast<IterVar>(op->node);
-    TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U);
-    Range dom = Range::FromMinExtent(make_zero(op->value.dtype()), op->value);
-    analyzer_->Bind(iv->var, dom);
-    iter_vars_.Set(iv->var, dom);
-    Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    return stmt;
-  } else {
+  return constraint_scope_.WithNewScope([&]() -> Stmt {
+    if (op->attr_key == tir::attr::thread_extent || op->attr_key == 
tir::attr::virtual_thread) {
+      IterVar iv = Downcast<IterVar>(op->node);
+      TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U);
+      Range dom = Range::FromMinExtent(make_zero(op->value.dtype()), 
op->value);
+      analyzer_->Bind(iv->var, dom);
+      iter_vars_.Set(iv->var, dom);
+    }
     return StmtExprMutator::VisitStmt_(op);
-  }
+  });
 }
 
 Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) {
   PrimExpr condition = this->VisitExpr(op->condition);
   PrimExpr message = this->VisitExpr(op->message);
-  With<ConstraintContext> ctx(analyzer_, condition);
-  Stmt body = this->VisitStmt(op->body);
+  constraint_scope_.Current().Emplace(analyzer_, condition);
 
-  if (condition.same_as(op->condition) && message.same_as(op->message) && 
body.same_as(op->body)) {
+  if (condition.same_as(op->condition) && message.same_as(op->message)) {
     return ffi::GetRef<Stmt>(op);
   } else {
     auto n = this->CopyOnWrite(op);
     n->condition = std::move(condition);
     n->message = std::move(message);
-    n->body = std::move(body);
     return Stmt(n);
   }
 }
 
+Stmt IRMutatorWithAnalyzer::VisitStmt_(const SeqStmtNode* op) {
+  // SeqStmt does NOT get WithNewScope — constraints accumulate across 
siblings.
+  return StmtExprMutator::VisitStmt_(op);
+}
+
 PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {
   // add condition context to if_then_else
   static auto op_if_then_else = Op::Get("tir.if_then_else");
   if (op->op.same_as(op_if_then_else)) {
     PrimExpr cond = this->VisitExpr(op->args[0]);
     PrimExpr true_value, false_value;
-    {
-      With<ConstraintContext> constraint(analyzer_, cond);
+    constraint_scope_.WithNewScope([&]() {
+      constraint_scope_.Current().Emplace(analyzer_, cond);
       WithRecordIterPredicate(cond, [&] { true_value = 
this->VisitExpr(op->args[1]); });
-    }
+    });
     {
       PrimExpr not_cond = Not(cond);
-      With<ConstraintContext> constraint(analyzer_, not_cond);
-      WithRecordIterPredicate(not_cond, [&] { false_value = 
this->VisitExpr(op->args[2]); });
+      constraint_scope_.WithNewScope([&]() {
+        constraint_scope_.Current().Emplace(analyzer_, not_cond);
+        WithRecordIterPredicate(not_cond, [&] { false_value = 
this->VisitExpr(op->args[2]); });
+      });
     }
     if (is_zero(cond)) {
       return false_value;
@@ -211,13 +224,16 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode* 
op) {
 PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const SelectNode* op) {
   PrimExpr cond = this->VisitExpr(op->condition);
   PrimExpr true_value, false_value;
-  {
-    With<ConstraintContext> constraint(analyzer_, cond);
+  constraint_scope_.WithNewScope([&]() {
+    constraint_scope_.Current().Emplace(analyzer_, cond);
     true_value = VisitExpr(op->true_value);
-  }
+  });
   {
-    With<ConstraintContext> constraint(analyzer_, 
analyzer_->rewrite_simplify(Not(cond)));
-    false_value = VisitExpr(op->false_value);
+    PrimExpr neg_cond = analyzer_->rewrite_simplify(Not(cond));
+    constraint_scope_.WithNewScope([&]() {
+      constraint_scope_.Current().Emplace(analyzer_, neg_cond);
+      false_value = VisitExpr(op->false_value);
+    });
   }
   if (is_zero(cond)) {
     return false_value;
diff --git a/src/arith/ir_mutator_with_analyzer.h 
b/src/arith/ir_mutator_with_analyzer.h
index 5b5ac7e6cd..8810a8f78f 100644
--- a/src/arith/ir_mutator_with_analyzer.h
+++ b/src/arith/ir_mutator_with_analyzer.h
@@ -25,6 +25,8 @@
 #define TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_
 
 #include <tvm/arith/analyzer.h>
+#include <tvm/ir/scope_stack.h>
+#include <tvm/support/with.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/stmt_functor.h>
 
@@ -56,6 +58,7 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
   tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override;
   tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override;
   tir::Stmt VisitStmt_(const tir::AssertStmtNode* op) override;
+  tir::Stmt VisitStmt_(const tir::SeqStmtNode* op) override;
   PrimExpr VisitExpr_(const tir::LetNode* op) override;
   PrimExpr VisitExpr_(const tir::SelectNode* op) override;
   PrimExpr VisitExpr_(const tir::CallNode* op) override;
@@ -79,6 +82,8 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
 
   /*! \brief internal analyzer field. */
   Analyzer* analyzer_;
+  /*! \brief Scope stack for accumulated assert constraints. */
+  ScopeStack<WithGroup<ConstraintContext>> constraint_scope_;
   // the following two fields are useful in case we want
   // note however that iter map analysis are usually more
   // expensive and we only encourage doing them during
diff --git a/src/arith/ir_visitor_with_analyzer.cc 
b/src/arith/ir_visitor_with_analyzer.cc
index fada12e9c4..02c194b14e 100644
--- a/src/arith/ir_visitor_with_analyzer.cc
+++ b/src/arith/ir_visitor_with_analyzer.cc
@@ -32,15 +32,19 @@ namespace arith {
 using namespace tir;
 
 void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) {
-  analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
-  StmtExprVisitor::VisitStmt_(op);
+  constraint_scope_.WithNewScope([&]() {
+    analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
+    StmtExprVisitor::VisitStmt_(op);
+  });
 }
 
 void IRVisitorWithAnalyzer::VisitStmt_(const SBlockNode* op) {
-  for (const auto& iter_var : op->iter_vars) {
-    analyzer_.Bind(iter_var->var, iter_var->dom);
-  }
-  StmtExprVisitor::VisitStmt_(op);
+  constraint_scope_.WithNewScope([&]() {
+    for (const auto& iter_var : op->iter_vars) {
+      analyzer_.Bind(iter_var->var, iter_var->dom);
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  });
 }
 
 void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
@@ -50,34 +54,45 @@ void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode* 
op) {
 }
 
 void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
-  this->VisitExpr(op->condition);
-
-  PrimExpr real_condition = ExtractRealCondition(op->condition);
-
-  {
-    With<ConstraintContext> constraint(&analyzer_, real_condition);
-    this->VisitStmt(op->then_case);
-  }
-  if (op->else_case) {
-    With<ConstraintContext> constraint(&analyzer_, 
analyzer_.rewrite_simplify(Not(real_condition)));
-    this->VisitStmt(op->else_case.value());
-  }
+  constraint_scope_.WithNewScope([&]() {
+    this->VisitExpr(op->condition);
+
+    PrimExpr real_condition = ExtractRealCondition(op->condition);
+
+    constraint_scope_.WithNewScope([&]() {
+      constraint_scope_.Current().Emplace(&analyzer_, real_condition);
+      this->VisitStmt(op->then_case);
+    });
+    if (op->else_case) {
+      constraint_scope_.WithNewScope([&]() {
+        constraint_scope_.Current().Emplace(&analyzer_,
+                                            
analyzer_.rewrite_simplify(Not(real_condition)));
+        this->VisitStmt(op->else_case.value());
+      });
+    }
+  });
 }
 
 void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
-  if (op->attr_key == tir::attr::thread_extent || op->attr_key == 
tir::attr::virtual_thread) {
-    IterVar iv = Downcast<IterVar>(op->node);
-    TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U);
-    analyzer_.Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), 
op->value));
-  }
-  StmtExprVisitor::VisitStmt_(op);
+  constraint_scope_.WithNewScope([&]() {
+    if (op->attr_key == tir::attr::thread_extent || op->attr_key == 
tir::attr::virtual_thread) {
+      IterVar iv = Downcast<IterVar>(op->node);
+      TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U);
+      analyzer_.Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 
0), op->value));
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  });
 }
 
 void IRVisitorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) {
   this->VisitExpr(op->condition);
   this->VisitExpr(op->message);
-  With<ConstraintContext> constraint(&analyzer_, op->condition);
-  this->VisitStmt(op->body);
+  constraint_scope_.Current().Emplace(&analyzer_, op->condition);
+}
+
+void IRVisitorWithAnalyzer::VisitStmt_(const SeqStmtNode* op) {
+  // SeqStmt does NOT get WithNewScope — constraints accumulate across 
siblings.
+  StmtExprVisitor::VisitStmt_(op);
 }
 
 void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) {
@@ -86,14 +101,14 @@ void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) 
{
   if (op->op.same_as(op_if_then_else)) {
     PrimExpr cond = op->args[0];
     this->VisitExpr(op->args[0]);
-    {
-      With<ConstraintContext> constraint(&analyzer_, cond);
+    constraint_scope_.WithNewScope([&]() {
+      constraint_scope_.Current().Emplace(&analyzer_, cond);
       this->VisitExpr(op->args[1]);
-    }
-    {
-      With<ConstraintContext> constraint(&analyzer_, 
analyzer_.rewrite_simplify(Not(cond)));
+    });
+    constraint_scope_.WithNewScope([&]() {
+      constraint_scope_.Current().Emplace(&analyzer_, 
analyzer_.rewrite_simplify(Not(cond)));
       this->VisitExpr(op->args[2]);
-    }
+    });
   } else {
     StmtExprVisitor::VisitExpr_(op);
   }
diff --git a/src/arith/ir_visitor_with_analyzer.h 
b/src/arith/ir_visitor_with_analyzer.h
index cd2b9bfdec..f0553a1c42 100644
--- a/src/arith/ir_visitor_with_analyzer.h
+++ b/src/arith/ir_visitor_with_analyzer.h
@@ -26,6 +26,8 @@
 #define TVM_ARITH_IR_VISITOR_WITH_ANALYZER_H_
 
 #include <tvm/arith/analyzer.h>
+#include <tvm/ir/scope_stack.h>
+#include <tvm/support/with.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
 
@@ -45,6 +47,7 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor {
   void VisitStmt_(const tir::IfThenElseNode* op);
   void VisitStmt_(const tir::AttrStmtNode* op);
   void VisitStmt_(const tir::AssertStmtNode* op);
+  void VisitStmt_(const tir::SeqStmtNode* op);
   void VisitExpr_(const tir::CallNode* op);
   void VisitExpr_(const tir::LetNode* op);
   void VisitExpr_(const tir::ReduceNode* op);
@@ -57,6 +60,9 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor {
   /*! \brief internal analyzer field. */
   arith::Analyzer analyzer_;
 
+  /*! \brief Scope stack for accumulated assert constraints. */
+  ScopeStack<WithGroup<ConstraintContext>> constraint_scope_;
+
   /*! \brief Extract a constraint from a conditional statement
    *
    * Intended for preparing argument for use in
diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc
index 2e05ea4a81..bd0f963c49 100644
--- a/src/relax/op/tensor/inspect.cc
+++ b/src/relax/op/tensor/inspect.cc
@@ -315,9 +315,11 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const 
Call& call) {
                    IntImm(DataType::Int(32), 
tir::builtin::TVMStructFieldKind::kArrShape)}),
         body);
 
-    body = tir::AssertStmt(
-        axis < tvm::cast(axis->dtype, ndim),
-        tir::StringImm("Specified axis may not be larger than the tensor's 
dimensionality"), body);
+    body = tir::SeqStmt(
+        {tir::AssertStmt(
+             axis < tvm::cast(axis->dtype, ndim),
+             tir::StringImm("Specified axis may not be larger than the 
tensor's dimensionality")),
+         body});
 
     body = tir::LetStmt(
         ndim,
@@ -326,7 +328,8 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const 
Call& call) {
                    IntImm(DataType::Int(32), 
tir::builtin::TVMStructFieldKind::kArrNDim)}),
         body);
 
-    body = tir::AssertStmt(0 <= axis, tir::StringImm("Specified axis may not 
be negative"), body);
+    body = tir::SeqStmt(
+        {tir::AssertStmt(0 <= axis, tir::StringImm("Specified axis may not be 
negative")), body});
 
     DictAttrs attrs({{"tir.is_scheduled", true}, {"tir.is_host", true}});
 
diff --git a/src/s_tir/analysis/estimate_flops.cc 
b/src/s_tir/analysis/estimate_flops.cc
index b11d262281..0f106f5854 100644
--- a/src/s_tir/analysis/estimate_flops.cc
+++ b/src/s_tir/analysis/estimate_flops.cc
@@ -199,7 +199,6 @@ class FlopEstimator : private ExprFunctor<TResult(const 
PrimExpr& n)>,
     if (op->message.defined()) {
       result += VisitExpr(op->message);
     }
-    result += VisitStmt(op->body);
     return result;
   }
 
diff --git a/src/s_tir/transform/bound_checker.cc 
b/src/s_tir/transform/bound_checker.cc
index dbee2effdb..62b76fa806 100644
--- a/src/s_tir/transform/bound_checker.cc
+++ b/src/s_tir/transform/bound_checker.cc
@@ -96,9 +96,8 @@ class BoundChecker : public StmtExprMutator {
     if (store_scope_bound_collector_.size()) {
       PrimExpr condition = MakeCondition();
       if (!condition.as<StringImmNode>()) {
-        Stmt nop = Evaluate(1);
         Stmt then_case = ffi::GetRef<Stmt>(op);
-        Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop);
+        Stmt else_case = AssertStmt(condition, StringImm(error_message_));
         Stmt body = IfThenElse(condition, then_case, else_case);
         return body;
       }
diff --git a/src/script/ir_builder/tir/frame.cc 
b/src/script/ir_builder/tir/frame.cc
index e5008c74ed..e65be1d45f 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -128,7 +128,16 @@ void ForFrameNode::ExitWithScope() {
 
 void AssertFrameNode::ExitWithScope() {
   TIRFrameNode::ExitWithScope();
-  AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts)));
+  if (stmts.empty()) {
+    AddToParent(tvm::tir::AssertStmt(condition, message));
+  } else {
+    ffi::Array<tvm::tir::Stmt> seq;
+    seq.push_back(tvm::tir::AssertStmt(condition, message));
+    for (const auto& stmt : stmts) {
+      seq.push_back(stmt);
+    }
+    AddToParent(tvm::tir::SeqStmt(seq));
+  }
 }
 
 void LetFrameNode::ExitWithScope() {
diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc
index eff58cf411..633704c164 100644
--- a/src/script/printer/tir/stmt.cc
+++ b/src/script/printer/tir/stmt.cc
@@ -131,20 +131,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     });
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
-    .set_dispatch<tir::AssertStmt>(
-        "", [](tir::AssertStmt stmt, AccessPath p, IRDocsifier d) -> Doc {
-          bool concise = AllowConciseScoping(d, stmt);
-          ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, 
p->Attr("condition"));
-          ExprDoc msg = d->AsDoc<ExprDoc>(stmt->message, p->Attr("message"));
-          With<TIRFrame> f(d, stmt);
-          AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
-          if (concise) {
-            ffi::Array<StmtDoc>* stmts = &(*f)->stmts;
-            stmts->insert(stmts->begin(), AssertDoc(cond, msg));
-            return StmtBlockDoc(*stmts);
-          }
-          return ScopeDoc(std::nullopt, TIR(d, "Assert")->Call({cond, msg}), 
(*f)->stmts);
-        });
+    .set_dispatch<tir::AssertStmt>("",
+                                   [](tir::AssertStmt stmt, AccessPath p, 
IRDocsifier d) -> Doc {
+                                     ExprDoc cond =
+                                         d->AsDoc<ExprDoc>(stmt->condition, 
p->Attr("condition"));
+                                     ExprDoc msg =
+                                         d->AsDoc<ExprDoc>(stmt->message, 
p->Attr("message"));
+                                     return AssertDoc(cond, msg);
+                                   });
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<tir::While>("", [](tir::While stmt, AccessPath p, 
IRDocsifier d) -> Doc {
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 93f0282015..472567efc7 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -2165,9 +2165,8 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
 
 void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) {
   EmitDebugLocation(op);
-  // auto a_cu =
-  With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
-  this->VisitStmt(op->body);
+  // AssertStmt is a leaf — no body to visit.
+  // Constraint scoping is handled by ScopeStack in analysis passes.
 }
 
 void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) {
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index d3e1cee46e..95b1260e45 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -1088,7 +1088,6 @@ void CodeGenC::VisitStmt_(const AssertStmtNode* op) {
   } else {
     stream << "assert(" << cond << ");\n";
   }
-  this->PrintStmt(op->body);
 }
 
 void CodeGenC::VisitStmt_(const ForNode* op) {
diff --git a/src/target/source/codegen_c_host.cc 
b/src/target/source/codegen_c_host.cc
index cb6ba238ef..ca6b71b4d8 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -332,7 +332,6 @@ void CodeGenCHost::VisitStmt_(const AssertStmtNode* op) {  
// NOLINT(*)
     PrintIndent();
     stream << "}\n";
   }
-  this->PrintStmt(op->body);
 }
 
 void CodeGenCHost::VisitExpr_(const MinNode* op, std::ostream& os) {  // 
NOLINT(*)
diff --git a/src/target/source/codegen_webgpu.cc 
b/src/target/source/codegen_webgpu.cc
index eb5351bd0f..3a9cceb687 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -698,8 +698,7 @@ void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
 }
 
 void CodeGenWebGPU::VisitStmt_(const AssertStmtNode* op) {
-  // skip assert
-  PrintStmt(op->body);
+  // skip assert — AssertStmt is a leaf, nothing to emit.
 }
 
 void CodeGenWebGPU::VisitStmt_(const WhileNode* op) {
diff --git a/src/target/spirv/codegen_spirv.cc 
b/src/target/spirv/codegen_spirv.cc
index ff8053a309..114bdbc806 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -886,8 +886,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
 }
 
 void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) {
-  With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
-  this->VisitStmt(op->body);
+  // AssertStmt is a leaf — no body to visit.
 }
 
 void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) {
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 4f0dbaf121..5eaff0c314 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -104,7 +104,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 }
 
 // AssertStmt
-AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span 
span) {
+AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Span span) {
   TVM_FFI_ICHECK(condition.defined());
   TVM_FFI_ICHECK(condition.dtype().is_bool())
       << "AssertStmt should have boolean condition, "
@@ -115,17 +115,15 @@ AssertStmt::AssertStmt(PrimExpr condition, PrimExpr 
message, Stmt body, Span spa
   ObjectPtr<AssertStmtNode> node = ffi::make_object<AssertStmtNode>();
   node->condition = std::move(condition);
   node->message = std::move(message);
-  node->body = std::move(body);
   node->span = std::move(span);
   data_ = std::move(node);
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.AssertStmt",
-                        [](PrimExpr condition, StringImm message, Stmt body, 
Span span) {
-                          return AssertStmt(condition, message, body, span);
-                        });
+  refl::GlobalDef().def("tir.AssertStmt", [](PrimExpr condition, StringImm 
message, Span span) {
+    return AssertStmt(condition, message, span);
+  });
 }
 
 // For
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index a3e59914c1..659056a3a2 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -81,7 +81,6 @@ void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
 void StmtVisitor::VisitStmt_(const AssertStmtNode* op) {
   this->VisitExpr(op->condition);
   this->VisitExpr(op->message);
-  this->VisitStmt(op->body);
 }
 
 void StmtVisitor::VisitStmt_(const SeqStmtNode* op) {
@@ -398,15 +397,13 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, 
bool flatten_before_visit
 Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) {
   PrimExpr condition = this->VisitExpr(op->condition);
   PrimExpr message = this->VisitExpr(op->message);
-  Stmt body = this->VisitStmt(op->body);
 
-  if (condition.same_as(op->condition) && message.same_as(op->message) && 
body.same_as(op->body)) {
+  if (condition.same_as(op->condition) && message.same_as(op->message)) {
     return ffi::GetRef<Stmt>(op);
   } else {
     auto n = CopyOnWrite(op);
     n->condition = std::move(condition);
     n->message = std::move(message);
-    n->body = std::move(body);
     return Stmt(n);
   }
 }
diff --git a/src/tir/ir/tir_visitor_with_path.cc 
b/src/tir/ir/tir_visitor_with_path.cc
index 2e9968790f..0c03a3c2c8 100644
--- a/src/tir/ir/tir_visitor_with_path.cc
+++ b/src/tir/ir/tir_visitor_with_path.cc
@@ -252,7 +252,6 @@ void TIRVisitorWithPath::VisitStmt_(const IfThenElseNode* 
op, AccessPath path) {
 void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, AccessPath path) 
{
   Visit(op->condition, path->Attr("condition"));
   Visit(op->message, path->Attr("message"));
-  Visit(op->body, path->Attr("body"));
 }
 
 void TIRVisitorWithPath::VisitStmt_(const SeqStmtNode* op, AccessPath path) {
diff --git a/src/tir/transform/arg_binder.cc b/src/tir/transform/arg_binder.cc
index de44c1449d..bab4cdb3da 100644
--- a/src/tir/transform/arg_binder.cc
+++ b/src/tir/transform/arg_binder.cc
@@ -43,7 +43,7 @@ void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond, 
const std::string& arg
   if (!is_one(scond)) {
     std::ostringstream os;
     os << "Argument " << arg_name << " has an unsatisfied constraint: " << 
cond;
-    asserts->emplace_back(AssertStmt(scond, tvm::tir::StringImm(os.str()), 
Evaluate(0)));
+    asserts->emplace_back(AssertStmt(scond, tvm::tir::StringImm(os.str())));
   }
 }
 
@@ -159,7 +159,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const 
PrimExpr& device_type,
 
   init_nest_.emplace_back(AssertStmt(
       !Call(DataType::Bool(), builtin::isnullptr(), {handle}),
-      tvm::tir::StringImm(arg_name + " is expected to have non-NULL DLTensor* 
pointer"), nop));
+      tvm::tir::StringImm(arg_name + " is expected to have non-NULL DLTensor* 
pointer")));
 
   // dimension checks
   PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
@@ -179,7 +179,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const 
PrimExpr& device_type,
   std::ostringstream ndim_err_msg;
   ndim_err_msg << arg_name << ".ndim is expected to equal " << 
buffer->shape.size();
   auto msg = tvm::tir::StringImm(ndim_err_msg.str());
-  init_nest_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
+  init_nest_.emplace_back(AssertStmt(a_ndim == v_ndim, msg));
   // type checks
   std::ostringstream type_err_msg;
   type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype;
@@ -192,7 +192,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const 
PrimExpr& device_type,
   if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) 
||
         buffer->dtype == DataType::UInt(4))) {
     auto type_msg = tvm::tir::StringImm(type_err_msg.str());
-    asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
+    asserts_.emplace_back(AssertStmt(cond, type_msg));
   }
 
   // shape field
@@ -238,7 +238,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const 
PrimExpr& device_type,
       Stmt check = AssertStmt(
           foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, 
b, span); },
                 const_true(1), conds),
-          stride_msg, Evaluate(0));
+          stride_msg);
       check = IfThenElse(Not(v_strides_is_null), check);
       asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
     }
@@ -314,9 +314,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const 
PrimExpr& device_type,
       }
       return product;
     }();
-    asserts_.emplace_back(AssertStmt(
-        alloc_size == 0 || !Call(DataType::Bool(), builtin::isnullptr(), 
{vptr}),
-        tvm::tir::StringImm(arg_name + " is expected to have non-NULL data 
pointer"), nop));
+    asserts_.emplace_back(
+        AssertStmt(alloc_size == 0 || !Call(DataType::Bool(), 
builtin::isnullptr(), {vptr}),
+                   tvm::tir::StringImm(arg_name + " is expected to have 
non-NULL data pointer")));
 
     def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
     // mark alignment of external bufs
diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc
index f661722e2f..4e8a590a7d 100644
--- a/src/tir/transform/ir_utils.cc
+++ b/src/tir/transform/ir_utils.cc
@@ -67,11 +67,8 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
       TVM_FFI_ICHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1]));
       n->seq.Set(n->size() - 1, body);
       body = Stmt(n);
-    } else if (const auto* assert_ = s.as<AssertStmtNode>()) {
-      auto n = ffi::make_object<AssertStmtNode>(*assert_);
-      TVM_FFI_ICHECK(is_no_op(n->body));
-      n->body = body;
-      body = Stmt(n);
+    } else if (s.as<AssertStmtNode>()) {
+      body = SeqStmt({s, body});
     } else if (const auto* alloc = s.as<AllocateNode>()) {
       auto n = ffi::make_object<AllocateNode>(*alloc);
       TVM_FFI_ICHECK(is_no_op(n->body));
diff --git a/src/tir/transform/make_packed_api.cc 
b/src/tir/transform/make_packed_api.cc
index 4c35b3fdd8..af63d9ba54 100644
--- a/src/tir/transform/make_packed_api.cc
+++ b/src/tir/transform/make_packed_api.cc
@@ -168,12 +168,12 @@ class SubroutineCallRewriter : public StmtExprMutator {
 }  // namespace
 
 inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
-  return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
+  return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg));
 }
 
 inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) {
   Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr});
-  return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0));
+  return AssertStmt(!isnull, tvm::tir::StringImm(msg));
 }
 
 /* \brief Return the global_symbol of the function, if it should be updated
@@ -300,7 +300,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
                                            type_index == 
ffi::TypeIndex::kTVMFFIOpaquePtr ||
                                            type_index == 
ffi::TypeIndex::kTVMFFIDLTensorPtr ||
                                            type_index >= 
ffi::TypeIndex::kTVMFFIStaticObjectBegin,
-                                       tvm::tir::StringImm(msg.str()), nop));
+                                       tvm::tir::StringImm(msg.str())));
       // if type_index is Tensor, we need to add the offset of the DLTensor 
header
       // which always equals 16 bytes, this ensures that T.handle always shows 
up as a DLTensor*
       const int64_t object_cell_offset = sizeof(TVMFFIObject);
@@ -316,7 +316,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
       msg << name_hint << ": Expect arg[" << i << "] to be boolean";
       seq_init.emplace_back(AssertStmt(
           type_index == ffi::TypeIndex::kTVMFFIBool || type_index == 
ffi::TypeIndex::kTVMFFIInt,
-          tvm::tir::StringImm(msg.str()), nop));
+          tvm::tir::StringImm(msg.str())));
       arg_value = Cast(DataType::Bool(), f_load_arg_value(DataType::Int(64), 
i));
 
     } else if (dtype.is_int() || dtype.is_uint()) {
@@ -324,7 +324,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
       msg << name_hint << ": Expect arg[" << i << "] to be int";
       seq_init.emplace_back(AssertStmt(
           type_index == ffi::TypeIndex::kTVMFFIInt || type_index == 
ffi::TypeIndex::kTVMFFIBool,
-          tvm::tir::StringImm(msg.str()), nop));
+          tvm::tir::StringImm(msg.str())));
       arg_value = f_load_arg_value(param.dtype(), i);
     } else {
       TVM_FFI_ICHECK(dtype.is_float());
@@ -333,7 +333,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
       seq_init.emplace_back(AssertStmt(type_index == 
ffi::TypeIndex::kTVMFFIFloat ||
                                            type_index == 
ffi::TypeIndex::kTVMFFIInt ||
                                            type_index == 
ffi::TypeIndex::kTVMFFIBool,
-                                       tvm::tir::StringImm(msg.str()), nop));
+                                       tvm::tir::StringImm(msg.str())));
       // use select so we can also handle int conversion to bool
       arg_value = tir::Select(
           type_index == ffi::TypeIndex::kTVMFFIFloat,
diff --git a/src/tir/transform/skip_assert.cc b/src/tir/transform/skip_assert.cc
index b2c473c97c..8e997bc9ee 100644
--- a/src/tir/transform/skip_assert.cc
+++ b/src/tir/transform/skip_assert.cc
@@ -29,9 +29,8 @@ namespace tir {
 class AssertSkipper : public StmtMutator {
  public:
   Stmt VisitStmt_(const AssertStmtNode* op) final {
-    Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<AssertStmtNode>();
-    return op->body;
+    // AssertStmt is a leaf — just remove it.
+    return Evaluate(0);
   }
 };
 
diff --git a/src/tir/transform/split_host_device.cc 
b/src/tir/transform/split_host_device.cc
index 130cc177f0..43b6701f1f 100644
--- a/src/tir/transform/split_host_device.cc
+++ b/src/tir/transform/split_host_device.cc
@@ -104,7 +104,7 @@ class HostDeviceSplitter : public StmtMutator {
       Var kernel_error_code("kernel_error_code", success->dtype);
       Call kernel_call(success->dtype, kernel_symbol_global, args);
       AssertStmt assert_success(kernel_error_code == success,
-                                StringImm("Error executing compute kernel"), 
Evaluate(0));
+                                StringImm("Error executing compute kernel"));
       LetStmt let_check(kernel_error_code, kernel_call, assert_success);
 
       return let_check;
diff --git a/tests/python/tir-base/test_tir_constructor.py 
b/tests/python/tir-base/test_tir_constructor.py
index 5347146c05..d05201a143 100644
--- a/tests/python/tir-base/test_tir_constructor.py
+++ b/tests/python/tir-base/test_tir_constructor.py
@@ -141,9 +141,8 @@ def test_stmt_constructor():
     assert isinstance(x, tvm.tir.AttrStmt)
     assert x.value.value == 1
 
-    x = tvm.tir.AssertStmt(tvm.tir.const(1, "bool"), 
tvm.runtime.convert("hellow"), nop)
+    x = tvm.tir.AssertStmt(tvm.tir.const(1, "bool"), 
tvm.runtime.convert("hellow"))
     assert isinstance(x, tvm.tir.AssertStmt)
-    assert x.body == nop
 
     x = tvm.tir.For(tvm.tir.Var("x", "int32"), 0, 10, tvm.tir.ForKind.SERIAL, 
nop)
     assert isinstance(x, tvm.tir.For)
diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py 
b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
index b65c0ec23a..71e8f37803 100644
--- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
@@ -298,8 +298,13 @@ def test_ir_builder_tir_assert():
     # the assert generated by IRBuilder
     assert_actual = ib.get()
 
-    # the expected assert statement
-    assert_expected = tir.AssertStmt(T.int32() == 0, tir.StringImm("a is 0"), 
tir.Evaluate(0))
+    # AssertStmt is a leaf. The frame emits the assert and then the body stmts 
as siblings.
+    assert_expected = tir.SeqStmt(
+        [
+            tir.AssertStmt(T.int32() == 0, tir.StringImm("a is 0")),
+            tir.Evaluate(0),
+        ]
+    )
 
     # Check if the generated ir is expected
     assert_structural_equal(assert_actual, assert_expected, map_free_vars=True)
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py 
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index c2091ae5e6..93ddff0fd4 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -289,8 +289,8 @@ def test_assert_stmt():
     _assert_print(
         obj,
         """
-with T.Assert(T.bool(True), "assertion"):
-    T.evaluate(0)
+assert T.bool(True), "assertion"
+T.evaluate(0)
 """,
     )
 


Reply via email to