This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5cf3405033 [Frontend][Tensorflow] Update Select to SelectV2 (#13884)
5cf3405033 is described below
commit 5cf3405033127fe14238b23e287ffe3ab0bb967d
Author: balaram-cadence <[email protected]>
AuthorDate: Thu Feb 9 05:03:51 2023 -0600
[Frontend][Tensorflow] Update Select to SelectV2 (#13884)
Fixes #13855
---
python/tvm/relay/frontend/tensorflow_ops.py | 22 ++++++++++++++++++++--
tests/python/frontend/tensorflow/test_forward.py | 23 +++++++++++++++++++++++
2 files changed, 43 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow_ops.py
b/python/tvm/relay/frontend/tensorflow_ops.py
index e9bb15e1d1..ab773f9a2a 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -2380,6 +2380,24 @@ def _transpose():
def _where():
+ def _impl(inputs, attr, params, mod):
+ if len(inputs) == 1:
+ return AttrCvt(op_name="argwhere")(inputs, attr)
+ cond_shape = _infer_shape(inputs[0], mod)
+ x_shape = _infer_shape(inputs[1], mod)
+ # Due to difference in broadcast behavior between Select and SelectV2,
+ # we adjust condition dimension with expand_dim and then broadcast.
+ if len(cond_shape) == 1 and cond_shape[0] == x_shape[0]:
+ for _ in range(len(x_shape) - 1):
+ inputs[0] = _op.expand_dims(inputs[0], axis=-1)
+ broadcast_cond = _op.broadcast_to(inputs[0], x_shape)
+ inputs[0] = _op.cast(broadcast_cond, "bool")
+ return AttrCvt(op_name="where")(inputs, attr)
+
+ return _impl
+
+
+def _where_v2():
def _impl(inputs, attr, params, mod):
if len(inputs) == 1:
return AttrCvt(op_name="argwhere")(inputs, attr)
@@ -3088,7 +3106,7 @@ _convert_map = {
"Round": AttrCvt("round"),
"Rsqrt": _rsqrt(),
"Select": _where(),
- "SelectV2": _where(),
+ "SelectV2": _where_v2(),
"Selu": _selu(),
"Shape": _shape(),
"Sigmoid": AttrCvt("sigmoid"),
@@ -3142,6 +3160,6 @@ _convert_map = {
"UniqueWithCounts": _unique(True),
"Unpack": _unpack(),
"UnravelIndex": _unravel_index(),
- "Where": _where(),
+ "Where": _where_v2(),
"ZerosLike": AttrCvt("zeros_like"),
}
diff --git a/tests/python/frontend/tensorflow/test_forward.py
b/tests/python/frontend/tensorflow/test_forward.py
index 2fb7c74f60..1e1bd435d5 100755
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1234,6 +1234,29 @@ def test_forward_argwhere():
_test_forward_where((5, 5, 5, 5, 5))
+def _test_forward_where_with_broadcast(in_shape, cond_shape):
+ choice_list = list(np.arange(10).astype("float32"))
+ t1 = np.random.choice(choice_list, size=cond_shape)
+ t2 = np.random.choice(choice_list, size=cond_shape)
+ x = np.random.choice(choice_list, size=in_shape)
+ y = np.random.choice(choice_list, size=in_shape)
+
+ with tf.Graph().as_default():
+ in1 = tf.placeholder(shape=cond_shape, dtype="float32", name="in1")
+ in2 = tf.placeholder(shape=cond_shape, dtype="float32", name="in2")
+ condition = math_ops.less(in1, in2, name="less")
+ lhs = tf.placeholder(shape=in_shape, dtype="float32", name="x")
+ rhs = tf.placeholder(shape=in_shape, dtype="float32", name="y")
+ out = tf.where(condition, lhs, rhs)
+ compare_tf_with_tvm([t1, t2, x, y], ["in1:0", "in2:0", "x:0", "y:0"],
out.name)
+
+
+def test_forward_where_with_broadcast():
+ _test_forward_where_with_broadcast((5, 2), (5,))
+ _test_forward_where_with_broadcast((5, 7), (5,))
+ _test_forward_where_with_broadcast((3, 2, 5), (3,))
+
+
#######################################################################
# SpaceToBatchND
# --------------