This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c932f2a7e7 [Relax] Add size heuristic to skip folding large creation 
ops (#18764)
c932f2a7e7 is described below

commit c932f2a7e70a7bd3f3f662f38b44bede735afc60
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Feb 12 22:14:19 2026 +0800

    [Relax] Add size heuristic to skip folding large creation ops (#18764)
    
    ## Why
    
    Folding large creation ops (zeros, ones, full, arange) with no tensor
    inputs materializes large constants in
    the binary unnecessarily, since they are cheap to compute at runtime.
    
    ## How
    
    - Add ShouldBeFolded heuristic that skips folding when output exceeds
    1024 elements and the op has no tensor inputs
    - Check call arguments for tensor inputs, including tuples for call_tir
    - Add tests for large creation ops, small creation ops, and large ops
    with tensor inputs
    
    Signed-off-by: Guan-Ming Chiu <[email protected]>
---
 src/relax/transform/fold_constant.cc               | 55 ++++++++++++-
 tests/python/relax/test_transform_fold_constant.py | 89 ++++++++++++++++++++++
 2 files changed, 141 insertions(+), 3 deletions(-)

diff --git a/src/relax/transform/fold_constant.cc 
b/src/relax/transform/fold_constant.cc
index 5a26d15850..8665376766 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -136,10 +136,59 @@ class ConstantFolder : public ExprMutator {
    * folding iota ops could result in large constants being materialized, thus 
increasing the size
    * of the program.
    */
+  static bool ExprContainsTensor(const Expr& expr) {
+    if (GetStructInfo(expr).as<TensorStructInfoNode>()) {
+      return true;
+    }
+    if (const auto* tuple = expr.as<TupleNode>()) {
+      for (const auto& field : tuple->fields) {
+        if (ExprContainsTensor(field)) {
+          return true;
+        }
+      }
+    }
+    return false;
+  }
+
   bool ShouldBeFolded(Expr expr) {
-    // TODO(prakalp): Implement a heuristic to check if folding this expr is 
actually useful or
-    // not.
-    return true;
+    // Skip folding for creation ops (no tensor inputs) that produce large 
outputs.
+    // These ops (e.g., zeros, ones, full, arange) are cheap to compute at 
runtime,
+    // and folding them would materialize large constants in the binary.
+    static constexpr int64_t kMaxFoldElements = 1024;
+
+    const auto* call = expr.as<CallNode>();
+    if (!call) return true;
+
+    const auto* tensor_sinfo = call->struct_info_.as<TensorStructInfoNode>();
+    if (!tensor_sinfo) return true;
+
+    auto opt_shape = tensor_sinfo->GetShape();
+    if (!opt_shape) return true;
+
+    int64_t num_elements = 1;
+    for (const auto& dim : opt_shape.value()) {
+      const auto* int_dim = dim.as<IntImmNode>();
+      if (!int_dim) return true;
+      int64_t d = int_dim->value;
+      if (d <= 0) return true;
+      if (num_elements > kMaxFoldElements / d) {
+        num_elements = kMaxFoldElements + 1;
+        break;
+      }
+      num_elements *= d;
+    }
+
+    if (num_elements <= kMaxFoldElements) return true;
+
+    // Large output. Only skip if there are no tensor inputs,
+    // i.e., this is a pure creation op.
+    for (const auto& arg : call->args) {
+      if (ExprContainsTensor(arg)) {
+        return true;
+      }
+    }
+
+    return false;
   }
 
   // Try constant evaluate a call_tir with a single tensor output.
diff --git a/tests/python/relax/test_transform_fold_constant.py 
b/tests/python/relax/test_transform_fold_constant.py
index 92125bc351..3e453a0ded 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -493,5 +493,94 @@ def test_fold_tuple_output():
     tvm.ir.assert_structural_equal(after, expected)
 
 
+def test_skip_folding_large_creation_op():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def before():
+            with R.dataflow():
+                # 2048 elements > 1024 threshold, no tensor input
+                gv = R.zeros((2048,), "float32")
+                R.output(gv)
+            return gv
+
+    before = Module
+    after = relax.transform.FoldConstant()(before)
+    # The zeros op should NOT be folded because the output is large
+    tvm.ir.assert_structural_equal(after, before)
+
+
+def test_fold_small_creation_op():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def before():
+            with R.dataflow():
+                # 16 elements <= 1024 threshold
+                gv = R.zeros((4, 4), "float32")
+                R.output(gv)
+            return gv
+
+        @R.function
+        def expected(c0: R.Tensor((4, 4), "float32")):
+            return c0
+
+    before = gen_mod(Module, "before", {})
+    expected = gen_mod(Module, "expected", {"c0": np.zeros((4, 4), 
dtype="float32")})
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_fold_boundary_creation_op():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def before():
+            with R.dataflow():
+                # Exactly 1024 elements == threshold, should fold
+                gv = R.zeros((1024,), "float32")
+                R.output(gv)
+            return gv
+
+        @R.function
+        def expected(c0: R.Tensor((1024,), "float32")):
+            return c0
+
+    before = gen_mod(Module, "before", {})
+    expected = gen_mod(Module, "expected", {"c0": np.zeros((1024,), 
dtype="float32")})
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_fold_large_op_with_tensor_input():
+    """Ops with tensor inputs should be folded even if output is large."""
+
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def addone(A: T.Buffer((2048,), "float32"), B: T.Buffer((2048,), 
"float32")) -> None:
+            for i in range(2048):
+                with T.sblock("addone"):
+                    vi = T.axis.remap("S", [i])
+                    B[vi] = A[vi] + T.float32(1)
+
+        @R.function
+        def before(c0: R.Tensor((2048,), "float32")):
+            cls = Module
+            lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((2048,), 
dtype="float32"))
+            return lv0
+
+        @R.function
+        def expected(c1: R.Tensor((2048,), "float32")):
+            return c1
+
+    c0_np = np.arange(2048).astype("float32")
+    c1_np = c0_np + 1
+    before = gen_mod(Module, "before", {"c0": c0_np})
+    expected = gen_mod(Module, "expected", {"c1": c1_np})
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to