This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new eed21eeff2 [ONNX] Fix cast op to/from bfloat16 (#11171)
eed21eeff2 is described below
commit eed21eeff2e1bb746b0ae1e7bb5831903e147e77
Author: Margaret Qian <[email protected]>
AuthorDate: Sun May 15 21:34:37 2022 -0700
[ONNX] Fix cast op to/from bfloat16 (#11171)
* fix cast from bfloat16
* fix cast to bfloat16 test as well
* clean up comments
* lint
* add comment
Co-authored-by: Margaret Qian <[email protected]>
---
python/tvm/relay/frontend/onnx.py | 28 ++++++++++++++++++++++++----
python/tvm/runtime/ndarray.py | 2 ++
tests/python/frontend/onnx/test_forward.py | 8 +++++---
3 files changed, 31 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index 036b5a9146..233067959f 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -102,6 +102,16 @@ def get_type(elem_type):
except ImportError as e:
raise ImportError("Unable to import onnx which is required
{}".format(e))
+ try:
+ from onnx import TensorProto
+ except ImportError as e:
+ raise ImportError("Unable to import TensorProto from onnx
{}".format(e))
+
+ # Onnx mapping converts bfloat16 to float16 because
+ # numpy does not have a bfloat16 data type. However,
+ # tvm has one, so we force the return type to be bfloat16
+ if elem_type == int(TensorProto.BFLOAT16):
+ return "bfloat16"
return str(TENSOR_TYPE_TO_NP_TYPE[elem_type])
@@ -1703,11 +1713,21 @@ class Cast(OnnxOpConverter):
@classmethod
def _impl_v5(cls, inputs, attr, params):
try:
- from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
-
- attr["to"] = str(TENSOR_TYPE_TO_NP_TYPE[attr["to"]])
+ from onnx import TensorProto
except ImportError as e:
- raise ImportError("Unable to import onnx.mapping which is required
{}".format(e))
+ raise ImportError("Unable to import TensorProto from onnx
{}".format(e))
+
+ # If onnx mapping is used, bfloat16 gets converted to float16
+ # which is not the desired behavior
+ if attr["to"] == int(TensorProto.BFLOAT16):
+ attr["to"] = "bfloat16"
+ else:
+ try:
+ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
+
+ attr["to"] = str(TENSOR_TYPE_TO_NP_TYPE[attr["to"]])
+ except ImportError as e:
+ raise ImportError("Unable to import onnx.mapping which is
required {}".format(e))
return AttrCvt(op_name="cast", transforms={"to": "dtype"})(inputs,
attr)
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 97f37c9985..3d4764d616 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -218,6 +218,8 @@ class NDArray(NDArrayBase):
dtype = str(t)
if dtype == "int4":
dtype = "int8"
+ if dtype == "bfloat16":
+ dtype = "uint16"
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p)
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index 03f0cb3bad..ec5d2b6ae2 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -5033,9 +5033,7 @@ unsupported_onnx_tests = [
"test_bernoulli_double_expanded",
"test_bernoulli_seed",
"test_bernoulli_seed_expanded",
- "test_cast_BFLOAT16_to_FLOAT",
"test_cast_DOUBLE_to_FLOAT16",
- "test_cast_FLOAT_to_BFLOAT16",
"test_cast_FLOAT_to_STRING",
"test_cast_STRING_to_FLOAT",
"test_castlike_BFLOAT16_to_FLOAT",
@@ -5185,6 +5183,11 @@ def test_onnx_nodes(target, dev, onnx_test):
# roialign results to 4 decimal places
atol = 1e-4
+ if "to_BFLOAT16" in test_dir:
+ # the tolerance here is for the comparison in uint16 space, but is not
as significant
+ # of a delta in bfloat16 space because it's representing the mantissa
being off by 1
+ atol = 1
+
if "_sce_" in test_dir:
# complicated loss functions like SoftmaxCrossEntropy can have minor
variations
# in accuracy depending on implementation
@@ -5205,7 +5208,6 @@ def test_onnx_nodes(target, dev, onnx_test):
outputs.append(numpy_helper.to_array(new_tensor))
else:
raise ImportError(str(tensor) + " not labeled as an import or
an output")
-
tvm_val = get_tvm_output_with_vm(onnx_model, inputs, target, dev)
if len(outputs) == 1:
tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=rtol, atol=atol)