This is an automated email from the ASF dual-hosted git repository. anijain2305 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push: new d24634a TF argmax - handling int64 datatype (#6674) d24634a is described below commit d24634a7cc3c502176838d238af42cf7c94defbe Author: Animesh Jain <anij...@umich.edu> AuthorDate: Mon Oct 12 20:36:00 2020 -0700 TF argmax - handling int64 datatype (#6674) Co-authored-by: Ubuntu <ubu...@ip-172-31-0-202.us-west-2.compute.internal> --- python/tvm/relay/frontend/tensorflow.py | 6 +++++- tests/python/frontend/tensorflow/test_forward.py | 12 ++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c7e8c00..3df582a 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -146,7 +146,11 @@ def _argx(func, func_name): raise TypeError( "Unsupported argument for `{}` : `axis` should be a constant".format(func_name) ) - return func(inputs[0], axis=axis_input_value, keepdims=False) + out = func(inputs[0], axis=axis_input_value, keepdims=False) + dtype = attr["output_type"].name + if dtype != "int32": + out = _op.cast(out, dtype=dtype) + return out return _impl diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index fb4c104..8e347e7 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1601,16 +1601,16 @@ def _test_argx(func, data, **kwargs): with tf.Graph().as_default(): inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0") - func(inp, name="argx0", output_type=tf.int32, **kwargs) - + func(inp, name="argx0", **kwargs) compare_tf_with_tvm(data, "c0:0", "argx0:0") def test_forward_argminmax(): - for axis in [None, 0, 1, 2]: - data = np.random.uniform(size=(8, 4, 9)).astype("float32") - _test_argx(tf.argmax, data=data, axis=axis) - _test_argx(tf.argmin, data=data, axis=axis) + for output_type in [tf.int64, tf.int32]: + for axis in [None, 0, 1, 2]: + data = np.random.uniform(size=(8, 4, 9)).astype("float32") + _test_argx(tf.argmax, data=data, axis=axis, output_type=output_type) + _test_argx(tf.argmin, data=data, axis=axis, output_type=output_type) #######################################################################