This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5cf3405033 [Frontend][Tensorflow] Update Select to SelectV2 (#13884)
5cf3405033 is described below

commit 5cf3405033127fe14238b23e287ffe3ab0bb967d
Author: balaram-cadence <[email protected]>
AuthorDate: Thu Feb 9 05:03:51 2023 -0600

    [Frontend][Tensorflow] Update Select to SelectV2 (#13884)
    
    Fixes #13855
---
 python/tvm/relay/frontend/tensorflow_ops.py      | 22 ++++++++++++++++++++--
 tests/python/frontend/tensorflow/test_forward.py | 23 +++++++++++++++++++++++
 2 files changed, 43 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow_ops.py 
b/python/tvm/relay/frontend/tensorflow_ops.py
index e9bb15e1d1..ab773f9a2a 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -2380,6 +2380,24 @@ def _transpose():
 
 
 def _where():
+    def _impl(inputs, attr, params, mod):
+        if len(inputs) == 1:
+            return AttrCvt(op_name="argwhere")(inputs, attr)
+        cond_shape = _infer_shape(inputs[0], mod)
+        x_shape = _infer_shape(inputs[1], mod)
+        # Due to difference in broadcast behavior between Select and SelectV2,
+        # we adjust condition dimension with expand_dim and then broadcast.
+        if len(cond_shape) == 1 and cond_shape[0] == x_shape[0]:
+            for _ in range(len(x_shape) - 1):
+                inputs[0] = _op.expand_dims(inputs[0], axis=-1)
+            broadcast_cond = _op.broadcast_to(inputs[0], x_shape)
+            inputs[0] = _op.cast(broadcast_cond, "bool")
+        return AttrCvt(op_name="where")(inputs, attr)
+
+    return _impl
+
+
+def _where_v2():
     def _impl(inputs, attr, params, mod):
         if len(inputs) == 1:
             return AttrCvt(op_name="argwhere")(inputs, attr)
@@ -3088,7 +3106,7 @@ _convert_map = {
     "Round": AttrCvt("round"),
     "Rsqrt": _rsqrt(),
     "Select": _where(),
-    "SelectV2": _where(),
+    "SelectV2": _where_v2(),
     "Selu": _selu(),
     "Shape": _shape(),
     "Sigmoid": AttrCvt("sigmoid"),
@@ -3142,6 +3160,6 @@ _convert_map = {
     "UniqueWithCounts": _unique(True),
     "Unpack": _unpack(),
     "UnravelIndex": _unravel_index(),
-    "Where": _where(),
+    "Where": _where_v2(),
     "ZerosLike": AttrCvt("zeros_like"),
 }
diff --git a/tests/python/frontend/tensorflow/test_forward.py 
b/tests/python/frontend/tensorflow/test_forward.py
index 2fb7c74f60..1e1bd435d5 100755
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1234,6 +1234,29 @@ def test_forward_argwhere():
     _test_forward_where((5, 5, 5, 5, 5))
 
 
+def _test_forward_where_with_broadcast(in_shape, cond_shape):
+    choice_list = list(np.arange(10).astype("float32"))
+    t1 = np.random.choice(choice_list, size=cond_shape)
+    t2 = np.random.choice(choice_list, size=cond_shape)
+    x = np.random.choice(choice_list, size=in_shape)
+    y = np.random.choice(choice_list, size=in_shape)
+
+    with tf.Graph().as_default():
+        in1 = tf.placeholder(shape=cond_shape, dtype="float32", name="in1")
+        in2 = tf.placeholder(shape=cond_shape, dtype="float32", name="in2")
+        condition = math_ops.less(in1, in2, name="less")
+        lhs = tf.placeholder(shape=in_shape, dtype="float32", name="x")
+        rhs = tf.placeholder(shape=in_shape, dtype="float32", name="y")
+        out = tf.where(condition, lhs, rhs)
+        compare_tf_with_tvm([t1, t2, x, y], ["in1:0", "in2:0", "x:0", "y:0"], 
out.name)
+
+
+def test_forward_where_with_broadcast():
+    _test_forward_where_with_broadcast((5, 2), (5,))
+    _test_forward_where_with_broadcast((5, 7), (5,))
+    _test_forward_where_with_broadcast((3, 2, 5), (3,))
+
+
 #######################################################################
 # SpaceToBatchND
 # --------------

Reply via email to