This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8fe6c8ad79 [REFACTOR][S-TIR] Lift STIR-only attributes out of
tir::attr namespace (#18816)
8fe6c8ad79 is described below
commit 8fe6c8ad79f98055c4e6b489ce9c024539fae233
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Feb 24 18:04:30 2026 -0500
[REFACTOR][S-TIR] Lift STIR-only attributes out of tir::attr namespace
(#18816)
---
include/tvm/s_tir/stmt.h | 199 +++++++++++++++++++++
include/tvm/tir/stmt.h | 150 ----------------
.../transform/split_layout_rewrite_preproc.cc | 5 +-
src/s_tir/meta_schedule/mutator/mutate_parallel.cc | 3 +-
.../meta_schedule/mutator/mutate_tile_size.cc | 3 +-
src/s_tir/meta_schedule/mutator/mutate_unroll.cc | 5 +-
.../postproc/rewrite_cooperative_fetch.cc | 7 +-
src/s_tir/meta_schedule/postproc/rewrite_layout.cc | 5 +-
.../postproc/rewrite_parallel_vectorize_unroll.cc | 17 +-
.../postproc/rewrite_reduction_block.cc | 13 +-
.../meta_schedule/postproc/rewrite_tensorize.cc | 5 +-
.../meta_schedule/postproc/verify_gpu_code.cc | 7 +-
.../meta_schedule/schedule_rule/add_rfactor.cc | 3 +-
.../meta_schedule/schedule_rule/auto_inline.cc | 5 +-
.../schedule_rule/multi_level_tiling.cc | 15 +-
.../multi_level_tiling_tensor_core.cc | 37 ++--
.../multi_level_tiling_with_intrin.cc | 3 +-
.../schedule_rule/parallel_vectorize_unroll.cc | 9 +-
.../schedule_rule/random_compute_location.cc | 5 +-
src/s_tir/schedule/analysis/analysis.cc | 3 +-
.../schedule/primitive/annotate_buffer_access.cc | 6 +-
src/s_tir/schedule/primitive/block_annotate.cc | 9 +-
src/s_tir/schedule/primitive/compute_inline.cc | 6 +-
src/s_tir/schedule/primitive/read_write_at.cc | 6 +-
src/s_tir/transform/annotate_irregular_loop.cc | 5 +-
src/s_tir/transform/compact_buffer_region.cc | 5 +-
src/s_tir/transform/inject_software_pipeline.cc | 17 +-
src/s_tir/transform/lower_opaque_block.cc | 5 +-
.../manifest_shared_memory_local_stage.cc | 5 +-
src/s_tir/transform/memhammer_lower_auto_copy.cc | 3 +-
.../remove_weight_layout_rewrite_block.cc | 3 +-
src/s_tir/transform/storage_access.cc | 3 +-
src/te/operation/create_primfunc.cc | 5 +-
src/tir/ir/script/script_complete.cc | 5 +-
src/tir/transform/ir_utils.cc | 5 +-
35 files changed, 336 insertions(+), 251 deletions(-)
diff --git a/include/tvm/s_tir/stmt.h b/include/tvm/s_tir/stmt.h
new file mode 100644
index 0000000000..2e6b2278f0
--- /dev/null
+++ b/include/tvm/s_tir/stmt.h
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/s_tir/stmt.h
+ * \brief S-TIR (Schedulable TIR) statement attribute declarations.
+ *
+ * This file contains attribute keys that are specific to the schedulable TIR
+ * (S-TIR) layer, including meta_schedule annotations and schedule primitive /
+ * SBlock annotations.
+ */
+#ifndef TVM_S_TIR_STMT_H_
+#define TVM_S_TIR_STMT_H_
+
+namespace tvm {
+namespace s_tir {
+namespace attr {
+
+// -----------------------------------------------------------------------
+// meta_schedule annotations
+// -----------------------------------------------------------------------
+
+/*! \brief Mark the tiling structure of blocks that are applied by rule
Multi-Level-Tiling */
+constexpr const char* meta_schedule_tiling_structure =
"meta_schedule.tiling_structure";
+
+/*!
+ * \brief Mark that the loop should be further skip and bound to environment
threads to enable
+ * cooperative fetching.
+ */
+constexpr const char* meta_schedule_cooperative_fetch =
"meta_schedule.cooperative_fetch";
+
+/*! \brief The allowed range of thread extent in thread bindings */
+constexpr const char* meta_schedule_thread_extent_low_inclusive =
+ "meta_schedule.thread_extent_low_inclusive";
+
+/*! \brief The allowed range of thread extent in thread bindings */
+constexpr const char* meta_schedule_thread_extent_high_inclusive =
+ "meta_schedule.thread_extent_high_inclusive";
+
+/*! \brief Mark the block whose producer needs to be applied by rule
Random-Compute-Location */
+constexpr const char* meta_schedule_random_compute_producer =
+ "meta_schedule.random_compute_producer";
+
+/*! \brief Mark auto-parallel setting on the block. */
+constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";
+
+/*! \brief Mark auto-vectorize setting on the block. */
+constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";
+
+/*! \brief Mark auto-unroll setting on the block. */
+constexpr const char* meta_schedule_unroll_explicit =
"meta_schedule.unroll_explicit";
+
+/*! \brief Mark auto-unroll setting on the block. */
+constexpr const char* meta_schedule_unroll_implicit =
"meta_schedule.unroll_implicit";
+
+/*! \brief Mark that a block should be further rewritten using tensorization.
*/
+constexpr const char* meta_schedule_auto_tensorize =
"meta_schedule.auto_tensorize";
+
+/*! \brief Mark that a block is a preprocessor block for layout rewrite. */
+constexpr const char* meta_schedule_layout_rewrite_preproc =
"meta_schedule.layout_rewrite_preproc";
+
+/*!
+ * \brief Mark that the init statement of a block should be further rewritten
using tensorization.
+ */
+constexpr const char* meta_schedule_auto_tensorize_init =
"meta_schedule.auto_tensorize_init";
+
+/*! \brief Mark that tensor core is enabled in the PrimExpr */
+constexpr const char* meta_schedule_tensor_core_enabled =
"meta_schedule.tensor_core_enabled";
+
+/*!
+ * \brief Mark a block as generated by cache_read or cache_write block.
+ * 0 means cache_read; 1 means cache_write.
+ * \sa meta_schedule_cache_type_read
+ * \sa meta_schedule_cache_type_write
+ */
+constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";
+
+/*! \sa meta_schedule_cache_type */
+constexpr const int meta_schedule_cache_type_read = 0;
+
+/*! \sa meta_schedule_cache_type */
+constexpr const int meta_schedule_cache_type_write = 1;
+
+/*! \brief Mark that a block is disallowed in auto inline. */
+constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
+
+// -----------------------------------------------------------------------
+// Schedule primitive / SBlock annotations
+// -----------------------------------------------------------------------
+
+/*!
+ * \brief Mark whether the script-completer need to fill in missing access
region
+ * during script parsing.
+ * \note The result should be a integer mask with range [0, 4).
+ * if (mask & 1) the read region should be detected,
+ * if (mask & 2) the write region should be detected.
+ */
+constexpr const char* script_parsing_detect_access =
"tir.script_parsing_detect_access";
+
+/*!
+ * \brief Mark that the block need to add predicate for block var bounds
during lowering
+ */
+constexpr const char* require_block_var_bound_predicate =
"require_bound_predicate";
+
+/*! \brief Mark the stage of a statement in the software pipeline */
+constexpr const char* software_pipeline_stage = "software_pipeline_stage";
+
+/*! \brief Mark the order of a statement in the software pipeline */
+constexpr const char* software_pipeline_order = "software_pipeline_order";
+
+/*! \brief List stages in the software pipeline that should run asynchronously
+ * \note All statements in the provided stages are assumed to have asynchronous
+ * semantics (e.g. CUDA async global to shared memory copy).
+ */
+constexpr const char* software_pipeline_async_stages =
"software_pipeline_async_stages";
+
+/*! \brief Mark the buffers which is const access and can be transformed
layout. */
+constexpr const char* layout_free_buffers = "layout_free_buffers";
+
+/*! \brief Mark the local stage for the shared memory access should be added.
*/
+constexpr const char* manifest_shared_memory_local_stage =
"tir.manifest_shared_memory_local_stage";
+
+/*!
+ * \brief Mark alignment of buffer dimension
+ * stmt.node is Tensor
+ * stmt.value is tvm_tuple(dim, align, offset)
+ * This gives hint to require stride of dim to be k * align + offset.
+ */
+constexpr const char* buffer_dim_align = "buffer_dim_align";
+
+/*! \brief Mark that a block has an explicitly specified read region.
+ * This is used to override the default read region inference in TIR.
+ */
+constexpr const char* explicit_read_region = "explicit_read_region";
+
+/*! \brief Mark that a block has an explicitly specified write region.
+ * This is used to override the default write region inference in TIR.
+ */
+constexpr const char* explicit_write_region = "explicit_write_region";
+
+/*! \brief ,ark a ForNode represent an irregular loop of non-structural
control flow edges. */
+constexpr const char* irregular_loop_mark = "irregular_loop_mark";
+
+/*! \brief Mark auto copy for memhammer */
+constexpr const char* auto_copy = "auto_copy";
+
+/*! \brief Mark local stage constraint on data copy */
+constexpr const char* local_stage = "local_stage";
+
+/*! \brief Mark vectorization length constraint on block */
+constexpr const char* vector_bytes = "vector_bytes";
+
+/*!
+ * \brief Mark that a block is executed by a warp. This implies the extend of
threadIdx.x is
+ * warp size.
+ */
+constexpr const char* warp_execution = "warp_execution";
+
+/*!
+ * \brief Marks the layout transforms to be used for a tensor.
+ *
+ * Only applies to a DataProducer, as it should be made part of the
+ * PrimFunc attributes for TIR.
+ */
+constexpr const char* layout_transforms = "layout_transforms";
+
+/*!
+ * \brief Marks the physical axis separators
+ *
+ * Only applies to a DataProducer, as it should be made part of the
+ * Buffer definition in a PrimFunc. See `BufferNode::axis_separators`
+ * for more details.
+ */
+constexpr const char* axis_separators = "axis_separators";
+
+/*!
+ * \brief Mark that the kernel is hand threaded and doesn't need syncs inserted
+ */
+constexpr const char* hand_threaded = "hand_threaded";
+
+} // namespace attr
+} // namespace s_tir
+} // namespace tvm
+#endif // TVM_S_TIR_STMT_H_
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index a26ab2dc32..4fcb91403f 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -958,21 +958,6 @@ constexpr const char* pragma_import_c = "pragma_import_c";
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
/*! \brief Try to modify the AST to support Tensor Core */
constexpr const char* pragma_tensor_core = "pragma_tensor_core";
-/*!
- * \brief Marks the layout transforms to be used for a tensor.
- *
- * Only applies to a DataProducer, as it should be made part of the
- * PrimFunc attributes for TIR.
- */
-constexpr const char* layout_transforms = "layout_transforms";
-/*!
- * \brief Marks the physical axis separators
- *
- * Only applies to a DataProducer, as it should be made part of the
- * Buffer definition in a PrimFunc. See `BufferNode::axis_separators`
- * for more details.
- */
-constexpr const char* axis_separators = "axis_separators";
/*!
* \brief Marks production of double buffer data
*/
@@ -985,13 +970,6 @@ constexpr const char* double_buffer_write =
"double_buffer_write";
constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
constexpr const char* scan_init_scope = "scan_init_scope";
-/*!
- * \brief Mark alignment of buffer dimension
- * stmt.node is Tensor
- * stmt.value is tvm_tuple(dim, align, offset)
- * This gives hint to require stride of dim to be k * align + offset.
- */
-constexpr const char* buffer_dim_align = "buffer_dim_align";
/*! \brief Mark stores/loads with theirs bounds. */
constexpr const char* buffer_bound = "buffer_bound";
/*!
@@ -1059,139 +1037,11 @@ constexpr const char* fragment_shape =
"fragment_shape";
*/
constexpr const char* fragment_layout = "fragment_layout";
-/*!
- * \brief Mark that the kernel is hand threaded and doesn't need syncs inserted
- */
-constexpr const char* hand_threaded = "hand_threaded";
-
-/*!
- * \brief Mark whether the script-completer need to fill in missing access
region
- * during script parsing.
- * \note The result should be a integer mask with range [0, 4).
- * if (mask & 1) the read region should be detected,
- * if (mask & 2) the write region should be detected.
- */
-constexpr const char* script_parsing_detect_access =
"tir.script_parsing_detect_access";
-
/*!
* \brief Mark that the loop should be partitioned.
*/
constexpr const char* pragma_loop_partition_hint =
"pragma_loop_partition_hint";
-/*! \brief Mark the stage of a statement in the software pipeline */
-constexpr const char* software_pipeline_stage = "software_pipeline_stage";
-
-/*! \brief Mark the order of a statement in the software pipeline */
-constexpr const char* software_pipeline_order = "software_pipeline_order";
-
-/*! \brief List stages in the software pipeline that should run asynchronously
- * \note All statements in the provided stages are assumed to have asynchronous
- * semantics (e.g. CUDA async global to shared memory copy).
- */
-constexpr const char* software_pipeline_async_stages =
"software_pipeline_async_stages";
-
-/*! \brief Mark the buffers which is const access and can be transformed
layout. */
-constexpr const char* layout_free_buffers = "layout_free_buffers";
-
-/*! \brief Mark the local stage for the shared memory access should be added.
*/
-constexpr const char* manifest_shared_memory_local_stage =
"tir.manifest_shared_memory_local_stage";
-
-/*! \brief Mark the tiling structure of blocks that are applied by rule
Multi-Level-Tiling */
-constexpr const char* meta_schedule_tiling_structure =
"meta_schedule.tiling_structure";
-
-/*!
- * \brief Mark that the loop should be further skip and bound to environment
threads to enable
- * cooperative fetching.
- */
-constexpr const char* meta_schedule_cooperative_fetch =
"meta_schedule.cooperative_fetch";
-
-/*! \brief The allowed range of thread extent in thread bindings */
-constexpr const char* meta_schedule_thread_extent_low_inclusive =
- "meta_schedule.thread_extent_low_inclusive";
-
-/*! \brief The allowed range of thread extent in thread bindings */
-constexpr const char* meta_schedule_thread_extent_high_inclusive =
- "meta_schedule.thread_extent_high_inclusive";
-
-/*! \brief Mark the block whose producer needs to be applied by rule
Random-Compute-Location */
-constexpr const char* meta_schedule_random_compute_producer =
- "meta_schedule.random_compute_producer";
-
-/*! \brief Mark auto-parallel setting on the block. */
-constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";
-
-/*! \brief Mark auto-vectorize setting on the block. */
-constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";
-
-/*! \brief Mark auto-unroll setting on the block. */
-constexpr const char* meta_schedule_unroll_explicit =
"meta_schedule.unroll_explicit";
-
-/*! \brief Mark auto-unroll setting on the block. */
-constexpr const char* meta_schedule_unroll_implicit =
"meta_schedule.unroll_implicit";
-
-/*! \brief Mark that a block should be further rewritten using tensorization.
*/
-constexpr const char* meta_schedule_auto_tensorize =
"meta_schedule.auto_tensorize";
-
-/*! \brief Mark that a block is a preprocessor block for layout rewrite. */
-constexpr const char* meta_schedule_layout_rewrite_preproc =
"meta_schedule.layout_rewrite_preproc";
-/*!
- * \brief Mark that the init statement of a block should be further rewritten
using tensorization.
- */
-constexpr const char* meta_schedule_auto_tensorize_init =
"meta_schedule.auto_tensorize_init";
-
-/*!
- * \brief Mark that the block need to add predicate for block var bounds
during lowering
- */
-constexpr const char* require_block_var_bound_predicate =
"require_bound_predicate";
-
-/*! \brief Mark that tensor core is enabled in the PrimExpr */
-constexpr const char* meta_schedule_tensor_core_enabled =
"meta_schedule.tensor_core_enabled";
-
-/*!
- * \brief Mark a block as generated by cache_read or cache_write block.
- * 0 means cache_read; 1 means cache_write.
- * \sa meta_schedule_cache_type_read
- * \sa meta_schedule_cache_type_write
- */
-constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";
-
-/*! \sa meta_schedule_cache_type */
-constexpr const int meta_schedule_cache_type_read = 0;
-
-/*! \sa meta_schedule_cache_type */
-constexpr const int meta_schedule_cache_type_write = 1;
-
-/*! \brief Mark auto copy for memhammer */
-constexpr const char* auto_copy = "auto_copy";
-
-/*! \brief Mark local stage constraint on data copy */
-constexpr const char* local_stage = "local_stage";
-
-/*! \brief Mark vectorization length constraint on block */
-constexpr const char* vector_bytes = "vector_bytes";
-
-/*!
- * \brief Mark that a block is executed by a warp. This implies the extend of
threadIdx.x is
- * warp size.
- */
-constexpr const char* warp_execution = "warp_execution";
-
-/*! \brief Mark that a block is disallowed in auto inline. */
-constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
-
-/*! \brief Mark that a block has an explicitly specified read region.
- * This is used to override the default read region inference in TIR.
- */
-constexpr const char* explicit_read_region = "explicit_read_region";
-
-/*! \brief Mark that a block has an explicitly specified write region.
- * This is used to override the default write region inference in TIR.
- */
-constexpr const char* explicit_write_region = "explicit_write_region";
-
-/*! \brief ,ark a ForNode represent an irregular loop of non-structural
control flow edges. */
-constexpr const char* irregular_loop_mark = "irregular_loop_mark";
-
/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc
b/src/relax/transform/split_layout_rewrite_preproc.cc
index 66f706e478..535cf36149 100644
--- a/src/relax/transform/split_layout_rewrite_preproc.cc
+++ b/src/relax/transform/split_layout_rewrite_preproc.cc
@@ -25,6 +25,7 @@
#include <tvm/ir/transform.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/stmt_functor.h>
@@ -166,7 +167,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator {
}
Stmt VisitStmt_(const SBlockNode* op) final {
SBlock block = Downcast<SBlock>(StmtMutator::VisitStmt_(op));
- auto it = op->annotations.find(attr::meta_schedule_layout_rewrite_preproc);
+ auto it =
op->annotations.find(s_tir::attr::meta_schedule_layout_rewrite_preproc);
bool is_layout_rewrite_preproc =
it != op->annotations.end() &&
is_one(Downcast<PrimExpr>((*it).second));
@@ -204,7 +205,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator {
RewriteInfo{buffer_index, op->reads[0]->buffer,
op->writes[0]->buffer});
auto new_annotations = op->annotations;
- new_annotations.erase(attr::meta_schedule_layout_rewrite_preproc);
+ new_annotations.erase(s_tir::attr::meta_schedule_layout_rewrite_preproc);
auto n = ffi::make_object<SBlockNode>(*block.get());
n->annotations = new_annotations;
return SBlock(n);
diff --git a/src/s_tir/meta_schedule/mutator/mutate_parallel.cc
b/src/s_tir/meta_schedule/mutator/mutate_parallel.cc
index 2bd43e0a65..fa540b7231 100644
--- a/src/s_tir/meta_schedule/mutator/mutate_parallel.cc
+++ b/src/s_tir/meta_schedule/mutator/mutate_parallel.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <algorithm>
#include <map>
@@ -40,7 +41,7 @@ bool IsAnnotateWithParallel(const Instruction& inst) {
}
TVM_FFI_ICHECK_EQ(inst->attrs.size(), 1);
ffi::String ann_key = Downcast<ffi::String>(inst->attrs[0]);
- return ann_key == tir::attr::meta_schedule_parallel;
+ return ann_key == s_tir::attr::meta_schedule_parallel;
}
/*!
diff --git a/src/s_tir/meta_schedule/mutator/mutate_tile_size.cc
b/src/s_tir/meta_schedule/mutator/mutate_tile_size.cc
index d6a43607c0..9014f23de8 100644
--- a/src/s_tir/meta_schedule/mutator/mutate_tile_size.cc
+++ b/src/s_tir/meta_schedule/mutator/mutate_tile_size.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <mutex>
#include <unordered_map>
@@ -119,7 +120,7 @@ void FindSampleVectorize(const Trace& trace,
std::vector<Instruction>* inst,
if (inst->kind.same_as(inst_annotate)) {
TVM_FFI_ICHECK_EQ(inst->attrs.size(), 1);
TVM_FFI_ICHECK_EQ(inst->inputs.size(), 2);
- if (Downcast<ffi::String>(inst->attrs[0]) ==
tir::attr::meta_schedule_cooperative_fetch) {
+ if (Downcast<ffi::String>(inst->attrs[0]) ==
s_tir::attr::meta_schedule_cooperative_fetch) {
const auto* ann_val = inst->inputs[1].as<s_tir::ExprRVNode>();
TVM_FFI_ICHECK(ann_val);
annotated.insert(ann_val);
diff --git a/src/s_tir/meta_schedule/mutator/mutate_unroll.cc
b/src/s_tir/meta_schedule/mutator/mutate_unroll.cc
index 47fed617b6..13d80f322f 100644
--- a/src/s_tir/meta_schedule/mutator/mutate_unroll.cc
+++ b/src/s_tir/meta_schedule/mutator/mutate_unroll.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include "../utils.h"
@@ -37,8 +38,8 @@ bool IsAnnotateWithUnroll(const Instruction& inst) {
}
TVM_FFI_ICHECK_EQ(inst->attrs.size(), 1);
ffi::String ann_key = Downcast<ffi::String>(inst->attrs[0]);
- return ann_key == tir::attr::meta_schedule_unroll_explicit ||
- ann_key == tir::attr::meta_schedule_unroll_implicit;
+ return ann_key == s_tir::attr::meta_schedule_unroll_explicit ||
+ ann_key == s_tir::attr::meta_schedule_unroll_implicit;
}
} // namespace s_tir
diff --git a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc
b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc
index 151a50b4a0..2b2242d4f0 100644
--- a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc
+++ b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include "../utils.h"
@@ -62,7 +63,7 @@ ffi::Optional<SBlockRV> ParseAnnotate(const Schedule& sch,
const Instruction& in
TVM_FFI_ICHECK_EQ(inst->inputs.size(), 2);
TVM_FFI_ICHECK_EQ(inst->attrs.size(), 1);
ffi::String ann_key = Downcast<ffi::String>(inst->attrs[0]);
- if (ann_key != tir::attr::meta_schedule_cooperative_fetch) {
+ if (ann_key != s_tir::attr::meta_schedule_cooperative_fetch) {
return std::nullopt;
}
*vector_lane =
Downcast<Integer>(sch->Get(Downcast<ExprRV>(inst->inputs[1])))->value;
@@ -83,7 +84,7 @@ bool ParseWarpExecutionAnn(const Schedule& sch, const
Instruction& inst) {
TVM_FFI_ICHECK_EQ(inst->inputs.size(), 2);
TVM_FFI_ICHECK_EQ(inst->attrs.size(), 1);
ffi::String ann_key = Downcast<ffi::String>(inst->attrs[0]);
- return ann_key == tir::attr::warp_execution;
+ return ann_key == s_tir::attr::warp_execution;
}
size_t GetMaxUsedDtypeBytes(SBlock block) {
@@ -176,7 +177,7 @@ bool RewriteCooperativeFetchNode::Apply(const
s_tir::Schedule& sch) {
}
auto task = [thread_extent_x, thread_extent_y, vector_lane, sch,
block = opt_block_rv.value()]() mutable -> void {
- sch->Unannotate(block, tir::attr::meta_schedule_cooperative_fetch);
+ sch->Unannotate(block, s_tir::attr::meta_schedule_cooperative_fetch);
s_tir::LoopRV fused = sch->GetLoops(block).back();
int64_t fused_extent = -1;
if (const int64_t* extent =
s_tir::GetLoopIntExtent(sch->Get(fused).get())) {
diff --git a/src/s_tir/meta_schedule/postproc/rewrite_layout.cc
b/src/s_tir/meta_schedule/postproc/rewrite_layout.cc
index 78000e54ed..6f2fe7741c 100644
--- a/src/s_tir/meta_schedule/postproc/rewrite_layout.cc
+++ b/src/s_tir/meta_schedule/postproc/rewrite_layout.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <optional>
#include <unordered_set>
@@ -121,7 +122,7 @@ class LayoutFreeBufferCollector : public StmtVisitor {
ffi::Array<Buffer> CollectLayoutFreeBuffers(const PrimFuncNode* func) {
// Only rewrite PrimFuncs with attr "layout_free_buffers"
ffi::Array<Integer> layout_free_buffer_index =
- func->GetAttr(tir::attr::layout_free_buffers,
ffi::Array<Integer>()).value();
+ func->GetAttr(s_tir::attr::layout_free_buffers,
ffi::Array<Integer>()).value();
ffi::Array<Buffer> layout_free_buffers;
for (const Integer& index : layout_free_buffer_index) {
@@ -186,7 +187,7 @@ bool RewriteLayout(const Schedule& sch) {
std::vector<std::pair<StmtSRef, ffi::String>> results;
auto add_layout_rewrite_block = [&sch](SBlockRV consumer_block_rv, int
buffer_index) {
SBlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv,
buffer_index, "global");
- sch->Annotate(rewrite_block_rv,
tir::attr::meta_schedule_layout_rewrite_preproc, true);
+ sch->Annotate(rewrite_block_rv,
s_tir::attr::meta_schedule_layout_rewrite_preproc, true);
};
for (const auto& [g_var, base_func] : sch->mod()->functions) {
diff --git
a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index 099f2f449c..d095d98612 100644
--- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include "../utils.h"
@@ -107,22 +108,22 @@ bool ParseAnnotation(const SBlock& block,
ParsedAnnotation* parsed) {
bool found = false;
*parsed = ParsedAnnotation{-1, -1, -1, -1, -1, -1};
for (const auto& ann : block->annotations) {
- if (ann.first == tir::attr::meta_schedule_parallel) {
+ if (ann.first == s_tir::attr::meta_schedule_parallel) {
found = true;
if (auto opt_int_imm = ann.second.try_cast<IntImm>()) {
parsed->max_parallel_extent = (*opt_int_imm)->value;
}
- } else if (ann.first == tir::attr::meta_schedule_vectorize) {
+ } else if (ann.first == s_tir::attr::meta_schedule_vectorize) {
found = true;
if (auto opt_int_imm = ann.second.try_cast<IntImm>()) {
parsed->max_vectorize_extent = (*opt_int_imm)->value;
}
- } else if (ann.first == tir::attr::meta_schedule_unroll_explicit) {
+ } else if (ann.first == s_tir::attr::meta_schedule_unroll_explicit) {
found = true;
if (auto opt_int_imm = ann.second.try_cast<IntImm>()) {
parsed->unroll_explicit = (*opt_int_imm)->value;
}
- } else if (ann.first == tir::attr::meta_schedule_unroll_implicit) {
+ } else if (ann.first == s_tir::attr::meta_schedule_unroll_implicit) {
found = true;
if (auto opt_int_imm = ann.second.try_cast<IntImm>()) {
parsed->unroll_implicit = (*opt_int_imm)->value;
@@ -135,16 +136,16 @@ bool ParseAnnotation(const SBlock& block,
ParsedAnnotation* parsed) {
void RemoveParsedAnn(const Schedule& sch, const SBlockRV& block_rv,
const ParsedAnnotation& parsed) {
if (parsed.max_parallel_extent != -1) {
- sch->Unannotate(block_rv, tir::attr::meta_schedule_parallel);
+ sch->Unannotate(block_rv, s_tir::attr::meta_schedule_parallel);
}
if (parsed.max_vectorize_extent != -1) {
- sch->Unannotate(block_rv, tir::attr::meta_schedule_vectorize);
+ sch->Unannotate(block_rv, s_tir::attr::meta_schedule_vectorize);
}
if (parsed.unroll_explicit != -1) {
- sch->Unannotate(block_rv, tir::attr::meta_schedule_unroll_explicit);
+ sch->Unannotate(block_rv, s_tir::attr::meta_schedule_unroll_explicit);
}
if (parsed.unroll_implicit != -1) {
- sch->Unannotate(block_rv, tir::attr::meta_schedule_unroll_implicit);
+ sch->Unannotate(block_rv, s_tir::attr::meta_schedule_unroll_implicit);
}
}
diff --git a/src/s_tir/meta_schedule/postproc/rewrite_reduction_block.cc
b/src/s_tir/meta_schedule/postproc/rewrite_reduction_block.cc
index b6bc04c3f6..67e15602c2 100644
--- a/src/s_tir/meta_schedule/postproc/rewrite_reduction_block.cc
+++ b/src/s_tir/meta_schedule/postproc/rewrite_reduction_block.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include "../utils.h"
@@ -149,20 +150,20 @@ bool RewriteReductionBlockNode::Apply(const
s_tir::Schedule& sch) {
s_tir::SBlockRV init_block_rv = sch->DecomposeReduction(block_rv,
loop_rvs[decompose_point]);
// Rewrite auto tensorization related annotations
- if (s_tir::GetAnn<ffi::String>(block_sref,
tir::attr::meta_schedule_auto_tensorize)
+ if (s_tir::GetAnn<ffi::String>(block_sref,
s_tir::attr::meta_schedule_auto_tensorize)
.has_value()) {
// Remove tensorization annotation as it shouldn't be propagated to
the init block.
- sch->Unannotate(init_block_rv,
tir::attr::meta_schedule_auto_tensorize);
+ sch->Unannotate(init_block_rv,
s_tir::attr::meta_schedule_auto_tensorize);
ffi::Optional<ffi::String> tensorize_init =
- s_tir::GetAnn<ffi::String>(block_sref,
tir::attr::meta_schedule_auto_tensorize_init);
+ s_tir::GetAnn<ffi::String>(block_sref,
s_tir::attr::meta_schedule_auto_tensorize_init);
// The annotation of tensorization of the init statement should be
moved to the init block
// after 'DecomposeReduction'.
// Annotate to hint `RewriteTensorize` postprocessor even if
tensorize_init is std::nullopt.
- sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize,
+ sch->Annotate(init_block_rv, s_tir::attr::meta_schedule_auto_tensorize,
tensorize_init.value_or(""));
if (tensorize_init.has_value()) {
- sch->Unannotate(block_rv,
tir::attr::meta_schedule_auto_tensorize_init);
- sch->Unannotate(init_block_rv,
tir::attr::meta_schedule_auto_tensorize_init);
+ sch->Unannotate(block_rv,
s_tir::attr::meta_schedule_auto_tensorize_init);
+ sch->Unannotate(init_block_rv,
s_tir::attr::meta_schedule_auto_tensorize_init);
}
}
++rewritten;
diff --git a/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc
b/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc
index 926aed03cd..0090b3a95b 100644
--- a/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc
+++ b/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc
@@ -18,6 +18,7 @@
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/s_tir/meta_schedule/postproc.h>
+#include <tvm/s_tir/stmt.h>
#include <algorithm>
@@ -39,7 +40,7 @@ void CollectTensorizationJobs(
tir::StmtSRef block_sref = sch->GetSRef(block);
std::string block_name =
block_sref->StmtAs<tir::SBlockNode>()->name_hint;
if (ffi::Optional<ffi::String> intrin_name =
- s_tir::GetAnn<ffi::String>(block_sref,
tir::attr::meta_schedule_auto_tensorize)) {
+ s_tir::GetAnn<ffi::String>(block_sref,
s_tir::attr::meta_schedule_auto_tensorize)) {
if (intrin_name.value() != "") {
jobs->emplace_back(block_name, func_name, [sch,
intrin_name](s_tir::SBlockRV block) {
try {
@@ -99,7 +100,7 @@ bool RewriteTensorizeNode::Apply(const s_tir::Schedule& sch)
{
const ffi::String& func_name = std::get<1>(job);
const auto& job_func = std::get<2>(job);
SBlockRV block = sch->GetSBlock(block_name, func_name);
- sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize);
+ sch->Unannotate(block, s_tir::attr::meta_schedule_auto_tensorize);
job_func(block);
}
return true;
diff --git a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc
b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc
index fa71fb0131..f05c28fca0 100644
--- a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/transform.h>
@@ -72,13 +73,13 @@ class ThreadExtentChecker : private StmtVisitor {
void VisitStmt_(const SBlockNode* block) {
int old_thread_idx_x = thread_idx_x;
- if (block->annotations.count(tir::attr::warp_execution)) {
+ if (block->annotations.count(s_tir::attr::warp_execution)) {
thread_idx_x = thread_warp_size_;
}
if (ffi::Optional<Integer> low_inclusive =
- GetAnn<Integer>(block,
tir::attr::meta_schedule_thread_extent_low_inclusive)) {
+ GetAnn<Integer>(block,
s_tir::attr::meta_schedule_thread_extent_low_inclusive)) {
if (ffi::Optional<Integer> high_inclusive =
- GetAnn<Integer>(block,
tir::attr::meta_schedule_thread_extent_high_inclusive)) {
+ GetAnn<Integer>(block,
s_tir::attr::meta_schedule_thread_extent_high_inclusive)) {
int64_t low = low_inclusive.value()->value;
int64_t high = high_inclusive.value()->value;
int64_t thread_extent_product = thread_idx_x * thread_idx_y *
thread_idx_z;
diff --git a/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc
b/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc
index f26fa9dc51..d618661ea0 100644
--- a/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc
+++ b/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include "../utils.h"
@@ -114,7 +115,7 @@ ffi::Array<s_tir::Schedule> AddRFactorNode::Apply(const
s_tir::Schedule& sch,
// Annotate that the rfactor block, which is now the producer of the
original block, needs to
// be considered by the rule Random-Compute-Location.
- sch_tmp->Annotate(block_rv,
tir::attr::meta_schedule_random_compute_producer, Integer(1));
+ sch_tmp->Annotate(block_rv,
s_tir::attr::meta_schedule_random_compute_producer, Integer(1));
res.push_back(sch_tmp);
} catch (const tvm::runtime::Error& e) {
}
diff --git a/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc
b/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc
index b4433f160c..3096da373f 100644
--- a/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include "../utils.h"
@@ -159,7 +160,7 @@ inline InlineType AutoInlineNode::CheckInline(const
s_tir::Schedule& sch,
}
// Cond 6. The block is disallowed for auto inline
if (ffi::Optional<ffi::String> ann =
- s_tir::GetAnn<ffi::String>(block_sref,
tir::attr::meta_schedule_inline_rule)) {
+ s_tir::GetAnn<ffi::String>(block_sref,
s_tir::attr::meta_schedule_inline_rule)) {
if (ann.value() == "disable") return InlineType::kNoInline;
}
// Last cond: Check inline into the consumers or the spatial producer
@@ -176,7 +177,7 @@ inline InlineType AutoInlineNode::CheckInline(const
s_tir::Schedule& sch,
if (producer_srefs.size() == 1 &&
s_tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) &&
CanReverseComputeInline(state, block_sref) &&
- !GetAnn<ffi::String>(producer_srefs[0],
tir::attr::meta_schedule_auto_tensorize)
+ !GetAnn<ffi::String>(producer_srefs[0],
s_tir::attr::meta_schedule_auto_tensorize)
.has_value()) {
return InlineType::kInlineIntoProducer;
}
diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc
b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc
index f9823e7ad6..613fcfb43c 100644
--- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -20,6 +20,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/s_tir/meta_schedule/schedule_rule.h>
+#include <tvm/s_tir/stmt.h>
#include <algorithm>
#include <utility>
@@ -111,7 +112,7 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const
TuneContext& context)
ffi::Array<Schedule> MultiLevelTilingNode::Apply(const Schedule& sch, const
SBlockRV& block_rv) {
if ((filter_fn_ && filter_fn_.value()(sch,
sch->GetSRef(block_rv)).cast<bool>()) ||
NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) {
- sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure,
structure);
+ sch->Annotate(block_rv, s_tir::attr::meta_schedule_tiling_structure,
structure);
ffi::Array<Schedule> results;
for (auto&& state : ApplySubRules({State(sch, block_rv)})) {
@@ -281,9 +282,9 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State
state,
if (spatial_loop_product > 2 * this->thread_warp_size_) {
low_inclusive = this->thread_warp_size_;
}
- sch->Annotate(block_rv,
tir::attr::meta_schedule_thread_extent_low_inclusive,
+ sch->Annotate(block_rv,
s_tir::attr::meta_schedule_thread_extent_low_inclusive,
Integer(low_inclusive));
- sch->Annotate(block_rv,
tir::attr::meta_schedule_thread_extent_high_inclusive,
+ sch->Annotate(block_rv,
s_tir::attr::meta_schedule_thread_extent_high_inclusive,
Integer(high_inclusive));
}
return {state};
@@ -351,11 +352,11 @@ std::vector<State>
MultiLevelTilingNode::AddAsyncPipeline(State state) const {
for (int stage : this->stages) {
State new_state = state->Copy();
LoopRV r_loop_fused =
new_state->sch->Fuse(new_state->tiles[r_indices_[0]]);
- new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_stage,
+ new_state->sch->Annotate(r_loop_fused,
s_tir::attr::software_pipeline_stage,
ffi::Array<Integer>{0, 0, stage - 2});
- new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_order,
+ new_state->sch->Annotate(r_loop_fused,
s_tir::attr::software_pipeline_order,
ffi::Array<Integer>{0, 1, 2});
- new_state->sch->Annotate(r_loop_fused,
tir::attr::software_pipeline_async_stages,
+ new_state->sch->Annotate(r_loop_fused,
s_tir::attr::software_pipeline_async_stages,
ffi::Array<Integer>{0});
ret.push_back(std::move(new_state));
}
@@ -393,7 +394,7 @@ void
MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
s_tir::ExprRV vector_load_len =
(*sch)->SampleCategorical(support::AsArray<int,
Integer>(valid_vector_lens),
ffi::Array<FloatImm>(n,
FloatImm(DataType::Float(32), prob)));
- (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch,
vector_load_len);
+ (*sch)->Annotate(block, s_tir::attr::meta_schedule_cooperative_fetch,
vector_load_len);
}
}
diff --git
a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index b8162acdbb..6b2338210d 100644
--- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -18,6 +18,7 @@
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/s_tir/meta_schedule/schedule_rule.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/op.h>
#include <algorithm>
@@ -244,7 +245,7 @@ ffi::Array<Schedule>
MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch,
const TensorCoreIntrinGroup& intrin_group = intrin_groups[kv.first];
const s_tir::AutoTensorizeMappingInfo& mapping_info = kv.second;
Schedule new_sch = sch->Copy();
- new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure,
structure);
+ new_sch->Annotate(block_rv, s_tir::attr::meta_schedule_tiling_structure,
structure);
initial_states.push_back(TensorCoreState(intrin_group, mapping_info,
new_sch, block_rv, true));
}
ffi::Array<Schedule> results;
@@ -294,7 +295,7 @@ void
MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(
ffi::Optional<LoopRV> loop = s_tir::TileWithTensorIntrin(*sch, block_rv,
intrin_name).value();
TVM_FFI_ICHECK(loop.defined());
SBlockRV blockized_outer = (*sch)->Blockize(loop.value());
- (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize,
intrin_name);
+ (*sch)->Annotate(blockized_outer, s_tir::attr::meta_schedule_auto_tensorize,
intrin_name);
if (!permuted_layout_annotate_value.empty()) {
(*sch)->Annotate(blockized_outer, "permuted_layout",
permuted_layout_annotate_value);
}
@@ -422,9 +423,9 @@ std::vector<State>
MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta
if (spatial_loop_product > 2 * this->thread_warp_size_) {
low_inclusive = this->thread_warp_size_;
}
- sch->Annotate(block_rv,
tir::attr::meta_schedule_thread_extent_low_inclusive,
+ sch->Annotate(block_rv,
s_tir::attr::meta_schedule_thread_extent_low_inclusive,
Integer(low_inclusive));
- sch->Annotate(block_rv,
tir::attr::meta_schedule_thread_extent_high_inclusive,
+ sch->Annotate(block_rv,
s_tir::attr::meta_schedule_thread_extent_high_inclusive,
Integer(high_inclusive));
}
return {state};
@@ -586,7 +587,7 @@ std::vector<State>
MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore(
sch->ReverseComputeInline(state->tensor_core_reindex_store);
auto loops = sch->GetLoops(cache_write);
auto blockized_store = sch->Blockize(loops[loops.size() - 2]);
- sch->Annotate(blockized_store, tir::attr::meta_schedule_auto_tensorize,
+ sch->Annotate(blockized_store, s_tir::attr::meta_schedule_auto_tensorize,
state->intrin_group.store_intrin);
ffi::Array<LoopRV> buffer_loops = sch->GetLoops(state->write_reuse[0]);
@@ -666,14 +667,14 @@ std::vector<State>
MultiLevelTilingTensorCoreNode::AddSoftwarePipeline(
const s_tir::SBlockRV cache_read = state->read_reuse.at(i);
if (state->is_mma) {
// Add vector bytes for memhammer
- sch->Annotate(cache_read, tir::attr::vector_bytes, Integer(16));
+ sch->Annotate(cache_read, s_tir::attr::vector_bytes, Integer(16));
if (!state->use_async) {
- sch->Annotate(cache_read, tir::attr::local_stage, Integer(1));
+ sch->Annotate(cache_read, s_tir::attr::local_stage, Integer(1));
sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0));
}
} else {
// Add local stage and double buffering
- sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage,
Integer(1));
+ sch->Annotate(cache_read,
s_tir::attr::manifest_shared_memory_local_stage, Integer(1));
sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0));
}
}
@@ -705,16 +706,16 @@ std::vector<State>
MultiLevelTilingTensorCoreNode::AddSoftwarePipeline(
// epilogue:
// compute matmul with fragment K1 - 1
//
- sch->Annotate(state->tiles[r_indices_[1]].back(),
tir::attr::software_pipeline_stage,
+ sch->Annotate(state->tiles[r_indices_[1]].back(),
s_tir::attr::software_pipeline_stage,
ffi::Array<Integer>{0, 0, 1});
- sch->Annotate(state->tiles[r_indices_[1]].back(),
tir::attr::software_pipeline_order,
+ sch->Annotate(state->tiles[r_indices_[1]].back(),
s_tir::attr::software_pipeline_order,
ffi::Array<Integer>{0, 1, 2});
if (state->is_mma && state->use_async) {
- sch->Annotate(state->tiles[r_indices_[0]].back(),
tir::attr::software_pipeline_async_stages,
+ sch->Annotate(state->tiles[r_indices_[0]].back(),
s_tir::attr::software_pipeline_async_stages,
ffi::Array<Integer>{0});
- sch->Annotate(state->tiles[r_indices_[0]].back(),
tir::attr::software_pipeline_stage,
+ sch->Annotate(state->tiles[r_indices_[0]].back(),
s_tir::attr::software_pipeline_stage,
ffi::Array<Integer>{0, 0, 1, 2, 2});
- sch->Annotate(state->tiles[r_indices_[0]].back(),
tir::attr::software_pipeline_order,
+ sch->Annotate(state->tiles[r_indices_[0]].back(),
s_tir::attr::software_pipeline_order,
ffi::Array<Integer>{0, 1, 3, 2, 4});
} else {
// Outer software pipeline: Interleave the outer loop with the (pipelined)
inner loop.
@@ -757,9 +758,9 @@ std::vector<State>
MultiLevelTilingTensorCoreNode::AddSoftwarePipeline(
// // epilogue of the inner pipeline
// compute matmul with fragment K1 - 1 of tile K0 - 1
//
- sch->Annotate(state->tiles[r_indices_[0]].back(),
tir::attr::software_pipeline_stage,
+ sch->Annotate(state->tiles[r_indices_[0]].back(),
s_tir::attr::software_pipeline_stage,
ffi::Array<Integer>{0, 0, 0, 0, 0, 1, 1});
- sch->Annotate(state->tiles[r_indices_[0]].back(),
tir::attr::software_pipeline_order,
+ sch->Annotate(state->tiles[r_indices_[0]].back(),
s_tir::attr::software_pipeline_order,
ffi::Array<Integer>{0, 3, 1, 4, 5, 2, 6});
}
@@ -900,11 +901,11 @@ inline std::vector<State>
MultiLevelTilingTensorCoreNode::TransformForTensorizat
state->block_rv = state->sch->Blockize(transformed_loop_rv.value());
// Add annotations for post processors.
- state->sch->Annotate(state->block_rv,
tir::attr::meta_schedule_auto_tensorize,
+ state->sch->Annotate(state->block_rv,
s_tir::attr::meta_schedule_auto_tensorize,
state->intrin_group.compute_intrin);
- state->sch->Annotate(state->block_rv,
tir::attr::meta_schedule_auto_tensorize_init,
+ state->sch->Annotate(state->block_rv,
s_tir::attr::meta_schedule_auto_tensorize_init,
state->intrin_group.init_intrin);
- state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Integer(1));
+ state->sch->Annotate(state->block_rv, s_tir::attr::warp_execution,
Integer(1));
return {std::move(state)};
}
diff --git
a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
index ed1ffceb6f..a790a8fa0b 100644
--- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
+++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
@@ -18,6 +18,7 @@
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include "../../schedule/analysis.h"
#include "../../schedule/transform.h"
@@ -40,7 +41,7 @@ ffi::Optional<s_tir::SBlockRV> TileForIntrin(s_tir::Schedule
sch, s_tir::SBlockR
}
TVM_FFI_ICHECK(tiled_loop_rv.defined());
s_tir::SBlockRV outer_block = sch->Blockize(tiled_loop_rv.value());
- sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize,
ffi::String(intrin_name));
+ sch->Annotate(outer_block, s_tir::attr::meta_schedule_auto_tensorize,
ffi::String(intrin_name));
return outer_block;
}
diff --git a/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
b/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
index 036eec6bc2..6d7fd3063f 100644
--- a/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
+++ b/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include "../utils.h"
@@ -61,12 +62,12 @@ class ParallelizeVectorizeUnrollNode : public
ScheduleRuleNode {
// Parallelization
if (max_jobs_per_core != -1) {
- sch->Annotate(root_rv, tir::attr::meta_schedule_parallel,
+ sch->Annotate(root_rv, s_tir::attr::meta_schedule_parallel,
Integer(this->max_parallel_extent_));
}
// Vectorization
if (max_vectorize_extent != -1) {
- sch->Annotate(root_rv, tir::attr::meta_schedule_vectorize,
Integer(max_vectorize_extent));
+ sch->Annotate(root_rv, s_tir::attr::meta_schedule_vectorize,
Integer(max_vectorize_extent));
}
// Unroll
if (!unroll_max_steps.empty() && !s_tir::CheckSpatialPrimFunc(sch,
root_rv)) {
@@ -75,9 +76,9 @@ class ParallelizeVectorizeUnrollNode : public
ScheduleRuleNode {
ffi::Array<FloatImm> probs(n, FloatImm(DataType::Float(32), prob));
PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs);
if (unroll_explicit) {
- sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit,
max_step);
+ sch->Annotate(root_rv, s_tir::attr::meta_schedule_unroll_explicit,
max_step);
} else {
- sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_implicit,
max_step);
+ sch->Annotate(root_rv, s_tir::attr::meta_schedule_unroll_implicit,
max_step);
}
}
return {sch};
diff --git a/src/s_tir/meta_schedule/schedule_rule/random_compute_location.cc
b/src/s_tir/meta_schedule/schedule_rule/random_compute_location.cc
index 4fb9034d1a..f00c1c87b4 100644
--- a/src/s_tir/meta_schedule/schedule_rule/random_compute_location.cc
+++ b/src/s_tir/meta_schedule/schedule_rule/random_compute_location.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include "../utils.h"
@@ -43,10 +44,10 @@ class RandomComputeLocationNode : public ScheduleRuleNode {
// access the input block. Hence we collect its producer ahead of time.
// - Note that only single producer is allowed in this case.
ffi::Array<s_tir::SBlockRV> producers{nullptr};
- if (s_tir::HasAnn(sch->GetSRef(block_rv),
tir::attr::meta_schedule_random_compute_producer,
+ if (s_tir::HasAnn(sch->GetSRef(block_rv),
s_tir::attr::meta_schedule_random_compute_producer,
true)) {
producers = sch->GetProducers(block_rv);
- sch->Unannotate(block_rv,
tir::attr::meta_schedule_random_compute_producer);
+ sch->Unannotate(block_rv,
s_tir::attr::meta_schedule_random_compute_producer);
TVM_FFI_ICHECK_EQ(producers.size(), 1);
}
diff --git a/src/s_tir/schedule/analysis/analysis.cc
b/src/s_tir/schedule/analysis/analysis.cc
index 9a7763660e..e210881b2b 100644
--- a/src/s_tir/schedule/analysis/analysis.cc
+++ b/src/s_tir/schedule/analysis/analysis.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include "../ir_comparator.h"
#include "../utils.h"
@@ -957,7 +958,7 @@ StmtSRef GetSRefLowestCommonAncestor(const
ffi::Array<StmtSRef>& srefs) {
}
bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) {
- return GetAnn<ffi::String>(block_sref,
tir::attr::meta_schedule_tiling_structure).has_value();
+ return GetAnn<ffi::String>(block_sref,
s_tir::attr::meta_schedule_tiling_structure).has_value();
}
std::pair<ffi::Array<StmtSRef>, std::vector<int>> CollectComputeLocation(
diff --git a/src/s_tir/schedule/primitive/annotate_buffer_access.cc
b/src/s_tir/schedule/primitive/annotate_buffer_access.cc
index 2e77570e4e..9962cc47fc 100644
--- a/src/s_tir/schedule/primitive/annotate_buffer_access.cc
+++ b/src/s_tir/schedule/primitive/annotate_buffer_access.cc
@@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <tvm/s_tir/stmt.h>
+
#include "../utils.h"
namespace tvm {
@@ -51,8 +53,8 @@ class AnnotateRegionRewriter : public StmtExprMutator {
// Annotate the block with explicit_read_region or explicit_write_region
ffi::Map<ffi::String, ffi::Any> new_annotations = n->annotations;
ffi::String annotation_key = buffer_index_type_ == BufferIndexType::kWrite
- ? tir::attr::explicit_write_region
- : tir::attr::explicit_read_region;
+ ? s_tir::attr::explicit_write_region
+ : s_tir::attr::explicit_read_region;
if (new_annotations.count(annotation_key)) {
ffi::Array<Integer> buffer_indices =
Downcast<ffi::Array<Integer>>(new_annotations[annotation_key]);
diff --git a/src/s_tir/schedule/primitive/block_annotate.cc
b/src/s_tir/schedule/primitive/block_annotate.cc
index ba8b8ccb83..f6e9fa8bab 100644
--- a/src/s_tir/schedule/primitive/block_annotate.cc
+++ b/src/s_tir/schedule/primitive/block_annotate.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/container/tuple.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/expr.h>
#include "../../../tir/transform/ir_utils.h"
@@ -149,14 +150,14 @@ class StorageAlignInvalidAnnotationError : public
ScheduleError {
std::ostringstream os;
os << "The block annotation for storage align is expected to be an array
of 4-integer-tuples "
"(buffer_index, axis, factor, offset). However, the block annotation
with key "
- << tir::attr::buffer_dim_align << " of the block {0} is "
- << block_->annotations.at(tir::attr::buffer_dim_align) << ", which is
unexpected.";
+ << s_tir::attr::buffer_dim_align << " of the block {0} is "
+ << block_->annotations.at(s_tir::attr::buffer_dim_align) << ", which is
unexpected.";
return os.str();
}
static StorageAlignAnnotation CheckAndGetAnnotation(const IRModule& mod,
const SBlock& block) {
// Get existing annotation value.
- auto it = block->annotations.find(tir::attr::buffer_dim_align);
+ auto it = block->annotations.find(s_tir::attr::buffer_dim_align);
if (it != block->annotations.end()) {
if (!IsValidAnnotation(block, (*it).second)) {
throw StorageAlignInvalidAnnotationError(mod, block);
@@ -252,7 +253,7 @@ void StorageAlign(ScheduleState self, const StmtSRef&
block_sref, int buffer_ind
// Step 3: Replace the block with the new annotation
SBlock new_block =
- WithAnnotation(block_ptr, tir::attr::buffer_dim_align,
storage_align_annotation);
+ WithAnnotation(block_ptr, s_tir::attr::buffer_dim_align,
storage_align_annotation);
self->Replace(block_sref, new_block, {{ffi::GetRef<SBlock>(block_ptr),
new_block}});
}
diff --git a/src/s_tir/schedule/primitive/compute_inline.cc
b/src/s_tir/schedule/primitive/compute_inline.cc
index ccc5ea3ccd..4ceb444ecd 100644
--- a/src/s_tir/schedule/primitive/compute_inline.cc
+++ b/src/s_tir/schedule/primitive/compute_inline.cc
@@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <tvm/s_tir/stmt.h>
+
#include "../utils.h"
namespace tvm {
@@ -684,7 +686,7 @@ class ReverseComputeInliner : public BaseInliner {
producer_store = producer_if->then_case.as<BufferStoreNode>();
} else {
producer_store = producer_block_->body.as<BufferStoreNode>();
- if (producer_block_->annotations.count(tir::attr::auto_copy) != 0) {
+ if (producer_block_->annotations.count(s_tir::attr::auto_copy) != 0) {
const ForNode* producer_inner_loop =
producer_block_->body.as<ForNode>();
while (producer_inner_loop->body.as<ForNode>()) {
producer_inner_loop = producer_inner_loop->body.as<ForNode>();
@@ -720,7 +722,7 @@ class ReverseComputeInliner : public BaseInliner {
subst_map.Set(iter->var, binding);
analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min,
iter->dom->extent));
}
- if (producer_block->annotations.count(tir::attr::auto_copy) != 0) {
+ if (producer_block->annotations.count(s_tir::attr::auto_copy) != 0) {
auto bind = [&](const ForNode* loop) {
analyzer_.Bind(loop->loop_var,
Range::FromMinExtent(make_zero(loop->extent->dtype),
loop->extent));
diff --git a/src/s_tir/schedule/primitive/read_write_at.cc
b/src/s_tir/schedule/primitive/read_write_at.cc
index 8231876bb9..f560de517b 100644
--- a/src/s_tir/schedule/primitive/read_write_at.cc
+++ b/src/s_tir/schedule/primitive/read_write_at.cc
@@ -17,6 +17,8 @@
* under the License.
*/
+#include <tvm/s_tir/stmt.h>
+
#include <string>
#include "../utils.h"
@@ -345,13 +347,13 @@ struct ReadWriteAtImpl {
StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef&
block_sref,
int read_buffer_index, const ffi::String& storage_scope) {
return ReadWriteAtImpl::Main<true>(self, loop_sref, block_sref,
read_buffer_index, storage_scope,
- {{tir::attr::auto_copy, true}});
+ {{s_tir::attr::auto_copy, true}});
}
StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const
StmtSRef& block_sref,
int write_buffer_index, const ffi::String& storage_scope) {
return ReadWriteAtImpl::Main<false>(self, loop_sref, block_sref,
write_buffer_index,
- storage_scope, {{tir::attr::auto_copy,
true}});
+ storage_scope, {{s_tir::attr::auto_copy,
true}});
}
/******** Instruction Registration ********/
diff --git a/src/s_tir/transform/annotate_irregular_loop.cc
b/src/s_tir/transform/annotate_irregular_loop.cc
index b21245ea15..f496a55041 100644
--- a/src/s_tir/transform/annotate_irregular_loop.cc
+++ b/src/s_tir/transform/annotate_irregular_loop.cc
@@ -20,6 +20,7 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
@@ -45,11 +46,11 @@ class IrregularLoopAnnotator : public StmtMutator {
<< "Loop kind " << op->kind << " is invalid for irregular loop " <<
op->loop_var;
for (const char* key :
{tir::attr::pragma_auto_unroll_max_step,
tir::attr::pragma_unroll_explicit,
- tir::attr::pragma_loop_partition_hint,
tir::attr::software_pipeline_stage}) {
+ tir::attr::pragma_loop_partition_hint,
s_tir::attr::software_pipeline_stage}) {
TVM_FFI_ICHECK(!res->annotations.count(key))
<< "Annotation `" << key << "` is invalid for irregular loop " <<
op->loop_var;
}
- res.CopyOnWrite()->annotations.Set(tir::attr::irregular_loop_mark, 1);
+ res.CopyOnWrite()->annotations.Set(s_tir::attr::irregular_loop_mark, 1);
}
std::swap(cur_has_jump, has_jump_);
return res;
diff --git a/src/s_tir/transform/compact_buffer_region.cc
b/src/s_tir/transform/compact_buffer_region.cc
index c9a4a583bc..6f78dfafc9 100644
--- a/src/s_tir/transform/compact_buffer_region.cc
+++ b/src/s_tir/transform/compact_buffer_region.cc
@@ -25,6 +25,7 @@
#include <tvm/arith/int_set.h>
#include <tvm/arith/int_solver.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
@@ -260,8 +261,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
}
};
- record_explicit_region(tir::attr::explicit_read_region,
BufferIndexType::kRead);
- record_explicit_region(tir::attr::explicit_write_region,
BufferIndexType::kWrite);
+ record_explicit_region(s_tir::attr::explicit_read_region,
BufferIndexType::kRead);
+ record_explicit_region(s_tir::attr::explicit_write_region,
BufferIndexType::kWrite);
// Step 3. Record relax position of ancestor_loops_
for (const Buffer& buffer : op->alloc_buffers) {
diff --git a/src/s_tir/transform/inject_software_pipeline.cc
b/src/s_tir/transform/inject_software_pipeline.cc
index 9330f49c4b..6e749dbe64 100644
--- a/src/s_tir/transform/inject_software_pipeline.cc
+++ b/src/s_tir/transform/inject_software_pipeline.cc
@@ -22,6 +22,7 @@
* \brief Transform annotated loops into pipelined one that parallelize
producers and consumers
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
@@ -1134,9 +1135,9 @@ class PipelineInjector : private StmtExprMutator {
}
auto pipeline_stages =
-
Downcast<ffi::Array<Integer>>(op->annotations.at(tir::attr::software_pipeline_stage));
+
Downcast<ffi::Array<Integer>>(op->annotations.at(s_tir::attr::software_pipeline_stage));
auto pipeline_orders =
-
Downcast<ffi::Array<Integer>>(op->annotations.at(tir::attr::software_pipeline_order));
+
Downcast<ffi::Array<Integer>>(op->annotations.at(s_tir::attr::software_pipeline_order));
TVM_FFI_ICHECK_EQ(pipeline_stages.size(), original_order.size())
<< "PrimFunc " << global_symbol_ << " has original order "
<< original_order.Map([](const auto& block) { return block->name_hint;
})
@@ -1147,7 +1148,7 @@ class PipelineInjector : private StmtExprMutator {
<< ", but pipeline annotation is " << pipeline_orders << " with
different size";
std::unordered_set<int> pipeline_async_stages;
- if (auto annot =
op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
+ if (auto annot =
op->annotations.Get(s_tir::attr::software_pipeline_async_stages)) {
for (auto s : Downcast<ffi::Array<Integer>>(annot.value())) {
pipeline_async_stages.insert(s->value);
}
@@ -1156,9 +1157,9 @@ class PipelineInjector : private StmtExprMutator {
ffi::Map<ffi::String, ffi::Any> preserved_annotations;
for (const auto& kv : op->annotations) {
const ffi::String& key = kv.first;
- if (kv.first != tir::attr::software_pipeline_stage &&
- kv.first != tir::attr::software_pipeline_order &&
- kv.first != tir::attr::software_pipeline_async_stages) {
+ if (kv.first != s_tir::attr::software_pipeline_stage &&
+ kv.first != s_tir::attr::software_pipeline_order &&
+ kv.first != s_tir::attr::software_pipeline_async_stages) {
preserved_annotations.Set(key, kv.second);
}
}
@@ -1228,8 +1229,8 @@ class PipelineInjector : private StmtExprMutator {
}
bool HasPipelineAnnotation(const ForNode* op) const {
- auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
- auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
+ auto it1 = op->annotations.find(s_tir::attr::software_pipeline_stage);
+ auto it2 = op->annotations.find(s_tir::attr::software_pipeline_order);
bool has_stage = it1 != op->annotations.end();
bool has_order = it2 != op->annotations.end();
if (has_stage && has_order) {
diff --git a/src/s_tir/transform/lower_opaque_block.cc
b/src/s_tir/transform/lower_opaque_block.cc
index f06d70bc53..cde6c28c8d 100644
--- a/src/s_tir/transform/lower_opaque_block.cc
+++ b/src/s_tir/transform/lower_opaque_block.cc
@@ -22,6 +22,7 @@
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/stmt_functor.h>
@@ -69,7 +70,7 @@ class OpaqueBlockLower : public StmtExprMutator {
tuple.Set<0>(-1);
allocate_aligns.push_back(tuple);
}
- allocate_annotations.Set(tir::attr::buffer_dim_align, allocate_aligns);
+ allocate_annotations.Set(s_tir::attr::buffer_dim_align,
allocate_aligns);
}
body = Allocate(buffer->data, buffer->dtype, allocation_shape,
const_true(), std::move(body),
@@ -107,7 +108,7 @@ class OpaqueBlockLower : public StmtExprMutator {
ffi::String thread_tag = op->thread_binding.value()->thread_tag;
body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
} else if (is_one(extent) && op->annotations.empty() &&
- !op->annotations.count(tir::attr::irregular_loop_mark)) {
+ !op->annotations.count(s_tir::attr::irregular_loop_mark)) {
// Case 2. Unit loop
return body;
} else {
diff --git a/src/s_tir/transform/manifest_shared_memory_local_stage.cc
b/src/s_tir/transform/manifest_shared_memory_local_stage.cc
index 6b4089bd29..5222d53f53 100644
--- a/src/s_tir/transform/manifest_shared_memory_local_stage.cc
+++ b/src/s_tir/transform/manifest_shared_memory_local_stage.cc
@@ -28,6 +28,7 @@
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
@@ -189,7 +190,7 @@ class SharedMemoryLocalStageInserter : public StmtMutator {
}
Stmt VisitStmt_(const SBlockNode* op) final {
- if (op->annotations.count(tir::attr::manifest_shared_memory_local_stage)) {
+ if
(op->annotations.count(s_tir::attr::manifest_shared_memory_local_stage)) {
// Rewrite the shared memory access to load from the intermediate buffer.
// The annotated block must be a leaf block (will be checked during
rewriting). No need to
// visit its body recursively.
@@ -198,7 +199,7 @@ class SharedMemoryLocalStageInserter : public StmtMutator {
auto [target_buffer, new_buffer, new_block, local_stage] =
rewriter.Rewrite(op);
buffer_remap_.Set(target_buffer, new_buffer);
-
new_block.CopyOnWrite()->annotations.erase(tir::attr::manifest_shared_memory_local_stage);
+
new_block.CopyOnWrite()->annotations.erase(s_tir::attr::manifest_shared_memory_local_stage);
buffer_local_stage_.Set(target_buffer, local_stage);
target_buffers_.push_back(target_buffer);
diff --git a/src/s_tir/transform/memhammer_lower_auto_copy.cc
b/src/s_tir/transform/memhammer_lower_auto_copy.cc
index 988702eb96..59b5357399 100644
--- a/src/s_tir/transform/memhammer_lower_auto_copy.cc
+++ b/src/s_tir/transform/memhammer_lower_auto_copy.cc
@@ -20,6 +20,7 @@
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
@@ -665,7 +666,7 @@ class AutoCopyMutator : public StmtExprMutator {
Stmt VisitStmt_(const SBlockNode* op) final {
SBlock block = Downcast<SBlock>(StmtMutator::VisitStmt_(op));
// only rewrite the block annotated with "auto_copy"
- if (!GetAnn<bool>(op, tir::attr::auto_copy).value_or(false)) {
+ if (!GetAnn<bool>(op, s_tir::attr::auto_copy).value_or(false)) {
SBlockNode* n = block.CopyOnWrite();
n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers));
return block;
diff --git a/src/s_tir/transform/remove_weight_layout_rewrite_block.cc
b/src/s_tir/transform/remove_weight_layout_rewrite_block.cc
index 609733f057..bfeee5d850 100644
--- a/src/s_tir/transform/remove_weight_layout_rewrite_block.cc
+++ b/src/s_tir/transform/remove_weight_layout_rewrite_block.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
@@ -52,7 +53,7 @@ class RemoveLayoutRewriteBlock : public StmtMutator {
Stmt VisitStmt_(const SBlockNode* op) final {
SBlock block = Downcast<SBlock>(StmtMutator::VisitStmt_(op));
- auto it =
block->annotations.find(tir::attr::meta_schedule_layout_rewrite_preproc);
+ auto it =
block->annotations.find(s_tir::attr::meta_schedule_layout_rewrite_preproc);
if (it == block->annotations.end() ||
!is_one(Downcast<PrimExpr>((*it).second))) {
// The block is not a weight layout block
// Remove allocates if needed
diff --git a/src/s_tir/transform/storage_access.cc
b/src/s_tir/transform/storage_access.cc
index f3fc337ddb..af45cd781b 100644
--- a/src/s_tir/transform/storage_access.cc
+++ b/src/s_tir/transform/storage_access.cc
@@ -22,6 +22,7 @@
*/
#include "storage_access.h"
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/op.h>
#include <string>
@@ -147,7 +148,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode*
op) {
StmtExprVisitor::VisitStmt_(op);
}
env_threads_.pop_back();
- } else if (op->attr_key == tir::attr::hand_threaded) {
+ } else if (op->attr_key == s_tir::attr::hand_threaded) {
// skip this pass on blocks that were hand_threaded
// this avoids control flow and read/write conflicts
// between hand-threaded kernels and automatic threading
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index 0734310c30..831abb9299 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -23,6 +23,7 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/name_supply.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/te/operation.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/function.h>
@@ -145,7 +146,7 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator
{
for (int i : this->layout_free_buffer_indices_) {
indices.push_back(i);
}
- return WithAttr(std::move(func), tir::attr::layout_free_buffers, indices);
+ return WithAttr(std::move(func), s_tir::attr::layout_free_buffers,
indices);
}
Stmt VisitStmt_(const SBlockNode* _block) final {
@@ -318,7 +319,7 @@ ffi::Map<ffi::String, ffi::Any>
GenerateBlockAnnotations(const te::ComputeOp& co
}
}
// Set script_parsing_detect_access
- annotations.Set(tir::attr::script_parsing_detect_access,
IntImm(DataType::Int(32), 3));
+ annotations.Set(s_tir::attr::script_parsing_detect_access,
IntImm(DataType::Int(32), 3));
return annotations;
}
diff --git a/src/tir/ir/script/script_complete.cc
b/src/tir/ir/script/script_complete.cc
index 1dc2e7bdff..836694166f 100644
--- a/src/tir/ir/script/script_complete.cc
+++ b/src/tir/ir/script/script_complete.cc
@@ -26,6 +26,7 @@
#include <tvm/arith/int_set.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/analysis.h>
#include <utility>
@@ -75,7 +76,7 @@ class ScriptCompleter : public StmtMutator {
// Get access detection mask
// 0 for provided region, 1 and 3 for need detect read, 2 and 3 for need
detect write
int mask = 0;
- auto it = op->annotations.find(attr::script_parsing_detect_access);
+ auto it = op->annotations.find(s_tir::attr::script_parsing_detect_access);
if (it != op->annotations.end()) {
mask = Downcast<IntImm>((*it).second)->value;
}
@@ -94,7 +95,7 @@ class ScriptCompleter : public StmtMutator {
if (mask & 2) n->writes = writes;
}
n->annotations = op->annotations;
- n->annotations.erase(attr::script_parsing_detect_access);
+ n->annotations.erase(s_tir::attr::script_parsing_detect_access);
return SBlock(n);
} else {
return block;
diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc
index 4134202b78..f661722e2f 100644
--- a/src/tir/transform/ir_utils.cc
+++ b/src/tir/transform/ir_utils.cc
@@ -26,6 +26,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
@@ -744,7 +745,7 @@ class StorageAlignCollector : public StmtVisitor {
/*! \brief For s-stir, the alignment annotations reside in block
annotations. */
void VisitStmt_(const SBlockNode* op) final {
- auto it = op->annotations.find(attr::buffer_dim_align);
+ auto it = op->annotations.find(s_tir::attr::buffer_dim_align);
if (it != op->annotations.end()) {
auto storage_align_annotation =
Downcast<StorageAlignAnnotation>((*it).second);
for (const auto& storage_align_tuple : storage_align_annotation) {
@@ -758,7 +759,7 @@ class StorageAlignCollector : public StmtVisitor {
/*! \brief For lowered tir, the alignment annotations reside in allocate
annotations. */
void VisitStmt_(const AllocateNode* op) final {
- auto it = op->annotations.find(attr::buffer_dim_align);
+ auto it = op->annotations.find(s_tir::attr::buffer_dim_align);
if (it != op->annotations.end()) {
auto storage_align_annotation =
Downcast<StorageAlignAnnotation>((*it).second);
for (const auto& storage_align_tuple : storage_align_annotation) {