This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6c5be6fbd0 [TVMScript] `T.axis.remap` syntax sugar for TVMScript
printer (#13743)
6c5be6fbd0 is described below
commit 6c5be6fbd062a5cd431f09a1d87ac614cee73a39
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Jan 17 21:14:27 2023 -0800
[TVMScript] `T.axis.remap` syntax sugar for TVMScript printer (#13743)
This PR implements the syntax sugar of `T.axis.remap` for new TVMScript
printer. This syntax sugar will synthesize the `T.axis.remap` when there are
more than 2 simple block iterating variable bindings. For example, it will
change
```python
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
```
into
```python
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
```
Co-authored-by: Junru Shao <[email protected]>
---
src/script/printer/tir/block.cc | 80 +++++++++++++++++++++-
.../test_tvmscript_printer_syntax_sugar.py | 69 +++++++++++++++++++
2 files changed, 147 insertions(+), 2 deletions(-)
diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc
index e7f733864c..069ec7f3ea 100644
--- a/src/script/printer/tir/block.cc
+++ b/src/script/printer/tir/block.cc
@@ -30,8 +30,42 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath
block_p, //
opt_realize.defined() ? opt_realize.value().get() : nullptr;
const ObjectPathNode* realize_p = opt_realize_p.defined() ?
opt_realize_p.get() : nullptr;
// Step 1. Handle block var and block bindings
- int n_vars = block->iter_vars.size();
- for (int i = 0; i < n_vars; ++i) {
+ // Step 1.1. Obtain all loop var defined along path
+ std::unordered_map<const tir::VarNode*, tir::For> loop_vars;
+ for (Frame f : d->frames) {
+ if (const auto* tir_f = f.as<TIRFrameNode>()) {
+ if (const auto* for_loop = tir_f->tir.as<tir::ForNode>()) {
+ for (const tir::ForNode* l = for_loop; l != nullptr; l =
l->body.as<tir::ForNode>()) {
+ loop_vars.insert(std::make_pair(l->loop_var.get(),
GetRef<tir::For>(l)));
+ }
+ }
+ }
+ }
+
+ std::vector<int> remap_vars_indices;
+ auto add_remapped_iter_var = [&](int i) -> bool {
+ if (realize) {
+ tir::ExprDeepEqual expr_equal;
+ tir::IterVar iter_var = block->iter_vars[i];
+ PrimExpr value = realize->iter_values[i];
+ if (iter_var->iter_type == tir::IterVarType::kDataPar ||
+ iter_var->iter_type == tir::IterVarType::kCommReduce) {
+ if (const auto* var = value.as<tir::VarNode>()) {
+ if (loop_vars.count(var)) {
+ tir::For for_loop = loop_vars.at(var);
+ if (expr_equal(for_loop->min, iter_var->dom->min) &&
+ expr_equal(for_loop->extent, iter_var->dom->extent)) {
+ remap_vars_indices.push_back(i);
+ return true;
+ }
+ }
+ }
+ }
+ }
+ return false;
+ };
+
+ auto print_single_iter_var = [&](int i) {
tir::IterVar iter_var = block->iter_vars[i];
ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i);
ExprDoc rhs = TIR("axis");
@@ -66,7 +100,49 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath
block_p, //
rhs = rhs->Call({dom});
}
(*frame)->stmts.push_back(AssignDoc(DefineVar(iter_var->var, *frame, d),
rhs, NullOpt));
+ };
+
+ auto print_remapped_iter_var = [&]() {
+ if (remap_vars_indices.size()) {
+ int m = remap_vars_indices.size();
+ if (!m) {
+ return;
+ }
+ if (m == 1) {
+ print_single_iter_var(remap_vars_indices[0]);
+ remap_vars_indices.clear();
+ return;
+ }
+ Array<ExprDoc> lhs;
+ Array<ExprDoc> loop_var_doc;
+ lhs.reserve(m);
+ loop_var_doc.reserve(m);
+ std::string binding_type = "";
+ for (int i : remap_vars_indices) {
+ tir::IterVar iter_var = block->iter_vars[i];
+ ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i);
+ lhs.push_back(DefineVar(iter_var->var, *frame, d));
+ loop_var_doc.push_back(d->AsDoc<ExprDoc>(realize->iter_values[i],
+
realize_p->Attr("iter_values")->ArrayIndex(i)));
+ binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ?
"S" : "R";
+ }
+ ExprDoc rhs = TIR("axis")->Attr("remap");
+ rhs = rhs->Call({LiteralDoc::Str(binding_type), ListDoc(loop_var_doc)});
+ (*frame)->stmts.push_back(AssignDoc(TupleDoc(lhs), rhs, NullOpt));
+ remap_vars_indices.clear();
+ }
+ };
+
+ // Step 1.2. Construct all block var bindings
+ int n_vars = block->iter_vars.size();
+ for (int i = 0; i < n_vars; ++i) {
+ if (!add_remapped_iter_var(i)) {
+ print_remapped_iter_var();
+ print_single_iter_var(i);
+ }
}
+ print_remapped_iter_var();
+
// Step 2. Handle block predicate
if (realize) {
ICHECK(realize->predicate.defined() &&
realize->predicate->dtype.is_bool());
diff --git a/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py
b/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py
new file mode 100644
index 0000000000..1bccb8188c
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm.testing
+from tvm.script.parser import tir as T
+from tvm.script import script
+
+
+def _test(obj, expected: str):
+ assert script(obj).strip() == expected.strip()
+
+
+def test_remap():
+ @T.prim_func
+ def block_with_remap_implicitly():
+ for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+ with T.block("update"):
+ v0 = T.axis.spatial(128, i0 + 1)
+ v1 = T.axis.spatial(128, i1)
+ v2 = T.axis.reduce(128, i2)
+ v3 = T.axis.spatial(128, i3 - 1)
+ v4 = T.axis.reduce(128, i4)
+ v5 = T.axis.spatial(128, i5)
+ pass
+
+ @T.prim_func
+ def block_with_remap_explicitly():
+ for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+ with T.block("update"):
+ v0 = T.axis.spatial(128, i0 + 1)
+ v1, v2 = T.axis.remap("SR", [i1, i2])
+ v3 = T.axis.spatial(128, i3 - 1)
+ v4, v5 = T.axis.remap("RS", [i4, i5])
+ pass
+
+ expected_output = """@T.prim_func
+def main():
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+ with T.block("update"):
+ v0 = T.axis.spatial(128, i0 + 1)
+ v1, v2 = T.axis.remap("SR", [i1, i2])
+ v3 = T.axis.spatial(128, i3 - 1)
+ v4, v5 = T.axis.remap("RS", [i4, i5])
+ T.reads()
+ T.writes()
+ T.evaluate(0)"""
+ _test(block_with_remap_implicitly, expected_output)
+ _test(block_with_remap_explicitly, expected_output)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()