This is an automated email from the ASF dual-hosted git repository.
haichen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 2586b4d [Relay][VM] Fix compilation of If-Elses (#5040)
2586b4d is described below
commit 2586b4d32501bec8ff9b7cea145376615ecb00c9
Author: Wei Chen <[email protected]>
AuthorDate: Thu Mar 12 03:26:28 2020 +0800
[Relay][VM] Fix compilation of If-Elses (#5040)
---
src/relay/backend/vm/compiler.cc | 8 +++++---
tests/python/relay/test_vm.py | 19 +++++++++++++++++++
2 files changed, 24 insertions(+), 3 deletions(-)
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index fc52a8e..e3c8d12 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -366,7 +366,9 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr&
expr)> {
this->Emit(Instruction::If(test_register, target_register, 0, 0));
this->VisitExpr(if_node->true_branch);
- size_t true_register = last_register_;
+ // It saves the result of If-Else expression.
+ auto merge_register = NewRegister();
+ Emit(Instruction::Move(last_register_, merge_register));
Emit(Instruction::Goto(0));
// Finally store how many instructions there are in the
@@ -378,7 +380,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr&
expr)> {
size_t false_register = last_register_;
// In else-branch, override the then-branch register
- Emit(Instruction::Move(false_register, true_register));
+ Emit(Instruction::Move(false_register, merge_register));
// Compute the total number of instructions
// after generating false.
auto after_false = this->instructions_.size();
@@ -397,7 +399,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr&
expr)> {
// Patch the Goto.
this->instructions_[after_true - 1].pc_offset = (after_false - after_true)
+ 1;
- this->last_register_ = true_register;
+ this->last_register_ = merge_register;
}
void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 02f1e5b..a8ac27a 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -142,6 +142,25 @@ def test_simple_if():
# diff
check_result([x_data, y_data], y_data, mod=mod)
+def test_multiple_ifs():
+ mod = tvm.IRModule({})
+ b = relay.var('b')
+ v0 = relay.var('v0')
+ v1 = relay.var('v1')
+ v2 = relay.var('v2')
+ v3 = relay.var('v3')
+ out = relay.Tuple([v2, v3])
+ out = relay.Let(v3, relay.If(b, v1, v0), out)
+ out = relay.Let(v2, relay.If(b, v0, v1), out)
+ out = relay.Let(v1, relay.Tuple([relay.const(1)]), out)
+ out = relay.Let(v0, relay.Tuple([relay.const(0)]), out)
+ fn = relay.Function([b], out)
+ mod['main'] = fn
+ ctx = tvm.runtime.ndarray.context('llvm', 0)
+ vm = relay.create_executor(ctx=ctx, mod=mod, kind='vm')
+ res = vmobj_to_list(vm.evaluate()(False))
+ assert(res == [1, 0])
+
def test_simple_call():
mod = tvm.IRModule({})
sum_up = relay.GlobalVar('sum_up')