This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 663f7ae77b [Fix][Relay] Fix axis transformation in squeeze shape 
function (#14135)
663f7ae77b is described below

commit 663f7ae77b1675949062b0574e45420235118a92
Author: Yuchao Zhang <[email protected]>
AuthorDate: Tue Feb 28 03:14:53 2023 +0800

    [Fix][Relay] Fix axis transformation in squeeze shape function (#14135)
    
    * fix squeeze shape function issue and add testcase.
    
    * fix lint
---
 python/tvm/relay/op/_transform.py |  4 +++-
 tests/python/relay/test_any.py    | 14 ++++++++++++++
 2 files changed, 17 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/op/_transform.py 
b/python/tvm/relay/op/_transform.py
index f28c28ce62..12450dc809 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -920,7 +920,9 @@ def squeeze_shape_func(attrs, inputs, _):
     keep_axes = []
     remove_axes = []
     if axis is not None:
-        for i in range(inputs[0].shape[0].value):
+        ndim = inputs[0].shape[0].value
+        axis = [i + ndim if i < 0 else i for i in axis]
+        for i in range(ndim):
             if i not in axis:
                 keep_axes.append(i)
             else:
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 37aa2271a5..86af1ad1c3 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -501,6 +501,18 @@ def verify_any_squeeze(data_shape, axis, 
static_data_shape):
     check_result([data_np], mod, ref_out)
 
 
+def verify_any_squeeze_sqrt(data_shape, axis, static_data_shape):
+    mod = tvm.IRModule()
+    dtype = "float32"
+    data = relay.var("data", shape=data_shape, dtype=dtype)
+    y = relay.squeeze(data, axis=axis)
+    y = relay.sqrt(y)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    ref_out = np.sqrt(np.squeeze(data_np, axis))
+    check_result([data_np], mod, ref_out)
+
+
 @tvm.testing.uses_gpu
 def test_any_squeeze():
     verify_any_squeeze((relay.Any(), relay.Any(), relay.Any()), (0,), (1, 9, 
8))
@@ -508,6 +520,8 @@ def test_any_squeeze():
     verify_any_squeeze(
         (1, relay.Any(), relay.Any(), 1, relay.Any(), relay.Any()), (0, 3), 
(1, 12, 2, 1, 9, 17)
     )
+    verify_any_squeeze_sqrt((1, relay.Any(), 12, 32, 1), (-1,), (1, 100, 12, 
32, 1))
+    verify_any_squeeze_sqrt((relay.Any(), relay.Any(), relay.Any(), 1), (-1,), 
(1, 9, 8, 1))
 
 
 @tvm.testing.uses_gpu

Reply via email to