This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1043136c9f [Runtime] Fix high RAM usage when saving / loading 
paramters of big models   (#14147)
1043136c9f is described below

commit 1043136c9f1d11bddd8890c8b9fc508eae70a343
Author: masahi <[email protected]>
AuthorDate: Wed Mar 1 17:36:38 2023 +0900

    [Runtime] Fix high RAM usage when saving / loading paramters of big models  
 (#14147)
    
    * add load_params_from_file
    
    * add save_params_to_file
    
    * avoid making another copy in save_params
    
    * black
    
    * add test
    
    * update doc
---
 python/tvm/runtime/__init__.py              |  7 ++++-
 python/tvm/runtime/params.py                | 49 ++++++++++++++++++++++++++---
 src/runtime/file_utils.cc                   | 12 +++++++
 tests/python/unittest/test_runtime_graph.py | 16 ++++++++--
 4 files changed, 77 insertions(+), 7 deletions(-)

diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index 71f71e6c84..eccdcbad95 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -32,6 +32,11 @@ from .ndarray import device, cpu, cuda, gpu, opencl, cl, 
vulkan, metal, mtl
 from .ndarray import vpi, rocm, ext_dev
 from .module import load_module, enabled, system_lib, load_static_library
 from .container import String, ShapeTuple
-from .params import save_param_dict, load_param_dict
+from .params import (
+    save_param_dict,
+    load_param_dict,
+    save_param_dict_to_file,
+    load_param_dict_from_file,
+)
 
 from . import executor
diff --git a/python/tvm/runtime/params.py b/python/tvm/runtime/params.py
index 78e745686c..4362a4b6a8 100644
--- a/python/tvm/runtime/params.py
+++ b/python/tvm/runtime/params.py
@@ -16,7 +16,19 @@
 # under the License.
 # pylint: disable=invalid-name
 """Helper utility to save and load parameter dicts."""
-from . import _ffi_api, ndarray
+from . import _ffi_api, ndarray, NDArray
+
+
+def _to_ndarray(params):
+    transformed = {}
+
+    for (k, v) in params.items():
+        if not isinstance(v, NDArray):
+            transformed[k] = ndarray.array(v)
+        else:
+            transformed[k] = v
+
+    return transformed
 
 
 def save_param_dict(params):
@@ -47,12 +59,25 @@ def save_param_dict(params):
        # Pass in byte array to module to directly set parameters
        tvm.runtime.load_param_dict(param_bytes)
     """
-    transformed = {k: ndarray.array(v) for (k, v) in params.items()}
-    return _ffi_api.SaveParams(transformed)
+    return _ffi_api.SaveParams(_to_ndarray(params))
+
+
+def save_param_dict_to_file(params, path):
+    """Save parameter dictionary to file.
+
+    Parameters
+    ----------
+    params : dict of str to NDArray
+        The parameter dictionary.
+
+    path: str
+        The path to the parameter file.
+    """
+    return _ffi_api.SaveParamsToFile(_to_ndarray(params), path)
 
 
 def load_param_dict(param_bytes):
-    """Load parameter dictionary to binary bytes.
+    """Load parameter dictionary from binary bytes.
 
     Parameters
     ----------
@@ -67,3 +92,19 @@ def load_param_dict(param_bytes):
     if isinstance(param_bytes, (bytes, str)):
         param_bytes = bytearray(param_bytes)
     return _ffi_api.LoadParams(param_bytes)
+
+
+def load_param_dict_from_file(path):
+    """Load parameter dictionary from file.
+
+    Parameters
+    ----------
+    path: str
+        The path to the parameter file to load from.
+
+    Returns
+    -------
+    params : dict of str to NDArray
+        The parameter dictionary.
+    """
+    return _ffi_api.LoadParamsFromFile(path)
diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc
index 1e7cc6ad44..1c0e16dbe1 100644
--- a/src/runtime/file_utils.cc
+++ b/src/runtime/file_utils.cc
@@ -243,9 +243,21 @@ 
TVM_REGISTER_GLOBAL("runtime.SaveParams").set_body_typed([](const Map<String, ND
   rv = TVMByteArray{s.data(), s.size()};
   return rv;
 });
+
+TVM_REGISTER_GLOBAL("runtime.SaveParamsToFile")
+    .set_body_typed([](const Map<String, NDArray>& params, const String& path) 
{
+      tvm::runtime::SimpleBinaryFileStream strm(path, "wb");
+      SaveParams(&strm, params);
+    });
+
 TVM_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const String& s) {
   return ::tvm::runtime::LoadParams(s);
 });
 
+TVM_REGISTER_GLOBAL("runtime.LoadParamsFromFile").set_body_typed([](const 
String& path) {
+  tvm::runtime::SimpleBinaryFileStream strm(path, "rb");
+  return LoadParams(&strm);
+});
+
 }  // namespace runtime
 }  // namespace tvm
diff --git a/tests/python/unittest/test_runtime_graph.py 
b/tests/python/unittest/test_runtime_graph.py
index 458952fb56..108784de7e 100644
--- a/tests/python/unittest/test_runtime_graph.py
+++ b/tests/python/unittest/test_runtime_graph.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import tempfile
 import tvm
 import tvm.testing
 from tvm import te, runtime
@@ -138,6 +139,17 @@ def test_load_unexpected_params():
     rt_mod.load_params(runtime.save_param_dict(new_params))
 
 
+def test_save_load_file():
+    p = np.random.randn(10)
+    params = {"x": p}
+
+    with tempfile.NamedTemporaryFile() as fp:
+        tvm.runtime.save_param_dict_to_file(params, fp.name)
+        params_loaded = tvm.runtime.load_param_dict_from_file(fp.name)
+
+        assert "x" in params_loaded
+        np.testing.assert_equal(p, params_loaded["x"].numpy())
+
+
 if __name__ == "__main__":
-    test_graph_simple()
-    test_load_unexpected_params()
+    tvm.testing.main()

Reply via email to