This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 022299b51f [Arith] MLIR PresburgerSet compile fix mlir >= 160 (#15638)
022299b51f is described below
commit 022299b51fa71c42f180b0c8e3afcac4eb50d71d
Author: Balint Cristian <[email protected]>
AuthorDate: Thu Aug 31 08:30:49 2023 +0300
[Arith] MLIR PresburgerSet compile fix mlir >= 160 (#15638)
Hi folks,
Some fixes for MLIR based analyzer module introduced by #14690 .
---
* Make CMake at par with LLVM info:
```
{...}
-- Use llvm-config=llvm-config-64
-- LLVM libdir: /usr/lib64
-- Found MLIR
-- Build with MLIR
-- Set TVM_MLIR_VERSION=160
-- Found LLVM_INCLUDE_DIRS=/usr/include
{...}
-- USE_MKL : OFF
-- USE_MLIR : ON
-- USE_MSVC_MT : OFF
{...}
```
* Fix several compilation errors:
```
error: cannot convert 'llvm::SmallVector<long int>' to
'llvm::ArrayRef<mlir::presburger::MPInt>'
error: no matching function for call to
'tvm::IntImm::IntImm(tvm::runtime::DataType, mlir::presburger::MPInt&)'
note: no known conversion for argument 2 from 'mlir::presburger::MPInt'
to 'int64_t' {aka 'long int'}
```
Tested using: ```llvm/mlir 16.0.6```, ```llvm/mlir 15.0.7```, ```llvm/mlir
17.0.0rc3```
---
CMakeLists.txt | 1 +
cmake/modules/LibInfo.cmake | 7 +++++++
cmake/utils/FindLLVM.cmake | 2 ++
src/arith/presburger_set.cc | 51 +++++++++++++++++++++++++++++++++++++--------
src/support/libinfo.cc | 6 ++++++
5 files changed, 58 insertions(+), 9 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 75f80bfaac..4f989a3d90 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -54,6 +54,7 @@ tvm_option(USE_HEXAGON_EXTERNAL_LIBS "Path to git repo
containing external Hexag
tvm_option(USE_RPC "Build with RPC" ON)
tvm_option(USE_THREADS "Build with thread support" ON)
tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path"
OFF)
+tvm_option(USE_MLIR "Build with MLIR support" OFF)
tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF)
tvm_option(USE_GRAPH_EXECUTOR "Build with tiny graph executor" ON)
tvm_option(USE_GRAPH_EXECUTOR_CUDA_GRAPH "Build with tiny graph executor with
CUDA Graph for GPUs" OFF)
diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake
index ad153cce04..fd12b9d038 100644
--- a/cmake/modules/LibInfo.cmake
+++ b/cmake/modules/LibInfo.cmake
@@ -24,6 +24,11 @@ function(add_lib_info src_file)
else()
string(STRIP ${TVM_INFO_LLVM_VERSION} TVM_INFO_LLVM_VERSION)
endif()
+ if (NOT DEFINED TVM_INFO_MLIR_VERSION)
+ set(TVM_INFO_MLIR_VERSION "NOT-FOUND")
+ else()
+ string(STRIP ${TVM_INFO_MLIR_VERSION} TVM_INFO_MLIR_VERSION)
+ endif()
if (NOT DEFINED CUDA_VERSION)
set(TVM_INFO_CUDA_VERSION "NOT-FOUND")
else()
@@ -47,6 +52,7 @@ function(add_lib_info src_file)
TVM_INFO_INDEX_DEFAULT_I64="${INDEX_DEFAULT_I64}"
TVM_INFO_INSTALL_DEV="${INSTALL_DEV}"
TVM_INFO_LLVM_VERSION="${TVM_INFO_LLVM_VERSION}"
+ TVM_INFO_MLIR_VERSION="${TVM_INFO_MLIR_VERSION}"
TVM_INFO_PICOJSON_PATH="${PICOJSON_PATH}"
TVM_INFO_RANG_PATH="${RANG_PATH}"
TVM_INFO_ROCM_PATH="${ROCM_PATH}"
@@ -86,6 +92,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_LIBBACKTRACE="${USE_LIBBACKTRACE}"
TVM_INFO_USE_LIBTORCH="${USE_LIBTORCH}"
TVM_INFO_USE_LLVM="${USE_LLVM}"
+ TVM_INFO_USE_MLIR="${USE_MLIR}"
TVM_INFO_USE_METAL="${USE_METAL}"
TVM_INFO_USE_MICRO_STANDALONE_RUNTIME="${USE_MICRO_STANDALONE_RUNTIME}"
TVM_INFO_USE_MICRO="${USE_MICRO}"
diff --git a/cmake/utils/FindLLVM.cmake b/cmake/utils/FindLLVM.cmake
index f10e5f1eb8..f40d97d9ba 100644
--- a/cmake/utils/FindLLVM.cmake
+++ b/cmake/utils/FindLLVM.cmake
@@ -150,6 +150,8 @@ macro(find_llvm use_llvm)
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRPresburger.a")
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRSupport.a")
set(TVM_MLIR_VERSION ${TVM_LLVM_VERSION})
+ message(STATUS "Build with MLIR")
+ message(STATUS "Set TVM_MLIR_VERSION=" ${TVM_MLIR_VERSION})
endif()
endif()
endif()
diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc
index f1d86c861a..3798ba1904 100644
--- a/src/arith/presburger_set.cc
+++ b/src/arith/presburger_set.cc
@@ -126,38 +126,54 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const {
for (const IntegerRelation& disjunct : disjuncts) {
PrimExpr union_entry = Bool(1);
for (unsigned i = 0, e = disjunct.getNumEqualities(); i < e; ++i) {
- PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
+ PrimExpr linear_eq = IntImm(DataType::Int(64), 0);
if (disjunct.getNumCols() > 1) {
for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) {
+#if TVM_MLIR_VERSION >= 160
+ auto coeff = int64_t(disjunct.atEq(i, j));
+#else
auto coeff = disjunct.atEq(i, j);
+#endif
if (coeff >= 0 || is_zero(linear_eq)) {
- linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j];
+ linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j];
} else {
- linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) *
vars[j];
+ linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) *
vars[j];
}
}
}
+#if TVM_MLIR_VERSION >= 160
+ auto c0 = int64_t(disjunct.atEq(i, disjunct.getNumCols() - 1));
+#else
auto c0 = disjunct.atEq(i, disjunct.getNumCols() - 1);
- linear_eq = linear_eq + IntImm(DataType::Int(32), c0);
+#endif
+ linear_eq = linear_eq + IntImm(DataType::Int(64), c0);
union_entry = (union_entry && (linear_eq == 0));
}
for (unsigned i = 0, e = disjunct.getNumInequalities(); i < e; ++i) {
- PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
+ PrimExpr linear_eq = IntImm(DataType::Int(64), 0);
if (disjunct.getNumCols() > 1) {
for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) {
+#if TVM_MLIR_VERSION >= 160
+ auto coeff = int64_t(disjunct.atIneq(i, j));
+#else
auto coeff = disjunct.atIneq(i, j);
+#endif
if (coeff >= 0 || is_zero(linear_eq)) {
- linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j];
+ linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j];
} else {
- linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) *
vars[j];
+ linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) *
vars[j];
}
}
}
+#if TVM_MLIR_VERSION >= 160
+ auto c0 = int64_t(disjunct.atIneq(i, disjunct.getNumCols() - 1));
+#else
auto c0 = disjunct.atIneq(i, disjunct.getNumCols() - 1);
+#endif
if (c0 >= 0) {
- linear_eq = linear_eq + IntImm(DataType::Int(32), c0);
+ linear_eq = linear_eq + IntImm(DataType::Int(64), c0);
} else {
- linear_eq = linear_eq - IntImm(DataType::Int(32), -c0);
+ linear_eq = linear_eq - IntImm(DataType::Int(64), -c0);
}
union_entry = (union_entry && (linear_eq >= 0));
}
@@ -199,10 +215,19 @@ PresburgerSet Intersect(const Array<PresburgerSet>& sets)
{
IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) {
Array<PrimExpr> tvm_coeffs = DetectLinearEquation(e, set->GetVars());
+#if TVM_MLIR_VERSION >= 160
+ SmallVector<mlir::presburger::MPInt> coeffs;
+#else
SmallVector<int64_t> coeffs;
+#endif
+
coeffs.reserve(tvm_coeffs.size());
for (const PrimExpr& it : tvm_coeffs) {
+#if TVM_MLIR_VERSION >= 160
+ coeffs.push_back(mlir::presburger::MPInt(*as_const_int(it)));
+#else
coeffs.push_back(*as_const_int(it));
+#endif
}
IntSet result = IntSet().Nothing();
@@ -211,9 +236,17 @@ IntSet EvalSet(const PrimExpr& e, const PresburgerSet&
set) {
auto range = simplex.computeIntegerBounds(coeffs);
auto maxRoundedDown(simplex.computeOptimum(Simplex::Direction::Up,
coeffs));
auto opt = range.first.getOptimumIfBounded();
+#if TVM_MLIR_VERSION >= 160
+ auto min = opt.has_value() ? IntImm(DataType::Int(64),
int64_t(opt.value())) : neg_inf();
+#else
auto min = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) :
neg_inf();
+#endif
opt = range.second.getOptimumIfBounded();
+#if TVM_MLIR_VERSION >= 160
+ auto max = opt.has_value() ? IntImm(DataType::Int(64),
int64_t(opt.value())) : pos_inf();
+#else
auto max = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) :
pos_inf();
+#endif
auto interval = IntervalSet(min, max);
result = Union({result, interval});
}
diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc
index 53f9292d16..d94c74b5bb 100644
--- a/src/support/libinfo.cc
+++ b/src/support/libinfo.cc
@@ -31,6 +31,10 @@
#define TVM_INFO_LLVM_VERSION "NOT-FOUND"
#endif
+#ifndef TVM_INFO_MLIR_VERSION
+#define TVM_INFO_MLIR_VERSION "NOT-FOUND"
+#endif
+
#ifndef TVM_INFO_USE_CUDA
#define TVM_INFO_USE_CUDA "NOT-FOUND"
#endif
@@ -271,6 +275,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
{"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64},
{"INSTALL_DEV", TVM_INFO_INSTALL_DEV},
{"LLVM_VERSION", TVM_INFO_LLVM_VERSION},
+ {"MLIR_VERSION", TVM_INFO_MLIR_VERSION},
{"PICOJSON_PATH", TVM_INFO_PICOJSON_PATH},
{"RANG_PATH", TVM_INFO_RANG_PATH},
{"ROCM_PATH", TVM_INFO_ROCM_PATH},
@@ -311,6 +316,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
{"USE_LIBBACKTRACE", TVM_INFO_USE_LIBBACKTRACE},
{"USE_LIBTORCH", TVM_INFO_USE_LIBTORCH},
{"USE_LLVM", TVM_INFO_USE_LLVM},
+ {"USE_MLIR", TVM_INFO_USE_MLIR},
{"USE_METAL", TVM_INFO_USE_METAL},
{"USE_MICRO_STANDALONE_RUNTIME", TVM_INFO_USE_MICRO_STANDALONE_RUNTIME},
{"USE_MICRO", TVM_INFO_USE_MICRO},