gemini-code-assist[bot] commented on code in PR #19584:
URL: https://github.com/apache/tvm/pull/19584#discussion_r3258206678


##########
src/arith/transitive_comparison_analyzer.cc:
##########
@@ -566,20 +619,45 @@ void TransitiveComparisonAnalyzer::Impl::Bind(const 
tirx::Var& var, const Range&
       TVM_FFI_ICHECK(allow_override) << "Binding of variable " << var << " as 
" << range
                                      << " conflicts with previous binding as " 
<< (*it).second;
       if (auto key = ExprToPreviousKey(var)) {
-        knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
-                                     [&](const auto& known) { return 
known.lhs_ == key.value(); }),
-                      knowns_.end());
+        Key old_key = key.value();
+
+        // For every Comparison previously stored with lhs_==old_key,
+        // drop it from old_key's bucket AND from its companion entry
+        // stored under `cmp.rhs_`.  Snapshot the matches first because
+        // the source bucket and the partner buckets may overlap.
+        auto idx_it = knowns_by_key_.find(old_key);
+        if (idx_it != knowns_by_key_.end()) {
+          std::vector<Comparison> to_remove;
+          for (const auto& cmp : idx_it->second) {
+            if (cmp.lhs_ == old_key) to_remove.push_back(cmp);
+          }
+          for (const auto& cmp : to_remove) {
+            // self-comparison (lhs_==rhs_): only stored once, handled
+            // by the bucket-erase below.
+            if (cmp.rhs_ == old_key) continue;
+            auto other_it = knowns_by_key_.find(cmp.rhs_);
+            if (other_it == knowns_by_key_.end()) continue;
+            other_it->second.erase(
+                std::remove(other_it->second.begin(), other_it->second.end(), 
cmp),
+                other_it->second.end());
+          }
+          idx_it->second.erase(std::remove_if(idx_it->second.begin(), 
idx_it->second.end(),
+                                              [&](const Comparison& c) {
+                                                return c.lhs_ == old_key;
+                                              }),
+                               idx_it->second.end());

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The current cleanup logic only removes comparisons where the overridden 
variable (`old_key`) is the `lhs_`. However, since `IndexKnown` stores 
comparisons under both keys, and `Bind` can theoretically involve other 
variables in its expression, `old_key` might appear as the `rhs_` in some 
comparisons. To maintain index consistency and correctness, any comparison 
involving the overridden variable should be removed from both its own bucket 
and its partner's bucket.
   
   Additionally, since every entry in `idx_it->second` by definition involves 
`old_key`, we can simplify the logic by clearing the entire bucket after 
updating the partner buckets.
   
   ```c
             std::vector<Comparison> to_remove = idx_it->second;
             for (const auto& cmp : to_remove) {
               Key partner_key = (cmp.lhs_ == old_key) ? cmp.rhs_ : cmp.lhs_;
               if (partner_key == old_key) continue;
               auto other_it = knowns_by_key_.find(partner_key);
               if (other_it != knowns_by_key_.end()) {
                 other_it->second.erase(
                     std::remove(other_it->second.begin(), 
other_it->second.end(), cmp),
                     other_it->second.end());
               }
             }
             knowns_by_key_.erase(idx_it);
   ```



##########
tests/python/arith/test_arith_transitive_comparison.py:
##########
@@ -0,0 +1,170 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+"""Tests for TransitiveComparisonAnalyzer and the per-key index."""
+
+import pytest
+
+import tvm
+import tvm.ir
+import tvm.testing
+from tvm import tirx
+from tvm.script import tirx as T
+
+
+def test_single_bind_provability():
+    analyzer = tvm.arith.Analyzer()
+    x = tirx.Var("x", "int32")
+    analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 100))
+    assert analyzer.can_prove(x >= 0)
+    assert analyzer.can_prove(x < 100)
+    assert analyzer.can_prove(x <= 99)
+    assert not analyzer.can_prove(x >= 1)
+
+
+def test_many_binds_correctness_preserved():
+    analyzer = tvm.arith.Analyzer()
+    vars_ = [tirx.Var(f"v{i}", "int32") for i in range(2048)]
+    for i, v in enumerate(vars_):
+        analyzer.bind(v, tvm.ir.Range.from_min_extent(i, 10))
+    for i in (0, len(vars_) // 2, len(vars_) - 1):
+        v = vars_[i]
+        assert analyzer.can_prove(v >= i)
+        assert analyzer.can_prove(v < i + 10)
+        assert not analyzer.can_prove(v >= i + 1)
+
+
+def test_bind_override_clears_old_constraints():
+    analyzer = tvm.arith.Analyzer()
+    x = tirx.Var("x", "int32")
+    analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 100))
+    assert analyzer.can_prove(x < 100)
+    analyzer.bind(x, tvm.ir.Range.from_min_extent(50, 10), allow_override=True)
+    assert analyzer.can_prove(x >= 50)
+    assert analyzer.can_prove(x < 60)
+    assert analyzer.can_prove(x >= 0)
+    assert analyzer.can_prove(x < 100)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This test case does not effectively verify that old constraints are cleared. 
Because the new range `[50, 60)` is a subset of the old range `[0, 100)`, the 
old constraints (`x >= 0` and `x < 100`) remain provable even if they were 
correctly removed from the analyzer. 
   
   To properly test the cleanup logic, consider using a new range that does not 
imply the old one, such as overriding `[0, 100)` with `[200, 300)`. In that 
case, `analyzer.can_prove(x < 100)` should return `False` after the override.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to