This is an automated email from the ASF dual-hosted git repository.

haichen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 529ee1f  [Relay] Fix VM compiler for while loop with free vars  (#4889)
529ee1f is described below

commit 529ee1feb6c96967d8ab28e08b72006c6d7e8887
Author: masahi <[email protected]>
AuthorDate: Sun Feb 16 15:43:22 2020 +0900

    [Relay] Fix VM compiler for while loop with free vars  (#4889)
    
    * add additional switch to handle nested call node
    
    * Fix VM compiler for while loop with free var
---
 src/relay/backend/vm/compiler.cc |  3 +++
 tests/python/relay/test_vm.py    | 27 +++++++++++++++++++++++++++
 2 files changed, 30 insertions(+)

diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 8d4f4ad..73a6450 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -637,6 +637,9 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& 
expr)> {
       // emit invoke closure here.
       VisitExpr(GetRef<Var>(var_node));
       Emit(Instruction::InvokeClosure(last_register_, args_registers, 
NewRegister()));
+    } else if (auto inner_call_node = op.as<CallNode>()) {
+      VisitExpr(GetRef<Call>(inner_call_node));
+      Emit(Instruction::InvokeClosure(last_register_, args_registers, 
NewRegister()));
     } else {
       // Finally if there are any other cases this is a bug.
       LOG(FATAL) << "internal error: unreachable code,"
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index c4cd616..8cac656 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -23,6 +23,7 @@ from tvm import relay
 from tvm.relay.scope_builder import ScopeBuilder
 from tvm.relay.testing.config import ctx_list
 from tvm.relay.prelude import Prelude
+from tvm.relay.loops import while_loop
 from tvm.relay import testing
 
 def check_result(args, expected_result, mod=None):
@@ -576,5 +577,31 @@ def test_vm_optimize():
     comp = relay.vm.VMCompiler()
     opt_mod, _ = comp.optimize(mod, "llvm", params)
 
+def test_loop_free_var():
+    x = relay.var('x', shape=(), dtype='int32')
+    i = relay.var('i', shape=(), dtype='int32')
+    s = relay.var('s', shape=(), dtype='int32')
+
+    def cond(i, _):
+        return i < relay.const(10, dtype='int32')
+
+    def body_no_free_var(i, acc):
+        incr = relay.const(1, "int32")
+        return i + incr, acc + i
+
+    def body_with_free_var(i, acc):
+        incr = relay.const(1, "int32")
+        return i + incr, acc + x
+
+    for args, body, expected in zip([[], [1]],
+                                    [body_no_free_var, body_with_free_var],
+                                    [45, 10]):
+        loop = while_loop(cond, [i, s], body)
+        tup = loop(relay.const(0, dtype='int32'), relay.zeros(shape=(), 
dtype='int32'))
+        ret = relay.TupleGetItem(tup, 1)
+        mod = tvm.IRModule()
+        mod["main"] = relay.Function(relay.analysis.free_vars(ret), ret)
+        check_result(args, expected, mod=mod)
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to