This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6c5be6fbd0 [TVMScript] `T.axis.remap` syntax sugar for TVMScript 
printer (#13743)
6c5be6fbd0 is described below

commit 6c5be6fbd062a5cd431f09a1d87ac614cee73a39
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Jan 17 21:14:27 2023 -0800

    [TVMScript] `T.axis.remap` syntax sugar for TVMScript printer (#13743)
    
    This PR implements the syntax sugar of `T.axis.remap` for new TVMScript 
printer. This syntax sugar will synthesize the `T.axis.remap` when there are 
more than 2 simple block iterating variable bindings. For example, it will 
change
    ```python
    for i, j, k in T.grid(128, 128, 128):
      with T.block("update"):
        vi = T.axis.spatial(128, i)
        vj = T.axis.spatial(128, j)
        vk = T.axis.reduce(128, k)
    ```
    into
    ```python
    for i, j, k in T.grid(128, 128, 128):
      with T.block("update"):
        vi, vj, vk = T.axis.remap("SSR", [i, j, k])
    ```
    
    Co-authored-by: Junru Shao <[email protected]>
---
 src/script/printer/tir/block.cc                    | 80 +++++++++++++++++++++-
 .../test_tvmscript_printer_syntax_sugar.py         | 69 +++++++++++++++++++
 2 files changed, 147 insertions(+), 2 deletions(-)

diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc
index e7f733864c..069ec7f3ea 100644
--- a/src/script/printer/tir/block.cc
+++ b/src/script/printer/tir/block.cc
@@ -30,8 +30,42 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath 
block_p,  //
       opt_realize.defined() ? opt_realize.value().get() : nullptr;
   const ObjectPathNode* realize_p = opt_realize_p.defined() ? 
opt_realize_p.get() : nullptr;
   // Step 1. Handle block var and block bindings
-  int n_vars = block->iter_vars.size();
-  for (int i = 0; i < n_vars; ++i) {
+  // Step 1.1. Obtain all loop var defined along path
+  std::unordered_map<const tir::VarNode*, tir::For> loop_vars;
+  for (Frame f : d->frames) {
+    if (const auto* tir_f = f.as<TIRFrameNode>()) {
+      if (const auto* for_loop = tir_f->tir.as<tir::ForNode>()) {
+        for (const tir::ForNode* l = for_loop; l != nullptr; l = 
l->body.as<tir::ForNode>()) {
+          loop_vars.insert(std::make_pair(l->loop_var.get(), 
GetRef<tir::For>(l)));
+        }
+      }
+    }
+  }
+
+  std::vector<int> remap_vars_indices;
+  auto add_remapped_iter_var = [&](int i) -> bool {
+    if (realize) {
+      tir::ExprDeepEqual expr_equal;
+      tir::IterVar iter_var = block->iter_vars[i];
+      PrimExpr value = realize->iter_values[i];
+      if (iter_var->iter_type == tir::IterVarType::kDataPar ||
+          iter_var->iter_type == tir::IterVarType::kCommReduce) {
+        if (const auto* var = value.as<tir::VarNode>()) {
+          if (loop_vars.count(var)) {
+            tir::For for_loop = loop_vars.at(var);
+            if (expr_equal(for_loop->min, iter_var->dom->min) &&
+                expr_equal(for_loop->extent, iter_var->dom->extent)) {
+              remap_vars_indices.push_back(i);
+              return true;
+            }
+          }
+        }
+      }
+    }
+    return false;
+  };
+
+  auto print_single_iter_var = [&](int i) {
     tir::IterVar iter_var = block->iter_vars[i];
     ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i);
     ExprDoc rhs = TIR("axis");
@@ -66,7 +100,49 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath 
block_p,  //
       rhs = rhs->Call({dom});
     }
     (*frame)->stmts.push_back(AssignDoc(DefineVar(iter_var->var, *frame, d), 
rhs, NullOpt));
+  };
+
+  auto print_remapped_iter_var = [&]() {
+    if (remap_vars_indices.size()) {
+      int m = remap_vars_indices.size();
+      if (!m) {
+        return;
+      }
+      if (m == 1) {
+        print_single_iter_var(remap_vars_indices[0]);
+        remap_vars_indices.clear();
+        return;
+      }
+      Array<ExprDoc> lhs;
+      Array<ExprDoc> loop_var_doc;
+      lhs.reserve(m);
+      loop_var_doc.reserve(m);
+      std::string binding_type = "";
+      for (int i : remap_vars_indices) {
+        tir::IterVar iter_var = block->iter_vars[i];
+        ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i);
+        lhs.push_back(DefineVar(iter_var->var, *frame, d));
+        loop_var_doc.push_back(d->AsDoc<ExprDoc>(realize->iter_values[i],
+                                                 
realize_p->Attr("iter_values")->ArrayIndex(i)));
+        binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? 
"S" : "R";
+      }
+      ExprDoc rhs = TIR("axis")->Attr("remap");
+      rhs = rhs->Call({LiteralDoc::Str(binding_type), ListDoc(loop_var_doc)});
+      (*frame)->stmts.push_back(AssignDoc(TupleDoc(lhs), rhs, NullOpt));
+      remap_vars_indices.clear();
+    }
+  };
+
+  // Step 1.2. Construct all block var bindings
+  int n_vars = block->iter_vars.size();
+  for (int i = 0; i < n_vars; ++i) {
+    if (!add_remapped_iter_var(i)) {
+      print_remapped_iter_var();
+      print_single_iter_var(i);
+    }
   }
+  print_remapped_iter_var();
+
   // Step 2. Handle block predicate
   if (realize) {
     ICHECK(realize->predicate.defined() && 
realize->predicate->dtype.is_bool());
diff --git a/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py 
b/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py
new file mode 100644
index 0000000000..1bccb8188c
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm.testing
+from tvm.script.parser import tir as T
+from tvm.script import script
+
+
+def _test(obj, expected: str):
+    assert script(obj).strip() == expected.strip()
+
+
+def test_remap():
+    @T.prim_func
+    def block_with_remap_implicitly():
+        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+            with T.block("update"):
+                v0 = T.axis.spatial(128, i0 + 1)
+                v1 = T.axis.spatial(128, i1)
+                v2 = T.axis.reduce(128, i2)
+                v3 = T.axis.spatial(128, i3 - 1)
+                v4 = T.axis.reduce(128, i4)
+                v5 = T.axis.spatial(128, i5)
+                pass
+
+    @T.prim_func
+    def block_with_remap_explicitly():
+        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+            with T.block("update"):
+                v0 = T.axis.spatial(128, i0 + 1)
+                v1, v2 = T.axis.remap("SR", [i1, i2])
+                v3 = T.axis.spatial(128, i3 - 1)
+                v4, v5 = T.axis.remap("RS", [i4, i5])
+                pass
+
+    expected_output = """@T.prim_func
+def main():
+    with T.block("root"):
+        T.reads()
+        T.writes()
+        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+            with T.block("update"):
+                v0 = T.axis.spatial(128, i0 + 1)
+                v1, v2 = T.axis.remap("SR", [i1, i2])
+                v3 = T.axis.spatial(128, i3 - 1)
+                v4, v5 = T.axis.remap("RS", [i4, i5])
+                T.reads()
+                T.writes()
+                T.evaluate(0)"""
+    _test(block_with_remap_implicitly, expected_output)
+    _test(block_with_remap_explicitly, expected_output)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to