Lunderberg commented on code in PR #12412:
URL: https://github.com/apache/tvm/pull/12412#discussion_r958848078
##########
src/printer/tvmscript_printer.cc:
##########
@@ -100,13 +100,21 @@ class BufferUsageFinder : public StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op);
}
+ void VisitStmt_(const DeclBufferNode* op) final {
+ buffers_declared_.insert(op->buffer.get());
Review Comment:
Should we also track which buffers have gone out of scope? If I'm
understanding it correctly, a single `DeclBufferNode` would also allow for
usage outside of the `DeclBufferNode::body`, where I'd expect it to only apply
within the scope of the node.
##########
tests/python/contrib/test_ethosu/test_hoist_allocates.py:
##########
@@ -242,7 +266,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"],
ethosu_write: T.Buffer[(2048,),
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16,
0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8",
16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC",
128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 112, 12,
placeholder_d_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0,
0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4[0], 112,
placeholder_global_2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5[0], 32,
placeholder_d_global_2[0], dtype="handle"))
- placeholder_d_global_3 = T.allocate([32], "uint8", "global")
+ placeholder_d_global_3_data = T.allocate([32], "uint8", "global")
Review Comment:
Same question here, whether we can use a single `T.decl_buffer` call.
##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const
BufferRealizeNode* op) {
return Doc();
}
-namespace {
-struct AllocUsage {
- Buffer alloc_buffer;
- Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>*
cache_ptr) {
- Map<Var, Array<Buffer>>& cache = *cache_ptr;
- if (!cache.count(op->buffer_var)) {
- cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const
DeclBufferNode* decl_buffer) {
Review Comment:
Why is the `decl_buffer` argument needed? It looks like this pattern only
applies when the `DeclBufferNode` is the immediate child of `AllocateNode`, so
we could pull that part of the check into this function. I'm thinking
something like the following:
```c++
bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) {
const DeclBufferNode* decl_buffer = allocate->body.as<DeclBufferNode>();
if(!decl_buffer) {
return false;
}
// Continue as normal from here.
}
```
(P.S. I really like this as a way to provide the cleaner TVMScript syntax
without immediately requiring additional TIR changes, and I'm glad that it gets
rid of the `FindAllocateUsage`. That felt like a hack as I was putting it in.)
##########
tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py:
##########
@@ -56,8 +56,8 @@ def main(placeholder_6: T.Buffer[(192,), "int8"],
ethosu_conv2d_1: T.Buffer[(512
placeholder_8 = T.buffer_decl([1], "uint8")
placeholder_5 = T.buffer_decl([1], "uint8")
# body
- ethosu_conv2d_2 = T.allocate([1024], "uint8", "global")
- ethosu_conv2d_3 = T.allocate([2048], "uint8", "global")
+ ethosu_conv2d_2 = T.decl_buffer([1024], "uint8", scope="global")
Review Comment:
Can we remove the `scope="global"` parameter since it matches the default?
##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const
BufferRealizeNode* op) {
return Doc();
}
-namespace {
-struct AllocUsage {
- Buffer alloc_buffer;
- Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>*
cache_ptr) {
- Map<Var, Array<Buffer>>& cache = *cache_ptr;
- if (!cache.count(op->buffer_var)) {
- cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const
DeclBufferNode* decl_buffer) {
+ const Var& buffer_var = allocate->buffer_var;
+ const Buffer& buffer = decl_buffer->buffer;
+ if (!buffer_var.same_as(buffer->data)) {
+ return false;
}
- Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
-
- auto is_exact_match = [](Buffer a, Buffer b) {
- if (a->dtype != b->dtype) return false;
- if (a->shape.size() != b->shape.size()) return false;
-
- arith::Analyzer analyzer;
- for (size_t i = 0; i < a->shape.size(); i++) {
- if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
- return false;
- }
- }
- return true;
- };
-
- // If the buffer allocated via T.allocate is an exact match to the
- // usage of the buffer later on, then that buffer is the return
- // value of T.allocate, and no T.buffer_decl statement is needed.
- Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0,
op->buffer_var->name_hint, 0,
- 0, kDefault);
- bool found_alloc_buf = false;
- Array<Buffer> aliasing_buffers;
- for (const auto& buf : buffer_usage) {
- if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
- alloc_buffer = buf;
- found_alloc_buf = true;
- } else {
- aliasing_buffers.push_back(buf);
+ if (allocate->dtype != buffer->dtype) {
+ return false;
+ }
+ if (!is_one(allocate->condition)) {
+ return false;
+ }
+ if (allocate->annotations.size()) {
+ return false;
+ }
+ if (allocate->extents.size() != buffer->shape.size()) {
+ return false;
+ }
+ tir::ExprDeepEqual expr_equal;
+ for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
+ if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
+ return false;
}
}
-
- return AllocUsage{alloc_buffer, aliasing_buffers};
+ return true;
}
-} // namespace
Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
- auto usage = FindAllocateUsage(op, &buffer_var_usage_);
- Buffer& alloc_buffer = usage.alloc_buffer;
- Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
- buf_not_in_headers_.insert(alloc_buffer.get());
- var_not_in_headers_.insert(alloc_buffer->data.get());
+ var_not_in_headers_.insert(op->buffer_var.get());
+
+ if (!buffer_var_usage_.count(op->buffer_var)) {
+ buffer_var_usage_ =
BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+ }
+ Array<Buffer> buffer_usage =
buffer_var_usage_.Get(op->buffer_var).value_or({});
+
+ if (buffer_usage.empty()) {
Review Comment:
What is the benefit of the check on `buffer_usage.empty()`. It looks like
it would prevent the Allocate/DeclBuffer pattern from being printed whenever
the buffer var is used.
##########
tests/python/contrib/test_ethosu/test_copy_compute_reordering.py:
##########
@@ -40,14 +40,14 @@ def main() -> None:
buffer9 = T.buffer_decl([32], "uint8")
buffer10 = T.buffer_decl([2048], "int8")
# body
- p1 = T.allocate([128], "uint8", "global")
- p2 = T.allocate([112], "uint8", "global")
- p3 = T.allocate([112], "uint8", "global")
- p4 = T.allocate([32], "uint8", "global")
- p5 = T.allocate([32], "uint8", "global")
- p6 = T.allocate([32], "uint8", "global")
- p7 = T.allocate([112], "uint8", "global")
- p8 = T.allocate([32], "uint8", "global")
+ p1 = T.decl_buffer([128], "uint8", scope="global")
Review Comment:
Since `"global"` is the default value for scope, can we remove the `scope =
"global"` parameter? It looks like it was only present before because there
was no default scope for `allocate()`.
##########
tests/python/contrib/test_ethosu/test_merge_constants.py:
##########
@@ -44,8 +44,10 @@ def main(buffer2: T.Buffer[(128,), "uint8"], buffer3:
T.Buffer[(32,), "uint8"])
buffer1 = T.buffer_decl([8192], "int8")
buffer10 = T.buffer_decl([2048], "int8")
# body
- p1 = T.allocate([128], "uint8", "global")
- p4 = T.allocate([32], "uint8", "global")
+ p1_data = T.allocate([128], "uint8", "global")
Review Comment:
Can this use `T.decl_buffer`, here and lower in the file?
##########
tests/python/unittest/test_tir_renew_defs.py:
##########
@@ -135,7 +135,8 @@ def test_undefined_buffer():
@T.prim_func
def access_alloc():
# Buffer A should be remapped
- A = T.allocate([128], "float16", "global")
+ A_data = T.allocate([128], "float16", "global")
Review Comment:
Can this be done through `T.decl_buffer` without the `T.allocate` statement?
##########
tests/python/contrib/test_ethosu/test_hoist_allocates.py:
##########
@@ -227,13 +244,20 @@ def main(placeholder: T.Buffer[(8192,), "int8"],
ethosu_write: T.Buffer[(2048,),
T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8",
data=placeholder.data)
T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8",
data=ethosu_write.data)
# body
- placeholder_global = T.allocate([128], "uint8", "global")
- placeholder_global_1 = T.allocate([112], "uint8", "global")
- placeholder_global_2 = T.allocate([112], "uint8", "global")
- placeholder_d_global = T.allocate([32], "uint8", "global")
- placeholder_d_global_1 = T.allocate([32], "uint8", "global")
- placeholder_d_global_2 = T.allocate([32], "uint8", "global")
- placeholder_global_3 = T.allocate([112], "uint8", "global")
+ placeholder_global_data = T.allocate([128], "uint8", "global")
Review Comment:
Do these require the separate `T.allocate` call? This section looks like
it matches the allocate/decl_buffer pattern.
##########
tests/python/unittest/test_tir_transform_unroll_loop.py:
##########
@@ -117,16 +117,19 @@ class before:
@T.prim_func
def main():
for i in T.unroll(2):
- with T.allocate([16], "float32", "global") as buf:
+ with T.allocate([16], "float32", "global") as buf_data:
+ buf = T.buffer_decl(shape=[16], dtype="float32",
scope="global", data=buf_data)
buf[0] = 0.0
@tvm.script.ir_module
class expected:
@T.prim_func
def main():
- with T.allocate([16], "float32", "global") as buf1:
+ with T.allocate([16], "float32", "global") as buf1_data:
Review Comment:
Since the `T.decl_buffer` is defined as a scope handler, it could be used in
this context too, correct?
##########
tests/python/unittest/test_tir_transform_flatten_buffer.py:
##########
@@ -33,7 +33,8 @@ def elementwise_func(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
C = T.match_buffer(c, (16, 16), "float32")
for i in T.serial(0, 16):
- B_new = T.allocate([1, 16], "float32", "global")
+ B_new_data = T.allocate([1, 16], "float32", "global")
Review Comment:
Can this be done with `T.decl_buffer` without the `data` argument?
##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const
BufferRealizeNode* op) {
return Doc();
}
-namespace {
-struct AllocUsage {
- Buffer alloc_buffer;
- Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>*
cache_ptr) {
- Map<Var, Array<Buffer>>& cache = *cache_ptr;
- if (!cache.count(op->buffer_var)) {
- cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const
DeclBufferNode* decl_buffer) {
+ const Var& buffer_var = allocate->buffer_var;
+ const Buffer& buffer = decl_buffer->buffer;
+ if (!buffer_var.same_as(buffer->data)) {
+ return false;
}
- Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
-
- auto is_exact_match = [](Buffer a, Buffer b) {
- if (a->dtype != b->dtype) return false;
- if (a->shape.size() != b->shape.size()) return false;
-
- arith::Analyzer analyzer;
- for (size_t i = 0; i < a->shape.size(); i++) {
- if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
- return false;
- }
- }
- return true;
- };
-
- // If the buffer allocated via T.allocate is an exact match to the
- // usage of the buffer later on, then that buffer is the return
- // value of T.allocate, and no T.buffer_decl statement is needed.
- Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0,
op->buffer_var->name_hint, 0,
- 0, kDefault);
- bool found_alloc_buf = false;
- Array<Buffer> aliasing_buffers;
- for (const auto& buf : buffer_usage) {
- if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
- alloc_buffer = buf;
- found_alloc_buf = true;
- } else {
- aliasing_buffers.push_back(buf);
+ if (allocate->dtype != buffer->dtype) {
+ return false;
+ }
+ if (!is_one(allocate->condition)) {
+ return false;
+ }
+ if (allocate->annotations.size()) {
+ return false;
+ }
+ if (allocate->extents.size() != buffer->shape.size()) {
+ return false;
+ }
+ tir::ExprDeepEqual expr_equal;
+ for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
+ if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
+ return false;
}
}
-
- return AllocUsage{alloc_buffer, aliasing_buffers};
+ return true;
}
-} // namespace
Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
- auto usage = FindAllocateUsage(op, &buffer_var_usage_);
- Buffer& alloc_buffer = usage.alloc_buffer;
- Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
- buf_not_in_headers_.insert(alloc_buffer.get());
- var_not_in_headers_.insert(alloc_buffer->data.get());
+ var_not_in_headers_.insert(op->buffer_var.get());
+
+ if (!buffer_var_usage_.count(op->buffer_var)) {
+ buffer_var_usage_ =
BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+ }
+ Array<Buffer> buffer_usage =
buffer_var_usage_.Get(op->buffer_var).value_or({});
+
+ if (buffer_usage.empty()) {
+ if (const DeclBufferNode* decl_buffer = op->body.as<DeclBufferNode>()) {
+ if (IsAllocateDeclBufferPattern(op, decl_buffer)) {
+ // As a syntax sugar, we identify the pattern of Allocate and
DeclBuffer and print a single
+ // DeclBuffer statement. It is intentionally to call `Print` instead
of `PrintBody` here to
+ // delegate the printing of the current node to `DeclBufferNode` while
maintaining the
+ // same value of `current_num_` and `num_child_`.
+ return Print(op->body);
Review Comment:
This branch skips the call to `PrintNonHeaderBufferDeclarations` and the
checks for the `with` syntax. It looks like the `with` syntax is handled
inside the printer for `DeclBufferNode`, but the
`PrintNonHeaderBufferDeclarations` does not. As a result, the `T.buffer_decl`
statement for any buffers that alias `op->buffer_var` would erroneously show up
at the function header, instead of being inside the `DeclBuffer` node.
This won't be an issue once the `DeclBufferNode` is mandatory, but could
cause bugs until then.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]