This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new eed21eeff2 [ONNX] Fix cast op to/from bfloat16 (#11171)
eed21eeff2 is described below

commit eed21eeff2e1bb746b0ae1e7bb5831903e147e77
Author: Margaret Qian <[email protected]>
AuthorDate: Sun May 15 21:34:37 2022 -0700

    [ONNX] Fix cast op to/from bfloat16 (#11171)
    
    * fix cast from bfloat16
    
    * fix cast to bfloat16 test as well
    
    * clean up comments
    
    * lint
    
    * add comment
    
    Co-authored-by: Margaret Qian <[email protected]>
---
 python/tvm/relay/frontend/onnx.py          | 28 ++++++++++++++++++++++++----
 python/tvm/runtime/ndarray.py              |  2 ++
 tests/python/frontend/onnx/test_forward.py |  8 +++++---
 3 files changed, 31 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index 036b5a9146..233067959f 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -102,6 +102,16 @@ def get_type(elem_type):
     except ImportError as e:
         raise ImportError("Unable to import onnx which is required 
{}".format(e))
 
+    try:
+        from onnx import TensorProto
+    except ImportError as e:
+        raise ImportError("Unable to import TensorProto from onnx 
{}".format(e))
+
+    # Onnx mapping converts bfloat16 to float16 because
+    # numpy does not have a bfloat16 data type. However,
+    # tvm has one, so we force the return type to be bfloat16
+    if elem_type == int(TensorProto.BFLOAT16):
+        return "bfloat16"
     return str(TENSOR_TYPE_TO_NP_TYPE[elem_type])
 
 
@@ -1703,11 +1713,21 @@ class Cast(OnnxOpConverter):
     @classmethod
     def _impl_v5(cls, inputs, attr, params):
         try:
-            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
-
-            attr["to"] = str(TENSOR_TYPE_TO_NP_TYPE[attr["to"]])
+            from onnx import TensorProto
         except ImportError as e:
-            raise ImportError("Unable to import onnx.mapping which is required 
{}".format(e))
+            raise ImportError("Unable to import TensorProto from onnx 
{}".format(e))
+
+        # If onnx mapping is used, bfloat16 gets converted to float16
+        # which is not the desired behavior
+        if attr["to"] == int(TensorProto.BFLOAT16):
+            attr["to"] = "bfloat16"
+        else:
+            try:
+                from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
+
+                attr["to"] = str(TENSOR_TYPE_TO_NP_TYPE[attr["to"]])
+            except ImportError as e:
+                raise ImportError("Unable to import onnx.mapping which is 
required {}".format(e))
         return AttrCvt(op_name="cast", transforms={"to": "dtype"})(inputs, 
attr)
 
 
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 97f37c9985..3d4764d616 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -218,6 +218,8 @@ class NDArray(NDArrayBase):
             dtype = str(t)
         if dtype == "int4":
             dtype = "int8"
+        if dtype == "bfloat16":
+            dtype = "uint16"
         np_arr = np.empty(shape, dtype=dtype)
         assert np_arr.flags["C_CONTIGUOUS"]
         data = np_arr.ctypes.data_as(ctypes.c_void_p)
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index 03f0cb3bad..ec5d2b6ae2 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -5033,9 +5033,7 @@ unsupported_onnx_tests = [
     "test_bernoulli_double_expanded",
     "test_bernoulli_seed",
     "test_bernoulli_seed_expanded",
-    "test_cast_BFLOAT16_to_FLOAT",
     "test_cast_DOUBLE_to_FLOAT16",
-    "test_cast_FLOAT_to_BFLOAT16",
     "test_cast_FLOAT_to_STRING",
     "test_cast_STRING_to_FLOAT",
     "test_castlike_BFLOAT16_to_FLOAT",
@@ -5185,6 +5183,11 @@ def test_onnx_nodes(target, dev, onnx_test):
         # roialign results to 4 decimal places
         atol = 1e-4
 
+    if "to_BFLOAT16" in test_dir:
+        # the tolerance here is for the comparison in uint16 space, but is not 
as significant
+        # of a delta in bfloat16 space because it's representing the mantissa 
being off by 1
+        atol = 1
+
     if "_sce_" in test_dir:
         # complicated loss functions like SoftmaxCrossEntropy can have minor 
variations
         # in accuracy depending on implementation
@@ -5205,7 +5208,6 @@ def test_onnx_nodes(target, dev, onnx_test):
                 outputs.append(numpy_helper.to_array(new_tensor))
             else:
                 raise ImportError(str(tensor) + " not labeled as an import or 
an output")
-
     tvm_val = get_tvm_output_with_vm(onnx_model, inputs, target, dev)
     if len(outputs) == 1:
         tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=rtol, atol=atol)

Reply via email to