echuraev commented on code in PR #19584:
URL: https://github.com/apache/tvm/pull/19584#discussion_r3387231834
##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -566,20 +619,38 @@ void TransitiveComparisonAnalyzer::Impl::Bind(const
tirx::Var& var, const Range&
TVM_FFI_ICHECK(allow_override) << "Binding of variable " << var << " as
" << range
<< " conflicts with previous binding as "
<< (*it).second;
if (auto key = ExprToPreviousKey(var)) {
- knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
- [&](const auto& known) { return
known.lhs_ == key.value(); }),
- knowns_.end());
+ Key old_key = key.value();
+
+ // Every entry in `knowns_by_key_[old_key]` involves old_key by
+ // construction (on either side). Remove each from its partner
+ // bucket and then drop the whole old_key bucket in one go.
+ auto idx_it = knowns_by_key_.find(old_key);
+ if (idx_it != knowns_by_key_.end()) {
+ const std::vector<Comparison>& to_remove = idx_it->second;
+ for (const auto& cmp : to_remove) {
+ Key partner_key = (cmp.lhs_ == old_key) ? cmp.rhs_ : cmp.lhs_;
+ // self-comparison (lhs_ == rhs_): only stored once, in
+ // the bucket we are about to erase.
+ if (partner_key == old_key) continue;
+ auto other_it = knowns_by_key_.find(partner_key);
+ if (other_it == knowns_by_key_.end()) continue;
+ other_it->second.erase(
Review Comment:
After erasing, consider dropping `other_it`'s bucket if it became empty (`if
(other_it->second.empty()) knowns_by_key_.erase(other_it);`). Correctness is
fine without it, but under repeated `allow_override` rebinds these empty
buckets accumulate.
##########
tests/python/arith/test_arith_transitive_comparison.py:
##########
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+"""Tests for TransitiveComparisonAnalyzer and the per-key index."""
+
+import tvm
+import tvm.ir
+import tvm.testing
+from tvm import tirx
+from tvm.script import tirx as T
+
+
+def test_single_bind_provability():
+ analyzer = tvm.arith.Analyzer()
+ x = tirx.Var("x", "int32")
+ analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 100))
+ assert analyzer.can_prove(x >= 0)
+ assert analyzer.can_prove(x < 100)
+ assert analyzer.can_prove(x <= 99)
+ assert not analyzer.can_prove(x >= 1)
+
+
+def test_many_binds_correctness_preserved():
+ analyzer = tvm.arith.Analyzer()
+ vars_ = [tirx.Var(f"v{i}", "int32") for i in range(2048)]
+ for i, v in enumerate(vars_):
+ analyzer.bind(v, tvm.ir.Range.from_min_extent(i, 10))
+ for i in (0, len(vars_) // 2, len(vars_) - 1):
+ v = vars_[i]
+ assert analyzer.can_prove(v >= i)
+ assert analyzer.can_prove(v < i + 10)
+ assert not analyzer.can_prove(v >= i + 1)
+
+
+def test_bind_override_clears_old_constraints():
+ analyzer = tvm.arith.Analyzer()
+ x = tirx.Var("x", "int32")
+ analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 100))
+ assert analyzer.can_prove(x < 100)
+ analyzer.bind(x, tvm.ir.Range.from_min_extent(200, 100),
allow_override=True)
+ assert analyzer.can_prove(x >= 200)
+ assert analyzer.can_prove(x < 300)
+ assert not analyzer.can_prove(x < 100)
+ assert not analyzer.can_prove(x < 200)
+
+
+def test_bind_override_clears_constraints_where_var_is_rhs():
+ analyzer = tvm.arith.Analyzer()
+ x = tirx.Var("x", "int32")
+ y = tirx.Var("y", "int32")
+ analyzer.bind(y, tvm.ir.Range.from_min_extent(0, 10))
+ analyzer.bind(x, y + 5)
+ assert analyzer.can_prove(x < 15)
+ analyzer.bind(y, tvm.ir.Range.from_min_extent(200, 100),
allow_override=True)
+ assert not analyzer.can_prove(x < 15)
+ assert analyzer.can_prove(x >= 205)
+
+
+def test_scoped_constraint_enter_and_exit():
+ analyzer = tvm.arith.Analyzer()
+ x = tirx.Var("x", "int32")
+ y = tirx.Var("y", "int32")
+ analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 100))
+ with analyzer.constraint_scope(y < x):
+ assert analyzer.can_prove(y < x)
+ assert not analyzer.can_prove(y < x)
+
+
+def test_cross_key_lookup():
+ analyzer = tvm.arith.Analyzer()
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ analyzer.bind(a, tvm.ir.Range.from_min_extent(0, 100))
+ with analyzer.constraint_scope(b > a):
+ assert analyzer.can_prove(a < b)
+
+
+def test_nested_constraint_scopes():
+ analyzer = tvm.arith.Analyzer()
+ x = tirx.Var("x", "int32")
+ y = tirx.Var("y", "int32")
+ z = tirx.Var("z", "int32")
+ analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 100))
+ with analyzer.constraint_scope(y < x):
+ assert analyzer.can_prove(y < x)
+ with analyzer.constraint_scope(z < y):
+ assert analyzer.can_prove(y < x)
+ assert analyzer.can_prove(z < y)
+ assert analyzer.can_prove(y < x)
+ assert not analyzer.can_prove(z < y)
+ assert not analyzer.can_prove(y < x)
+ assert not analyzer.can_prove(z < y)
+
+
+def test_unrelated_binds_do_not_match():
+ analyzer = tvm.arith.Analyzer()
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ c = tirx.Var("c", "int32")
+ d = tirx.Var("d", "int32")
+ analyzer.bind(a, tvm.ir.Range.from_min_extent(0, 10))
+ analyzer.bind(b, tvm.ir.Range.from_min_extent(0, 10))
+ analyzer.bind(c, tvm.ir.Range.from_min_extent(0, 10))
+ assert not analyzer.can_prove(a < b)
+ assert not analyzer.can_prove(b < c)
+ assert not analyzer.can_prove(c < d)
+
+
+def test_scoped_then_global_bind_interaction():
+ analyzer = tvm.arith.Analyzer()
+ y = tirx.Var("y", "int32")
+ x = tirx.Var("x", "int32")
+ with analyzer.constraint_scope(y > 0):
+ analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 100))
+ assert analyzer.can_prove(x < 100)
+ assert analyzer.can_prove(y > 0)
+ assert not analyzer.can_prove(y > 0)
+ assert analyzer.can_prove(x < 100)
+
+
+def test_self_comparison_indexed_once():
+ # `x == x` produces a Comparison with lhs_ == rhs_; IndexKnown
+ # must store it once, not twice.
+ analyzer = tvm.arith.Analyzer()
+ x = tirx.Var("x", "int32")
+ with analyzer.constraint_scope(x == x):
+ assert analyzer.can_prove(x == x)
+ analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 10))
+ assert analyzer.can_prove(x >= 0)
+ assert analyzer.can_prove(x < 10)
+
+
+def test_transitively_prove_inequalities_uses_dfs_path():
+ # `i < j` and `j < k` (from For ranges) compose into `i < k` only
+ # when the DFS path runs (transitively_prove_inequalities=True).
+
+ @T.prim_func
+ def before(A: T.Buffer((1,), "int32")):
+ for i in T.serial(0, 50):
+ for j in T.serial(i + 1, 50):
+ for k in T.serial(j + 1, 50):
+ if i < k:
+ A[0] = 1
+ else:
+ A[0] = 0
+
+ @T.prim_func
+ def after_dfs(A: T.Buffer((1,), "int32")):
+ T.func_attr({"global_symbol": "before"})
+ for i in T.serial(0, 50):
+ for j in T.serial(i + 1, 50):
+ for k in T.serial(j + 1, 50):
+ A[0] = 1
+
+ mod = tvm.IRModule({"main": before})
+ expected = tvm.IRModule({"main": after_dfs})
+
+ with tvm.transform.PassContext(
+ config={"tirx.Simplify": {"transitively_prove_inequalities": True}}
+ ):
+ out_with_dfs = tvm.tirx.transform.Simplify()(mod)
Review Comment:
This branch is behind than `main`, where `tvm.tirx.transform.Simplify` /
config `"tirx.Simplify"` were renamed to `StmtSimplify` /
`"tirx.StmtSimplify"`. I think that after rebasing onto current `main` this
test will fail unless `Simplify()` → `StmtSimplify()` and `"tirx.Simplify"` →
`"tirx.StmtSimplify"`.
##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -566,20 +619,38 @@ void TransitiveComparisonAnalyzer::Impl::Bind(const
tirx::Var& var, const Range&
TVM_FFI_ICHECK(allow_override) << "Binding of variable " << var << " as
" << range
<< " conflicts with previous binding as "
<< (*it).second;
if (auto key = ExprToPreviousKey(var)) {
- knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
- [&](const auto& known) { return
known.lhs_ == key.value(); }),
- knowns_.end());
+ Key old_key = key.value();
+
+ // Every entry in `knowns_by_key_[old_key]` involves old_key by
+ // construction (on either side). Remove each from its partner
+ // bucket and then drop the whole old_key bucket in one go.
+ auto idx_it = knowns_by_key_.find(old_key);
+ if (idx_it != knowns_by_key_.end()) {
+ const std::vector<Comparison>& to_remove = idx_it->second;
+ for (const auto& cmp : to_remove) {
+ Key partner_key = (cmp.lhs_ == old_key) ? cmp.rhs_ : cmp.lhs_;
+ // self-comparison (lhs_ == rhs_): only stored once, in
+ // the bucket we are about to erase.
+ if (partner_key == old_key) continue;
+ auto other_it = knowns_by_key_.find(partner_key);
+ if (other_it == knowns_by_key_.end()) continue;
+ other_it->second.erase(
+ std::remove(other_it->second.begin(), other_it->second.end(),
cmp),
+ other_it->second.end());
+ }
+ knowns_by_key_.erase(idx_it);
+ }
}
}
}
prev_bindings_.Set(var, range);
Review Comment:
When `it != end` and `differs_from_previous == false` (re-`Bind` of the same
var with the same range), we skip the cleanup but still fall through to
`prev_bindings_.Set(...)` and `AddKnown(...)`. That re-adds the same
comparisons into `knowns_by_key_` under both keys, so duplicates accumulate
across repeated binds. That was inflating exactly the buckets this PR is trying
to keep small (and doubling the cost vs. the old flat `knowns_`). Suggest
gating the `Set` + `AddKnown` block to run only when the binding is new or
actually changed. (gemini-code-assist raised this earlier; still looks
unaddressed.)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]