This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 521440ea8e [REFACTOR][TIR] Cleanup AttrStmt attributes (#18862)
521440ea8e is described below
commit 521440ea8e025492b7a81ca906be5016b8ef9095
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Mar 2 14:03:20 2026 -0500
[REFACTOR][TIR] Cleanup AttrStmt attributes (#18862)
## Summary
- Phase out 12 unused `tir::attr` constants (`scan_*`, `channel_*`,
`pipeline_*`, `buffer_bind_scope`, `coproc_*`, `loop_scope`) and remove
their dead code paths
- Move 11 S-TIR-owned attributes (`async_*`, `double_buffer_*`,
`fragment_*`, `pragma_loop_partition_hint`, `reduce_scope`,
`virtual_thread`) from `tir::attr` to `s_tir::attr`
- Alphabetize the remaining 15 `tir::attr` constants
---
include/tvm/s_tir/stmt.h | 43 +++++++
include/tvm/tir/stmt.h | 135 ++++-----------------
src/arith/ir_mutator_with_analyzer.cc | 3 +-
src/arith/ir_visitor_with_analyzer.cc | 3 +-
src/s_tir/analysis/verify_gpu_code.cc | 3 +-
.../postproc/disallow_async_strided_mem_copy.cc | 3 +-
.../multi_level_tiling_tensor_core.cc | 4 +-
src/s_tir/transform/annotate_irregular_loop.cc | 2 +-
src/s_tir/transform/compact_buffer_region.cc | 2 +-
src/s_tir/transform/inject_double_buffer.cc | 9 +-
src/s_tir/transform/inject_ptx_async_copy.cc | 3 +-
src/s_tir/transform/inject_software_pipeline.cc | 10 +-
src/s_tir/transform/inject_virtual_thread.cc | 7 +-
src/s_tir/transform/loop_partition.cc | 5 +-
src/s_tir/transform/lower_async_dma.cc | 7 +-
.../transform/lower_cross_thread_reduction.cc | 3 +-
src/s_tir/transform/lower_opaque_block.cc | 2 +-
src/s_tir/transform/lower_thread_allreduce.cc | 3 +-
.../transform/merge_shared_memory_allocations.cc | 3 +-
src/s_tir/transform/storage_access.cc | 7 +-
src/s_tir/transform/tensorcore_infer_fragment.cc | 5 +-
src/s_tir/transform/thread_storage_sync.cc | 10 +-
src/s_tir/transform/unify_thread_binding.cc | 3 +-
src/target/llvm/codegen_cpu.cc | 6 +-
src/target/source/codegen_cuda.cc | 9 +-
src/target/spirv/codegen_spirv.cc | 6 +-
src/tir/analysis/verify_memory.cc | 3 +-
src/tir/analysis/verify_well_formed.cc | 1 -
src/tir/ir/data_type_rewriter.cc | 5 +-
src/tir/ir/tir_visitor_with_path.cc | 24 +---
src/tir/transform/annotate_device_regions.cc | 3 +-
src/tir/transform/bind_target.cc | 5 +-
src/tir/transform/ir_utils.cc | 4 +-
src/tir/transform/ir_utils.h | 3 +-
src/tir/transform/lower_warp_memory.cc | 3 +-
src/tir/transform/narrow_datatype.cc | 3 +-
src/tir/transform/remove_no_op.cc | 3 +-
src/tir/transform/storage_rewrite.cc | 7 +-
.../codegen/test_target_codegen_static_init.py | 26 ----
.../test_tir_analysis_verify_well_formed.py | 53 --------
40 files changed, 149 insertions(+), 290 deletions(-)
diff --git a/include/tvm/s_tir/stmt.h b/include/tvm/s_tir/stmt.h
index 2e6b2278f0..86435f94b6 100644
--- a/include/tvm/s_tir/stmt.h
+++ b/include/tvm/s_tir/stmt.h
@@ -31,6 +31,49 @@ namespace tvm {
namespace s_tir {
namespace attr {
+/*!
+ * \brief Annotations for invoking and synchronizing asynchronous operations.
+ */
+constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
+constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
+constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
+
+/*!
+ * \brief Mark that the attached statement runs asynchronously.
+ */
+constexpr const char* async_scope = "async_scope";
+
+/*!
+ * \brief Marks production of double buffer data
+ */
+constexpr const char* double_buffer_scope = "double_buffer_scope";
+
+/*!
+ * \brief Marks region used by double buffer write
+ */
+constexpr const char* double_buffer_write = "double_buffer_write";
+
+/*!
+ * \brief Mark that the shape of TensorCore fragment
+ */
+constexpr const char* fragment_shape = "fragment_shape";
+
+/*!
+ * \brief Mark that the layout of TensorCore fragment
+ */
+constexpr const char* fragment_layout = "fragment_layout";
+
+/*!
+ * \brief Mark that the loop should be partitioned.
+ */
+constexpr const char* pragma_loop_partition_hint =
"pragma_loop_partition_hint";
+
+/*! \brief Mark of reduce scope */
+constexpr const char* reduce_scope = "reduce_scope";
+
+/*! \brief Mark launching of a virtual thread. */
+constexpr const char* virtual_thread = "virtual_thread";
+
// -----------------------------------------------------------------------
// meta_schedule annotations
// -----------------------------------------------------------------------
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index b6424c1da3..dcf1274bcf 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -916,136 +916,43 @@ class SBlockRealize : public Stmt {
/*! \brief namespace of possible attributes in AttrStmt.attr_key */
namespace attr {
-// The above attr does not pass to ir stage.
-/*! \brief Mark launching extent of thread, used by device API. */
-constexpr const char* thread_extent = "thread_extent";
-/*! \brief Mark launching of a virtual thread. */
-constexpr const char* virtual_thread = "virtual_thread";
-/*! \brief Mark region is processed by a co-processor */
-constexpr const char* coproc_scope = "coproc_scope";
-/*!
- * \brief Mark region creates coprocessor micro ops,
- * can be reused if corresponding variable is independent.
- */
-constexpr const char* coproc_uop_scope = "coproc_uop_scope";
-/*! \brief Mark the scope as volatile access for certain handle. */
-constexpr const char* volatile_scope = "volatile_scope";
-/*!
- * \brief Mark the scope as generated by extern primitive.
- * such scope can contain arbitrary ir program and we need to be careful
- * when make certain assumptions about the structure of the program.
- */
-constexpr const char* extern_scope = "extern_scope";
+/*! \brief Mark stores/loads with their bounds. */
+constexpr const char* buffer_bound = "buffer_bound";
/*!
- * \brief Mark the scope as when computation start to happen
+ * \brief Mark the scope as when computation start to happen.
* This can hint some code generator to create a new function for compute.
*/
constexpr const char* compute_scope = "compute_scope";
-/*! \brief Mark storage alignment requirement of buffers */
-constexpr const char* storage_alignment = "storage_alignment";
/*! \brief The allocation device for global malloc in host. */
constexpr const char* device_id = "device_id";
+/*! \brief Mark that it is in the device scope. */
+constexpr const char* device_scope = "device_scope";
/*! \brief The device type. */
constexpr const char* device_type = "device_type";
-/*! \brief Mark of loop scope */
-constexpr const char* loop_scope = "loop_scope";
-/*! \brief Mark of reduce scope */
-constexpr const char* reduce_scope = "reduce_scope";
+/*!
+ * \brief Mark the scope as generated by extern primitive.
+ * Such scope can contain arbitrary ir program and we need to be careful
+ * when making certain assumptions about the structure of the program.
+ */
+constexpr const char* extern_scope = "extern_scope";
/*! \brief Pragma: auto-unroll, max_step */
constexpr const char* pragma_auto_unroll_max_step =
"pragma_auto_unroll_max_step";
-/*! \brief Pragma: unroll explicit */
-constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
-/*! \brief Mark region is guarded by the pragma extension */
-constexpr const char* pragma_scope_prefix = "pragma_";
/*! \brief Import C source or file into the final code gen module */
constexpr const char* pragma_import_c = "pragma_import_c";
/*! \brief Import llvm source or file into the final code gen module */
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
+/*! \brief Mark region is guarded by the pragma extension */
+constexpr const char* pragma_scope_prefix = "pragma_";
/*! \brief Try to modify the AST to support Tensor Core */
constexpr const char* pragma_tensor_core = "pragma_tensor_core";
-/*!
- * \brief Marks production of double buffer data
- */
-constexpr const char* double_buffer_scope = "double_buffer_scope";
-/*!
- * \brief Marks region used by double buffer write
- */
-constexpr const char* double_buffer_write = "double_buffer_write";
-/*! \brief Mark of scan update scope */
-constexpr const char* scan_update_scope = "scan_update_scope";
-/*! \brief Mark of scan init scope */
-constexpr const char* scan_init_scope = "scan_init_scope";
-/*! \brief Mark stores/loads with theirs bounds. */
-constexpr const char* buffer_bound = "buffer_bound";
-/*!
- * \brief Bind the buffer specification to the region of the op
- * When this scope occurs, the stmt.node is a ffi::Array<NodeRef> = [buffer,
tensor]
- * stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...).
- * The scope represents that we need to bind the storage region of tensor to
buffer.
- * This will affect replacement of some variables inside the scope that
- * corresponds to field of buffer to be the actual expressions of tensor
during
- * storage flattening phase.
- */
-constexpr const char* buffer_bind_scope = "buffer_bind_scope";
-// Pipeline related attributes
-/*! \brief channel read scope */
-constexpr const char* channel_read_scope = "channel_read_scope";
-/*! \brief Advance step of channel after end of scope */
-constexpr const char* channel_read_advance = "channel_read_advance";
-/*! \brief channel write scope */
-constexpr const char* channel_write_scope = "channel_write_scope";
-/*! \brief Advance step of channel after end of scope */
-constexpr const char* channel_write_advance = "channel_write_advance";
-/*! \brief pipeline stage scope, implies always execution */
-constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
-/*! \brief pipeline execution scope, implies the scope can be pipelined. */
-constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
-
-/*!
- * \brief Mark that it is in the device scope.
- */
-constexpr const char* device_scope = "device_scope";
-
-/*!
- * \brief Mark that the attached statement runs asynchronously.
- */
-constexpr const char* async_scope = "async_scope";
-
-/*!
- * \brief Annotations for invoking and synchronizing asynchronous operations.
-
- * Synchronization is done in terms of "queue": It is an abstract entity
associated
- * with each asynchronous unit, and it tracks invocations and completions of
asynchronous
- * operations in the FIFO order.
- *
- * Similarly to PTX instructions commit_group and wait_group, these
annotations express
- * synchronization by "counting":
- *
- * async_commit_queue(i): Group one or more invocations of async operations in
the given scope,
- * and "commit" (or push) them to the queue i. A group of operations committed
together is
- * awaited as one chunk. Groups committed to the same queue complete in the
FIFO order.
- *
- * async_wait_queue(i, N): Block until only N most recent committed groups are
still in-flight at
- * the queue i. N does not have to be a constant, but some backends may
require a constant count.
-*/
-constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
-constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
-constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
-
-/*!
- * \brief Mark that the shape of TensorCore fragment
- */
-constexpr const char* fragment_shape = "fragment_shape";
-
-/*!
- * \brief Mark that the layout of TensorCore fragment
- */
-constexpr const char* fragment_layout = "fragment_layout";
-
-/*!
- * \brief Mark that the loop should be partitioned.
- */
-constexpr const char* pragma_loop_partition_hint =
"pragma_loop_partition_hint";
+/*! \brief Pragma: unroll explicit */
+constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
+/*! \brief Mark storage alignment requirement of buffers */
+constexpr const char* storage_alignment = "storage_alignment";
+/*! \brief Mark launching extent of thread, used by device API. */
+constexpr const char* thread_extent = "thread_extent";
+/*! \brief Mark the scope as volatile access for certain handle. */
+constexpr const char* volatile_scope = "volatile_scope";
/*!
* \brief Check if attr_key is a pragma key extension
diff --git a/src/arith/ir_mutator_with_analyzer.cc
b/src/arith/ir_mutator_with_analyzer.cc
index 7ba8059336..f6c8db0161 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -23,6 +23,7 @@
#include "ir_mutator_with_analyzer.h"
#include <tvm/arith/iter_affine_map.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
@@ -142,7 +143,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const
IfThenElseNode* op) {
Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
return constraint_scope_.WithNewScope([&]() -> Stmt {
- if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
tir::attr::virtual_thread) {
+ if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U);
Range dom = Range::FromMinExtent(make_zero(op->value.dtype()),
op->value);
diff --git a/src/arith/ir_visitor_with_analyzer.cc
b/src/arith/ir_visitor_with_analyzer.cc
index 54464ad70f..736e148d7a 100644
--- a/src/arith/ir_visitor_with_analyzer.cc
+++ b/src/arith/ir_visitor_with_analyzer.cc
@@ -22,6 +22,7 @@
*/
#include "ir_visitor_with_analyzer.h"
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
@@ -75,7 +76,7 @@ void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode*
op) {
void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
constraint_scope_.WithNewScope([&]() {
- if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
tir::attr::virtual_thread) {
+ if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U);
analyzer_.Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype,
0), op->value));
diff --git a/src/s_tir/analysis/verify_gpu_code.cc
b/src/s_tir/analysis/verify_gpu_code.cc
index ed7853a99c..3cfb41769d 100644
--- a/src/s_tir/analysis/verify_gpu_code.cc
+++ b/src/s_tir/analysis/verify_gpu_code.cc
@@ -26,6 +26,7 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
@@ -86,7 +87,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}
void VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
tir::attr::virtual_thread) {
+ if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread) {
if (nest_level_ == 0) {
// enter a new kernel, reset statistics
Reset_();
diff --git
a/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
b/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
index 753c929dd6..2222289d5f 100644
--- a/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
+++ b/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include "../utils.h"
@@ -51,7 +52,7 @@ struct AsyncStridedMemCopyFinder : private StmtExprVisitor {
void VisitStmt_(const AttrStmtNode* attrStmt) final {
if (!found_) {
- if (attrStmt->attr_key == tir::attr::async_commit_queue_scope) {
+ if (attrStmt->attr_key == s_tir::attr::async_commit_queue_scope) {
auto async_scope = attrStmt->body.as<AttrStmtNode>();
if (!async_scope) {
StmtExprVisitor::VisitStmt_(attrStmt);
diff --git
a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index 6b2338210d..37b6d14f2a 100644
--- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -670,12 +670,12 @@ std::vector<State>
MultiLevelTilingTensorCoreNode::AddSoftwarePipeline(
sch->Annotate(cache_read, s_tir::attr::vector_bytes, Integer(16));
if (!state->use_async) {
sch->Annotate(cache_read, s_tir::attr::local_stage, Integer(1));
- sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0));
+ sch->Annotate(cache_read, s_tir::attr::double_buffer_scope,
Integer(0));
}
} else {
// Add local stage and double buffering
sch->Annotate(cache_read,
s_tir::attr::manifest_shared_memory_local_stage, Integer(1));
- sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0));
+ sch->Annotate(cache_read, s_tir::attr::double_buffer_scope, Integer(0));
}
}
diff --git a/src/s_tir/transform/annotate_irregular_loop.cc
b/src/s_tir/transform/annotate_irregular_loop.cc
index f496a55041..711d87c3af 100644
--- a/src/s_tir/transform/annotate_irregular_loop.cc
+++ b/src/s_tir/transform/annotate_irregular_loop.cc
@@ -46,7 +46,7 @@ class IrregularLoopAnnotator : public StmtMutator {
<< "Loop kind " << op->kind << " is invalid for irregular loop " <<
op->loop_var;
for (const char* key :
{tir::attr::pragma_auto_unroll_max_step,
tir::attr::pragma_unroll_explicit,
- tir::attr::pragma_loop_partition_hint,
s_tir::attr::software_pipeline_stage}) {
+ s_tir::attr::pragma_loop_partition_hint,
s_tir::attr::software_pipeline_stage}) {
TVM_FFI_ICHECK(!res->annotations.count(key))
<< "Annotation `" << key << "` is invalid for irregular loop " <<
op->loop_var;
}
diff --git a/src/s_tir/transform/compact_buffer_region.cc
b/src/s_tir/transform/compact_buffer_region.cc
index 6f78dfafc9..eaebca0f4a 100644
--- a/src/s_tir/transform/compact_buffer_region.cc
+++ b/src/s_tir/transform/compact_buffer_region.cc
@@ -320,7 +320,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
}
void VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
tir::attr::virtual_thread) {
+ if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread) {
IterVar iter = Downcast<IterVar>(op->node);
ancestor_iters_.push_back(iter);
Range dom = iter->dom;
diff --git a/src/s_tir/transform/inject_double_buffer.cc
b/src/s_tir/transform/inject_double_buffer.cc
index 99da2ac51b..a10a338bfe 100644
--- a/src/s_tir/transform/inject_double_buffer.cc
+++ b/src/s_tir/transform/inject_double_buffer.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
@@ -60,7 +61,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.InjectDoubleBuffer",
InjectDoubleBufferCo
class DoubleBufferDetector : public StmtExprVisitor {
public:
void VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == tir::attr::double_buffer_scope) {
+ if (op->attr_key == s_tir::attr::double_buffer_scope) {
touched_.insert(op->node.as<VarNode>());
StmtExprVisitor::VisitStmt_(op);
} else {
@@ -80,7 +81,7 @@ class DoubleBufferDetector : public StmtExprVisitor {
class StripDoubleBufferWrite : public StmtMutator {
public:
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == tir::attr::double_buffer_write) {
+ if (op->attr_key == s_tir::attr::double_buffer_write) {
return VisitStmt(op->body);
} else {
return StmtMutator::VisitStmt_(op);
@@ -103,7 +104,7 @@ class DoubleBufferInjector : public StmtExprMutator {
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == tir::attr::double_buffer_scope) {
+ if (op->attr_key == s_tir::attr::double_buffer_scope) {
return MakeProducer(op);
} else {
return StmtExprMutator::VisitStmt_(op);
@@ -279,7 +280,7 @@ class DoubleBufferInjector : public StmtExprMutator {
vmap[e.loop->loop_var.get()] = loop_shift;
vmap[e.switch_write_var.get()] = indexmod(loop_shift, two);
body = Substitute(body, vmap);
- body = AttrStmt(buffer, tir::attr::double_buffer_write, 1, body);
+ body = AttrStmt(buffer, s_tir::attr::double_buffer_write, 1, body);
body = IfThenElse(loop_shift < e.loop->extent, body);
return body;
}
diff --git a/src/s_tir/transform/inject_ptx_async_copy.cc
b/src/s_tir/transform/inject_ptx_async_copy.cc
index 6e0257d248..c0632fdd23 100644
--- a/src/s_tir/transform/inject_ptx_async_copy.cc
+++ b/src/s_tir/transform/inject_ptx_async_copy.cc
@@ -31,6 +31,7 @@
#include "../../tir/ir/buffer_common.h"
#include "storage_access.h"
+#include "tvm/s_tir/stmt.h"
#include "tvm/tir/stmt.h"
namespace tvm {
@@ -40,7 +41,7 @@ using namespace tvm::tir;
class PTXAsyncCopyInjector : public StmtMutator {
public:
Stmt VisitStmt_(const AttrStmtNode* attr) {
- if (attr->attr_key == tir::attr::async_scope) {
+ if (attr->attr_key == s_tir::attr::async_scope) {
TVM_FFI_ICHECK(in_async == false) << "Nested async scopes not supported";
in_async = true;
auto body = this->VisitStmt(attr->body);
diff --git a/src/s_tir/transform/inject_software_pipeline.cc
b/src/s_tir/transform/inject_software_pipeline.cc
index 1e1bb446e4..ac61801881 100644
--- a/src/s_tir/transform/inject_software_pipeline.cc
+++ b/src/s_tir/transform/inject_software_pipeline.cc
@@ -755,8 +755,8 @@ class PipelineRewriter : public StmtExprMutator {
SBlockNode* n = block.CopyOnWrite();
auto zero = make_zero(DataType::Int(32));
n->body =
- AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
- AttrStmt(zero, tir::attr::async_wait_inflight_count,
wait_count, n->body));
+ AttrStmt(zero, s_tir::attr::async_wait_queue_scope, stage_id,
+ AttrStmt(zero, s_tir::attr::async_wait_inflight_count,
wait_count, n->body));
};
if (state.predicate &&
!ana_normalized->CanProve(state.predicate.value())) {
@@ -798,7 +798,7 @@ class PipelineRewriter : public StmtExprMutator {
for (auto body : group_bodies) {
auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
-
tir::attr::async_commit_queue_scope, stage_id, body);
+
s_tir::attr::async_commit_queue_scope, stage_id, body);
auto new_block = MakeSBlock(commit_queue_scope,
buffer_data_to_buffer_);
stmts.push_back(SBlockRealize({}, predicate, new_block));
}
@@ -914,7 +914,7 @@ class PipelineRewriter : public StmtExprMutator {
}
SBlockNode* n = new_block.CopyOnWrite();
- n->body = AttrStmt(make_zero(DataType::Int(32)),
tir::attr::async_scope, 1, n->body);
+ n->body = AttrStmt(make_zero(DataType::Int(32)),
s_tir::attr::async_scope, 1, n->body);
}
new_blocks.push_back(
@@ -1212,7 +1212,7 @@ class PipelineInjector : private StmtExprMutator {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
- auto it = op->annotations.find(tir::attr::double_buffer_scope);
+ auto it = op->annotations.find(s_tir::attr::double_buffer_scope);
if (it != op->annotations.end()) {
int buffer_index = Downcast<Integer>((*it).second).IntValue();
TVM_FFI_CHECK(buffer_index >= 0 && static_cast<size_t>(buffer_index) <
op->writes.size(),
diff --git a/src/s_tir/transform/inject_virtual_thread.cc
b/src/s_tir/transform/inject_virtual_thread.cc
index b644dffe8d..cd1bca9a0b 100644
--- a/src/s_tir/transform/inject_virtual_thread.cc
+++ b/src/s_tir/transform/inject_virtual_thread.cc
@@ -22,6 +22,7 @@
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
@@ -290,10 +291,6 @@ class VTInjector : public arith::IRMutatorWithAnalyzer {
PrimExpr value = this->VisitExpr(op->value);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(ffi::GetRef<Stmt>(op), true);
- } else if (!allow_share_ && !vt_loop_injected_ &&
- (op->attr_key == tir::attr::coproc_uop_scope ||
- op->attr_key == tir::attr::coproc_scope)) {
- return InjectVTLoop(ffi::GetRef<Stmt>(op), true);
} else {
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
@@ -498,7 +495,7 @@ class VirtualThreadInjector : public
arith::IRMutatorWithAnalyzer {
Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
- if (op->attr_key == tir::attr::virtual_thread) {
+ if (op->attr_key == s_tir::attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
bool allow_share = std::string(iv->thread_tag).substr(0, 7) == "vthread";
int nthread = static_cast<int>(op->value.as<IntImmNode>()->value);
diff --git a/src/s_tir/transform/loop_partition.cc
b/src/s_tir/transform/loop_partition.cc
index 8020e97867..2afa173319 100644
--- a/src/s_tir/transform/loop_partition.cc
+++ b/src/s_tir/transform/loop_partition.cc
@@ -24,6 +24,7 @@
#include <tvm/arith/bound.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
@@ -152,7 +153,7 @@ class CandidateSelector final : public StmtExprVisitor {
record_.erase(var.get());
return;
}
- } else if (op->attr_key == tir::attr::pragma_loop_partition_hint) {
+ } else if (op->attr_key == s_tir::attr::pragma_loop_partition_hint) {
if (analyzer_.CanProve(op->value)) {
const VarNode* var = nullptr;
if (op->node.as<VarNode>()) {
@@ -791,7 +792,7 @@ class RemoveLikelyTagsAndHints : public StmtExprMutator {
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == tir::attr::pragma_loop_partition_hint) {
+ if (op->attr_key == s_tir::attr::pragma_loop_partition_hint) {
return VisitStmt(op->body);
}
return StmtExprMutator::VisitStmt_(op);
diff --git a/src/s_tir/transform/lower_async_dma.cc
b/src/s_tir/transform/lower_async_dma.cc
index d00824eed4..fb6e9260ee 100644
--- a/src/s_tir/transform/lower_async_dma.cc
+++ b/src/s_tir/transform/lower_async_dma.cc
@@ -26,6 +26,7 @@
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/s_tir/analysis.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/stmt.h>
@@ -93,7 +94,7 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer {
// 0, /* in flight count */
// dtype=int32
// )
- if (op->attr_key == tir::attr::async_wait_queue_scope) {
+ if (op->attr_key == s_tir::attr::async_wait_queue_scope) {
// get queue ID
auto queue_id_node = op->value.as<IntImmNode>();
TVM_FFI_ICHECK(queue_id_node);
@@ -108,7 +109,7 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer
{
}
auto async_wait = op->body.as<AttrStmtNode>();
- if (!async_wait || async_wait->attr_key !=
tir::attr::async_wait_inflight_count) {
+ if (!async_wait || async_wait->attr_key !=
s_tir::attr::async_wait_inflight_count) {
DLOG(INFO) << "AsyncDMALowerer exiting because the body of the
`AttrStmtNode` with key "
"`async_wait_queue_scope` does not contain an
`AttrStmtNode` with key "
"`async_wait_inflight_count`";
@@ -135,7 +136,7 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer
{
// 128, /* size */
// dtype=int32
// )
- } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
+ } else if (op->attr_key == s_tir::attr::async_commit_queue_scope) {
// get queue ID
auto queue_id_node = op->value.as<IntImmNode>();
TVM_FFI_ICHECK(queue_id_node);
diff --git a/src/s_tir/transform/lower_cross_thread_reduction.cc
b/src/s_tir/transform/lower_cross_thread_reduction.cc
index c3c1f2ab3a..03df34b617 100644
--- a/src/s_tir/transform/lower_cross_thread_reduction.cc
+++ b/src/s_tir/transform/lower_cross_thread_reduction.cc
@@ -22,6 +22,7 @@
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
@@ -418,7 +419,7 @@ Stmt TransformReductionBlock(const SBlockRealizeNode*
realize,
/*name_hint=*/block->name_hint + "_cross_thread",
/*body=*/
AttrStmt(/*node=*/reducer,
- /*attr_key=*/tir::attr::reduce_scope,
+ /*attr_key=*/s_tir::attr::reduce_scope,
/*value=*/make_zero(DataType::Handle()),
/*body=*/
Evaluate(Call(/*dtype=*/DataType::Handle(),
diff --git a/src/s_tir/transform/lower_opaque_block.cc
b/src/s_tir/transform/lower_opaque_block.cc
index cde6c28c8d..507dae9014 100644
--- a/src/s_tir/transform/lower_opaque_block.cc
+++ b/src/s_tir/transform/lower_opaque_block.cc
@@ -146,7 +146,7 @@ class OpaqueBlockLower : public StmtExprMutator {
/*thread_tag=*/thread_tag);
ffi::String attr_key = (thread_tag == "vthread" || thread_tag ==
"vthread.x" ||
thread_tag == "vthread.y" || thread_tag ==
"vthread.z")
- ? tir::attr::virtual_thread
+ ? s_tir::attr::virtual_thread
: tir::attr::thread_extent;
return AttrStmt(/*node=*/std::move(iter_var),
/*attr_key=*/std::move(attr_key),
diff --git a/src/s_tir/transform/lower_thread_allreduce.cc
b/src/s_tir/transform/lower_thread_allreduce.cc
index 39ccef4721..b3fb6c6d6c 100644
--- a/src/s_tir/transform/lower_thread_allreduce.cc
+++ b/src/s_tir/transform/lower_thread_allreduce.cc
@@ -24,6 +24,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
@@ -53,7 +54,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
Stmt ret = StmtExprMutator::VisitStmt_(op);
thread_extents_.pop_back();
return ret;
- } else if (op->attr_key == tir::attr::reduce_scope) {
+ } else if (op->attr_key == s_tir::attr::reduce_scope) {
const CommReducerNode* combiner = op->node.as<CommReducerNode>();
TVM_FFI_ICHECK(combiner);
reduce_combiner_.push_back(combiner);
diff --git a/src/s_tir/transform/merge_shared_memory_allocations.cc
b/src/s_tir/transform/merge_shared_memory_allocations.cc
index 6463f43125..1975091c24 100644
--- a/src/s_tir/transform/merge_shared_memory_allocations.cc
+++ b/src/s_tir/transform/merge_shared_memory_allocations.cc
@@ -25,6 +25,7 @@
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
@@ -217,7 +218,7 @@ class SharedMemLinearAccessPatternFinder final : public
StmtExprVisitor {
in_thread_env_ = false;
} else if (op->attr_key == tir::attr::extern_scope) {
VisitNewScope(op);
- } else if (op->attr_key == tir::attr::virtual_thread) {
+ } else if (op->attr_key == s_tir::attr::virtual_thread) {
VisitNewScope(op);
} else {
StmtExprVisitor::VisitStmt_(op);
diff --git a/src/s_tir/transform/storage_access.cc
b/src/s_tir/transform/storage_access.cc
index af45cd781b..dae797486d 100644
--- a/src/s_tir/transform/storage_access.cc
+++ b/src/s_tir/transform/storage_access.cc
@@ -110,7 +110,7 @@ void StorageAccessVisitor::VisitStmt_(const LetStmtNode*
op) {
}
void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == tir::attr::double_buffer_write) {
+ if (op->attr_key == s_tir::attr::double_buffer_write) {
TVM_FFI_ICHECK(double_buffer_write_ == nullptr);
double_buffer_write_ = op->node.as<VarNode>();
scope_.push_back(std::vector<StmtEntry>());
@@ -128,11 +128,6 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode*
op) {
scope_.back().emplace_back(std::move(s));
}
double_buffer_write_ = nullptr;
- } else if (op->attr_key == tir::attr::coproc_scope) {
- IterVar iv = Downcast<IterVar>(op->node);
- env_threads_.push_back(iv);
- StmtExprVisitor::VisitStmt_(op);
- env_threads_.pop_back();
} else if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
env_threads_.push_back(iv);
diff --git a/src/s_tir/transform/tensorcore_infer_fragment.cc
b/src/s_tir/transform/tensorcore_infer_fragment.cc
index 428f8f6f54..26bd96cf40 100644
--- a/src/s_tir/transform/tensorcore_infer_fragment.cc
+++ b/src/s_tir/transform/tensorcore_infer_fragment.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
@@ -187,10 +188,10 @@ class InferFragmenter : public StmtMutator {
std::string shape =
std::to_string(info.m) + ", " + std::to_string(info.n) + ", " +
std::to_string(info.k);
PrimExpr shape_expr = StringImm(shape);
- Stmt shape_attr = AttrStmt(op->buffer_var, tir::attr::fragment_shape,
shape_expr, stmt);
+ Stmt shape_attr = AttrStmt(op->buffer_var, s_tir::attr::fragment_shape,
shape_expr, stmt);
if (info.layout != "") {
// Add shape attribute to matrix_a and matrix_b
- Stmt layout_attr = AttrStmt(op->buffer_var, tir::attr::fragment_layout,
+ Stmt layout_attr = AttrStmt(op->buffer_var,
s_tir::attr::fragment_layout,
StringImm(info.layout), shape_attr);
return layout_attr;
} else {
diff --git a/src/s_tir/transform/thread_storage_sync.cc
b/src/s_tir/transform/thread_storage_sync.cc
index 6892f83a29..2ce3ddd814 100644
--- a/src/s_tir/transform/thread_storage_sync.cc
+++ b/src/s_tir/transform/thread_storage_sync.cc
@@ -22,6 +22,7 @@
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
@@ -292,15 +293,16 @@ class ThreadSyncAfterWaitQueueInserter : public
StmtExprMutator {
explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope) :
sync_scope_(sync_scope) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == tir::attr::async_wait_queue_scope) {
+ if (op->attr_key == s_tir::attr::async_wait_queue_scope) {
auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
{StringImm(sync_scope_.to_string())}));
auto inner = op->body.as<AttrStmtNode>();
- TVM_FFI_ICHECK(inner && inner->attr_key ==
tir::attr::async_wait_inflight_count);
+ TVM_FFI_ICHECK(inner && inner->attr_key ==
s_tir::attr::async_wait_inflight_count);
auto zero = make_zero(DataType::Int(32));
auto new_body = SeqStmt({sync, inner->body});
- return AttrStmt(zero, tir::attr::async_wait_queue_scope, op->value,
- AttrStmt(zero, tir::attr::async_wait_inflight_count,
inner->value, new_body));
+ return AttrStmt(
+ zero, s_tir::attr::async_wait_queue_scope, op->value,
+ AttrStmt(zero, s_tir::attr::async_wait_inflight_count, inner->value,
new_body));
}
return StmtExprMutator::VisitStmt_(op);
}
diff --git a/src/s_tir/transform/unify_thread_binding.cc
b/src/s_tir/transform/unify_thread_binding.cc
index f9b2d131cd..c33380175f 100644
--- a/src/s_tir/transform/unify_thread_binding.cc
+++ b/src/s_tir/transform/unify_thread_binding.cc
@@ -23,6 +23,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
@@ -48,7 +49,7 @@ class ThreadBindingUnifier : public StmtExprMutator {
private:
Stmt VisitStmt_(const AttrStmtNode* op) final {
// If this AttrStmt is not thread binding attribute, return as usual.
- if (op->attr_key != tir::attr::thread_extent && op->attr_key !=
tir::attr::virtual_thread) {
+ if (op->attr_key != tir::attr::thread_extent && op->attr_key !=
s_tir::attr::virtual_thread) {
return StmtMutator::VisitStmt_(op);
}
IterVar old_iter_var = Downcast<IterVar>(op->node);
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index 6ffb4fc030..8f9ee62cc8 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -1127,11 +1127,7 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) {
void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
EmitDebugLocation(op);
- if (op->attr_key == tir::attr::coproc_uop_scope) {
- const StringImmNode* value = op->value.as<StringImmNode>();
- TVM_FFI_ICHECK(value != nullptr);
- this->CreateStaticInit(value->value, op->body);
- } else if (op->attr_key == tir::attr::compute_scope) {
+ if (op->attr_key == tir::attr::compute_scope) {
this->CreateComputeScope(op);
} else if (tir::attr::IsPragmaKey(op->attr_key)) {
if (op->attr_key == "pragma_parallel_stride_pattern") {
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 3d5beacc63..c577f1d5b6 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -25,6 +25,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/stmt_functor.h>
@@ -1350,15 +1351,15 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
}
void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == tir::attr::fragment_shape) {
+ if (op->attr_key == s_tir::attr::fragment_shape) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* shape_str = op->value.as<StringImmNode>();
fragment_shapes[buffer] = shape_str->value;
- } else if (op->attr_key == tir::attr::fragment_layout) {
+ } else if (op->attr_key == s_tir::attr::fragment_layout) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* layout_str = op->value.as<StringImmNode>();
fragment_layouts[buffer] = layout_str->value;
- } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
+ } else if (op->attr_key == s_tir::attr::async_commit_queue_scope) {
const IntImmNode* queue_id = op->value.as<IntImmNode>();
TVM_FFI_ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
@@ -1366,7 +1367,7 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(),
{});
this->VisitExpr(commit_group, this->stream);
return;
- } else if (op->attr_key == tir::attr::async_wait_queue_scope) {
+ } else if (op->attr_key == s_tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op);
auto queue_id = wait_attrs.first.as<IntImmNode>();
TVM_FFI_ICHECK(queue_id && queue_id->value == 0)
diff --git a/src/target/spirv/codegen_spirv.cc
b/src/target/spirv/codegen_spirv.cc
index 114bdbc806..d91367bcf6 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -23,6 +23,7 @@
*/
#include "codegen_spirv.h"
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
@@ -874,10 +875,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
const VarNode* v = op->node.as<VarNode>();
TVM_FFI_ICHECK(v);
storage_info_[v].is_volatile = true;
- } else if (op->attr_key == tir::attr::buffer_bind_scope) {
- const VarNode* v = op->node.as<VarNode>();
- TVM_FFI_ICHECK(v);
- } else if (op->attr_key == tir::attr::fragment_shape) {
+ } else if (op->attr_key == s_tir::attr::fragment_shape) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* shape_str = op->value.as<StringImmNode>();
fragment_info_[buffer] = {shape_str->value};
diff --git a/src/tir/analysis/verify_memory.cc
b/src/tir/analysis/verify_memory.cc
index 10642f1c70..19bc55bf64 100644
--- a/src/tir/analysis/verify_memory.cc
+++ b/src/tir/analysis/verify_memory.cc
@@ -79,8 +79,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
}
void VisitStmt_(const AttrStmtNode* op) final {
- if (!InThreadEnv() &&
- (op->attr_key == attr::thread_extent || op->attr_key ==
attr::pipeline_exec_scope)) {
+ if (!InThreadEnv() && op->attr_key == attr::thread_extent) {
EnterThreadEnv();
StmtExprVisitor::VisitStmt_(op);
ExitThreadEnv();
diff --git a/src/tir/analysis/verify_well_formed.cc
b/src/tir/analysis/verify_well_formed.cc
index e400cea4e4..0ff363a547 100644
--- a/src/tir/analysis/verify_well_formed.cc
+++ b/src/tir/analysis/verify_well_formed.cc
@@ -314,7 +314,6 @@ class UndefinedVarVerifier : public
Verifier<UndefinedVarVerifier> {
* - DeclBuffer statement
* - SBlock::alloc_buffers
* - SBlock::match_buffers
- * - AttrStmt with key "buffer_bind_scope"
*
* it must not appear in a BufferLoad, BufferStore, or BufferRegion outside
that declaration's
* scope.
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index 7e8b2ed2a6..00d8b786aa 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -24,6 +24,7 @@
#include "data_type_rewriter.h"
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
@@ -94,7 +95,7 @@ Stmt DataTypeLegalizer::VisitStmt_(const SBlockNode* op) {
}
Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::virtual_thread) {
+ if (op->attr_key == attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread) {
Stmt s = StmtExprMutator::VisitStmt_(op);
op = s.as<AttrStmtNode>();
TVM_FFI_ICHECK(op != nullptr) << "Expected type to be AttrStmtNode"
@@ -287,7 +288,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode*
op) {
}
Stmt IndexDataTypeRewriter::VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::virtual_thread) {
+ if (op->attr_key == attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread) {
bool is_enabled = is_enabled_;
is_enabled_ = true;
auto stmt = DataTypeLegalizer::VisitStmt_(op);
diff --git a/src/tir/ir/tir_visitor_with_path.cc
b/src/tir/ir/tir_visitor_with_path.cc
index eca926f547..9e10ad8957 100644
--- a/src/tir/ir/tir_visitor_with_path.cc
+++ b/src/tir/ir/tir_visitor_with_path.cc
@@ -24,6 +24,7 @@
#include "tir_visitor_with_path.h"
#include <tvm/ffi/reflection/access_path.h>
+#include <tvm/s_tir/stmt.h>
#include <algorithm>
#include <optional>
@@ -179,31 +180,12 @@ void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode*
op, AccessPath path) {
std::vector<std::variant<DefContext<IterVar>, DefContext<Var>,
DefContext<Buffer>>> context;
if (auto iter_var = op->node.as<IterVar>();
- iter_var && (op->attr_key == attr::thread_extent || op->attr_key ==
attr::virtual_thread)) {
+ iter_var &&
+ (op->attr_key == attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread)) {
// Some attributes serve as a source of definition for the
// tir::Var they annotate.
context.push_back(WithDef(iter_var.value(), path->Attr("node")));
- } else if (op->attr_key == attr::buffer_bind_scope) {
- // The `attr::buffer_bind_scope` attribute defines a view into an
- // existing buffer, similar to the newer
- // `BlockNode::match_buffers` field. It requires the buffer being
- // viewed to be defined prior to the attribute. The
- // `attr::buffer_bind_scope` is the point of definition for the
- // `tir::Buffer buffer_view`, its `tir::Var` data pointer, and any
- // symbolic shapes used within `buffer_view that are not already
- // defined.
- ffi::Array<ObjectRef> arr = Downcast<ffi::Array<ObjectRef>>(op->node);
- TVM_FFI_ICHECK_EQ(arr.size(), 2U);
- Buffer buffer_view = Downcast<Buffer>(arr[0]);
- Buffer orig_buffer = Downcast<Buffer>(arr[1]);
- Visit(orig_buffer, path->Attr("node")->ArrayItem(1));
-
- for (auto& var : WithMatchBufferDefs(buffer_view,
path->Attr("node")->ArrayItem(0))) {
- context.push_back(std::move(var));
- }
- context.push_back(WithDef(buffer_view, path->Attr("node")->ArrayItem(0)));
-
} else if (auto expr = op->node.as<PrimExpr>()) {
Visit(expr.value(), path->Attr("node"));
}
diff --git a/src/tir/transform/annotate_device_regions.cc
b/src/tir/transform/annotate_device_regions.cc
index 755adade0c..22ddd07495 100644
--- a/src/tir/transform/annotate_device_regions.cc
+++ b/src/tir/transform/annotate_device_regions.cc
@@ -41,8 +41,7 @@ class DeviceRegionAnnotater : public StmtMutator {
if (op->attr_key == tvm::attr::kTarget) {
// If a target attribute already exists, use it as-is.
return ffi::GetRef<Stmt>(op);
- } else if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::pipeline_exec_scope ||
- op->attr_key == attr::device_scope) {
+ } else if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::device_scope) {
// These attributes are only allowed in device-side code, so
// they should be annotated with the function's default target.
Stmt body = ffi::GetRef<Stmt>(op);
diff --git a/src/tir/transform/bind_target.cc b/src/tir/transform/bind_target.cc
index 8ff15e09ca..eaf13c44bd 100644
--- a/src/tir/transform/bind_target.cc
+++ b/src/tir/transform/bind_target.cc
@@ -36,6 +36,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/global_var_supply.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
@@ -112,7 +113,7 @@ class FunctionClassifierVisitor : public StmtExprVisitor {
}
void VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::virtual_thread) {
+ if (op->attr_key == attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread) {
// Enter GPU scope for thread extent and virtual thread attributes
bool last_is_under_gpu_scope = is_under_gpu_scope_;
is_under_gpu_scope_ = true;
@@ -197,7 +198,7 @@ class CallSubstitutor : public StmtExprMutator {
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::virtual_thread) {
+ if (op->attr_key == attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread) {
// Enter GPU scope for thread extent and virtual thread attributes
bool last_is_under_gpu_scope = is_under_gpu_scope_;
is_under_gpu_scope_ = true;
diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc
index 4e8a590a7d..90d53d4df7 100644
--- a/src/tir/transform/ir_utils.cc
+++ b/src/tir/transform/ir_utils.cc
@@ -728,9 +728,9 @@ void ConditionalBoundsContext::ExitWithScope() {
}
std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const AttrStmtNode* op) {
- TVM_FFI_ICHECK(op && op->attr_key == tir::attr::async_wait_queue_scope);
+ TVM_FFI_ICHECK(op && op->attr_key == s_tir::attr::async_wait_queue_scope);
auto inner = op->body.as<AttrStmtNode>();
- TVM_FFI_ICHECK(inner && inner->attr_key ==
tir::attr::async_wait_inflight_count);
+ TVM_FFI_ICHECK(inner && inner->attr_key ==
s_tir::attr::async_wait_inflight_count);
return std::make_pair(op->value, inner->value);
}
diff --git a/src/tir/transform/ir_utils.h b/src/tir/transform/ir_utils.h
index 51e2293321..8077ebdea8 100644
--- a/src/tir/transform/ir_utils.h
+++ b/src/tir/transform/ir_utils.h
@@ -28,6 +28,7 @@
#include <tvm/arith/int_solver.h>
#include <tvm/ffi/container/tuple.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/support/with.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
@@ -311,7 +312,7 @@ struct FragmentInfo {
std::unordered_map<const VarNode*, FragmentInfo>
GetTensorCoreFragmentInfo(const Stmt& stmt);
// Return the queue id and the in-flight count associated with the given
-// attr::async_wait_queue_scope annotation.
+// s_tir::attr::async_wait_queue_scope annotation.
std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const AttrStmtNode* op);
/*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor,
offset) */
diff --git a/src/tir/transform/lower_warp_memory.cc
b/src/tir/transform/lower_warp_memory.cc
index 073759967b..d138b1efae 100644
--- a/src/tir/transform/lower_warp_memory.cc
+++ b/src/tir/transform/lower_warp_memory.cc
@@ -29,6 +29,7 @@
#include <tvm/arith/pattern.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
@@ -395,7 +396,7 @@ class BindVarBoundInfo : public StmtVisitor {
}
void VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::virtual_thread) {
+ if (op->attr_key == attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U);
if (!var_dom_.count(iv->var.get())) {
diff --git a/src/tir/transform/narrow_datatype.cc
b/src/tir/transform/narrow_datatype.cc
index 5ecd46d7be..0ca492c003 100644
--- a/src/tir/transform/narrow_datatype.cc
+++ b/src/tir/transform/narrow_datatype.cc
@@ -24,6 +24,7 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
@@ -120,7 +121,7 @@ class DataTypeVisitor final : public StmtExprVisitor {
}
void VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::virtual_thread) {
+ if (op->attr_key == attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U);
analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value));
diff --git a/src/tir/transform/remove_no_op.cc
b/src/tir/transform/remove_no_op.cc
index 010d189d89..f75df8ef04 100644
--- a/src/tir/transform/remove_no_op.cc
+++ b/src/tir/transform/remove_no_op.cc
@@ -24,6 +24,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
@@ -110,7 +111,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer {
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_debug_skip_region") {
return MakeEvaluate(0);
- } else if (op->attr_key == attr::async_wait_queue_scope) {
+ } else if (op->attr_key == s_tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op);
auto wait_cnt = wait_attrs.second;
arith::Analyzer ana;
diff --git a/src/tir/transform/storage_rewrite.cc
b/src/tir/transform/storage_rewrite.cc
index 48f9cb67a0..9f2f3ce61d 100644
--- a/src/tir/transform/storage_rewrite.cc
+++ b/src/tir/transform/storage_rewrite.cc
@@ -26,6 +26,7 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/type.h>
+#include <tvm/s_tir/stmt.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
@@ -200,7 +201,7 @@ class LinearAccessPatternFinder final : public
StmtExprVisitor {
in_thread_env_ = false;
} else if (op->attr_key == attr::extern_scope) {
VisitNewScope(op);
- } else if (op->attr_key == attr::virtual_thread) {
+ } else if (op->attr_key == s_tir::attr::virtual_thread) {
VisitNewScope(op);
} else {
StmtExprVisitor::VisitStmt_(op);
@@ -481,7 +482,7 @@ class StoragePlanRewriter : public StmtExprMutator {
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::virtual_thread ||
+ if (op->attr_key == attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread ||
attr::IsPragmaKey(op->attr_key)) {
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
@@ -855,7 +856,7 @@ class StoragePlanRewriter : public StmtExprMutator {
// enter/exit new scope
if (s.stmt->IsInstance<AttrStmtNode>()) {
const auto* op = static_cast<const AttrStmtNode*>(s.stmt);
- if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::virtual_thread ||
+ if (op->attr_key == attr::thread_extent || op->attr_key ==
s_tir::attr::virtual_thread ||
attr::IsPragmaKey(op->attr_key)) {
PlanNewScope(op);
} else {
diff --git a/tests/python/codegen/test_target_codegen_static_init.py
b/tests/python/codegen/test_target_codegen_static_init.py
index dee6a88d61..5dbe5d1315 100644
--- a/tests/python/codegen/test_target_codegen_static_init.py
+++ b/tests/python/codegen/test_target_codegen_static_init.py
@@ -23,31 +23,6 @@ from tvm.script import ir as I
from tvm.script import tir as T
-def test_static_callback():
- @I.ir_module
- class Module:
- @T.prim_func
- def ramp(A: T.handle):
- T.func_attr({"global_symbol": "ramp"})
- n = T.int64()
- Ab = T.match_buffer(A, (n,), "int64")
- # coproc_uop_scope with TVMBackendRunOnce ensures body runs only
once
- with T.attr(
- T.iter_var(T.int32(), (0, 1), "DataPar", "cop"),
- "coproc_uop_scope",
- "TVMBackendRunOnce",
- ):
- for i in T.parallel(n):
- Ab[i] = Ab[i] + T.int64(1)
-
- mod = Module
- f = tvm.driver.build(mod, target="llvm")
- a = tvm.runtime.tensor(np.zeros(10, dtype="int64"))
- f(a)
- f(a)
- np.testing.assert_equal(a.numpy(), np.ones(a.shape[0]))
-
-
def test_static_init():
@tvm.register_global_func("test_static_callback")
def test_cb(sh, A):
@@ -74,5 +49,4 @@ def test_static_init():
if __name__ == "__main__":
- test_static_callback()
test_static_init()
diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
index 519fa3794b..fe87c775aa 100644
--- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
+++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
@@ -221,59 +221,6 @@ def test_multiple_buffer_arguments_may_share_allocation():
tvm.tir.analysis.verify_well_formed(mod)
-def test_buffer_bind_scope_defines_buffer_obj():
- """The "buffer_bind_scope" attribute defines a buffer view"""
-
- @I.ir_module
- class mod:
- @T.prim_func
- def func(A: T.Buffer([256, 256], "float32")):
- for tile_i, tile_j in T.grid(16, 16):
- B = T.Buffer([16, 16], "float32")
- T.attr(
- [B, A],
- "buffer_bind_scope",
- T.tvm_tuple(
- tile_i * 16,
- 16,
- tile_j * 16,
- 16,
- dtype="handle",
- ),
- )
- for i, j in T.grid(16, 16):
- B[i, j] = 0.0
-
- tvm.tir.analysis.verify_well_formed(mod)
-
-
-def test_buffer_bind_scope_defines_symbolic_variables():
- """The "buffer_bind_scope" attribute may define symbolic variables"""
-
- @I.ir_module
- class mod:
- @T.prim_func
- def func(A: T.Buffer([256, 256], "int32")):
- for tile_i, tile_j in T.grid(16, 16):
- elem_offset = T.int32()
- B = T.Buffer([16, 16], "int32", elem_offset=elem_offset)
- T.attr(
- [B, A],
- "buffer_bind_scope",
- T.tvm_tuple(
- tile_i * 16,
- 16,
- tile_j * 16,
- 16,
- dtype="handle",
- ),
- )
- for i, j in T.grid(16, 16):
- B[i, j] = elem_offset
-
- tvm.tir.analysis.verify_well_formed(mod)
-
-
def test_block_match_buffer_defines_buffer_obj():
"""In a block, T.match_buffer defines a buffer view"""