This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e1ae821 Add while node support in TVMScript (#9004)
e1ae821 is described below
commit e1ae821c7d3991a246689fc12f629fe80ed8671a
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Sep 15 01:55:42 2021 +0800
Add while node support in TVMScript (#9004)
* support while
* update synr version
---
docker/install/ubuntu_install_python_package.sh | 2 +-
python/gen_requirements.py | 2 +-
python/tvm/script/parser.py | 13 +++++++++++++
python/tvm/tir/__init__.py | 2 +-
src/printer/tvmscript_printer.cc | 8 ++++++++
tests/python/unittest/test_tvmscript_roundtrip.py | 17 +++++++++++++++++
tests/scripts/task_ci_setup.sh | 2 +-
7 files changed, 42 insertions(+), 4 deletions(-)
diff --git a/docker/install/ubuntu_install_python_package.sh
b/docker/install/ubuntu_install_python_package.sh
index 88d6840..eff86a9 100755
--- a/docker/install/ubuntu_install_python_package.sh
+++ b/docker/install/ubuntu_install_python_package.sh
@@ -36,6 +36,6 @@ pip3 install \
pytest-xdist \
requests \
scipy \
- synr==0.3.0 \
+ synr==0.4.0 \
six \
tornado
diff --git a/python/gen_requirements.py b/python/gen_requirements.py
index a9a8607..7470ccc 100755
--- a/python/gen_requirements.py
+++ b/python/gen_requirements.py
@@ -244,7 +244,7 @@ CONSTRAINTS = [
("sphinx_autodoc_annotation", None),
("sphinx_gallery", None),
("sphinx_rtd_theme", None),
- ("synr", "==0.3.0"),
+ ("synr", "==0.4.0"),
("tensorflow", None),
("tensorflow-estimator", None),
("tflite", None),
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 60fc496..51ee0ae 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -594,6 +594,19 @@ class TVMScriptParser(Transformer):
self.current_lineno, self.current_col_offset = old_lineno,
old_col_offset
return res
+ def transform_While(self, node):
+ """While visitor
+ AST abstract grammar:
+ While(expr condition, stmt* body)
+ """
+ condition = self.transform(node.condition)
+ # body
+ self.context.enter_scope(nodes=node.body.stmts)
+ body = self.parse_body(node)
+ self.context.exit_scope()
+
+ return tvm.tir.While(condition, body,
span=tvm_span_from_synr(node.span))
+
def transform_With(self, node):
"""With visitor
AST abstract grammar:
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index eb200df..4400623 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -27,7 +27,7 @@ from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or,
Not
from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast,
Shuffle
from .expr import Call, CallEffectKind, Let, IterVar, Any
-from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For
+from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate,
AttrStmt
from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index be31961..906dc25 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -170,6 +170,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const
Stmt&)>,
Doc VisitStmt_(const IfThenElseNode* op) override;
Doc VisitStmt_(const SeqStmtNode* op) override;
Doc VisitStmt_(const ForNode* op) override;
+ Doc VisitStmt_(const WhileNode* op) override;
Doc VisitStmt_(const PrefetchNode* op) override;
Doc VisitStmt_(const EvaluateNode* op) override;
Doc VisitStmt_(const BlockRealizeNode* op) override;
@@ -830,6 +831,13 @@ Doc TVMScriptPrinter::VisitStmt_(const PrefetchNode* op) {
return doc;
}
+Doc TVMScriptPrinter::VisitStmt_(const WhileNode* op) {
+ Doc doc;
+ doc << "while " << Print(op->condition) << ":";
+ doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
+ return doc;
+}
+
Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) {
Doc doc;
doc << "ty." << runtime::DLDataType2String(node->dtype);
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index e0f0c6d..7c123af 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3066,5 +3066,22 @@ def test_same_name_var():
assert out_str.find("i_") == -1
[email protected]
+def while_loop(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, (16,), "float32")
+ B = tir.match_buffer(b, (16,), "float32")
+ i = tir.alloc_buffer((), "int32", scope="local")
+ with tir.block([16]) as [vi]:
+ B[vi] = 0
+ while i[()] < 10:
+ for j in range(16):
+ B[j] += A[j]
+
+
+def test_while_loop():
+ rt_func = tvm.script.from_source(tvm.script.asscript(while_loop, True))
+ tvm.ir.assert_structural_equal(while_loop, rt_func)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh
index 753d17d..01d5587 100755
--- a/tests/scripts/task_ci_setup.sh
+++ b/tests/scripts/task_ci_setup.sh
@@ -30,7 +30,7 @@ set -o pipefail
#
echo "Addtiional setup in" ${CI_IMAGE_NAME}
-python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.3.0
+python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.4.0
# Rebuild standalone_crt in build/ tree. This file is not currently archived
by pack_lib() in
# Jenkinsfile. We expect config.cmake to be present from pack_lib().