This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new faab2e7f27 [Relax] Fix the squeeze operator to behave consistently 
with torch (#18478)
faab2e7f27 is described below

commit faab2e7f27341516b574f5ef1bc00a11a2261d2a
Author: ConvolutedDog <[email protected]>
AuthorDate: Mon Nov 24 15:27:19 2025 +0800

    [Relax] Fix the squeeze operator to behave consistently with torch (#18478)
    
    This commit fixes the squeeze operator to behave consistently with
    PyTorch
    by implementing no-op behavior when squeezing dimensions that are not of
    size 1.
    
    Previously:
      squeeze(x, [1]) on tensor with shape [32, 10, 5] would fail
    
    Now:
    squeeze(x, [1]) on tensor with shape [32, 10, 5] returns the original
    tensor
      without modification, matching PyTorch's behavior
    
    This fixes compatibility issues when converting PyTorch models that use
    squeeze with dimensions that may not always be 1 during inference."
    
    This work was done in collaboration with guan404ming's commit d87841d.
---
 include/tvm/topi/transform.h                          |  7 ++++---
 .../relax/frontend/torch/base_fx_graph_translator.py  |  2 +-
 src/relax/op/tensor/manipulate.cc                     | 11 +++--------
 .../relax/test_frontend_from_exported_program.py      | 19 ++++++++++++++++++-
 tests/python/relax/test_op_manipulate.py              | 18 +++++++++++++-----
 5 files changed, 39 insertions(+), 18 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 2d7096613b..ef4830a46a 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -428,10 +428,11 @@ inline Tensor squeeze(const Tensor& x, 
ffi::Optional<ffi::Array<Integer>> opt_ax
       if (val < 0) {
         val += static_cast<int>(x->shape.size());
       }
-      if (IsConstInt(x->shape[val])) {
-        ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " 
must have size 1";
+      // If a dimension is not 1, silently skip it (no-op).
+      bool is_const = IsConstInt(x->shape[val]);
+      if ((is_const && GetConstInt(x->shape[val]) == 1) || !is_const) {
+        axis_val.push_back(val);
       }
-      axis_val.push_back(val);
     }
   }
 
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 3a3e0360af..fb8790322e 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -2003,7 +2003,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             valid_dims = []
             for d in dim:
                 axis = d if d >= 0 else len(shape) + d
-                if axis < len(shape) and shape[axis] == 1:
+                if axis < len(shape):
                     valid_dims.append(d)
             # If no valid dims, use None to squeeze all size-1 dimensions
             dim = valid_dims if valid_dims else None
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 78244a8bc5..0768e899b1 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -1234,15 +1234,10 @@ StructInfo InferStructInfoSqueeze(const Call& call, 
const BlockBuilder& ctx) {
       // Todo(relax-team): revisit here for better check on if the axis being 
squeezed has length 1.
       // When `axis` is given, the dim lengths at the axes must be integer 1 
when it is not symbolic
       const auto* int_len = shape_value.value()[axes[i]].as<IntImmNode>();
-      if (int_len != nullptr && int_len->value != 1) {
-        ctx->ReportFatal(Diagnostic::Error(call)
-                         << "Squeeze expects the input tensor shape values at 
the given axis "
-                            "positions to be all 1. However, the tensor shape 
at axis "
-                         << axes[i] << " is " << shape_value.value()[axes[i]]
-                         << " which is not 1. If it is symbolic, please use 
MatchCast to cast it "
-                            "to 1 before doing Squeeze.");
+      // If a dimension is not 1, silently skip it (no-op), matching PyTorch 
behavior.
+      if ((int_len != nullptr && int_len->value == 1) || int_len == nullptr) {
+        axis_removal_mask[axes[i]] = true;
       }
-      axis_removal_mask[axes[i]] = true;
     }
   } else {
     // When `axis` is not defined, squeeze all unit-length dimensions.
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 3435ac5670..89017e30a7 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5482,15 +5482,32 @@ def test_squeeze():
             input: R.Tensor((3, 1, 4, 1), dtype="float32")
         ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input, 
axis=[1, 3])
+                lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input, 
axis=[0, 1, 2, 3])
                 gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
+    class Squeeze3(Module):
+        def forward(self, input):
+            return input.squeeze(2)
+
+    @I.ir_module
+    class Expected3:
+        @R.function
+        def main(
+            inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
+        ) -> R.Tuple(R.Tensor((3, 1, 4, 1), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3, 1, 4, 1), dtype="float32") = R.squeeze(inp_0, 
axis=[2])
+                gv: R.Tuple(R.Tensor((3, 1, 4, 1), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
     example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),)
 
     verify_model(Squeeze1(), example_args, {}, Expected1)
     verify_model(Squeeze2(), example_args, {}, Expected2)
+    verify_model(Squeeze3(), example_args, {}, Expected3)
 
 
 def test_stack():
diff --git a/tests/python/relax/test_op_manipulate.py 
b/tests/python/relax/test_op_manipulate.py
index 004c4b9618..d39584e06b 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -994,11 +994,19 @@ def test_squeeze_infer_struct_info_axis_length_not_one():
     x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
     x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
 
-    with pytest.raises(TVMError):
-        bb.normalize(relax.op.squeeze(x0, [0]))
-    _check_inference(bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo((3, 
4), "float32"))
-    with pytest.raises(TVMError):
-        bb.normalize(relax.op.squeeze(x2, [0]))
+    # Squeeze concrete shape (2,3,4) at axis=0, but axis length 2 != 1, 
squeeze is no-op.
+    _check_inference(
+        bb, relax.op.squeeze(x0, [0]), relax.TensorStructInfo(shape=(2, 3, 4), 
dtype="float32")
+    )
+    # Squeeze symbolic shape (a,3,4) at axis=0, assuming a can achieve 
successful squeeze.
+    _check_inference(
+        bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo(shape=(3, 4), 
dtype="float32")
+    )
+    # Squeeze shape variable s0 (corresponding to (2,3,4)) at axis=0.
+    _check_inference(
+        bb, relax.op.squeeze(x2, [0]), relax.TensorStructInfo(shape=s0, 
dtype="float32")
+    )
+    # Squeeze shape variable s1 (a,3,4) at axis=0, assuming a can achieve 
successful squeeze.
     _check_inference(bb, relax.op.squeeze(x3, [0]), 
relax.TensorStructInfo(dtype="float32", ndim=2))
 
 

Reply via email to