This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 13f54e0  [BugFix][TVMScript] Fix printer for dependent loops (#9506)
13f54e0 is described below

commit 13f54e03fd3b7066d9ce9b0b0c3e1ba297d627bb
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Nov 15 15:57:05 2021 +0800

    [BugFix][TVMScript] Fix printer for dependent loops (#9506)
---
 src/printer/tvmscript_printer.cc                  | 18 ++++++++++++++++--
 tests/python/unittest/test_tvmscript_roundtrip.py | 14 ++++++++++++++
 2 files changed, 30 insertions(+), 2 deletions(-)

diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index a47712e..f1c47e7 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -228,13 +228,27 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const 
Stmt&)>,
   /*! \brief Print all simple loops in stack into one line using 
tir_prefix_.grid(). */
   Doc PrintLoopStack();
   /*!
-   * \brief Print all simple loops in stack into one line using 
tir_prefix_.grid().
+   * \brief Check whether a loop satisfies:
+   * 1. the loop is serial;
+   * 2. the loop has no annotation;
+   * 3. the loop starts from 0;
+   * 4. there is no optional information.
    * \param for_op the for node to be checked
+   * \return A boolean indicating whether the input loop satisfies the above 
conditions
    */
   bool IsSimpleLoop(const ForNode* for_op) {
     return for_op->kind == ForKind::kSerial && for_op->annotations.empty() &&
            is_zero(for_op->min) && !ContainsOptionalInfo(GetRef<Stmt>(for_op));
   }
+  /*!
+   * \brief Check whether the `min` or `extent` of a loop depends on previous 
loops
+   * \param for_op The loop to be checked
+   * \return A boolean indicating whether the input loop depends on previous 
loops
+   */
+  bool DependOnPrevLoops(const ForNode* for_op) {
+    auto f_check = [&var_map = this->loop_var_map_](const VarNode* v) { return 
var_map.count(v); };
+    return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check);
+  }
 
   /*!
    * \brief Print additional info about expr in comment.
@@ -895,7 +909,7 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
   bool simple_loop = IsSimpleLoop(op);
   if (simple_loop) simple_loop_stack_.push_back(GetRef<For>(op));
   // It is a loop that can be compressed, let the loops below print it out
-  if (simple_loop && body != nullptr && IsSimpleLoop(body)) {
+  if (simple_loop && body != nullptr && IsSimpleLoop(body) && 
!DependOnPrevLoops(body)) {
     doc << Print(GetRef<For>(body));
     TryDeallocVar(op->loop_var);
     loop_var_map_.erase(op->loop_var.get());
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 4e1308b..d7a8ff9 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3176,5 +3176,19 @@ def test_div_mod():
     assert isinstance(func.body[3].value, tvm.tir.Mod)
 
 
[email protected]_func
+def loop_extent_dependent(a: T.handle) -> None:
+    A = T.match_buffer(a, [], dtype="int32")
+    for i in T.serial(0, 128):
+        for j in T.serial(0, i):
+            A[()] = A[()] + j
+
+
+def test_loop_extent_dependent():
+    func = loop_extent_dependent
+    rt_func = tvm.script.from_source(func.script(show_meta=True))
+    tvm.ir.assert_structural_equal(func, rt_func, True)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to