This is an automated email from the ASF dual-hosted git repository.
anijain2305 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d24634a TF argmax - handling int64 datatype (#6674)
d24634a is described below
commit d24634a7cc3c502176838d238af42cf7c94defbe
Author: Animesh Jain <[email protected]>
AuthorDate: Mon Oct 12 20:36:00 2020 -0700
TF argmax - handling int64 datatype (#6674)
Co-authored-by: Ubuntu <[email protected]>
---
python/tvm/relay/frontend/tensorflow.py | 6 +++++-
tests/python/frontend/tensorflow/test_forward.py | 12 ++++++------
2 files changed, 11 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow.py
b/python/tvm/relay/frontend/tensorflow.py
index c7e8c00..3df582a 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -146,7 +146,11 @@ def _argx(func, func_name):
raise TypeError(
"Unsupported argument for `{}` : `axis` should be a
constant".format(func_name)
)
- return func(inputs[0], axis=axis_input_value, keepdims=False)
+ out = func(inputs[0], axis=axis_input_value, keepdims=False)
+ dtype = attr["output_type"].name
+ if dtype != "int32":
+ out = _op.cast(out, dtype=dtype)
+ return out
return _impl
diff --git a/tests/python/frontend/tensorflow/test_forward.py
b/tests/python/frontend/tensorflow/test_forward.py
index fb4c104..8e347e7 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1601,16 +1601,16 @@ def _test_argx(func, data, **kwargs):
with tf.Graph().as_default():
inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype,
name="c0")
- func(inp, name="argx0", output_type=tf.int32, **kwargs)
-
+ func(inp, name="argx0", **kwargs)
compare_tf_with_tvm(data, "c0:0", "argx0:0")
def test_forward_argminmax():
- for axis in [None, 0, 1, 2]:
- data = np.random.uniform(size=(8, 4, 9)).astype("float32")
- _test_argx(tf.argmax, data=data, axis=axis)
- _test_argx(tf.argmin, data=data, axis=axis)
+ for output_type in [tf.int64, tf.int32]:
+ for axis in [None, 0, 1, 2]:
+ data = np.random.uniform(size=(8, 4, 9)).astype("float32")
+ _test_argx(tf.argmax, data=data, axis=axis,
output_type=output_type)
+ _test_argx(tf.argmin, data=data, axis=axis,
output_type=output_type)
#######################################################################