wrongtest-intellif commented on code in PR #17423:
URL: https://github.com/apache/tvm/pull/17423#discussion_r1778229103
##########
src/tir/schedule/primitive/annotate_buffer_access.cc:
##########
@@ -0,0 +1,155 @@
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+class AnnotateRegionRewriter : public StmtExprMutator {
+ public:
+ AnnotateRegionRewriter(Buffer buffer, int buffer_index, BufferRegion
new_region,
+ BufferIndexType buffer_index_type)
+ : buffer_(buffer),
+ buffer_index_(buffer_index),
+ new_region_(new_region),
+ buffer_index_type_(buffer_index_type) {}
+
+ Stmt VisitStmt_(const BlockNode* op) final {
+ Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+
+ Array<BufferRegion> regions =
+ buffer_index_type_ == BufferIndexType::kWrite ? block->writes :
block->reads;
+ ICHECK_GE(buffer_index_, 0) << "Buffer index must be non-negative";
+ ICHECK_LT(buffer_index_, static_cast<int>(regions.size())) << "Buffer
index out of range";
+ regions.Set(buffer_index_, new_region_);
+
+ ObjectPtr<BlockNode> n = CopyOnWrite(block.get());
+ if (buffer_index_type_ == BufferIndexType::kWrite) {
+ n->writes = std::move(regions);
+ } else {
+ n->reads = std::move(regions);
+ }
+
+ // Annotate the block with explicit_read_region or explicit_write_region
+ Map<String, ObjectRef> new_annotations = n->annotations;
+ String annotation_key = buffer_index_type_ == BufferIndexType::kWrite
+ ? attr::explicit_write_region
+ : attr::explicit_read_region;
+ if (new_annotations.count(annotation_key)) {
+ Array<Integer> buffer_indices =
Downcast<Array<Integer>>(new_annotations[annotation_key]);
+ bool found = false;
+ for (const Integer& index : buffer_indices) {
+ if (index->value == buffer_index_) {
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ buffer_indices.push_back(Integer(buffer_index_));
+ new_annotations.Set(annotation_key, buffer_indices);
+ }
+ } else {
+ new_annotations.Set(annotation_key,
Array<Integer>{Integer(buffer_index_)});
+ }
+ n->annotations = std::move(new_annotations);
+
+ return Block(n);
+ }
+
+ private:
+ Buffer buffer_;
+ int buffer_index_;
+ BufferRegion new_region_;
+ BufferIndexType buffer_index_type_;
+};
+
+void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int
buffer_index,
+ BufferIndexType buffer_index_type, const IndexMap&
index_map) {
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+ Buffer buffer = GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index,
buffer_index_type);
+
+ arith::Analyzer analyzer;
+ Array<PrimExpr> block_iter_vars;
+ for (const IterVar& iter_var : block->iter_vars) {
+ block_iter_vars.push_back(iter_var->var);
+ }
+ Array<PrimExpr> new_indices = index_map->MapIndices(block_iter_vars,
&analyzer);
+ ICHECK_EQ(new_indices.size() % 2, 0) << "The size of new_indices should be
even.";
+ Array<Range> new_ranges;
+ for (size_t i = 0; i < new_indices.size(); i += 2) {
+ // (iter_var, iter_var) represents a single point
+ if (analyzer.CanProveEqual(new_indices[i], new_indices[i + 1])) {
+ new_ranges.push_back(Range::FromMinExtent(new_indices[i], 1));
+ }
+ // (begin, end) represents a region
+ else {
+ new_ranges.push_back(Range::FromMinExtent(
+ new_indices[i], analyzer.Simplify(new_indices[i + 1] -
new_indices[i])));
Review Comment:
should it be `(end - begin + 1)`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]