This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 3bd516701c [Unity] handle bf16 in dump_ndarray_cache and
load_ndarray_cache (#14514)
3bd516701c is described below
commit 3bd516701c6f43f051f37b75f71b165d0d3245cd
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed Apr 5 18:19:27 2023 -0400
[Unity] handle bf16 in dump_ndarray_cache and load_ndarray_cache (#14514)
handle bf16 dump and load
---
python/tvm/contrib/tvmjs.py | 18 +++++++++++-------
1 file changed, 11 insertions(+), 7 deletions(-)
diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
index 3783baefe0..48fac1c66d 100644
--- a/python/tvm/contrib/tvmjs.py
+++ b/python/tvm/contrib/tvmjs.py
@@ -179,20 +179,21 @@ def dump_ndarray_cache(
shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard",
shard_cap_nbytes)
- for k, v in params.items():
- shape = list(v.shape)
-
+ for k, origin_v in params.items():
+ shape = list(origin_v.shape)
+ v = origin_v
if not isinstance(v, np.ndarray):
v = v.numpy()
+ # prefer to preserve original dtype, especially if the format was
bfloat16
+ dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray)
else v.dtype
+
# convert fp32 to bf16
- if encode_format == "f32-to-bf16" and v.dtype == "float32":
+ if encode_format == "f32-to-bf16" and dtype == "float32":
data = _convert_f32_to_bf16(v).tobytes()
- dtype = "bfloat16"
f32_to_bf16_triggered = True
else:
data = v.tobytes()
- dtype = str(v.dtype)
shard_manager.append(data, name=k, shape=shape, dtype=dtype,
encode_format=encode_format)
@@ -263,9 +264,12 @@ def load_ndarray_cache(cachepath: str, device:
tvm.runtime.Device):
arr = tvm.nd.empty(shape, dtype, device=device)
assert offset + nbytes <= len(raw_data)
buffer_source = raw_data[offset : offset + nbytes]
- if encode_format == "f32-to-bf16":
+ if encode_format == "f32-to-bf16" and dtype == "float32":
data = np.frombuffer(buffer_source,
dtype="uint16").reshape(shape)
arr.copyfrom(_convert_bf16_to_f32(data))
+ elif dtype == "bfloat16":
+ data = np.frombuffer(buffer_source,
dtype="uint16").reshape(shape)
+ arr.copyfrom(data)
else:
data = np.frombuffer(buffer_source, dtype=dtype).reshape(shape)
arr.copyfrom(data)