This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e7b87fe6fc [ARITH] Add optional Z3-backed proving to Analyzer (#19667)
e7b87fe6fc is described below
commit e7b87fe6fca575ffc1a1cd5c2be0d829bb2e3f3b
Author: Yixin Dong <[email protected]>
AuthorDate: Wed Jun 17 20:39:39 2026 -0400
[ARITH] Add optional Z3-backed proving to Analyzer (#19667)
## Summary
This PR adds a Z3 SMT solver backend to `tvm::arith::Analyzer` for
stronger integer arithmetic proving.
The integration is guarded by `USE_Z3`, which defaults to `AUTO`. In the
default mode, TVM enables Z3 when the static Z3 development artifacts
are available and otherwise builds the conservative stub implementation.
When Z3 is enabled, `Analyzer::CanProve` runs the existing TVM
arithmetic analysis path first, then falls back to Z3 only when the
existing analyzers cannot prove the predicate and the requested strength
is `kSymbolicBound`. Z3 is linked statically from the PyPI `z3-static`
package, so `libtvm` does not need a runtime `libz3` dependency.
## Features
- Z3 build support through `USE_Z3`, defaulting to `AUTO`.
- A new `arith::Z3Prover` sub-analyzer owned by `arith::Analyzer`.
- SMT-LIB2 export for debugging and external solver reproduction.
- Python debug/config APIs: `Analyzer.get_smtlib2`,
`Analyzer.set_z3_timeout_ms`, `Analyzer.set_z3_rlimit`, and
`Analyzer.get_z3_stats`.
- C++ APIs for proving, binding, constraints, stats, model inspection,
and satisfying-value counting.
- Scalar integer, unsigned integer, and boolean expression translation
to Z3.
- Support for arithmetic, comparisons, boolean operators, `min`, `max`,
`select`, `if_then_else`, `let`, casts, truncated division/modulo, floor
division/modulo, and selected bitwise/shift operations.
- Deterministic solver control using Z3 `rlimit`, with `random_seed`
fixed to `42`.
- Thread-local Z3 context sharing to reduce initialization overhead
while keeping thread safety.
- A disabled-mode stub implementation that returns conservative results
when Z3 is not built.
## Implementation Notes
- The real and stub implementations live in `src/arith/z3_prover.cc`,
selected by the `TVM_USE_Z3` macro from
`cmake/modules/contrib/Z3.cmake`.
- `cmake/modules/contrib/Z3.cmake` first resolves the PIC static `libz3`
layout provided by `z3-static` using its `z3_static.get_cmake_dir()`
helper, then falls back to a custom `Z3_DIR` or `CMAKE_PREFIX_PATH`
installation.
- `USE_Z3=ON` requires Z3 to be found, while `USE_Z3=AUTO` allows source
builds and CI jobs without Z3 artifacts to continue with the stub.
- The Z3 fallback is exception-safe and gated behind `kSymbolicBound`,
so the common `kDefault` path does not pay solver cost.
- TVM `Div` and `Mod` are translated with truncating helpers rather than
Z3's Euclidean operators to stay sound for negative dividends.
- Shift handling relies on Z3's native bit-vector semantics and does not
add hard assertions to the shared solver.
## References
The implementation is based on the Z3 analyzer integration used in
TileLang's TVM fork, with the upstream port kept scoped to TVM's
arithmetic analyzer.
- [tile-ai/tilelang#1367](https://github.com/tile-ai/tilelang/pull/1367)
- [tile-ai/tilelang#1458](https://github.com/tile-ai/tilelang/pull/1458)
- [tile-ai/tilelang#2216](https://github.com/tile-ai/tilelang/pull/2216)
- [TileLang/tvm#22](https://github.com/TileLang/tvm/pull/22)
- [TileLang/tvm#24](https://github.com/TileLang/tvm/pull/24)
- [Original TileLang TVM
commit](https://github.com/tile-ai/tvm/commit/e633295de994a89668d7a9930dbbd455af3efc66)
---------
Signed-off-by: Ubospica <[email protected]>
---
CMakeLists.txt | 5 +
cmake/modules/LLVM.cmake | 13 +
cmake/modules/contrib/Z3.cmake | 93 ++++
include/tvm/arith/analyzer.h | 130 +++++-
pyproject.toml | 9 +-
python/tvm/arith/analyzer.py | 89 +++-
src/arith/analyzer.cc | 28 +-
src/arith/z3_prover.cc | 864 ++++++++++++++++++++++++++++++++++++
tests/python/arith/test_arith_z3.py | 756 +++++++++++++++++++++++++++++++
9 files changed, 1982 insertions(+), 5 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 57d76ceb66..567edc1dc6 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -89,6 +89,7 @@ tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT"
"3rdparty/compiler-rt")
# Contrib library options
tvm_option(USE_BLAS "The blas library to be linked" none)
tvm_option(USE_AMX "Enable Intel AMX" OFF)
+tvm_option(USE_Z3 "Build with Z3 SMT solver support" AUTO)
tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
tvm_option(USE_DNNL "Enable DNNL codegen" OFF)
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
@@ -459,6 +460,7 @@ include(cmake/modules/contrib/AMX.cmake)
include(cmake/modules/contrib/CUTLASS.cmake)
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/Sort.cmake)
+include(cmake/modules/contrib/Z3.cmake)
include(cmake/modules/contrib/CoreML.cmake)
include(cmake/modules/contrib/TensorRT.cmake)
include(cmake/modules/contrib/NNAPI.cmake)
@@ -545,6 +547,9 @@ add_library(tvm_objs OBJECT ${COMPILER_SRCS})
add_library(tvm_runtime_objs OBJECT ${RUNTIME_SRCS})
target_link_libraries(tvm_objs PUBLIC tvm_ffi_header)
target_link_libraries(tvm_runtime_objs PUBLIC tvm_ffi_header)
+if(TARGET tvm_llvm_header)
+ target_link_libraries(tvm_objs PUBLIC tvm_llvm_header)
+endif()
include(GNUInstallDirs)
diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake
index 8ad8bfc53c..a7fa4233ac 100644
--- a/cmake/modules/LLVM.cmake
+++ b/cmake/modules/LLVM.cmake
@@ -34,6 +34,19 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN})
endif()
include_directories(SYSTEM ${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
+ add_library(tvm_llvm_header INTERFACE)
+ if(MSVC)
+ # MSVC treats GCC-style -isystem operands as source files.
+ target_include_directories(tvm_llvm_header SYSTEM INTERFACE
${LLVM_INCLUDE_DIRS})
+ target_compile_options(tvm_llvm_header INTERFACE ${LLVM_DEFINITIONS})
+ else()
+ set(TVM_LLVM_INCLUDE_FLAGS "")
+ foreach(__llvm_include_dir IN LISTS LLVM_INCLUDE_DIRS)
+ string(STRIP "${__llvm_include_dir}" __llvm_include_dir)
+ list(APPEND TVM_LLVM_INCLUDE_FLAGS "-isystem" "${__llvm_include_dir}")
+ endforeach()
+ target_compile_options(tvm_llvm_header INTERFACE ${TVM_LLVM_INCLUDE_FLAGS}
${LLVM_DEFINITIONS})
+ endif()
message(STATUS "Build with LLVM " ${LLVM_PACKAGE_VERSION})
message(STATUS "Set TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION})
# Set flags that are only needed for LLVM target
diff --git a/cmake/modules/contrib/Z3.cmake b/cmake/modules/contrib/Z3.cmake
new file mode 100644
index 0000000000..5d9af4408f
--- /dev/null
+++ b/cmake/modules/contrib/Z3.cmake
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# src/arith/z3_prover.cc is always part of COMPILER_SRCS (picked up by the
+# src/arith/*.cc glob). It compiles a conservative stub by default and switches
+# to the real Z3 implementation only when the TVM_USE_Z3 macro is defined
below.
+if(${USE_Z3} MATCHES ${IS_FALSE_PATTERN})
+ return()
+endif()
+
+set(TVM_Z3_REQUIRED TRUE)
+if("${USE_Z3}" MATCHES "^[Aa][Uu][Tt][Oo]$")
+ set(TVM_Z3_REQUIRED FALSE)
+endif()
+
+# Default lookup: the PIC static Z3 library shipped by the PyPI `z3-static`
+# package (headers + libz3.a + Z3 CMake package files). Linking it statically
+# keeps libtvm free of a runtime libz3 dependency. Users can override the
+# lookup by setting Z3_DIR/CMAKE_PREFIX_PATH to any Z3 installation (e.g. a
+# shared system Z3).
+if(NOT Z3_DIR)
+ find_package(Python3 COMPONENTS Interpreter QUIET)
+ if(Python3_EXECUTABLE)
+ execute_process(
+ COMMAND
+ "${Python3_EXECUTABLE}" -m z3_static.config --cmake-dir
+ OUTPUT_VARIABLE Z3_STATIC_CMAKE_DIR
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ ERROR_QUIET
+ RESULT_VARIABLE Z3_STATIC_RESULT
+ )
+ if(Z3_STATIC_RESULT EQUAL 0 AND EXISTS "${Z3_STATIC_CMAKE_DIR}")
+ set(Z3_DIR "${Z3_STATIC_CMAKE_DIR}")
+ endif()
+ endif()
+endif()
+
+find_package(Z3 CONFIG QUIET)
+if(NOT Z3_FOUND AND NOT TARGET z3::libz3 AND NOT TARGET Z3::libz3)
+ find_package(Z3 QUIET)
+endif()
+
+if(TARGET z3::libz3 OR TARGET Z3::libz3)
+ if(TARGET z3::libz3)
+ set(Z3_TARGET z3::libz3)
+ else()
+ set(Z3_TARGET Z3::libz3)
+ endif()
+ get_target_property(Z3_TARGET_INCLUDE_DIRS ${Z3_TARGET}
INTERFACE_INCLUDE_DIRECTORIES)
+ if(Z3_TARGET_INCLUDE_DIRS)
+ include_directories(SYSTEM ${Z3_TARGET_INCLUDE_DIRS})
+ endif()
+ list(APPEND TVM_LINKER_LIBS ${Z3_TARGET})
+elseif(Z3_FOUND OR (Z3_INCLUDE_DIR AND Z3_LIBRARY))
+ if(NOT Z3_INCLUDE_DIR AND Z3_CXX_INCLUDE_DIRS)
+ set(Z3_INCLUDE_DIR ${Z3_CXX_INCLUDE_DIRS})
+ endif()
+ if(NOT Z3_LIBRARY AND Z3_LIBRARIES)
+ set(Z3_LIBRARY ${Z3_LIBRARIES})
+ endif()
+ if(NOT Z3_INCLUDE_DIR OR NOT Z3_LIBRARY)
+ message(FATAL_ERROR "USE_Z3 is ON, but Z3 include directory or library was
not found.")
+ endif()
+ include_directories(SYSTEM ${Z3_INCLUDE_DIR})
+ list(APPEND TVM_LINKER_LIBS ${Z3_LIBRARY})
+else()
+ if(TVM_Z3_REQUIRED)
+ message(FATAL_ERROR
+ "USE_Z3 is ON, but Z3 was not found. Install the static Z3 development "
+ "package with `pip install 'z3-static>=4.16.0.post1'`, or point "
+ "Z3_DIR/CMAKE_PREFIX_PATH at a Z3 installation.")
+ endif()
+ message(STATUS "Build without Z3 SMT solver support")
+ return()
+endif()
+
+# Enable the real Z3 implementation inside the single src/arith/z3_prover.cc
file.
+add_compile_definitions(TVM_USE_Z3)
+message(STATUS "Build with Z3 SMT solver support")
diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 924cc29927..e635315e67 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -27,6 +27,7 @@
#include <tvm/arith/int_set.h>
#include <tvm/ffi/cast.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/string.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/with_context.h>
@@ -588,6 +589,110 @@ class IntSetAnalyzer {
Impl* impl_;
};
+class Z3Prover {
+ public:
+ /*!
+ * \brief Update binding of var to a new expression.
+ *
+ * \param var The variable of interest.
+ * \param new_range The range of allowed values for this var.
+ * \param allow_override whether we allow override of existing information.
+ */
+ TVM_DLL void Bind(const Var& var, const Range& new_range, bool
allow_override = false);
+
+ /*!
+ * \brief Update binding of var to a new expression.
+ *
+ * \param var The variable of interest.
+ * \param expr The bound expression.
+ * \param allow_override whether we allow override of existing information.
+ */
+ TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override
= false);
+
+ /*!
+ * \brief Whether the Z3 backend is compiled into this build (USE_Z3=ON).
+ *
+ * \return true if the real Z3 prover is available, false for the stub.
+ */
+ TVM_DLL bool IsEnabled() const;
+
+ /*!
+ * \brief Whether can we prove expr is always true.
+ *
+ * \param expr The expression.
+ * \return Whether we can prove it.
+ */
+ TVM_DLL bool CanProve(const PrimExpr& expr);
+
+ /*!
+ * \brief Update the internal state to enter constraint.
+ *
+ * \param constraint A constraint expression.
+ * \return an exit function that must be called to cleanup the constraint
can be nullptr.
+ */
+ std::function<void()> EnterConstraint(const PrimExpr& constraint);
+
+ /*!
+ * \brief Get the SMTLIB2 representation of the current context.
+ *
+ * \param expr The optional expression to check.
+ * \return The SMTLIB2 string.
+ */
+ ffi::String GetSMTLIB2(const ffi::Optional<PrimExpr> expr);
+
+ /*!
+ * \brief Get statistics about Z3 prover.
+ *
+ * \return The statistics string.
+ */
+ ffi::String GetStats();
+
+ /*!
+ * \brief Set timeout in milliseconds for Z3 prover.
+ *
+ * \param timeout_ms The timeout in milliseconds.
+ */
+ void SetTimeoutMs(unsigned timeout_ms);
+
+ /*!
+ * \brief Set resource limitation for Z3 prover.
+ *
+ * \param rlimit the resource limitation.
+ */
+ void SetRLimit(unsigned rlimit);
+
+ /*!
+ * \brief Get the Z3 model for the given expression if satisfiable.
+ *
+ * \param expr The expression to get the model for.
+ * \return The model as a string.
+ */
+ ffi::String GetModel(const PrimExpr& expr);
+
+ /*!
+ * \brief Count the number of integer values that satisfy the current
constraints.
+ *
+ * This method uses Z3's model enumeration to count how many distinct values
of
+ * the given variable satisfy all current constraints.
+ *
+ * \param var The variable to count satisfying values for.
+ * \param max_count Maximum number of solutions to enumerate.
+ * \param min_consecutive Minimum consecutive count requirement.
+ * \return The number of distinct values that satisfy the constraints, or a
negative error code.
+ */
+ TVM_DLL int64_t CountSatisfyingValues(const Var& var, int64_t max_count =
2048,
+ int64_t min_consecutive = 1);
+
+ private:
+ friend class AnalyzerObj;
+ friend class Analyzer;
+ explicit Z3Prover(AnalyzerObj* parent);
+ TVM_DLL ~Z3Prover();
+ void CopyFrom(const Z3Prover& other);
+ class Impl;
+ Impl* impl_;
+};
+
/*!
* \brief Analyzer that contains bunch of sub-analyzers.
*
@@ -612,6 +717,8 @@ class TVM_DLL AnalyzerObj : public ffi::Object {
IntSetAnalyzer int_set;
/*! \brief sub-analyzer transitive comparisons */
TransitiveComparisonAnalyzer transitive_comparisons;
+ /*! \brief sub-analyzer using Z3 */
+ Z3Prover z3_prover;
/*! \brief constructor */
AnalyzerObj();
/*!
@@ -810,7 +917,16 @@ class ConstraintContext {
* \param constraint The constraint to be applied.
*/
ConstraintContext(const Analyzer& analyzer, PrimExpr constraint)
- : analyzer_(analyzer), constraint_(constraint) {}
+ : ConstraintContext(analyzer, std::move(constraint), false) {}
+ /*!
+ * \brief Construct a constraint context.
+ * \param analyzer The analyzer whose context is updated. The context
+ * keeps a reference to the analyzer while the scope is active.
+ * \param constraint The constraint to be applied.
+ * \param is_assume Whether the constraint comes from an assumption.
+ */
+ ConstraintContext(const Analyzer& analyzer, PrimExpr constraint, bool
is_assume)
+ : analyzer_(analyzer), constraint_(std::move(constraint)),
is_assume_(is_assume) {}
/*!
* \brief Construct a constraint context from a borrowed analyzer object.
* \param analyzer The borrowed analyzer object.
@@ -819,7 +935,15 @@ class ConstraintContext {
* This overload is for internal callers that already operate on
AnalyzerObj*.
*/
ConstraintContext(AnalyzerObj* analyzer, PrimExpr constraint)
- : ConstraintContext(ffi::GetRef<Analyzer>(analyzer),
std::move(constraint)) {}
+ : ConstraintContext(ffi::GetRef<Analyzer>(analyzer),
std::move(constraint), false) {}
+ /*!
+ * \brief Construct a constraint context from a borrowed analyzer object.
+ * \param analyzer The borrowed analyzer object.
+ * \param constraint The constraint to be applied.
+ * \param is_assume Whether the constraint comes from an assumption.
+ */
+ ConstraintContext(AnalyzerObj* analyzer, PrimExpr constraint, bool is_assume)
+ : ConstraintContext(ffi::GetRef<Analyzer>(analyzer),
std::move(constraint), is_assume) {}
// enter the scope.
void EnterWithScope();
// exit the scope.
@@ -830,6 +954,8 @@ class ConstraintContext {
PrimExpr constraint_;
/*! \brief functions to be called in recovery */
std::vector<std::function<void()>> recovery_functions_;
+ /*! \brief Whether the constraint comes from an assumption. */
+ bool is_assume_;
};
} // namespace arith
diff --git a/pyproject.toml b/pyproject.toml
index e3f1038f22..2c38e0b21b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,7 +16,12 @@
# under the License.
[build-system]
-requires = ["scikit-build-core>=0.11", "setuptools-scm>=8"]
+# z3-static ships the PIC static libz3 + headers consumed by USE_Z3=ON.
+requires = [
+ "scikit-build-core>=0.11",
+ "setuptools-scm>=8",
+ "z3-static>=4.16.0.post1",
+]
build-backend = "scikit_build_core.build"
[project]
@@ -141,6 +146,8 @@ logging.level = "INFO"
[tool.scikit-build.cmake.define]
TVM_BUILD_PYTHON_MODULE = "ON"
USE_CUDA = "OFF"
+# Statically link Z3 from the z3-static build dependency by default.
+USE_Z3 = "ON"
BUILD_TESTING = "OFF"
[tool.setuptools_scm]
diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py
index 0aa6a75eba..78e93395c3 100644
--- a/python/tvm/arith/analyzer.py
+++ b/python/tvm/arith/analyzer.py
@@ -128,6 +128,91 @@ class Analyzer(Object):
def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.Analyzer)
+ @property
+ def is_z3_enabled(self) -> bool:
+ """Whether this build includes the Z3 backend (``USE_Z3=ON``).
+
+ The Z3-specific methods (:py:meth:`get_smtlib2`,
:py:meth:`get_z3_stats`,
+ :py:meth:`set_z3_timeout_ms`, :py:meth:`set_z3_rlimit`) only work when
+ this is ``True``.
+ """
+ return bool(_ffi_api.AnalyzerIsZ3Enabled(self))
+
+ def _check_z3_enabled(self) -> None:
+ if not self.is_z3_enabled:
+ raise RuntimeError(
+ "The Z3 backend is not available in this build. "
+ "Rebuild TVM with USE_Z3=ON to use Z3-specific Analyzer APIs."
+ )
+
+ def get_smtlib2(self, expr: tirx.PrimExpr | None = None) -> str:
+ """Get the current Z3 problem in SMT-LIB2 format.
+
+ Raises
+ ------
+ RuntimeError
+ If TVM was built without Z3 (``USE_Z3=OFF``), since there is no
+ solver state to export. Use :py:attr:`is_z3_enabled` to check
first.
+
+ Parameters
+ ----------
+ expr : Optional[PrimExpr]
+ The expression to prove. If provided, its negation is added to the
problem.
+ """
+ self._check_z3_enabled()
+ return _ffi_api.AnalyzerGetSMTLIB2(self, expr)
+
+ def set_z3_timeout_ms(self, timeout_ms: int) -> None:
+ """Set Z3 timeout in milliseconds.
+
+ Raises
+ ------
+ RuntimeError
+ If TVM was built without Z3 (``USE_Z3=OFF``).
+
+ Parameters
+ ----------
+ timeout_ms : int
+ The timeout in milliseconds.
+ """
+ self._check_z3_enabled()
+ _ffi_api.AnalyzerSetZ3TimeoutMs(self, timeout_ms)
+
+ def set_z3_rlimit(self, rlimit: int) -> None:
+ """Set Z3 resource limit.
+
+ The resource limit gives deterministic solver budgeting (unlike a wall
+ clock timeout). A value of ``0`` disables the limit.
+
+ Raises
+ ------
+ RuntimeError
+ If TVM was built without Z3 (``USE_Z3=OFF``).
+
+ Parameters
+ ----------
+ rlimit : int
+ The resource limit.
+ """
+ self._check_z3_enabled()
+ _ffi_api.AnalyzerSetZ3RLimit(self, rlimit)
+
+ def get_z3_stats(self) -> str:
+ """Get Z3 solver statistics.
+
+ Raises
+ ------
+ RuntimeError
+ If TVM was built without Z3 (``USE_Z3=OFF``).
+
+ Returns
+ -------
+ stats : str
+ The Z3 statistics.
+ """
+ self._check_z3_enabled()
+ return _ffi_api.AnalyzerGetZ3Stats(self)
+
def const_int_bound(self, expr: tirx.PrimExpr) -> ConstIntBound:
"""Find constant integer bound for expr.
@@ -260,7 +345,9 @@ class Analyzer(Object):
The expression.
strength: ProofStrength
- The proof strength
+ The proof strength. When TVM is built with Z3 (``USE_Z3=ON``), the
+ optional Z3 fallback is only consulted at ``SYMBOLIC_BOUND`` or
+ higher, after the native analyzers fail to prove the predicate.
Returns
-------
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 69dbe97f5e..b66ecb0fd1 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -39,7 +39,8 @@ AnalyzerObj::AnalyzerObj()
modular_set(this),
rewrite_simplify(this),
canonical_simplify(this),
- int_set(this) {}
+ int_set(this),
+ z3_prover(this) {}
void AnalyzerObj::Bind(const Var& var, const PrimExpr& expr, bool
allow_override) {
PrimExpr new_expr = expr;
@@ -52,6 +53,7 @@ void AnalyzerObj::Bind(const Var& var, const PrimExpr& expr,
bool allow_override
this->canonical_simplify.Update(var, new_expr, allow_override);
this->int_set.Update(var, this->int_set(new_expr), allow_override);
this->transitive_comparisons.Bind(var, expr, allow_override);
+ this->z3_prover.Bind(var, expr, allow_override);
}
void AnalyzerObj::Bind(const Var& var, const Range& range, bool
allow_override) {
@@ -62,6 +64,7 @@ void AnalyzerObj::Bind(const Var& var, const Range& range,
bool allow_override)
this->const_int_bound.Bind(var, range, allow_override);
this->int_set.Bind(var, range, allow_override);
this->transitive_comparisons.Bind(var, range, allow_override);
+ this->z3_prover.Bind(var, range, allow_override);
}
// skip modular_set
// skip rewrite simplify
@@ -131,6 +134,7 @@ void ConstraintContext::EnterWithScope() {
recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_));
+
recovery_functions_.push_back(analyzer_->z3_prover.EnterConstraint(constraint_));
}
void ConstraintContext::ExitWithScope() {
@@ -231,6 +235,12 @@ bool AnalyzerObj::CanProve(const PrimExpr& expr,
ProofStrength strength) {
}
}
+ // Z3 is an expensive best-effort fallback. Gate it behind the higher
+ // kSymbolicBound strength so the common kDefault path (including deeply
+ // recursive internal CanProve calls) never pays the prover cost.
+ if (strength >= ProofStrength::kSymbolicBound &&
z3_prover.CanProve(simplified)) {
+ return true;
+ }
return false;
}
@@ -334,6 +344,22 @@ TVM_FFI_STATIC_INIT_BLOCK() {
return static_cast<int64_t>(
analyzer->transitive_comparisons.TryCompare(lhs, rhs,
propagate_inequalities));
})
+ .def("arith.AnalyzerIsZ3Enabled",
+ [](Analyzer analyzer) { return analyzer->z3_prover.IsEnabled(); })
+ .def("arith.AnalyzerGetSMTLIB2",
+ [](Analyzer analyzer, ffi::Optional<PrimExpr> expr) {
+ return analyzer->z3_prover.GetSMTLIB2(expr);
+ })
+ .def("arith.AnalyzerSetZ3TimeoutMs",
+ [](Analyzer analyzer, int64_t timeout_ms) {
+
analyzer->z3_prover.SetTimeoutMs(static_cast<unsigned>(timeout_ms));
+ })
+ .def("arith.AnalyzerSetZ3RLimit",
+ [](Analyzer analyzer, int64_t rlimit) {
+ analyzer->z3_prover.SetRLimit(static_cast<unsigned>(rlimit));
+ })
+ .def("arith.AnalyzerGetZ3Stats",
+ [](Analyzer analyzer) { return analyzer->z3_prover.GetStats(); })
.def("arith.AnalyzerGetEnabledExtensions",
[](Analyzer analyzer) {
return
static_cast<std::int64_t>(analyzer->rewrite_simplify.GetEnabledExtensions());
diff --git a/src/arith/z3_prover.cc b/src/arith/z3_prover.cc
new file mode 100644
index 0000000000..aab4b485dd
--- /dev/null
+++ b/src/arith/z3_prover.cc
@@ -0,0 +1,864 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/arith/z3_prover.cc
+ * \brief Optional Z3 SMT solver backend for arith::Analyzer.
+ *
+ * The real implementation is compiled only when TVM_USE_Z3 is defined (set by
+ * the USE_Z3 CMake option). Otherwise a conservative stub is compiled so the
+ * C++ and Python APIs stay available without a Z3 dependency.
+ */
+#ifdef TVM_USE_Z3
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/ffi/extra/structural_equal.h>
+#include <tvm/ffi/extra/structural_hash.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/tirx/analysis.h>
+#include <tvm/tirx/builtin.h>
+#include <tvm/tirx/expr.h>
+#include <tvm/tirx/expr_functor.h>
+#include <tvm/tirx/op.h>
+#include <tvm/tirx/op_attr_types.h>
+
+#include <algorithm>
+#include <climits>
+#include <map>
+#include <memory>
+#include <optional>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "tvm/ffi/cast.h"
+#include "tvm/ffi/object.h"
+#include "tvm/ffi/string.h"
+#include "tvm/ir/expr.h"
+#include "tvm/runtime/data_type.h"
+#include "z3++.h"
+
+namespace tvm::arith {
+
+using namespace tirx;
+using namespace ffi;
+
+namespace {
+
+struct Namespace {
+ std::unordered_set<std::string> used_names;
+ /// @brief Get a new name that is not used before
+ /// This function is used to generate z3 variable names
+ ///
+ /// Z3 may deduplicate variables with the same name, which
+ /// causes issues when different TVM variables are mapped to
+ /// the same z3 variable.
+ ///
+ /// This function generates unique names by appending
+ /// suffixes to the original expression string representation.
+ ///
+ /// such as : "x", "x$1", "x$2", ...
+ std::string GetNewName(const PrimExpr& expr) {
+ std::stringstream ss;
+ ss << expr;
+ auto name = ss.str();
+ if (used_names.count(name) == 0) {
+ used_names.insert(name);
+ return name;
+ }
+ int idx = 1;
+ std::string check_name = name + "$" + std::to_string(idx);
+ while (used_names.count(check_name)) {
+ idx++;
+ check_name = name + "$" + std::to_string(idx);
+ }
+ used_names.insert(check_name);
+ return check_name;
+ }
+};
+
+} // namespace
+
+class Z3Prover::Impl : ExprFunctor<z3::expr(const PrimExpr&)> {
+ public:
+ using Base = ExprFunctor<z3::expr(const PrimExpr&)>;
+ using Self = Z3Prover::Impl;
+
+ AnalyzerObj* analyzer;
+ // Keep a reference to the thread-local context for the whole lifetime of
this
+ // prover. Schedules created on worker threads may be destroyed after the
+ // worker exits, so storing only a raw reference in z3::solver is not enough.
+ static std::shared_ptr<z3::context> GetThreadLocalContext() {
+ static thread_local std::shared_ptr<z3::context> local_ctx =
std::make_shared<z3::context>();
+ return local_ctx;
+ }
+ std::shared_ptr<z3::context> ctx{GetThreadLocalContext()};
+
+ /// @brief Z3 solver instance
+ z3::solver solver{*ctx};
+
+ /// @brief Memorize pure expressions
+ std::unordered_map<PrimExpr, z3::expr, StructuralHash, ExprDeepEqual> memo_;
+
+ /// @brief Namespace for variable naming
+ Namespace ns;
+
+ /// @brief Timeout in milliseconds
+ unsigned timeout_ms{UINT_MAX};
+
+ /// @brief Max steps
+ unsigned rlimit{UINT_MAX};
+
+ /// @brief Create a z3 solver with custom options
+ static z3::solver CreateSolver(z3::context& ctx) {
+ z3::solver solver(ctx);
+ // here we disable model generation to speed up the solving process
+ solver.set("model", false);
+ // ensure determinstic behavior
+ solver.set("random_seed", (unsigned)42);
+ return solver;
+ }
+
+ Impl(AnalyzerObj* parent) : analyzer(parent) {
+ scope_stack_.push_back({});
+ solver = CreateSolver(*ctx);
+ // use rlimit, not timeout to ensure deterministic behavior
+ SetRLimit(10000U);
+ }
+
+ /// @brief Create a Free z3 expression from PrimExprNode
+ z3::expr Create(const PrimExprNode* op) {
+ auto ref = ffi::GetRef<PrimExpr>(op);
+ auto dtype = op->dtype;
+ std::string name = ns.GetNewName(ref);
+ /// TVM max_val can't handle uint64 max correctly, so we special case it
here
+ if (dtype.is_bool()) {
+ return ctx->bool_const(name.c_str());
+ } else {
+ z3::expr e = ctx->int_const(name.c_str());
+ if (dtype.is_uint() && dtype.bits() == 64) {
+ solver.add(ctx->int_val(0) <= e && e <=
ctx->int_val((uint64_t)UINT64_MAX));
+ } else {
+ auto min_val = Downcast<IntImm>(min_value(dtype))->value;
+ auto max_val = Downcast<IntImm>(max_value(dtype))->value;
+ solver.add(ctx->int_val(min_val) <= e && e <= ctx->int_val(max_val));
+ }
+ return e;
+ }
+ }
+
+ struct Scope {
+ enum Kind {
+ BindValue,
+ BindRange,
+ Constraint,
+ } kind;
+ Var var;
+ PrimExpr value;
+ PrimExpr min;
+ PrimExpr extent;
+ PrimExpr constraint;
+ };
+
+ /// @brief scope_stack memorizes existing constraint and bindings
+ /// to generate SMTLIB2 representation with comments
+ std::vector<std::vector<Scope>> scope_stack_;
+
+ /// @brief Enter a constraint scope
+ std::function<void()> EnterConstraint(const PrimExpr& constraint) {
+ scope_stack_.push_back({});
+ scope_stack_.back().push_back(
+ Scope{Scope::Constraint, Var(), PrimExpr(), PrimExpr(), PrimExpr(),
constraint});
+ solver.push();
+ solver.add(VisitBool(constraint));
+ auto side_effect_exprs = std::move(side_effect_exprs_);
+ side_effect_exprs_.clear();
+ for (const auto& expr : side_effect_exprs) {
+ memo_.erase(expr);
+ }
+ return [this]() {
+ solver.pop();
+ scope_stack_.pop_back();
+ };
+ }
+
+ /// @brief Check trivil bad cases, return true if the expr is a bad case
+ /// Z3 prover may take a long time to initialize (at least 200us),
+ /// This optimization can speedup 30% of the test cases in our unit tests
+ bool CheckTrivilBadCases(const PrimExpr& expr) {
+ if (IsFreeNode(expr)) {
+ return true;
+ }
+ auto checkTrivilCmp = [this](const PrimExpr& lhs, const PrimExpr& rhs) {
+ if (IsFreeNode(lhs) && rhs->IsInstance<IntImmNode>()) {
+ return true;
+ }
+ if (IsFreeNode(rhs) && lhs->IsInstance<IntImmNode>()) {
+ return true;
+ }
+ if (IsFreeNode(lhs) && IsFreeNode(rhs)) {
+ return true;
+ }
+ // cast('xxx', free_var) == constant
+ if (auto cast = lhs.as<CastNode>()) {
+ if (IsFreeNode(cast->value) && rhs->IsInstance<IntImmNode>()) {
+ return true;
+ }
+ }
+ // constant == cast('xxx', free_var)
+ if (auto cast = rhs.as<CastNode>()) {
+ if (IsFreeNode(cast->value) && lhs->IsInstance<IntImmNode>()) {
+ return true;
+ }
+ }
+ return false;
+ };
+ if (auto eq = expr.as<EQNode>()) {
+ auto lhs = eq->a;
+ auto rhs = eq->b;
+ return checkTrivilCmp(lhs, rhs);
+ } else if (auto ne = expr.as<NENode>()) {
+ auto lhs = ne->a;
+ auto rhs = ne->b;
+ return checkTrivilCmp(lhs, rhs);
+ }
+ return false;
+ }
+
+ /// @brief Check if the expression can be proved
+ bool CanProve(const PrimExpr& expr) {
+ // Z3 is only a fallback. Any failure (including z3::exception thrown by
the
+ // solver) must degrade to "cannot prove" instead of escaping to the
caller.
+ try {
+ if (CheckTrivilBadCases(expr)) return false;
+ if (!IsValidDType(expr->dtype)) return false;
+ z3::expr_vector constr(*ctx);
+ constr.push_back(!ConvertBool(expr));
+ auto result = solver.check(constr);
+ constr.pop_back();
+ return result == z3::unsat;
+ } catch (const z3::exception&) {
+ return false;
+ }
+ }
+
+ /// @brief Binded
+ /// @brief Bind a variable to a value or a range
+ void Bind(const Var& var, const PrimExpr& value, bool allow_override =
false) {
+ if (!IsValidDType(var->dtype)) return;
+ scope_stack_.back().push_back(Scope{Scope::BindValue, var, value});
+ // we add the binding whenever the value is pure,
+ // because non-pure parts are handling by creating free variables in
VisitExpr
+ memo_.emplace(var, ConvertInt(value));
+ }
+
+ /// @brief Bind a variable to a range
+ void Bind(const Var& var, const Range& range, bool allow_override = false) {
+ if (!IsValidDType(var->dtype)) return;
+ scope_stack_.back().push_back(
+ Scope{Scope::BindRange, var, PrimExpr(), range->min, range->extent});
+ // 1. Create a placeholder for the var, and save it in the memo
+ // if the var is overrided later, we can just update the memo, and the
old placeholder will
+ // be ignored
+ auto var_expr = Create(var.as<PrimExprNode>());
+ memo_.emplace(var, var_expr);
+
+ // 2. Add constraint on the placeholder
+ // when min_expr >= max_expr, the range is empty, which is under
undefined behavior
+ // instead of adding an unsat constraint, we just skip the range
constraint to leave it a
+ // free var
+ //
+ // NOTE: range->min + range->extent builds a fresh AddNode that is not
folded, so we must
+ // test is_const_int on range->min and range->extent individually and
add the two constants
+ // in C++. Otherwise this fast path is never taken and we always emit
the more expensive
+ // symbolic constraint below.
+ if (tirx::is_const_int(range->min) && tirx::is_const_int(range->extent)) {
+ int64_t min_value = *tirx::as_const_int(range->min);
+ int64_t extent_value = *tirx::as_const_int(range->extent);
+ int64_t max_value = min_value + extent_value;
+ if (min_value < max_value) {
+ solver.add(ctx->int_val(min_value) <= var_expr);
+ solver.add(var_expr < ctx->int_val(max_value));
+ }
+ } else {
+ solver.add(ConvertBool(range->extent <= 0 ||
+ (range->min <= var && var < range->min +
range->extent)));
+ }
+ }
+
+ void CopyFrom(const Self& other_) {
+ // 1. create a new solver
+ // because this->solver depends on this->ctx
+ // we need to deconstruct the old solver, and create a new one
depending on this->ctx
+ solver = CreateSolver(*ctx);
+ // 2. ctx is owned by this Impl and pins the underlying thread-local
context for the lifetime
+ // of solver and memoized expressions.
+ // 3. copy other objects
+ ns = other_.ns;
+ for (auto& item : other_.memo_) {
+ memo_.emplace(item.first, item.second);
+ }
+ for (auto a : other_.solver.assertions()) {
+ solver.add(a);
+ }
+ // 4. copy timeout options
+ // but other solver options are not copied
+ SetTimeoutMs(other_.timeout_ms);
+ SetRLimit(other_.rlimit);
+ // 5. copy the scope stack, which containing comments for SMTLIB2
generation
+ scope_stack_ = other_.scope_stack_;
+ }
+
+ /// @brief Set timeout in milliseconds
+ void SetTimeoutMs(unsigned timeout_ms) {
+ this->timeout_ms = timeout_ms;
+ solver.set("timeout", timeout_ms);
+ }
+
+ /// @brief Set max steps
+ void SetRLimit(unsigned rlimit) {
+ this->rlimit = rlimit;
+ solver.set("rlimit", rlimit);
+ }
+
+ /// @brief Get the SMTLIB2 representation of the current solver state
+ ffi::String GetSMTLIB2() {
+ std::stringstream ss;
+ ss << "(set-option :timeout " << timeout_ms << ")\n";
+ AddScopeDebugMsg(ss);
+ ss << solver.to_smt2();
+ return ss.str();
+ }
+
+ void AddScopeDebugMsg(std::ostream& ss) {
+ for (const auto& scope : scope_stack_) {
+ ss << "; Entering Scope\n";
+ for (const auto& s : scope) {
+ switch (s.kind) {
+ case Scope::Constraint:
+ ss << "; constraint: " << s.constraint << "\n";
+ break;
+ case Scope::BindValue:
+ ss << "; bind value: " << s.var << " = " << s.value << "\n";
+ break;
+ case Scope::BindRange:
+ ss << "; bind range: " << s.var << " in [" << s.min << ", " <<
s.min + s.extent
+ << ")\n";
+ break;
+ }
+ }
+ }
+ }
+
+ /// @brief Get the SMTLIB2 representation of the current solver state with
additional expr trying
+ /// to prove
+ ffi::String GetSMTLIB2(const PrimExpr& expr) {
+ std::stringstream ss;
+ ss << "(set-option :timeout " << timeout_ms << ")\n";
+ AddScopeDebugMsg(ss);
+ ss << "; Trying to prove: " << expr << "\n";
+ solver.push();
+ solver.add(!ConvertBool(expr));
+ ss << solver.to_smt2();
+ solver.pop();
+ return ss.str();
+ }
+
+ /// @brief Get the statistics of the solver
+ ffi::String GetStats() {
+ std::stringstream ss;
+ ss << solver.statistics();
+ return ss.str();
+ }
+
+ ffi::String GetModel(const PrimExpr& expr) {
+ solver.set("model", true);
+ solver.push();
+ solver.add(!ConvertBool(expr));
+ auto result = solver.check();
+ ffi::String model_str;
+ if (result == z3::sat) {
+ z3::model m = solver.get_model();
+ std::map<std::string, z3::expr> model_map;
+ for (unsigned i = 0; i < m.size(); i++) {
+ z3::func_decl d = m[i];
+ model_map.emplace(d.name().str(), m.get_const_interp(d));
+ }
+ std::stringstream ss;
+ for (const auto& [k, v] : model_map) {
+ ss << " " << k << " = " << v << "\n";
+ }
+ model_str = ss.str();
+ }
+ solver.pop();
+ solver.set("model", false);
+ return model_str;
+ }
+
+ /*!
+ * \brief Count the number of distinct integer values satisfying current
constraints.
+ *
+ * Uses Z3's model enumeration (AllSAT pattern) to count solutions:
+ * 1. Find a satisfying assignment
+ * 2. Add a blocking clause to exclude it
+ * 3. Repeat until UNSAT
+ *
+ * \param var The variable to count values for
+ * \param max_count Safety limit on enumeration
+ * \param min_consecutive Minimum consecutive count requirement (0 to
disable)
+ * \return Number of satisfying values, -1 on error, -2 if min_consecutive
constraint not met
+ */
+ int64_t CountSatisfyingValues(const Var& var, int64_t max_count, int64_t
min_consecutive = 1) {
+ if (!IsValidDType(var->dtype)) {
+ return -1;
+ }
+
+ solver.set("model", true);
+ solver.push();
+
+ // Convert the TVM variable to Z3 expression
+ z3::expr z3_var = VisitInt(var);
+
+ int64_t count = 0;
+ std::vector<int64_t> found_values;
+
+ while (count < max_count) {
+ auto result = solver.check();
+ if (result != z3::sat) {
+ break; // No more solutions
+ }
+
+ z3::model m = solver.get_model();
+ z3::expr val_expr = m.eval(z3_var, true);
+
+ // Extract the integer value from Z3 expression
+ int64_t val;
+ if (val_expr.is_numeral()) {
+ val = val_expr.get_numeral_int64();
+ } else {
+ // If we can't get a concrete value, stop enumeration
+ break;
+ }
+
+ found_values.push_back(val);
+ count++;
+
+ // Add blocking clause: var != val (exclude this solution)
+ solver.add(z3_var != ctx->int_val(val));
+ }
+
+ solver.pop();
+ solver.set("model", false);
+
+ // Clear any side effects from visiting the variable
+ for (const auto& expr : side_effect_exprs_) {
+ memo_.erase(expr);
+ }
+ side_effect_exprs_.clear();
+
+ // Check minimum consecutive constraint if enabled
+ if (min_consecutive > 0 && count > 0) {
+ // Sort the values to check consecutive groups
+ std::sort(found_values.begin(), found_values.end());
+
+ // Check that all values form groups of at least min_consecutive
consecutive numbers
+ int64_t consecutive_count = 1;
+ for (size_t i = 1; i < found_values.size(); i++) {
+ if (found_values[i] == found_values[i - 1] + 1) {
+ // Consecutive value
+ consecutive_count++;
+ } else {
+ // Gap found, check if the previous group meets the minimum
+ if (consecutive_count < min_consecutive) {
+ return -2; // Previous group too small
+ }
+ consecutive_count = 1; // Start new group
+ }
+ }
+ // Check the last group
+ if (consecutive_count < min_consecutive) {
+ return -2; // Last group too small
+ }
+ }
+
+ return count;
+ }
+
+ private:
+ using Z3BinOp = z3::expr (*)(const z3::expr&, const z3::expr&);
+
+ std::vector<PrimExpr> side_effect_exprs_;
+
+ z3::expr ConvertBool(const PrimExpr& e) {
+ auto res = VisitBool(e);
+ for (auto& expr : side_effect_exprs_) {
+ memo_.erase(expr);
+ }
+ side_effect_exprs_.clear();
+ return res;
+ }
+
+ z3::expr ConvertInt(const PrimExpr& e) {
+ auto res = VisitInt(e);
+ for (auto& expr : side_effect_exprs_) {
+ memo_.erase(expr);
+ }
+ side_effect_exprs_.clear();
+ return res;
+ }
+
+ /// @brief Visit expression with memoization
+ z3::expr VisitExpr(const PrimExpr& e) override {
+ if (memo_.count(e)) {
+ return memo_.at(e);
+ }
+ auto res = Base::VisitExpr(e);
+ auto side_effect = SideEffect(e);
+ if (side_effect <= CallEffectKind::kPure) {
+ memo_.emplace(e, res);
+ } else if (side_effect <= CallEffectKind::kReadState) {
+ memo_.emplace(e, res);
+ side_effect_exprs_.emplace_back(e);
+ } else {
+ side_effect_exprs_.emplace_back(e);
+ }
+ return res;
+ }
+
+ /// @brief Check if the expression is a free node having no constraints
+ bool IsFreeNode(const PrimExpr& e) {
+ if (memo_.count(e)) {
+ return false;
+ }
+ return e->IsInstance<CallNode>() || e->IsInstance<BufferLoadNode>() ||
+ e->IsInstance<ProducerLoadNode>() || e->IsInstance<ReduceNode>() ||
+ (e->IsInstance<CastNode>() &&
!IsValidDType(Downcast<Cast>(e)->value->dtype));
+ }
+
+ /// @brief Check if the dtype is valid for z3 integer operations
+ static bool IsValidDType(const DataType& dtype) {
+ return (dtype.is_int() || dtype.is_uint() || dtype.is_bool()) &&
dtype.lanes() == 1;
+ }
+
+ /// @brief Visit the expression and convert it into z3 integer expression
+ z3::expr VisitInt(const PrimExpr& expr) {
+ auto e = VisitExpr(expr);
+ if (e.is_bool()) {
+ return z3::ite(e, ctx->int_val(1), ctx->int_val(0));
+ } else {
+ return e;
+ }
+ }
+
+ /// @brief Visit the expression and convert it into z3 boolean expression
+ z3::expr VisitBool(const PrimExpr& e) {
+ auto expr = VisitExpr(e);
+ if (expr.is_bool()) {
+ return expr;
+ } else {
+ return expr != ctx->int_val(0);
+ }
+ }
+
+ /// @brief Helper function to visit binary arithmetic operations
+ z3::expr VisitArith(Z3BinOp signed_op, const PrimExprNode* op, const
PrimExpr& a,
+ const PrimExpr& b) {
+ if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) {
+ return signed_op(VisitInt(a), VisitInt(b));
+ } else {
+ return Create(op);
+ }
+ }
+
+ z3::expr VisitExpr_(const LetNode* op) override {
+ if (IsValidDType(op->var->dtype)) {
+ memo_.emplace(op->var, VisitInt(op->value));
+ }
+ return VisitExpr(op->body);
+ }
+ z3::expr VisitExpr_(const CastNode* op) override {
+ // if the inner dtype is valid, we just visit it
+ if (IsValidDType(op->value->dtype) && IsValidDType(op->dtype)) {
+ return VisitInt(op->value);
+ } else {
+ // otherwise, we create a new free z3 variable
+ return Create(op);
+ }
+ }
+ z3::expr VisitExpr_(const VarNode* op) override { return Create(op); }
+ z3::expr VisitExpr_(const BufferLoadNode* op) override { return Create(op); }
+ z3::expr VisitExpr_(const ProducerLoadNode* op) override { return
Create(op); }
+ z3::expr VisitExpr_(const ReduceNode* op) override { return Create(op); }
+ z3::expr VisitExpr_(const MinNode* op) override {
+ auto a = VisitInt(op->a);
+ auto b = VisitInt(op->b);
+ return z3::ite(a < b, a, b);
+ }
+ z3::expr VisitExpr_(const MaxNode* op) override {
+ auto a = VisitInt(op->a);
+ auto b = VisitInt(op->b);
+ return z3::ite(a > b, a, b);
+ }
+ // TVM Div/Mod are truncated (round toward zero), while Z3's native operator/
+ // and operator% are Euclidean. Using the raw operators is unsound once the
+ // dividend can be negative, so we implement truncating helpers explicitly.
+ static z3::expr truncdiv(const z3::expr& a, const z3::expr& b) {
+ z3::expr abs_a = z3::ite(a >= 0, a, -a);
+ z3::expr abs_b = z3::ite(b >= 0, b, -b);
+ // |a| / |b| is exact (Euclidean == truncated for non-negative operands).
+ z3::expr q = abs_a / abs_b;
+ return z3::ite((a >= 0) == (b >= 0), q, -q);
+ }
+ static z3::expr truncmod(const z3::expr& a, const z3::expr& b) {
+ // TVM Mod follows the sign of the dividend: a - b * truncdiv(a, b).
+ return a - b * truncdiv(a, b);
+ }
+ static z3::expr floordiv(const z3::expr& a, const z3::expr& b) {
+ return z3::ite(b > 0, a / b, -((-a) / b));
+ }
+ static z3::expr floormod(const z3::expr& a, const z3::expr& b) {
+ return z3::ite(b > 0, a % b, -((-a) % b));
+ }
+ z3::expr VisitExpr_(const AddNode* op) override {
+ return VisitArith(z3::operator+, op, op->a, op->b);
+ }
+ z3::expr VisitExpr_(const SubNode* op) override {
+ return VisitArith(z3::operator-, op, op->a, op->b);
+ }
+ z3::expr VisitExpr_(const MulNode* op) override {
+ return VisitArith(z3::operator*, op, op->a, op->b);
+ }
+ z3::expr VisitExpr_(const DivNode* op) override { return
VisitArith(truncdiv, op, op->a, op->b); }
+ z3::expr VisitExpr_(const ModNode* op) override { return
VisitArith(truncmod, op, op->a, op->b); }
+ z3::expr VisitExpr_(const FloorDivNode* op) override {
+ return VisitArith(floordiv, op, op->a, op->b);
+ }
+ z3::expr VisitExpr_(const FloorModNode* op) override {
+ return VisitArith(floormod, op, op->a, op->b);
+ }
+ z3::expr VisitExpr_(const EQNode* op) override {
+ return VisitArith(z3::operator==, op, op->a, op->b);
+ }
+ z3::expr VisitExpr_(const NENode* op) override {
+ return VisitArith(z3::operator!=, op, op->a, op->b);
+ }
+ z3::expr VisitExpr_(const LTNode* op) override {
+ return VisitArith(z3::operator<, op, op->a, op->b);
+ }
+ z3::expr VisitExpr_(const LENode* op) override {
+ return VisitArith(z3::operator<=, op, op->a, op->b);
+ }
+ z3::expr VisitExpr_(const GTNode* op) override {
+ return VisitArith(z3::operator>, op, op->a, op->b);
+ }
+ z3::expr VisitExpr_(const GENode* op) override {
+ return VisitArith(z3::operator>=, op, op->a, op->b);
+ }
+ z3::expr VisitExpr_(const AndNode* op) override { return VisitBool(op->a) &&
VisitBool(op->b); }
+ z3::expr VisitExpr_(const OrNode* op) override { return VisitBool(op->a) ||
VisitBool(op->b); }
+ z3::expr VisitExpr_(const NotNode* op) override { return !VisitBool(op->a); }
+ z3::expr VisitExpr_(const SelectNode* op) override {
+ return z3::ite(VisitBool(op->condition), VisitInt(op->true_value),
VisitInt(op->false_value));
+ }
+ z3::expr VisitExpr_(const IntImmNode* op) override { return
ctx->int_val(op->value); }
+
+ // Bitwise operations
+ z3::expr VisitExpr_(const CallNode* op) override {
+ // Check if this is a bitwise operation
+ if (op->op.same_as(tirx::builtin::bitwise_and())) {
+ return VisitBitwiseOp(z3::operator&, op);
+ } else if (op->op.same_as(tirx::builtin::bitwise_or())) {
+ return VisitBitwiseOp(z3::operator|, op);
+ } else if (op->op.same_as(tirx::builtin::bitwise_xor())) {
+ return VisitBitwiseOp(z3::operator^, op);
+ } else if (op->op.same_as(tirx::builtin::bitwise_not())) {
+ return VisitBitwiseNotOp(op);
+ } else if (op->op.same_as(tirx::builtin::shift_left())) {
+ return VisitShiftOp(z3::shl, op);
+ } else if (op->op.same_as(tirx::builtin::shift_right())) {
+ return VisitShiftOp(z3::ashr, op);
+ } else if (op->op.same_as(tirx::builtin::if_then_else()) &&
op->args.size() == 3 &&
+ IsValidDType(op->args[1]->dtype) &&
IsValidDType(op->args[2]->dtype)) {
+ // tir.if_then_else(cond, a, b) is a select-like ternary.
+ return z3::ite(VisitBool(op->args[0]), VisitInt(op->args[1]),
VisitInt(op->args[2]));
+ } else {
+ // For other call nodes, create a free variable
+ return Create(op);
+ }
+ }
+
+ /// @brief Helper function to visit binary bitwise operations
+ z3::expr VisitBitwiseOp(z3::expr (*op_func)(const z3::expr&, const
z3::expr&),
+ const CallNode* op) {
+ if (op->args.size() != 2) {
+ LOG(FATAL) << "Binary bitwise operation expects 2 arguments, got " <<
op->args.size();
+ TVM_FFI_UNREACHABLE();
+ }
+
+ const PrimExpr& a = op->args[0];
+ const PrimExpr& b = op->args[1];
+ unsigned bit_width = std::max(op->args[0].dtype().bits(),
op->args[1].dtype().bits());
+
+ if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) {
+ return z3::bv2int(
+ op_func(z3::int2bv(bit_width, VisitInt(a)), z3::int2bv(bit_width,
VisitInt(b))), true);
+ } else {
+ return Create(op);
+ }
+ }
+
+ /// @brief Helper function to visit unary bitwise not operation
+ z3::expr VisitBitwiseNotOp(const CallNode* op) {
+ if (op->args.size() != 1) {
+ LOG(FATAL) << "Bitwise not operation expects 1 argument, got " <<
op->args.size();
+ TVM_FFI_UNREACHABLE();
+ }
+
+ const PrimExpr& a = op->args[0];
+
+ if (IsValidDType(a->dtype)) {
+ // Cast integer to bit-vector, apply bitwise not, then cast back.
+ unsigned bit_width = a.dtype().bits();
+ z3::expr a_int = VisitInt(a);
+ z3::expr a_bv = z3::int2bv(bit_width, a_int);
+ return z3::bv2int(~a_bv, true);
+ } else {
+ return Create(op);
+ }
+ }
+
+ /// @brief Helper function to visit shift operations
+ z3::expr VisitShiftOp(z3::expr (*op_func)(const z3::expr&, const z3::expr&),
const CallNode* op) {
+ if (op->args.size() != 2) {
+ LOG(FATAL) << "Shift operation expects 2 arguments, got " <<
op->args.size();
+ TVM_FFI_UNREACHABLE();
+ }
+
+ const PrimExpr& a = op->args[0];
+ const PrimExpr& b = op->args[1];
+
+ // Shift operations require integer types for both operands
+ if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) {
+ z3::expr a_expr = VisitInt(a);
+ z3::expr b_expr = VisitInt(b);
+
+ // Rely on Z3's native bit-vector shift behavior. We must NOT add hard
+ // assertions such as `b_expr >= 0` to the solver here: solver.add() has
no
+ // matching push/pop in this path, so the assertion would permanently
+ // poison the shared solver and make all subsequent unrelated proofs
about
+ // `b` unsound.
+ unsigned bit_width = std::max(a.dtype().bits(), b.dtype().bits());
+ z3::expr a_bv = z3::int2bv(bit_width, a_expr);
+ z3::expr b_bv = z3::int2bv(bit_width, b_expr);
+
+ // Perform the shift in bit-vector domain, then cast back to int.
+ z3::expr result_bv = op_func(a_bv, b_bv);
+ return z3::bv2int(result_bv, true);
+ } else {
+ return Create(op);
+ }
+ }
+
+ z3::expr VisitExprDefault_(const Object* op) override {
+ // Z3 is a best-effort fallback that runs only after the native analyzers
+ // have already failed. An unsupported node must not crash the build, so we
+ // model it as a fresh unconstrained free variable, which keeps the proof
+ // sound (it can only make CanProve more conservative).
+ return Create(static_cast<const PrimExprNode*>(op));
+ }
+};
+
+TVM_DLL bool Z3Prover::IsEnabled() const { return true; }
+TVM_DLL bool Z3Prover::CanProve(const PrimExpr& expr) { return
impl_->CanProve(expr); }
+TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool
allow_override) {
+ return impl_->Bind(var, new_range, allow_override);
+}
+TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool
allow_override) {
+ return impl_->Bind(var, expr, allow_override);
+}
+std::function<void()> Z3Prover::EnterConstraint(const PrimExpr& constraint) {
+ return impl_->EnterConstraint(constraint);
+}
+ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional<PrimExpr> expr) {
+ if (expr.has_value()) {
+ return impl_->GetSMTLIB2(expr.value());
+ } else {
+ return impl_->GetSMTLIB2();
+ }
+}
+void Z3Prover::SetTimeoutMs(unsigned timeout_ms) {
impl_->SetTimeoutMs(timeout_ms); }
+void Z3Prover::SetRLimit(unsigned max_step) { impl_->SetRLimit(max_step); }
+void Z3Prover::CopyFrom(const Z3Prover& other) {
impl_->CopyFrom(*other.impl_); }
+ffi::String Z3Prover::GetStats() { return impl_->GetStats(); }
+ffi::String Z3Prover::GetModel(const PrimExpr& expr) { return
impl_->GetModel(expr); }
+TVM_DLL int64_t Z3Prover::CountSatisfyingValues(const Var& var, int64_t
max_count,
+ int64_t min_consecutive) {
+ return impl_->CountSatisfyingValues(var, max_count, min_consecutive);
+}
+Z3Prover::Z3Prover(AnalyzerObj* parent) : impl_(new Impl{parent}) {}
+TVM_DLL Z3Prover::~Z3Prover() { delete impl_; }
+
+} // namespace tvm::arith
+
+#else // TVM_USE_Z3
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tirx/expr.h>
+#include <tvm/tirx/op.h>
+
+#include "tvm/ffi/string.h"
+#include "tvm/ir/expr.h"
+
+namespace tvm::arith {
+
+using namespace tirx;
+using namespace ffi;
+
+// Stub implementation used when Z3 support is not built. All proving queries
+// conservatively report "cannot prove" while keeping the public API available.
+class Z3Prover::Impl {};
+
+TVM_DLL bool Z3Prover::IsEnabled() const { return false; }
+TVM_DLL bool Z3Prover::CanProve(const PrimExpr& expr) { return false; }
+TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool
allow_override) {}
+TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool
allow_override) {}
+std::function<void()> Z3Prover::EnterConstraint(const PrimExpr& constraint) {
+ return []() {};
+}
+ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional<PrimExpr> expr) {
+ return "; Z3 Prover is disabled.";
+}
+void Z3Prover::SetTimeoutMs(unsigned timeout_ms) {}
+void Z3Prover::SetRLimit(unsigned rlimit) {}
+ffi::String Z3Prover::GetModel(const PrimExpr& expr) { return "; Z3 Prover is
disabled."; }
+TVM_DLL int64_t Z3Prover::CountSatisfyingValues(const Var& var, int64_t
max_count,
+ int64_t min_consecutive) {
+ return -1; // Z3 disabled, return error
+}
+
+void Z3Prover::CopyFrom(const Z3Prover& other) {}
+ffi::String Z3Prover::GetStats() { return "; Z3 Prover is disabled."; }
+Z3Prover::Z3Prover(AnalyzerObj*) : impl_(nullptr) {}
+TVM_DLL Z3Prover::~Z3Prover() {}
+
+} // namespace tvm::arith
+
+#endif // TVM_USE_Z3
diff --git a/tests/python/arith/test_arith_z3.py
b/tests/python/arith/test_arith_z3.py
new file mode 100644
index 0000000000..a64afd76c5
--- /dev/null
+++ b/tests/python/arith/test_arith_z3.py
@@ -0,0 +1,756 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import gc
+import queue
+import threading
+
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import tirx
+from tvm.arith import Analyzer, ProofStrength
+
+# The Z3 prover is only consulted at the kSymbolicBound strength so the common
+# default path never pays the prover cost.
+SB = ProofStrength.SYMBOLIC_BOUND
+
+
+def _require_z3(analyzer):
+ if not analyzer.is_z3_enabled:
+ pytest.skip("Z3 prover is disabled in this build")
+
+
+def implies(x, y):
+ return tirx.Or(tirx.Not(x), y)
+
+
+# ---------------------------------------------------------------------------
+# API availability (works regardless of whether Z3 is built)
+# ---------------------------------------------------------------------------
+
+
+def test_z3_capability_query():
+ # `is_z3_enabled` is the supported way to detect the build configuration.
+ # The Z3-specific debug/config methods work only when it is True, and raise
+ # a clear error otherwise.
+ analyzer = Analyzer()
+ assert isinstance(analyzer.is_z3_enabled, bool)
+
+ if analyzer.is_z3_enabled:
+ assert isinstance(analyzer.get_smtlib2(), str)
+ assert isinstance(analyzer.get_z3_stats(), str)
+ else:
+ with pytest.raises(RuntimeError):
+ analyzer.get_smtlib2()
+ with pytest.raises(RuntimeError):
+ analyzer.get_z3_stats()
+ with pytest.raises(RuntimeError):
+ analyzer.set_z3_timeout_ms(1000)
+ with pytest.raises(RuntimeError):
+ analyzer.set_z3_rlimit(0)
+
+
+def test_z3_context_lifetime_outlives_worker_thread():
+ _require_z3(Analyzer())
+
+ result_queue = queue.Queue()
+
+ def worker():
+ try:
+ analyzer = Analyzer()
+ x = tirx.Var("x", "int32")
+ analyzer.bind(x, tvm.ir.Range(0, 16))
+ assert analyzer.can_prove(x >= 0, SB)
+ result_queue.put(("analyzer", analyzer))
+ except BaseException as err: # pylint: disable=broad-exception-caught
+ result_queue.put(("error", err))
+
+ thread = threading.Thread(target=worker)
+ thread.start()
+ thread.join()
+
+ kind, payload = result_queue.get_nowait()
+ if kind == "error":
+ raise payload
+
+ del payload
+ gc.collect()
+
+
+# ---------------------------------------------------------------------------
+# Examples the native analyzer cannot prove but Z3 can.
+#
+# Each case asserts both that the native analyzers (kDefault, Z3 gated off)
+# fail and that Z3 (kSymbolicBound) succeeds. This demonstrates the added value
+# of the Z3 backend and that it is correctly gated behind kSymbolicBound.
+# ---------------------------------------------------------------------------
+
+
+def test_z3_floor_division_identity_constraint():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ c = tirx.Var("c", "int32")
+
+ expr = ((b - a) // c) * c + a <= b
+ with analyzer.constraint_scope(tirx.all(a > 0, b > 0, c > 0)):
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_floor_division_identity_via_bind_range():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ c = tirx.Var("c", "int32")
+
+ analyzer.bind(a, tvm.ir.Range(1, 100000))
+ analyzer.bind(b, tvm.ir.Range(1, 100000))
+ analyzer.bind(c, tvm.ir.Range(1, 100000))
+
+ expr = ((b - a) // c) * c + a <= b
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_multiplication_monotonicity():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ c = tirx.Var("c", "int32")
+ d = tirx.Var("d", "int32")
+
+ expr = implies(tirx.all(a < b, b < c, a * d < b * d), b * d < c * d)
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_nested_floor_division_collapse():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ expr = implies(
+ tirx.all(a >= 0, a < 128),
+ a // 128 == (a // 64 * 32 + a % 32 // 16 * 8) // 64,
+ )
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_deeply_nested_floor_division_identity():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ expr = implies(
+ tirx.all(a >= 0, a < 128),
+ (
+ a % 16 * 64
+ + a // 64 * 32
+ + a % 8 // 4 * 32
+ + (a % 32 // 16 + a % 2) % 2 * 8
+ + 16
+ - (a // 64 + a % 8 // 4) // 2 * 64
+ )
+ // 512
+ == (
+ a % 16 * 64
+ + a // 64 * 32
+ + a % 8 // 4 * 32
+ + (a % 32 // 16 + a % 2) % 2 * 8
+ - (a // 64 + a % 8 // 4) // 2 * 64
+ )
+ // 512,
+ )
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_min_max_sum_identity():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ y = tirx.Var("y", "int32")
+ expr = tirx.max(x, y) + tirx.min(x, y) == x + y
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_select_absolute_value_nonneg():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ expr = tirx.Select(x >= 0, x, -x) >= 0
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_transitive_inequality():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ c = tirx.Var("c", "int32")
+ expr = implies(tirx.all(a <= b, b <= c), a <= c)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_square_expansion_nonneg():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ expr = (a + b) * (a + b) >= a * a + b * b
+ with analyzer.constraint_scope(tirx.all(a >= 0, b >= 0)):
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_square_monotonicity():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ expr = implies(tirx.all(0 <= a, a <= b), a * a <= b * b)
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_strict_multiplication():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ d = tirx.Var("d", "int32")
+ expr = implies(tirx.all(a < b, d > 0), a * d < b * d)
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_floor_division_monotonicity():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ c = tirx.Var("c", "int32")
+ expr = implies(tirx.all(a <= b, c > 0), tirx.floordiv(a, c) <=
tirx.floordiv(b, c))
+ assert not analyzer.can_prove(expr)
+ analyzer.set_z3_rlimit(0)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_floor_division_lower_bound():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ expr = implies(b > 0, tirx.floordiv(a, b) * b <= a)
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_floor_modulo_range():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ expr = implies(b > 0, tirx.all(0 <= tirx.floormod(a, b), tirx.floormod(a,
b) < b))
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_flattened_index_bound():
+ # Classic index-flattening bound used throughout TVM: for a row index i in
+ # [0, m) and a column index j in [0, n), the flattened index i * n + j
stays
+ # within [0, m * n).
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ i = tirx.Var("i", "int32")
+ j = tirx.Var("j", "int32")
+ m = tirx.Var("m", "int32")
+ n = tirx.Var("n", "int32")
+ expr = tirx.all(0 <= i * n + j, i * n + j < m * n)
+ with analyzer.constraint_scope(tirx.all(0 <= i, i < m, 0 <= j, j < n, m >
0, n > 0)):
+ assert not analyzer.can_prove(expr)
+ analyzer.set_z3_rlimit(0)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_modular_combination():
+ # Native modular_set tracks single-variable moduli, but combining two
+ # independent modular facts to reason about their sum is left to Z3.
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ y = tirx.Var("y", "int32")
+ expr = tirx.floormod(x + y, 2) == 0
+ with analyzer.constraint_scope(tirx.all(tirx.floormod(x, 6) == 0,
tirx.floormod(y, 6) == 0)):
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_square_non_negative():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ assert not analyzer.can_prove(a * a >= 0)
+ assert analyzer.can_prove(a * a >= 0, SB)
+
+
+def test_z3_min_max_average_bounds():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ assert not analyzer.can_prove(tirx.max(a, b) * 2 >= a + b)
+ assert analyzer.can_prove(tirx.max(a, b) * 2 >= a + b, SB)
+ assert analyzer.can_prove(tirx.min(a, b) * 2 <= a + b, SB)
+
+
+def test_z3_symbolic_bind_range_with_constraint():
+ # Combine a symbolic range binding (x in [0, n)) with a constraint on the
+ # extent to derive a concrete bound on x.
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ n = tirx.Var("n", "int32")
+ analyzer.bind(x, tvm.ir.Range(0, n))
+ with analyzer.constraint_scope(n <= 8):
+ assert not analyzer.can_prove(x < 8)
+ assert analyzer.can_prove(x < 8, SB)
+
+
+def test_z3_equality_congruence():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ expr = implies(a == b, a * a == b * b)
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_integer_strict_transitivity():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ c = tirx.Var("c", "int32")
+ # Over the integers, a < b and b < c implies a + 1 < c.
+ expr = implies(tirx.all(a < b, b < c), a + 1 < c)
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_if_then_else_absolute_value():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ expr = tirx.if_then_else(x >= 0, x, -x) >= 0
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_unsigned_non_negative():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ u = tirx.Var("u", "uint32")
+ assert not analyzer.can_prove(u >= 0)
+ assert analyzer.can_prove(u >= 0, SB)
+
+
+def test_z3_unsigned64_non_negative():
+ # Exercises the special-cased uint64 range handling (UINT64_MAX bound).
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ u = tirx.Var("u", "uint64")
+ assert not analyzer.can_prove(u >= 0)
+ assert analyzer.can_prove(u >= 0, SB)
+
+
+def test_z3_int64_square_expansion():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int64")
+ b = tirx.Var("b", "int64")
+ expr = (a + b) * (a + b) >= a * a + b * b
+ with analyzer.constraint_scope(tirx.all(a >= 0, b >= 0)):
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_boolean_variable_reasoning():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ p = tirx.Var("p", "bool")
+ q = tirx.Var("q", "bool")
+ expr = implies(tirx.And(p, q), tirx.Or(p, q))
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_not_equal_from_strict_less():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ y = tirx.Var("y", "int32")
+ expr = implies(x < y, tirx.NE(x, y))
+ assert not analyzer.can_prove(expr)
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_let_expression():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ y = tirx.Var("y", "int32")
+ t = tirx.Var("t", "int32")
+ let = tirx.Let(t, y * 2, t)
+ assert not analyzer.can_prove(let == y * 2)
+ assert analyzer.can_prove(let == y * 2, SB)
+
+
+def test_z3_cast_preserves_bounds():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ s = tirx.Var("s", "int16")
+ widened = tirx.Cast("int32", s)
+ assert analyzer.can_prove(widened <= 32767, SB)
+ assert analyzer.can_prove(widened >= -32768, SB)
+
+
+def test_z3_bitwise_and_mask_bound():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ analyzer.bind(x, tvm.ir.Range(0, 256))
+ assert analyzer.can_prove(tirx.bitwise_and(x, tirx.IntImm("int32", 7)) <
8, SB)
+
+
+def test_z3_bitwise_and_le_operand():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ y = tirx.Var("y", "int32")
+ analyzer.bind(x, tvm.ir.Range(0, 256))
+ analyzer.bind(y, tvm.ir.Range(0, 256))
+ # Bit-vector reasoning over two variables exceeds the default deterministic
+ # rlimit; lift it (0 == unlimited, still deterministic) for this proof.
+ analyzer.set_z3_rlimit(0)
+ assert analyzer.can_prove(tirx.bitwise_and(x, y) <= x, SB)
+
+
+def test_z3_bitwise_or_ge_operand():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ y = tirx.Var("y", "int32")
+ analyzer.bind(x, tvm.ir.Range(0, 256))
+ analyzer.bind(y, tvm.ir.Range(0, 256))
+ analyzer.set_z3_rlimit(0)
+ assert analyzer.can_prove(tirx.bitwise_or(x, y) >= x, SB)
+
+
+def test_z3_bitwise_xor_bound():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ y = tirx.Var("y", "int32")
+ analyzer.bind(x, tvm.ir.Range(0, 256))
+ analyzer.bind(y, tvm.ir.Range(0, 256))
+ analyzer.set_z3_rlimit(0)
+ assert analyzer.can_prove(tirx.bitwise_xor(x, y) < 256, SB)
+
+
+def test_z3_bitwise_not_identity():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ analyzer.bind(x, tvm.ir.Range(0, 256))
+ analyzer.set_z3_rlimit(0)
+ # Two's complement: ~x == -x - 1.
+ assert analyzer.can_prove(tirx.bitwise_not(x) == -x - 1, SB)
+
+
+def test_z3_shift_right_halves():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ analyzer.bind(x, tvm.ir.Range(0, 256))
+ analyzer.set_z3_rlimit(0)
+ # For non-negative x, (x >> 1) * 2 <= x.
+ assert analyzer.can_prove(tirx.shift_right(x, tirx.IntImm("int32", 1)) * 2
<= x, SB)
+
+
+def test_z3_shift_left_lower_bound():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ n = tirx.Var("n", "int32")
+ # Keep operands small so the 32-bit left shift cannot overflow; then
+ # x << n == x * 2 ** n >= x for x >= 1.
+ analyzer.bind(x, tvm.ir.Range(1, 16))
+ analyzer.bind(n, tvm.ir.Range(0, 4))
+ # Bit-vector shift reasoning exceeds the default deterministic rlimit.
+ analyzer.set_z3_rlimit(0)
+ assert analyzer.can_prove(tirx.shift_left(x, n) >= x, SB)
+
+
+# ---------------------------------------------------------------------------
+# Soundness / negative tests (Z3 must NOT prove false predicates)
+# ---------------------------------------------------------------------------
+
+
+def test_z3_negative_unprovable_inequality():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ # a < b does not hold for arbitrary a, b.
+ assert not analyzer.can_prove(a < b, SB)
+ # a * a > a is false (e.g. a == 0).
+ assert not analyzer.can_prove(a * a > a, SB)
+
+
+def test_z3_truncmod_can_be_negative():
+ # Regression test for truncated div/mod semantics: TVM Div/Mod round toward
+ # zero, so truncmod(a, 4) can be negative. A solver that modeled them as
+ # Euclidean would unsoundly "prove" truncmod(a, 4) >= 0.
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ assert not analyzer.can_prove(tirx.truncmod(a, 4) >= 0, SB)
+
+
+def test_z3_truncdiv_truncmod_identity():
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ expr = tirx.truncdiv(a, b) * b + tirx.truncmod(a, b) == a
+ with analyzer.constraint_scope(b != 0):
+ assert analyzer.can_prove(expr, SB)
+
+
+def test_z3_floormod_nested_identities():
+ # Ported from TileLang's test_divmod. Here `%` is floormod: nested floormod
+ # by opposite-sign divisors collapses to the single-divisor result, while
+ # the mixed case does not.
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ assert not analyzer.can_prove(a % 2 % -2 - a % 2 == 0, SB)
+ assert analyzer.can_prove(a % -2 % 2 - a % 2 == 0, SB)
+
+
+def test_z3_floormod_nonnegative():
+ # In contrast to truncmod, floormod with a positive divisor is always in
+ # [0, divisor), which Z3 should be able to prove.
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ assert analyzer.can_prove(tirx.floormod(a, 4) >= 0, SB)
+ assert analyzer.can_prove(tirx.floormod(a, 4) < 4, SB)
+
+
+def test_z3_shift_does_not_poison_solver():
+ # Regression test: evaluating a shift expression must not add permanent
+ # assertions (such as `b >= 0` / `b < 64`) to the shared solver. Otherwise
+ # an unrelated, unbounded `b` would be wrongly provable to be < 100.
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+
+ # Touch a shift expression so the prover visits the shift amount `b`.
+ analyzer.can_prove(tirx.shift_left(a, b) >= 0, SB)
+
+ # `b` is otherwise unconstrained, so this must remain unprovable.
+ assert not analyzer.can_prove(b < 100, SB)
+ assert not analyzer.can_prove(b >= 0, SB)
+
+
+def test_z3_constraint_scope_is_popped():
+ # Constraints entered through a scope must be removed once the scope exits,
+ # i.e. EnterConstraint's solver.push()/pop() must be balanced.
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ with analyzer.constraint_scope(x > 5):
+ assert analyzer.can_prove(x > 0, SB)
+ # The constraint is gone; x is unconstrained again.
+ assert not analyzer.can_prove(x > 0, SB)
+
+
+def test_z3_opaque_call_is_safe():
+ # An opaque/unsupported sub-expression is modeled as a fresh free variable.
+ # It must neither crash nor be provable on its own, yet still be usable as
a
+ # constraint.
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ call = tirx.call_extern("int32", "foo", x)
+ assert not analyzer.can_prove(call > 0, SB)
+ with analyzer.constraint_scope(call > 0):
+ assert analyzer.can_prove(call > 0, SB)
+ assert not analyzer.can_prove(call > 0, SB)
+
+
+def test_z3_shift_overflow_is_not_proven():
+ # Z3 models fixed-width shifts via bit-vectors, so it correctly refuses to
+ # prove `x << n >= x` for an unbounded `x` (a large `x` overflows int32 and
+ # wraps to a negative value). This guards against unsound shift modeling.
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ n = tirx.Var("n", "int32")
+ analyzer.set_z3_rlimit(0)
+ expr = implies(tirx.all(x >= 1, n >= 0, n < 8), tirx.shift_left(x, n) >= x)
+ assert not analyzer.can_prove(expr, SB)
+
+
+def test_z3_analyzers_are_isolated():
+ # Analyzers share a thread-local Z3 context but own separate solvers, so
+ # constraints and bindings in one must never leak into another.
+ analyzer_a = Analyzer()
+ analyzer_b = Analyzer()
+ _require_z3(analyzer_a)
+
+ x = tirx.Var("x", "int32")
+ with analyzer_a.constraint_scope(x > 100):
+ assert analyzer_a.can_prove(x > 50, SB)
+ assert not analyzer_b.can_prove(x > 50, SB)
+
+ analyzer_c = Analyzer()
+ analyzer_d = Analyzer()
+ analyzer_c.bind(x, tvm.ir.Range(0, 10))
+ assert analyzer_c.can_prove(x < 10, SB)
+ assert not analyzer_d.can_prove(x < 10, SB)
+
+
+def test_z3_repeated_can_prove_is_consistent():
+ # Repeated queries must be stateless: a CanProve call must not pollute the
+ # solver and change the result of a subsequent call.
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ x = tirx.Var("x", "int32")
+ assert analyzer.can_prove(x > 0, SB) == analyzer.can_prove(x > 0, SB)
+
+ analyzer.bind(x, tvm.ir.Range(5, 10))
+ assert analyzer.can_prove(x >= 5, SB)
+ assert analyzer.can_prove(x >= 5, SB)
+
+
+def test_z3_is_gated_behind_symbolic_bound():
+ # The Z3 fallback must not run at the default strength.
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ c = tirx.Var("c", "int32")
+ expr = ((b - a) // c) * c + a <= b
+ with analyzer.constraint_scope(tirx.all(a > 0, b > 0, c > 0)):
+ assert not analyzer.can_prove(expr, ProofStrength.DEFAULT)
+ assert analyzer.can_prove(expr, SB)
+
+
+# ---------------------------------------------------------------------------
+# SMT-LIB2 export
+# ---------------------------------------------------------------------------
+
+
+def test_z3_smtlib2_roundtrip():
+ z3 = pytest.importorskip("z3")
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ c = tirx.Var("c", "int32")
+ expr = ((b - a) // c) * c + a <= b
+
+ solver = z3.Solver()
+ with analyzer.constraint_scope(tirx.all(a > 0, b > 0, c > 0)):
+ solver.from_string(analyzer.get_smtlib2(expr))
+ assert solver.check() == z3.unsat
+
+
+def test_z3_smtlib2_roundtrip_with_timeout():
+ z3 = pytest.importorskip("z3")
+ analyzer = Analyzer()
+ _require_z3(analyzer)
+
+ a = tirx.Var("a", "int32")
+ b = tirx.Var("b", "int32")
+ c = tirx.Var("c", "int32")
+ analyzer.set_z3_timeout_ms(1000)
+
+ expr = implies(tirx.all(a > 0, b > 0, c > 0), ((b - a) // c) * c + a <= b)
+ solver = z3.Solver()
+ solver.from_string(analyzer.get_smtlib2(expr))
+ assert solver.check() == z3.unsat
+
+
+if __name__ == "__main__":
+ tvm.testing.main()