vinx13 commented on code in PR #12412:
URL: https://github.com/apache/tvm/pull/12412#discussion_r958892025
##########
tests/python/unittest/test_tir_transform_unroll_loop.py:
##########
@@ -117,16 +117,19 @@ class before:
@T.prim_func
def main():
for i in T.unroll(2):
- with T.allocate([16], "float32", "global") as buf:
+ with T.allocate([16], "float32", "global") as buf_data:
+ buf = T.buffer_decl(shape=[16], dtype="float32",
scope="global", data=buf_data)
buf[0] = 0.0
@tvm.script.ir_module
class expected:
@T.prim_func
def main():
- with T.allocate([16], "float32", "global") as buf1:
+ with T.allocate([16], "float32", "global") as buf1_data:
Review Comment:
Yes. The reason I kept `T.allocate` here is that the pass also need some
updates before we can use `T.decl_buffer`. (The scope of this PR is to only
update existing TVM scripts without touching related passes to minimize changes)
##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const
BufferRealizeNode* op) {
return Doc();
}
-namespace {
-struct AllocUsage {
- Buffer alloc_buffer;
- Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>*
cache_ptr) {
- Map<Var, Array<Buffer>>& cache = *cache_ptr;
- if (!cache.count(op->buffer_var)) {
- cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const
DeclBufferNode* decl_buffer) {
+ const Var& buffer_var = allocate->buffer_var;
+ const Buffer& buffer = decl_buffer->buffer;
+ if (!buffer_var.same_as(buffer->data)) {
+ return false;
}
- Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
-
- auto is_exact_match = [](Buffer a, Buffer b) {
- if (a->dtype != b->dtype) return false;
- if (a->shape.size() != b->shape.size()) return false;
-
- arith::Analyzer analyzer;
- for (size_t i = 0; i < a->shape.size(); i++) {
- if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
- return false;
- }
- }
- return true;
- };
-
- // If the buffer allocated via T.allocate is an exact match to the
- // usage of the buffer later on, then that buffer is the return
- // value of T.allocate, and no T.buffer_decl statement is needed.
- Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0,
op->buffer_var->name_hint, 0,
- 0, kDefault);
- bool found_alloc_buf = false;
- Array<Buffer> aliasing_buffers;
- for (const auto& buf : buffer_usage) {
- if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
- alloc_buffer = buf;
- found_alloc_buf = true;
- } else {
- aliasing_buffers.push_back(buf);
+ if (allocate->dtype != buffer->dtype) {
+ return false;
+ }
+ if (!is_one(allocate->condition)) {
+ return false;
+ }
+ if (allocate->annotations.size()) {
+ return false;
+ }
+ if (allocate->extents.size() != buffer->shape.size()) {
+ return false;
+ }
+ tir::ExprDeepEqual expr_equal;
+ for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
+ if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
+ return false;
}
}
-
- return AllocUsage{alloc_buffer, aliasing_buffers};
+ return true;
}
-} // namespace
Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
- auto usage = FindAllocateUsage(op, &buffer_var_usage_);
- Buffer& alloc_buffer = usage.alloc_buffer;
- Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
- buf_not_in_headers_.insert(alloc_buffer.get());
- var_not_in_headers_.insert(alloc_buffer->data.get());
+ var_not_in_headers_.insert(op->buffer_var.get());
+
+ if (!buffer_var_usage_.count(op->buffer_var)) {
+ buffer_var_usage_ =
BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+ }
+ Array<Buffer> buffer_usage =
buffer_var_usage_.Get(op->buffer_var).value_or({});
+
+ if (buffer_usage.empty()) {
+ if (const DeclBufferNode* decl_buffer = op->body.as<DeclBufferNode>()) {
+ if (IsAllocateDeclBufferPattern(op, decl_buffer)) {
+ // As a syntax sugar, we identify the pattern of Allocate and
DeclBuffer and print a single
+ // DeclBuffer statement. It is intentionally to call `Print` instead
of `PrintBody` here to
+ // delegate the printing of the current node to `DeclBufferNode` while
maintaining the
+ // same value of `current_num_` and `num_child_`.
+ return Print(op->body);
Review Comment:
That's correct. So I checked `buffer_usage.empty()` above to make sure
`T.buffer_decl` is not needed
##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const
BufferRealizeNode* op) {
return Doc();
}
-namespace {
-struct AllocUsage {
- Buffer alloc_buffer;
- Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>*
cache_ptr) {
- Map<Var, Array<Buffer>>& cache = *cache_ptr;
- if (!cache.count(op->buffer_var)) {
- cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const
DeclBufferNode* decl_buffer) {
+ const Var& buffer_var = allocate->buffer_var;
+ const Buffer& buffer = decl_buffer->buffer;
+ if (!buffer_var.same_as(buffer->data)) {
+ return false;
}
- Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
-
- auto is_exact_match = [](Buffer a, Buffer b) {
- if (a->dtype != b->dtype) return false;
- if (a->shape.size() != b->shape.size()) return false;
-
- arith::Analyzer analyzer;
- for (size_t i = 0; i < a->shape.size(); i++) {
- if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
- return false;
- }
- }
- return true;
- };
-
- // If the buffer allocated via T.allocate is an exact match to the
- // usage of the buffer later on, then that buffer is the return
- // value of T.allocate, and no T.buffer_decl statement is needed.
- Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0,
op->buffer_var->name_hint, 0,
- 0, kDefault);
- bool found_alloc_buf = false;
- Array<Buffer> aliasing_buffers;
- for (const auto& buf : buffer_usage) {
- if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
- alloc_buffer = buf;
- found_alloc_buf = true;
- } else {
- aliasing_buffers.push_back(buf);
+ if (allocate->dtype != buffer->dtype) {
+ return false;
+ }
+ if (!is_one(allocate->condition)) {
+ return false;
+ }
+ if (allocate->annotations.size()) {
+ return false;
+ }
+ if (allocate->extents.size() != buffer->shape.size()) {
+ return false;
+ }
+ tir::ExprDeepEqual expr_equal;
+ for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
+ if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
+ return false;
}
}
-
- return AllocUsage{alloc_buffer, aliasing_buffers};
+ return true;
}
-} // namespace
Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
- auto usage = FindAllocateUsage(op, &buffer_var_usage_);
- Buffer& alloc_buffer = usage.alloc_buffer;
- Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
- buf_not_in_headers_.insert(alloc_buffer.get());
- var_not_in_headers_.insert(alloc_buffer->data.get());
+ var_not_in_headers_.insert(op->buffer_var.get());
+
+ if (!buffer_var_usage_.count(op->buffer_var)) {
+ buffer_var_usage_ =
BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+ }
+ Array<Buffer> buffer_usage =
buffer_var_usage_.Get(op->buffer_var).value_or({});
+
+ if (buffer_usage.empty()) {
Review Comment:
Usage in `DeclBuffer` is excluded from the result of `BufferUsageFinder`
##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const
BufferRealizeNode* op) {
return Doc();
}
-namespace {
-struct AllocUsage {
- Buffer alloc_buffer;
- Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>*
cache_ptr) {
- Map<Var, Array<Buffer>>& cache = *cache_ptr;
- if (!cache.count(op->buffer_var)) {
- cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const
DeclBufferNode* decl_buffer) {
Review Comment:
Nice catch! Indeed it's clearer without `decl_buffer` argument
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]