vinx13 commented on code in PR #12412:
URL: https://github.com/apache/tvm/pull/12412#discussion_r958892025


##########
tests/python/unittest/test_tir_transform_unroll_loop.py:
##########
@@ -117,16 +117,19 @@ class before:
         @T.prim_func
         def main():
             for i in T.unroll(2):
-                with T.allocate([16], "float32", "global") as buf:
+                with T.allocate([16], "float32", "global") as buf_data:
+                    buf = T.buffer_decl(shape=[16], dtype="float32", 
scope="global", data=buf_data)
                     buf[0] = 0.0
 
     @tvm.script.ir_module
     class expected:
         @T.prim_func
         def main():
-            with T.allocate([16], "float32", "global") as buf1:
+            with T.allocate([16], "float32", "global") as buf1_data:

Review Comment:
   Yes. The reason I kept `T.allocate` here is that the pass also need some 
updates before we can use `T.decl_buffer`. (The scope of this PR is to only 
update existing TVM scripts without touching related passes to minimize changes)



##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const 
BufferRealizeNode* op) {
   return Doc();
 }
 
-namespace {
-struct AllocUsage {
-  Buffer alloc_buffer;
-  Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* 
cache_ptr) {
-  Map<Var, Array<Buffer>>& cache = *cache_ptr;
-  if (!cache.count(op->buffer_var)) {
-    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const 
DeclBufferNode* decl_buffer) {
+  const Var& buffer_var = allocate->buffer_var;
+  const Buffer& buffer = decl_buffer->buffer;
+  if (!buffer_var.same_as(buffer->data)) {
+    return false;
   }
-  Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
-
-  auto is_exact_match = [](Buffer a, Buffer b) {
-    if (a->dtype != b->dtype) return false;
-    if (a->shape.size() != b->shape.size()) return false;
-
-    arith::Analyzer analyzer;
-    for (size_t i = 0; i < a->shape.size(); i++) {
-      if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
-        return false;
-      }
-    }
-    return true;
-  };
-
-  // If the buffer allocated via T.allocate is an exact match to the
-  // usage of the buffer later on, then that buffer is the return
-  // value of T.allocate, and no T.buffer_decl statement is needed.
-  Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, 
op->buffer_var->name_hint, 0,
-                      0, kDefault);
-  bool found_alloc_buf = false;
-  Array<Buffer> aliasing_buffers;
-  for (const auto& buf : buffer_usage) {
-    if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
-      alloc_buffer = buf;
-      found_alloc_buf = true;
-    } else {
-      aliasing_buffers.push_back(buf);
+  if (allocate->dtype != buffer->dtype) {
+    return false;
+  }
+  if (!is_one(allocate->condition)) {
+    return false;
+  }
+  if (allocate->annotations.size()) {
+    return false;
+  }
+  if (allocate->extents.size() != buffer->shape.size()) {
+    return false;
+  }
+  tir::ExprDeepEqual expr_equal;
+  for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
+    if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
+      return false;
     }
   }
-
-  return AllocUsage{alloc_buffer, aliasing_buffers};
+  return true;
 }
-}  // namespace
 
 Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
-  auto usage = FindAllocateUsage(op, &buffer_var_usage_);
-  Buffer& alloc_buffer = usage.alloc_buffer;
-  Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
-  buf_not_in_headers_.insert(alloc_buffer.get());
-  var_not_in_headers_.insert(alloc_buffer->data.get());
+  var_not_in_headers_.insert(op->buffer_var.get());
+
+  if (!buffer_var_usage_.count(op->buffer_var)) {
+    buffer_var_usage_ = 
BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+  }
+  Array<Buffer> buffer_usage = 
buffer_var_usage_.Get(op->buffer_var).value_or({});
+
+  if (buffer_usage.empty()) {
+    if (const DeclBufferNode* decl_buffer = op->body.as<DeclBufferNode>()) {
+      if (IsAllocateDeclBufferPattern(op, decl_buffer)) {
+        // As a syntax sugar, we identify the pattern of Allocate and 
DeclBuffer and print a single
+        // DeclBuffer statement. It is intentionally to call `Print` instead 
of `PrintBody` here to
+        // delegate the printing of the current node to `DeclBufferNode` while 
maintaining the
+        // same value of `current_num_` and `num_child_`.
+        return Print(op->body);

Review Comment:
   That's correct. So I checked `buffer_usage.empty()` above to make sure 
`T.buffer_decl` is not needed



##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const 
BufferRealizeNode* op) {
   return Doc();
 }
 
-namespace {
-struct AllocUsage {
-  Buffer alloc_buffer;
-  Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* 
cache_ptr) {
-  Map<Var, Array<Buffer>>& cache = *cache_ptr;
-  if (!cache.count(op->buffer_var)) {
-    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const 
DeclBufferNode* decl_buffer) {
+  const Var& buffer_var = allocate->buffer_var;
+  const Buffer& buffer = decl_buffer->buffer;
+  if (!buffer_var.same_as(buffer->data)) {
+    return false;
   }
-  Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
-
-  auto is_exact_match = [](Buffer a, Buffer b) {
-    if (a->dtype != b->dtype) return false;
-    if (a->shape.size() != b->shape.size()) return false;
-
-    arith::Analyzer analyzer;
-    for (size_t i = 0; i < a->shape.size(); i++) {
-      if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
-        return false;
-      }
-    }
-    return true;
-  };
-
-  // If the buffer allocated via T.allocate is an exact match to the
-  // usage of the buffer later on, then that buffer is the return
-  // value of T.allocate, and no T.buffer_decl statement is needed.
-  Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, 
op->buffer_var->name_hint, 0,
-                      0, kDefault);
-  bool found_alloc_buf = false;
-  Array<Buffer> aliasing_buffers;
-  for (const auto& buf : buffer_usage) {
-    if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
-      alloc_buffer = buf;
-      found_alloc_buf = true;
-    } else {
-      aliasing_buffers.push_back(buf);
+  if (allocate->dtype != buffer->dtype) {
+    return false;
+  }
+  if (!is_one(allocate->condition)) {
+    return false;
+  }
+  if (allocate->annotations.size()) {
+    return false;
+  }
+  if (allocate->extents.size() != buffer->shape.size()) {
+    return false;
+  }
+  tir::ExprDeepEqual expr_equal;
+  for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
+    if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
+      return false;
     }
   }
-
-  return AllocUsage{alloc_buffer, aliasing_buffers};
+  return true;
 }
-}  // namespace
 
 Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
-  auto usage = FindAllocateUsage(op, &buffer_var_usage_);
-  Buffer& alloc_buffer = usage.alloc_buffer;
-  Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
-  buf_not_in_headers_.insert(alloc_buffer.get());
-  var_not_in_headers_.insert(alloc_buffer->data.get());
+  var_not_in_headers_.insert(op->buffer_var.get());
+
+  if (!buffer_var_usage_.count(op->buffer_var)) {
+    buffer_var_usage_ = 
BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+  }
+  Array<Buffer> buffer_usage = 
buffer_var_usage_.Get(op->buffer_var).value_or({});
+
+  if (buffer_usage.empty()) {

Review Comment:
   Usage in `DeclBuffer` is excluded from the result of `BufferUsageFinder`



##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const 
BufferRealizeNode* op) {
   return Doc();
 }
 
-namespace {
-struct AllocUsage {
-  Buffer alloc_buffer;
-  Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* 
cache_ptr) {
-  Map<Var, Array<Buffer>>& cache = *cache_ptr;
-  if (!cache.count(op->buffer_var)) {
-    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const 
DeclBufferNode* decl_buffer) {

Review Comment:
   Nice catch! Indeed it's clearer without `decl_buffer` argument



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to