This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e47bf761d4 [Unity] Smart parameter fetching (#14708)
e47bf761d4 is described below
commit e47bf761d4945dc34f788070d5751446c2e2bb4c
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Apr 24 21:35:53 2023 -0400
[Unity] Smart parameter fetching (#14708)
This PR add a smart mode of param fetching that can pass in -1
and fetches all parameters with given prefix.
---
src/runtime/relax_vm/ndarray_cache_support.cc | 3 ++-
tests/python/relax/test_runtime_builtin.py | 2 +-
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc
b/src/runtime/relax_vm/ndarray_cache_support.cc
index 311681228e..3e4bcce20e 100644
--- a/src/runtime/relax_vm/ndarray_cache_support.cc
+++ b/src/runtime/relax_vm/ndarray_cache_support.cc
@@ -174,12 +174,13 @@ class ParamModuleNode : public runtime::ModuleNode {
static Array<NDArray> GetParams(const std::string& prefix, int num_params) {
Array<NDArray> params;
- for (int i = 0; i < num_params; ++i) {
+ for (int i = 0; i < num_params || num_params == -1; ++i) {
std::string name = prefix + "_" + std::to_string(i);
auto opt = NDArrayCache::Get(name);
if (opt) {
params.push_back(opt.value());
} else {
+ if (num_params == -1) return params;
LOG(FATAL) << "Cannot find " << name << " in cache";
}
}
diff --git a/tests/python/relax/test_runtime_builtin.py
b/tests/python/relax/test_runtime_builtin.py
index f4ab3a2b54..d25841a71f 100644
--- a/tests/python/relax/test_runtime_builtin.py
+++ b/tests/python/relax/test_runtime_builtin.py
@@ -180,7 +180,7 @@ def test_ndarray_cache():
temp = utils.tempdir()
tvmjs.dump_ndarray_cache(param_dict, temp.path,
encode_format="f32-to-bf16")
fload(str(temp.path), tvm.cpu().device_type, 0)
- res = fget_params("x", 2)
+ res = fget_params("x", -1)
for i, v in enumerate(res):
v_np = param_dict[f"x_{i}"]
if v_np.dtype == "float32":