This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7f86987  [Frontend][Tensorflow] Support range like axis in 
tf.raw_ops.All for TF 2.x (#7502)
7f86987 is described below

commit 7f869879d27c0055168f25b41447f16acf8b58fd
Author: Xingyu Zhou <[email protected]>
AuthorDate: Wed Feb 24 15:25:44 2021 -0800

    [Frontend][Tensorflow] Support range like axis in tf.raw_ops.All for TF 2.x 
(#7502)
    
    * add TF2.x raw_ops.all axis range support
    
    * apply linting
    
    * fix range() func input
---
 python/tvm/relay/frontend/tensorflow.py          | 10 ++++++
 tests/python/frontend/tensorflow/test_forward.py | 39 ++++++++++++++++++++++++
 2 files changed, 49 insertions(+)

diff --git a/python/tvm/relay/frontend/tensorflow.py 
b/python/tvm/relay/frontend/tensorflow.py
index ac52ab7..3a3c5fc 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1976,6 +1976,16 @@ def _range():
                 # Symbolic delta
                 delta = inputs[2]
 
+        # if all attributes are constant, evalute the range function and 
return relay.const
+        if all(
+            [
+                isinstance(start, (np.int32, np.int64, int, np.float32, 
np.float64, float)),
+                isinstance(limit, (np.int32, np.int64, int, np.float32, 
np.float64, float)),
+                isinstance(delta, (np.int32, np.int64, int, np.float32, 
np.float64, float)),
+            ]
+        ):
+            return tvm.relay.const(list(range(int(start), int(limit), 
int(delta))))
+
         dtype = attr["Tidx"].name if "Tidx" in attr else str(start.dtype)
         if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, 
float)):
             start = _expr.const(start, dtype=dtype)
diff --git a/tests/python/frontend/tensorflow/test_forward.py 
b/tests/python/frontend/tensorflow/test_forward.py
index ecf6441..d0038ca 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -3949,6 +3949,45 @@ def test_forward_reduce():
 
 
 #######################################################################
+# All, Max, Min
+# ------------------------------------------------------------------
+
+
+def test_forward_raw_reduce():
+    def _check_op(tf_op, ishape, axis, keepdims, range_axis=False, 
dtype="float32"):
+        tf.reset_default_graph()
+        if dtype == "bool":
+            np_data = np.random.choice([True, False], size=ishape)
+        else:
+            np_data = np.random.uniform(size=ishape).astype(dtype)
+        if tf_op == tf.math.reduce_prod:
+            axis = 1
+            np_data = np_data.reshape(1, -1)
+        with tf.Graph().as_default():
+            if range_axis:
+                axis = tf.range(axis[0], axis[1], axis[2], name="range", 
dtype="int32")
+            in_data = tf.placeholder(dtype, name="in_data")
+            reduce_op = tf_op(input=in_data, axis=axis, keep_dims=keepdims, 
name="reduce_std")
+            compare_tf_with_tvm([np_data], ["in_data:0"], reduce_op.name)
+
+    def _test_raw_reduce_op(op, dtypes=["int32", "float32"]):
+        for dtype in dtypes:
+            _check_op(op, (3, 10), axis=(-1), keepdims=False, dtype=dtype)
+            _check_op(op, (8, 16, 32), axis=(-1), keepdims=False, dtype=dtype)
+            _check_op(op, (1, 8, 8, 3), axis=(2, 3), keepdims=True, 
dtype=dtype)
+            _check_op(op, (2, 3, 10, 10), axis=(1, 2), keepdims=True, 
dtype=dtype)
+            _check_op(op, (1, 8, 8, 3), axis=(2, 4, 1), keepdims=True, 
range_axis=True, dtype=dtype)
+            _check_op(
+                op, (2, 3, 10, 10), axis=(1, 3, 1), keepdims=True, 
range_axis=True, dtype=dtype
+            )
+
+    if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"):
+        _test_raw_reduce_op(tf.raw_ops.All, dtypes=["bool"])
+        _test_raw_reduce_op(tf.raw_ops.Max)
+        _test_raw_reduce_op(tf.raw_ops.Min)
+
+
+#######################################################################
 # Relational operators
 # --------------------
 

Reply via email to