This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2cc0451 added surpport for arg type of numeric float16 and testcase,
fixed the (#10797)
2cc0451 is described below
commit 2cc0451c282e7de9e529c635dfa774f63329e359
Author: mawnja <[email protected]>
AuthorDate: Mon Mar 28 05:14:09 2022 +0800
added surpport for arg type of numeric float16 and testcase, fixed the
(#10797)
cierror
---
python/tvm/_ffi/base.py | 2 +-
python/tvm/runtime/vm.py | 10 +++++++++-
tests/python/relay/test_vm.py | 20 ++++++++++++++++++++
3 files changed, 30 insertions(+), 2 deletions(-)
diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py
index 164ea53..e4e1fb1 100644
--- a/python/tvm/_ffi/base.py
+++ b/python/tvm/_ffi/base.py
@@ -28,7 +28,7 @@ from . import libinfo
# ----------------------------
string_types = (str,)
integer_types = (int, np.int32)
-numeric_types = integer_types + (float, np.float32)
+numeric_types = integer_types + (float, np.float16, np.float32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py
index 27fd5af..0592368 100644
--- a/python/tvm/runtime/vm.py
+++ b/python/tvm/runtime/vm.py
@@ -32,6 +32,14 @@ from ..rpc.base import RPC_SESS_MASK
def _convert(arg, cargs):
+ def _gettype(arg):
+ if isinstance(arg, np.float16):
+ return "float16"
+ elif isinstance(arg, (_base.integer_types, bool)):
+ return "int32"
+ else:
+ return "float32"
+
if isinstance(arg, Object):
cargs.append(arg)
elif isinstance(arg, np.ndarray):
@@ -45,7 +53,7 @@ def _convert(arg, cargs):
_convert(field, field_args)
cargs.append(container.tuple_object(field_args))
elif isinstance(arg, (_base.numeric_types, bool)):
- dtype = "int32" if isinstance(arg, (_base.integer_types, bool)) else
"float32"
+ dtype = _gettype(arg)
value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0))
cargs.append(value)
elif isinstance(arg, str):
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index e4666c6..cde7806 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -619,6 +619,26 @@ def test_add_op_scalar(target, dev):
check_result(target, dev, [x_data, y_data], x_data + y_data, mod)
+def test_add_op_scalar_float16(target, dev):
+ """
+ test_add_op_scalar_float16:
+ fn (x, y) {
+ return x + y;
+ }
+ """
+ mod = tvm.IRModule()
+ x = relay.var("x", shape=(), dtype="float16") # Default to float16
+ y = relay.var("y", shape=(), dtype="float16") # Default to float16
+ func = relay.Function([x, y], relay.op.add(x, y))
+ x_y_data = [
+ (np.array(10.0, dtype="float16"), np.array(1.0, dtype="float16")),
+ (np.float16(10.0), np.float16(1.0)),
+ ]
+ for (x_data, y_data) in x_y_data:
+ mod["main"] = func
+ check_result(target, dev, [x_data, y_data], x_data + y_data, mod)
+
+
def test_add_op_scalar_int(target, dev):
"""
test_add_op_scalar_int: