This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 663f7ae77b [Fix][Relay] Fix axis transformation in squeeze shape
function (#14135)
663f7ae77b is described below
commit 663f7ae77b1675949062b0574e45420235118a92
Author: Yuchao Zhang <[email protected]>
AuthorDate: Tue Feb 28 03:14:53 2023 +0800
[Fix][Relay] Fix axis transformation in squeeze shape function (#14135)
* fix squeeze shape function issue and add testcase.
* fix lint
---
python/tvm/relay/op/_transform.py | 4 +++-
tests/python/relay/test_any.py | 14 ++++++++++++++
2 files changed, 17 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/op/_transform.py
b/python/tvm/relay/op/_transform.py
index f28c28ce62..12450dc809 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -920,7 +920,9 @@ def squeeze_shape_func(attrs, inputs, _):
keep_axes = []
remove_axes = []
if axis is not None:
- for i in range(inputs[0].shape[0].value):
+ ndim = inputs[0].shape[0].value
+ axis = [i + ndim if i < 0 else i for i in axis]
+ for i in range(ndim):
if i not in axis:
keep_axes.append(i)
else:
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 37aa2271a5..86af1ad1c3 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -501,6 +501,18 @@ def verify_any_squeeze(data_shape, axis,
static_data_shape):
check_result([data_np], mod, ref_out)
+def verify_any_squeeze_sqrt(data_shape, axis, static_data_shape):
+ mod = tvm.IRModule()
+ dtype = "float32"
+ data = relay.var("data", shape=data_shape, dtype=dtype)
+ y = relay.squeeze(data, axis=axis)
+ y = relay.sqrt(y)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ ref_out = np.sqrt(np.squeeze(data_np, axis))
+ check_result([data_np], mod, ref_out)
+
+
@tvm.testing.uses_gpu
def test_any_squeeze():
verify_any_squeeze((relay.Any(), relay.Any(), relay.Any()), (0,), (1, 9,
8))
@@ -508,6 +520,8 @@ def test_any_squeeze():
verify_any_squeeze(
(1, relay.Any(), relay.Any(), 1, relay.Any(), relay.Any()), (0, 3),
(1, 12, 2, 1, 9, 17)
)
+ verify_any_squeeze_sqrt((1, relay.Any(), 12, 32, 1), (-1,), (1, 100, 12,
32, 1))
+ verify_any_squeeze_sqrt((relay.Any(), relay.Any(), relay.Any(), 1), (-1,),
(1, 9, 8, 1))
@tvm.testing.uses_gpu