This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 7312934 [RELAY][VM] Add shape_of instruction (#5855)
7312934 is described below
commit 731293462cd54b83566dfabf35daa36a75d56839
Author: Zhi <[email protected]>
AuthorDate: Sun Jun 28 10:05:50 2020 -0700
[RELAY][VM] Add shape_of instruction (#5855)
---
include/tvm/runtime/vm.h | 12 ++++
python/tvm/relay/op/__init__.py | 1 +
python/tvm/relay/op/vm/__init__.py | 20 ++++++
python/tvm/relay/op/vm/_ffi_api.py | 20 ++++++
python/tvm/relay/op/vm/vm.py | 35 +++++++++
python/tvm/relay/transform/memory_alloc.py | 4 +-
src/relay/backend/vm/compiler.cc | 13 ++++
src/relay/op/tensor/unary.cc | 14 ----
src/relay/op/type_relations.cc | 15 ++++
src/relay/op/type_relations.h | 12 ++++
src/relay/op/vm/vm.cc | 58 +++++++++++++++
src/relay/transforms/fold_constant.cc | 4 +-
src/runtime/vm/executable.cc | 10 +++
src/runtime/vm/vm.cc | 31 ++++++++
tests/python/relay/test_vm_serialization.py | 107 ++++++++++------------------
15 files changed, 268 insertions(+), 88 deletions(-)
diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h
index b9ccbf9..0cce533 100644
--- a/include/tvm/runtime/vm.h
+++ b/include/tvm/runtime/vm.h
@@ -114,6 +114,7 @@ enum class Opcode {
LoadConsti = 14U,
Fatal = 15U,
AllocStorage = 16U,
+ ShapeOf = 17U,
};
/*! \brief A single virtual machine instruction.
@@ -245,6 +246,9 @@ struct Instruction {
/*! \brief The hint of the dtype. */
DLDataType dtype_hint;
} alloc_storage;
+ struct /* ShapeOf Operands */ {
+ RegName tensor;
+ } shape_of;
};
/*!
@@ -389,6 +393,14 @@ struct Instruction {
static Instruction AllocStorage(RegName size, Index alignment, DLDataType
dtype_hint,
RegName dst);
+ /*!
+ * \brief Get the shape of an input tensor.
+ * \param tensor The input tensor.
+ * \param dst The destination to store the shape of the given tensor.
+ * \return The shape of instruction.
+ */
+ static Instruction ShapeOf(RegName tensor, RegName dst);
+
Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr);
diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py
index ce0df95..a45d466 100644
--- a/python/tvm/relay/op/__init__.py
+++ b/python/tvm/relay/op/__init__.py
@@ -27,6 +27,7 @@ from .reduce import *
from .tensor import *
from .transform import *
from .algorithm import *
+from .vm import *
from . import nn
from . import annotation
from . import memory
diff --git a/python/tvm/relay/op/vm/__init__.py
b/python/tvm/relay/op/vm/__init__.py
new file mode 100644
index 0000000..2ac1e57
--- /dev/null
+++ b/python/tvm/relay/op/vm/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import
+"""Dialect operators for Relay VM."""
+from __future__ import absolute_import as _abs
+from . import vm
diff --git a/python/tvm/relay/op/vm/_ffi_api.py
b/python/tvm/relay/op/vm/_ffi_api.py
new file mode 100644
index 0000000..3eeeeb8
--- /dev/null
+++ b/python/tvm/relay/op/vm/_ffi_api.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for relay.op.vm"""
+import tvm._ffi
+
+tvm._ffi._init_api("relay.op.vm", __name__)
diff --git a/python/tvm/relay/op/vm/vm.py b/python/tvm/relay/op/vm/vm.py
new file mode 100644
index 0000000..680729d
--- /dev/null
+++ b/python/tvm/relay/op/vm/vm.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint:
disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
+"""Dialect operators for Relay VM."""
+from . import _ffi_api
+
+
+def shape_of(expr):
+ """Invoke a function to get the shape of a tensor.
+
+ Parameters
+ ----------
+ expr : tvm.relay.Expr
+ The expr used to evaluate its tensor shape.
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The expression with the evaluated tensor shape.
+ """
+ return _ffi_api.shape_of(expr)
diff --git a/python/tvm/relay/transform/memory_alloc.py
b/python/tvm/relay/transform/memory_alloc.py
index 6c081cb..a7ba2a8 100644
--- a/python/tvm/relay/transform/memory_alloc.py
+++ b/python/tvm/relay/transform/memory_alloc.py
@@ -44,6 +44,7 @@ class ManifestAllocPass(ExprMutator):
def __init__(self, target_host):
self.invoke_tvm = op.memory.invoke_tvm_op
self.shape_func = op.memory.shape_func
+ self.shape_of = op.vm.shape_of
self.scopes = [ScopeBuilder()]
self.target_host = target_host
self.default_context = cpu(0)
@@ -53,9 +54,6 @@ class ManifestAllocPass(ExprMutator):
def current_scope(self):
return self.scopes[-1]
- def shape_of(self, e):
- return op.shape_of(e, self.compute_dtype)
-
def visit_tuple(self, tup):
scope = self.current_scope()
new_fields = []
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 0b839a2..2151acf 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -283,6 +283,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr&
expr)> {
case Opcode::Invoke:
case Opcode::AllocClosure:
case Opcode::AllocStorage:
+ case Opcode::ShapeOf:
case Opcode::Move:
case Opcode::InvokeClosure:
last_register_ = instr.dst;
@@ -588,6 +589,18 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr&
expr)> {
auto outputs = Downcast<Tuple>(args[2]);
EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
})
+ .Match("vm.shape_of",
+ [this](const Array<Expr>& args, const Attrs& attrs, const
Array<Type>& type_arg) {
+ CHECK_EQ(args.size(), 1U);
+ // Get the attributes.
+ const auto* shape_of_attrs = attrs.as<ShapeOfAttrs>();
+ CHECK(shape_of_attrs) << "Must be the shape_of attrs";
+ CHECK_EQ(shape_of_attrs->dtype.bits(), 64)
+ << "The dtype of shape of must be int64, but got"
+ << DLDataType2String(shape_of_attrs->dtype);
+ this->VisitExpr(args[0]);
+ Emit(Instruction::ShapeOf(last_register_, NewRegister()));
+ })
.Match("memory.kill",
[](const Array<Expr>& args, const Attrs& attrs, const
Array<Type>& type_arg) {
LOG(FATAL) << "memory.kill is not yet supported";
diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc
index 6b72670..99e6c02 100644
--- a/src/relay/op/tensor/unary.cc
+++ b/src/relay/op/tensor/unary.cc
@@ -396,20 +396,6 @@ RELAY_REGISTER_UNARY_OP("bitwise_not")
// shape_of
TVM_REGISTER_NODE_TYPE(ShapeOfAttrs);
-bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
- const TypeReporter& reporter) {
- CHECK_EQ(num_inputs, 1);
- auto tt = types[0].as<TensorTypeNode>();
- if (tt == nullptr) {
- return false;
- }
- const auto* param = attrs.as<ShapeOfAttrs>();
- CHECK(param != nullptr);
- auto rank_shape = RankShape(tt->shape);
- reporter->Assign(types[1], TensorType(rank_shape, param->dtype));
- return true;
-}
-
Array<te::Tensor> ShapeOfCompute(const Attrs& attrs, const Array<te::Tensor>&
inputs,
const Type& out_type) {
CHECK_EQ(inputs.size(), 1);
diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc
index 46143d1..0647ec9 100644
--- a/src/relay/op/type_relations.cc
+++ b/src/relay/op/type_relations.cc
@@ -25,6 +25,7 @@
#include "./type_relations.h"
#include <tvm/arith/analyzer.h>
+#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/tir/op.h>
@@ -146,5 +147,19 @@ Array<IndexExpr> RankShape(const Array<IndexExpr>& shape) {
}
}
+bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(num_inputs, 1);
+ auto tt = types[0].as<TensorTypeNode>();
+ if (tt == nullptr) {
+ return false;
+ }
+ const auto* param = attrs.as<ShapeOfAttrs>();
+ CHECK(param != nullptr);
+ auto rank_shape = RankShape(tt->shape);
+ reporter->Assign(types[1], TensorType(rank_shape, param->dtype));
+ return true;
+}
+
} // namespace relay
} // namespace tvm
diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h
index acd4b2d..5ab8b12 100644
--- a/src/relay/op/type_relations.h
+++ b/src/relay/op/type_relations.h
@@ -79,6 +79,18 @@ bool IdentityCompRel(const Array<Type>& types, int
num_inputs, const Attrs& attr
Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);
+/*!
+ * \brief The shape of type relation.
+ *
+ * \param types The input and output types to the relation.
+ * \param num_inputs The number of input arguments.
+ * \param attrs The attributes
+ * \param reporter The reporter.
+ * \return true whether relation has been resolved.
+ */
+bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter);
+
} // namespace relay
} // namespace tvm
diff --git a/src/relay/op/vm/vm.cc b/src/relay/op/vm/vm.cc
new file mode 100644
index 0000000..af33100
--- /dev/null
+++ b/src/relay/op/vm/vm.cc
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/vm/vm.cc
+ * \brief Dialect operators for Relay VM.
+ */
+
+#include <topi/elemwise.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/data_type.h>
+
+#include "../../transforms/infer_layout_util.h"
+#include "../op_common.h"
+#include "../type_relations.h"
+
+namespace tvm {
+namespace relay {
+
+RELAY_REGISTER_OP("vm.shape_of")
+ .describe(R"code(Get the shape of an input tensor.
+)code" TVM_ADD_FILELINE)
+ .set_num_inputs(1)
+ .add_argument("tensor", "Tensor", "The input tensor")
+ .add_type_rel("ShapeOf", ShapeOfRel)
+ .set_support_level(10)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<TNonComputational>("TNonComputational", true)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout);
+
+TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed([](Expr expr) {
+ auto attrs = make_object<ShapeOfAttrs>();
+ attrs->dtype = DataType::Int(64);
+ static const Op& op = Op::Get("vm.shape_of");
+ return Call(op, {expr}, Attrs(attrs), {});
+});
+
+} // namespace relay
+} // namespace tvm
diff --git a/src/relay/transforms/fold_constant.cc
b/src/relay/transforms/fold_constant.cc
index b2eab8f..50de871 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -81,6 +81,7 @@ class ConstantFolder : public ExprMutator {
: executor_(executor),
module_(module),
shape_of_op_(Op::Get("shape_of")),
+ vm_shape_of_op_(Op::Get("vm.shape_of")),
invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")),
shape_func_op_(Op::Get("memory.shape_func")),
alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
@@ -123,7 +124,7 @@ class ConstantFolder : public ExprMutator {
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
// Try to evaluate shape_of op
- if (call->op == shape_of_op_) {
+ if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) {
return EvaluateShapeOf(res, origin_args, call->attrs);
}
@@ -166,6 +167,7 @@ class ConstantFolder : public ExprMutator {
// Cache the following ops for equivalence checking in this pass.
const Op& shape_of_op_;
+ const Op& vm_shape_of_op_;
const Op& invoke_tvm_op_;
const Op& shape_func_op_;
const Op& alloc_tensor_op_;
diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc
index 65b1a2f..f520404 100644
--- a/src/runtime/vm/executable.cc
+++ b/src/runtime/vm/executable.cc
@@ -417,6 +417,11 @@ VMInstructionSerializer SerializeInstruction(const
Instruction& instr) {
fields.push_back(instr.pc_offset);
break;
}
+ case Opcode::ShapeOf: {
+ // Number of fields = 2
+ fields.assign({instr.shape_of.tensor, instr.dst});
+ break;
+ }
default:
LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op);
break;
@@ -683,6 +688,11 @@ Instruction DeserializeInstruction(const
VMInstructionSerializer& instr) {
DCHECK_EQ(instr.fields.size(), 1U);
return Instruction::Goto(instr.fields[0]);
}
+ case Opcode::ShapeOf: {
+ // Number of fields = 2
+ DCHECK_EQ(instr.fields.size(), 2U);
+ return Instruction::ShapeOf(instr.fields[0], instr.fields[1]);
+ }
default:
LOG(FATAL) << "Invalid opcode" << instr.opcode;
return Instruction();
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index 0c0ca35..6b10a89 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -145,6 +145,9 @@ Instruction::Instruction(const Instruction& instr) {
case Opcode::AllocStorage:
this->alloc_storage = instr.alloc_storage;
return;
+ case Opcode::ShapeOf:
+ this->shape_of.tensor = instr.shape_of.tensor;
+ return;
default:
std::ostringstream out;
out << "Invalid instruction " << static_cast<int>(instr.op);
@@ -239,6 +242,9 @@ Instruction& Instruction::operator=(const Instruction&
instr) {
case Opcode::AllocStorage:
this->alloc_storage = instr.alloc_storage;
return *this;
+ case Opcode::ShapeOf:
+ this->shape_of.tensor = instr.shape_of.tensor;
+ return *this;
default:
std::ostringstream out;
out << "Invalid instruction " << static_cast<int>(instr.op);
@@ -258,6 +264,7 @@ Instruction::~Instruction() {
case Opcode::Goto:
case Opcode::LoadConsti:
case Opcode::AllocStorage:
+ case Opcode::ShapeOf:
case Opcode::Fatal:
return;
case Opcode::AllocTensor:
@@ -351,6 +358,14 @@ Instruction Instruction::AllocStorage(RegName size, Index
alignment, DLDataType
return instr;
}
+Instruction Instruction::ShapeOf(RegName tensor, Index dst) {
+ Instruction instr;
+ instr.op = Opcode::ShapeOf;
+ instr.dst = dst;
+ instr.shape_of.tensor = tensor;
+ return instr;
+}
+
Instruction Instruction::AllocADT(Index tag, Index num_fields,
const std::vector<RegName>& datatype_fields,
Index dst) {
Instruction instr;
@@ -585,6 +600,10 @@ void InstructionPrint(std::ostream& os, const Instruction&
instr) {
<< DLDataType2String(instr.alloc_storage.dtype_hint);
break;
}
+ case Opcode::ShapeOf: {
+ os << "shape_of $" << instr.dst << " $" << instr.shape_of.tensor;
+ break;
+ }
default:
LOG(FATAL) << "should never hit this case" << static_cast<int>(instr.op);
break;
@@ -1057,6 +1076,18 @@ void VirtualMachine::RunLoop() {
pc_++;
goto main_loop;
}
+ case Opcode::ShapeOf: {
+ auto input = ReadRegister(instr.shape_of.tensor);
+ NDArray input_array = Downcast<NDArray>(input);
+ int ndim = input_array->ndim;
+ auto out_tensor = NDArray::Empty({ndim}, {kDLInt, 64, 1}, {kDLCPU, 0});
+ for (int i = 0; i < ndim; ++i) {
+ reinterpret_cast<int64_t*>(out_tensor->data)[i] =
input_array->shape[i];
+ }
+ WriteRegister(instr.dst, out_tensor);
+ pc_++;
+ goto main_loop;
+ }
case Opcode::Ret: {
// If we have hit the point from which we started
// running, we should return to the caller breaking
diff --git a/tests/python/relay/test_vm_serialization.py
b/tests/python/relay/test_vm_serialization.py
index 5d20651..95e6c6f 100644
--- a/tests/python/relay/test_vm_serialization.py
+++ b/tests/python/relay/test_vm_serialization.py
@@ -19,7 +19,6 @@
import numpy as np
import tvm
-from tvm import te
from tvm.runtime import vm as _vm
from tvm.relay import vm as rly_vm
from tvm import relay
@@ -41,11 +40,15 @@ def create_exec(f, target="llvm", params=None):
return executable
-def veval(vm, *args, ctx=tvm.cpu()):
- assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine"
- ret = vm.run(*args)
- return ret
-
+def get_serialized_output(mod, *data, params=None, target="llvm",
+ ctx=tvm.cpu()):
+ exe = create_exec(mod, target, params=params)
+ code, lib = exe.save()
+ des_exec = _vm.Executable.load_exec(code, lib)
+ des_vm = _vm.VirtualMachine(des_exec)
+ des_vm.init(ctx)
+ result = des_vm.run(*data)
+ return result
def run_network(mod,
params,
@@ -56,24 +59,16 @@ def run_network(mod,
result = ex.evaluate()(data, **params)
return result.asnumpy().astype(dtype)
- def get_serialized_output(mod, data, params, target, ctx, dtype='float32'):
- exe = create_exec(mod, target, params=params)
- code, lib = exe.save()
- des_exec = _vm.Executable.load_exec(code, lib)
- des_vm = _vm.VirtualMachine(des_exec)
- des_vm.init(ctx)
- result = des_vm.run(data)
- return result.asnumpy().astype(dtype)
-
data = np.random.uniform(size=data_shape).astype(dtype)
target = "llvm"
ctx = tvm.cpu(0)
tvm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params,
target, ctx, dtype)
- vm_out = get_serialized_output(mod, tvm.nd.array(data.astype(dtype)),
params,
- target, ctx, dtype)
- tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5)
+ vm_out = get_serialized_output(mod, tvm.nd.array(data.astype(dtype)),
+ params=params, target=target, ctx=ctx)
+ tvm.testing.assert_allclose(vm_out.asnumpy().astype(dtype), tvm_out,
+ rtol=1e-5, atol=1e-5)
def test_serializer():
@@ -143,7 +138,7 @@ def test_save_load():
des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
- res = veval(des_vm, x_data)
+ res = des_vm.run(x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data)
@@ -151,14 +146,8 @@ def test_const():
c = relay.const(1.0, "float32")
x = relay.var('x', shape=(10, 10), dtype='float32')
f = relay.Function([x], x + c)
- exe = create_exec(f)
- code, lib = exe.save()
- assert isinstance(code, bytearray)
- des_exec = _vm.Executable.load_exec(code, lib)
- des_vm = _vm.VirtualMachine(des_exec)
- des_vm.init(tvm.cpu())
x_data = np.random.rand(10, 10).astype('float32')
- res = veval(des_vm, x_data)
+ res = get_serialized_output(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + 1)
@@ -172,18 +161,12 @@ def test_if():
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32')
- exe = create_exec(f)
- code, lib = exe.save()
- des_exec = _vm.Executable.load_exec(code, lib)
- des_vm = _vm.VirtualMachine(des_exec)
- des_vm.init(tvm.cpu())
-
# same
- res = veval(des_vm, x_data, x_data)
+ res = get_serialized_output(f, x_data, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data)
# diff
- res = veval(des_vm, x_data, y_data)
+ res = get_serialized_output(f, x_data, y_data)
tvm.testing.assert_allclose(res.asnumpy(), y_data)
@@ -208,13 +191,7 @@ def test_loop():
aarg = relay.var('accum', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
- exe = create_exec(mod)
- code, lib = exe.save()
- des_exec = _vm.Executable.load_exec(code, lib)
- des_vm = _vm.VirtualMachine(des_exec)
- des_vm.init(tvm.cpu())
-
- result = veval(des_vm, i_data, accum_data)
+ result = get_serialized_output(mod, i_data, accum_data)
tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound +
1)))
@@ -225,13 +202,7 @@ def test_tuple():
i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32')
- exe = create_exec(f)
- code, lib = exe.save()
- des_exec = _vm.Executable.load_exec(code, lib)
- des_vm = _vm.VirtualMachine(des_exec)
- des_vm.init(tvm.cpu())
-
- result = veval(des_vm, (i_data, j_data))
+ result = get_serialized_output(f, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), j_data)
@@ -246,13 +217,7 @@ def test_adt_list():
f = relay.Function([], l321)
mod["main"] = f
- exe = create_exec(mod)
- code, lib = exe.save()
- des_exec = _vm.Executable.load_exec(code, lib)
- des_vm = _vm.VirtualMachine(des_exec)
- des_vm.init(tvm.cpu())
-
- result = veval(des_vm)
+ result = get_serialized_output(mod)
assert len(result) == 2
assert len(result[1]) == 2
assert len(result[1][1]) == 2
@@ -292,15 +257,8 @@ def test_adt_compose():
f = relay.Function([y], add_two_body)
mod["main"] = f
- exe = create_exec(mod)
- code, lib = exe.save()
- des_exec = _vm.Executable.load_exec(code, lib)
- des_vm = _vm.VirtualMachine(des_exec)
- des_vm.init(tvm.cpu())
-
x_data = np.array(np.random.rand()).astype('float32')
- result = veval(des_vm, x_data)
-
+ result = get_serialized_output(mod, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
@@ -312,13 +270,7 @@ def test_closure():
clo = ff(relay.const(1.0))
main = clo(relay.const(2.0))
- exe = create_exec(main)
- code, lib = exe.save()
- des_exec = _vm.Executable.load_exec(code, lib)
- des_vm = _vm.VirtualMachine(des_exec)
- des_vm.init(tvm.cpu())
-
- res = veval(des_vm)
+ res = get_serialized_output(main)
tvm.testing.assert_allclose(res.asnumpy(), 3.0)
@@ -332,6 +284,20 @@ def test_mobilenet():
run_network(mod, params)
+def test_vm_shape_of():
+ x = relay.var('x', shape=(relay.Any(), relay.Any(), relay.Any()),
dtype="float32")
+ relu_x = relay.nn.relu(x)
+ data = np.random.uniform(size=(2, 3, 4)).astype('float32')
+ args = [data]
+
+ newshape_var = relay.var('newshape', shape=(2,), dtype='int64')
+ args.append(np.array((1, -1), dtype='int64'))
+ main = relay.reshape(relu_x, newshape=newshape_var)
+
+ res = get_serialized_output(main, *args).asnumpy()
+ tvm.testing.assert_allclose(res.flatten(), data.flatten())
+
+
if __name__ == "__main__":
test_serializer()
test_save_load()
@@ -344,3 +310,4 @@ if __name__ == "__main__":
test_closure()
test_resnet()
test_mobilenet()
+ test_vm_shape_of()