Lunderberg commented on code in PR #12412:
URL: https://github.com/apache/tvm/pull/12412#discussion_r958848078


##########
src/printer/tvmscript_printer.cc:
##########
@@ -100,13 +100,21 @@ class BufferUsageFinder : public StmtExprVisitor {
     StmtExprVisitor::VisitStmt_(op);
   }
 
+  void VisitStmt_(const DeclBufferNode* op) final {
+    buffers_declared_.insert(op->buffer.get());

Review Comment:
   Should we also track which buffers have gone out of scope?  If I'm 
understanding it correctly, a single `DeclBufferNode` would also allow for 
usage outside of the `DeclBufferNode::body`, where I'd expect it to only apply 
within the scope of the node.



##########
tests/python/contrib/test_ethosu/test_hoist_allocates.py:
##########
@@ -242,7 +266,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"], 
ethosu_write: T.Buffer[(2048,),
             T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 
0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 
16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 
128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 112, 12, 
placeholder_d_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 
0, dtype="handle"))
             T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4[0], 112, 
placeholder_global_2[0], dtype="handle"))
             T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5[0], 32, 
placeholder_d_global_2[0], dtype="handle"))
-            placeholder_d_global_3 = T.allocate([32], "uint8", "global")
+            placeholder_d_global_3_data = T.allocate([32], "uint8", "global")

Review Comment:
   Same question here, whether we can use a single `T.decl_buffer` call.



##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const 
BufferRealizeNode* op) {
   return Doc();
 }
 
-namespace {
-struct AllocUsage {
-  Buffer alloc_buffer;
-  Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* 
cache_ptr) {
-  Map<Var, Array<Buffer>>& cache = *cache_ptr;
-  if (!cache.count(op->buffer_var)) {
-    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const 
DeclBufferNode* decl_buffer) {

Review Comment:
   Why is the `decl_buffer` argument needed?  It looks like this pattern only 
applies when the `DeclBufferNode` is the immediate child of `AllocateNode`, so 
we could pull that part of the check into this function.  I'm thinking 
something like the following:
   
   ```c++
   bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) {
       const DeclBufferNode* decl_buffer = allocate->body.as<DeclBufferNode>();
       if(!decl_buffer) {
           return false;
       }
       // Continue as normal from here.
   }
   ```
   
   (P.S. I really like this as a way to provide the cleaner TVMScript syntax 
without immediately requiring additional TIR changes, and I'm glad that it gets 
rid of the `FindAllocateUsage`.  That felt like a hack as I was putting it in.)



##########
tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py:
##########
@@ -56,8 +56,8 @@ def main(placeholder_6: T.Buffer[(192,), "int8"], 
ethosu_conv2d_1: T.Buffer[(512
         placeholder_8 = T.buffer_decl([1], "uint8")
         placeholder_5 = T.buffer_decl([1], "uint8")
         # body
-        ethosu_conv2d_2 = T.allocate([1024], "uint8", "global")
-        ethosu_conv2d_3 = T.allocate([2048], "uint8", "global")
+        ethosu_conv2d_2 = T.decl_buffer([1024], "uint8", scope="global")

Review Comment:
   Can we remove the `scope="global"` parameter since it matches the default?



##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const 
BufferRealizeNode* op) {
   return Doc();
 }
 
-namespace {
-struct AllocUsage {
-  Buffer alloc_buffer;
-  Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* 
cache_ptr) {
-  Map<Var, Array<Buffer>>& cache = *cache_ptr;
-  if (!cache.count(op->buffer_var)) {
-    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const 
DeclBufferNode* decl_buffer) {
+  const Var& buffer_var = allocate->buffer_var;
+  const Buffer& buffer = decl_buffer->buffer;
+  if (!buffer_var.same_as(buffer->data)) {
+    return false;
   }
-  Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
-
-  auto is_exact_match = [](Buffer a, Buffer b) {
-    if (a->dtype != b->dtype) return false;
-    if (a->shape.size() != b->shape.size()) return false;
-
-    arith::Analyzer analyzer;
-    for (size_t i = 0; i < a->shape.size(); i++) {
-      if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
-        return false;
-      }
-    }
-    return true;
-  };
-
-  // If the buffer allocated via T.allocate is an exact match to the
-  // usage of the buffer later on, then that buffer is the return
-  // value of T.allocate, and no T.buffer_decl statement is needed.
-  Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, 
op->buffer_var->name_hint, 0,
-                      0, kDefault);
-  bool found_alloc_buf = false;
-  Array<Buffer> aliasing_buffers;
-  for (const auto& buf : buffer_usage) {
-    if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
-      alloc_buffer = buf;
-      found_alloc_buf = true;
-    } else {
-      aliasing_buffers.push_back(buf);
+  if (allocate->dtype != buffer->dtype) {
+    return false;
+  }
+  if (!is_one(allocate->condition)) {
+    return false;
+  }
+  if (allocate->annotations.size()) {
+    return false;
+  }
+  if (allocate->extents.size() != buffer->shape.size()) {
+    return false;
+  }
+  tir::ExprDeepEqual expr_equal;
+  for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
+    if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
+      return false;
     }
   }
-
-  return AllocUsage{alloc_buffer, aliasing_buffers};
+  return true;
 }
-}  // namespace
 
 Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
-  auto usage = FindAllocateUsage(op, &buffer_var_usage_);
-  Buffer& alloc_buffer = usage.alloc_buffer;
-  Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
-  buf_not_in_headers_.insert(alloc_buffer.get());
-  var_not_in_headers_.insert(alloc_buffer->data.get());
+  var_not_in_headers_.insert(op->buffer_var.get());
+
+  if (!buffer_var_usage_.count(op->buffer_var)) {
+    buffer_var_usage_ = 
BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+  }
+  Array<Buffer> buffer_usage = 
buffer_var_usage_.Get(op->buffer_var).value_or({});
+
+  if (buffer_usage.empty()) {

Review Comment:
   What is the benefit of the check on `buffer_usage.empty()`.  It looks like 
it would prevent the Allocate/DeclBuffer pattern from being printed whenever 
the buffer var is used.



##########
tests/python/contrib/test_ethosu/test_copy_compute_reordering.py:
##########
@@ -40,14 +40,14 @@ def main() -> None:
         buffer9 = T.buffer_decl([32], "uint8")
         buffer10 = T.buffer_decl([2048], "int8")
         # body
-        p1 = T.allocate([128], "uint8", "global")
-        p2 = T.allocate([112], "uint8", "global")
-        p3 = T.allocate([112], "uint8", "global")
-        p4 = T.allocate([32], "uint8", "global")
-        p5 = T.allocate([32], "uint8", "global")
-        p6 = T.allocate([32], "uint8", "global")
-        p7 = T.allocate([112], "uint8", "global")
-        p8 = T.allocate([32], "uint8", "global")
+        p1 = T.decl_buffer([128], "uint8", scope="global")

Review Comment:
   Since `"global"` is the default value for scope, can we remove the `scope = 
"global"` parameter?  It looks like it was only present before because there 
was no default scope for `allocate()`.



##########
tests/python/contrib/test_ethosu/test_merge_constants.py:
##########
@@ -44,8 +44,10 @@ def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: 
T.Buffer[(32,), "uint8"])
             buffer1 = T.buffer_decl([8192], "int8")
             buffer10 = T.buffer_decl([2048], "int8")
             # body
-            p1 = T.allocate([128], "uint8", "global")
-            p4 = T.allocate([32], "uint8", "global")
+            p1_data = T.allocate([128], "uint8", "global")

Review Comment:
   Can this use `T.decl_buffer`, here and lower in the file?



##########
tests/python/unittest/test_tir_renew_defs.py:
##########
@@ -135,7 +135,8 @@ def test_undefined_buffer():
     @T.prim_func
     def access_alloc():
         # Buffer A should be remapped
-        A = T.allocate([128], "float16", "global")
+        A_data = T.allocate([128], "float16", "global")

Review Comment:
   Can this be done through `T.decl_buffer` without the `T.allocate` statement?



##########
tests/python/contrib/test_ethosu/test_hoist_allocates.py:
##########
@@ -227,13 +244,20 @@ def main(placeholder: T.Buffer[(8192,), "int8"], 
ethosu_write: T.Buffer[(2048,),
             T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", 
data=placeholder.data)
             T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", 
data=ethosu_write.data)
             # body
-            placeholder_global = T.allocate([128], "uint8", "global")
-            placeholder_global_1 = T.allocate([112], "uint8", "global")
-            placeholder_global_2 = T.allocate([112], "uint8", "global")
-            placeholder_d_global = T.allocate([32], "uint8", "global")
-            placeholder_d_global_1 = T.allocate([32], "uint8", "global")
-            placeholder_d_global_2 = T.allocate([32], "uint8", "global")
-            placeholder_global_3 = T.allocate([112], "uint8", "global")
+            placeholder_global_data = T.allocate([128], "uint8", "global")

Review Comment:
   Do these require the separate `T.allocate` call?   This section looks like 
it matches the allocate/decl_buffer pattern.



##########
tests/python/unittest/test_tir_transform_unroll_loop.py:
##########
@@ -117,16 +117,19 @@ class before:
         @T.prim_func
         def main():
             for i in T.unroll(2):
-                with T.allocate([16], "float32", "global") as buf:
+                with T.allocate([16], "float32", "global") as buf_data:
+                    buf = T.buffer_decl(shape=[16], dtype="float32", 
scope="global", data=buf_data)
                     buf[0] = 0.0
 
     @tvm.script.ir_module
     class expected:
         @T.prim_func
         def main():
-            with T.allocate([16], "float32", "global") as buf1:
+            with T.allocate([16], "float32", "global") as buf1_data:

Review Comment:
   Since the `T.decl_buffer` is defined as a scope handler, it could be used in 
this context too, correct?



##########
tests/python/unittest/test_tir_transform_flatten_buffer.py:
##########
@@ -33,7 +33,8 @@ def elementwise_func(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (16, 16), "float32")
     C = T.match_buffer(c, (16, 16), "float32")
     for i in T.serial(0, 16):
-        B_new = T.allocate([1, 16], "float32", "global")
+        B_new_data = T.allocate([1, 16], "float32", "global")

Review Comment:
   Can this be done with `T.decl_buffer` without the `data` argument?



##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const 
BufferRealizeNode* op) {
   return Doc();
 }
 
-namespace {
-struct AllocUsage {
-  Buffer alloc_buffer;
-  Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* 
cache_ptr) {
-  Map<Var, Array<Buffer>>& cache = *cache_ptr;
-  if (!cache.count(op->buffer_var)) {
-    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const 
DeclBufferNode* decl_buffer) {
+  const Var& buffer_var = allocate->buffer_var;
+  const Buffer& buffer = decl_buffer->buffer;
+  if (!buffer_var.same_as(buffer->data)) {
+    return false;
   }
-  Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
-
-  auto is_exact_match = [](Buffer a, Buffer b) {
-    if (a->dtype != b->dtype) return false;
-    if (a->shape.size() != b->shape.size()) return false;
-
-    arith::Analyzer analyzer;
-    for (size_t i = 0; i < a->shape.size(); i++) {
-      if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
-        return false;
-      }
-    }
-    return true;
-  };
-
-  // If the buffer allocated via T.allocate is an exact match to the
-  // usage of the buffer later on, then that buffer is the return
-  // value of T.allocate, and no T.buffer_decl statement is needed.
-  Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, 
op->buffer_var->name_hint, 0,
-                      0, kDefault);
-  bool found_alloc_buf = false;
-  Array<Buffer> aliasing_buffers;
-  for (const auto& buf : buffer_usage) {
-    if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
-      alloc_buffer = buf;
-      found_alloc_buf = true;
-    } else {
-      aliasing_buffers.push_back(buf);
+  if (allocate->dtype != buffer->dtype) {
+    return false;
+  }
+  if (!is_one(allocate->condition)) {
+    return false;
+  }
+  if (allocate->annotations.size()) {
+    return false;
+  }
+  if (allocate->extents.size() != buffer->shape.size()) {
+    return false;
+  }
+  tir::ExprDeepEqual expr_equal;
+  for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
+    if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
+      return false;
     }
   }
-
-  return AllocUsage{alloc_buffer, aliasing_buffers};
+  return true;
 }
-}  // namespace
 
 Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
-  auto usage = FindAllocateUsage(op, &buffer_var_usage_);
-  Buffer& alloc_buffer = usage.alloc_buffer;
-  Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
-  buf_not_in_headers_.insert(alloc_buffer.get());
-  var_not_in_headers_.insert(alloc_buffer->data.get());
+  var_not_in_headers_.insert(op->buffer_var.get());
+
+  if (!buffer_var_usage_.count(op->buffer_var)) {
+    buffer_var_usage_ = 
BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+  }
+  Array<Buffer> buffer_usage = 
buffer_var_usage_.Get(op->buffer_var).value_or({});
+
+  if (buffer_usage.empty()) {
+    if (const DeclBufferNode* decl_buffer = op->body.as<DeclBufferNode>()) {
+      if (IsAllocateDeclBufferPattern(op, decl_buffer)) {
+        // As a syntax sugar, we identify the pattern of Allocate and 
DeclBuffer and print a single
+        // DeclBuffer statement. It is intentionally to call `Print` instead 
of `PrintBody` here to
+        // delegate the printing of the current node to `DeclBufferNode` while 
maintaining the
+        // same value of `current_num_` and `num_child_`.
+        return Print(op->body);

Review Comment:
   This branch skips the call to `PrintNonHeaderBufferDeclarations` and the 
checks for the `with` syntax.  It looks like the `with` syntax is handled 
inside the printer for `DeclBufferNode`, but the 
`PrintNonHeaderBufferDeclarations` does not.  As a result, the `T.buffer_decl` 
statement for any buffers that alias `op->buffer_var` would erroneously show up 
at the function header, instead of being inside the `DeclBuffer` node.
   
   This won't be an issue once the `DeclBufferNode` is mandatory, but could 
cause bugs until then.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to