This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 3bd516701c [Unity] handle bf16 in dump_ndarray_cache and 
load_ndarray_cache (#14514)
3bd516701c is described below

commit 3bd516701c6f43f051f37b75f71b165d0d3245cd
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed Apr 5 18:19:27 2023 -0400

    [Unity] handle bf16 in dump_ndarray_cache and load_ndarray_cache (#14514)
    
    handle bf16 dump and load
---
 python/tvm/contrib/tvmjs.py | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
index 3783baefe0..48fac1c66d 100644
--- a/python/tvm/contrib/tvmjs.py
+++ b/python/tvm/contrib/tvmjs.py
@@ -179,20 +179,21 @@ def dump_ndarray_cache(
 
     shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard", 
shard_cap_nbytes)
 
-    for k, v in params.items():
-        shape = list(v.shape)
-
+    for k, origin_v in params.items():
+        shape = list(origin_v.shape)
+        v = origin_v
         if not isinstance(v, np.ndarray):
             v = v.numpy()
 
+        # prefer to preserve original dtype, especially if the format was 
bfloat16
+        dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray) 
else v.dtype
+
         # convert fp32 to bf16
-        if encode_format == "f32-to-bf16" and v.dtype == "float32":
+        if encode_format == "f32-to-bf16" and dtype == "float32":
             data = _convert_f32_to_bf16(v).tobytes()
-            dtype = "bfloat16"
             f32_to_bf16_triggered = True
         else:
             data = v.tobytes()
-        dtype = str(v.dtype)
 
         shard_manager.append(data, name=k, shape=shape, dtype=dtype, 
encode_format=encode_format)
 
@@ -263,9 +264,12 @@ def load_ndarray_cache(cachepath: str, device: 
tvm.runtime.Device):
             arr = tvm.nd.empty(shape, dtype, device=device)
             assert offset + nbytes <= len(raw_data)
             buffer_source = raw_data[offset : offset + nbytes]
-            if encode_format == "f32-to-bf16":
+            if encode_format == "f32-to-bf16" and dtype == "float32":
                 data = np.frombuffer(buffer_source, 
dtype="uint16").reshape(shape)
                 arr.copyfrom(_convert_bf16_to_f32(data))
+            elif dtype == "bfloat16":
+                data = np.frombuffer(buffer_source, 
dtype="uint16").reshape(shape)
+                arr.copyfrom(data)
             else:
                 data = np.frombuffer(buffer_source, dtype=dtype).reshape(shape)
                 arr.copyfrom(data)

Reply via email to