This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 06a7cda8aa [Relax][Op] Fixed incorrect output shape of Pool op when
ceil_mode = true (#18641)
06a7cda8aa is described below
commit 06a7cda8aa7c7c6e7a5e0de082dad0dbb027a35b
Author: Nguyen Duy Loc <[email protected]>
AuthorDate: Tue Jan 6 21:29:22 2026 +0700
[Relax][Op] Fixed incorrect output shape of Pool op when ceil_mode = true
(#18641)
### Summary
Fixed incorrect output shape of Pool op when ceil_mode = true
### Steps to Reproduce
Example: Create Pool Operator from PyTorch
```
class PoolModule(nn.Module):
def forward(self, x):
return torch.nn.AvgPool2d(2, 2, 1, True)(x)
```
```
class Module:
def main(x: R.Tensor((1, 3, 17, 17), dtype="float32")) ->
R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.avg_pool2d(x, pool_size=[2, 2], strides=[2, 2], dilation=[1, 1],
padding=[1, 1, 1, 1], ceil_mode=True, count_include_pad=True, layout="NCHW",
out_layout="NCHW")
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
R.output(gv)
return gv
```
### Expected
```
class Module:
def main(x: R.Tensor((1, 3, 17, 17), dtype="float32")) ->
R.Tuple(R.Tensor((1, 3, 9, 9), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 9, 9), dtype="float32") =
R.nn.avg_pool2d(x, pool_size=[2, 2], strides=[2, 2], dilation=[1, 1],
padding=[1, 1, 1, 1], ceil_mode=True, count_include_pad=True, layout="NCHW",
out_layout="NCHW")
gv: R.Tuple(R.Tensor((1, 3, 9, 9), dtype="float32")) = (lv,)
R.output(gv)
return gv
```
### Resolve
- Citation:
https://docs.pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
<img width="500" height="200" alt="PR1"
src="https://github.com/user-attachments/assets/52a27448-006f-409e-b8b4-65f49e908d5f"
/>
- Fixed: #18594
---
include/tvm/topi/nn/pooling.h | 14 ++++++++++++--
src/relax/op/nn/pooling.cc | 39 +++++++++++++++++++++++++++++++++------
2 files changed, 45 insertions(+), 8 deletions(-)
diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h
index b977a54a59..3caf7bf1f7 100644
--- a/include/tvm/topi/nn/pooling.h
+++ b/include/tvm/topi/nn/pooling.h
@@ -563,8 +563,18 @@ inline Tensor pool_impl_nd(const Tensor& x, const
ffi::Array<PrimExpr>& kernel_s
PrimExpr numerator =
data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] +
pad_tail[i];
- auto out_dim = analyzer.Simplify(indexdiv(numerator, stride[i]) + 1);
- out_shape.Set(ii, out_dim);
+ auto raw_out = indexdiv(numerator, stride[i]) + 1;
+ if (ceil_mode) {
+ // In the case of ceil_mode=True, we need to check if the last pooling
window is valid.
+ // If not, we skip the last window as it would start in the bottom
padded region,
+ // we need to minus 1 to get the correct output shape.
+ auto invalid_last = (raw_out - 1) * stride[i] >= data_shape[ii] +
pad_head[i];
+ auto out_dim = analyzer.Simplify(if_then_else(invalid_last, raw_out - 1,
raw_out));
+ out_shape.Set(ii, out_dim);
+ } else {
+ auto out_dim = analyzer.Simplify(raw_out);
+ out_shape.Set(ii, out_dim);
+ }
}
ffi::Map<ffi::String, ffi::Any> attrs;
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index 5841355200..1a19872c27 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -111,7 +111,13 @@ StructInfo InferStructInfoPool1D(const Call& call, const
BlockBuilder& ctx) {
if (attrs->ceil_mode) {
numerator_w += attrs->strides[0] - 1;
}
- out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w,
attrs->strides[0]) + 1);
+ PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[0]) + 1;
+ if (attrs->ceil_mode) {
+ PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[0] >= input_w +
attrs->padding[0];
+ out_NCW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_w,
raw_out_w - 1, raw_out_w));
+ } else {
+ out_NCW_shape[2] = analyzer->Simplify(raw_out_w);
+ }
ffi::Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice);
@@ -232,8 +238,17 @@ StructInfo InferStructInfoPool2D(const Call& call, const
BlockBuilder& ctx) {
numerator_h += attrs->strides[0] - 1;
numerator_w += attrs->strides[1] - 1;
}
- out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h,
attrs->strides[0]) + 1);
- out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w,
attrs->strides[1]) + 1);
+ PrimExpr raw_out_h = floordiv(numerator_h, attrs->strides[0]) + 1;
+ PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[1]) + 1;
+ if (attrs->ceil_mode) {
+ PrimExpr invalid_last_h = (raw_out_h - 1) * attrs->strides[0] >= input_h +
attrs->padding[0];
+ PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[1] >= input_w +
attrs->padding[1];
+ out_NCHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_h,
raw_out_h - 1, raw_out_h));
+ out_NCHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_w,
raw_out_w - 1, raw_out_w));
+ } else {
+ out_NCHW_shape[2] = analyzer->Simplify(raw_out_h);
+ out_NCHW_shape[3] = analyzer->Simplify(raw_out_w);
+ }
ffi::Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice);
@@ -380,9 +395,21 @@ StructInfo InferStructInfoPool3D(const Call& call, const
BlockBuilder& ctx) {
numerator_h += attrs->strides[1] - 1;
numerator_w += attrs->strides[2] - 1;
}
- out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d,
attrs->strides[0]) + 1);
- out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h,
attrs->strides[1]) + 1);
- out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w,
attrs->strides[2]) + 1);
+ PrimExpr raw_out_d = floordiv(numerator_d, attrs->strides[0]) + 1;
+ PrimExpr raw_out_h = floordiv(numerator_h, attrs->strides[1]) + 1;
+ PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[2]) + 1;
+ if (attrs->ceil_mode) {
+ PrimExpr invalid_last_d = (raw_out_d - 1) * attrs->strides[0] >= input_d +
attrs->padding[0];
+ PrimExpr invalid_last_h = (raw_out_h - 1) * attrs->strides[1] >= input_h +
attrs->padding[1];
+ PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[2] >= input_w +
attrs->padding[2];
+ out_NCDHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_d,
raw_out_d - 1, raw_out_d));
+ out_NCDHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_h,
raw_out_h - 1, raw_out_h));
+ out_NCDHW_shape[4] = analyzer->Simplify(if_then_else(invalid_last_w,
raw_out_w - 1, raw_out_w));
+ } else {
+ out_NCDHW_shape[2] = analyzer->Simplify(raw_out_d);
+ out_NCDHW_shape[3] = analyzer->Simplify(raw_out_h);
+ out_NCDHW_shape[4] = analyzer->Simplify(raw_out_w);
+ }
ffi::Array<PrimExpr> out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape);
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice);