This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 73a9e99 [Relay] Expose vm OptimizeModule to Python (#4800)
73a9e99 is described below
commit 73a9e997b097ff09d2896e74355ccf6d16ccd254
Author: masahi <[email protected]>
AuthorDate: Sun Feb 2 11:04:44 2020 +0900
[Relay] Expose vm OptimizeModule to Python (#4800)
* Expose VM OptimizeModule to python
* added missing imports
* fix import
---
python/tvm/relay/backend/vm.py | 38 ++++++++++++++++++++++++++++++++++++++
python/tvm/relay/scope_builder.py | 1 +
src/relay/backend/vm/compiler.cc | 13 +++++++++++++
tests/python/relay/test_vm.py | 5 +++++
4 files changed, 57 insertions(+)
diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py
index 3100900..f1cdefc 100644
--- a/python/tvm/relay/backend/vm.py
+++ b/python/tvm/relay/backend/vm.py
@@ -23,6 +23,7 @@ Implements a Python interface to compiling and executing on
the Relay VM.
import numpy as np
import tvm
+import tvm.ndarray as _nd
from tvm import autotvm, container
from tvm.object import Object
from tvm.relay import expr as _expr
@@ -409,6 +410,8 @@ class VMCompiler(object):
self._codegen = self.mod["codegen"]
self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"]
+ self._get_params_func = self.mod["get_params"]
+ self._optimize = self.mod["optimize"]
def set_params(self, params):
"""Set constant parameters for the model.
@@ -426,6 +429,14 @@ class VMCompiler(object):
inputs[name] = _expr.const(param)
self._set_params_func(inputs)
+ def get_params(self):
+ """Return the updated weights."""
+ params = self._get_params_func()
+ ret = {}
+ for key, value in params.items():
+ ret[key] = value.data
+ return ret
+
def lower(self, mod, target=None, target_host=None):
"""Lower the module to VM bytecode.
@@ -458,6 +469,33 @@ class VMCompiler(object):
"""Generate the kernel library."""
self._codegen()
+ def optimize(self, mod, target=None, params=None):
+ """Helper method that optimizes a Relay module via VM.
+
+ Parameters
+ ----------
+ mod : relay.Module
+
+ target : str, :any:`tvm.target.Target`, or dict of str (i.e.
+ device/context name) to str/tvm.target.Target, optional
+
+ params : dict of str to NDArray
+ Input parameters to the graph that do not change
+ during inference time. Used for constant folding.
+
+ Returns
+ -------
+ mod : relay.Module
+ The optimized relay module.
+
+ params : dict
+ The parameters of the final module.
+ """
+ target = self._update_target(target)
+ if params:
+ self.set_params(params)
+ return self._optimize(mod, target), self.get_params()
+
def get_exec(self):
"""Get the VM executable.
diff --git a/python/tvm/relay/scope_builder.py
b/python/tvm/relay/scope_builder.py
index 16044c1..43c6532 100644
--- a/python/tvm/relay/scope_builder.py
+++ b/python/tvm/relay/scope_builder.py
@@ -18,6 +18,7 @@
"""The scope builder interface."""
from __future__ import absolute_import
+from . import ty as _ty
from . import expr as _expr
from .._ffi import base as _base
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index cc5d6bc..8d4f4ad 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -772,6 +772,19 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
this->SetParam(kv.first, kv.second->data);
}
});
+ } else if (name == "get_params") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ Map<std::string, Constant> ret;
+ for (const auto& kv : params_) {
+ ret.Set(kv.first, ConstantNode::make(kv.second));
+ }
+ *rv = ret;
+ });
+ } else if (name == "optimize") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ CHECK_EQ(args.num_args, 2);
+ *rv = this->OptimizeModule(args[0], args[1]);
+ });
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index d4a7a1a..9ea939c 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -22,6 +22,7 @@ from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.testing.config import ctx_list
from tvm.relay.prelude import Prelude
+from tvm.relay import testing
import pytest
def check_result(args, expected_result, mod=None):
@@ -570,6 +571,10 @@ def test_add_op_broadcast():
mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod)
+def test_vm_optimize():
+ mod, params = testing.resnet.get_workload(batch_size=1, num_layers=18)
+ comp = relay.backend.vm.VMCompiler()
+ opt_mod, _ = comp.optimize(mod, "llvm", params)
if __name__ == "__main__":
pytest.main([__file__])