This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c95d16e  [Frontend][Tensorflow2] Stridedslice and concat_v2 fix (#8483)
c95d16e is described below

commit c95d16e097b8c8be5322a6c92d063bf3ae78eddb
Author: srinidhigoud <[email protected]>
AuthorDate: Fri Jul 16 23:26:41 2021 -0700

    [Frontend][Tensorflow2] Stridedslice and concat_v2 fix (#8483)
    
    * fix for strided_slice when begin > end in case of shrinkaxis_mask
    
    * fix for name_hint missing error for concat_v2 op
    
    * removing a local fix
    
    * adding more testing capability to concat_v2
---
 python/tvm/relay/frontend/tensorflow_ops.py                 | 10 ++++++++--
 tests/python/frontend/tensorflow/test_forward.py            |  2 ++
 tests/python/frontend/tensorflow2/test_functional_models.py |  3 ++-
 3 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow_ops.py 
b/python/tvm/relay/frontend/tensorflow_ops.py
index 797ff51..ba0fcca 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -1483,7 +1483,13 @@ def _identityn():
 def _concatV2():
     def _impl(inputs, attr, params, mod):
         pop_node = inputs.pop(len(inputs) - 1)
-        axis = int(_get_num_param(params, pop_node))
+        try:
+            axis = int(_get_num_param(params, pop_node))
+        except (IndexError, KeyError, AttributeError):
+            try:
+                axis = int(_infer_value(pop_node, params, mod).numpy())
+            except Exception:
+                axis = int(pop_node)
         return AttrCvt(op_name="concatenate", ignores=["T", "N", "Tidx"], 
extras={"axis": axis})(
             [inputs], attr
         )
@@ -2244,7 +2250,7 @@ def _stridedSlice():
                             if begin[index] < 0
                             else begin[index]
                         )
-                        m_end[final_index] = begin[index] + 1
+                        m_end[final_index] = m_begin[final_index] + 1
                         m_stride[final_index] = 1
                         fshape_indices.append(-2)
                     else:
diff --git a/tests/python/frontend/tensorflow/test_forward.py 
b/tests/python/frontend/tensorflow/test_forward.py
index 583014f..9bbb6ca 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -2553,7 +2553,9 @@ def test_forward_stridedslice():
 
     _test_stridedslice([], [0], [0], [1], "float32", new_axis_mask=1)
     _test_stridedslice([2], [1], [1], [1], "float32", shrink_axis_mask=1)
+    _test_stridedslice([4], [-1], [0], [1], "float32", shrink_axis_mask=1)
     _test_stridedslice([2, 1], [0], [1], [1], "float32", shrink_axis_mask=1)
+    _test_stridedslice([2, 3, 4], [-2], [0], [1], "float32", 
shrink_axis_mask=8)
     _test_stridedslice([2, 3, 4], [0], [1], [1], "float32", shrink_axis_mask=8)
     _test_stridedslice([3, 4, 3], [1, -1, 0], [4, -5, 3], [2, -1, 1], 
"float32")
     _test_stridedslice([3, 4, 3], [1, 0], [4, 3], [2, 1], "float32", 
ellipsis_mask=8)
diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py 
b/tests/python/frontend/tensorflow2/test_functional_models.py
index b3504ff..a39ecb4 100644
--- a/tests/python/frontend/tensorflow2/test_functional_models.py
+++ b/tests/python/frontend/tensorflow2/test_functional_models.py
@@ -354,7 +354,8 @@ def test_concat_v2():
         @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), 
dtype=tf.float32)])
         def func(self, x):
             a, b, c = tf.split(x, 3, axis=1)
-            return tf.raw_ops.ConcatV2(values=[a, b, c], axis=1)
+            axis = tf.add(tf.constant(1, dtype="int32"), tf.constant(0, 
dtype="int32"))
+            return tf.raw_ops.ConcatV2(values=[a, b, c], axis=axis)
 
     run_all(ConcatV2)
 

Reply via email to