This is an automated email from the ASF dual-hosted git repository. spectrometerHBH pushed a commit to branch tir-bench in repository https://gitbox.apache.org/repos/asf/tvm.git
commit aa035ff69fe8becad4187683c24351b46d5b6ee3 Author: Hongyi Jin <[email protected]> AuthorDate: Thu May 28 15:59:41 2026 -0400 fix(arith): memoize IntervalSet variable relaxation to avoid exponential blowup (#652) IntervalSetEvaluator relaxes variable bounds by recursively evaluating both the min and max sub-expressions of each mapped variable's interval. For diamond-shaped variable dependency chains (var a -> {b, c}, b -> {d, e}, ...) this re-expands shared sub-expressions along every path, costing O(2^depth) in the length of the dependency chain (bounded only by dom_map_.size()). This made tirx.LowerTIRxCleanup hang indefinitely (>300s, ~200% CPU, no GPU work) when Analyzer::Bind evaluated the int_set of a small (5-node) bound expression whose variables transitively referenced ~67 other bound vars: the evaluator reached 2^20+ VisitExpr calls at recursion depth 67 with no end in sight. Repro: GDN prefill v0_0 block-Neumann kernels (B00008). Fix: memoize the fully-relaxed interval per variable and break cyclic dependencies with an in-progress set. Each variable's relaxed interval is deterministic for a given evaluator instance (dom_map_/dom_constraints_ are fixed), so caching collapses the diamonds to linear cost. Short chains (the common case, never hitting the old depth cutoff) are unaffected. LowerTIRxCleanup on the repro now completes in ~0.01s and full tvm.compile(tir_pipeline="tirx") of both Neumann variants succeeds. All tests/python/arith pass (983 passed, 2 skipped, 8 xfailed). --- src/arith/int_set.cc | 32 +++++++++++++++++++++++++++++--- tests/python/arith/test_arith_intset.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 6b3e2b9532..86a2d949bc 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -33,6 +33,7 @@ #include <algorithm> #include <unordered_map> +#include <unordered_set> #include <utility> #include "constraint_extract.h" @@ -458,9 +459,29 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> { if (res->min_value.same_as(var) && res->max_value.same_as(var)) { return res; } - // recursively evaluate mapped result - // in case the domain contains variables to be relaxed. - return Eval(res); + // Recursively relax the mapped interval, since the domain bounds may + // themselves reference other variables that need to be relaxed. + // + // Memoize the fully-relaxed interval per variable, and guard against + // cyclic variable dependencies with an in-progress set. Without this, + // diamond-shaped variable dependencies (var a -> {b, c}, b -> {d, e}, ...) + // are re-expanded along every path: each level evaluates both the min and + // max sub-expressions, so the cost is exponential (2^depth) in the length + // of the variable dependency chain rather than linear. + auto memo_it = relax_memo_.find(op); + if (memo_it != relax_memo_.end()) { + return memo_it->second; + } + if (relax_in_progress_.count(op)) { + // Cyclic dependency among variable bounds: stop relaxing here to keep + // the recursion finite, keeping this variable symbolic. + return res; + } + relax_in_progress_.insert(op); + IntervalSet relaxed = Eval(res); + relax_in_progress_.erase(op); + relax_memo_[op] = relaxed; + return relaxed; } IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_<Add>(op); } @@ -606,6 +627,11 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> { // recursive depth int recur_depth_{0}; + // Memo of fully-relaxed interval sets per variable, to avoid exponential + // re-expansion of diamond-shaped variable dependencies. + std::unordered_map<const VarNode*, IntervalSet> relax_memo_; + // Variables currently being relaxed, used to break cyclic dependencies. + std::unordered_set<const VarNode*> relax_in_progress_; // analyzer Analyzer* analyzer_; const ffi::Map<Var, IntSet>& dom_map_; diff --git a/tests/python/arith/test_arith_intset.py b/tests/python/arith/test_arith_intset.py index 49e09191d6..03ca4d9f3b 100644 --- a/tests/python/arith/test_arith_intset.py +++ b/tests/python/arith/test_arith_intset.py @@ -394,5 +394,35 @@ def test_modular_set(): ) +def test_relax_deep_variable_dependency_chain(): + """Regression test for B00008. + + When a variable's interval bound references another variable that is also in + the domain map, the evaluator relaxes it transitively. A diamond-shaped + chain -- where each variable's bound references the next one in *both* its + min and its max -- used to be re-expanded along every path, costing + O(2^depth) and hanging tirx.LowerTIRxCleanup indefinitely. The relaxation is + now memoized per variable, so this completes in linear time. + """ + ck = IntSetChecker() + n = 64 # 2^64 expansions without memoization; trivially fast with it. + xs = [tvm.tirx.Var(f"x{i}", "int32") for i in range(n + 1)] + dmap = {xs[i]: tvm.arith.IntervalSet(xs[i + 1] - 1, xs[i + 1] + 1) for i in range(n)} + dmap[xs[n]] = tvm.arith.IntervalSet(0, 100) + # x0 relaxes through the whole chain: [0 - n, 100 + n]. + ck.verify(xs[0], dmap, (-n, 100 + n)) + + +def test_relax_cyclic_variable_dependency(): + """A cyclic variable dependency must terminate (and stay symbolic).""" + ana = tvm.arith.Analyzer() + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + # x depends on y and y depends on x: relaxation must not loop forever. + dmap = {x: tvm.arith.IntervalSet(y, y), y: tvm.arith.IntervalSet(x, x)} + res = ana.int_set(x, dmap) + assert res is not None + + if __name__ == "__main__": tvm.testing.main()
