This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 4cebb1c  [ARITH] Remove legacy const pattern functions (#5387)
4cebb1c is described below

commit 4cebb1c76303697df64624b87f4b1e8a654dbbf2
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Apr 20 20:53:32 2020 -0700

    [ARITH] Remove legacy const pattern functions (#5387)
---
 src/arith/compute_expr.h                     | 21 ---------------
 src/arith/pattern_match.h                    | 11 ++++++++
 src/target/llvm/codegen_llvm.cc              | 39 ++++++++++++++--------------
 src/target/source/codegen_c.cc               | 19 +++++++-------
 src/target/spirv/codegen_spirv.cc            |  6 ++---
 src/tir/pass/arg_binder.cc                   |  6 ++---
 src/tir/pass/ir_util.h                       | 16 ------------
 src/tir/transforms/inject_virtual_thread.cc  |  8 +++---
 src/tir/transforms/lower_thread_allreduce.cc |  4 ++-
 src/tir/transforms/lower_tvm_builtin.cc      |  6 ++---
 src/tir/transforms/lower_warp_memory.cc      | 38 ++++++++++++++-------------
 src/tir/transforms/storage_flatten.cc        | 11 ++++----
 src/tir/transforms/unroll_loop.cc            |  7 ++---
 src/tir/transforms/vectorize_loop.cc         |  7 +++--
 14 files changed, 87 insertions(+), 112 deletions(-)

diff --git a/src/arith/compute_expr.h b/src/arith/compute_expr.h
index f842780..fddd34b 100644
--- a/src/arith/compute_expr.h
+++ b/src/arith/compute_expr.h
@@ -56,27 +56,6 @@ template<typename Op>
 inline PrimExpr ComputeReduce(
     const Array<PrimExpr>& values, PrimExpr empty_value);
 
-inline bool GetConst(PrimExpr e, int64_t* out) {
-  if (e.dtype().is_vector()) return false;
-  const int64_t* v = tir::as_const_int(e);
-  if (v) {
-    *out = *v; return true;
-  } else {
-    return false;
-  }
-}
-
-// get a small constant int
-inline bool GetConstInt(PrimExpr e, int* out) {
-  int64_t v1 = 0;
-  if (GetConst(e, &v1)) {
-    if (v1 > static_cast<int64_t>(
-            std::numeric_limits<int>::max())) return false;
-    *out = static_cast<int>(v1); return true;
-  }
-  return false;
-}
-
 template<>
 inline PrimExpr Compute<tir::AddNode>(PrimExpr a, PrimExpr b) {
   return a + b;
diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h
index e81b088..0920ed3 100644
--- a/src/arith/pattern_match.h
+++ b/src/arith/pattern_match.h
@@ -574,6 +574,17 @@ ramp(const Pattern<TBase>& base,
       base.derived(), stride.derived(), lanes.derived());
 }
 
+template<typename TBase>
+inline PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>
+ramp(const Pattern<TBase>& base,
+     int stride,
+     int lanes) {
+  return PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>(
+      base.derived(),
+      PConstWithTypeLike<TBase>(base.derived(), stride),
+      PConst<int>(lanes));
+}
+
 /*!
  * \brief Pattern broadcast expression.
  * \tparam TA The pattern type of the value.
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 14302ef..820a20c 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -30,6 +30,7 @@
 
 #include "codegen_llvm.h"
 #include "codegen_cpu.h"
+#include "../../arith/pattern_match.h"
 #include "../build_common.h"
 namespace tvm {
 namespace codegen {
@@ -363,27 +364,27 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
         md_builder_->createTBAAStructTagNode(meta, meta, 0));
     return;
   }
-  int base = 0, width = 0;
+
+  int64_t base = 0, width = 0;
+  arith::PVar<IntImm> pbase, pstride;
+  arith::PVar<int> planes;
   // create meta-data for alias analysis
   // Use a group of binary tree ranges of memory banks.
   if (index.defined()) {
-    const RampNode* ramp = index.as<RampNode>();
-    if (ramp) {
-      int base, stride;
-      if (arith::GetConstInt(ramp->base, &base) &&
-          arith::GetConstInt(ramp->stride, &stride)) {
-        int xwith = ramp->lanes * stride;
-        width = 1;
-        while (width < xwith) {
-          width *= 2;
-        }
-        while (base % width) {
-          base -= base % width;
-          width *= 2;
-        }
+    if (arith::ramp(pbase, pstride, planes).Match(index)) {
+      base = pbase.Eval()->value;
+      int64_t xwith = planes.Eval() * pstride.Eval()->value;
+      width = 1;
+      while (width < xwith) {
+        width *= 2;
       }
-    } else {
-      if (arith::GetConstInt(index, &base)) width = 1;
+      while (base % width) {
+        base -= base % width;
+        width *= 2;
+      }
+    } else if (auto* ptr = index.as<tir::IntImmNode>()) {
+      width = 1;
+      base = ptr->value;
     }
   }
   llvm::MDNode* meta = md_tbaa_root_;
@@ -394,8 +395,8 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
   meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
   // create a tree-shape access structure.
   if (width != 0) {
-    for (int w = 1024; w >= width; w /= 2) {
-      int b = (base / w) * w;
+    for (int64_t w = 1024; w >= width; w /= 2) {
+      int64_t b = (base / w) * w;
       std::stringstream os;
       os << buffer << ".w" << w << ".b" << b;
       meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 6e7784c..84604b8 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -23,8 +23,8 @@
 #include <iomanip>
 #include <cctype>
 #include "codegen_c.h"
+#include "../../arith/pattern_match.h"
 #include "../../arith/compute_expr.h"
-#include "../../tir/pass/ir_util.h"
 
 namespace tvm {
 namespace codegen {
@@ -198,8 +198,8 @@ std::string CodeGenC::GetBufferRef(
     // optimize for case where it is in register,
     if (HandleTypeMatch(buffer, t) && !is_vol) {
       // optimize for constant access
-      int offset;
-      if (arith::GetConstInt(index, &offset)) {
+      if (auto* ptr = index.as<tir::IntImmNode>()) {
+        int64_t offset = ptr->value;
         CHECK_EQ(offset % t.lanes(), 0)
             << "Find unaligned vector load to a vector type";
         os << vid << '[' << (offset / t.lanes()) << ']';
@@ -663,9 +663,10 @@ void CodeGenC::VisitExpr_(const LoadNode* op, 
std::ostream& os) {  // NOLINT(*)
   } else {
     CHECK(is_one(op->predicate))
         << "predicated load is not supported";
-    PrimExpr base;
-    if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) {
-      std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base);
+
+    arith::PVar<PrimExpr> base;
+    if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) {
+      std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), 
base.Eval());
       HandleVolatileLoads(ref, op, os);
     } else {
       std::ostringstream svalue_expr;
@@ -708,10 +709,10 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
   } else {
     CHECK(is_one(op->predicate))
         << "Predicated store is not supported";
-    PrimExpr base;
-    if (GetRamp1Base(op->index, t.lanes(), &base)) {
+    arith::PVar<PrimExpr> base;
+    if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
       std::string value = this->PrintExpr(op->value);
-      this->PrintVecStore(op->buffer_var.get(), t, base, value);
+      this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
     } else {
       // The assignment below introduces side-effect, and the resulting value 
cannot
       // be reused across multiple expression, thus a new scope is needed
diff --git a/src/target/spirv/codegen_spirv.cc 
b/src/target/spirv/codegen_spirv.cc
index d4631aa..5d05b08 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -103,11 +103,11 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(
   spirv::Value v;
   if (ts.rank == 1) {
     v = builder_->GetLocalID(ts.dim_index);
-    int size = 0;
-    CHECK(arith::GetConstInt(extent, &size))
+    auto* sizeptr = extent.as<tir::IntImmNode>();
+    CHECK(sizeptr)
         << "SPIRV only allows constant thread group size " << " get " << 
extent;
     CHECK_LT(ts.dim_index, 3);
-    workgroup_size_[ts.dim_index] = static_cast<uint32_t>(size);
+    workgroup_size_[ts.dim_index] = static_cast<uint32_t>(sizeptr->value);
   } else {
     v = builder_->GetWorkgroupID(ts.dim_index);
   }
diff --git a/src/tir/pass/arg_binder.cc b/src/tir/pass/arg_binder.cc
index 51a6d8b..2f03047 100644
--- a/src/tir/pass/arg_binder.cc
+++ b/src/tir/pass/arg_binder.cc
@@ -291,9 +291,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
   }
   // Byte_offset field.
   int data_bytes = GetVectorBytes(buffer->dtype);
-  int64_t const_offset;
-  if (arith::GetConst(buffer->elem_offset, &const_offset)) {
-    Bind_(make_const(DataType::UInt(64), const_offset * data_bytes),
+
+  if (const auto* const_offset = buffer->elem_offset.as<IntImmNode>()) {
+    Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes),
                TVMArrayGet(DataType::UInt(64), handle, 
intrinsic::kArrByteOffset),
           arg_name + ".byte_offset", true);
   } else {
diff --git a/src/tir/pass/ir_util.h b/src/tir/pass/ir_util.h
index d8da61f..a167433 100644
--- a/src/tir/pass/ir_util.h
+++ b/src/tir/pass/ir_util.h
@@ -174,22 +174,6 @@ inline int GetTempAllocaAlignment(DataType type, int32_t 
const_size) {
   return align;
 }
 
-/*!
- * \brief Pattern match index to Ramp with stride=1
- *        This is a common pattern in continuous memory load.
- * \param index The index formula
- * \param lanes number of lanes in the ramp
- * \param base The result base.
- * \return true if pattern match success and store the base to base.
- */
-inline bool GetRamp1Base(PrimExpr index, int lanes, PrimExpr *base) {
-  const RampNode* r = index.as<RampNode>();
-  if (!r) return false;
-  if (!is_one(r->stride)) return false;
-  CHECK_EQ(r->lanes, lanes);
-  *base = r->base;
-  return true;
-}
 }  // namespace tir
 }  // namespace tvm
 #endif  // TVM_TIR_PASS_IR_UTIL_H_
diff --git a/src/tir/transforms/inject_virtual_thread.cc 
b/src/tir/transforms/inject_virtual_thread.cc
index c70962d..24747a4 100644
--- a/src/tir/transforms/inject_virtual_thread.cc
+++ b/src/tir/transforms/inject_virtual_thread.cc
@@ -57,15 +57,15 @@ class ExprTouched final : public StmtExprVisitor {
   }
   void VisitExpr_(const CallNode *op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
-      int rw_mask = 0;
-      CHECK(arith::GetConstInt(op->args[4], &rw_mask));
+      const auto* rw_mask = op->args[4].as<IntImmNode>();
       const VarNode* buffer_var = op->args[1].as<VarNode>();
       CHECK(buffer_var);
+      CHECK(rw_mask);
       // read
-      if (rw_mask & 1) {
+      if (rw_mask->value & 1) {
         HandleUseVar(buffer_var);
       }
-      if (rw_mask & 2) {
+      if (rw_mask->value & 2) {
         HandleWriteVar(buffer_var);
       }
       this->VisitExpr(op->args[2]);
diff --git a/src/tir/transforms/lower_thread_allreduce.cc 
b/src/tir/transforms/lower_thread_allreduce.cc
index 85744d1..467e220 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -163,8 +163,10 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
       CHECK_GE(e.scope.dim_index, 0)
           << "vthread do not work with cross thread reduction";
       if (e.scope.rank == 1) {
-        CHECK(arith::GetConstInt(attr->value, &(e.extent)))
+        const auto* ptr = attr->value.as<IntImmNode>();
+        CHECK(ptr)
             << "Need constant extent for reduce set " << iv;
+        e.extent = static_cast<int>(ptr->value);
         if (reduce_set.count(iv->var.get())) {
           vred.push_back(e);
           ++nmatch;
diff --git a/src/tir/transforms/lower_tvm_builtin.cc 
b/src/tir/transforms/lower_tvm_builtin.cc
index 71ba468..76cfc43 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -30,7 +30,6 @@
 #include <unordered_set>
 
 #include "../pass/ir_util.h"
-#include "../../arith/compute_expr.h"
 
 namespace tvm {
 namespace tir {
@@ -94,11 +93,10 @@ class BuiltinLower : public StmtExprMutator {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
     op = stmt.as<AllocateNode>();
     // Get constant allocation bound.
-    int64_t dev_type;
     int64_t nbytes = GetVectorBytes(op->dtype);
     if (device_type_.defined()) {
-      if (arith::GetConst(device_type_, &dev_type)) {
-        if (dev_type == kDLCPU) {
+      if (const auto* dev_type = device_type_.as<IntImmNode>()) {
+        if (dev_type->value == kDLCPU) {
           int32_t constant_size = op->constant_allocation_size();
           if (constant_size > 0 && constant_size * nbytes < 
runtime::kMaxStackAlloca) {
             return stmt;
diff --git a/src/tir/transforms/lower_warp_memory.cc 
b/src/tir/transforms/lower_warp_memory.cc
index ac08e6f..96a901f 100644
--- a/src/tir/transforms/lower_warp_memory.cc
+++ b/src/tir/transforms/lower_warp_memory.cc
@@ -37,7 +37,7 @@
 
 #include <unordered_set>
 
-#include "../pass/ir_util.h"
+#include "../../arith/pattern_match.h"
 #include "../../arith/compute_expr.h"
 #include "../../runtime/thread_storage_scope.h"
 
@@ -121,11 +121,11 @@ class WarpStoreCoeffFinder : private StmtVisitor {
       if (op->value.dtype().lanes() == 1) {
         UpdatePattern(op->index);
       } else {
-        PrimExpr base;
-        CHECK(GetRamp1Base(op->index, op->value.dtype().lanes(), &base))
+        arith::PVar<PrimExpr> base;
+        CHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(op->index))
             << "LowerWarpMemory failed due to store index=" << op->index
             << ", can only handle continuous store";
-        UpdatePattern(base);
+        UpdatePattern(base.Eval());
       }
     } else {
       StmtVisitor::VisitStmt_(op);
@@ -137,19 +137,18 @@ class WarpStoreCoeffFinder : private StmtVisitor {
         arith::DetectLinearEquation(index, {warp_index_});
     CHECK_EQ(m.size(), 2U)
         << "LowerWarpMemory failed due to store index=" << index;
-    int coeff = 0;
     PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]);
-
-    CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0)
+    const auto* mcoeff_as_int = mcoeff.as<IntImmNode>();
+    CHECK(mcoeff_as_int && mcoeff_as_int->value > 0)
         << "LowerWarpMemory failed due to store index=" << index
         << ", require positive constant coefficient on warp index " << 
warp_index_
         << " but get " << mcoeff;
 
     if (warp_coeff_ != 0) {
-      CHECK_EQ(warp_coeff_, coeff)
+      CHECK_EQ(warp_coeff_, mcoeff_as_int->value)
           << "LowerWarpMemory failed due to two different store coefficient to 
warp index";
     } else {
-      warp_coeff_ = coeff;
+      warp_coeff_ = mcoeff_as_int->value;
     }
   }
 
@@ -158,7 +157,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {
   // the warp index
   Var warp_index_;
   // the coefficient
-  int warp_coeff_{0};
+  int64_t warp_coeff_{0};
   // analyzer.
   arith::Analyzer* analyzer_;
 };
@@ -184,10 +183,10 @@ class WarpIndexFinder : private StmtVisitor {
     if (op->attr_key == attr::thread_extent) {
       IterVar iv = Downcast<IterVar>(op->node);
       if (iv->thread_tag == "threadIdx.x") {
-        int value = 0;
-        CHECK(arith::GetConstInt(op->value, &value) &&
-              value <= warp_size_ &&
-              warp_size_ % value == 0)
+        auto* value_as_int = op->value.as<IntImmNode>();
+        CHECK(value_as_int &&
+              value_as_int->value <= warp_size_ &&
+              warp_size_ % value_as_int->value == 0)
             << "Expect threadIdx.x 's size to be no larger than, and a factor 
of"
             << " warp size(" << warp_size_ << ")" << " to enable warp memory"
             << " but get " << op->value << " instead";
@@ -198,7 +197,7 @@ class WarpIndexFinder : private StmtVisitor {
               << "Please create it using thread_axis once and reuse the axis "
               << "across multiple binds in the same kernel";
         } else {
-          width_ = value;
+          width_ = value_as_int->value;
           warp_index_ = iv;
         }
       }
@@ -281,9 +280,12 @@ class WarpAccessRewriter : protected StmtExprMutator {
   // in this access pattern.
   std::pair<PrimExpr, PrimExpr> SplitIndexByGroup(const PrimExpr& index) {
     if (index.dtype().lanes() != 1) {
-      PrimExpr base, local_index, group;
-      CHECK(GetRamp1Base(index, index.dtype().lanes(), &base));
-      std::tie(local_index, group) = SplitIndexByGroup(base);
+      PrimExpr local_index, group;
+
+      arith::PVar<PrimExpr> base;
+      CHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index));
+
+      std::tie(local_index, group) = SplitIndexByGroup(base.Eval());
       local_index =
           RampNode::make(local_index, make_const(local_index.dtype(), 1), 
index.dtype().lanes());
       return std::make_pair(local_index, group);
diff --git a/src/tir/transforms/storage_flatten.cc 
b/src/tir/transforms/storage_flatten.cc
index 99d437d..e5b2ad8 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -326,13 +326,14 @@ class StorageFlattener : public StmtExprMutator {
       << "Prefetch dim should be the same as buffer dim";
 
     int block_size = 1,
-        elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(),
-        shape = 0;
+        elem_cnt = cache_line_size_ / e.buffer->dtype.bytes();
 
     int starts = op->bounds.size() - 1;
-    while (starts > 0 && arith::GetConstInt(e.buffer->shape[starts], &shape)
-        && elem_cnt >= block_size * shape) {
-      block_size *= shape;
+
+    while (starts > 0) {
+      auto* shape_as_int = e.buffer->shape[starts].as<IntImmNode>();
+      if (shape_as_int == nullptr || block_size * shape_as_int->value > 
elem_cnt) break;
+      block_size *= static_cast<int>(shape_as_int->value);
       starts--;
     }
     PrimExpr stride(elem_cnt / block_size);
diff --git a/src/tir/transforms/unroll_loop.cc 
b/src/tir/transforms/unroll_loop.cc
index 9ff5429..5eb244d 100644
--- a/src/tir/transforms/unroll_loop.cc
+++ b/src/tir/transforms/unroll_loop.cc
@@ -51,16 +51,13 @@ class LoopUnroller : public StmtExprMutator {
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == "pragma_auto_unroll_max_step") {
-      int value = 0;
-      CHECK(arith::GetConstInt(op->value, &value));
+      int value = static_cast<int>(Downcast<Integer>(op->value)->value);
       std::swap(value, auto_max_step_);
       Stmt ret = this->VisitStmt(op->body);
       std::swap(value, auto_max_step_);
       return ret;
     } else if (op->attr_key == "pragma_unroll_explicit") {
-      int value = 0;
-      CHECK(arith::GetConstInt(op->value, &value));
-      bool explicit_unroll = value;
+      bool explicit_unroll = Downcast<Integer>(op->value)->value;
       std::swap(explicit_unroll, explicit_unroll_);
       Stmt ret = this->VisitStmt(op->body);
       std::swap(explicit_unroll, explicit_unroll_);
diff --git a/src/tir/transforms/vectorize_loop.cc 
b/src/tir/transforms/vectorize_loop.cc
index cc4361d..2299573 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -519,12 +519,11 @@ class LoopVectorizer : public StmtMutator {
   Stmt VisitStmt_(const ForNode* op) final {
     if (op->for_type == ForType::Vectorized) {
       CHECK(is_zero(op->min));
-      int lanes = 0;
-      bool succ = arith::GetConstInt(op->extent, &lanes);
-      if (!succ || lanes < 1) {
+      auto* extent_as_int = op->extent.as<IntImmNode>();
+      if (!extent_as_int || extent_as_int->value < 1) {
         LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
       }
-      return Vectorizer(op->loop_var, lanes)(op->body);
+      return Vectorizer(op->loop_var, 
static_cast<int>(extent_as_int->value))(op->body);
     } else {
       return StmtMutator::VisitStmt_(op);
     }

Reply via email to