This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 022299b51f [Arith] MLIR PresburgerSet compile fix mlir >= 160 (#15638)
022299b51f is described below

commit 022299b51fa71c42f180b0c8e3afcac4eb50d71d
Author: Balint Cristian <[email protected]>
AuthorDate: Thu Aug 31 08:30:49 2023 +0300

    [Arith] MLIR PresburgerSet compile fix mlir >= 160 (#15638)
    
    Hi folks,
    
    Some fixes for MLIR based analyzer module introduced by #14690 .
    
    ---
    
    * Make CMake at par with LLVM info:
    ```
    {...}
    -- Use llvm-config=llvm-config-64
    -- LLVM libdir: /usr/lib64
    -- Found MLIR
    -- Build with MLIR
    -- Set TVM_MLIR_VERSION=160
    -- Found LLVM_INCLUDE_DIRS=/usr/include
    {...}
    --    USE_MKL                            : OFF
    --    USE_MLIR                           : ON
    --    USE_MSVC_MT                        : OFF
    {...}
    ```
    
    * Fix several compilation errors:
    ```
    error: cannot convert 'llvm::SmallVector<long int>' to 
'llvm::ArrayRef<mlir::presburger::MPInt>'
    error: no matching function for call to 
'tvm::IntImm::IntImm(tvm::runtime::DataType, mlir::presburger::MPInt&)'
    note:   no known conversion for argument 2 from 'mlir::presburger::MPInt' 
to 'int64_t' {aka 'long int'}
    ```
    
    Tested using: ```llvm/mlir 16.0.6```, ```llvm/mlir 15.0.7```,  ```llvm/mlir 
17.0.0rc3```
---
 CMakeLists.txt              |  1 +
 cmake/modules/LibInfo.cmake |  7 +++++++
 cmake/utils/FindLLVM.cmake  |  2 ++
 src/arith/presburger_set.cc | 51 +++++++++++++++++++++++++++++++++++++--------
 src/support/libinfo.cc      |  6 ++++++
 5 files changed, 58 insertions(+), 9 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 75f80bfaac..4f989a3d90 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -54,6 +54,7 @@ tvm_option(USE_HEXAGON_EXTERNAL_LIBS "Path to git repo 
containing external Hexag
 tvm_option(USE_RPC "Build with RPC" ON)
 tvm_option(USE_THREADS "Build with thread support" ON)
 tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" 
OFF)
+tvm_option(USE_MLIR "Build with MLIR support" OFF)
 tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF)
 tvm_option(USE_GRAPH_EXECUTOR "Build with tiny graph executor" ON)
 tvm_option(USE_GRAPH_EXECUTOR_CUDA_GRAPH "Build with tiny graph executor with 
CUDA Graph for GPUs" OFF)
diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake
index ad153cce04..fd12b9d038 100644
--- a/cmake/modules/LibInfo.cmake
+++ b/cmake/modules/LibInfo.cmake
@@ -24,6 +24,11 @@ function(add_lib_info src_file)
   else()
     string(STRIP ${TVM_INFO_LLVM_VERSION} TVM_INFO_LLVM_VERSION)
   endif()
+  if (NOT DEFINED TVM_INFO_MLIR_VERSION)
+    set(TVM_INFO_MLIR_VERSION "NOT-FOUND")
+  else()
+    string(STRIP ${TVM_INFO_MLIR_VERSION} TVM_INFO_MLIR_VERSION)
+  endif()
   if (NOT DEFINED CUDA_VERSION)
     set(TVM_INFO_CUDA_VERSION "NOT-FOUND")
   else()
@@ -47,6 +52,7 @@ function(add_lib_info src_file)
     TVM_INFO_INDEX_DEFAULT_I64="${INDEX_DEFAULT_I64}"
     TVM_INFO_INSTALL_DEV="${INSTALL_DEV}"
     TVM_INFO_LLVM_VERSION="${TVM_INFO_LLVM_VERSION}"
+    TVM_INFO_MLIR_VERSION="${TVM_INFO_MLIR_VERSION}"
     TVM_INFO_PICOJSON_PATH="${PICOJSON_PATH}"
     TVM_INFO_RANG_PATH="${RANG_PATH}"
     TVM_INFO_ROCM_PATH="${ROCM_PATH}"
@@ -86,6 +92,7 @@ function(add_lib_info src_file)
     TVM_INFO_USE_LIBBACKTRACE="${USE_LIBBACKTRACE}"
     TVM_INFO_USE_LIBTORCH="${USE_LIBTORCH}"
     TVM_INFO_USE_LLVM="${USE_LLVM}"
+    TVM_INFO_USE_MLIR="${USE_MLIR}"
     TVM_INFO_USE_METAL="${USE_METAL}"
     TVM_INFO_USE_MICRO_STANDALONE_RUNTIME="${USE_MICRO_STANDALONE_RUNTIME}"
     TVM_INFO_USE_MICRO="${USE_MICRO}"
diff --git a/cmake/utils/FindLLVM.cmake b/cmake/utils/FindLLVM.cmake
index f10e5f1eb8..f40d97d9ba 100644
--- a/cmake/utils/FindLLVM.cmake
+++ b/cmake/utils/FindLLVM.cmake
@@ -150,6 +150,8 @@ macro(find_llvm use_llvm)
           list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRPresburger.a")
           list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRSupport.a")
           set(TVM_MLIR_VERSION ${TVM_LLVM_VERSION})
+          message(STATUS "Build with MLIR")
+          message(STATUS "Set TVM_MLIR_VERSION=" ${TVM_MLIR_VERSION})
         endif()
       endif()
     endif()
diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc
index f1d86c861a..3798ba1904 100644
--- a/src/arith/presburger_set.cc
+++ b/src/arith/presburger_set.cc
@@ -126,38 +126,54 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const {
   for (const IntegerRelation& disjunct : disjuncts) {
     PrimExpr union_entry = Bool(1);
     for (unsigned i = 0, e = disjunct.getNumEqualities(); i < e; ++i) {
-      PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
+      PrimExpr linear_eq = IntImm(DataType::Int(64), 0);
       if (disjunct.getNumCols() > 1) {
         for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) {
+#if TVM_MLIR_VERSION >= 160
+          auto coeff = int64_t(disjunct.atEq(i, j));
+#else
           auto coeff = disjunct.atEq(i, j);
+#endif
           if (coeff >= 0 || is_zero(linear_eq)) {
-            linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j];
+            linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j];
           } else {
-            linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * 
vars[j];
+            linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) * 
vars[j];
           }
         }
       }
+#if TVM_MLIR_VERSION >= 160
+      auto c0 = int64_t(disjunct.atEq(i, disjunct.getNumCols() - 1));
+#else
       auto c0 = disjunct.atEq(i, disjunct.getNumCols() - 1);
-      linear_eq = linear_eq + IntImm(DataType::Int(32), c0);
+#endif
+      linear_eq = linear_eq + IntImm(DataType::Int(64), c0);
       union_entry = (union_entry && (linear_eq == 0));
     }
     for (unsigned i = 0, e = disjunct.getNumInequalities(); i < e; ++i) {
-      PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
+      PrimExpr linear_eq = IntImm(DataType::Int(64), 0);
       if (disjunct.getNumCols() > 1) {
         for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) {
+#if TVM_MLIR_VERSION >= 160
+          auto coeff = int64_t(disjunct.atIneq(i, j));
+#else
           auto coeff = disjunct.atIneq(i, j);
+#endif
           if (coeff >= 0 || is_zero(linear_eq)) {
-            linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j];
+            linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j];
           } else {
-            linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * 
vars[j];
+            linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) * 
vars[j];
           }
         }
       }
+#if TVM_MLIR_VERSION >= 160
+      auto c0 = int64_t(disjunct.atIneq(i, disjunct.getNumCols() - 1));
+#else
       auto c0 = disjunct.atIneq(i, disjunct.getNumCols() - 1);
+#endif
       if (c0 >= 0) {
-        linear_eq = linear_eq + IntImm(DataType::Int(32), c0);
+        linear_eq = linear_eq + IntImm(DataType::Int(64), c0);
       } else {
-        linear_eq = linear_eq - IntImm(DataType::Int(32), -c0);
+        linear_eq = linear_eq - IntImm(DataType::Int(64), -c0);
       }
       union_entry = (union_entry && (linear_eq >= 0));
     }
@@ -199,10 +215,19 @@ PresburgerSet Intersect(const Array<PresburgerSet>& sets) 
{
 
 IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) {
   Array<PrimExpr> tvm_coeffs = DetectLinearEquation(e, set->GetVars());
+#if TVM_MLIR_VERSION >= 160
+  SmallVector<mlir::presburger::MPInt> coeffs;
+#else
   SmallVector<int64_t> coeffs;
+#endif
+
   coeffs.reserve(tvm_coeffs.size());
   for (const PrimExpr& it : tvm_coeffs) {
+#if TVM_MLIR_VERSION >= 160
+    coeffs.push_back(mlir::presburger::MPInt(*as_const_int(it)));
+#else
     coeffs.push_back(*as_const_int(it));
+#endif
   }
 
   IntSet result = IntSet().Nothing();
@@ -211,9 +236,17 @@ IntSet EvalSet(const PrimExpr& e, const PresburgerSet& 
set) {
     auto range = simplex.computeIntegerBounds(coeffs);
     auto maxRoundedDown(simplex.computeOptimum(Simplex::Direction::Up, 
coeffs));
     auto opt = range.first.getOptimumIfBounded();
+#if TVM_MLIR_VERSION >= 160
+    auto min = opt.has_value() ? IntImm(DataType::Int(64), 
int64_t(opt.value())) : neg_inf();
+#else
     auto min = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : 
neg_inf();
+#endif
     opt = range.second.getOptimumIfBounded();
+#if TVM_MLIR_VERSION >= 160
+    auto max = opt.has_value() ? IntImm(DataType::Int(64), 
int64_t(opt.value())) : pos_inf();
+#else
     auto max = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : 
pos_inf();
+#endif
     auto interval = IntervalSet(min, max);
     result = Union({result, interval});
   }
diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc
index 53f9292d16..d94c74b5bb 100644
--- a/src/support/libinfo.cc
+++ b/src/support/libinfo.cc
@@ -31,6 +31,10 @@
 #define TVM_INFO_LLVM_VERSION "NOT-FOUND"
 #endif
 
+#ifndef TVM_INFO_MLIR_VERSION
+#define TVM_INFO_MLIR_VERSION "NOT-FOUND"
+#endif
+
 #ifndef TVM_INFO_USE_CUDA
 #define TVM_INFO_USE_CUDA "NOT-FOUND"
 #endif
@@ -271,6 +275,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
       {"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64},
       {"INSTALL_DEV", TVM_INFO_INSTALL_DEV},
       {"LLVM_VERSION", TVM_INFO_LLVM_VERSION},
+      {"MLIR_VERSION", TVM_INFO_MLIR_VERSION},
       {"PICOJSON_PATH", TVM_INFO_PICOJSON_PATH},
       {"RANG_PATH", TVM_INFO_RANG_PATH},
       {"ROCM_PATH", TVM_INFO_ROCM_PATH},
@@ -311,6 +316,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
       {"USE_LIBBACKTRACE", TVM_INFO_USE_LIBBACKTRACE},
       {"USE_LIBTORCH", TVM_INFO_USE_LIBTORCH},
       {"USE_LLVM", TVM_INFO_USE_LLVM},
+      {"USE_MLIR", TVM_INFO_USE_MLIR},
       {"USE_METAL", TVM_INFO_USE_METAL},
       {"USE_MICRO_STANDALONE_RUNTIME", TVM_INFO_USE_MICRO_STANDALONE_RUNTIME},
       {"USE_MICRO", TVM_INFO_USE_MICRO},

Reply via email to