This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 06a7cda8aa [Relax][Op] Fixed incorrect output shape of Pool op when 
ceil_mode = true (#18641)
06a7cda8aa is described below

commit 06a7cda8aa7c7c6e7a5e0de082dad0dbb027a35b
Author: Nguyen Duy Loc <[email protected]>
AuthorDate: Tue Jan 6 21:29:22 2026 +0700

    [Relax][Op] Fixed incorrect output shape of Pool op when ceil_mode = true 
(#18641)
    
    ### Summary
    Fixed incorrect output shape of Pool op when ceil_mode = true
    
    ### Steps to Reproduce
    Example: Create Pool Operator from PyTorch
    ```
    class PoolModule(nn.Module):
        def forward(self, x):
            return torch.nn.AvgPool2d(2, 2, 1, True)(x)
    ```
    ```
    class Module:
        def main(x: R.Tensor((1, 3, 17, 17), dtype="float32")) -> 
R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.avg_pool2d(x, pool_size=[2, 2], strides=[2, 2], dilation=[1, 1], 
padding=[1, 1, 1, 1], ceil_mode=True, count_include_pad=True, layout="NCHW", 
out_layout="NCHW")
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv
    ```
    
    ### Expected
    ```
    class Module:
        def main(x: R.Tensor((1, 3, 17, 17), dtype="float32")) -> 
R.Tuple(R.Tensor((1, 3, 9, 9), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 9, 9), dtype="float32") = 
R.nn.avg_pool2d(x, pool_size=[2, 2], strides=[2, 2], dilation=[1, 1], 
padding=[1, 1, 1, 1], ceil_mode=True, count_include_pad=True, layout="NCHW", 
out_layout="NCHW")
                gv: R.Tuple(R.Tensor((1, 3, 9, 9), dtype="float32")) = (lv,)
                R.output(gv)
            return gv
    
    ```
    ### Resolve
    - Citation:
    https://docs.pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
    <img width="500" height="200" alt="PR1"
    
src="https://github.com/user-attachments/assets/52a27448-006f-409e-b8b4-65f49e908d5f";
    />
    
    - Fixed: #18594
---
 include/tvm/topi/nn/pooling.h | 14 ++++++++++++--
 src/relax/op/nn/pooling.cc    | 39 +++++++++++++++++++++++++++++++++------
 2 files changed, 45 insertions(+), 8 deletions(-)

diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h
index b977a54a59..3caf7bf1f7 100644
--- a/include/tvm/topi/nn/pooling.h
+++ b/include/tvm/topi/nn/pooling.h
@@ -563,8 +563,18 @@ inline Tensor pool_impl_nd(const Tensor& x, const 
ffi::Array<PrimExpr>& kernel_s
 
     PrimExpr numerator =
         data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] + 
pad_tail[i];
-    auto out_dim = analyzer.Simplify(indexdiv(numerator, stride[i]) + 1);
-    out_shape.Set(ii, out_dim);
+    auto raw_out = indexdiv(numerator, stride[i]) + 1;
+    if (ceil_mode) {
+      // In the case of ceil_mode=True, we need to check if the last pooling 
window is valid.
+      // If not, we skip the last window as it would start in the bottom 
padded region,
+      // we need to minus 1 to get the correct output shape.
+      auto invalid_last = (raw_out - 1) * stride[i] >= data_shape[ii] + 
pad_head[i];
+      auto out_dim = analyzer.Simplify(if_then_else(invalid_last, raw_out - 1, 
raw_out));
+      out_shape.Set(ii, out_dim);
+    } else {
+      auto out_dim = analyzer.Simplify(raw_out);
+      out_shape.Set(ii, out_dim);
+    }
   }
 
   ffi::Map<ffi::String, ffi::Any> attrs;
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index 5841355200..1a19872c27 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -111,7 +111,13 @@ StructInfo InferStructInfoPool1D(const Call& call, const 
BlockBuilder& ctx) {
   if (attrs->ceil_mode) {
     numerator_w += attrs->strides[0] - 1;
   }
-  out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[0]) + 1);
+  PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[0]) + 1;
+  if (attrs->ceil_mode) {
+    PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[0] >= input_w + 
attrs->padding[0];
+    out_NCW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_w, 
raw_out_w - 1, raw_out_w));
+  } else {
+    out_NCW_shape[2] = analyzer->Simplify(raw_out_w);
+  }
 
   ffi::Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
   return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, 
data_sinfo->vdevice);
@@ -232,8 +238,17 @@ StructInfo InferStructInfoPool2D(const Call& call, const 
BlockBuilder& ctx) {
     numerator_h += attrs->strides[0] - 1;
     numerator_w += attrs->strides[1] - 1;
   }
-  out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, 
attrs->strides[0]) + 1);
-  out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[1]) + 1);
+  PrimExpr raw_out_h = floordiv(numerator_h, attrs->strides[0]) + 1;
+  PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[1]) + 1;
+  if (attrs->ceil_mode) {
+    PrimExpr invalid_last_h = (raw_out_h - 1) * attrs->strides[0] >= input_h + 
attrs->padding[0];
+    PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[1] >= input_w + 
attrs->padding[1];
+    out_NCHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_h, 
raw_out_h - 1, raw_out_h));
+    out_NCHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_w, 
raw_out_w - 1, raw_out_w));
+  } else {
+    out_NCHW_shape[2] = analyzer->Simplify(raw_out_h);
+    out_NCHW_shape[3] = analyzer->Simplify(raw_out_w);
+  }
 
   ffi::Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
   return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, 
data_sinfo->vdevice);
@@ -380,9 +395,21 @@ StructInfo InferStructInfoPool3D(const Call& call, const 
BlockBuilder& ctx) {
     numerator_h += attrs->strides[1] - 1;
     numerator_w += attrs->strides[2] - 1;
   }
-  out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d, 
attrs->strides[0]) + 1);
-  out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, 
attrs->strides[1]) + 1);
-  out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[2]) + 1);
+  PrimExpr raw_out_d = floordiv(numerator_d, attrs->strides[0]) + 1;
+  PrimExpr raw_out_h = floordiv(numerator_h, attrs->strides[1]) + 1;
+  PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[2]) + 1;
+  if (attrs->ceil_mode) {
+    PrimExpr invalid_last_d = (raw_out_d - 1) * attrs->strides[0] >= input_d + 
attrs->padding[0];
+    PrimExpr invalid_last_h = (raw_out_h - 1) * attrs->strides[1] >= input_h + 
attrs->padding[1];
+    PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[2] >= input_w + 
attrs->padding[2];
+    out_NCDHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_d, 
raw_out_d - 1, raw_out_d));
+    out_NCDHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_h, 
raw_out_h - 1, raw_out_h));
+    out_NCDHW_shape[4] = analyzer->Simplify(if_then_else(invalid_last_w, 
raw_out_w - 1, raw_out_w));
+  } else {
+    out_NCDHW_shape[2] = analyzer->Simplify(raw_out_d);
+    out_NCDHW_shape[3] = analyzer->Simplify(raw_out_h);
+    out_NCDHW_shape[4] = analyzer->Simplify(raw_out_w);
+  }
 
   ffi::Array<PrimExpr> out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape);
   return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, 
data_sinfo->vdevice);

Reply via email to