This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7f86987 [Frontend][Tensorflow] Support range like axis in
tf.raw_ops.All for TF 2.x (#7502)
7f86987 is described below
commit 7f869879d27c0055168f25b41447f16acf8b58fd
Author: Xingyu Zhou <[email protected]>
AuthorDate: Wed Feb 24 15:25:44 2021 -0800
[Frontend][Tensorflow] Support range like axis in tf.raw_ops.All for TF 2.x
(#7502)
* add TF2.x raw_ops.all axis range support
* apply linting
* fix range() func input
---
python/tvm/relay/frontend/tensorflow.py | 10 ++++++
tests/python/frontend/tensorflow/test_forward.py | 39 ++++++++++++++++++++++++
2 files changed, 49 insertions(+)
diff --git a/python/tvm/relay/frontend/tensorflow.py
b/python/tvm/relay/frontend/tensorflow.py
index ac52ab7..3a3c5fc 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1976,6 +1976,16 @@ def _range():
# Symbolic delta
delta = inputs[2]
+ # if all attributes are constant, evalute the range function and
return relay.const
+ if all(
+ [
+ isinstance(start, (np.int32, np.int64, int, np.float32,
np.float64, float)),
+ isinstance(limit, (np.int32, np.int64, int, np.float32,
np.float64, float)),
+ isinstance(delta, (np.int32, np.int64, int, np.float32,
np.float64, float)),
+ ]
+ ):
+ return tvm.relay.const(list(range(int(start), int(limit),
int(delta))))
+
dtype = attr["Tidx"].name if "Tidx" in attr else str(start.dtype)
if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64,
float)):
start = _expr.const(start, dtype=dtype)
diff --git a/tests/python/frontend/tensorflow/test_forward.py
b/tests/python/frontend/tensorflow/test_forward.py
index ecf6441..d0038ca 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -3949,6 +3949,45 @@ def test_forward_reduce():
#######################################################################
+# All, Max, Min
+# ------------------------------------------------------------------
+
+
+def test_forward_raw_reduce():
+ def _check_op(tf_op, ishape, axis, keepdims, range_axis=False,
dtype="float32"):
+ tf.reset_default_graph()
+ if dtype == "bool":
+ np_data = np.random.choice([True, False], size=ishape)
+ else:
+ np_data = np.random.uniform(size=ishape).astype(dtype)
+ if tf_op == tf.math.reduce_prod:
+ axis = 1
+ np_data = np_data.reshape(1, -1)
+ with tf.Graph().as_default():
+ if range_axis:
+ axis = tf.range(axis[0], axis[1], axis[2], name="range",
dtype="int32")
+ in_data = tf.placeholder(dtype, name="in_data")
+ reduce_op = tf_op(input=in_data, axis=axis, keep_dims=keepdims,
name="reduce_std")
+ compare_tf_with_tvm([np_data], ["in_data:0"], reduce_op.name)
+
+ def _test_raw_reduce_op(op, dtypes=["int32", "float32"]):
+ for dtype in dtypes:
+ _check_op(op, (3, 10), axis=(-1), keepdims=False, dtype=dtype)
+ _check_op(op, (8, 16, 32), axis=(-1), keepdims=False, dtype=dtype)
+ _check_op(op, (1, 8, 8, 3), axis=(2, 3), keepdims=True,
dtype=dtype)
+ _check_op(op, (2, 3, 10, 10), axis=(1, 2), keepdims=True,
dtype=dtype)
+ _check_op(op, (1, 8, 8, 3), axis=(2, 4, 1), keepdims=True,
range_axis=True, dtype=dtype)
+ _check_op(
+ op, (2, 3, 10, 10), axis=(1, 3, 1), keepdims=True,
range_axis=True, dtype=dtype
+ )
+
+ if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"):
+ _test_raw_reduce_op(tf.raw_ops.All, dtypes=["bool"])
+ _test_raw_reduce_op(tf.raw_ops.Max)
+ _test_raw_reduce_op(tf.raw_ops.Min)
+
+
+#######################################################################
# Relational operators
# --------------------