This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s1 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 1a120df445f22e3bf9ba477a13f18a5a25c75985 Author: tqchen <[email protected]> AuthorDate: Wed Apr 16 12:13:05 2025 -0400 Fix MS trace --- src/meta_schedule/postproc/rewrite_layout.cc | 2 +- .../postproc/rewrite_parallel_vectorize_unroll.cc | 16 ++++++++-------- src/tir/ir/stmt.cc | 9 ++------- .../meta_schedule/test_meta_schedule_trace_apply.py | 4 ++-- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 81681528fb..6c8079fdda 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -183,7 +183,7 @@ bool RewriteLayout(const Schedule& sch) { std::vector<std::pair<StmtSRef, String>> results; auto add_layout_rewrite_block = [&sch](BlockRV consumer_block_rv, int buffer_index) { BlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global"); - sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true()); + sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, true); }; for (const auto& [g_var, base_func] : sch->mod()->functions) { diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index 9779603f90..965cc8baef 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -106,23 +106,23 @@ bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { for (const auto& ann : block->annotations) { if (ann.first == attr::meta_schedule_parallel) { found = true; - if (const auto* imm = ann.second.as<tir::IntImmNode>()) { - parsed->max_parallel_extent = imm->value; + if (auto opt_int_imm = ann.second.as<IntImm>()) { + parsed->max_parallel_extent = (*opt_int_imm)->value; } } else if (ann.first == attr::meta_schedule_vectorize) { found = true; - if (const auto* imm = ann.second.as<tir::IntImmNode>()) { - parsed->max_vectorize_extent = imm->value; + if (auto opt_int_imm = ann.second.as<IntImm>()) { + parsed->max_vectorize_extent = (*opt_int_imm)->value; } } else if (ann.first == attr::meta_schedule_unroll_explicit) { found = true; - if (const auto* imm = ann.second.as<tir::IntImmNode>()) { - parsed->unroll_explicit = imm->value; + if (auto opt_int_imm = ann.second.as<IntImm>()) { + parsed->unroll_explicit = (*opt_int_imm)->value; } } else if (ann.first == attr::meta_schedule_unroll_implicit) { found = true; - if (const auto* imm = ann.second.as<tir::IntImmNode>()) { - parsed->unroll_implicit = imm->value; + if (auto opt_int_imm = ann.second.as<IntImm>()) { + parsed->unroll_implicit = (*opt_int_imm)->value; } } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 36b14e9eb5..0a7817c02a 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -102,13 +102,8 @@ AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span spa TVM_REGISTER_NODE_TYPE(AssertStmtNode); TVM_REGISTER_GLOBAL("tir.AssertStmt") - .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body, Span span) { - if (const auto* str = message.as<StringObj>()) { - auto msg = StringImm(str->bytes.data); - return AssertStmt(condition, msg, body, span); - } else { - return AssertStmt(condition, Downcast<PrimExpr>(message), body, span); - } + .set_body_typed([](PrimExpr condition, StringImm message, Stmt body, Span span) { + return AssertStmt(condition, message, body, span); }); // For diff --git a/tests/python/meta_schedule/test_meta_schedule_trace_apply.py b/tests/python/meta_schedule/test_meta_schedule_trace_apply.py index 5f0583c890..bbe086c34e 100644 --- a/tests/python/meta_schedule/test_meta_schedule_trace_apply.py +++ b/tests/python/meta_schedule/test_meta_schedule_trace_apply.py @@ -1864,7 +1864,7 @@ def test_dense_add_cpu(): ), pad_value=None, ) - sch.annotate(block_or_loop=b59, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1) + sch.annotate(block_or_loop=b59, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=True) verify(Dense, apply_anchor_trace, DenseAdd, "llvm", DenseAdd_scheduled_cpu) @@ -1930,7 +1930,7 @@ def test_dense_add_cpu_no_write_cache(): ), pad_value=None, ) - sch.annotate(block_or_loop=b50, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=1) + sch.annotate(block_or_loop=b50, ann_key="meta_schedule.layout_rewrite_preproc", ann_val=True) verify(Dense, apply_trace, DenseAdd, "llvm", DenseAdd_cpu_no_write_cache)
