This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c95d16e [Frontend][Tensorflow2] Stridedslice and concat_v2 fix (#8483)
c95d16e is described below
commit c95d16e097b8c8be5322a6c92d063bf3ae78eddb
Author: srinidhigoud <[email protected]>
AuthorDate: Fri Jul 16 23:26:41 2021 -0700
[Frontend][Tensorflow2] Stridedslice and concat_v2 fix (#8483)
* fix for strided_slice when begin > end in case of shrinkaxis_mask
* fix for name_hint missing error for concat_v2 op
* removing a local fix
* adding more testing capability to concat_v2
---
python/tvm/relay/frontend/tensorflow_ops.py | 10 ++++++++--
tests/python/frontend/tensorflow/test_forward.py | 2 ++
tests/python/frontend/tensorflow2/test_functional_models.py | 3 ++-
3 files changed, 12 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow_ops.py
b/python/tvm/relay/frontend/tensorflow_ops.py
index 797ff51..ba0fcca 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -1483,7 +1483,13 @@ def _identityn():
def _concatV2():
def _impl(inputs, attr, params, mod):
pop_node = inputs.pop(len(inputs) - 1)
- axis = int(_get_num_param(params, pop_node))
+ try:
+ axis = int(_get_num_param(params, pop_node))
+ except (IndexError, KeyError, AttributeError):
+ try:
+ axis = int(_infer_value(pop_node, params, mod).numpy())
+ except Exception:
+ axis = int(pop_node)
return AttrCvt(op_name="concatenate", ignores=["T", "N", "Tidx"],
extras={"axis": axis})(
[inputs], attr
)
@@ -2244,7 +2250,7 @@ def _stridedSlice():
if begin[index] < 0
else begin[index]
)
- m_end[final_index] = begin[index] + 1
+ m_end[final_index] = m_begin[final_index] + 1
m_stride[final_index] = 1
fshape_indices.append(-2)
else:
diff --git a/tests/python/frontend/tensorflow/test_forward.py
b/tests/python/frontend/tensorflow/test_forward.py
index 583014f..9bbb6ca 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -2553,7 +2553,9 @@ def test_forward_stridedslice():
_test_stridedslice([], [0], [0], [1], "float32", new_axis_mask=1)
_test_stridedslice([2], [1], [1], [1], "float32", shrink_axis_mask=1)
+ _test_stridedslice([4], [-1], [0], [1], "float32", shrink_axis_mask=1)
_test_stridedslice([2, 1], [0], [1], [1], "float32", shrink_axis_mask=1)
+ _test_stridedslice([2, 3, 4], [-2], [0], [1], "float32",
shrink_axis_mask=8)
_test_stridedslice([2, 3, 4], [0], [1], [1], "float32", shrink_axis_mask=8)
_test_stridedslice([3, 4, 3], [1, -1, 0], [4, -5, 3], [2, -1, 1],
"float32")
_test_stridedslice([3, 4, 3], [1, 0], [4, 3], [2, 1], "float32",
ellipsis_mask=8)
diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py
b/tests/python/frontend/tensorflow2/test_functional_models.py
index b3504ff..a39ecb4 100644
--- a/tests/python/frontend/tensorflow2/test_functional_models.py
+++ b/tests/python/frontend/tensorflow2/test_functional_models.py
@@ -354,7 +354,8 @@ def test_concat_v2():
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 30),
dtype=tf.float32)])
def func(self, x):
a, b, c = tf.split(x, 3, axis=1)
- return tf.raw_ops.ConcatV2(values=[a, b, c], axis=1)
+ axis = tf.add(tf.constant(1, dtype="int32"), tf.constant(0,
dtype="int32"))
+ return tf.raw_ops.ConcatV2(values=[a, b, c], axis=axis)
run_all(ConcatV2)