This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new faab2e7f27 [Relax] Fix the squeeze operator to behave consistently
with torch (#18478)
faab2e7f27 is described below
commit faab2e7f27341516b574f5ef1bc00a11a2261d2a
Author: ConvolutedDog <[email protected]>
AuthorDate: Mon Nov 24 15:27:19 2025 +0800
[Relax] Fix the squeeze operator to behave consistently with torch (#18478)
This commit fixes the squeeze operator to behave consistently with
PyTorch
by implementing no-op behavior when squeezing dimensions that are not of
size 1.
Previously:
squeeze(x, [1]) on tensor with shape [32, 10, 5] would fail
Now:
squeeze(x, [1]) on tensor with shape [32, 10, 5] returns the original
tensor
without modification, matching PyTorch's behavior
This fixes compatibility issues when converting PyTorch models that use
squeeze with dimensions that may not always be 1 during inference."
This work was done in collaboration with guan404ming's commit d87841d.
---
include/tvm/topi/transform.h | 7 ++++---
.../relax/frontend/torch/base_fx_graph_translator.py | 2 +-
src/relax/op/tensor/manipulate.cc | 11 +++--------
.../relax/test_frontend_from_exported_program.py | 19 ++++++++++++++++++-
tests/python/relax/test_op_manipulate.py | 18 +++++++++++++-----
5 files changed, 39 insertions(+), 18 deletions(-)
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 2d7096613b..ef4830a46a 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -428,10 +428,11 @@ inline Tensor squeeze(const Tensor& x,
ffi::Optional<ffi::Array<Integer>> opt_ax
if (val < 0) {
val += static_cast<int>(x->shape.size());
}
- if (IsConstInt(x->shape[val])) {
- ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << "
must have size 1";
+ // If a dimension is not 1, silently skip it (no-op).
+ bool is_const = IsConstInt(x->shape[val]);
+ if ((is_const && GetConstInt(x->shape[val]) == 1) || !is_const) {
+ axis_val.push_back(val);
}
- axis_val.push_back(val);
}
}
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 3a3e0360af..fb8790322e 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -2003,7 +2003,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
valid_dims = []
for d in dim:
axis = d if d >= 0 else len(shape) + d
- if axis < len(shape) and shape[axis] == 1:
+ if axis < len(shape):
valid_dims.append(d)
# If no valid dims, use None to squeeze all size-1 dimensions
dim = valid_dims if valid_dims else None
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 78244a8bc5..0768e899b1 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -1234,15 +1234,10 @@ StructInfo InferStructInfoSqueeze(const Call& call,
const BlockBuilder& ctx) {
// Todo(relax-team): revisit here for better check on if the axis being
squeezed has length 1.
// When `axis` is given, the dim lengths at the axes must be integer 1
when it is not symbolic
const auto* int_len = shape_value.value()[axes[i]].as<IntImmNode>();
- if (int_len != nullptr && int_len->value != 1) {
- ctx->ReportFatal(Diagnostic::Error(call)
- << "Squeeze expects the input tensor shape values at
the given axis "
- "positions to be all 1. However, the tensor shape
at axis "
- << axes[i] << " is " << shape_value.value()[axes[i]]
- << " which is not 1. If it is symbolic, please use
MatchCast to cast it "
- "to 1 before doing Squeeze.");
+ // If a dimension is not 1, silently skip it (no-op), matching PyTorch
behavior.
+ if ((int_len != nullptr && int_len->value == 1) || int_len == nullptr) {
+ axis_removal_mask[axes[i]] = true;
}
- axis_removal_mask[axes[i]] = true;
}
} else {
// When `axis` is not defined, squeeze all unit-length dimensions.
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 3435ac5670..89017e30a7 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5482,15 +5482,32 @@ def test_squeeze():
input: R.Tensor((3, 1, 4, 1), dtype="float32")
) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input,
axis=[1, 3])
+ lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input,
axis=[0, 1, 2, 3])
gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv
+ class Squeeze3(Module):
+ def forward(self, input):
+ return input.squeeze(2)
+
+ @I.ir_module
+ class Expected3:
+ @R.function
+ def main(
+ inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
+ ) -> R.Tuple(R.Tensor((3, 1, 4, 1), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3, 1, 4, 1), dtype="float32") = R.squeeze(inp_0,
axis=[2])
+ gv: R.Tuple(R.Tensor((3, 1, 4, 1), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),)
verify_model(Squeeze1(), example_args, {}, Expected1)
verify_model(Squeeze2(), example_args, {}, Expected2)
+ verify_model(Squeeze3(), example_args, {}, Expected3)
def test_stack():
diff --git a/tests/python/relax/test_op_manipulate.py
b/tests/python/relax/test_op_manipulate.py
index 004c4b9618..d39584e06b 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -994,11 +994,19 @@ def test_squeeze_infer_struct_info_axis_length_not_one():
x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
- with pytest.raises(TVMError):
- bb.normalize(relax.op.squeeze(x0, [0]))
- _check_inference(bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo((3,
4), "float32"))
- with pytest.raises(TVMError):
- bb.normalize(relax.op.squeeze(x2, [0]))
+ # Squeeze concrete shape (2,3,4) at axis=0, but axis length 2 != 1,
squeeze is no-op.
+ _check_inference(
+ bb, relax.op.squeeze(x0, [0]), relax.TensorStructInfo(shape=(2, 3, 4),
dtype="float32")
+ )
+ # Squeeze symbolic shape (a,3,4) at axis=0, assuming a can achieve
successful squeeze.
+ _check_inference(
+ bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo(shape=(3, 4),
dtype="float32")
+ )
+ # Squeeze shape variable s0 (corresponding to (2,3,4)) at axis=0.
+ _check_inference(
+ bb, relax.op.squeeze(x2, [0]), relax.TensorStructInfo(shape=s0,
dtype="float32")
+ )
+ # Squeeze shape variable s1 (a,3,4) at axis=0, assuming a can achieve
successful squeeze.
_check_inference(bb, relax.op.squeeze(x3, [0]),
relax.TensorStructInfo(dtype="float32", ndim=2))