This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c932f2a7e7 [Relax] Add size heuristic to skip folding large creation
ops (#18764)
c932f2a7e7 is described below
commit c932f2a7e70a7bd3f3f662f38b44bede735afc60
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Feb 12 22:14:19 2026 +0800
[Relax] Add size heuristic to skip folding large creation ops (#18764)
## Why
Folding large creation ops (zeros, ones, full, arange) with no tensor
inputs materializes large constants in
the binary unnecessarily, since they are cheap to compute at runtime.
## How
- Add ShouldBeFolded heuristic that skips folding when output exceeds
1024 elements and the op has no tensor inputs
- Check call arguments for tensor inputs, including tuples for call_tir
- Add tests for large creation ops, small creation ops, and large ops
with tensor inputs
Signed-off-by: Guan-Ming Chiu <[email protected]>
---
src/relax/transform/fold_constant.cc | 55 ++++++++++++-
tests/python/relax/test_transform_fold_constant.py | 89 ++++++++++++++++++++++
2 files changed, 141 insertions(+), 3 deletions(-)
diff --git a/src/relax/transform/fold_constant.cc
b/src/relax/transform/fold_constant.cc
index 5a26d15850..8665376766 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -136,10 +136,59 @@ class ConstantFolder : public ExprMutator {
* folding iota ops could result in large constants being materialized, thus
increasing the size
* of the program.
*/
+ static bool ExprContainsTensor(const Expr& expr) {
+ if (GetStructInfo(expr).as<TensorStructInfoNode>()) {
+ return true;
+ }
+ if (const auto* tuple = expr.as<TupleNode>()) {
+ for (const auto& field : tuple->fields) {
+ if (ExprContainsTensor(field)) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
bool ShouldBeFolded(Expr expr) {
- // TODO(prakalp): Implement a heuristic to check if folding this expr is
actually useful or
- // not.
- return true;
+ // Skip folding for creation ops (no tensor inputs) that produce large
outputs.
+ // These ops (e.g., zeros, ones, full, arange) are cheap to compute at
runtime,
+ // and folding them would materialize large constants in the binary.
+ static constexpr int64_t kMaxFoldElements = 1024;
+
+ const auto* call = expr.as<CallNode>();
+ if (!call) return true;
+
+ const auto* tensor_sinfo = call->struct_info_.as<TensorStructInfoNode>();
+ if (!tensor_sinfo) return true;
+
+ auto opt_shape = tensor_sinfo->GetShape();
+ if (!opt_shape) return true;
+
+ int64_t num_elements = 1;
+ for (const auto& dim : opt_shape.value()) {
+ const auto* int_dim = dim.as<IntImmNode>();
+ if (!int_dim) return true;
+ int64_t d = int_dim->value;
+ if (d <= 0) return true;
+ if (num_elements > kMaxFoldElements / d) {
+ num_elements = kMaxFoldElements + 1;
+ break;
+ }
+ num_elements *= d;
+ }
+
+ if (num_elements <= kMaxFoldElements) return true;
+
+ // Large output. Only skip if there are no tensor inputs,
+ // i.e., this is a pure creation op.
+ for (const auto& arg : call->args) {
+ if (ExprContainsTensor(arg)) {
+ return true;
+ }
+ }
+
+ return false;
}
// Try constant evaluate a call_tir with a single tensor output.
diff --git a/tests/python/relax/test_transform_fold_constant.py
b/tests/python/relax/test_transform_fold_constant.py
index 92125bc351..3e453a0ded 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -493,5 +493,94 @@ def test_fold_tuple_output():
tvm.ir.assert_structural_equal(after, expected)
+def test_skip_folding_large_creation_op():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def before():
+ with R.dataflow():
+ # 2048 elements > 1024 threshold, no tensor input
+ gv = R.zeros((2048,), "float32")
+ R.output(gv)
+ return gv
+
+ before = Module
+ after = relax.transform.FoldConstant()(before)
+ # The zeros op should NOT be folded because the output is large
+ tvm.ir.assert_structural_equal(after, before)
+
+
+def test_fold_small_creation_op():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def before():
+ with R.dataflow():
+ # 16 elements <= 1024 threshold
+ gv = R.zeros((4, 4), "float32")
+ R.output(gv)
+ return gv
+
+ @R.function
+ def expected(c0: R.Tensor((4, 4), "float32")):
+ return c0
+
+ before = gen_mod(Module, "before", {})
+ expected = gen_mod(Module, "expected", {"c0": np.zeros((4, 4),
dtype="float32")})
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_fold_boundary_creation_op():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def before():
+ with R.dataflow():
+ # Exactly 1024 elements == threshold, should fold
+ gv = R.zeros((1024,), "float32")
+ R.output(gv)
+ return gv
+
+ @R.function
+ def expected(c0: R.Tensor((1024,), "float32")):
+ return c0
+
+ before = gen_mod(Module, "before", {})
+ expected = gen_mod(Module, "expected", {"c0": np.zeros((1024,),
dtype="float32")})
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_fold_large_op_with_tensor_input():
+ """Ops with tensor inputs should be folded even if output is large."""
+
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def addone(A: T.Buffer((2048,), "float32"), B: T.Buffer((2048,),
"float32")) -> None:
+ for i in range(2048):
+ with T.sblock("addone"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi] + T.float32(1)
+
+ @R.function
+ def before(c0: R.Tensor((2048,), "float32")):
+ cls = Module
+ lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((2048,),
dtype="float32"))
+ return lv0
+
+ @R.function
+ def expected(c1: R.Tensor((2048,), "float32")):
+ return c1
+
+ c0_np = np.arange(2048).astype("float32")
+ c1_np = c0_np + 1
+ before = gen_mod(Module, "before", {"c0": c0_np})
+ expected = gen_mod(Module, "expected", {"c1": c1_np})
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()