This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s1
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 1a120df445f22e3bf9ba477a13f18a5a25c75985
Author: tqchen <[email protected]>
AuthorDate: Wed Apr 16 12:13:05 2025 -0400

    Fix MS trace
---
 src/meta_schedule/postproc/rewrite_layout.cc             |  2 +-
 .../postproc/rewrite_parallel_vectorize_unroll.cc        | 16 ++++++++--------
 src/tir/ir/stmt.cc                                       |  9 ++-------
 .../meta_schedule/test_meta_schedule_trace_apply.py      |  4 ++--
 4 files changed, 13 insertions(+), 18 deletions(-)

diff --git a/src/meta_schedule/postproc/rewrite_layout.cc 
b/src/meta_schedule/postproc/rewrite_layout.cc
index 81681528fb..6c8079fdda 100644
--- a/src/meta_schedule/postproc/rewrite_layout.cc
+++ b/src/meta_schedule/postproc/rewrite_layout.cc
@@ -183,7 +183,7 @@ bool RewriteLayout(const Schedule& sch) {
   std::vector<std::pair<StmtSRef, String>> results;
   auto add_layout_rewrite_block = [&sch](BlockRV consumer_block_rv, int 
buffer_index) {
     BlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, 
"global");
-    sch->Annotate(rewrite_block_rv, 
attr::meta_schedule_layout_rewrite_preproc, const_true());
+    sch->Annotate(rewrite_block_rv, 
attr::meta_schedule_layout_rewrite_preproc, true);
   };
 
   for (const auto& [g_var, base_func] : sch->mod()->functions) {
diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc 
b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index 9779603f90..965cc8baef 100644
--- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -106,23 +106,23 @@ bool ParseAnnotation(const Block& block, 
ParsedAnnotation* parsed) {
   for (const auto& ann : block->annotations) {
     if (ann.first == attr::meta_schedule_parallel) {
       found = true;
-      if (const auto* imm = ann.second.as<tir::IntImmNode>()) {
-        parsed->max_parallel_extent = imm->value;
+      if (auto opt_int_imm = ann.second.as<IntImm>()) {
+        parsed->max_parallel_extent = (*opt_int_imm)->value;
       }
     } else if (ann.first == attr::meta_schedule_vectorize) {
       found = true;
-      if (const auto* imm = ann.second.as<tir::IntImmNode>()) {
-        parsed->max_vectorize_extent = imm->value;
+      if (auto opt_int_imm = ann.second.as<IntImm>()) {
+        parsed->max_vectorize_extent = (*opt_int_imm)->value;
       }
     } else if (ann.first == attr::meta_schedule_unroll_explicit) {
       found = true;
-      if (const auto* imm = ann.second.as<tir::IntImmNode>()) {
-        parsed->unroll_explicit = imm->value;
+      if (auto opt_int_imm = ann.second.as<IntImm>()) {
+        parsed->unroll_explicit = (*opt_int_imm)->value;
       }
     } else if (ann.first == attr::meta_schedule_unroll_implicit) {
       found = true;
-      if (const auto* imm = ann.second.as<tir::IntImmNode>()) {
-        parsed->unroll_implicit = imm->value;
+      if (auto opt_int_imm = ann.second.as<IntImm>()) {
+        parsed->unroll_implicit = (*opt_int_imm)->value;
       }
     }
   }
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 36b14e9eb5..0a7817c02a 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -102,13 +102,8 @@ AssertStmt::AssertStmt(PrimExpr condition, PrimExpr 
message, Stmt body, Span spa
 TVM_REGISTER_NODE_TYPE(AssertStmtNode);
 
 TVM_REGISTER_GLOBAL("tir.AssertStmt")
-    .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body, Span 
span) {
-      if (const auto* str = message.as<StringObj>()) {
-        auto msg = StringImm(str->bytes.data);
-        return AssertStmt(condition, msg, body, span);
-      } else {
-        return AssertStmt(condition, Downcast<PrimExpr>(message), body, span);
-      }
+    .set_body_typed([](PrimExpr condition, StringImm message, Stmt body, Span 
span) {
+      return AssertStmt(condition, message, body, span);
     });
 
 // For
diff --git a/tests/python/meta_schedule/test_meta_schedule_trace_apply.py 
b/tests/python/meta_schedule/test_meta_schedule_trace_apply.py
index 5f0583c890..bbe086c34e 100644
--- a/tests/python/meta_schedule/test_meta_schedule_trace_apply.py
+++ b/tests/python/meta_schedule/test_meta_schedule_trace_apply.py
@@ -1864,7 +1864,7 @@ def test_dense_add_cpu():
             ),
             pad_value=None,
         )
-        sch.annotate(block_or_loop=b59, 
ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1)
+        sch.annotate(block_or_loop=b59, 
ann_key="meta_schedule.layout_rewrite_preproc", ann_val=True)
 
     verify(Dense, apply_anchor_trace, DenseAdd, "llvm", DenseAdd_scheduled_cpu)
 
@@ -1930,7 +1930,7 @@ def test_dense_add_cpu_no_write_cache():
             ),
             pad_value=None,
         )
-        sch.annotate(block_or_loop=b50, 
ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1)
+        sch.annotate(block_or_loop=b50, 
ann_key="meta_schedule.layout_rewrite_preproc", ann_val=True)
 
     verify(Dense, apply_trace, DenseAdd, "llvm", DenseAdd_cpu_no_write_cache)
 

Reply via email to