This is an automated email from the ASF dual-hosted git repository.

anijain2305 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d24634a  TF argmax - handling int64 datatype (#6674)
d24634a is described below

commit d24634a7cc3c502176838d238af42cf7c94defbe
Author: Animesh Jain <anij...@umich.edu>
AuthorDate: Mon Oct 12 20:36:00 2020 -0700

    TF argmax - handling int64 datatype (#6674)
    
    Co-authored-by: Ubuntu <ubu...@ip-172-31-0-202.us-west-2.compute.internal>
---
 python/tvm/relay/frontend/tensorflow.py          |  6 +++++-
 tests/python/frontend/tensorflow/test_forward.py | 12 ++++++------
 2 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow.py 
b/python/tvm/relay/frontend/tensorflow.py
index c7e8c00..3df582a 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -146,7 +146,11 @@ def _argx(func, func_name):
             raise TypeError(
                 "Unsupported argument for `{}` : `axis` should be a 
constant".format(func_name)
             )
-        return func(inputs[0], axis=axis_input_value, keepdims=False)
+        out = func(inputs[0], axis=axis_input_value, keepdims=False)
+        dtype = attr["output_type"].name
+        if dtype != "int32":
+            out = _op.cast(out, dtype=dtype)
+        return out
 
     return _impl
 
diff --git a/tests/python/frontend/tensorflow/test_forward.py 
b/tests/python/frontend/tensorflow/test_forward.py
index fb4c104..8e347e7 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1601,16 +1601,16 @@ def _test_argx(func, data, **kwargs):
 
     with tf.Graph().as_default():
         inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, 
name="c0")
-        func(inp, name="argx0", output_type=tf.int32, **kwargs)
-
+        func(inp, name="argx0", **kwargs)
         compare_tf_with_tvm(data, "c0:0", "argx0:0")
 
 
 def test_forward_argminmax():
-    for axis in [None, 0, 1, 2]:
-        data = np.random.uniform(size=(8, 4, 9)).astype("float32")
-        _test_argx(tf.argmax, data=data, axis=axis)
-        _test_argx(tf.argmin, data=data, axis=axis)
+    for output_type in [tf.int64, tf.int32]:
+        for axis in [None, 0, 1, 2]:
+            data = np.random.uniform(size=(8, 4, 9)).astype("float32")
+            _test_argx(tf.argmax, data=data, axis=axis, 
output_type=output_type)
+            _test_argx(tf.argmin, data=data, axis=axis, 
output_type=output_type)
 
 
 #######################################################################

Reply via email to