This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9f907f5cd0 [Arith] Add Analyzer::Clone for deep-copying analyzer state
(#19836)
9f907f5cd0 is described below
commit 9f907f5cd099891031a6e8eecb0bb22b963d1dc2
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Jun 19 07:16:19 2026 -0400
[Arith] Add Analyzer::Clone for deep-copying analyzer state (#19836)
Copying an Analyzer handle shares the same mutable AnalyzerObj, so a
pass had no way to snapshot accumulated facts (variable bounds, modular
sets, rewrite/canonical bindings, integer-set domains, literal
constraints, transitive comparisons) and keep exploring without mutating
the original.
This pr adds AnalyzerObj::Clone(), which allocates a fresh AnalyzerObj
and copies each sub-analyzer's persistent state through a new
per-sub-analyzer CopyFrom. Parent back-pointers are re-established by
the fresh constructor rather than copied, and per-query/recursion
scratch state is left default. Exposed to Python as Analyzer.clone().
---
include/tvm/arith/analyzer.h | 26 ++++++++
python/tvm/arith/analyzer.py | 18 ++++++
src/arith/analyzer.cc | 12 ++++
src/arith/canonical_simplify.cc | 4 ++
src/arith/const_int_bound.cc | 9 +++
src/arith/int_set.cc | 7 +++
src/arith/modular_set.cc | 6 ++
src/arith/rewrite_simplify.cc | 2 +
src/arith/rewrite_simplify.h | 7 +++
src/arith/transitive_comparison_analyzer.cc | 11 ++++
tests/cpp/arith_simplify_test.cc | 21 +++++++
tests/python/arith/test_arith_analyzer_object.py | 79 ++++++++++++++++++++++++
12 files changed, 202 insertions(+)
diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index e635315e67..9aca5c1189 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -181,6 +181,7 @@ class ConstIntBoundAnalyzer {
friend class ConstraintContext;
explicit ConstIntBoundAnalyzer(AnalyzerObj* parent);
TVM_DLL ~ConstIntBoundAnalyzer();
+ void CopyFrom(const ConstIntBoundAnalyzer& other);
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
@@ -260,6 +261,7 @@ class ModularSetAnalyzer {
friend class ConstraintContext;
explicit ModularSetAnalyzer(AnalyzerObj* parent);
TVM_DLL ~ModularSetAnalyzer();
+ void CopyFrom(const ModularSetAnalyzer& other);
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
@@ -414,6 +416,7 @@ class RewriteSimplifier {
friend class CanonicalSimplifier;
explicit RewriteSimplifier(AnalyzerObj* parent);
TVM_DLL ~RewriteSimplifier();
+ void CopyFrom(const RewriteSimplifier& other);
class Impl;
/*! \brief Internal impl */
Impl* impl_;
@@ -445,6 +448,7 @@ class CanonicalSimplifier {
friend class ConstraintContext;
explicit CanonicalSimplifier(AnalyzerObj* parent);
TVM_DLL ~CanonicalSimplifier();
+ void CopyFrom(const CanonicalSimplifier& other);
class Impl;
/*! \brief Internal impl */
Impl* impl_;
@@ -530,6 +534,7 @@ class TransitiveComparisonAnalyzer {
friend class ConstraintContext;
TransitiveComparisonAnalyzer();
TVM_DLL ~TransitiveComparisonAnalyzer();
+ void CopyFrom(const TransitiveComparisonAnalyzer& other);
class Impl;
/*! \brief Internal impl */
std::unique_ptr<Impl> impl_;
@@ -584,6 +589,7 @@ class IntSetAnalyzer {
friend class AnalyzerObj;
explicit IntSetAnalyzer(AnalyzerObj* parent);
TVM_DLL ~IntSetAnalyzer();
+ void CopyFrom(const IntSetAnalyzer& other);
class Impl;
/*! \brief Internal impl */
Impl* impl_;
@@ -854,6 +860,26 @@ class TVM_DLL AnalyzerObj : public ffi::Object {
*/
PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
+ /*!
+ * \brief Deep-copy this analyzer into a new, independent Analyzer.
+ *
+ * The returned analyzer carries the same accumulated facts (variable
+ * bounds, modular sets, rewrite/canonical bindings, integer-set domains,
+ * literal constraints and transitive comparisons) as this one, but owns
+ * its own state: binding or simplifying on either analyzer afterwards does
+ * not affect the other. This is the deep copy that handle-copying an
+ * Analyzer does not provide.
+ *
+ * \note Do not call this while a `With<ConstraintContext>` scope is active
+ * on this analyzer. The clone would inherit the scoped constraints
+ * but not the recovery functions that pop them on scope exit, so the
+ * constraints would leak as if they were global facts. Clone at a
+ * point where no constraint scope is in effect.
+ *
+ * \return A new Analyzer holding an independent copy of the facts.
+ */
+ Analyzer Clone() const;
+
/*!
* \brief Analyzer methods update facts, constraints, caches, and stats.
*
diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py
index 78e93395c3..d82cae0129 100644
--- a/python/tvm/arith/analyzer.py
+++ b/python/tvm/arith/analyzer.py
@@ -278,6 +278,24 @@ class Analyzer(Object):
"""
return _ffi_api.AnalyzerSimplify(self, expr, steps)
+ def clone(self) -> "Analyzer":
+ """Return a deep copy of this analyzer with independent state.
+
+ The returned analyzer carries the same accumulated facts (variable
+ bounds, modular sets, bindings, integer-set domains, literal
+ constraints and transitive comparisons) as this one, but owns its own
+ state: binding or simplifying on either analyzer afterwards does not
+ affect the other. Unlike copying the handle, this is a true deep copy.
+
+ Do not call this while a constraint scope is active on this analyzer.
+
+ Returns
+ -------
+ result : Analyzer
+ A new analyzer holding an independent copy of the facts.
+ """
+ return _ffi_api.AnalyzerClone(self)
+
def rewrite_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr:
"""Simplify expression via rewriting rules.
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index b66ecb0fd1..fc59f891e1 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -265,11 +265,23 @@ PrimExpr AnalyzerObj::Simplify(const PrimExpr& expr, int
steps) {
return res;
}
+Analyzer AnalyzerObj::Clone() const {
+ Analyzer cloned;
+ cloned->const_int_bound.CopyFrom(this->const_int_bound);
+ cloned->modular_set.CopyFrom(this->modular_set);
+ cloned->rewrite_simplify.CopyFrom(this->rewrite_simplify);
+ cloned->canonical_simplify.CopyFrom(this->canonical_simplify);
+ cloned->int_set.CopyFrom(this->int_set);
+ cloned->transitive_comparisons.CopyFrom(this->transitive_comparisons);
+ return cloned;
+}
+
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<AnalyzerObj>();
refl::GlobalDef()
.def("arith.Analyzer", []() { return Analyzer(); })
+ .def("arith.AnalyzerClone", [](Analyzer analyzer) { return
analyzer->Clone(); })
.def("arith.AnalyzerConstIntBound",
[](Analyzer analyzer, const PrimExpr& expr) { return
analyzer->const_int_bound(expr); })
.def("arith.AnalyzerConstIntBoundUpdate",
diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index f1dd1a63c5..7806c23445 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -1454,5 +1454,9 @@ CanonicalSimplifier::CanonicalSimplifier(AnalyzerObj*
parent) : impl_(new Impl(p
CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; }
+void CanonicalSimplifier::CopyFrom(const CanonicalSimplifier& other) {
+ impl_->CopyFrom(*other.impl_);
+}
+
} // namespace arith
} // namespace tvm
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 8ff1a8b17e..4d700564ea 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -498,6 +498,11 @@ class ConstIntBoundAnalyzer::Impl
return frecover;
}
+ void CopyFrom(const Impl& other) {
+ var_map_ = other.var_map_;
+ additional_info_ = other.additional_info_;
+ }
+
private:
friend class ConstIntBoundAnalyzer;
// parent analyzer
@@ -859,5 +864,9 @@ ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(AnalyzerObj*
parent) : impl_(new Im
ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; }
+void ConstIntBoundAnalyzer::CopyFrom(const ConstIntBoundAnalyzer& other) {
+ impl_->CopyFrom(*other.impl_);
+}
+
} // namespace arith
} // namespace tvm
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index a1e01d3e86..b68042e2af 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -658,6 +658,11 @@ class IntSetAnalyzer::Impl {
void Bind(const Var& var, const PrimExpr& expr, bool override_info);
std::function<void()> EnterConstraint(const PrimExpr& constraint);
+ void CopyFrom(const Impl& other) {
+ dom_map_ = other.dom_map_;
+ dom_constraints_ = other.dom_constraints_;
+ }
+
private:
// Utility function to split a boolean condition into the domain
// bounds implied by that condition.
@@ -681,6 +686,8 @@ IntSetAnalyzer::IntSetAnalyzer(AnalyzerObj* parent) :
impl_(new Impl(parent)) {}
IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; }
+void IntSetAnalyzer::CopyFrom(const IntSetAnalyzer& other) {
impl_->CopyFrom(*other.impl_); }
+
IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const ffi::Map<Var,
IntSet>& dom_map) {
return impl_->Eval(expr, dom_map);
}
diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc
index 5f66356e1a..856f5df0b7 100644
--- a/src/arith/modular_set.cc
+++ b/src/arith/modular_set.cc
@@ -310,6 +310,8 @@ class ModularSetAnalyzer::Impl : public
ExprFunctor<ModularSetAnalyzer::Entry(co
return Everything();
}
+ void CopyFrom(const Impl& other) { var_map_ = other.var_map_; }
+
private:
/*! \brief pointer to parent. */
AnalyzerObj* parent_{nullptr};
@@ -407,5 +409,9 @@ ModularSetAnalyzer::ModularSetAnalyzer(AnalyzerObj* parent)
: impl_(new Impl(par
ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; }
+void ModularSetAnalyzer::CopyFrom(const ModularSetAnalyzer& other) {
+ impl_->CopyFrom(*other.impl_);
+}
+
} // namespace arith
} // namespace tvm
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 6d6ce03016..a1bbce1072 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -2466,6 +2466,8 @@ RewriteSimplifier::RewriteSimplifier(AnalyzerObj* parent)
: impl_(new Impl(paren
RewriteSimplifier::~RewriteSimplifier() { delete impl_; }
+void RewriteSimplifier::CopyFrom(const RewriteSimplifier& other) {
impl_->CopyFrom(*other.impl_); }
+
// Pattern A (RM): auto-default repr from reflection.
} // namespace arith
diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h
index b42b73336a..719aa5ec07 100644
--- a/src/arith/rewrite_simplify.h
+++ b/src/arith/rewrite_simplify.h
@@ -135,6 +135,13 @@ class RewriteSimplifier::Impl : public
IRMutatorWithAnalyzer {
void SetMaximumRewriteSteps(int64_t maximum) { maximum_rewrite_steps_ =
maximum; }
+ void CopyFrom(const Impl& other) {
+ var_map_ = other.var_map_;
+ literal_constraints_ = other.literal_constraints_;
+ enabled_extensions_ = other.enabled_extensions_;
+ maximum_rewrite_steps_ = other.maximum_rewrite_steps_;
+ }
+
protected:
int64_t maximum_rewrite_steps_{0};
RewriteSimplifierStatsNode stats_;
diff --git a/src/arith/transitive_comparison_analyzer.cc
b/src/arith/transitive_comparison_analyzer.cc
index e7deea4cfd..20fd05169f 100644
--- a/src/arith/transitive_comparison_analyzer.cc
+++ b/src/arith/transitive_comparison_analyzer.cc
@@ -82,6 +82,13 @@ class TransitiveComparisonAnalyzer::Impl {
*/
std::function<void()> EnterConstraint(const PrimExpr& expr);
+ void CopyFrom(const Impl& other) {
+ expr_to_key = other.expr_to_key;
+ prev_bindings_ = other.prev_bindings_;
+ knowns_ = other.knowns_;
+ scoped_knowns_ = other.scoped_knowns_;
+ }
+
private:
/* \brief Internal representation of a PrimExpr
*
@@ -528,6 +535,10 @@ bool
TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() :
impl_(std::make_unique<Impl>()) {}
TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
+void TransitiveComparisonAnalyzer::CopyFrom(const
TransitiveComparisonAnalyzer& other) {
+ impl_->CopyFrom(*other.impl_);
+}
+
CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs,
const PrimExpr& rhs,
bool
propagate_inequalities) {
return impl_->TryCompare(lhs, rhs, propagate_inequalities);
diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc
index ba5305e9dd..d5050446d6 100644
--- a/tests/cpp/arith_simplify_test.cc
+++ b/tests/cpp/arith_simplify_test.cc
@@ -75,6 +75,27 @@ TEST(AnalyzerObjectRef,
ConstHandleRefCanMutateAnalyzerState) {
TVM_FFI_ICHECK(analyzer->CanProve(x < 8));
}
+TEST(AnalyzerObjectRef, CloneIsIndependent) {
+ tvm::arith::Analyzer analyzer;
+ auto x = tvm::te::var("x");
+ auto y = tvm::te::var("y");
+
+ analyzer->Bind(x, tvm::Range::FromMinExtent(0, 8));
+ analyzer->modular_set.Update(x, tvm::arith::ModularSet(4, 0));
+
+ tvm::arith::Analyzer clone = analyzer->Clone();
+ TVM_FFI_ICHECK(clone->CanProve(x < 8));
+ TVM_FFI_ICHECK(clone->modular_set(x)->coeff == 4);
+
+ clone->Bind(y, tvm::Range::FromMinExtent(0, 4));
+ clone->modular_set.Update(x, tvm::arith::ModularSet(8, 0), true);
+ TVM_FFI_ICHECK(clone->CanProve(y < 4));
+ TVM_FFI_ICHECK(!analyzer->CanProve(y < 4));
+ TVM_FFI_ICHECK(analyzer->CanProve(x < 8));
+ TVM_FFI_ICHECK(analyzer->modular_set(x)->coeff == 4);
+ TVM_FFI_ICHECK(clone->modular_set(x)->coeff == 8);
+}
+
TEST(ConstantFold, Broadcast) {
tvm::ffi::StructuralEqual checker;
auto i32x4 = tvm::tirx::Broadcast(tvm::IntImm::Int32(10), 4);
diff --git a/tests/python/arith/test_arith_analyzer_object.py
b/tests/python/arith/test_arith_analyzer_object.py
index 4b4c4134b9..9edd75d7aa 100644
--- a/tests/python/arith/test_arith_analyzer_object.py
+++ b/tests/python/arith/test_arith_analyzer_object.py
@@ -204,5 +204,84 @@ def test_analyzer_object_state_persists_across_ffi_calls():
tvm.ir.assert_structural_equal(analyzer.simplify(tile), tvm.tirx.const(8,
"int32"))
+def test_analyzer_object_clone_is_independent():
+ analyzer = tvm.arith.Analyzer()
+ x = tirx.Var("x", "int64")
+ y = tirx.Var("y", "int64")
+ z = tirx.Var("z", "int64")
+
+ analyzer.bind(x, tvm.ir.Range(0, 8))
+
+ clone = analyzer.clone()
+ assert clone is not analyzer
+ assert clone.can_prove(x < 8)
+
+ clone.bind(y, tvm.ir.Range(0, 4))
+ assert clone.can_prove(y < 4)
+ assert not analyzer.can_prove(y < 4)
+
+ analyzer.bind(z, tvm.ir.Range(0, 4))
+ assert analyzer.can_prove(z < 4)
+ assert not clone.can_prove(z < 4)
+
+ assert analyzer.can_prove(x < 8)
+ assert clone.can_prove(x < 8)
+
+
+def test_analyzer_object_clone_copies_every_sub_analyzer():
+ analyzer = tvm.arith.Analyzer()
+ x = tirx.Var("x", "int64")
+ w = tirx.Var("w", "int64")
+ v = tirx.Var("v", "int64")
+
+ analyzer.bind(x, tvm.ir.Range(0, 8))
+ analyzer.update(x, tvm.arith.ModularSet(4, 0))
+ analyzer.bind(w, tirx.const(4, "int64"))
+ analyzer.update(v, tvm.arith.IntervalSet(2, 9))
+ analyzer.enabled_extensions = Extension.ComparisonOfProductAndSum
+
+ clone = analyzer.clone()
+
+ assert clone.can_prove(x < 8)
+ assert clone.modular_set(x).coeff == 4
+ tvm.ir.assert_structural_equal(clone.simplify(w + 1), tirx.const(5,
"int64"))
+ assert clone.int_set(v).max_value.value == 9
+ assert clone.enabled_extensions == Extension.ComparisonOfProductAndSum
+ assert clone.try_compare(x, tirx.const(0, "int64")) == CompareResult.GE
+
+ t = tirx.Var("t", "int64")
+ clone.update(x, tvm.arith.ModularSet(8, 0), override=True)
+ clone.update(v, tvm.arith.IntervalSet(0, 3), override=True)
+ clone.bind(w, tirx.const(8, "int64"), allow_override=True)
+ clone.bind(t, tvm.ir.Range(0, 4))
+ clone.enabled_extensions = Extension.NoExtensions
+
+ assert analyzer.modular_set(x).coeff == 4
+ assert clone.modular_set(x).coeff == 8
+ assert analyzer.int_set(v).max_value.value == 9
+ assert clone.int_set(v).max_value.value == 3
+ tvm.ir.assert_structural_equal(analyzer.simplify(w + 1), tirx.const(5,
"int64"))
+ tvm.ir.assert_structural_equal(clone.simplify(w + 1), tirx.const(9,
"int64"))
+ assert analyzer.enabled_extensions == Extension.ComparisonOfProductAndSum
+ assert clone.enabled_extensions == Extension.NoExtensions
+ assert clone.try_compare(t, tirx.const(0, "int64")) == CompareResult.GE
+ assert analyzer.try_compare(t, tirx.const(0, "int64")) ==
CompareResult.UNKNOWN
+
+
+def test_analyzer_object_clone_resets_rewrite_stats():
+ analyzer = tvm.arith.Analyzer()
+ x = tirx.Var("x", "int64")
+ y = tirx.Var("y", "int64")
+ analyzer.bind(x, tvm.ir.Range(0, 8))
+ analyzer.bind(y, tvm.ir.Range(0, 8))
+ analyzer.simplify((x + y) * 2 - x - y)
+ source_attempts = analyzer.rewrite_simplify_stats.rewrites_attempted
+ assert source_attempts > 0
+
+ clone = analyzer.clone()
+ assert clone.rewrite_simplify_stats.rewrites_attempted == 0
+ assert analyzer.rewrite_simplify_stats.rewrites_attempted ==
source_attempts
+
+
if __name__ == "__main__":
tvm.testing.main()