This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 95a4b1a819 [IR][TIR] Remove body from AssertStmt (#18832)
95a4b1a819 is described below
commit 95a4b1a8193c28a62f8195be6a13539e7f701a4d
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Feb 27 19:34:38 2026 -0500
[IR][TIR] Remove body from AssertStmt (#18832)
This PR Remove the body field from AssertStmt, making it a leaf
statement. Constraints from AssertStmt are now tracked via
WithGroup<ConstraintContext> in a ScopeStack, providing clean RAII-based
scope management.
New utilities:
- WithGroup<T>: manages a dynamic group of With<T> RAII contexts
- ScopeStack<T>: scope stack for hierarchical state during IR visiting
---
include/tvm/arith/analyzer.h | 2 +-
include/tvm/ir/scope_stack.h | 123 +++++++++++++++++
include/tvm/support/with.h | 74 +++++++++++
include/tvm/tir/stmt.h | 2 +-
python/tvm/tir/stmt.py | 9 +-
src/arith/ir_mutator_with_analyzer.cc | 146 ++++++++++++---------
src/arith/ir_mutator_with_analyzer.h | 5 +
src/arith/ir_visitor_with_analyzer.cc | 79 ++++++-----
src/arith/ir_visitor_with_analyzer.h | 6 +
src/relax/op/tensor/inspect.cc | 11 +-
src/s_tir/analysis/estimate_flops.cc | 1 -
src/s_tir/transform/bound_checker.cc | 3 +-
src/script/ir_builder/tir/frame.cc | 11 +-
src/script/printer/tir/stmt.cc | 22 ++--
src/target/llvm/codegen_llvm.cc | 5 +-
src/target/source/codegen_c.cc | 1 -
src/target/source/codegen_c_host.cc | 1 -
src/target/source/codegen_webgpu.cc | 3 +-
src/target/spirv/codegen_spirv.cc | 3 +-
src/tir/ir/stmt.cc | 10 +-
src/tir/ir/stmt_functor.cc | 5 +-
src/tir/ir/tir_visitor_with_path.cc | 1 -
src/tir/transform/arg_binder.cc | 16 +--
src/tir/transform/ir_utils.cc | 7 +-
src/tir/transform/make_packed_api.cc | 12 +-
src/tir/transform/skip_assert.cc | 5 +-
src/tir/transform/split_host_device.cc | 2 +-
tests/python/tir-base/test_tir_constructor.py | 3 +-
.../tvmscript/test_tvmscript_ir_builder_tir.py | 9 +-
.../python/tvmscript/test_tvmscript_printer_tir.py | 4 +-
30 files changed, 403 insertions(+), 178 deletions(-)
diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 85d814f5b2..b77f2ee5db 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -563,7 +563,7 @@ class ConstraintContext {
Analyzer* analyzer_;
/*! \brief The constraint */
PrimExpr constraint_;
- /*! \brief function to be called in recovery */
+ /*! \brief functions to be called in recovery */
std::vector<std::function<void()>> recovery_functions_;
};
diff --git a/include/tvm/ir/scope_stack.h b/include/tvm/ir/scope_stack.h
new file mode 100644
index 0000000000..694d35e19e
--- /dev/null
+++ b/include/tvm/ir/scope_stack.h
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/scope_stack.h
+ * \brief A generic scope stack for managing hierarchical state during IR
visiting.
+ */
+#ifndef TVM_IR_SCOPE_STACK_H_
+#define TVM_IR_SCOPE_STACK_H_
+
+#include <tvm/ffi/error.h>
+
+#include <deque>
+#include <type_traits>
+
+namespace tvm {
+
+/*!
+ * \brief A scope stack for maintaining hierarchical state during IR visiting.
+ *
+ * During IR tree traversal, visitors often need to track scope-local state
+ * (e.g., active constraints, variable bindings) that should be automatically
+ * cleaned up when leaving a scope. ScopeStack provides this via WithNewScope,
+ * which pushes a new element on entry and pops it on exit.
+ *
+ * \code
+ * ScopeStack<WithGroup<ConstraintContext>> constraints;
+ *
+ * // In VisitStmt_(ForNode):
+ * return constraints.WithNewScope([&]() -> Stmt {
+ * constraints.Current().Emplace(&analyzer, condition);
+ * return StmtExprMutator::VisitStmt_(op);
+ * });
+ * \endcode
+ *
+ * \tparam T The element type stored on the stack. Must be
default-constructible.
+ */
+template <typename T>
+class ScopeStack {
+ public:
+ /*! \brief Construct with one initial scope level. */
+ ScopeStack() { stack_.emplace_back(); }
+
+ /*! \brief Return the number of active scopes. */
+ size_t size() const { return stack_.size(); }
+
+ /*! \brief Return true if no scopes are active. */
+ bool empty() const { return stack_.empty(); }
+
+ /*!
+ * \brief Access the current (innermost) scope element.
+ *
+ * The returned reference is stable across push_back/pop_back because
+ * std::deque guarantees pointer stability for these operations.
+ *
+ * \return Mutable reference to the top element.
+ */
+ T& Current() {
+ TVM_FFI_ICHECK(!stack_.empty());
+ return stack_.back();
+ }
+
+ /*! \brief Const access to the current (innermost) scope element. */
+ const T& Current() const {
+ TVM_FFI_ICHECK(!stack_.empty());
+ return stack_.back();
+ }
+
+ /*!
+ * \brief Execute body within a new scope.
+ *
+ * Pushes a new T onto the stack, executes the body, then pops it.
+ *
+ * \param body A callable to execute within the scope.
+ * \return The return value of body(), if non-void.
+ */
+ template <typename F>
+ auto WithNewScope(F&& body) -> decltype(body()) {
+ stack_.emplace_back();
+ struct Guard {
+ std::deque<T>* stack;
+ ~Guard() noexcept(false) { stack->pop_back(); }
+ } guard{&stack_};
+ if constexpr (std::is_void_v<decltype(body())>) {
+ body();
+ } else {
+ return body();
+ }
+ }
+
+ private:
+ /*!
+ * \brief The scope stack.
+ *
+ * We use std::deque rather than std::vector for pointer stability:
+ * references returned by Current() remain valid across push/pop operations.
+ * This is critical because methods called on Current() (e.g., Emplace on
+ * a WithGroup) may trigger re-entrant code that pushes new scopes onto
+ * the same stack. With std::vector the internal buffer reallocation would
+ * invalidate the reference, causing use-after-free.
+ */
+ std::deque<T> stack_;
+};
+
+} // namespace tvm
+
+#endif // TVM_IR_SCOPE_STACK_H_
diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h
index 8d4d48038d..8cf2823e06 100644
--- a/include/tvm/support/with.h
+++ b/include/tvm/support/with.h
@@ -25,8 +25,11 @@
#ifndef TVM_SUPPORT_WITH_H_
#define TVM_SUPPORT_WITH_H_
+#include <exception>
#include <functional>
+#include <memory>
#include <utility>
+#include <vector>
namespace tvm {
@@ -90,5 +93,76 @@ class With {
ContextType ctx_;
};
+/*!
+ * \brief A group of RAII contexts managed together.
+ *
+ * Allows dynamically emplacing multiple context objects that are
+ * all exited (in reverse order) when the group is destroyed.
+ * ContextType must declare `friend class With<ContextType>`
+ * and provide EnterWithScope() / ExitWithScope() methods.
+ *
+ * \code
+ * WithGroup<ConstraintContext> group;
+ * group.Emplace(&analyzer, cond1); // constructs and enters
+ * group.Emplace(&analyzer, cond2); // constructs and enters
+ * // destructor: exits cond2, then cond1
+ * \endcode
+ *
+ * \tparam ContextType The context type with EnterWithScope/ExitWithScope.
+ */
+template <typename ContextType>
+class WithGroup {
+ public:
+ WithGroup() = default;
+ WithGroup(WithGroup&&) = default;
+ WithGroup& operator=(WithGroup&&) = default;
+ WithGroup(const WithGroup&) = delete;
+ WithGroup& operator=(const WithGroup&) = delete;
+
+ /*!
+ * \brief Construct a context and enter its scope.
+ * \param args Arguments forwarded to ContextType constructor.
+ */
+ template <typename... Args>
+ void Emplace(Args&&... args) {
+
entries_.push_back(std::make_unique<With<ContextType>>(std::forward<Args>(args)...));
+ }
+
+ /*! \brief Number of active contexts in this group. */
+ size_t size() const { return entries_.size(); }
+
+ /*!
+ * \brief Destructor — exits all contexts in reverse order.
+ *
+ * On normal exit: if any ExitWithScope throws, the remaining
+ * contexts are still cleaned up, then the first exception
+ * is re-thrown.
+ *
+ * During stack unwinding: all exceptions are swallowed
+ * to avoid std::terminate.
+ */
+ ~WithGroup() noexcept(false) {
+ bool unwinding = std::uncaught_exceptions() > 0;
+ std::exception_ptr first_exc;
+ while (!entries_.empty()) {
+ // Move the last entry out of the vector first, then destroy it.
+ // This ensures entries_ shrinks even if ~With() throws.
+ auto entry = std::move(entries_.back());
+ entries_.pop_back();
+ try {
+ entry.reset(); // calls ~With<ContextType>() -> ExitWithScope()
+ } catch (...) {
+ if (!unwinding && !first_exc) {
+ first_exc = std::current_exception();
+ }
+ }
+ }
+ if (first_exc) std::rethrow_exception(first_exc);
+ }
+
+ private:
+ std::vector<std::unique_ptr<With<ContextType>>> entries_;
+};
+
} // namespace tvm
#endif // TVM_SUPPORT_WITH_H_
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index b41c92d66a..c12e14e0c0 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -176,7 +176,7 @@ class AssertStmtNode : public StmtNode {
*/
class AssertStmt : public Stmt {
public:
- TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span
span = Span());
+ TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AssertStmt, Stmt, AssertStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode);
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 6abd96a360..293ebed6a7 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -90,26 +90,19 @@ class AssertStmt(Stmt):
message : PrimExpr
The error message.
- body : tvm.tir.Stmt
- The body statement.
-
span : Optional[Span]
The location of the stmt in the source code.
"""
condition: PrimExpr
message: PrimExpr
- body: Stmt
span: Optional[Span]
- def __init__(
- self, condition: PrimExpr, message: PrimExpr, body: Stmt, span:
Optional[Span] = None
- ) -> None:
+ def __init__(self, condition: PrimExpr, message: PrimExpr, span:
Optional[Span] = None) -> None:
self.__init_handle_by_constructor__(
_ffi_api.AssertStmt,
condition,
message,
- body,
span, # type: ignore
)
diff --git a/src/arith/ir_mutator_with_analyzer.cc
b/src/arith/ir_mutator_with_analyzer.cc
index e57e4c0d5a..96a64c5e47 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -60,19 +60,23 @@ ffi::Array<PrimExpr>
IRMutatorWithAnalyzer::IterMapSimplifyWithContext(
}
Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
- // record the loop variable as iterators
- Range dom = Range::FromMinExtent(op->min, op->extent);
- analyzer_->Bind(op->loop_var, dom);
- iter_vars_.Set(op->loop_var, dom);
- return StmtExprMutator::VisitStmt_(op);
+ return constraint_scope_.WithNewScope([&]() -> Stmt {
+ // record the loop variable as iterators
+ Range dom = Range::FromMinExtent(op->min, op->extent);
+ analyzer_->Bind(op->loop_var, dom);
+ iter_vars_.Set(op->loop_var, dom);
+ return StmtExprMutator::VisitStmt_(op);
+ });
}
Stmt IRMutatorWithAnalyzer::VisitStmt_(const SBlockNode* op) {
- for (const auto& iter_var : op->iter_vars) {
- analyzer_->Bind(iter_var->var, iter_var->dom);
- iter_vars_.Set(iter_var->var, iter_var->dom);
- }
- return StmtExprMutator::VisitStmt_(op);
+ return constraint_scope_.WithNewScope([&]() -> Stmt {
+ for (const auto& iter_var : op->iter_vars) {
+ analyzer_->Bind(iter_var->var, iter_var->dom);
+ iter_vars_.Set(iter_var->var, iter_var->dom);
+ }
+ return StmtExprMutator::VisitStmt_(op);
+ });
}
Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
@@ -94,88 +98,97 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode*
op) {
}
Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
- PrimExpr condition = this->VisitExpr(op->condition);
- PrimExpr real_condition = condition;
- static auto op_likely = Op::Get("tir.likely");
+ return constraint_scope_.WithNewScope([&]() -> Stmt {
+ PrimExpr condition = this->VisitExpr(op->condition);
+ PrimExpr real_condition = condition;
+ static auto op_likely = Op::Get("tir.likely");
- if (auto call = condition.as<CallNode>()) {
- if (call->op.same_as(op_likely)) {
- real_condition = call->args[0];
+ if (auto call = condition.as<CallNode>()) {
+ if (call->op.same_as(op_likely)) {
+ real_condition = call->args[0];
+ }
}
- }
- Stmt then_case;
- ffi::Optional<Stmt> else_case;
- {
- With<ConstraintContext> ctx(analyzer_, real_condition);
- WithRecordIterPredicate(real_condition, [&] { then_case =
this->VisitStmt(op->then_case); });
- }
- if (op->else_case) {
- With<ConstraintContext> ctx(analyzer_,
analyzer_->rewrite_simplify(Not(real_condition)));
- else_case = this->VisitStmt(op->else_case.value());
- }
- if (is_one(real_condition)) return then_case;
- if (is_zero(real_condition)) {
- return else_case.value_or(Evaluate(0));
- }
+ Stmt then_case;
+ ffi::Optional<Stmt> else_case;
+ constraint_scope_.WithNewScope([&]() {
+ constraint_scope_.Current().Emplace(analyzer_, real_condition);
+ WithRecordIterPredicate(real_condition, [&] { then_case =
this->VisitStmt(op->then_case); });
+ });
+ if (op->else_case) {
+ PrimExpr neg_condition =
analyzer_->rewrite_simplify(Not(real_condition));
+ constraint_scope_.WithNewScope([&]() {
+ constraint_scope_.Current().Emplace(analyzer_, neg_condition);
+ else_case = this->VisitStmt(op->else_case.value());
+ });
+ }
+ if (is_one(real_condition)) return then_case;
+ if (is_zero(real_condition)) {
+ return else_case.value_or(Evaluate(0));
+ }
- if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
- else_case.same_as(op->else_case)) {
- return ffi::GetRef<Stmt>(op);
- } else {
- auto n = this->CopyOnWrite(op);
- n->condition = std::move(condition);
- n->then_case = std::move(then_case);
- n->else_case = std::move(else_case);
- return Stmt(n);
- }
+ if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
+ else_case.same_as(op->else_case)) {
+ return ffi::GetRef<Stmt>(op);
+ } else {
+ auto n = this->CopyOnWrite(op);
+ n->condition = std::move(condition);
+ n->then_case = std::move(then_case);
+ n->else_case = std::move(else_case);
+ return Stmt(n);
+ }
+ });
}
Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
tir::attr::virtual_thread) {
- IterVar iv = Downcast<IterVar>(op->node);
- TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U);
- Range dom = Range::FromMinExtent(make_zero(op->value.dtype()), op->value);
- analyzer_->Bind(iv->var, dom);
- iter_vars_.Set(iv->var, dom);
- Stmt stmt = StmtExprMutator::VisitStmt_(op);
- return stmt;
- } else {
+ return constraint_scope_.WithNewScope([&]() -> Stmt {
+ if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
tir::attr::virtual_thread) {
+ IterVar iv = Downcast<IterVar>(op->node);
+ TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U);
+ Range dom = Range::FromMinExtent(make_zero(op->value.dtype()),
op->value);
+ analyzer_->Bind(iv->var, dom);
+ iter_vars_.Set(iv->var, dom);
+ }
return StmtExprMutator::VisitStmt_(op);
- }
+ });
}
Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
PrimExpr message = this->VisitExpr(op->message);
- With<ConstraintContext> ctx(analyzer_, condition);
- Stmt body = this->VisitStmt(op->body);
+ constraint_scope_.Current().Emplace(analyzer_, condition);
- if (condition.same_as(op->condition) && message.same_as(op->message) &&
body.same_as(op->body)) {
+ if (condition.same_as(op->condition) && message.same_as(op->message)) {
return ffi::GetRef<Stmt>(op);
} else {
auto n = this->CopyOnWrite(op);
n->condition = std::move(condition);
n->message = std::move(message);
- n->body = std::move(body);
return Stmt(n);
}
}
+Stmt IRMutatorWithAnalyzer::VisitStmt_(const SeqStmtNode* op) {
+ // SeqStmt does NOT get WithNewScope — constraints accumulate across
siblings.
+ return StmtExprMutator::VisitStmt_(op);
+}
+
PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
static auto op_if_then_else = Op::Get("tir.if_then_else");
if (op->op.same_as(op_if_then_else)) {
PrimExpr cond = this->VisitExpr(op->args[0]);
PrimExpr true_value, false_value;
- {
- With<ConstraintContext> constraint(analyzer_, cond);
+ constraint_scope_.WithNewScope([&]() {
+ constraint_scope_.Current().Emplace(analyzer_, cond);
WithRecordIterPredicate(cond, [&] { true_value =
this->VisitExpr(op->args[1]); });
- }
+ });
{
PrimExpr not_cond = Not(cond);
- With<ConstraintContext> constraint(analyzer_, not_cond);
- WithRecordIterPredicate(not_cond, [&] { false_value =
this->VisitExpr(op->args[2]); });
+ constraint_scope_.WithNewScope([&]() {
+ constraint_scope_.Current().Emplace(analyzer_, not_cond);
+ WithRecordIterPredicate(not_cond, [&] { false_value =
this->VisitExpr(op->args[2]); });
+ });
}
if (is_zero(cond)) {
return false_value;
@@ -211,13 +224,16 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode*
op) {
PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const SelectNode* op) {
PrimExpr cond = this->VisitExpr(op->condition);
PrimExpr true_value, false_value;
- {
- With<ConstraintContext> constraint(analyzer_, cond);
+ constraint_scope_.WithNewScope([&]() {
+ constraint_scope_.Current().Emplace(analyzer_, cond);
true_value = VisitExpr(op->true_value);
- }
+ });
{
- With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not(cond)));
- false_value = VisitExpr(op->false_value);
+ PrimExpr neg_cond = analyzer_->rewrite_simplify(Not(cond));
+ constraint_scope_.WithNewScope([&]() {
+ constraint_scope_.Current().Emplace(analyzer_, neg_cond);
+ false_value = VisitExpr(op->false_value);
+ });
}
if (is_zero(cond)) {
return false_value;
diff --git a/src/arith/ir_mutator_with_analyzer.h
b/src/arith/ir_mutator_with_analyzer.h
index 5b5ac7e6cd..8810a8f78f 100644
--- a/src/arith/ir_mutator_with_analyzer.h
+++ b/src/arith/ir_mutator_with_analyzer.h
@@ -25,6 +25,8 @@
#define TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_
#include <tvm/arith/analyzer.h>
+#include <tvm/ir/scope_stack.h>
+#include <tvm/support/with.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
@@ -56,6 +58,7 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override;
tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override;
tir::Stmt VisitStmt_(const tir::AssertStmtNode* op) override;
+ tir::Stmt VisitStmt_(const tir::SeqStmtNode* op) override;
PrimExpr VisitExpr_(const tir::LetNode* op) override;
PrimExpr VisitExpr_(const tir::SelectNode* op) override;
PrimExpr VisitExpr_(const tir::CallNode* op) override;
@@ -79,6 +82,8 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
/*! \brief internal analyzer field. */
Analyzer* analyzer_;
+ /*! \brief Scope stack for accumulated assert constraints. */
+ ScopeStack<WithGroup<ConstraintContext>> constraint_scope_;
// the following two fields are useful in case we want
// note however that iter map analysis are usually more
// expensive and we only encourage doing them during
diff --git a/src/arith/ir_visitor_with_analyzer.cc
b/src/arith/ir_visitor_with_analyzer.cc
index fada12e9c4..02c194b14e 100644
--- a/src/arith/ir_visitor_with_analyzer.cc
+++ b/src/arith/ir_visitor_with_analyzer.cc
@@ -32,15 +32,19 @@ namespace arith {
using namespace tir;
void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) {
- analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
- StmtExprVisitor::VisitStmt_(op);
+ constraint_scope_.WithNewScope([&]() {
+ analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
+ StmtExprVisitor::VisitStmt_(op);
+ });
}
void IRVisitorWithAnalyzer::VisitStmt_(const SBlockNode* op) {
- for (const auto& iter_var : op->iter_vars) {
- analyzer_.Bind(iter_var->var, iter_var->dom);
- }
- StmtExprVisitor::VisitStmt_(op);
+ constraint_scope_.WithNewScope([&]() {
+ for (const auto& iter_var : op->iter_vars) {
+ analyzer_.Bind(iter_var->var, iter_var->dom);
+ }
+ StmtExprVisitor::VisitStmt_(op);
+ });
}
void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
@@ -50,34 +54,45 @@ void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode*
op) {
}
void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
- this->VisitExpr(op->condition);
-
- PrimExpr real_condition = ExtractRealCondition(op->condition);
-
- {
- With<ConstraintContext> constraint(&analyzer_, real_condition);
- this->VisitStmt(op->then_case);
- }
- if (op->else_case) {
- With<ConstraintContext> constraint(&analyzer_,
analyzer_.rewrite_simplify(Not(real_condition)));
- this->VisitStmt(op->else_case.value());
- }
+ constraint_scope_.WithNewScope([&]() {
+ this->VisitExpr(op->condition);
+
+ PrimExpr real_condition = ExtractRealCondition(op->condition);
+
+ constraint_scope_.WithNewScope([&]() {
+ constraint_scope_.Current().Emplace(&analyzer_, real_condition);
+ this->VisitStmt(op->then_case);
+ });
+ if (op->else_case) {
+ constraint_scope_.WithNewScope([&]() {
+ constraint_scope_.Current().Emplace(&analyzer_,
+
analyzer_.rewrite_simplify(Not(real_condition)));
+ this->VisitStmt(op->else_case.value());
+ });
+ }
+ });
}
void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
tir::attr::virtual_thread) {
- IterVar iv = Downcast<IterVar>(op->node);
- TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U);
- analyzer_.Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0),
op->value));
- }
- StmtExprVisitor::VisitStmt_(op);
+ constraint_scope_.WithNewScope([&]() {
+ if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
tir::attr::virtual_thread) {
+ IterVar iv = Downcast<IterVar>(op->node);
+ TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U);
+ analyzer_.Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype,
0), op->value));
+ }
+ StmtExprVisitor::VisitStmt_(op);
+ });
}
void IRVisitorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) {
this->VisitExpr(op->condition);
this->VisitExpr(op->message);
- With<ConstraintContext> constraint(&analyzer_, op->condition);
- this->VisitStmt(op->body);
+ constraint_scope_.Current().Emplace(&analyzer_, op->condition);
+}
+
+void IRVisitorWithAnalyzer::VisitStmt_(const SeqStmtNode* op) {
+ // SeqStmt does NOT get WithNewScope — constraints accumulate across
siblings.
+ StmtExprVisitor::VisitStmt_(op);
}
void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) {
@@ -86,14 +101,14 @@ void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op)
{
if (op->op.same_as(op_if_then_else)) {
PrimExpr cond = op->args[0];
this->VisitExpr(op->args[0]);
- {
- With<ConstraintContext> constraint(&analyzer_, cond);
+ constraint_scope_.WithNewScope([&]() {
+ constraint_scope_.Current().Emplace(&analyzer_, cond);
this->VisitExpr(op->args[1]);
- }
- {
- With<ConstraintContext> constraint(&analyzer_,
analyzer_.rewrite_simplify(Not(cond)));
+ });
+ constraint_scope_.WithNewScope([&]() {
+ constraint_scope_.Current().Emplace(&analyzer_,
analyzer_.rewrite_simplify(Not(cond)));
this->VisitExpr(op->args[2]);
- }
+ });
} else {
StmtExprVisitor::VisitExpr_(op);
}
diff --git a/src/arith/ir_visitor_with_analyzer.h
b/src/arith/ir_visitor_with_analyzer.h
index cd2b9bfdec..f0553a1c42 100644
--- a/src/arith/ir_visitor_with_analyzer.h
+++ b/src/arith/ir_visitor_with_analyzer.h
@@ -26,6 +26,8 @@
#define TVM_ARITH_IR_VISITOR_WITH_ANALYZER_H_
#include <tvm/arith/analyzer.h>
+#include <tvm/ir/scope_stack.h>
+#include <tvm/support/with.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
@@ -45,6 +47,7 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor {
void VisitStmt_(const tir::IfThenElseNode* op);
void VisitStmt_(const tir::AttrStmtNode* op);
void VisitStmt_(const tir::AssertStmtNode* op);
+ void VisitStmt_(const tir::SeqStmtNode* op);
void VisitExpr_(const tir::CallNode* op);
void VisitExpr_(const tir::LetNode* op);
void VisitExpr_(const tir::ReduceNode* op);
@@ -57,6 +60,9 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor {
/*! \brief internal analyzer field. */
arith::Analyzer analyzer_;
+ /*! \brief Scope stack for accumulated assert constraints. */
+ ScopeStack<WithGroup<ConstraintContext>> constraint_scope_;
+
/*! \brief Extract a constraint from a conditional statement
*
* Intended for preparing argument for use in
diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc
index 2e05ea4a81..bd0f963c49 100644
--- a/src/relax/op/tensor/inspect.cc
+++ b/src/relax/op/tensor/inspect.cc
@@ -315,9 +315,11 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const
Call& call) {
IntImm(DataType::Int(32),
tir::builtin::TVMStructFieldKind::kArrShape)}),
body);
- body = tir::AssertStmt(
- axis < tvm::cast(axis->dtype, ndim),
- tir::StringImm("Specified axis may not be larger than the tensor's
dimensionality"), body);
+ body = tir::SeqStmt(
+ {tir::AssertStmt(
+ axis < tvm::cast(axis->dtype, ndim),
+ tir::StringImm("Specified axis may not be larger than the
tensor's dimensionality")),
+ body});
body = tir::LetStmt(
ndim,
@@ -326,7 +328,8 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const
Call& call) {
IntImm(DataType::Int(32),
tir::builtin::TVMStructFieldKind::kArrNDim)}),
body);
- body = tir::AssertStmt(0 <= axis, tir::StringImm("Specified axis may not
be negative"), body);
+ body = tir::SeqStmt(
+ {tir::AssertStmt(0 <= axis, tir::StringImm("Specified axis may not be
negative")), body});
DictAttrs attrs({{"tir.is_scheduled", true}, {"tir.is_host", true}});
diff --git a/src/s_tir/analysis/estimate_flops.cc
b/src/s_tir/analysis/estimate_flops.cc
index b11d262281..0f106f5854 100644
--- a/src/s_tir/analysis/estimate_flops.cc
+++ b/src/s_tir/analysis/estimate_flops.cc
@@ -199,7 +199,6 @@ class FlopEstimator : private ExprFunctor<TResult(const
PrimExpr& n)>,
if (op->message.defined()) {
result += VisitExpr(op->message);
}
- result += VisitStmt(op->body);
return result;
}
diff --git a/src/s_tir/transform/bound_checker.cc
b/src/s_tir/transform/bound_checker.cc
index dbee2effdb..62b76fa806 100644
--- a/src/s_tir/transform/bound_checker.cc
+++ b/src/s_tir/transform/bound_checker.cc
@@ -96,9 +96,8 @@ class BoundChecker : public StmtExprMutator {
if (store_scope_bound_collector_.size()) {
PrimExpr condition = MakeCondition();
if (!condition.as<StringImmNode>()) {
- Stmt nop = Evaluate(1);
Stmt then_case = ffi::GetRef<Stmt>(op);
- Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop);
+ Stmt else_case = AssertStmt(condition, StringImm(error_message_));
Stmt body = IfThenElse(condition, then_case, else_case);
return body;
}
diff --git a/src/script/ir_builder/tir/frame.cc
b/src/script/ir_builder/tir/frame.cc
index e5008c74ed..e65be1d45f 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -128,7 +128,16 @@ void ForFrameNode::ExitWithScope() {
void AssertFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
- AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts)));
+ if (stmts.empty()) {
+ AddToParent(tvm::tir::AssertStmt(condition, message));
+ } else {
+ ffi::Array<tvm::tir::Stmt> seq;
+ seq.push_back(tvm::tir::AssertStmt(condition, message));
+ for (const auto& stmt : stmts) {
+ seq.push_back(stmt);
+ }
+ AddToParent(tvm::tir::SeqStmt(seq));
+ }
}
void LetFrameNode::ExitWithScope() {
diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc
index eff58cf411..633704c164 100644
--- a/src/script/printer/tir/stmt.cc
+++ b/src/script/printer/tir/stmt.cc
@@ -131,20 +131,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<tir::AssertStmt>(
- "", [](tir::AssertStmt stmt, AccessPath p, IRDocsifier d) -> Doc {
- bool concise = AllowConciseScoping(d, stmt);
- ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition,
p->Attr("condition"));
- ExprDoc msg = d->AsDoc<ExprDoc>(stmt->message, p->Attr("message"));
- With<TIRFrame> f(d, stmt);
- AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
- if (concise) {
- ffi::Array<StmtDoc>* stmts = &(*f)->stmts;
- stmts->insert(stmts->begin(), AssertDoc(cond, msg));
- return StmtBlockDoc(*stmts);
- }
- return ScopeDoc(std::nullopt, TIR(d, "Assert")->Call({cond, msg}),
(*f)->stmts);
- });
+ .set_dispatch<tir::AssertStmt>("",
+ [](tir::AssertStmt stmt, AccessPath p,
IRDocsifier d) -> Doc {
+ ExprDoc cond =
+ d->AsDoc<ExprDoc>(stmt->condition,
p->Attr("condition"));
+ ExprDoc msg =
+ d->AsDoc<ExprDoc>(stmt->message,
p->Attr("message"));
+ return AssertDoc(cond, msg);
+ });
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::While>("", [](tir::While stmt, AccessPath p,
IRDocsifier d) -> Doc {
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 93f0282015..472567efc7 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -2165,9 +2165,8 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) {
EmitDebugLocation(op);
- // auto a_cu =
- With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
- this->VisitStmt(op->body);
+ // AssertStmt is a leaf — no body to visit.
+ // Constraint scoping is handled by ScopeStack in analysis passes.
}
void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) {
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index d3e1cee46e..95b1260e45 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -1088,7 +1088,6 @@ void CodeGenC::VisitStmt_(const AssertStmtNode* op) {
} else {
stream << "assert(" << cond << ");\n";
}
- this->PrintStmt(op->body);
}
void CodeGenC::VisitStmt_(const ForNode* op) {
diff --git a/src/target/source/codegen_c_host.cc
b/src/target/source/codegen_c_host.cc
index cb6ba238ef..ca6b71b4d8 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -332,7 +332,6 @@ void CodeGenCHost::VisitStmt_(const AssertStmtNode* op) {
// NOLINT(*)
PrintIndent();
stream << "}\n";
}
- this->PrintStmt(op->body);
}
void CodeGenCHost::VisitExpr_(const MinNode* op, std::ostream& os) { //
NOLINT(*)
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index eb5351bd0f..3a9cceb687 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -698,8 +698,7 @@ void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
}
void CodeGenWebGPU::VisitStmt_(const AssertStmtNode* op) {
- // skip assert
- PrintStmt(op->body);
+ // skip assert — AssertStmt is a leaf, nothing to emit.
}
void CodeGenWebGPU::VisitStmt_(const WhileNode* op) {
diff --git a/src/target/spirv/codegen_spirv.cc
b/src/target/spirv/codegen_spirv.cc
index ff8053a309..114bdbc806 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -886,8 +886,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
}
void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) {
- With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
- this->VisitStmt(op->body);
+ // AssertStmt is a leaf — no body to visit.
}
void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) {
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 4f0dbaf121..5eaff0c314 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -104,7 +104,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
}
// AssertStmt
-AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span
span) {
+AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Span span) {
TVM_FFI_ICHECK(condition.defined());
TVM_FFI_ICHECK(condition.dtype().is_bool())
<< "AssertStmt should have boolean condition, "
@@ -115,17 +115,15 @@ AssertStmt::AssertStmt(PrimExpr condition, PrimExpr
message, Stmt body, Span spa
ObjectPtr<AssertStmtNode> node = ffi::make_object<AssertStmtNode>();
node->condition = std::move(condition);
node->message = std::move(message);
- node->body = std::move(body);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.AssertStmt",
- [](PrimExpr condition, StringImm message, Stmt body,
Span span) {
- return AssertStmt(condition, message, body, span);
- });
+ refl::GlobalDef().def("tir.AssertStmt", [](PrimExpr condition, StringImm
message, Span span) {
+ return AssertStmt(condition, message, span);
+ });
}
// For
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index a3e59914c1..659056a3a2 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -81,7 +81,6 @@ void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
void StmtVisitor::VisitStmt_(const AssertStmtNode* op) {
this->VisitExpr(op->condition);
this->VisitExpr(op->message);
- this->VisitStmt(op->body);
}
void StmtVisitor::VisitStmt_(const SeqStmtNode* op) {
@@ -398,15 +397,13 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op,
bool flatten_before_visit
Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
PrimExpr message = this->VisitExpr(op->message);
- Stmt body = this->VisitStmt(op->body);
- if (condition.same_as(op->condition) && message.same_as(op->message) &&
body.same_as(op->body)) {
+ if (condition.same_as(op->condition) && message.same_as(op->message)) {
return ffi::GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->condition = std::move(condition);
n->message = std::move(message);
- n->body = std::move(body);
return Stmt(n);
}
}
diff --git a/src/tir/ir/tir_visitor_with_path.cc
b/src/tir/ir/tir_visitor_with_path.cc
index 2e9968790f..0c03a3c2c8 100644
--- a/src/tir/ir/tir_visitor_with_path.cc
+++ b/src/tir/ir/tir_visitor_with_path.cc
@@ -252,7 +252,6 @@ void TIRVisitorWithPath::VisitStmt_(const IfThenElseNode*
op, AccessPath path) {
void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, AccessPath path)
{
Visit(op->condition, path->Attr("condition"));
Visit(op->message, path->Attr("message"));
- Visit(op->body, path->Attr("body"));
}
void TIRVisitorWithPath::VisitStmt_(const SeqStmtNode* op, AccessPath path) {
diff --git a/src/tir/transform/arg_binder.cc b/src/tir/transform/arg_binder.cc
index de44c1449d..bab4cdb3da 100644
--- a/src/tir/transform/arg_binder.cc
+++ b/src/tir/transform/arg_binder.cc
@@ -43,7 +43,7 @@ void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond,
const std::string& arg
if (!is_one(scond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint: " <<
cond;
- asserts->emplace_back(AssertStmt(scond, tvm::tir::StringImm(os.str()),
Evaluate(0)));
+ asserts->emplace_back(AssertStmt(scond, tvm::tir::StringImm(os.str())));
}
}
@@ -159,7 +159,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const
PrimExpr& device_type,
init_nest_.emplace_back(AssertStmt(
!Call(DataType::Bool(), builtin::isnullptr(), {handle}),
- tvm::tir::StringImm(arg_name + " is expected to have non-NULL DLTensor*
pointer"), nop));
+ tvm::tir::StringImm(arg_name + " is expected to have non-NULL DLTensor*
pointer")));
// dimension checks
PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
@@ -179,7 +179,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const
PrimExpr& device_type,
std::ostringstream ndim_err_msg;
ndim_err_msg << arg_name << ".ndim is expected to equal " <<
buffer->shape.size();
auto msg = tvm::tir::StringImm(ndim_err_msg.str());
- init_nest_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
+ init_nest_.emplace_back(AssertStmt(a_ndim == v_ndim, msg));
// type checks
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype;
@@ -192,7 +192,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const
PrimExpr& device_type,
if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4)
||
buffer->dtype == DataType::UInt(4))) {
auto type_msg = tvm::tir::StringImm(type_err_msg.str());
- asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
+ asserts_.emplace_back(AssertStmt(cond, type_msg));
}
// shape field
@@ -238,7 +238,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const
PrimExpr& device_type,
Stmt check = AssertStmt(
foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a,
b, span); },
const_true(1), conds),
- stride_msg, Evaluate(0));
+ stride_msg);
check = IfThenElse(Not(v_strides_is_null), check);
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
@@ -314,9 +314,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const
PrimExpr& device_type,
}
return product;
}();
- asserts_.emplace_back(AssertStmt(
- alloc_size == 0 || !Call(DataType::Bool(), builtin::isnullptr(),
{vptr}),
- tvm::tir::StringImm(arg_name + " is expected to have non-NULL data
pointer"), nop));
+ asserts_.emplace_back(
+ AssertStmt(alloc_size == 0 || !Call(DataType::Bool(),
builtin::isnullptr(), {vptr}),
+ tvm::tir::StringImm(arg_name + " is expected to have
non-NULL data pointer")));
def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
// mark alignment of external bufs
diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc
index f661722e2f..4e8a590a7d 100644
--- a/src/tir/transform/ir_utils.cc
+++ b/src/tir/transform/ir_utils.cc
@@ -67,11 +67,8 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
TVM_FFI_ICHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1]));
n->seq.Set(n->size() - 1, body);
body = Stmt(n);
- } else if (const auto* assert_ = s.as<AssertStmtNode>()) {
- auto n = ffi::make_object<AssertStmtNode>(*assert_);
- TVM_FFI_ICHECK(is_no_op(n->body));
- n->body = body;
- body = Stmt(n);
+ } else if (s.as<AssertStmtNode>()) {
+ body = SeqStmt({s, body});
} else if (const auto* alloc = s.as<AllocateNode>()) {
auto n = ffi::make_object<AllocateNode>(*alloc);
TVM_FFI_ICHECK(is_no_op(n->body));
diff --git a/src/tir/transform/make_packed_api.cc
b/src/tir/transform/make_packed_api.cc
index 4c35b3fdd8..af63d9ba54 100644
--- a/src/tir/transform/make_packed_api.cc
+++ b/src/tir/transform/make_packed_api.cc
@@ -168,12 +168,12 @@ class SubroutineCallRewriter : public StmtExprMutator {
} // namespace
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
- return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
+ return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg));
}
inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) {
Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr});
- return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0));
+ return AssertStmt(!isnull, tvm::tir::StringImm(msg));
}
/* \brief Return the global_symbol of the function, if it should be updated
@@ -300,7 +300,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
type_index ==
ffi::TypeIndex::kTVMFFIOpaquePtr ||
type_index ==
ffi::TypeIndex::kTVMFFIDLTensorPtr ||
type_index >=
ffi::TypeIndex::kTVMFFIStaticObjectBegin,
- tvm::tir::StringImm(msg.str()), nop));
+ tvm::tir::StringImm(msg.str())));
// if type_index is Tensor, we need to add the offset of the DLTensor
header
// which always equals 16 bytes, this ensures that T.handle always shows
up as a DLTensor*
const int64_t object_cell_offset = sizeof(TVMFFIObject);
@@ -316,7 +316,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
msg << name_hint << ": Expect arg[" << i << "] to be boolean";
seq_init.emplace_back(AssertStmt(
type_index == ffi::TypeIndex::kTVMFFIBool || type_index ==
ffi::TypeIndex::kTVMFFIInt,
- tvm::tir::StringImm(msg.str()), nop));
+ tvm::tir::StringImm(msg.str())));
arg_value = Cast(DataType::Bool(), f_load_arg_value(DataType::Int(64),
i));
} else if (dtype.is_int() || dtype.is_uint()) {
@@ -324,7 +324,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_init.emplace_back(AssertStmt(
type_index == ffi::TypeIndex::kTVMFFIInt || type_index ==
ffi::TypeIndex::kTVMFFIBool,
- tvm::tir::StringImm(msg.str()), nop));
+ tvm::tir::StringImm(msg.str())));
arg_value = f_load_arg_value(param.dtype(), i);
} else {
TVM_FFI_ICHECK(dtype.is_float());
@@ -333,7 +333,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
seq_init.emplace_back(AssertStmt(type_index ==
ffi::TypeIndex::kTVMFFIFloat ||
type_index ==
ffi::TypeIndex::kTVMFFIInt ||
type_index ==
ffi::TypeIndex::kTVMFFIBool,
- tvm::tir::StringImm(msg.str()), nop));
+ tvm::tir::StringImm(msg.str())));
// use select so we can also handle int conversion to bool
arg_value = tir::Select(
type_index == ffi::TypeIndex::kTVMFFIFloat,
diff --git a/src/tir/transform/skip_assert.cc b/src/tir/transform/skip_assert.cc
index b2c473c97c..8e997bc9ee 100644
--- a/src/tir/transform/skip_assert.cc
+++ b/src/tir/transform/skip_assert.cc
@@ -29,9 +29,8 @@ namespace tir {
class AssertSkipper : public StmtMutator {
public:
Stmt VisitStmt_(const AssertStmtNode* op) final {
- Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<AssertStmtNode>();
- return op->body;
+ // AssertStmt is a leaf — just remove it.
+ return Evaluate(0);
}
};
diff --git a/src/tir/transform/split_host_device.cc
b/src/tir/transform/split_host_device.cc
index 130cc177f0..43b6701f1f 100644
--- a/src/tir/transform/split_host_device.cc
+++ b/src/tir/transform/split_host_device.cc
@@ -104,7 +104,7 @@ class HostDeviceSplitter : public StmtMutator {
Var kernel_error_code("kernel_error_code", success->dtype);
Call kernel_call(success->dtype, kernel_symbol_global, args);
AssertStmt assert_success(kernel_error_code == success,
- StringImm("Error executing compute kernel"),
Evaluate(0));
+ StringImm("Error executing compute kernel"));
LetStmt let_check(kernel_error_code, kernel_call, assert_success);
return let_check;
diff --git a/tests/python/tir-base/test_tir_constructor.py
b/tests/python/tir-base/test_tir_constructor.py
index 5347146c05..d05201a143 100644
--- a/tests/python/tir-base/test_tir_constructor.py
+++ b/tests/python/tir-base/test_tir_constructor.py
@@ -141,9 +141,8 @@ def test_stmt_constructor():
assert isinstance(x, tvm.tir.AttrStmt)
assert x.value.value == 1
- x = tvm.tir.AssertStmt(tvm.tir.const(1, "bool"),
tvm.runtime.convert("hellow"), nop)
+ x = tvm.tir.AssertStmt(tvm.tir.const(1, "bool"),
tvm.runtime.convert("hellow"))
assert isinstance(x, tvm.tir.AssertStmt)
- assert x.body == nop
x = tvm.tir.For(tvm.tir.Var("x", "int32"), 0, 10, tvm.tir.ForKind.SERIAL,
nop)
assert isinstance(x, tvm.tir.For)
diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
index b65c0ec23a..71e8f37803 100644
--- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
@@ -298,8 +298,13 @@ def test_ir_builder_tir_assert():
# the assert generated by IRBuilder
assert_actual = ib.get()
- # the expected assert statement
- assert_expected = tir.AssertStmt(T.int32() == 0, tir.StringImm("a is 0"),
tir.Evaluate(0))
+ # AssertStmt is a leaf. The frame emits the assert and then the body stmts
as siblings.
+ assert_expected = tir.SeqStmt(
+ [
+ tir.AssertStmt(T.int32() == 0, tir.StringImm("a is 0")),
+ tir.Evaluate(0),
+ ]
+ )
# Check if the generated ir is expected
assert_structural_equal(assert_actual, assert_expected, map_free_vars=True)
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index c2091ae5e6..93ddff0fd4 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -289,8 +289,8 @@ def test_assert_stmt():
_assert_print(
obj,
"""
-with T.Assert(T.bool(True), "assertion"):
- T.evaluate(0)
+assert T.bool(True), "assertion"
+T.evaluate(0)
""",
)