This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2cc0451  added surpport for arg type of numeric float16 and testcase, 
fixed the (#10797)
2cc0451 is described below

commit 2cc0451c282e7de9e529c635dfa774f63329e359
Author: mawnja <[email protected]>
AuthorDate: Mon Mar 28 05:14:09 2022 +0800

    added surpport for arg type of numeric float16 and testcase, fixed the 
(#10797)
    
    cierror
---
 python/tvm/_ffi/base.py       |  2 +-
 python/tvm/runtime/vm.py      | 10 +++++++++-
 tests/python/relay/test_vm.py | 20 ++++++++++++++++++++
 3 files changed, 30 insertions(+), 2 deletions(-)

diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py
index 164ea53..e4e1fb1 100644
--- a/python/tvm/_ffi/base.py
+++ b/python/tvm/_ffi/base.py
@@ -28,7 +28,7 @@ from . import libinfo
 # ----------------------------
 string_types = (str,)
 integer_types = (int, np.int32)
-numeric_types = integer_types + (float, np.float32)
+numeric_types = integer_types + (float, np.float16, np.float32)
 
 # this function is needed for python3
 # to convert ctypes.char_p .value back to python str
diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py
index 27fd5af..0592368 100644
--- a/python/tvm/runtime/vm.py
+++ b/python/tvm/runtime/vm.py
@@ -32,6 +32,14 @@ from ..rpc.base import RPC_SESS_MASK
 
 
 def _convert(arg, cargs):
+    def _gettype(arg):
+        if isinstance(arg, np.float16):
+            return "float16"
+        elif isinstance(arg, (_base.integer_types, bool)):
+            return "int32"
+        else:
+            return "float32"
+
     if isinstance(arg, Object):
         cargs.append(arg)
     elif isinstance(arg, np.ndarray):
@@ -45,7 +53,7 @@ def _convert(arg, cargs):
             _convert(field, field_args)
         cargs.append(container.tuple_object(field_args))
     elif isinstance(arg, (_base.numeric_types, bool)):
-        dtype = "int32" if isinstance(arg, (_base.integer_types, bool)) else 
"float32"
+        dtype = _gettype(arg)
         value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0))
         cargs.append(value)
     elif isinstance(arg, str):
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index e4666c6..cde7806 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -619,6 +619,26 @@ def test_add_op_scalar(target, dev):
         check_result(target, dev, [x_data, y_data], x_data + y_data, mod)
 
 
+def test_add_op_scalar_float16(target, dev):
+    """
+    test_add_op_scalar_float16:
+        fn (x, y) {
+            return x + y;
+        }
+    """
+    mod = tvm.IRModule()
+    x = relay.var("x", shape=(), dtype="float16")  # Default to float16
+    y = relay.var("y", shape=(), dtype="float16")  # Default to float16
+    func = relay.Function([x, y], relay.op.add(x, y))
+    x_y_data = [
+        (np.array(10.0, dtype="float16"), np.array(1.0, dtype="float16")),
+        (np.float16(10.0), np.float16(1.0)),
+    ]
+    for (x_data, y_data) in x_y_data:
+        mod["main"] = func
+        check_result(target, dev, [x_data, y_data], x_data + y_data, mod)
+
+
 def test_add_op_scalar_int(target, dev):
     """
     test_add_op_scalar_int:

Reply via email to