This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 89bafd5 [RUNTIME] Unify load params interface (#7559)
89bafd5 is described below
commit 89bafd58c27e1dff670dddb3fcf7b4f84dc4eedc
Author: Tristan Konolige <[email protected]>
AuthorDate: Mon Mar 8 21:40:38 2021 -0800
[RUNTIME] Unify load params interface (#7559)
---
apps/android_camera/models/prepare_model.py | 2 +-
apps/bundle_deploy/build_model.py | 6 +-
apps/bundle_deploy/runtime.cc | 1 +
apps/sgx/src/build_model.py | 4 +-
.../wasm-graph/tools/build_graph_lib.py | 4 +-
docs/deploy/android.rst | 2 +-
golang/sample/gen_mobilenet_lib.py | 4 +-
python/tvm/contrib/debugger/debug_result.py | 6 +-
python/tvm/driver/tvmc/compiler.py | 4 +-
python/tvm/driver/tvmc/runner.py | 7 +--
python/tvm/relay/param_dict.py | 28 ++++-----
python/tvm/runtime/__init__.py | 1 +
.../tvm/{relay/param_dict.py => runtime/params.py} | 23 +++----
rust/tvm-graph-rt/src/graph.rs | 2 +-
rust/tvm-graph-rt/tests/build_model.py | 4 +-
.../tests/test_nn/src/build_test_graph.py | 4 +-
rust/tvm/examples/resnet/src/build_resnet.py | 4 +-
src/relay/backend/param_dict.cc | 70 ++++------------------
src/relay/backend/param_dict.h | 27 +--------
src/runtime/file_utils.cc | 67 +++++++++++++++++++++
src/runtime/file_utils.h | 26 ++++++++
src/runtime/graph/graph_runtime.cc | 31 ++--------
src/runtime/graph/graph_runtime.h | 3 -
src/runtime/vm/vm.cc | 2 +
tests/python/contrib/test_tensorrt.py | 4 +-
tests/python/relay/test_cpp_build_module.py | 4 +-
tests/python/relay/test_param_dict.py | 8 +--
tests/python/unittest/test_runtime_graph.py | 6 +-
.../test_runtime_module_based_interface.py | 10 ++--
tutorials/frontend/deploy_sparse.py | 4 +-
30 files changed, 176 insertions(+), 192 deletions(-)
diff --git a/apps/android_camera/models/prepare_model.py
b/apps/android_camera/models/prepare_model.py
index ab20e02..f155d46 100644
--- a/apps/android_camera/models/prepare_model.py
+++ b/apps/android_camera/models/prepare_model.py
@@ -106,7 +106,7 @@ def main(model_str, output_path):
f.write(graph)
print("dumping params...")
with open(output_path_str + "/" + "deploy_param.params", "wb") as f:
- f.write(relay.save_param_dict(params))
+ f.write(runtime.save_param_dict(params))
print("dumping labels...")
synset_url = "".join(
[
diff --git a/apps/bundle_deploy/build_model.py
b/apps/bundle_deploy/build_model.py
index 0991ac9..8fbc01b 100644
--- a/apps/bundle_deploy/build_model.py
+++ b/apps/bundle_deploy/build_model.py
@@ -20,7 +20,7 @@ import argparse
import os
from tvm import relay
import tvm
-from tvm import te
+from tvm import te, runtime
import logging
import json
from tvm.contrib import cc as _cc
@@ -70,7 +70,7 @@ def build_module(opts):
with open(
os.path.join(build_dir, file_format_str.format(name="params",
ext="bin")), "wb"
) as f_params:
- f_params.write(relay.save_param_dict(params))
+ f_params.write(runtime.save_param_dict(params))
def build_test_module(opts):
@@ -113,7 +113,7 @@ def build_test_module(opts):
with open(
os.path.join(build_dir, file_format_str.format(name="test_params",
ext="bin")), "wb"
) as f_params:
- f_params.write(relay.save_param_dict(lowered_params))
+ f_params.write(runtime.save_param_dict(lowered_params))
with open(
os.path.join(build_dir, file_format_str.format(name="test_data",
ext="bin")), "wb"
) as fp:
diff --git a/apps/bundle_deploy/runtime.cc b/apps/bundle_deploy/runtime.cc
index 3224028..2f7e384 100644
--- a/apps/bundle_deploy/runtime.cc
+++ b/apps/bundle_deploy/runtime.cc
@@ -23,6 +23,7 @@
#include <tvm/runtime/registry.h>
#include "../../src/runtime/c_runtime_api.cc"
+#include "../../src/runtime/container.cc"
#include "../../src/runtime/cpu_device_api.cc"
#include "../../src/runtime/file_utils.cc"
#include "../../src/runtime/graph/graph_runtime.cc"
diff --git a/apps/sgx/src/build_model.py b/apps/sgx/src/build_model.py
index 868d3bc..1fc297d 100755
--- a/apps/sgx/src/build_model.py
+++ b/apps/sgx/src/build_model.py
@@ -23,7 +23,7 @@ import os
from os import path as osp
import sys
-from tvm import relay
+from tvm import relay, runtime
from tvm.relay import testing
import tvm
from tvm import te
@@ -49,7 +49,7 @@ def main():
with open(osp.join(build_dir, "graph.json"), "w") as f_graph_json:
f_graph_json.write(graph)
with open(osp.join(build_dir, "params.bin"), "wb") as f_params:
- f_params.write(relay.save_param_dict(params))
+ f_params.write(runtime.save_param_dict(params))
if __name__ == "__main__":
diff --git a/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py
b/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py
index 42695d2..3d8a349 100644
--- a/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py
+++ b/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py
@@ -24,7 +24,7 @@ import sys
import onnx
import tvm
-from tvm import relay
+from tvm import relay, runtime
def _get_mod_and_params(model_file):
@@ -60,7 +60,7 @@ def build_graph_lib(model_file, opt_level):
f_graph.write(graph_json)
with open(os.path.join(out_dir, "graph.params"), "wb") as f_params:
- f_params.write(relay.save_param_dict(params))
+ f_params.write(runtime.save_param_dict(params))
if __name__ == "__main__":
diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst
index 8c8fcfb..256978d 100644
--- a/docs/deploy/android.rst
+++ b/docs/deploy/android.rst
@@ -31,7 +31,7 @@ The code below will save the compilation output which is
required on android tar
with open("deploy_graph.json", "w") as fo:
fo.write(graph.json())
with open("deploy_param.params", "wb") as fo:
- fo.write(relay.save_param_dict(params))
+ fo.write(runtime.save_param_dict(params))
deploy_lib.so, deploy_graph.json, deploy_param.params will go to android
target.
diff --git a/golang/sample/gen_mobilenet_lib.py
b/golang/sample/gen_mobilenet_lib.py
index b82e0c4..12f215b 100644
--- a/golang/sample/gen_mobilenet_lib.py
+++ b/golang/sample/gen_mobilenet_lib.py
@@ -16,7 +16,7 @@
# under the License.
import os
-from tvm import relay, transform
+from tvm import relay, transform, runtime
from tvm.contrib.download import download_testdata
@@ -94,4 +94,4 @@ with open("./mobilenet.json", "w") as fo:
fo.write(graph)
with open("./mobilenet.params", "wb") as fo:
- fo.write(relay.save_param_dict(params))
+ fo.write(runtime.save_param_dict(params))
diff --git a/python/tvm/contrib/debugger/debug_result.py
b/python/tvm/contrib/debugger/debug_result.py
index 3159ab3..f58947f 100644
--- a/python/tvm/contrib/debugger/debug_result.py
+++ b/python/tvm/contrib/debugger/debug_result.py
@@ -264,8 +264,4 @@ def save_tensors(params):
"""
_save_tensors = tvm.get_global_func("tvm.relay._save_param_dict")
- args = []
- for k, v in params.items():
- args.append(k)
- args.append(tvm.nd.array(v))
- return _save_tensors(*args)
+ return _save_tensors(params)
diff --git a/python/tvm/driver/tvmc/compiler.py
b/python/tvm/driver/tvmc/compiler.py
index fc1805e..83791e5 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -24,7 +24,7 @@ from pathlib import Path
import tvm
from tvm import autotvm, auto_scheduler
-from tvm import relay
+from tvm import relay, runtime
from tvm.contrib import cc
from tvm.contrib import utils
@@ -282,7 +282,7 @@ def save_module(module_path, graph, lib, params,
cross=None):
with open(temp.relpath(param_name), "wb") as params_file:
logger.debug("writing params to file to %s", params_file.name)
- params_file.write(relay.save_param_dict(params))
+ params_file.write(runtime.save_param_dict(params))
logger.debug("saving module as tar file to %s", module_path)
with tarfile.open(module_path, "w") as tar:
diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py
index 87ea3be..1d23ccf 100644
--- a/python/tvm/driver/tvmc/runner.py
+++ b/python/tvm/driver/tvmc/runner.py
@@ -24,11 +24,11 @@ import tarfile
import tempfile
import numpy as np
-import tvm
from tvm import rpc
from tvm.autotvm.measure import request_remote
from tvm.contrib import graph_runtime as runtime
from tvm.contrib.debugger import debug_runtime
+from tvm.relay import load_param_dict
from . import common
from .common import TVMCException
@@ -163,9 +163,8 @@ def get_input_info(graph_str, params):
shape_dict = {}
dtype_dict = {}
- # Use a special function to load the binary params back into a dict
- load_arr = tvm.get_global_func("tvm.relay._load_param_dict")(params)
- param_names = [v.name for v in load_arr]
+ params_dict = load_param_dict(params)
+ param_names = [k for (k, v) in params_dict.items()]
graph = json.loads(graph_str)
for node_id in graph["arg_nodes"]:
node = graph["nodes"][node_id]
diff --git a/python/tvm/relay/param_dict.py b/python/tvm/relay/param_dict.py
index 2d0398e..2714607 100644
--- a/python/tvm/relay/param_dict.py
+++ b/python/tvm/relay/param_dict.py
@@ -16,12 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Helper utility to save parameter dicts."""
-import tvm
-import tvm._ffi
-
-
-_save_param_dict = tvm._ffi.get_global_func("tvm.relay._save_param_dict")
-_load_param_dict = tvm._ffi.get_global_func("tvm.relay._load_param_dict")
+import tvm.runtime
def save_param_dict(params):
@@ -30,6 +25,9 @@ def save_param_dict(params):
The result binary bytes can be loaded by the
GraphModule with API "load_params".
+ .. deprecated:: 0.9.0
+ Use :py:func:`tvm.runtime.save_param_dict` instead.
+
Parameters
----------
params : dict of str to NDArray
@@ -47,21 +45,20 @@ def save_param_dict(params):
# set up the parameter dict
params = {"param0": arr0, "param1": arr1}
# save the parameters as byte array
- param_bytes = tvm.relay.save_param_dict(params)
+ param_bytes = tvm.runtime.save_param_dict(params)
# We can serialize the param_bytes and load it back later.
# Pass in byte array to module to directly set parameters
- graph_runtime_mod.load_params(param_bytes)
+ tvm.runtime.load_param_dict(param_bytes)
"""
- args = []
- for k, v in params.items():
- args.append(k)
- args.append(tvm.nd.array(v))
- return _save_param_dict(*args)
+ return tvm.runtime.save_param_dict(params)
def load_param_dict(param_bytes):
"""Load parameter dictionary to binary bytes.
+ .. deprecated:: 0.9.0
+ Use :py:func:`tvm.runtime.load_param_dict` instead.
+
Parameters
----------
param_bytes: bytearray
@@ -72,7 +69,4 @@ def load_param_dict(param_bytes):
params : dict of str to NDArray
The parameter dictionary.
"""
- if isinstance(param_bytes, (bytes, str)):
- param_bytes = bytearray(param_bytes)
- load_arr = _load_param_dict(param_bytes)
- return {v.name: v.array for v in load_arr}
+ return tvm.runtime.load_param_dict(param_bytes)
diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index 21c06c5..7d58af7 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -29,3 +29,4 @@ from .ndarray import context, cpu, gpu, opencl, cl, vulkan,
metal, mtl
from .ndarray import vpi, rocm, ext_dev, micro_dev
from .module import load_module, enabled, system_lib
from .container import String
+from .params import save_param_dict, load_param_dict
diff --git a/python/tvm/relay/param_dict.py b/python/tvm/runtime/params.py
similarity index 76%
copy from python/tvm/relay/param_dict.py
copy to python/tvm/runtime/params.py
index 2d0398e..78e7456 100644
--- a/python/tvm/relay/param_dict.py
+++ b/python/tvm/runtime/params.py
@@ -15,13 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
-"""Helper utility to save parameter dicts."""
-import tvm
-import tvm._ffi
-
-
-_save_param_dict = tvm._ffi.get_global_func("tvm.relay._save_param_dict")
-_load_param_dict = tvm._ffi.get_global_func("tvm.relay._load_param_dict")
+"""Helper utility to save and load parameter dicts."""
+from . import _ffi_api, ndarray
def save_param_dict(params):
@@ -47,16 +42,13 @@ def save_param_dict(params):
# set up the parameter dict
params = {"param0": arr0, "param1": arr1}
# save the parameters as byte array
- param_bytes = tvm.relay.save_param_dict(params)
+ param_bytes = tvm.runtime.save_param_dict(params)
# We can serialize the param_bytes and load it back later.
# Pass in byte array to module to directly set parameters
- graph_runtime_mod.load_params(param_bytes)
+ tvm.runtime.load_param_dict(param_bytes)
"""
- args = []
- for k, v in params.items():
- args.append(k)
- args.append(tvm.nd.array(v))
- return _save_param_dict(*args)
+ transformed = {k: ndarray.array(v) for (k, v) in params.items()}
+ return _ffi_api.SaveParams(transformed)
def load_param_dict(param_bytes):
@@ -74,5 +66,4 @@ def load_param_dict(param_bytes):
"""
if isinstance(param_bytes, (bytes, str)):
param_bytes = bytearray(param_bytes)
- load_arr = _load_param_dict(param_bytes)
- return {v.name: v.array for v in load_arr}
+ return _ffi_api.LoadParams(param_bytes)
diff --git a/rust/tvm-graph-rt/src/graph.rs b/rust/tvm-graph-rt/src/graph.rs
index 646a20d..83fe37e 100644
--- a/rust/tvm-graph-rt/src/graph.rs
+++ b/rust/tvm-graph-rt/src/graph.rs
@@ -483,7 +483,7 @@ named! {
)
}
-/// Loads a param dict saved using `relay.save_param_dict`.
+/// Loads a param dict saved using `runtime.save_param_dict`.
pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>,
GraphFormatError> {
match parse_param_dict(bytes) {
Ok((remaining_bytes, param_dict)) => {
diff --git a/rust/tvm-graph-rt/tests/build_model.py
b/rust/tvm-graph-rt/tests/build_model.py
index d34b440..9690759 100755
--- a/rust/tvm-graph-rt/tests/build_model.py
+++ b/rust/tvm-graph-rt/tests/build_model.py
@@ -23,7 +23,7 @@ from os import path as osp
import numpy as np
import tvm
from tvm import te
-from tvm import relay
+from tvm import relay, runtime
from tvm.relay import testing
CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
@@ -47,7 +47,7 @@ def main():
with open(osp.join(CWD, "graph.json"), "w") as f_resnet:
f_resnet.write(graph)
with open(osp.join(CWD, "graph.params"), "wb") as f_params:
- f_params.write(relay.save_param_dict(params))
+ f_params.write(runtime.save_param_dict(params))
if __name__ == "__main__":
diff --git a/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py
b/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py
index e743e48..0045b3b 100755
--- a/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py
+++ b/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py
@@ -23,7 +23,7 @@ import sys
import numpy as np
import tvm
-from tvm import te
+from tvm import te, runtime
from tvm import relay
from tvm.relay import testing
@@ -49,7 +49,7 @@ def main():
f_resnet.write(graph)
with open(osp.join(out_dir, "graph.params"), "wb") as f_params:
- f_params.write(relay.save_param_dict(params))
+ f_params.write(runtime.save_param_dict(params))
if __name__ == "__main__":
diff --git a/rust/tvm/examples/resnet/src/build_resnet.py
b/rust/tvm/examples/resnet/src/build_resnet.py
index 03ac611..fdacb5b 100644
--- a/rust/tvm/examples/resnet/src/build_resnet.py
+++ b/rust/tvm/examples/resnet/src/build_resnet.py
@@ -27,7 +27,7 @@ import numpy as np
import tvm
from tvm import te
-from tvm import relay
+from tvm import relay, runtime
from tvm.relay import testing
from tvm.contrib import graph_runtime, cc
from PIL import Image
@@ -88,7 +88,7 @@ def build(target_dir):
fo.write(graph)
with open(osp.join(target_dir, "deploy_param.params"), "wb") as fo:
- fo.write(relay.save_param_dict(params))
+ fo.write(runtime.save_param_dict(params))
def download_img_labels():
diff --git a/src/relay/backend/param_dict.cc b/src/relay/backend/param_dict.cc
index 1d7e08a..bb0fad9 100644
--- a/src/relay/backend/param_dict.cc
+++ b/src/relay/backend/param_dict.cc
@@ -31,70 +31,24 @@
#include <utility>
#include <vector>
+#include "../../runtime/file_utils.h"
+
namespace tvm {
namespace relay {
using namespace runtime;
-TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict").set_body([](TVMArgs args,
TVMRetValue* rv) {
- ICHECK_EQ(args.size() % 2, 0u);
- // `args` is in the form "key, value, key, value, ..."
- size_t num_params = args.size() / 2;
- std::vector<std::string> names;
- names.reserve(num_params);
- std::vector<DLTensor*> arrays;
- arrays.reserve(num_params);
- for (size_t i = 0; i < num_params * 2; i += 2) {
- names.emplace_back(args[i].operator String());
- arrays.emplace_back(args[i + 1].operator DLTensor*());
- }
- std::string bytes;
- dmlc::MemoryStringStream strm(&bytes);
- dmlc::Stream* fo = &strm;
- uint64_t header = kTVMNDArrayListMagic, reserved = 0;
- fo->Write(header);
- fo->Write(reserved);
- fo->Write(names);
- {
- uint64_t sz = static_cast<uint64_t>(arrays.size());
- fo->Write(sz);
- for (size_t i = 0; i < sz; ++i) {
- tvm::runtime::SaveDLTensor(fo, arrays[i]);
- }
- }
- TVMByteArray arr;
- arr.data = bytes.c_str();
- arr.size = bytes.length();
- *rv = arr;
-});
-
-TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict").set_body([](TVMArgs args,
TVMRetValue* rv) {
- std::string bytes = args[0];
- std::vector<std::string> names;
- dmlc::MemoryStringStream memstrm(&bytes);
- dmlc::Stream* strm = &memstrm;
- uint64_t header, reserved;
- ICHECK(strm->Read(&header)) << "Invalid parameters file format";
- ICHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format";
- ICHECK(strm->Read(&reserved)) << "Invalid parameters file format";
- ICHECK(strm->Read(&names)) << "Invalid parameters file format";
- uint64_t sz;
- strm->Read(&sz, sizeof(sz));
- size_t size = static_cast<size_t>(sz);
- ICHECK(size == names.size()) << "Invalid parameters file format";
- tvm::Array<NamedNDArray> ret;
- for (size_t i = 0; i < size; ++i) {
- tvm::runtime::NDArray temp;
- temp.Load(strm);
- auto n = tvm::make_object<NamedNDArrayNode>();
- n->name = std::move(names[i]);
- n->array = temp;
- ret.push_back(NamedNDArray(n));
- }
- *rv = ret;
+TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict")
+ .set_body_typed([](const Map<String, NDArray>& params) {
+ std::string s = ::tvm::runtime::SaveParams(params);
+ // copy return array so it is owned by the ret value
+ TVMRetValue rv;
+ rv = TVMByteArray{s.data(), s.size()};
+ return rv;
+ });
+TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict").set_body_typed([](const
String& s) {
+ return ::tvm::runtime::LoadParams(s);
});
-TVM_REGISTER_NODE_TYPE(NamedNDArrayNode);
-
} // namespace relay
} // namespace tvm
diff --git a/src/relay/backend/param_dict.h b/src/relay/backend/param_dict.h
index 384201f..96e17a9 100644
--- a/src/relay/backend/param_dict.h
+++ b/src/relay/backend/param_dict.h
@@ -32,32 +32,7 @@
#include <string>
namespace tvm {
-namespace relay {
-
-/*! \brief Magic number for NDArray list file */
-constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
-
-/*!
- * \brief Wrapper node for naming `NDArray`s.
- */
-struct NamedNDArrayNode : public ::tvm::Object {
- std::string name;
- tvm::runtime::NDArray array;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("name", &name);
- v->Visit("array", &array);
- }
-
- static constexpr const char* _type_key = "NamedNDArray";
- TVM_DECLARE_FINAL_OBJECT_INFO(NamedNDArrayNode, Object);
-};
-
-class NamedNDArray : public ObjectRef {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(NamedNDArray, ObjectRef, NamedNDArrayNode);
-};
-} // namespace relay
+namespace relay {} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_PARAM_DICT_H_
diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc
index 3957505..92c398b 100644
--- a/src/runtime/file_utils.cc
+++ b/src/runtime/file_utils.cc
@@ -24,6 +24,7 @@
#include <dmlc/json.h>
#include <dmlc/memory_io.h>
+#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include <tvm/support/logging.h>
@@ -158,5 +159,71 @@ void LoadMetaDataFromFile(const std::string& file_name,
void RemoveFile(const std::string& file_name) {
std::remove(file_name.c_str()); }
+Map<String, NDArray> LoadParams(const std::string& param_blob) {
+ dmlc::MemoryStringStream strm(const_cast<std::string*>(¶m_blob));
+ return LoadParams(&strm);
+}
+Map<String, NDArray> LoadParams(dmlc::Stream* strm) {
+ Map<String, NDArray> params;
+ uint64_t header, reserved;
+ ICHECK(strm->Read(&header)) << "Invalid parameters file format";
+ ICHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format";
+ ICHECK(strm->Read(&reserved)) << "Invalid parameters file format";
+
+ std::vector<std::string> names;
+ ICHECK(strm->Read(&names)) << "Invalid parameters file format";
+ uint64_t sz;
+ strm->Read(&sz);
+ size_t size = static_cast<size_t>(sz);
+ ICHECK(size == names.size()) << "Invalid parameters file format";
+ for (size_t i = 0; i < size; ++i) {
+ // The data_entry is allocated on device, NDArray.load always load the
array into CPU.
+ NDArray temp;
+ temp.Load(strm);
+ params.Set(names[i], temp);
+ }
+ return params;
+}
+
+void SaveParams(dmlc::Stream* strm, const Map<String, NDArray>& params) {
+ std::vector<std::string> names;
+ std::vector<const DLTensor*> arrays;
+ for (auto& p : params) {
+ names.push_back(p.first);
+ arrays.push_back(p.second.operator->());
+ }
+
+ uint64_t header = kTVMNDArrayListMagic, reserved = 0;
+ strm->Write(header);
+ strm->Write(reserved);
+ strm->Write(names);
+ {
+ uint64_t sz = static_cast<uint64_t>(arrays.size());
+ strm->Write(sz);
+ for (size_t i = 0; i < sz; ++i) {
+ tvm::runtime::SaveDLTensor(strm, arrays[i]);
+ }
+ }
+}
+
+std::string SaveParams(const Map<String, NDArray>& params) {
+ std::string bytes;
+ dmlc::MemoryStringStream strm(&bytes);
+ dmlc::Stream* fo = &strm;
+ SaveParams(fo, params);
+ return bytes;
+}
+
+TVM_REGISTER_GLOBAL("runtime.SaveParams").set_body_typed([](const Map<String,
NDArray>& params) {
+ std::string s = ::tvm::runtime::SaveParams(params);
+ // copy return array so it is owned by the ret value
+ TVMRetValue rv;
+ rv = TVMByteArray{s.data(), s.size()};
+ return rv;
+});
+TVM_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const String& s) {
+ return ::tvm::runtime::LoadParams(s);
+});
+
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/file_utils.h b/src/runtime/file_utils.h
index dfa7d67..718d10d 100644
--- a/src/runtime/file_utils.h
+++ b/src/runtime/file_utils.h
@@ -94,6 +94,32 @@ void LoadMetaDataFromFile(const std::string& file_name,
* \param file_name The file name.
*/
void RemoveFile(const std::string& file_name);
+
+constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
+/*!
+ * \brief Load parameters from a string.
+ * \param param_blob Serialized string of parameters.
+ * \return Map of parameter name to parameter value.
+ */
+Map<String, NDArray> LoadParams(const std::string& param_blob);
+/*!
+ * \brief Load parameters from a stream.
+ * \param strm Stream to load parameters from.
+ * \return Map of parameter name to parameter value.
+ */
+Map<String, NDArray> LoadParams(dmlc::Stream* strm);
+/*!
+ * \brief Serialize parameters to a byte array.
+ * \param params Parameters to save.
+ * \return String containing binary parameter data.
+ */
+std::string SaveParams(const Map<String, NDArray>& params);
+/*!
+ * \brief Serialize parameters to a stream.
+ * \param strm Stream to write to.
+ * \param params Parameters to save.
+ */
+void SaveParams(dmlc::Stream* strm, const Map<String, NDArray>& params);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_FILE_UTILS_H_
diff --git a/src/runtime/graph/graph_runtime.cc
b/src/runtime/graph/graph_runtime.cc
index 6d586cf..6c51e71 100644
--- a/src/runtime/graph/graph_runtime.cc
+++ b/src/runtime/graph/graph_runtime.cc
@@ -38,6 +38,8 @@
#include <utility>
#include <vector>
+#include "../file_utils.h"
+
namespace tvm {
namespace runtime {
namespace details {
@@ -196,31 +198,10 @@ void GraphRuntime::LoadParams(const std::string&
param_blob) {
}
void GraphRuntime::LoadParams(dmlc::Stream* strm) {
- uint64_t header, reserved;
- ICHECK(strm->Read(&header)) << "Invalid parameters file format";
- ICHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format";
- ICHECK(strm->Read(&reserved)) << "Invalid parameters file format";
-
- std::vector<std::string> names;
- ICHECK(strm->Read(&names)) << "Invalid parameters file format";
- uint64_t sz;
- strm->Read(&sz);
- size_t size = static_cast<size_t>(sz);
- ICHECK(size == names.size()) << "Invalid parameters file format";
- for (size_t i = 0; i < size; ++i) {
- int in_idx = GetInputIndex(names[i]);
- if (in_idx < 0) {
- NDArray temp;
- temp.Load(strm);
- continue;
- }
- uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
- ICHECK_LT(eid, data_entry_.size());
-
- // The data_entry is allocated on device, NDArray.load always load the
array into CPU.
- NDArray temp;
- temp.Load(strm);
- data_entry_[eid].CopyFrom(temp);
+ Map<String, NDArray> params = ::tvm::runtime::LoadParams(strm);
+ for (auto& p : params) {
+ uint32_t eid = this->entry_id(input_nodes_[GetInputIndex(p.first)], 0);
+ data_entry_[eid].CopyFrom(p.second);
}
}
diff --git a/src/runtime/graph/graph_runtime.h
b/src/runtime/graph/graph_runtime.h
index 6279118..a1e2ee3 100644
--- a/src/runtime/graph/graph_runtime.h
+++ b/src/runtime/graph/graph_runtime.h
@@ -47,9 +47,6 @@ namespace runtime {
ICHECK_EQ(ret, 0) << TVMGetLastError(); \
}
-/*! \brief Magic number for NDArray list file */
-constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
-
/*! \brief operator attributes about tvm op */
struct TVMOpParam {
std::string func_name;
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index 3f890ba..6d121aa 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -35,6 +35,8 @@
#include <stdexcept>
#include <vector>
+#include "../file_utils.h"
+
using namespace tvm::runtime;
namespace tvm {
diff --git a/tests/python/contrib/test_tensorrt.py
b/tests/python/contrib/test_tensorrt.py
index 60d6b2a..ae8214d 100644
--- a/tests/python/contrib/test_tensorrt.py
+++ b/tests/python/contrib/test_tensorrt.py
@@ -22,7 +22,7 @@ import itertools
import tvm
import tvm.relay.testing
-from tvm import relay
+from tvm import relay, runtime
from tvm.relay.op.contrib import tensorrt
from tvm.contrib import graph_runtime, utils
from tvm.runtime.vm import VirtualMachine
@@ -265,7 +265,7 @@ def test_tensorrt_serialize_graph_runtime():
def compile_graph(mod, params):
with tvm.transform.PassContext(opt_level=3,
config={"relay.ext.tensorrt.options": config}):
graph, lib, params = relay.build(mod, params=params, target="cuda")
- params = relay.save_param_dict(params)
+ params = runtime.save_param_dict(params)
return graph, lib, params
def run_graph(graph, lib, params):
diff --git a/tests/python/relay/test_cpp_build_module.py
b/tests/python/relay/test_cpp_build_module.py
index 67f0621..60f3dfa 100644
--- a/tests/python/relay/test_cpp_build_module.py
+++ b/tests/python/relay/test_cpp_build_module.py
@@ -18,7 +18,7 @@ import numpy as np
import tvm
from tvm import te
-from tvm import relay
+from tvm import relay, runtime
from tvm.contrib.nvcc import have_fp16
import tvm.testing
@@ -86,7 +86,7 @@ def test_fp16_build():
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
- rt.load_params(relay.save_param_dict(params))
+ rt.load_params(runtime.save_param_dict(params))
rt.run()
out = rt.get_output(0)
diff --git a/tests/python/relay/test_param_dict.py
b/tests/python/relay/test_param_dict.py
index 74c9ebc..29e0b5c 100644
--- a/tests/python/relay/test_param_dict.py
+++ b/tests/python/relay/test_param_dict.py
@@ -17,7 +17,7 @@
import os
import numpy as np
import tvm
-from tvm import te
+from tvm import te, runtime
import json
import base64
from tvm._ffi.base import py_str
@@ -31,7 +31,7 @@ def test_save_load():
x = np.ones((10, 2)).astype("float32")
y = np.ones((1, 2, 3)).astype("float32")
params = {"x": x, "y": y}
- param_bytes = relay.save_param_dict(params)
+ param_bytes = runtime.save_param_dict(params)
assert isinstance(param_bytes, bytearray)
param2 = relay.load_param_dict(param_bytes)
assert len(param2) == 2
@@ -46,7 +46,7 @@ def test_ndarray_reflection():
param_dict = {"x": tvm_array, "y": tvm_array}
assert param_dict["x"].same_as(param_dict["y"])
# Serialize then deserialize `param_dict`.
- deser_param_dict = relay.load_param_dict(relay.save_param_dict(param_dict))
+ deser_param_dict =
relay.load_param_dict(runtime.save_param_dict(param_dict))
# Make sure the data matches the original data and `x` and `y` contain the
same data.
np.testing.assert_equal(deser_param_dict["x"].asnumpy(),
tvm_array.asnumpy())
# Make sure `x` and `y` contain the same data.
@@ -77,7 +77,7 @@ def test_bigendian_rpc_param():
lib = remote.load_module("dev_lib.o")
ctx = remote.cpu(0)
mod = graph_runtime.create(graph, lib, ctx)
- mod.load_params(relay.save_param_dict(params))
+ mod.load_params(runtime.save_param_dict(params))
mod.run()
out = mod.get_output(0, tvm.nd.empty(shape, dtype=dtype, ctx=ctx))
tvm.testing.assert_allclose(x_in + 1, out.asnumpy())
diff --git a/tests/python/unittest/test_runtime_graph.py
b/tests/python/unittest/test_runtime_graph.py
index c43a359..16e9db4 100644
--- a/tests/python/unittest/test_runtime_graph.py
+++ b/tests/python/unittest/test_runtime_graph.py
@@ -16,7 +16,7 @@
# under the License.
import tvm
import tvm.testing
-from tvm import te
+from tvm import te, runtime
import numpy as np
import json
from tvm import rpc
@@ -94,12 +94,12 @@ def test_graph_simple():
graph, lib, params = relay.build(func, target="llvm", params=params)
mod_shared = graph_runtime.create(graph, lib, tvm.cpu(0))
- mod_shared.load_params(relay.save_param_dict(params))
+ mod_shared.load_params(runtime.save_param_dict(params))
num_mods = 10
mods = [graph_runtime.create(graph, lib, tvm.cpu(0)) for _ in
range(num_mods)]
for mod in mods:
- mod.share_params(mod_shared, relay.save_param_dict(params))
+ mod.share_params(mod_shared, runtime.save_param_dict(params))
a = np.random.uniform(size=(1, 10)).astype("float32")
for mod in mods:
diff --git a/tests/python/unittest/test_runtime_module_based_interface.py
b/tests/python/unittest/test_runtime_module_based_interface.py
index 51a5872..a34fe4a 100644
--- a/tests/python/unittest/test_runtime_module_based_interface.py
+++ b/tests/python/unittest/test_runtime_module_based_interface.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
-from tvm import relay
+from tvm import relay, runtime
from tvm.relay import testing
import tvm
from tvm.contrib import graph_runtime
@@ -314,7 +314,7 @@ def test_remove_package_params():
complied_graph_lib_no_params = complied_graph_lib["remove_params"]()
complied_graph_lib_no_params.export_library(path_lib)
with open(temp.relpath("deploy_param.params"), "wb") as fo:
- fo.write(relay.save_param_dict(complied_graph_lib.get_params()))
+ fo.write(runtime.save_param_dict(complied_graph_lib.get_params()))
loaded_lib = tvm.runtime.load_module(path_lib)
data = np.random.uniform(-1, 1,
size=input_shape(mod)).astype("float32")
ctx = tvm.cpu(0)
@@ -361,7 +361,7 @@ def test_remove_package_params():
complied_graph_lib_no_params = complied_graph_lib["remove_params"]()
complied_graph_lib_no_params.export_library(path_lib)
with open(temp.relpath("deploy_param.params"), "wb") as fo:
- fo.write(relay.save_param_dict(complied_graph_lib.get_params()))
+ fo.write(runtime.save_param_dict(complied_graph_lib.get_params()))
loaded_lib = tvm.runtime.load_module(path_lib)
data = np.random.uniform(-1, 1,
size=input_shape(mod)).astype("float32")
ctx = tvm.gpu(0)
@@ -409,7 +409,7 @@ def test_remove_package_params():
complied_graph_lib_no_params.export_library(path_lib)
path_params = temp.relpath("deploy_param.params")
with open(path_params, "wb") as fo:
- fo.write(relay.save_param_dict(complied_graph_lib.get_params()))
+ fo.write(runtime.save_param_dict(complied_graph_lib.get_params()))
from tvm import rpc
@@ -462,7 +462,7 @@ def test_remove_package_params():
complied_graph_lib_no_params.export_library(path_lib)
path_params = temp.relpath("deploy_param.params")
with open(path_params, "wb") as fo:
- fo.write(relay.save_param_dict(complied_graph_lib.get_params()))
+ fo.write(runtime.save_param_dict(complied_graph_lib.get_params()))
from tvm import rpc
diff --git a/tutorials/frontend/deploy_sparse.py
b/tutorials/frontend/deploy_sparse.py
index 9641fb8..98004a9 100644
--- a/tutorials/frontend/deploy_sparse.py
+++ b/tutorials/frontend/deploy_sparse.py
@@ -81,7 +81,7 @@ import time
import itertools
import numpy as np
import tensorflow as tf
-from tvm import relay
+from tvm import relay, runtime
from tvm.contrib import graph_runtime
from tvm.relay import data_dep_optimization as ddo
from tensorflow.python.framework.convert_to_constants import (
@@ -196,7 +196,7 @@ def import_graphdef(
with open(os.path.join(abs_path, relay_file), "w") as fo:
fo.write(tvm.ir.save_json(mod))
with open(os.path.join(abs_path, relay_params), "wb") as fo:
- fo.write(relay.save_param_dict(params))
+ fo.write(runtime.save_param_dict(params))
return mod, params, shape_dict