This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 4cebb1c [ARITH] Remove legacy const pattern functions (#5387)
4cebb1c is described below
commit 4cebb1c76303697df64624b87f4b1e8a654dbbf2
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Apr 20 20:53:32 2020 -0700
[ARITH] Remove legacy const pattern functions (#5387)
---
src/arith/compute_expr.h | 21 ---------------
src/arith/pattern_match.h | 11 ++++++++
src/target/llvm/codegen_llvm.cc | 39 ++++++++++++++--------------
src/target/source/codegen_c.cc | 19 +++++++-------
src/target/spirv/codegen_spirv.cc | 6 ++---
src/tir/pass/arg_binder.cc | 6 ++---
src/tir/pass/ir_util.h | 16 ------------
src/tir/transforms/inject_virtual_thread.cc | 8 +++---
src/tir/transforms/lower_thread_allreduce.cc | 4 ++-
src/tir/transforms/lower_tvm_builtin.cc | 6 ++---
src/tir/transforms/lower_warp_memory.cc | 38 ++++++++++++++-------------
src/tir/transforms/storage_flatten.cc | 11 ++++----
src/tir/transforms/unroll_loop.cc | 7 ++---
src/tir/transforms/vectorize_loop.cc | 7 +++--
14 files changed, 87 insertions(+), 112 deletions(-)
diff --git a/src/arith/compute_expr.h b/src/arith/compute_expr.h
index f842780..fddd34b 100644
--- a/src/arith/compute_expr.h
+++ b/src/arith/compute_expr.h
@@ -56,27 +56,6 @@ template<typename Op>
inline PrimExpr ComputeReduce(
const Array<PrimExpr>& values, PrimExpr empty_value);
-inline bool GetConst(PrimExpr e, int64_t* out) {
- if (e.dtype().is_vector()) return false;
- const int64_t* v = tir::as_const_int(e);
- if (v) {
- *out = *v; return true;
- } else {
- return false;
- }
-}
-
-// get a small constant int
-inline bool GetConstInt(PrimExpr e, int* out) {
- int64_t v1 = 0;
- if (GetConst(e, &v1)) {
- if (v1 > static_cast<int64_t>(
- std::numeric_limits<int>::max())) return false;
- *out = static_cast<int>(v1); return true;
- }
- return false;
-}
-
template<>
inline PrimExpr Compute<tir::AddNode>(PrimExpr a, PrimExpr b) {
return a + b;
diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h
index e81b088..0920ed3 100644
--- a/src/arith/pattern_match.h
+++ b/src/arith/pattern_match.h
@@ -574,6 +574,17 @@ ramp(const Pattern<TBase>& base,
base.derived(), stride.derived(), lanes.derived());
}
+template<typename TBase>
+inline PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>
+ramp(const Pattern<TBase>& base,
+ int stride,
+ int lanes) {
+ return PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>(
+ base.derived(),
+ PConstWithTypeLike<TBase>(base.derived(), stride),
+ PConst<int>(lanes));
+}
+
/*!
* \brief Pattern broadcast expression.
* \tparam TA The pattern type of the value.
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 14302ef..820a20c 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -30,6 +30,7 @@
#include "codegen_llvm.h"
#include "codegen_cpu.h"
+#include "../../arith/pattern_match.h"
#include "../build_common.h"
namespace tvm {
namespace codegen {
@@ -363,27 +364,27 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
md_builder_->createTBAAStructTagNode(meta, meta, 0));
return;
}
- int base = 0, width = 0;
+
+ int64_t base = 0, width = 0;
+ arith::PVar<IntImm> pbase, pstride;
+ arith::PVar<int> planes;
// create meta-data for alias analysis
// Use a group of binary tree ranges of memory banks.
if (index.defined()) {
- const RampNode* ramp = index.as<RampNode>();
- if (ramp) {
- int base, stride;
- if (arith::GetConstInt(ramp->base, &base) &&
- arith::GetConstInt(ramp->stride, &stride)) {
- int xwith = ramp->lanes * stride;
- width = 1;
- while (width < xwith) {
- width *= 2;
- }
- while (base % width) {
- base -= base % width;
- width *= 2;
- }
+ if (arith::ramp(pbase, pstride, planes).Match(index)) {
+ base = pbase.Eval()->value;
+ int64_t xwith = planes.Eval() * pstride.Eval()->value;
+ width = 1;
+ while (width < xwith) {
+ width *= 2;
}
- } else {
- if (arith::GetConstInt(index, &base)) width = 1;
+ while (base % width) {
+ base -= base % width;
+ width *= 2;
+ }
+ } else if (auto* ptr = index.as<tir::IntImmNode>()) {
+ width = 1;
+ base = ptr->value;
}
}
llvm::MDNode* meta = md_tbaa_root_;
@@ -394,8 +395,8 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
// create a tree-shape access structure.
if (width != 0) {
- for (int w = 1024; w >= width; w /= 2) {
- int b = (base / w) * w;
+ for (int64_t w = 1024; w >= width; w /= 2) {
+ int64_t b = (base / w) * w;
std::stringstream os;
os << buffer << ".w" << w << ".b" << b;
meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 6e7784c..84604b8 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -23,8 +23,8 @@
#include <iomanip>
#include <cctype>
#include "codegen_c.h"
+#include "../../arith/pattern_match.h"
#include "../../arith/compute_expr.h"
-#include "../../tir/pass/ir_util.h"
namespace tvm {
namespace codegen {
@@ -198,8 +198,8 @@ std::string CodeGenC::GetBufferRef(
// optimize for case where it is in register,
if (HandleTypeMatch(buffer, t) && !is_vol) {
// optimize for constant access
- int offset;
- if (arith::GetConstInt(index, &offset)) {
+ if (auto* ptr = index.as<tir::IntImmNode>()) {
+ int64_t offset = ptr->value;
CHECK_EQ(offset % t.lanes(), 0)
<< "Find unaligned vector load to a vector type";
os << vid << '[' << (offset / t.lanes()) << ']';
@@ -663,9 +663,10 @@ void CodeGenC::VisitExpr_(const LoadNode* op,
std::ostream& os) { // NOLINT(*)
} else {
CHECK(is_one(op->predicate))
<< "predicated load is not supported";
- PrimExpr base;
- if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) {
- std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base);
+
+ arith::PVar<PrimExpr> base;
+ if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) {
+ std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(),
base.Eval());
HandleVolatileLoads(ref, op, os);
} else {
std::ostringstream svalue_expr;
@@ -708,10 +709,10 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
} else {
CHECK(is_one(op->predicate))
<< "Predicated store is not supported";
- PrimExpr base;
- if (GetRamp1Base(op->index, t.lanes(), &base)) {
+ arith::PVar<PrimExpr> base;
+ if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
std::string value = this->PrintExpr(op->value);
- this->PrintVecStore(op->buffer_var.get(), t, base, value);
+ this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
} else {
// The assignment below introduces side-effect, and the resulting value
cannot
// be reused across multiple expression, thus a new scope is needed
diff --git a/src/target/spirv/codegen_spirv.cc
b/src/target/spirv/codegen_spirv.cc
index d4631aa..5d05b08 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -103,11 +103,11 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(
spirv::Value v;
if (ts.rank == 1) {
v = builder_->GetLocalID(ts.dim_index);
- int size = 0;
- CHECK(arith::GetConstInt(extent, &size))
+ auto* sizeptr = extent.as<tir::IntImmNode>();
+ CHECK(sizeptr)
<< "SPIRV only allows constant thread group size " << " get " <<
extent;
CHECK_LT(ts.dim_index, 3);
- workgroup_size_[ts.dim_index] = static_cast<uint32_t>(size);
+ workgroup_size_[ts.dim_index] = static_cast<uint32_t>(sizeptr->value);
} else {
v = builder_->GetWorkgroupID(ts.dim_index);
}
diff --git a/src/tir/pass/arg_binder.cc b/src/tir/pass/arg_binder.cc
index 51a6d8b..2f03047 100644
--- a/src/tir/pass/arg_binder.cc
+++ b/src/tir/pass/arg_binder.cc
@@ -291,9 +291,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
}
// Byte_offset field.
int data_bytes = GetVectorBytes(buffer->dtype);
- int64_t const_offset;
- if (arith::GetConst(buffer->elem_offset, &const_offset)) {
- Bind_(make_const(DataType::UInt(64), const_offset * data_bytes),
+
+ if (const auto* const_offset = buffer->elem_offset.as<IntImmNode>()) {
+ Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes),
TVMArrayGet(DataType::UInt(64), handle,
intrinsic::kArrByteOffset),
arg_name + ".byte_offset", true);
} else {
diff --git a/src/tir/pass/ir_util.h b/src/tir/pass/ir_util.h
index d8da61f..a167433 100644
--- a/src/tir/pass/ir_util.h
+++ b/src/tir/pass/ir_util.h
@@ -174,22 +174,6 @@ inline int GetTempAllocaAlignment(DataType type, int32_t
const_size) {
return align;
}
-/*!
- * \brief Pattern match index to Ramp with stride=1
- * This is a common pattern in continuous memory load.
- * \param index The index formula
- * \param lanes number of lanes in the ramp
- * \param base The result base.
- * \return true if pattern match success and store the base to base.
- */
-inline bool GetRamp1Base(PrimExpr index, int lanes, PrimExpr *base) {
- const RampNode* r = index.as<RampNode>();
- if (!r) return false;
- if (!is_one(r->stride)) return false;
- CHECK_EQ(r->lanes, lanes);
- *base = r->base;
- return true;
-}
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_PASS_IR_UTIL_H_
diff --git a/src/tir/transforms/inject_virtual_thread.cc
b/src/tir/transforms/inject_virtual_thread.cc
index c70962d..24747a4 100644
--- a/src/tir/transforms/inject_virtual_thread.cc
+++ b/src/tir/transforms/inject_virtual_thread.cc
@@ -57,15 +57,15 @@ class ExprTouched final : public StmtExprVisitor {
}
void VisitExpr_(const CallNode *op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
- int rw_mask = 0;
- CHECK(arith::GetConstInt(op->args[4], &rw_mask));
+ const auto* rw_mask = op->args[4].as<IntImmNode>();
const VarNode* buffer_var = op->args[1].as<VarNode>();
CHECK(buffer_var);
+ CHECK(rw_mask);
// read
- if (rw_mask & 1) {
+ if (rw_mask->value & 1) {
HandleUseVar(buffer_var);
}
- if (rw_mask & 2) {
+ if (rw_mask->value & 2) {
HandleWriteVar(buffer_var);
}
this->VisitExpr(op->args[2]);
diff --git a/src/tir/transforms/lower_thread_allreduce.cc
b/src/tir/transforms/lower_thread_allreduce.cc
index 85744d1..467e220 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -163,8 +163,10 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
CHECK_GE(e.scope.dim_index, 0)
<< "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) {
- CHECK(arith::GetConstInt(attr->value, &(e.extent)))
+ const auto* ptr = attr->value.as<IntImmNode>();
+ CHECK(ptr)
<< "Need constant extent for reduce set " << iv;
+ e.extent = static_cast<int>(ptr->value);
if (reduce_set.count(iv->var.get())) {
vred.push_back(e);
++nmatch;
diff --git a/src/tir/transforms/lower_tvm_builtin.cc
b/src/tir/transforms/lower_tvm_builtin.cc
index 71ba468..76cfc43 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -30,7 +30,6 @@
#include <unordered_set>
#include "../pass/ir_util.h"
-#include "../../arith/compute_expr.h"
namespace tvm {
namespace tir {
@@ -94,11 +93,10 @@ class BuiltinLower : public StmtExprMutator {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
// Get constant allocation bound.
- int64_t dev_type;
int64_t nbytes = GetVectorBytes(op->dtype);
if (device_type_.defined()) {
- if (arith::GetConst(device_type_, &dev_type)) {
- if (dev_type == kDLCPU) {
+ if (const auto* dev_type = device_type_.as<IntImmNode>()) {
+ if (dev_type->value == kDLCPU) {
int32_t constant_size = op->constant_allocation_size();
if (constant_size > 0 && constant_size * nbytes <
runtime::kMaxStackAlloca) {
return stmt;
diff --git a/src/tir/transforms/lower_warp_memory.cc
b/src/tir/transforms/lower_warp_memory.cc
index ac08e6f..96a901f 100644
--- a/src/tir/transforms/lower_warp_memory.cc
+++ b/src/tir/transforms/lower_warp_memory.cc
@@ -37,7 +37,7 @@
#include <unordered_set>
-#include "../pass/ir_util.h"
+#include "../../arith/pattern_match.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
@@ -121,11 +121,11 @@ class WarpStoreCoeffFinder : private StmtVisitor {
if (op->value.dtype().lanes() == 1) {
UpdatePattern(op->index);
} else {
- PrimExpr base;
- CHECK(GetRamp1Base(op->index, op->value.dtype().lanes(), &base))
+ arith::PVar<PrimExpr> base;
+ CHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(op->index))
<< "LowerWarpMemory failed due to store index=" << op->index
<< ", can only handle continuous store";
- UpdatePattern(base);
+ UpdatePattern(base.Eval());
}
} else {
StmtVisitor::VisitStmt_(op);
@@ -137,19 +137,18 @@ class WarpStoreCoeffFinder : private StmtVisitor {
arith::DetectLinearEquation(index, {warp_index_});
CHECK_EQ(m.size(), 2U)
<< "LowerWarpMemory failed due to store index=" << index;
- int coeff = 0;
PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]);
-
- CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0)
+ const auto* mcoeff_as_int = mcoeff.as<IntImmNode>();
+ CHECK(mcoeff_as_int && mcoeff_as_int->value > 0)
<< "LowerWarpMemory failed due to store index=" << index
<< ", require positive constant coefficient on warp index " <<
warp_index_
<< " but get " << mcoeff;
if (warp_coeff_ != 0) {
- CHECK_EQ(warp_coeff_, coeff)
+ CHECK_EQ(warp_coeff_, mcoeff_as_int->value)
<< "LowerWarpMemory failed due to two different store coefficient to
warp index";
} else {
- warp_coeff_ = coeff;
+ warp_coeff_ = mcoeff_as_int->value;
}
}
@@ -158,7 +157,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {
// the warp index
Var warp_index_;
// the coefficient
- int warp_coeff_{0};
+ int64_t warp_coeff_{0};
// analyzer.
arith::Analyzer* analyzer_;
};
@@ -184,10 +183,10 @@ class WarpIndexFinder : private StmtVisitor {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
- int value = 0;
- CHECK(arith::GetConstInt(op->value, &value) &&
- value <= warp_size_ &&
- warp_size_ % value == 0)
+ auto* value_as_int = op->value.as<IntImmNode>();
+ CHECK(value_as_int &&
+ value_as_int->value <= warp_size_ &&
+ warp_size_ % value_as_int->value == 0)
<< "Expect threadIdx.x 's size to be no larger than, and a factor
of"
<< " warp size(" << warp_size_ << ")" << " to enable warp memory"
<< " but get " << op->value << " instead";
@@ -198,7 +197,7 @@ class WarpIndexFinder : private StmtVisitor {
<< "Please create it using thread_axis once and reuse the axis "
<< "across multiple binds in the same kernel";
} else {
- width_ = value;
+ width_ = value_as_int->value;
warp_index_ = iv;
}
}
@@ -281,9 +280,12 @@ class WarpAccessRewriter : protected StmtExprMutator {
// in this access pattern.
std::pair<PrimExpr, PrimExpr> SplitIndexByGroup(const PrimExpr& index) {
if (index.dtype().lanes() != 1) {
- PrimExpr base, local_index, group;
- CHECK(GetRamp1Base(index, index.dtype().lanes(), &base));
- std::tie(local_index, group) = SplitIndexByGroup(base);
+ PrimExpr local_index, group;
+
+ arith::PVar<PrimExpr> base;
+ CHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index));
+
+ std::tie(local_index, group) = SplitIndexByGroup(base.Eval());
local_index =
RampNode::make(local_index, make_const(local_index.dtype(), 1),
index.dtype().lanes());
return std::make_pair(local_index, group);
diff --git a/src/tir/transforms/storage_flatten.cc
b/src/tir/transforms/storage_flatten.cc
index 99d437d..e5b2ad8 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -326,13 +326,14 @@ class StorageFlattener : public StmtExprMutator {
<< "Prefetch dim should be the same as buffer dim";
int block_size = 1,
- elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(),
- shape = 0;
+ elem_cnt = cache_line_size_ / e.buffer->dtype.bytes();
int starts = op->bounds.size() - 1;
- while (starts > 0 && arith::GetConstInt(e.buffer->shape[starts], &shape)
- && elem_cnt >= block_size * shape) {
- block_size *= shape;
+
+ while (starts > 0) {
+ auto* shape_as_int = e.buffer->shape[starts].as<IntImmNode>();
+ if (shape_as_int == nullptr || block_size * shape_as_int->value >
elem_cnt) break;
+ block_size *= static_cast<int>(shape_as_int->value);
starts--;
}
PrimExpr stride(elem_cnt / block_size);
diff --git a/src/tir/transforms/unroll_loop.cc
b/src/tir/transforms/unroll_loop.cc
index 9ff5429..5eb244d 100644
--- a/src/tir/transforms/unroll_loop.cc
+++ b/src/tir/transforms/unroll_loop.cc
@@ -51,16 +51,13 @@ class LoopUnroller : public StmtExprMutator {
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
- int value = 0;
- CHECK(arith::GetConstInt(op->value, &value));
+ int value = static_cast<int>(Downcast<Integer>(op->value)->value);
std::swap(value, auto_max_step_);
Stmt ret = this->VisitStmt(op->body);
std::swap(value, auto_max_step_);
return ret;
} else if (op->attr_key == "pragma_unroll_explicit") {
- int value = 0;
- CHECK(arith::GetConstInt(op->value, &value));
- bool explicit_unroll = value;
+ bool explicit_unroll = Downcast<Integer>(op->value)->value;
std::swap(explicit_unroll, explicit_unroll_);
Stmt ret = this->VisitStmt(op->body);
std::swap(explicit_unroll, explicit_unroll_);
diff --git a/src/tir/transforms/vectorize_loop.cc
b/src/tir/transforms/vectorize_loop.cc
index cc4361d..2299573 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -519,12 +519,11 @@ class LoopVectorizer : public StmtMutator {
Stmt VisitStmt_(const ForNode* op) final {
if (op->for_type == ForType::Vectorized) {
CHECK(is_zero(op->min));
- int lanes = 0;
- bool succ = arith::GetConstInt(op->extent, &lanes);
- if (!succ || lanes < 1) {
+ auto* extent_as_int = op->extent.as<IntImmNode>();
+ if (!extent_as_int || extent_as_int->value < 1) {
LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
}
- return Vectorizer(op->loop_var, lanes)(op->body);
+ return Vectorizer(op->loop_var,
static_cast<int>(extent_as_int->value))(op->body);
} else {
return StmtMutator::VisitStmt_(op);
}