This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1043136c9f [Runtime] Fix high RAM usage when saving / loading
paramters of big models (#14147)
1043136c9f is described below
commit 1043136c9f1d11bddd8890c8b9fc508eae70a343
Author: masahi <[email protected]>
AuthorDate: Wed Mar 1 17:36:38 2023 +0900
[Runtime] Fix high RAM usage when saving / loading paramters of big models
(#14147)
* add load_params_from_file
* add save_params_to_file
* avoid making another copy in save_params
* black
* add test
* update doc
---
python/tvm/runtime/__init__.py | 7 ++++-
python/tvm/runtime/params.py | 49 ++++++++++++++++++++++++++---
src/runtime/file_utils.cc | 12 +++++++
tests/python/unittest/test_runtime_graph.py | 16 ++++++++--
4 files changed, 77 insertions(+), 7 deletions(-)
diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index 71f71e6c84..eccdcbad95 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -32,6 +32,11 @@ from .ndarray import device, cpu, cuda, gpu, opencl, cl,
vulkan, metal, mtl
from .ndarray import vpi, rocm, ext_dev
from .module import load_module, enabled, system_lib, load_static_library
from .container import String, ShapeTuple
-from .params import save_param_dict, load_param_dict
+from .params import (
+ save_param_dict,
+ load_param_dict,
+ save_param_dict_to_file,
+ load_param_dict_from_file,
+)
from . import executor
diff --git a/python/tvm/runtime/params.py b/python/tvm/runtime/params.py
index 78e745686c..4362a4b6a8 100644
--- a/python/tvm/runtime/params.py
+++ b/python/tvm/runtime/params.py
@@ -16,7 +16,19 @@
# under the License.
# pylint: disable=invalid-name
"""Helper utility to save and load parameter dicts."""
-from . import _ffi_api, ndarray
+from . import _ffi_api, ndarray, NDArray
+
+
+def _to_ndarray(params):
+ transformed = {}
+
+ for (k, v) in params.items():
+ if not isinstance(v, NDArray):
+ transformed[k] = ndarray.array(v)
+ else:
+ transformed[k] = v
+
+ return transformed
def save_param_dict(params):
@@ -47,12 +59,25 @@ def save_param_dict(params):
# Pass in byte array to module to directly set parameters
tvm.runtime.load_param_dict(param_bytes)
"""
- transformed = {k: ndarray.array(v) for (k, v) in params.items()}
- return _ffi_api.SaveParams(transformed)
+ return _ffi_api.SaveParams(_to_ndarray(params))
+
+
+def save_param_dict_to_file(params, path):
+ """Save parameter dictionary to file.
+
+ Parameters
+ ----------
+ params : dict of str to NDArray
+ The parameter dictionary.
+
+ path: str
+ The path to the parameter file.
+ """
+ return _ffi_api.SaveParamsToFile(_to_ndarray(params), path)
def load_param_dict(param_bytes):
- """Load parameter dictionary to binary bytes.
+ """Load parameter dictionary from binary bytes.
Parameters
----------
@@ -67,3 +92,19 @@ def load_param_dict(param_bytes):
if isinstance(param_bytes, (bytes, str)):
param_bytes = bytearray(param_bytes)
return _ffi_api.LoadParams(param_bytes)
+
+
+def load_param_dict_from_file(path):
+ """Load parameter dictionary from file.
+
+ Parameters
+ ----------
+ path: str
+ The path to the parameter file to load from.
+
+ Returns
+ -------
+ params : dict of str to NDArray
+ The parameter dictionary.
+ """
+ return _ffi_api.LoadParamsFromFile(path)
diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc
index 1e7cc6ad44..1c0e16dbe1 100644
--- a/src/runtime/file_utils.cc
+++ b/src/runtime/file_utils.cc
@@ -243,9 +243,21 @@
TVM_REGISTER_GLOBAL("runtime.SaveParams").set_body_typed([](const Map<String, ND
rv = TVMByteArray{s.data(), s.size()};
return rv;
});
+
+TVM_REGISTER_GLOBAL("runtime.SaveParamsToFile")
+ .set_body_typed([](const Map<String, NDArray>& params, const String& path)
{
+ tvm::runtime::SimpleBinaryFileStream strm(path, "wb");
+ SaveParams(&strm, params);
+ });
+
TVM_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const String& s) {
return ::tvm::runtime::LoadParams(s);
});
+TVM_REGISTER_GLOBAL("runtime.LoadParamsFromFile").set_body_typed([](const
String& path) {
+ tvm::runtime::SimpleBinaryFileStream strm(path, "rb");
+ return LoadParams(&strm);
+});
+
} // namespace runtime
} // namespace tvm
diff --git a/tests/python/unittest/test_runtime_graph.py
b/tests/python/unittest/test_runtime_graph.py
index 458952fb56..108784de7e 100644
--- a/tests/python/unittest/test_runtime_graph.py
+++ b/tests/python/unittest/test_runtime_graph.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import tempfile
import tvm
import tvm.testing
from tvm import te, runtime
@@ -138,6 +139,17 @@ def test_load_unexpected_params():
rt_mod.load_params(runtime.save_param_dict(new_params))
+def test_save_load_file():
+ p = np.random.randn(10)
+ params = {"x": p}
+
+ with tempfile.NamedTemporaryFile() as fp:
+ tvm.runtime.save_param_dict_to_file(params, fp.name)
+ params_loaded = tvm.runtime.load_param_dict_from_file(fp.name)
+
+ assert "x" in params_loaded
+ np.testing.assert_equal(p, params_loaded["x"].numpy())
+
+
if __name__ == "__main__":
- test_graph_simple()
- test_load_unexpected_params()
+ tvm.testing.main()