This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 13f54e0 [BugFix][TVMScript] Fix printer for dependent loops (#9506)
13f54e0 is described below
commit 13f54e03fd3b7066d9ce9b0b0c3e1ba297d627bb
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Nov 15 15:57:05 2021 +0800
[BugFix][TVMScript] Fix printer for dependent loops (#9506)
---
src/printer/tvmscript_printer.cc | 18 ++++++++++++++++--
tests/python/unittest/test_tvmscript_roundtrip.py | 14 ++++++++++++++
2 files changed, 30 insertions(+), 2 deletions(-)
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index a47712e..f1c47e7 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -228,13 +228,27 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const
Stmt&)>,
/*! \brief Print all simple loops in stack into one line using
tir_prefix_.grid(). */
Doc PrintLoopStack();
/*!
- * \brief Print all simple loops in stack into one line using
tir_prefix_.grid().
+ * \brief Check whether a loop satisfies:
+ * 1. the loop is serial;
+ * 2. the loop has no annotation;
+ * 3. the loop starts from 0;
+ * 4. there is no optional information.
* \param for_op the for node to be checked
+ * \return A boolean indicating whether the input loop satisfies the above
conditions
*/
bool IsSimpleLoop(const ForNode* for_op) {
return for_op->kind == ForKind::kSerial && for_op->annotations.empty() &&
is_zero(for_op->min) && !ContainsOptionalInfo(GetRef<Stmt>(for_op));
}
+ /*!
+ * \brief Check whether the `min` or `extent` of a loop depends on previous
loops
+ * \param for_op The loop to be checked
+ * \return A boolean indicating whether the input loop depends on previous
loops
+ */
+ bool DependOnPrevLoops(const ForNode* for_op) {
+ auto f_check = [&var_map = this->loop_var_map_](const VarNode* v) { return
var_map.count(v); };
+ return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check);
+ }
/*!
* \brief Print additional info about expr in comment.
@@ -895,7 +909,7 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
bool simple_loop = IsSimpleLoop(op);
if (simple_loop) simple_loop_stack_.push_back(GetRef<For>(op));
// It is a loop that can be compressed, let the loops below print it out
- if (simple_loop && body != nullptr && IsSimpleLoop(body)) {
+ if (simple_loop && body != nullptr && IsSimpleLoop(body) &&
!DependOnPrevLoops(body)) {
doc << Print(GetRef<For>(body));
TryDeallocVar(op->loop_var);
loop_var_map_.erase(op->loop_var.get());
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 4e1308b..d7a8ff9 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3176,5 +3176,19 @@ def test_div_mod():
assert isinstance(func.body[3].value, tvm.tir.Mod)
[email protected]_func
+def loop_extent_dependent(a: T.handle) -> None:
+ A = T.match_buffer(a, [], dtype="int32")
+ for i in T.serial(0, 128):
+ for j in T.serial(0, i):
+ A[()] = A[()] + j
+
+
+def test_loop_extent_dependent():
+ func = loop_extent_dependent
+ rt_func = tvm.script.from_source(func.script(show_meta=True))
+ tvm.ir.assert_structural_equal(func, rt_func, True)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))