This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fbd9fe4048 Fix division truncation in window size calculation for 
small dtypes in average_pool (#18014)
fbd9fe4048 is described below

commit fbd9fe40487a395995b3d2745c00bfa981af6a4b
Author: Qingchao Shen <[email protected]>
AuthorDate: Mon May 26 20:45:14 2025 +0800

    Fix division truncation in window size calculation for small dtypes in 
average_pool (#18014)
    
    * Update pooling.h
    
    * Update test_te_create_primfunc.py
    
    * fix lint error
---
 include/tvm/topi/nn/pooling.h              | 2 +-
 tests/python/te/test_te_create_primfunc.py | 8 ++++++++
 2 files changed, 9 insertions(+), 1 deletion(-)

diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h
index abe26b6c67..8e13ae49af 100644
--- a/include/tvm/topi/nn/pooling.h
+++ b/include/tvm/topi/nn/pooling.h
@@ -383,7 +383,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const 
Array<PrimExpr>& output_
 
           PrimExpr divide_factor = tvm::cast(x->dtype, 1);
           for (size_t i = 0; i < n_dim; ++i) {
-            divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
+            divide_factor *= tvm::cast(DataType::Int(32), 
reduce_axes[i]->dom->extent);
           }
 
           return div(pool_sum(indices), divide_factor);
diff --git a/tests/python/te/test_te_create_primfunc.py 
b/tests/python/te/test_te_create_primfunc.py
index 9925f54be4..b0850a89b5 100644
--- a/tests/python/te/test_te_create_primfunc.py
+++ b/tests/python/te/test_te_create_primfunc.py
@@ -882,6 +882,14 @@ def test_adaptive_pooling_window():
     _check_workload(te_workload, tir_workload)
 
 
+def test_global_pool():
+    # fix the issue-17938
+    data = te.placeholder((1, 1, 32, 32), dtype="int8", name="data")
+    op_output = topi.nn.global_pool(data=data, pool_type="avg", layout="NCHW")
+    f = te.create_prim_func([data, op_output])
+    assert f
+
+
 def test_nested_reduce_domain_dependency():
     @T.prim_func
     def tir_workload(

Reply via email to