This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 30b34d2521 [TIR] More flexible buffer compaction (#14021)
30b34d2521 is described below
commit 30b34d2521dd8b3fa99b9bd0297114cc05698e02
Author: wrongtest <[email protected]>
AuthorDate: Sat Apr 29 05:06:13 2023 +0800
[TIR] More flexible buffer compaction (#14021)
Hi there, the change want to enforce the power and flexiblity of
`CompactBufferAllocation` pass in two aspects:
1. Free of pass order
- It could work on both s-tir (with opaque blocks) and lowered tir. For
example, now one could be able to invoke `LoopPartition` and then make
partitioned tile aware buffer compactions.
- We test existing cases to ensure that
`(LowerOpaqueBlock . CompactBufferAllocation) (mod) ==
(CompactBufferAllocation . LowerOpaqueBlock) (mod)`
2. Allow "non-strict" compaction
- Add an option `is_strict` defaults to `True`, to denote that during
compaction we should respect the original buffer shape bound. Thus the
compacted buffer region never exceed the original.
- If set to `False`, the "compacted" shape is totally determined to
cover buffer region accesses. Thus it may become larger than the original
shape. This change the original semantic for out-of-bound accesses but may be
helpful in certain usages.
- If loop domain changed (eg, align the loop dim or remove the extra
predicate), the accessed buffer region may grow, the pass could provide a
fallback implementation to adapt the buffer regions.
About implementation issues:
- To achieve [1]
- Buffer decl point:
- (s-tir) `T.alloc_buffer` in block form is handled without any
change.
- (lowered) `T.decl_buffer` is newly handled. We assume it is at
the proper position to dominate all accesses.
- Predicates
- Predicates in `T.where` and `IfThenElse`, `T.if_then_else` are
handled uniformlly now. We would try simply resolve the predicate to update
loop domain as before. But on failures, now we keep them into pending stack.
- The visit logic is very alike `IRVisitorWithAnalyzer`'s, but
since `arith::Analyzer` and `arith::IntSet` are independent components now,
the change do not introduce `IRVisitorWithAnalyzer`'.
- Buffer accesses
- There is no difference between s-tir and lowered form. If there
are pending predicates on access point, always try affine iter analysis to get
more tight relaxed region.
- Thread binding
- To work on lowered form, `attr::thread_extent` and
`attr::virtual_thread` are handled to record neccesary info for thread relaxion.
- Buffer aliasing
- No change to `T.match_buffer` handling, and we explicitly disable
compaction to alised buffers in lowered form.
- Dim alignment
- We utilize annotation field in `T.allocate` and preserve
`attr::buffer_dim_align` in `LowerOpaqueBlock` pass when it convert
`T.alloc_buffer` to `T.allocate`. Thus the compaction could collect the dim
alignment information in lowered form.
- To achieve [2]
- It is much direct. `SimplifyAndNarrowBufferRegionFromNDIntSet` would
not intersect accessed region with original shape if the `is_strict` option is
overrided to false.
---
include/tvm/tir/transform.h | 5 +-
python/tvm/tir/transform/transform.py | 10 +-
src/tir/analysis/block_access_region_detector.cc | 10 +-
src/tir/schedule/primitive.h | 5 +-
src/tir/schedule/primitive/block_annotate.cc | 1 +
src/tir/transforms/compact_buffer_region.cc | 487 ++++--
src/tir/transforms/ir_utils.cc | 101 +-
src/tir/transforms/ir_utils.h | 44 +-
src/tir/transforms/lower_opaque_block.cc | 34 +-
tests/python/unittest/test_tir_buffer.py | 23 -
.../test_tir_transform_compact_buffer_region.py | 1730 +++++++++++---------
11 files changed, 1470 insertions(+), 980 deletions(-)
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 35aa392db2..8dee176277 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -456,10 +456,11 @@ TVM_DLL Pass ConvertBlocksToOpaque();
*
* \endcode
*
- *
+ * \param is_strict ensure the compacted shape always smaller than the
original shape.
+ * otherwise it allows to grow the shape to match actual accessed buffer
regions.
* \return The pass.
*/
-TVM_DLL Pass CompactBufferAllocation();
+TVM_DLL Pass CompactBufferAllocation(bool is_strict = true);
/*!
* This pass legalizes packed calls by wrapping their arguments into TVMValues
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index 1df2ac76b5..0437f1a887 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -747,7 +747,7 @@ def ConvertBlocksToOpaque():
return _ffi_api.ConvertBlocksToOpaque() # type: ignore
-def CompactBufferAllocation():
+def CompactBufferAllocation(is_strict: bool = True):
"""Compact the buffer access region. by removing the buffer regions
that are not accessed, i.e. narrowing the buffer shape and adjust
the access region if necessary.
@@ -783,13 +783,19 @@ def CompactBufferAllocation():
for j in range(0, 16):
C[i, j] = B[0, j] + 1
+ Parameters
+ ----------
+ is_strict : bool
+ Ensure the compacted shape to be always smaller than the original
shape.
+ Otherwise it allows to grow the shape to match actual accessed buffer
regions.
+
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
- return _ffi_api.CompactBufferAllocation() # type: ignore
+ return _ffi_api.CompactBufferAllocation(is_strict) # type: ignore
def LowerMatchBuffer():
diff --git a/src/tir/analysis/block_access_region_detector.cc
b/src/tir/analysis/block_access_region_detector.cc
index 057cec475d..a15cecabdd 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -60,6 +60,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
/*! \brief Extra iteration range hint for free vars */
std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
+ /*! \brief Unresolved conditions within current scope. */
+ std::vector<PrimExpr> pending_conditions_;
/*! \brief The buffers that the current block reads */
std::vector<Buffer> read_buffers_;
/*! \brief The buffers that the current block writes */
@@ -164,12 +166,12 @@ void BlockReadWriteDetector::VisitStmt_(const
IfThenElseNode* op) {
VisitExpr(op->condition);
{
// Visit then branch
- With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_,
true);
+ With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_,
&pending_conditions_);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case) {
// Visit else branch
- With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_,
false);
+ With<ConditionalBoundsContext> ctx(!op->condition, &dom_map_, &hint_map_,
&pending_conditions_);
StmtExprVisitor::VisitStmt(op->else_case.value());
}
}
@@ -207,12 +209,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode*
op) {
VisitExpr(op->args[0]);
{
// Visit then branch
- With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_,
true);
+ With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_,
&pending_conditions_);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
- With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_,
false);
+ With<ConditionalBoundsContext> ctx(!op->args[0], &dom_map_, &hint_map_,
&pending_conditions_);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 5f5591ac45..78d1cab05c 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -481,10 +481,7 @@ TVM_DLL StmtSRef DecomposeReduction(ScheduleState self,
const StmtSRef& block_sr
*/
TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int
factor_axis);
/******** Schedule: Block annotation ********/
-/*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor,
offset) */
-using StorageAlignTuple = Array<Integer>;
-/*! \brief A list of StorageAlignTuple, used by StorageAlign */
-using StorageAlignAnnotation = Array<StorageAlignTuple>;
+
/*!
* \brief Set alignment requirement for specific dimension such that
* stride[axis] == k * factor + offset for some k. This is useful to
set memory layout for
diff --git a/src/tir/schedule/primitive/block_annotate.cc
b/src/tir/schedule/primitive/block_annotate.cc
index 3f1789b3d6..917e37c9a7 100644
--- a/src/tir/schedule/primitive/block_annotate.cc
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -18,6 +18,7 @@
*/
#include <tvm/tir/expr.h>
+#include "../../transforms/ir_utils.h"
#include "../utils.h"
namespace tvm {
diff --git a/src/tir/transforms/compact_buffer_region.cc
b/src/tir/transforms/compact_buffer_region.cc
index a047191f16..6c6eb169a5 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -28,6 +28,7 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
+#include <numeric>
#include <stack>
#include "../../support/arena.h"
@@ -41,51 +42,6 @@ namespace tir {
using support::NDIntSet;
-/*!
- * \brief simplify and return the region collected by NDIntSet. return the
original
- * buffer shape if the int_set is empty.
- */
-Region SimplifyAndNarrowBufferRegionFromNDIntSet(
- const NDIntSet& nd_int_set, const Array<PrimExpr>& original_shape,
arith::Analyzer* analyzer,
- const std::vector<const ForNode*>& ancestor_loops) {
- Array<Range> result;
- result.reserve(nd_int_set.size());
- for (size_t i = 0; i < nd_int_set.size(); ++i) {
- const arith::IntSet& int_set = nd_int_set[i];
- Range range = int_set.CoverRange(Range(/*begin=*/0,
/*end=*/original_shape[i]));
- PrimExpr min = analyzer->Simplify(tvm::max(0, range->min));
- PrimExpr extent = range->extent;
-
- // Apply stronger symbolic proof to help us remove symbolic min here.
- if (!analyzer->CanProveLessEqualThanSymbolicShapeValue(range->extent,
original_shape[i])) {
- extent = tvm::min(original_shape[i], range->extent);
- }
-
- extent = analyzer->Simplify(extent);
-
- // Check the buffer region is not loop dependent, since loop dependent
- // allocation is not supported yet.
- auto is_loop_var = [&ancestor_loops](const VarNode* v) {
- return std::any_of(ancestor_loops.begin(), ancestor_loops.end(),
- [v](const ForNode* n) { return n->loop_var.get() ==
v; });
- };
- if (UsesVar(extent, is_loop_var)) {
- // try estimate a constant upperbound on region's extent
- int64_t upperbound = analyzer->const_int_bound(extent)->max_value;
- if (upperbound != arith::ConstIntBound::kPosInf) {
- extent = make_const(extent->dtype, upperbound);
- } else {
- // or else we have to fallback to full region
- min = make_zero(original_shape[i]->dtype);
- extent = original_shape[i];
- }
- }
-
- result.push_back(Range::FromMinExtent(min, extent));
- }
- return result;
-}
-
/*! \brief a more constrained bound estimate for n-dimentional int set */
NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
const std::unordered_map<const VarNode*, arith::IntSet>&
dom_map,
@@ -103,6 +59,44 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map);
}
+/*!
+ * \brief Collect buffer aliasing information.
+ */
+class Var2BufferCollector : public StmtExprVisitor {
+ public:
+ /*! \brief Map the buffer var to all aliased buffers. */
+ std::unordered_map<Var, std::unordered_set<Buffer, ObjectPtrHash,
ObjectPtrEqual>, ObjectPtrHash,
+ ObjectPtrEqual>
+ var2buffer_;
+
+ private:
+ void VisitStmt_(const BufferStoreNode* op) final {
+ var2buffer_[op->buffer->data].insert(op->buffer);
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const BufferLoadNode* op) final {
+ var2buffer_[op->buffer->data].insert(op->buffer);
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const BlockNode* op) final {
+ for (const Buffer& buffer : op->alloc_buffers) {
+ var2buffer_[buffer->data].insert(buffer);
+ }
+ for (const MatchBufferRegion& region : op->match_buffers) {
+ var2buffer_[region->buffer->data].insert(region->buffer);
+ var2buffer_[region->source->buffer->data].insert(region->source->buffer);
+ }
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitStmt_(const DeclBufferNode* op) final {
+ var2buffer_[op->buffer->data].insert(op->buffer);
+ StmtExprVisitor::VisitStmt_(op);
+ }
+};
+
/*!
* \brief Collect the access region of each buffer.
* \note The param buffer regions will not be collected.
@@ -110,10 +104,17 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
class BufferAccessRegionCollector : public StmtExprVisitor {
public:
static std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>
Collect(
- const PrimFunc& f) {
- BufferAccessRegionCollector collector;
- collector(f->body);
- return std::move(collector.buffer_access_region_);
+ const PrimFunc& f, bool collect_inbound) {
+ BufferAccessRegionCollector region_collector(collect_inbound);
+
+ // collect buffer var to aliased buffer mapping
+ Var2BufferCollector var2buffer_collector;
+ var2buffer_collector(f->body);
+ std::swap(region_collector.var2buffer_, var2buffer_collector.var2buffer_);
+
+ // collect buffer access regions
+ region_collector(f->body);
+ return std::move(region_collector.buffer_access_region_);
}
private:
@@ -127,7 +128,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
: buffer(buffer), accessed_region(region) {}
};
- BufferAccessRegionCollector() = default;
+ explicit BufferAccessRegionCollector(bool collect_inbound) :
collect_inbound_(collect_inbound) {}
/**************** Visitor overload ****************/
@@ -138,18 +139,23 @@ class BufferAccessRegionCollector : public
StmtExprVisitor {
void VisitExpr_(const BufferLoadNode* op) final {
VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices));
+ StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const VarNode* op) final { VisitBufferVar(GetRef<Var>(op)); }
void VisitStmt_(const ForNode* op) final {
- ancestor_loops_.push_back(op);
Range loop_range = Range::FromMinExtent(op->min, op->extent);
+ IterVar iter = op->kind == ForKind::kThreadBinding
+ ? IterVar(Range(), op->loop_var,
IterVarType::kThreadIndex,
+ op->thread_binding.value()->thread_tag)
+ : IterVar(Range(), op->loop_var, IterVarType::kDataPar);
+ ancestor_iters_.push_back(iter);
dom_analyzer_.Bind(op->loop_var, loop_range);
dom_map_.emplace(op->loop_var.get(), arith::IntSet::FromRange(loop_range));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(op->loop_var.get());
- ancestor_loops_.pop_back();
+ ancestor_iters_.pop_back();
}
void VisitStmt_(const LetStmtNode* op) final {
@@ -181,12 +187,14 @@ class BufferAccessRegionCollector : public
StmtExprVisitor {
StmtExprVisitor::VisitExpr(op->condition);
{
// Visit then branch
- With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_,
true);
+ With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_,
+ &pending_conditions_);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case) {
// Visit else branch
- With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_,
false);
+ With<ConditionalBoundsContext> ctx(!op->condition, &dom_map_, &hint_map_,
+ &pending_conditions_);
StmtExprVisitor::VisitStmt(op->else_case.value());
}
}
@@ -197,12 +205,14 @@ class BufferAccessRegionCollector : public
StmtExprVisitor {
StmtExprVisitor::VisitExpr(op->args[0]);
{
// Visit then branch
- With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_,
true);
+ With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_,
+ &pending_conditions_);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
- With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_,
false);
+ With<ConditionalBoundsContext> ctx(!op->args[0], &dom_map_, &hint_map_,
+ &pending_conditions_);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
@@ -227,9 +237,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
auto& regions = access_annotations_[p.first];
p.second.swap(regions);
}
- // Step 2. Record relax position of ancestor_loops_ into
buffer_var_in_scope_
+ // Step 2. Record relax position of ancestor_loops_
for (const Buffer& buffer : op->alloc_buffers) {
- buffer_var_in_scope_.emplace(buffer->data, std::make_pair(buffer,
ancestor_loops_.size()));
+ VisitBufferDef(buffer->data);
}
// Step 3. Visit match buffers
for (const MatchBufferRegion& region : op->match_buffers) {
@@ -248,37 +258,77 @@ class BufferAccessRegionCollector : public
StmtExprVisitor {
}
// Step 6. Update buffer_access_region_ from relaxed_accesses_ for inner
buffers.
for (const Buffer& buffer : op->alloc_buffers) {
- auto it = relaxed_accesses_.find(buffer);
- ICHECK(it != relaxed_accesses_.end())
- << buffer << " is allocated but not accessed within block scope";
- const NDIntSet& nd_int_set = it->second;
- buffer_access_region_[buffer] =
SimplifyAndNarrowBufferRegionFromNDIntSet(
- nd_int_set, buffer->shape, &dom_analyzer_, ancestor_loops_);
+ ICHECK_EQ(var2buffer_[buffer->data].size(), 1)
+ << "Block allocation buffer shoud not be alised";
+ SimplifyAndNarrowBufferRegionFromNDIntSet(buffer);
}
}
void VisitStmt_(const BlockRealizeNode* op) final {
- PrimExpr cur_predicate = predicate_in_scope;
- predicate_in_scope = op->predicate;
+ With<ConditionalBoundsContext> ctx(op->predicate, &dom_map_, &hint_map_,
&pending_conditions_);
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitStmt_(const AllocateNode* op) final {
+ auto it = var2buffer_.find(op->buffer_var);
+
+ // Do not make compaction when the buffer def and
+ // the allocation is not one-to-one with the same dtype.
+ if (it == var2buffer_.end() || it->second.size() > 1) {
+ return StmtExprVisitor::VisitStmt_(op);
+ }
+ const Buffer& buffer = *it->second.begin();
+ if (buffer->dtype != op->dtype) {
+ return StmtExprVisitor::VisitStmt_(op);
+ }
+
+ // Step 0. Record relax position of ancestor_loops_
+ VisitBufferDef(op->buffer_var);
+ // Step 1. Visit block body recursively
+ StmtExprVisitor::VisitStmt(op->body);
+ // Step 2. Update buffer_access_region_ from relaxed_accesses_ for inner
buffers.
+ SimplifyAndNarrowBufferRegionFromNDIntSet(buffer);
+ }
+
+ void VisitStmt_(const AttrStmtNode* op) final {
+ if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::virtual_thread) {
+ IterVar iter = Downcast<IterVar>(op->node);
+ ancestor_iters_.push_back(iter);
+ Range dom = iter->dom;
+ if (!dom.defined()) { // dom is empty for legacy te schedule
+ dom = Range::FromMinExtent(0, op->value);
+ }
+ dom_analyzer_.Bind(iter->var, dom);
+ dom_map_.emplace(iter->var.get(), arith::IntSet::FromRange(dom));
+ StmtExprVisitor::VisitStmt_(op);
+ dom_map_.erase(iter->var.get());
+ ancestor_iters_.pop_back();
+ return;
+ }
StmtExprVisitor::VisitStmt_(op);
- predicate_in_scope = cur_predicate;
}
/**************** Helper functions ****************/
+ /*! \brief Record information on the buffer defining point. */
+ void VisitBufferDef(const Var& buffer_data) {
+ auto it = buffer_scope_depth_.find(buffer_data);
+ ICHECK(it == buffer_scope_depth_.end()) << buffer_data << " has duplicate
definitions";
+ buffer_scope_depth_.insert(it, {buffer_data, ancestor_iters_.size()});
+ }
+
void VisitBufferAccess(const BufferRegion& buffer_region) {
- const BufferNode* buffer = buffer_region->buffer.get();
- auto it = buffer_var_in_scope_.find(buffer->data);
- if (it != buffer_var_in_scope_.end()) {
- const Buffer& buffer = it->second.first;
- size_t n_ancestor_loops = it->second.second;
+ const Buffer& buffer = buffer_region->buffer;
+ auto it = buffer_scope_depth_.find(buffer->data);
+ if (it != buffer_scope_depth_.end()) {
+ size_t n_ancestor_loops = it->second;
// Step 1. Stop ancestor loop vars out of the allocation block from
// being relaxed unless NeedRelaxThread() is true.
std::vector<arith::IntSet> non_relaxed(n_ancestor_loops);
for (size_t i = 0; i < n_ancestor_loops; ++i) {
- const ForNode* loop = ancestor_loops_[i];
- const VarNode* v = loop->loop_var.get();
- if (NeedRelaxThread(GetRef<For>(loop),
runtime::StorageScope::Create(buffer.scope()))) {
+ const IterVar& iter = ancestor_iters_[i];
+ const VarNode* v = iter->var.get();
+ if (NeedRelaxThread(iter,
runtime::StorageScope::Create(buffer.scope()))) {
continue;
}
auto dom_it = dom_map_.find(v);
@@ -288,12 +338,21 @@ class BufferAccessRegionCollector : public
StmtExprVisitor {
dom_map_.erase(dom_it);
}
// Step 2. Relax the access region
+ auto normalize_pred = [](const PrimExpr& pred) {
+ if (pred->dtype.is_bool()) return pred;
+ return pred != make_zero(pred->dtype);
+ };
+ PrimExpr predicate = dom_analyzer_.Simplify(
+ std::accumulate(pending_conditions_.begin(),
pending_conditions_.end(), const_true(),
+ [normalize_pred](const PrimExpr& x, const PrimExpr&
y) {
+ return normalize_pred(x) && normalize_pred(y);
+ }));
NDIntSet nd_int_set =
- NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_,
&dom_analyzer_);
+ NDIntSetEval(buffer_region->region, predicate, dom_map_,
&dom_analyzer_);
// Step 3. Restore the non-relaxed ancestor loops domain
for (size_t i = 0; i < n_ancestor_loops; ++i) {
- const VarNode* v = ancestor_loops_[i]->loop_var.get();
+ const VarNode* v = ancestor_iters_[i]->var.get();
dom_map_.emplace(v, non_relaxed[i]);
}
// Step 4. Update relaxed_accesses_ dict
@@ -307,9 +366,11 @@ class BufferAccessRegionCollector : public StmtExprVisitor
{
}
void VisitBufferVar(const Var& var) {
- auto it = buffer_var_in_scope_.find(var);
- if (it != buffer_var_in_scope_.end()) {
- const Buffer& buffer = it->second.first;
+ auto it = var2buffer_.find(var);
+ if (it == var2buffer_.end()) {
+ return;
+ }
+ for (const Buffer& buffer : it->second) {
auto annotation_it = access_annotations_.find(buffer);
if (annotation_it != access_annotations_.end()) {
// opaque buffer has explicit accessed region annotations
@@ -322,92 +383,135 @@ class BufferAccessRegionCollector : public
StmtExprVisitor {
}
}
- /*! \brief Check whether the thread binding loop should be relaxed with
given storage scope. */
- static bool NeedRelaxThread(const For& loop, const runtime::StorageScope&
scope) {
- if (loop->kind != ForKind::kThreadBinding) {
+ /*! \brief Check whether the thread binding iter should be relaxed with
given storage scope. */
+ static bool NeedRelaxThread(const IterVar& iter, const
runtime::StorageScope& scope) {
+ if (iter->iter_type != IterVarType::kThreadIndex) {
return false;
}
- ICHECK(loop->thread_binding.defined());
- const String& thread_tag = loop->thread_binding.value()->thread_tag;
+ ICHECK(iter->thread_tag.defined());
// When there is warp memory
// threadIdx.x must be set to be warp index.
- return CanRelaxStorageUnderThread(scope,
runtime::ThreadScope::Create(thread_tag));
+ return CanRelaxStorageUnderThread(scope,
runtime::ThreadScope::Create((iter->thread_tag)));
+ }
+
+ /*!
+ * \brief simplify and narrow down the region collected by NDIntSet.
+ * Update the `relaxed_accesses_` dict. If `collect_inbound_` is true,
+ * the result region would never exceed the original buffer shape.
+ */
+ void SimplifyAndNarrowBufferRegionFromNDIntSet(const Buffer& buffer) {
+ auto it = relaxed_accesses_.find(buffer);
+ ICHECK(it != relaxed_accesses_.end())
+ << buffer << " is allocated but not accessed within block scope";
+
+ const Array<PrimExpr>& original_shape = buffer->shape;
+ const NDIntSet& nd_int_set = it->second;
+ Array<Range>& result_region = buffer_access_region_[buffer];
+ result_region.resize(nd_int_set.size());
+
+ for (size_t i = 0; i < nd_int_set.size(); ++i) {
+ const arith::IntSet& int_set = nd_int_set[i];
+ Range original =
+ Range(/*begin=*/make_zero(original_shape[i]->dtype),
/*end=*/original_shape[i]);
+ Range range = int_set.CoverRange(original);
+ PrimExpr min, extent;
+ if (collect_inbound_) {
+ min = dom_analyzer_.Simplify(tvm::max(0, range->min));
+ extent = range->extent;
+ // Apply stronger symbolic proof to help us remove symbolic min here.
+ if (!dom_analyzer_.CanProveLessEqualThanSymbolicShapeValue(extent,
original_shape[i])) {
+ extent = tvm::min(original_shape[i], range->extent);
+ }
+ extent = dom_analyzer_.Simplify(extent);
+ } else {
+ min = dom_analyzer_.Simplify(range->min);
+ extent = dom_analyzer_.Simplify(range->extent);
+ }
+
+ // We check the buffer extent is pure and not loop dependent, since loop
dependent
+ // or data dependent allocation is not supported yet. Otherwise we should
+ // fallback to use original buffer shape.
+ if (SideEffect(extent) > CallEffectKind::kPure) {
+ result_region.Set(i, original);
+ continue;
+ }
+ auto is_loop_var = [this](const VarNode* v) {
+ return std::any_of(ancestor_iters_.begin(), ancestor_iters_.end(),
+ [v](const IterVar& n) { return n->var.get() == v;
});
+ };
+ if (UsesVar(extent, is_loop_var)) {
+ // try estimate a constant upperbound on region's extent
+ int64_t upperbound = dom_analyzer_.const_int_bound(extent)->max_value;
+ if (upperbound != arith::ConstIntBound::kPosInf) {
+ extent = make_const(extent->dtype, upperbound);
+ } else {
+ result_region.Set(i, original);
+ continue;
+ }
+ }
+ result_region.Set(i, Range::FromMinExtent(min, extent));
+ }
}
/**************** Class members ****************/
- /*! \brief The loops from the current node up to the root. */
- std::vector<const ForNode*> ancestor_loops_;
+ /*! \brief Only collect accessed region within original buffer shape bound.
*/
+ bool collect_inbound_{true};
+
+ /*! \brief The iteration scopes from the current node up to the root. */
+ std::vector<IterVar> ancestor_iters_;
/*!
- * \brief The vars of the buffer allocated under the current block.
- * Map each buffer var to (buffer_obj, n_ancester_loop) pair, where
- * n_ancester_loop is the loop num out of the current block.
- * Tancestor_loops_[0: n_ancester_loop] should not be relaxed when
+ * \brief Map each buffer var to the n_ancester_loop. which is the loop
depth at the
+ * define point. ancestor_loops_[0: n_ancester_loop] should not be relaxed
when
* we evaluate this buffer's access regions.
*/
- std::unordered_map<Var, std::pair<Buffer, size_t>, ObjectPtrHash,
ObjectPtrEqual>
- buffer_var_in_scope_;
- /*! \brief The block predicate of current scope */
- PrimExpr predicate_in_scope{true};
+ std::unordered_map<Var, size_t, ObjectPtrHash, ObjectPtrEqual>
buffer_scope_depth_;
+
+ /*! \brief Map the buffer var to all aliased buffers. */
+ std::unordered_map<Var, std::unordered_set<Buffer, ObjectPtrHash,
ObjectPtrEqual>, ObjectPtrHash,
+ ObjectPtrEqual>
+ var2buffer_;
/*! \brief The map from loop vars to their iter range. */
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
/*! \brief Extra map from free vars to their iter range hints. */
std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
+ /*! \brief Unresolved conditions within current scope. */
+ std::vector<PrimExpr> pending_conditions_;
/*! \brief The analyzer aware of loop domains. */
arith::Analyzer dom_analyzer_;
/*! \brief The map from Buffer to it's relaxed access set. */
std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual>
relaxed_accesses_;
- /*! \brief The map from Buffer to it entire access region, used for
returning. */
+
+ /*!
+ * \brief The map from Buffer to it entire access region, used for returning.
+ * The entire access region should get updated on the buffer's define point
+ * and we sanity check that every buffer is defined only once.
+ */
std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>
buffer_access_region_;
+
/*! \brief The map from Buffer to it's access regions annotated by current
block. */
std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash,
ObjectPtrEqual>
access_annotations_;
};
-/*! \brief Collect storage alignment information from block annotations. */
-class StorageAlignCollector : public StmtVisitor {
- public:
- static std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash,
ObjectPtrEqual> Collect(
- const PrimFunc& f) {
- StorageAlignCollector collector;
- collector(f->body);
- return std::move(collector.storage_align_);
- }
-
- private:
- void VisitStmt_(const BlockNode* op) final {
- auto it = op->annotations.find(attr::buffer_dim_align);
- if (it != op->annotations.end()) {
- auto storage_align_annotation =
Downcast<StorageAlignAnnotation>((*it).second);
- for (const auto& storage_align_tuple : storage_align_annotation) {
- int buffer_index = storage_align_tuple[0]->value;
- const Buffer& buffer = op->writes[buffer_index]->buffer;
- storage_align_[buffer].push_back(storage_align_tuple);
- }
- }
- StmtVisitor::VisitStmt_(op);
- }
-
- /*! \brief The map from Buffer to its storage alignment information. */
- std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash,
ObjectPtrEqual> storage_align_;
-};
-
/*! \brief Reallocate the buffers with minimal region. */
class BufferCompactor : public StmtExprMutator {
public:
static Stmt Compact(
const PrimFunc& f,
const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>&
regions,
- const std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash,
ObjectPtrEqual>&
+ const std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash,
ObjectPtrEqual>&
storage_align) {
- std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual>
buffer_info;
-
+ // collect buffer allocation info for no-alias buffers
+ std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual>
buffer_info;
for (const auto& kv : regions) {
const Buffer& buffer = kv.first;
+
+ // set dim alignment info
Region region = kv.second;
- BufferAllocInfo buffer_alloc_info(std::move(region));
- auto it = storage_align.find(buffer);
+ BufferAllocInfo alloc_info;
+ auto it = storage_align.find(buffer->data);
if (it != storage_align.end()) {
std::vector<DimAlignInfo> dim_aligns(buffer->shape.size());
for (const StorageAlignTuple& dim_align : (*it).second) {
@@ -417,9 +521,33 @@ class BufferCompactor : public StmtExprMutator {
int offset = dim_align[3]->value;
dim_aligns.at(dim) = {factor, offset};
}
- buffer_alloc_info.dim_aligns = std::move(dim_aligns);
+ alloc_info.dim_aligns = std::move(dim_aligns);
+ }
+
+ // prepare new buffer
+ Array<PrimExpr> shape = region.Map([](const Range& range) { return
range->extent; });
+ Array<PrimExpr> strides;
+ if (alloc_info.dim_aligns.size()) {
+ ICHECK(alloc_info.dim_aligns.size() == shape.size());
+ strides.resize(shape.size());
+ PrimExpr stride = make_const(shape[0].dtype(), 1);
+ for (size_t i = shape.size(); i != 0; --i) {
+ size_t dim = i - 1;
+ if (alloc_info.dim_aligns[dim].align_factor != 0) {
+ PrimExpr factor = make_const(stride.dtype(),
alloc_info.dim_aligns[dim].align_factor);
+ PrimExpr offset = make_const(stride.dtype(),
alloc_info.dim_aligns[dim].align_offset);
+ stride = stride + indexmod(factor + offset - indexmod(stride,
factor), factor);
+ }
+ strides.Set(dim, stride);
+ stride = stride * shape[dim];
+ }
}
- buffer_info.emplace(buffer, std::move(buffer_alloc_info));
+ ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get());
+ n->shape = std::move(shape);
+ n->strides = std::move(strides);
+ alloc_info.new_buffer = Buffer(std::move(n));
+ alloc_info.region = region;
+ buffer_info.emplace(buffer->data, std::move(alloc_info));
}
BufferCompactor compactor(std::move(buffer_info));
Stmt stmt = compactor(f->body);
@@ -445,12 +573,10 @@ class BufferCompactor : public StmtExprMutator {
* \note The value if NullOpt if the buffer do not need reallocate (e.g
parameter buffer).
*/
Buffer new_buffer;
-
- explicit BufferAllocInfo(Region region) : region(std::move(region)) {}
};
explicit BufferCompactor(
- std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash,
ObjectPtrEqual> buffer_info)
+ std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual>
buffer_info)
: buffer_info_(std::move(buffer_info)) {}
Stmt VisitStmt_(const BufferStoreNode* _op) final {
@@ -471,7 +597,8 @@ class BufferCompactor : public StmtExprMutator {
// Step 0. Check there is no Init part.
ICHECK(!op->init.defined());
// Step 1. Reallocate and rewrite alloc_buffers, also update
BufferAllocInfo.
- Array<Buffer> alloc_buffers = RewriteAllocBuffer(op->alloc_buffers);
+ Array<Buffer> alloc_buffers =
+ op->alloc_buffers.Map([this](const Buffer& buf) { return
RewriteAllocBuffer(buf); });
// Step 2. Recursively rewrite BufferLoad/BufferStore.
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
// Step 3. Update block signature.
@@ -483,47 +610,45 @@ class BufferCompactor : public StmtExprMutator {
return std::move(block);
}
- Array<Buffer> RewriteAllocBuffer(const Array<Buffer>& buffers) {
- Array<Buffer> result;
- result.reserve(buffers.size());
- for (const Buffer& buffer : buffers) {
- auto it = buffer_info_.find(buffer);
- ICHECK(it != buffer_info_.end());
- BufferAllocInfo& info = it->second;
- Array<PrimExpr> shape;
- shape.reserve(info.region.size());
- for (const Range& range : info.region) {
- shape.push_back(range->extent);
- }
- Array<PrimExpr> strides;
- if (info.dim_aligns.size()) {
- ICHECK(info.dim_aligns.size() == shape.size());
- strides.resize(shape.size());
- PrimExpr stride = make_const(shape[0].dtype(), 1);
- for (size_t i = shape.size(); i != 0; --i) {
- size_t dim = i - 1;
- if (info.dim_aligns[dim].align_factor != 0) {
- PrimExpr factor = make_const(stride.dtype(),
info.dim_aligns[dim].align_factor);
- PrimExpr offset = make_const(stride.dtype(),
info.dim_aligns[dim].align_offset);
- stride = stride + indexmod(factor + offset - indexmod(stride,
factor), factor);
- }
- strides.Set(dim, stride);
- stride = stride * shape[dim];
- }
- }
- ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get());
- n->shape = std::move(shape);
- n->strides = std::move(strides);
- info.new_buffer = Buffer(std::move(n));
- result.push_back(info.new_buffer);
+ Stmt VisitStmt_(const DeclBufferNode* op) final {
+ Buffer new_buffer = RewriteAllocBuffer(op->buffer);
+ auto n = CopyOnWrite(op);
+ n->buffer = std::move(new_buffer);
+ n->body = VisitStmt(op->body);
+ return DeclBuffer(n);
+ }
+
+ Stmt VisitStmt_(const AllocateNode* op) final {
+ Allocate allocate = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
+ auto it = buffer_info_.find(allocate->buffer_var);
+ if (it == buffer_info_.end()) {
+ return std::move(allocate);
}
- return result;
+ // Rewrite allocation shape if the corresponding buffer is in the
buffer_info_
+ // dict and the dtype is consistent, which denotes there are no buffer
aliasing
+ // and the compaction is safe.
+ const Buffer& new_buffer = it->second.new_buffer;
+ if (op->dtype != new_buffer->dtype) {
+ return std::move(allocate);
+ }
+ Array<PrimExpr> new_shape = GetBufferAllocationShape(new_buffer);
+ auto n = allocate.CopyOnWrite();
+ ICHECK(n->buffer_var.same_as(new_buffer->data));
+ n->extents = new_shape;
+ return std::move(allocate);
+ }
+
+ Buffer RewriteAllocBuffer(const Buffer& buffer) {
+ auto it = buffer_info_.find(buffer->data);
+ if (it != buffer_info_.end()) {
+ return it->second.new_buffer;
+ }
+ return buffer;
}
void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) const {
- auto it = buffer_info_.find(*buffer);
+ auto it = buffer_info_.find((*buffer)->data);
if (it == buffer_info_.end()) {
- // Skip if the buffer is parameter
return;
}
const BufferAllocInfo& info = it->second;
@@ -539,7 +664,7 @@ class BufferCompactor : public StmtExprMutator {
}
void RewriteBufferRegion(Buffer* buffer, Region* region) const {
- auto it = buffer_info_.find(*buffer);
+ auto it = buffer_info_.find((*buffer)->data);
if (it == buffer_info_.end()) {
// Skip if the buffer is parameter
return;
@@ -580,18 +705,16 @@ class BufferCompactor : public StmtExprMutator {
*match_buffers = std::move(result);
}
- /*! \brief The allocation information about each buffer. */
- std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual>
buffer_info_;
+ /*! \brief Map buffer var to the allocation information about each buffer. */
+ std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual>
buffer_info_;
};
-PrimFunc CompactBufferAllocation(PrimFunc f) {
+PrimFunc CompactBufferAllocation(PrimFunc f, bool is_strict) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
- std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
- BufferAccessRegionCollector::Collect(f);
- std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash,
ObjectPtrEqual>
- storage_align = StorageAlignCollector::Collect(f);
+ auto region = BufferAccessRegionCollector::Collect(f,
/*collect_inbound=*/is_strict);
+ auto storage_align = CollectStorageAlignAnnotation(f->body);
fptr->body = BufferCompactor::Compact(f, region, storage_align);
return f;
} else {
@@ -601,9 +724,9 @@ PrimFunc CompactBufferAllocation(PrimFunc f) {
namespace transform {
-Pass CompactBufferAllocation() {
+Pass CompactBufferAllocation(bool is_strict) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
- return CompactBufferAllocation(std::move(f));
+ return CompactBufferAllocation(std::move(f), is_strict);
};
return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation", {});
}
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index e80772fda3..b591016fd8 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -75,6 +75,11 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
ICHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
+ } else if (const auto* decl_buffer = s.as<DeclBufferNode>()) {
+ auto n = make_object<DeclBufferNode>(*decl_buffer);
+ ICHECK(is_no_op(n->body));
+ n->body = body;
+ body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
@@ -387,6 +392,18 @@ String GetPtrStorageScope(Var buffer_var) {
return ptr_type->storage_scope;
}
+Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer) {
+ Array<PrimExpr> alloc_shape = buffer->shape;
+ if (buffer->strides.size()) {
+ ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
+ for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
+ ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i])));
+ alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
+ }
+ }
+ return alloc_shape;
+}
+
Array<PrimExpr> ConvertIndices(const MatchBufferRegion& match_buffer,
const Array<PrimExpr>& indices) {
const Buffer& target = match_buffer->buffer;
@@ -438,11 +455,14 @@ Bool IsFromLegacyTESchedule(PrimFunc f) {
return from_legacy_te_schedule.value();
}
-Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
+Optional<arith::IntConstraints> ConditionalBoundsContext::TrySolveCondition() {
// extract equations and related vars from condition expression.
// currently only extract simple integral equations which could be solvable.
arith::Analyzer analyzer;
- PrimExpr condition = is_true_branch_ ? condition_ :
analyzer.Simplify(!condition_);
+ PrimExpr condition = analyzer.Simplify(condition_);
+ if (is_const_int(condition)) {
+ return NullOpt;
+ }
Array<PrimExpr> equations;
Array<Var> vars;
std::function<void(const PrimExpr&)> fvisit = [&equations, &vars,
&fvisit](const PrimExpr& e) {
@@ -485,7 +505,7 @@ Map<Var, Range>
ConditionalBoundsContext::GetVarBoundsFromCondition() {
};
fvisit(condition);
if (equations.empty() || vars.empty()) {
- return Map<Var, Range>();
+ return NullOpt;
}
// build dom ranges for related vars
Map<Var, Range> ranges;
@@ -506,22 +526,35 @@ Map<Var, Range>
ConditionalBoundsContext::GetVarBoundsFromCondition() {
}
// solve constraints
arith::IntConstraints constraint(vars, ranges, equations);
- auto result = arith::SolveInequalitiesToRange(constraint);
- return result->ranges;
+ arith::IntConstraints result = arith::SolveInequalitiesToRange(constraint);
+ return std::move(result);
}
ConditionalBoundsContext::ConditionalBoundsContext(
const PrimExpr& condition, std::unordered_map<const VarNode*,
arith::IntSet>* relax_map,
- std::unordered_map<const VarNode*, arith::IntSet>* hint_map, bool
is_true_branch)
+ std::unordered_map<const VarNode*, arith::IntSet>* hint_map,
+ std::vector<PrimExpr>* pending_conditions)
: condition_(condition),
relax_map_(relax_map),
hint_map_(hint_map),
- is_true_branch_(is_true_branch) {}
+ pending_conditions_(pending_conditions),
+ origin_pending_conditions_num_(pending_conditions->size()) {}
void ConditionalBoundsContext::EnterWithScope() {
- for (const auto& p : GetVarBoundsFromCondition()) {
- const auto* var = p.first.get();
- arith::IntSet new_dom = arith::IntSet::FromRange(p.second);
+ Optional<arith::IntConstraints> constraints = TrySolveCondition();
+ if (!constraints.defined()) {
+ // fail to process the condition, add to unresolved
+ pending_conditions_->push_back(condition_);
+ return;
+ }
+ for (const PrimExpr& unresolved : constraints.value()->relations) {
+ // add partially unresolved conditions
+ pending_conditions_->push_back(unresolved);
+ }
+ // update solved var ranges
+ for (const auto& kv : constraints.value()->ranges) {
+ const VarNode* var = kv.first.get();
+ arith::IntSet new_dom = arith::IntSet::FromRange(kv.second);
auto relax_it = relax_map_->find(var);
if (relax_it != relax_map_->end()) {
// this is a bound for relaxed var
@@ -542,6 +575,7 @@ void ConditionalBoundsContext::EnterWithScope() {
}
void ConditionalBoundsContext::ExitWithScope() {
+ pending_conditions_->resize(origin_pending_conditions_num_);
for (const auto& p : origin_map_) {
const auto* var = p.first;
auto relax_it = relax_map_->find(var);
@@ -568,6 +602,53 @@ std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const
AttrStmtNode* op) {
return std::make_pair(op->value, inner->value);
}
+/*! \brief Collect storage alignment information from annotations. */
+class StorageAlignCollector : public StmtVisitor {
+ private:
+ friend std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash,
ObjectPtrEqual>
+ CollectStorageAlignAnnotation(const Stmt& body);
+
+ /*! \brief For s-stir, the alignment annotations reside in block
annotations. */
+ void VisitStmt_(const BlockNode* op) final {
+ auto it = op->annotations.find(attr::buffer_dim_align);
+ if (it != op->annotations.end()) {
+ auto storage_align_annotation =
Downcast<StorageAlignAnnotation>((*it).second);
+ for (const auto& storage_align_tuple : storage_align_annotation) {
+ int buffer_index = storage_align_tuple[0]->value;
+ const Buffer& buffer = op->writes[buffer_index]->buffer;
+ storage_align_[buffer->data].push_back(storage_align_tuple);
+ }
+ }
+ StmtVisitor::VisitStmt_(op);
+ }
+
+ /*! \brief For lowered tir, the alignment annotations reside in allocate
annotations. */
+ void VisitStmt_(const AllocateNode* op) final {
+ auto it = op->annotations.find(attr::buffer_dim_align);
+ if (it != op->annotations.end()) {
+ auto storage_align_annotation =
Downcast<StorageAlignAnnotation>((*it).second);
+ for (const auto& storage_align_tuple : storage_align_annotation) {
+ int buffer_index = storage_align_tuple[0]->value;
+ // the first buffer idx info is meaningless for allocate
+ // stmt and should set as negative intentionally.
+ ICHECK_EQ(buffer_index, -1);
+ storage_align_[op->buffer_var].push_back(storage_align_tuple);
+ }
+ }
+ StmtVisitor::VisitStmt_(op);
+ }
+
+ /*! \brief The map from buffer var to its storage alignment information. */
+ std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash,
ObjectPtrEqual> storage_align_;
+};
+
+std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>
+CollectStorageAlignAnnotation(const Stmt& body) {
+ StorageAlignCollector collector;
+ collector(body);
+ return std::move(collector.storage_align_);
+}
+
namespace transform {
Pass ConvertSSA() {
auto pass_func = [](IRModule mod, PassContext ctx) {
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index 6915a0e3ac..afaff34472 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -25,6 +25,7 @@
#define TVM_TIR_TRANSFORMS_IR_UTILS_H_
#include <tvm/arith/int_set.h>
+#include <tvm/arith/int_solver.h>
#include <tvm/runtime/device_api.h>
#include <tvm/support/with.h>
#include <tvm/tir/builtin.h>
@@ -224,6 +225,13 @@ Array<PrimExpr> ConvertIndices(const MatchBufferRegion&
match_buffer,
*/
Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region&
region);
+/*!
+ * \brief Get stride aware buffer allocation shape from buffer.
+ * \param buffer The buffer object.
+ * \return shape The shape considering buffer strides.
+ */
+Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer);
+
/*!
* \brief Check if a given PrimFunc originated from a TE schedule.
*
@@ -235,12 +243,12 @@ Region ConvertRegion(const MatchBufferRegion&
match_buffer, const Region& region
Bool IsFromLegacyTESchedule(PrimFunc f);
/*!
- *\brief Context helper to update domain map within conditional scope.
- *
- * Assume the condition is `0 <= i && i < 9` and global domain of i is [0,
20], thus `bounds[i]` is
- * [0, 8]. Then `With<ConditionalBoundsContext> ctx(condition, &relax_map,
&hint_map, true)` step
- *into scope where dom_map[i] is [0, 8] and `With<ConditionalBoundsContext>
ctx(condition,
- *&relax_map, &hint_map, false)` step into scope where dom_map[i] is [9, 20]
+ * \brief Context helper to update domain map within conditional scope.
+ * Assume the condition is `0 <= i && i < 9` and domain of i is [0, 20], Then
+ * `With<ConditionalBoundsContext> ctx(condition, &relax_map, &hint_map,
&constraints)`
+ * step into scope where dom_map[i] is [0, 8]; and
+ * `With<ConditionalBoundsContext> ctx(!condition, &relax_map, &hint_map,
&constraints)`
+ * step into scope where dom_map[i] is [9, 20]
*/
class ConditionalBoundsContext {
private:
@@ -250,17 +258,17 @@ class ConditionalBoundsContext {
* \param condition The condition holds on true branch.
* \param relax_map The domain map for relaxed vars to update.
* \param hint_map The domain map for free vars to update.
- * \param is_true_branch Whether step into the branch where condition bounds
holds.
+ * \param pending_conditions The stack of unresolved constraints.
*/
ConditionalBoundsContext(const PrimExpr& condition,
std::unordered_map<const VarNode*, arith::IntSet>*
relax_map,
std::unordered_map<const VarNode*, arith::IntSet>*
hint_map,
- bool is_true_branch);
+ std::vector<PrimExpr>* pending_constraints);
void EnterWithScope();
void ExitWithScope();
/*! \brief Helper to solve related variable's bound within conditional
scope.*/
- Map<Var, Range> GetVarBoundsFromCondition();
+ Optional<arith::IntConstraints> TrySolveCondition();
/*! \brief the condition holds on true branch. */
const PrimExpr& condition_;
@@ -268,10 +276,12 @@ class ConditionalBoundsContext {
std::unordered_map<const VarNode*, arith::IntSet>* relax_map_;
/*! \brief domain map for free vars to update */
std::unordered_map<const VarNode*, arith::IntSet>* hint_map_;
- /*! \brief whether is on true branch */
- bool is_true_branch_;
+ /*! \brief unresolved condition stack */
+ std::vector<PrimExpr>* pending_conditions_;
/*! \brief used to record and restore original var bounds */
std::unordered_map<const VarNode*, arith::IntSet> origin_map_;
+ /*! \brief used to record unresolved conditions num. */
+ size_t origin_pending_conditions_num_;
};
// Information of tensor core fragment.
@@ -321,6 +331,18 @@ std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const
AttrStmtNode* op);
*/
PrimFunc BindParams(PrimFunc f, const Array<runtime::NDArray>& constants);
+/*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor,
offset) */
+using StorageAlignTuple = Array<Integer>;
+/*! \brief A list of StorageAlignTuple, used by StorageAlign */
+using StorageAlignAnnotation = Array<StorageAlignTuple>;
+/*!
+ * \brief Collect storage alignment annotations for all buffer vars within
body.
+ * \param body The stmt to collect.
+ * \return The result dict from buffer var to storage align annotations.
+ */
+std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>
+CollectStorageAlignAnnotation(const Stmt& body);
+
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_
diff --git a/src/tir/transforms/lower_opaque_block.cc
b/src/tir/transforms/lower_opaque_block.cc
index 9a702db69f..86892433b4 100644
--- a/src/tir/transforms/lower_opaque_block.cc
+++ b/src/tir/transforms/lower_opaque_block.cc
@@ -33,6 +33,13 @@ namespace tir {
* \brief Remove Block to ensure that the TIR can not be scheduled again.
*/
class OpaqueBlockLower : public StmtExprMutator {
+ public:
+ static Stmt Rewrite(Stmt body) {
+ OpaqueBlockLower lower;
+ lower.storage_align_ = CollectStorageAlignAnnotation(body);
+ return lower(std::move(body));
+ }
+
private:
Stmt VisitStmt_(const BlockRealizeNode* op) final {
// We have convert blocks into opaque blocks in previous passes.
@@ -49,16 +56,22 @@ class OpaqueBlockLower : public StmtExprMutator {
// Step 3. Handle allocations in reverse order
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
const Buffer& buffer = new_block->alloc_buffers[i - 1];
- Array<PrimExpr> new_shape = buffer->shape;
- if (buffer->strides.size()) {
- ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
- for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
- ICHECK(is_zero(floormod(buffer->strides[i - 1],
buffer->strides[i])));
- new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
+ Array<PrimExpr> allocation_shape = GetBufferAllocationShape(buffer);
+ body = DeclBuffer(buffer, std::move(body));
+ Map<String, ObjectRef> allocate_annotations;
+ auto it = storage_align_.find(buffer->data);
+ if (it != storage_align_.end()) {
+ StorageAlignAnnotation allocate_aligns;
+ for (auto tuple : it->second) {
+ ICHECK_EQ(tuple.size(), 4);
+ tuple.Set(0, -1);
+ allocate_aligns.push_back(tuple);
}
+ allocate_annotations.Set(attr::buffer_dim_align, allocate_aligns);
}
- body = DeclBuffer(buffer, std::move(body));
- body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(),
std::move(body));
+
+ body = Allocate(buffer->data, buffer->dtype, allocation_shape,
const_true(), std::move(body),
+ allocate_annotations);
}
// Step 4. Handle annotations, block annotations are not preserved by
default.
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
@@ -181,13 +194,16 @@ class OpaqueBlockLower : public StmtExprMutator {
/*! \brief Attr keys to preserve into loop annotations. */
std::unordered_set<std::string> preserved_annotations_;
+
+ /*! \brief The map from buffer var to its storage alignment information. */
+ std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash,
ObjectPtrEqual> storage_align_;
};
PrimFunc LowerOpaqueBlock(PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
auto fptr = f.CopyOnWrite();
- fptr->body = OpaqueBlockLower()(std::move(fptr->body));
+ fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
return f;
} else {
return f;
diff --git a/tests/python/unittest/test_tir_buffer.py
b/tests/python/unittest/test_tir_buffer.py
index 95ad81db88..e3b63d9315 100644
--- a/tests/python/unittest/test_tir_buffer.py
+++ b/tests/python/unittest/test_tir_buffer.py
@@ -99,29 +99,6 @@ def test_buffer_offset_of():
tvm.ir.assert_structural_equal(offset, [n * 2 + 103])
-def test_buffer_vload_nullptr():
- var = tvm.tir.Var("v", dtype="int32")
- buf = tvm.tir.decl_buffer((1,), name="buf")
- buf_load = tvm.tir.expr.BufferLoad(buffer=buf,
indices=tvm.runtime.convert([0]))
- buf_load_stmt = tvm.tir.stmt.Evaluate(buf_load)
- for_loop = tvm.tir.stmt.For(
- loop_var=var, kind=0, min_val=0, extent=tvm.tir.Cast("int32",
buf_load), body=buf_load_stmt
- )
- buf_func = tvm.tir.PrimFunc(params={}, body=for_loop)
- mod = tvm.IRModule({"main": buf_func})
- # Trigger nullptr buffer bug by pass
- with pytest.raises(tvm.error.TVMError) as cm:
- mod = tvm.transform.Sequential(
- [
- tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
- tvm.tir.transform.CompactBufferAllocation(),
- tvm.tir.transform.LowerOpaqueBlock(),
- tvm.tir.transform.FlattenBuffer(),
- ]
- )(mod)
- assert "(n != nullptr) is false" in str(cm.execption)
-
-
def test_buffer_index_merge_mult_mod():
m = te.size_var("m")
n = te.size_var("n")
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index e90539f3ef..c860bc0c55 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -17,713 +17,736 @@
import tvm
import tvm.testing
from tvm import te
+from tvm import tir
from tvm.script import tir as T
-def _check(original, transformed):
- func = original
- mod = tvm.IRModule.from_expr(func)
- mod = tvm.tir.transform.CompactBufferAllocation()(mod)
- mod = tvm.tir.transform.Simplify()(mod)
- transformed =
tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(transformed))["main"]
- tvm.ir.assert_structural_equal(mod["main"], transformed)
+class BaseCompactTest:
+ """Base testcase class. The inherit testcase should include:
+ - `before` and `expected` primfunc used to check structural equality for
the transformation.
+ - `is_lower_order_free` tag, defaults to True, denotes that we would check
+ (LowerOpaqueBlock . CompactBufferAllocation)(before) ==
+ (CompactBufferAllocation . LowerOpaqueBlock)(before)
+ - `is_strict` tag, defaults to True, controls the `is_strict` option of
the compaction pass.
+ """
+
+ def test_compact(self):
+ is_lower_order_free = getattr(self, "is_lower_order_free", True)
+ is_strict = getattr(self, "is_strict_mode", True)
+
+ before = tvm.IRModule.from_expr(self.before)
+ expected = tvm.IRModule.from_expr(self.expected)
+ simplify = tvm.transform.Sequential([tir.transform.Simplify(),
tir.transform.RemoveNoOp()])
+ after =
simplify(tir.transform.CompactBufferAllocation(is_strict=is_strict)(before))
+ expected = simplify(expected)
+ try:
+ tvm.ir.assert_structural_equal(after, expected)
+ except ValueError as err:
+ script = tvm.IRModule(
+ {"expected": expected["main"], "after": after["main"],
"before": before["main"]}
+ ).script()
+ raise ValueError(
+ f"Function after simplification did not match
expected:\n{script}"
+ ) from err
+
+ if not is_lower_order_free:
+ return
+ lower_before_compact = tir.transform.LowerOpaqueBlock()(before)
+ lower_before_compact =
tir.transform.CompactBufferAllocation(is_strict=is_strict)(
+ lower_before_compact
+ )
+ lower_before_compact = simplify(lower_before_compact)
+ lower_after_compact = tir.transform.LowerOpaqueBlock()(after)
+ lower_after_compact = simplify(lower_after_compact)
+ try:
+ tvm.ir.assert_structural_equal(lower_before_compact,
lower_after_compact)
+ except ValueError as err:
+ script = tvm.IRModule(
+ {
+ "lower_before_compact": lower_before_compact["main"],
+ "lower_after_compact": lower_after_compact["main"],
+ "before": before["main"],
+ }
+ ).script()
+ raise ValueError(
+ f"Function after simplification did not match
expected:\n{script}"
+ ) from err
+
+
+class TestElemwise(BaseCompactTest):
+ @T.prim_func
+ def before(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), "float32")
+ C = T.match_buffer(c, (16, 16), "float32")
+ for i in range(0, 16):
+ with T.block():
+ T.reads(A[i, 0:16])
+ T.writes(C[i, 0:16])
+ B = T.alloc_buffer((16, 16), "float32")
+ for j in range(0, 16):
+ with T.block():
+ T.reads(A[i, j])
+ T.writes(B[i, j])
+ B[i, j] = A[i, j] + 1.0
+ for j in range(0, 16):
+ with T.block():
+ T.reads(B[i, j])
+ T.writes(C[i, j])
+ C[i, j] = B[i, j] * 2.0
+ @T.prim_func
+ def expected(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), "float32")
+ C = T.match_buffer(c, (16, 16), "float32")
+ for i in range(0, 16):
+ with T.block():
+ T.reads(A[i, 0:16])
+ T.writes(C[i, 0:16])
+ B = T.alloc_buffer((1, 16), "float32")
+ for j in range(0, 16):
+ with T.block():
+ T.reads(A[i, j])
+ T.writes(B[0, j])
+ B[0, j] = A[i, j] + 1.0
+ for j in range(0, 16):
+ with T.block():
+ T.reads(B[0, j])
+ T.writes(C[i, j])
+ C[i, j] = B[0, j] * 2.0
[email protected]_func
-def elementwise_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (16, 16), "float32")
- for i in range(0, 16):
- with T.block():
- T.reads(A[i, 0:16])
- T.writes(C[i, 0:16])
- B = T.alloc_buffer((16, 16), "float32")
- for j in range(0, 16):
- with T.block():
- T.reads(A[i, j])
- T.writes(B[i, j])
+
+class TestUnschedulableFunc(BaseCompactTest):
+ @T.prim_func
+ def before(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), "float32")
+ C = T.match_buffer(c, (16, 16), "float32")
+ for i in range(0, 16):
+ with T.block():
+ T.reads(A[i, 0:16])
+ T.writes(C[i, 0:16])
+ B = T.alloc_buffer((16, 16), "float32")
+ for j in range(0, 16):
+ T.evaluate(T.call_extern("dummy_extern_function", B.data,
dtype="int32"))
B[i, j] = A[i, j] + 1.0
- for j in range(0, 16):
- with T.block():
- T.reads(B[i, j])
- T.writes(C[i, j])
+ for j in range(0, 16):
C[i, j] = B[i, j] * 2.0
+ expected = before
[email protected]_func
-def compacted_elementwise_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (16, 16), "float32")
- for i in range(0, 16):
- with T.block():
- T.reads(A[i, 0:16])
- T.writes(C[i, 0:16])
- B = T.alloc_buffer((1, 16), "float32")
- for j in range(0, 16):
- with T.block():
- T.reads(A[i, j])
- T.writes(B[0, j])
- B[0, j] = A[i, j] + 1.0
- for j in range(0, 16):
- with T.block():
- T.reads(B[0, j])
- T.writes(C[i, j])
- C[i, j] = B[0, j] * 2.0
+class TestParamBufferAccess(BaseCompactTest):
+ @T.prim_func
+ def before(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (20, 20), "float32")
+ B = T.match_buffer(c, (20, 20), "float32")
+ for i in range(0, 16):
+ with T.block():
+ T.reads(A[i, 0:16])
+ T.writes(B[i, 0:16])
+ for j in range(0, 16):
+ with T.block():
+ T.reads(A[i, j])
+ T.writes(B[i, j])
+ B[i, j] = A[i, j] + 1.0
[email protected]_func
-def unschedulable_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (16, 16), "float32")
- for i in range(0, 16):
- with T.block():
- T.reads(A[i, 0:16])
- T.writes(C[i, 0:16])
- B = T.alloc_buffer((16, 16), "float32")
- for j in range(0, 16):
- T.evaluate(T.call_extern("dummy_extern_function", B.data,
dtype="int32"))
- B[i, j] = A[i, j] + 1.0
- for j in range(0, 16):
- C[i, j] = B[i, j] * 2.0
-
-
[email protected]_func
-def param_buffer_access_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (20, 20), "float32")
- B = T.match_buffer(c, (20, 20), "float32")
- for i in range(0, 16):
- with T.block():
- T.reads(A[i, 0:16])
- T.writes(B[i, 0:16])
- for j in range(0, 16):
- with T.block():
- T.reads(A[i, j])
- T.writes(B[i, j])
- B[i, j] = A[i, j] + 1.0
+ expected = before
[email protected]_func
-def shared_mem_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (16, 16), "float32")
- for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
- for i1 in T.thread_binding(0, 2, thread="vthread"):
- for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
- with T.block():
- T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
- T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
- B = T.alloc_buffer((16, 16), "float32", scope="shared")
- for j in range(0, 16):
- with T.block():
- T.reads(A[i0 * 8 + i1 * 4 + i2, j])
- T.writes(B[i0 * 8 + i1 * 4 + i2, j])
- B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 +
i2, j] + 1.0
- for j in range(0, 16):
- with T.block():
- T.reads(B[i0 * 8 + i1 * 4 + i2, j])
- T.writes(C[i0 * 8 + i1 * 4 + i2, j])
- C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 +
i2, j] * 2.0
-
-
[email protected]_func
-def compacted_shared_mem_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (16, 16), "float32")
- for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
- for i1 in T.thread_binding(0, 2, thread="vthread"):
- for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
- with T.block():
- T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
- T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
- B = T.alloc_buffer((8, 16), "float32", scope="shared")
- for j in range(0, 16):
- with T.block():
- T.reads(A[i0 * 8 + i1 * 4 + i2, j])
- T.writes(B[i1 * 4 + i2, j])
- B[i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] +
1.0
- for j in range(0, 16):
- with T.block():
- T.reads(B[i1 * 4 + i2, j])
- T.writes(C[i0 * 8 + i1 * 4 + i2, j])
- C[i0 * 8 + i1 * 4 + i2, j] = B[i1 * 4 + i2, j] *
2.0
-
-
[email protected]_func
-def warp_mem_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (16, 16), "float32")
- for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
- for i1 in T.thread_binding(0, 2, thread="vthread"):
- for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
- with T.block():
- T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
- T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
- B = T.alloc_buffer((16, 16), "float32", scope="warp")
- for j in range(0, 16):
- with T.block():
- T.reads(A[i0 * 8 + i1 * 4 + i2, j])
- T.writes(B[i0 * 8 + i1 * 4 + i2, j])
- B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 +
i2, j] + 1.0
- for j in range(0, 16):
- with T.block():
- T.reads(B[i0 * 8 + i1 * 4 + i2, j])
- T.writes(C[i0 * 8 + i1 * 4 + i2, j])
- C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 +
i2, j] * 2.0
-
-
[email protected]_func
-def compacted_warp_mem_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (16, 16), "float32")
- for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
- for i1 in T.thread_binding(0, 2, thread="vthread"):
- for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
- with T.block():
- T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
- T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
- B = T.alloc_buffer((4, 16), "float32", scope="warp")
- for j in range(0, 16):
- with T.block():
- T.reads(A[i0 * 8 + i1 * 4 + i2, j])
- T.writes(B[i2, j])
- B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0
- for j in range(0, 16):
- with T.block():
- T.reads(B[i2, j])
- T.writes(C[i0 * 8 + i1 * 4 + i2, j])
- C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0
+class TestSharedMem(BaseCompactTest):
+ @T.prim_func
+ def before(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), "float32")
+ C = T.match_buffer(c, (16, 16), "float32")
+ for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
+ for i1 in T.thread_binding(0, 2, thread="vthread"):
+ for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
+ with T.block():
+ T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+ T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+ B = T.alloc_buffer((16, 16), "float32", scope="shared")
+ for j in range(0, 16):
+ with T.block():
+ T.reads(A[i0 * 8 + i1 * 4 + i2, j])
+ T.writes(B[i0 * 8 + i1 * 4 + i2, j])
+ B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4
+ i2, j] + 1.0
+ for j in range(0, 16):
+ with T.block():
+ T.reads(B[i0 * 8 + i1 * 4 + i2, j])
+ T.writes(C[i0 * 8 + i1 * 4 + i2, j])
+ C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4
+ i2, j] * 2.0
+ @T.prim_func
+ def expected(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), "float32")
+ C = T.match_buffer(c, (16, 16), "float32")
+ for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
+ for i1 in T.thread_binding(0, 2, thread="vthread"):
+ for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
+ with T.block():
+ T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+ T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+ B = T.alloc_buffer((8, 16), "float32", scope="shared")
+ for j in range(0, 16):
+ with T.block():
+ T.reads(A[i0 * 8 + i1 * 4 + i2, j])
+ T.writes(B[i1 * 4 + i2, j])
+ B[i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j]
+ 1.0
+ for j in range(0, 16):
+ with T.block():
+ T.reads(B[i1 * 4 + i2, j])
+ T.writes(C[i0 * 8 + i1 * 4 + i2, j])
+ C[i0 * 8 + i1 * 4 + i2, j] = B[i1 * 4 + i2, j]
* 2.0
[email protected]_func
-def symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None:
- A = T.match_buffer(a, (n * 8,), "float32")
- C = T.match_buffer(c, (n * 8,), "float32")
- for i in range(0, n):
- with T.block():
- T.reads(A[i * 8 : i * 8 + 8])
- T.writes(C[i * 8 : i * 8 + 8])
- B = T.alloc_buffer((n * 8,), "float32")
- for j in range(0, 8):
- with T.block():
- T.reads(A[i * 8 + j])
- T.writes(B[i * 8 + j])
- B[i * 8 + j] = A[i * 8 + j] + 1.0
- for j in range(0, 8):
- with T.block():
- T.reads(B[i * 8 + j])
- T.writes(C[i * 8 + j])
- C[i * 8 + j] = B[i * 8 + j] * 2.0
+class TestWrapMem(BaseCompactTest):
+ @T.prim_func
+ def before(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), "float32")
+ C = T.match_buffer(c, (16, 16), "float32")
+ for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
+ for i1 in T.thread_binding(0, 2, thread="vthread"):
+ for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
+ with T.block():
+ T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+ T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+ B = T.alloc_buffer((16, 16), "float32", scope="warp")
+ for j in range(0, 16):
+ with T.block():
+ T.reads(A[i0 * 8 + i1 * 4 + i2, j])
+ T.writes(B[i0 * 8 + i1 * 4 + i2, j])
+ B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4
+ i2, j] + 1.0
+ for j in range(0, 16):
+ with T.block():
+ T.reads(B[i0 * 8 + i1 * 4 + i2, j])
+ T.writes(C[i0 * 8 + i1 * 4 + i2, j])
+ C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4
+ i2, j] * 2.0
[email protected]_func
-def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None:
- A = T.match_buffer(a, (n * 8,), "float32")
- C = T.match_buffer(c, (n * 8,), "float32")
- for i in range(0, n):
- with T.block():
- T.reads(A[i * 8 : i * 8 + 8])
- T.writes(C[i * 8 : i * 8 + 8])
- B = T.alloc_buffer((8,), "float32")
- for j in range(0, 8):
- with T.block():
- T.reads(A[i * 8 + j])
- T.writes(B[j])
- B[j] = A[i * 8 + j] + 1.0
- for j in range(0, 8):
- with T.block():
- T.reads(B[j])
- T.writes(C[i * 8 + j])
- C[i * 8 + j] = B[j] * 2.0
+ @T.prim_func
+ def expected(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), "float32")
+ C = T.match_buffer(c, (16, 16), "float32")
+ for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
+ for i1 in T.thread_binding(0, 2, thread="vthread"):
+ for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
+ with T.block():
+ T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+ T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+ B = T.alloc_buffer((4, 16), "float32", scope="warp")
+ for j in range(0, 16):
+ with T.block():
+ T.reads(A[i0 * 8 + i1 * 4 + i2, j])
+ T.writes(B[i2, j])
+ B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0
+ for j in range(0, 16):
+ with T.block():
+ T.reads(B[i2, j])
+ T.writes(C[i0 * 8 + i1 * 4 + i2, j])
+ C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0
[email protected]_func
-def complex_func(a: T.handle, c: T.handle, n: T.int32) -> None:
- A = T.match_buffer(a, (8, 8), "float32")
- C = T.match_buffer(c, (8, 8), "float32")
- for i in range(0, 8):
- with T.block():
- T.reads(A[0, 8])
- T.writes(C[0, 8])
- B = T.alloc_buffer((8, 8), "float32")
- for j in range(0, 4):
- with T.block():
- D = T.alloc_buffer((8, 8), "float32")
- T.reads(A[i, j])
- T.writes(B[i, j])
- for k in range(4, 8):
- D[k, j] = 1.0
- for k in range(2, 4):
- B[i, j] = A[i, j] + D[k, j]
- for j in range(3, 5):
- with T.block():
- T.reads(B[i, j])
- T.writes(C[i, j])
- C[i, j] = B[i, j]
- for j in range(6, 8):
- with T.block():
- T.reads(B[i, j])
- T.writes(C[i, j])
- C[i, j] = B[i, j]
+class TestSymbolic(BaseCompactTest):
+ @T.prim_func
+ def before(a: T.handle, c: T.handle, n: T.int32) -> None:
+ A = T.match_buffer(a, (n * 8,), "float32")
+ C = T.match_buffer(c, (n * 8,), "float32")
+ for i in range(0, n):
+ with T.block():
+ T.reads(A[i * 8 : i * 8 + 8])
+ T.writes(C[i * 8 : i * 8 + 8])
+ B = T.alloc_buffer((n * 8,), "float32")
+ for j in range(0, 8):
+ with T.block():
+ T.reads(A[i * 8 + j])
+ T.writes(B[i * 8 + j])
+ B[i * 8 + j] = A[i * 8 + j] + 1.0
+ for j in range(0, 8):
+ with T.block():
+ T.reads(B[i * 8 + j])
+ T.writes(C[i * 8 + j])
+ C[i * 8 + j] = B[i * 8 + j] * 2.0
+
+ @T.prim_func
+ def expected(a: T.handle, c: T.handle, n: T.int32) -> None:
+ A = T.match_buffer(a, (n * 8,), "float32")
+ C = T.match_buffer(c, (n * 8,), "float32")
+ for i in range(0, n):
+ with T.block():
+ T.reads(A[i * 8 : i * 8 + 8])
+ T.writes(C[i * 8 : i * 8 + 8])
+ B = T.alloc_buffer((8,), "float32")
+ for j in range(0, 8):
+ with T.block():
+ T.reads(A[i * 8 + j])
+ T.writes(B[j])
+ B[j] = A[i * 8 + j] + 1.0
+ for j in range(0, 8):
+ with T.block():
+ T.reads(B[j])
+ T.writes(C[i * 8 + j])
+ C[i * 8 + j] = B[j] * 2.0
[email protected]_func
-def compacted_complex_func(a: T.handle, c: T.handle, n: T.int32) -> None:
- A = T.match_buffer(a, (8, 8), "float32")
- C = T.match_buffer(c, (8, 8), "float32")
- for i in range(0, 8):
- with T.block():
- T.reads(A[0, 8])
- T.writes(C[0, 8])
- B = T.alloc_buffer((1, 8), "float32")
- for j in range(0, 4):
- with T.block():
- D = T.alloc_buffer((6, 1), "float32")
- T.reads(A[i, j])
- T.writes(B[0, j])
- for k in range(4, 8):
- D[k - 2, 0] = 1.0
- for k in range(2, 4):
- B[0, j] = A[i, j] + D[k - 2, 0]
- for j in range(3, 5):
- with T.block():
- T.reads(B[0, j])
- T.writes(C[i, j])
- C[i, j] = B[0, j]
- for j in range(6, 8):
- with T.block():
- T.reads(B[0, j])
- T.writes(C[i, j])
- C[i, j] = B[0, j]
+class TestComplexFunc(BaseCompactTest):
+ @T.prim_func
+ def before(a: T.handle, c: T.handle, n: T.int32) -> None:
+ A = T.match_buffer(a, (8, 8), "float32")
+ C = T.match_buffer(c, (8, 8), "float32")
+ for i in range(0, 8):
+ with T.block():
+ T.reads(A[0, 8])
+ T.writes(C[0, 8])
+ B = T.alloc_buffer((8, 8), "float32")
+ for j in range(0, 4):
+ with T.block():
+ D = T.alloc_buffer((8, 8), "float32")
+ T.reads(A[i, j])
+ T.writes(B[i, j])
+ for k in range(4, 8):
+ D[k, j] = 1.0
+ for k in range(2, 4):
+ B[i, j] = A[i, j] + D[k, j]
+ for j in range(3, 5):
+ with T.block():
+ T.reads(B[i, j])
+ T.writes(C[i, j])
+ C[i, j] = B[i, j]
+ for j in range(6, 8):
+ with T.block():
+ T.reads(B[i, j])
+ T.writes(C[i, j])
+ C[i, j] = B[i, j]
+
+ @T.prim_func
+ def expected(a: T.handle, c: T.handle, n: T.int32) -> None:
+ A = T.match_buffer(a, (8, 8), "float32")
+ C = T.match_buffer(c, (8, 8), "float32")
+ for i in range(0, 8):
+ with T.block():
+ T.reads(A[0, 8])
+ T.writes(C[0, 8])
+ B = T.alloc_buffer((1, 8), "float32")
+ for j in range(0, 4):
+ with T.block():
+ D = T.alloc_buffer((6, 1), "float32")
+ T.reads(A[i, j])
+ T.writes(B[0, j])
+ for k in range(4, 8):
+ D[k - 2, 0] = 1.0
+ for k in range(2, 4):
+ B[0, j] = A[i, j] + D[k - 2, 0]
+ for j in range(3, 5):
+ with T.block():
+ T.reads(B[0, j])
+ T.writes(C[i, j])
+ C[i, j] = B[0, j]
+ for j in range(6, 8):
+ with T.block():
+ T.reads(B[0, j])
+ T.writes(C[i, j])
+ C[i, j] = B[0, j]
[email protected]_func
-def match_buffer_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16))
- C = T.match_buffer(c, (16, 16))
- for i in range(0, 16):
- with T.block():
- A0 = T.match_buffer(A[i, 0:16], (16))
- C0 = T.match_buffer(C[i, 0:16], (16))
- B = T.alloc_buffer((16, 16))
+class TestMatchBuffer(BaseCompactTest):
+ is_lower_order_free = False
+
+ @T.prim_func
+ def before(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16))
+ C = T.match_buffer(c, (16, 16))
+ for i in range(0, 16):
with T.block():
- B0 = T.match_buffer(B[i, 0:16], (16))
+ A0 = T.match_buffer(A[i, 0:16], (16))
+ C0 = T.match_buffer(C[i, 0:16], (16))
+ B = T.alloc_buffer((16, 16))
+ with T.block():
+ B0 = T.match_buffer(B[i, 0:16], (16))
+ for j in range(0, 16):
+ with T.block():
+ A1 = T.match_buffer(A0[j], ())
+ B1 = T.match_buffer(B0[j], ())
+ B1[()] = A1[()] + 1.0
for j in range(0, 16):
with T.block():
- A1 = T.match_buffer(A0[j], ())
- B1 = T.match_buffer(B0[j], ())
- B1[()] = A1[()] + 1.0
- for j in range(0, 16):
+ C1 = T.match_buffer(C0[j], ())
+ B2 = T.match_buffer(B[i, j], ())
+ C1[()] = B2[()] * 2.0
+
+ @T.prim_func
+ def expected(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16))
+ C = T.match_buffer(c, (16, 16))
+ for i in range(0, 16):
+ with T.block():
+ A0 = T.match_buffer(A[i, 0:16], (16))
+ C0 = T.match_buffer(C[i, 0:16], (16))
+ B = T.alloc_buffer((1, 16))
with T.block():
- C1 = T.match_buffer(C0[j], ())
- B2 = T.match_buffer(B[i, j], ())
- C1[()] = B2[()] * 2.0
+ B0 = T.match_buffer(B[0, 0:16], (16))
+ for j in range(0, 16):
+ with T.block():
+ A1 = T.match_buffer(A0[j], ())
+ B1 = T.match_buffer(B0[j], ())
+ B1[()] = A1[()] + 1.0
+ for j in range(0, 16):
+ with T.block():
+ C1 = T.match_buffer(C0[j], ())
+ B2 = T.match_buffer(B[0, j], ())
+ C1[()] = B2[()] * 2.0
[email protected]_func
-def compacted_match_buffer_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16))
- C = T.match_buffer(c, (16, 16))
- for i in range(0, 16):
- with T.block():
- A0 = T.match_buffer(A[i, 0:16], (16))
- C0 = T.match_buffer(C[i, 0:16], (16))
- B = T.alloc_buffer((1, 16))
+class TestStorageAlign(BaseCompactTest):
+ @T.prim_func
+ def before(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), "float32")
+ C = T.match_buffer(c, (16, 16), "float32")
+ for i in range(0, 16):
with T.block():
- B0 = T.match_buffer(B[0, 0:16], (16))
+ T.reads(A[i, 0:16])
+ T.writes(C[i, 0:16])
+ B = T.alloc_buffer((16, 16), "float32")
for j in range(0, 16):
with T.block():
- A1 = T.match_buffer(A0[j], ())
- B1 = T.match_buffer(B0[j], ())
- B1[()] = A1[()] + 1.0
- for j in range(0, 16):
- with T.block():
- C1 = T.match_buffer(C0[j], ())
- B2 = T.match_buffer(B[0, j], ())
- C1[()] = B2[()] * 2.0
+ T.reads(A[i, j])
+ T.writes(B[i, j])
+ T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
+ B[i, j] = A[i, j] + 1.0
+ for j in range(0, 16):
+ with T.block():
+ T.reads(B[i, j])
+ T.writes(C[i, j])
+ C[i, j] = B[i, j] * 2.0
+ @T.prim_func
+ def expected(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), "float32")
+ C = T.match_buffer(c, (16, 16), "float32")
+ for i in range(0, 16):
+ with T.block():
+ T.reads(A[i, 0:16])
+ T.writes(C[i, 0:16])
+ B = T.alloc_buffer((1, 16), strides=(31, 1), dtype="float32")
+ for j in range(0, 16):
+ with T.block():
+ T.reads(A[i, j])
+ T.writes(B[0, j])
+ T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
+ B[0, j] = A[i, j] + 1.0
+ for j in range(0, 16):
+ with T.block():
+ T.reads(B[0, j])
+ T.writes(C[i, j])
+ C[i, j] = B[0, j] * 2.0
[email protected]_func
-def storage_align_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (16, 16), "float32")
- for i in range(0, 16):
+
+class TestPaddingPattern(BaseCompactTest):
+ @T.prim_func
+ def before(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), "float32")
+ C = T.match_buffer(c, (20, 20), "float32")
with T.block():
- T.reads(A[i, 0:16])
- T.writes(C[i, 0:16])
- B = T.alloc_buffer((16, 16), "float32")
- for j in range(0, 16):
+ B = T.alloc_buffer((20, 20), dtype="float32")
+ for i, j in T.grid(16, 16):
with T.block():
- T.reads(A[i, j])
- T.writes(B[i, j])
- T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
- B[i, j] = A[i, j] + 1.0
- for j in range(0, 16):
+ B[i, j] = A[i, j]
+ for i, j in T.grid(20, 20):
with T.block():
- T.reads(B[i, j])
- T.writes(C[i, j])
- C[i, j] = B[i, j] * 2.0
-
+ C[i, j] = T.if_then_else(
+ 2 <= i and i < 18 and 2 <= j and j < 18,
+ B[i - 2, j - 2],
+ 0.0,
+ dtype="float32",
+ )
[email protected]_func
-def compacted_storage_align_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (16, 16), "float32")
- for i in range(0, 16):
+ @T.prim_func
+ def expected(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, [16, 16], dtype="float32")
+ C = T.match_buffer(c, [20, 20], dtype="float32")
with T.block():
- T.reads(A[i, 0:16])
- T.writes(C[i, 0:16])
- B = T.alloc_buffer((1, 16), strides=(31, 1), dtype="float32")
- for j in range(0, 16):
+ B = T.alloc_buffer([16, 16], dtype="float32")
+ for i, j in T.grid(16, 16):
with T.block():
- T.reads(A[i, j])
- T.writes(B[0, j])
- T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
- B[0, j] = A[i, j] + 1.0
- for j in range(0, 16):
+ B[i, j] = A[i, j]
+ for i, j in T.grid(20, 20):
with T.block():
- T.reads(B[0, j])
- T.writes(C[i, j])
- C[i, j] = B[0, j] * 2.0
-
-
[email protected]_func
-def padding_pattern_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (20, 20), "float32")
- with T.block():
- B = T.alloc_buffer((20, 20), dtype="float32")
- for i, j in T.grid(16, 16):
- with T.block():
- B[i, j] = A[i, j]
- for i, j in T.grid(20, 20):
- with T.block():
- C[i, j] = T.if_then_else(
- 2 <= i and i < 18 and 2 <= j and j < 18,
- B[i - 2, j - 2],
- 0.0,
- dtype="float32",
- )
+ C[i, j] = T.if_then_else(
+ 2 <= i and i < 18 and 2 <= j and j < 18,
+ B[i - 2, j - 2],
+ 0.0,
+ dtype="float32",
+ )
[email protected]_func
-def compacted_padding_pattern_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, [16, 16], dtype="float32")
- C = T.match_buffer(c, [20, 20], dtype="float32")
- with T.block():
- B = T.alloc_buffer([16, 16], dtype="float32")
- for i, j in T.grid(16, 16):
- with T.block():
- B[i, j] = A[i, j]
- for i, j in T.grid(20, 20):
- with T.block():
- C[i, j] = T.if_then_else(
- 2 <= i and i < 18 and 2 <= j and j < 18, B[i - 2, j - 2],
0.0, dtype="float32"
+class TestPaddingPatternInlined(BaseCompactTest):
+ @T.prim_func
+ def before(a: T.handle, b: T.handle) -> None:
+ X = T.match_buffer(a, [224, 224], dtype="float32")
+ Y = T.match_buffer(b, [224, 224], dtype="float32")
+ cache = T.alloc_buffer([224, 224], dtype="float32")
+ for h, w in T.grid(224, 224):
+ with T.block("cache"):
+ cache[h, w] = X[h, w]
+ for h, w, kh, kw in T.grid(224, 224, 3, 3):
+ with T.block("compute"):
+ Y[h, w] = T.max(
+ Y[h, w],
+ T.if_then_else(
+ T.likely(1 <= h + kh, dtype="bool")
+ and T.likely(h + kh < 225, dtype="bool")
+ and T.likely(1 <= w + kw, dtype="bool")
+ and T.likely(w + kw < 225, dtype="bool"),
+ cache[h + kh - 1, w + kw - 1],
+ 0.0,
+ dtype="float32",
+ ),
)
+ @T.prim_func
+ def expected(X: T.Buffer((224, 224), "float32"), Y: T.Buffer((224, 224),
"float32")) -> None:
+ cache = T.alloc_buffer([224, 224], dtype="float32")
+ for h, w in T.grid(224, 224):
+ with T.block("cache"):
+ cache[h, w] = X[h, w]
+ for h, w, kh, kw in T.grid(224, 224, 3, 3):
+ with T.block("compute"):
+ Y[h, w] = T.max(
+ Y[h, w],
+ T.if_then_else(
+ T.likely(1 <= h + kh, dtype="bool")
+ and T.likely(h + kh < 225, dtype="bool")
+ and T.likely(1 <= w + kw, dtype="bool")
+ and T.likely(w + kw < 225, dtype="bool"),
+ cache[h + kh - 1, w + kw - 1],
+ 0.0,
+ dtype="float32",
+ ),
+ )
+
+
+class TestMemAccessInBranch(BaseCompactTest):
+ @T.prim_func
+ def before(a: T.handle) -> None:
+ A = T.match_buffer(a, (224, 224), "float32")
+ with T.block():
+ B1 = T.alloc_buffer((224, 224), dtype="float32")
+ B2 = T.alloc_buffer((224, 224), dtype="float32")
+ B3 = T.alloc_buffer((224, 224), dtype="float32")
+ B4 = T.alloc_buffer((224, 224), dtype="float32")
+ for i in range(0, 224):
+ for j in range(0, 224):
+ with T.block():
+ if i < 112 and j < 112:
+ B1[i, j] = A[i, j] * 2.0
+ else:
+ B2[i, j] = A[i, j] + 3.0
+ for i in range(0, 224):
+ for j in range(0, 224):
+ with T.block():
+ if i < 112 or j < 112:
+ B3[i, j] = A[i, j] * 2.0
+ else:
+ B4[i, j] = A[i, j] + 3.0
[email protected]_func
-def padding_pattern_inlined(a: T.handle, b: T.handle) -> None:
- X = T.match_buffer(a, [224, 224], dtype="float32")
- Y = T.match_buffer(b, [224, 224], dtype="float32")
- cache = T.alloc_buffer([224, 224], dtype="float32")
- for h, w in T.grid(224, 224):
- with T.block("cache"):
- cache[h, w] = X[h, w]
- for h, w, kh, kw in T.grid(224, 224, 3, 3):
- with T.block("compute"):
- Y[h, w] = T.max(
- Y[h, w],
- T.if_then_else(
- T.likely(1 <= h + kh, dtype="bool")
- and T.likely(h + kh < 225, dtype="bool")
- and T.likely(1 <= w + kw, dtype="bool")
- and T.likely(w + kw < 225, dtype="bool"),
- cache[h + kh - 1, w + kw - 1],
- 0.0,
- dtype="float32",
- ),
- )
-
-
[email protected]_func
-def compacted_padding_pattern_inlined(
- X: T.Buffer((224, 224), "float32"), Y: T.Buffer((224, 224), "float32")
-) -> None:
- cache = T.alloc_buffer([224, 224], dtype="float32")
- for h, w in T.grid(224, 224):
- with T.block("cache"):
- cache[h, w] = X[h, w]
- for h, w, kh, kw in T.grid(224, 224, 3, 3):
- with T.block("compute"):
- Y[h, w] = T.max(
- Y[h, w],
- T.if_then_else(
- T.likely(1 <= h + kh, dtype="bool")
- and T.likely(h + kh < 225, dtype="bool")
- and T.likely(1 <= w + kw, dtype="bool")
- and T.likely(w + kw < 225, dtype="bool"),
- cache[h + kh - 1, w + kw - 1],
- 0.0,
- dtype="float32",
- ),
- )
-
-
[email protected]_func
-def mem_access_in_branch_func(a: T.handle) -> None:
- A = T.match_buffer(a, (224, 224), "float32")
- with T.block():
- B1 = T.alloc_buffer((224, 224), dtype="float32")
- B2 = T.alloc_buffer((224, 224), dtype="float32")
- B3 = T.alloc_buffer((224, 224), dtype="float32")
- B4 = T.alloc_buffer((224, 224), dtype="float32")
- for i in range(0, 224):
- for j in range(0, 224):
+ @T.prim_func
+ def expected(a: T.handle) -> None:
+ A = T.match_buffer(a, [224, 224], dtype="float32")
+ with T.block():
+ B1 = T.alloc_buffer([112, 112], dtype="float32")
+ B2 = T.alloc_buffer([224, 224], dtype="float32")
+ B3 = T.alloc_buffer([224, 224], dtype="float32")
+ B4 = T.alloc_buffer([112, 112], dtype="float32")
+ for i, j in T.grid(224, 224):
with T.block():
if i < 112 and j < 112:
B1[i, j] = A[i, j] * 2.0
else:
B2[i, j] = A[i, j] + 3.0
- for i in range(0, 224):
- for j in range(0, 224):
+ for i, j in T.grid(224, 224):
with T.block():
if i < 112 or j < 112:
B3[i, j] = A[i, j] * 2.0
else:
- B4[i, j] = A[i, j] + 3.0
-
-
[email protected]_func
-def compacted_mem_access_in_branch_func(a: T.handle) -> None:
- A = T.match_buffer(a, [224, 224], dtype="float32")
- with T.block():
- B1 = T.alloc_buffer([112, 112], dtype="float32")
- B2 = T.alloc_buffer([224, 224], dtype="float32")
- B3 = T.alloc_buffer([224, 224], dtype="float32")
- B4 = T.alloc_buffer([112, 112], dtype="float32")
- for i, j in T.grid(224, 224):
- with T.block():
- if i < 112 and j < 112:
- B1[i, j] = A[i, j] * 2.0
- else:
- B2[i, j] = A[i, j] + 3.0
- for i, j in T.grid(224, 224):
- with T.block():
- if i < 112 or j < 112:
- B3[i, j] = A[i, j] * 2.0
- else:
- B4[i - 112, j - 112] = A[i, j] + 3.0
-
-
[email protected]_func
-def opaque_access_annotated_func(a: T.handle) -> None:
- A = T.match_buffer(a, (1024,), "float32")
- with T.block():
- B = T.alloc_buffer((1024,), dtype="float32")
- C = T.alloc_buffer((1024,), dtype="float32")
- for i in range(0, 512):
- with T.block():
- # no annotation, opaque access will cover full region
- T.reads([])
- T.writes([])
- T.evaluate(T.call_extern("opaque_extern_function", A.data,
B.data, dtype="int32"))
- B[i] = A[i]
- with T.block():
- # treat opaque access only access annotated regions, even if
- # they are not compatible with actual buffer accesses.
- T.reads([B[i]])
- T.writes([C[i : i + 9]])
- T.evaluate(T.call_extern("opaque_extern_function", B.data,
C.data, dtype="int32"))
- C[i] = B[i]
-
-
[email protected]_func
-def compacted_opaque_access_annotated_func(a: T.handle) -> None:
- A = T.match_buffer(a, (1024,), "float32")
- with T.block():
- B = T.alloc_buffer((1024,), dtype="float32")
- C = T.alloc_buffer((520,), dtype="float32")
- for i in range(0, 512):
- with T.block():
- # no annotation, opaque access will cover full region
- T.reads([])
- T.writes([])
- T.evaluate(T.call_extern("opaque_extern_function", A.data,
B.data, dtype="int32"))
- B[i] = A[i]
- with T.block():
- # treat opaque access only access annotated regions, even if
- # they are not compatible with actual buffer accesses.
- T.reads([B[i]])
- T.writes([C[i : i + 9]])
- T.evaluate(T.call_extern("opaque_extern_function", B.data,
C.data, dtype="int32"))
- C[i] = B[i]
-
-
[email protected]_func
-def sparse_read_cache(
- A_data: T.Buffer((819,), "float32"),
- B: T.Buffer((128,), "float32"),
- A_indptr: T.Buffer((129,), "int32"),
- A_indices: T.Buffer((819,), "int32"),
-) -> None:
- for i in T.serial(128):
- with T.block("rowsum_outer"):
- T.reads(
- A_indptr[i : i + 1],
- A_data[A_indptr[i] + 0 : A_indptr[i] + (A_indptr[i + 1] -
A_indptr[i])],
- )
- T.writes(B[i])
- with T.block("rowsum_init"):
- T.reads()
- T.writes(B[i])
- B[i] = T.float32(0)
- for k in T.serial(A_indptr[i + 1] - A_indptr[i]):
- with T.block():
- T.reads(A_indptr[i], A_data[A_indptr[i] + k], B[i])
- T.writes(B[i])
- A_data_local = T.alloc_buffer([819], dtype="float32",
scope="local")
- with T.block("A_data_cache_read"):
- T.reads(A_indptr[i], A_data[A_indptr[i] + k])
- T.writes(A_data_local[A_indptr[i] + k])
- A_data_local[A_indptr[i] + k] = A_data[A_indptr[i] + k]
- with T.block("rowsum_inner"):
- T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k])
- T.writes(B[i])
- B[i] = B[i] + A_data_local[A_indptr[i] + k]
-
-
[email protected]_func
-def compacted_sparse_read_cache(
- A_data: T.Buffer((819,), "float32"),
- B: T.Buffer((128,), "float32"),
- A_indptr: T.Buffer((129,), "int32"),
- A_indices: T.Buffer((819,), "int32"),
-) -> None:
- for i in T.serial(128):
- with T.block("rowsum_outer"):
- T.reads(
- A_indptr[i : i + 1],
- A_data[A_indptr[i] + 0 : A_indptr[i] + 0 + (A_indptr[i + 1] -
A_indptr[i])],
- )
- T.writes(B[i])
- with T.block("rowsum_init"):
- T.reads()
- T.writes(B[i])
- B[i] = T.float32(0)
- for k in T.serial(A_indptr[i + 1] - A_indptr[i]):
- with T.block():
- T.reads(A_indptr[i], A_data[A_indptr[i] + k], B[i])
- T.writes(B[i])
- A_data_local = T.alloc_buffer([1], dtype="float32",
scope="local")
- with T.block("A_data_cache_read"):
- T.reads(A_indptr[i], A_data[A_indptr[i] + k])
- T.writes(A_data_local[T.min(A_indptr[i] + k, 0)])
- A_data_local[T.min(A_indptr[i] + k, 0)] =
A_data[A_indptr[i] + k]
- with T.block("rowsum_inner"):
- T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k])
- T.writes(B[i])
- B[i] = B[i] + A_data_local[T.min(A_indptr[i] + k, 0)]
-
-
[email protected]_func
-def narrow_shape(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"))
-> None:
- B_cache = T.alloc_buffer(10, "float32")
- for j in T.serial(3):
- for k in T.serial(4):
- with T.block("B_cache"):
- T.where(j * 4 + k < 10)
- B_cache[j * 4 + k] = B[j]
- for i in T.serial(10):
- A[i] = B_cache[i] + T.float32(1)
-
-
[email protected]_func
-def compacted_narrow_shape(A: T.Buffer((10,), "float32"), B: T.Buffer((10,),
"float32")) -> None:
- # body
- # with T.block("root")
- B_cache = T.alloc_buffer([10], dtype="float32")
- for j, k in T.grid(3, 4):
- with T.block("B_cache"):
- T.where(j * 4 + k < 10)
- T.reads(B[j])
- T.writes(B_cache[j * 4 + k])
- B_cache[j * 4 + k] = B[j]
- for i in T.serial(10):
- A[i] = B_cache[i] + T.float32(1)
-
-
-def test_elementwise():
- _check(elementwise_func, compacted_elementwise_func)
-
-
-def test_unschedulable_block():
- _check(unschedulable_func, unschedulable_func) # changes nothing
-
-
-def test_param_access():
- _check(param_buffer_access_func, param_buffer_access_func) # changes
nothing
-
-
-def test_shared_mem():
- _check(shared_mem_func, compacted_shared_mem_func)
-
-
-def test_warp_mem():
- _check(warp_mem_func, compacted_warp_mem_func)
-
-
-def test_symbolic():
- _check(symbolic_func, compacted_symbolic_func)
-
-
-def test_complex():
- _check(complex_func, compacted_complex_func)
+ B4[i - 112, j - 112] = A[i, j] + 3.0
-def test_match_buffer():
- _check(match_buffer_func, compacted_match_buffer_func)
+class TestAnnotatedOpaqueAccess(BaseCompactTest):
+ is_lower_order_free = False
-def test_lower_te():
- x = te.placeholder((1,))
- y = te.compute((1,), lambda i: x[i] + 2)
- s = te.create_schedule(y.op)
- orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
- mod = tvm.tir.transform.CompactBufferAllocation()(orig_mod)
- tvm.ir.assert_structural_equal(mod, orig_mod) # CompactBufferAllocation
should do nothing on TE
-
-
-def test_storage_align():
- _check(storage_align_func, compacted_storage_align_func)
-
+ @T.prim_func
+ def before(a: T.handle) -> None:
+ A = T.match_buffer(a, (1024,), "float32")
+ with T.block():
+ B = T.alloc_buffer((1024,), dtype="float32")
+ C = T.alloc_buffer((1024,), dtype="float32")
+ for i in range(0, 512):
+ with T.block():
+ # no annotation, opaque access will cover full region
+ T.reads([])
+ T.writes([])
+ T.evaluate(
+ T.call_extern("opaque_extern_function", A.data,
B.data, dtype="int32")
+ )
+ B[i] = A[i]
+ with T.block():
+ # treat opaque access only access annotated regions, even
if
+ # they are not compatible with actual buffer accesses.
+ T.reads([B[i]])
+ T.writes([C[i : i + 9]])
+ T.evaluate(
+ T.call_extern("opaque_extern_function", B.data,
C.data, dtype="int32")
+ )
+ C[i] = B[i]
-def test_padding_pattern():
- _check(padding_pattern_func, compacted_padding_pattern_func)
+ @T.prim_func
+ def expected(a: T.handle) -> None:
+ A = T.match_buffer(a, (1024,), "float32")
+ with T.block():
+ B = T.alloc_buffer((1024,), dtype="float32")
+ C = T.alloc_buffer((520,), dtype="float32")
+ for i in range(0, 512):
+ with T.block():
+ # no annotation, opaque access will cover full region
+ T.reads([])
+ T.writes([])
+ T.evaluate(
+ T.call_extern("opaque_extern_function", A.data,
B.data, dtype="int32")
+ )
+ B[i] = A[i]
+ with T.block():
+ # treat opaque access only access annotated regions, even
if
+ # they are not compatible with actual buffer accesses.
+ T.reads([B[i]])
+ T.writes([C[i : i + 9]])
+ T.evaluate(
+ T.call_extern("opaque_extern_function", B.data,
C.data, dtype="int32")
+ )
+ C[i] = B[i]
-def test_padding_pattern_inlined():
- _check(padding_pattern_inlined, compacted_padding_pattern_inlined)
+class TestSparseReadCache(BaseCompactTest):
+ @T.prim_func
+ def before(
+ A_data: T.Buffer((819,), "float32"),
+ B: T.Buffer((128,), "float32"),
+ A_indptr: T.Buffer((129,), "int32"),
+ ) -> None:
+ for i in T.serial(128):
+ with T.block("rowsum_outer"):
+ T.reads(
+ A_indptr[i : i + 1],
+ A_data[A_indptr[i] + 0 : A_indptr[i] + (A_indptr[i + 1] -
A_indptr[i])],
+ )
+ T.writes(B[i])
+ with T.block("rowsum_init"):
+ T.reads()
+ T.writes(B[i])
+ B[i] = T.float32(0)
+ for k in T.serial(A_indptr[i + 1] - A_indptr[i]):
+ with T.block():
+ T.reads(A_indptr[i], A_data[A_indptr[i] + k], B[i])
+ T.writes(B[i])
+ A_data_local = T.alloc_buffer([819], dtype="float32",
scope="local")
+ with T.block("A_data_cache_read"):
+ T.reads(A_indptr[i], A_data[A_indptr[i] + k])
+ T.writes(A_data_local[A_indptr[i] + k])
+ A_data_local[A_indptr[i] + k] = A_data[A_indptr[i]
+ k]
+ with T.block("rowsum_inner"):
+ T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k])
+ T.writes(B[i])
+ B[i] = B[i] + A_data_local[A_indptr[i] + k]
+ @T.prim_func
+ def expected(
+ A_data: T.Buffer((819,), "float32"),
+ B: T.Buffer((128,), "float32"),
+ A_indptr: T.Buffer((129,), "int32"),
+ ) -> None:
+ for i in T.serial(128):
+ with T.block("rowsum_outer"):
+ T.reads(
+ A_indptr[i : i + 1],
+ A_data[A_indptr[i] + 0 : A_indptr[i] + 0 + (A_indptr[i +
1] - A_indptr[i])],
+ )
+ T.writes(B[i])
+ with T.block("rowsum_init"):
+ T.reads()
+ T.writes(B[i])
+ B[i] = T.float32(0)
+ for k in T.serial(A_indptr[i + 1] - A_indptr[i]):
+ with T.block():
+ T.reads(A_indptr[i], A_data[A_indptr[i] + k], B[i])
+ T.writes(B[i])
+ A_data_local = T.alloc_buffer([1], dtype="float32",
scope="local")
+ with T.block("A_data_cache_read"):
+ T.reads(A_indptr[i], A_data[A_indptr[i] + k])
+ T.writes(A_data_local[T.min(A_indptr[i] + k, 0)])
+ A_data_local[T.min(A_indptr[i] + k, 0)] =
A_data[A_indptr[i] + k]
+ with T.block("rowsum_inner"):
+ T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k])
+ T.writes(B[i])
+ B[i] = B[i] + A_data_local[T.min(A_indptr[i] + k,
0)]
-def test_mem_access_in_branch_func():
- _check(mem_access_in_branch_func, compacted_mem_access_in_branch_func)
+class TestDataDependentRegion(BaseCompactTest):
+ """Partial code of NMS, the `argsort_nms_cpu`'s region depends on inner
allocated buffer
+ `nkeep`'s value, thus the buffer should not be compacted with data
dependent region extent."""
-def test_opaque_access_annotated_func():
- _check(opaque_access_annotated_func,
compacted_opaque_access_annotated_func)
+ @T.prim_func
+ def before(
+ p0: T.Buffer((30,), "float32"),
+ p1: T.Buffer((1,), "int32"),
+ hybrid_nms: T.Buffer((30,), "float32"),
+ ):
+ argsort_nms_cpu = T.decl_buffer([5], "int32", scope="global")
+ for i in range(1):
+ nkeep = T.decl_buffer([1], "int32", scope="global")
+ if 0 < p1[i]:
+ nkeep[0] = p1[i]
+ if 2 < nkeep[0]:
+ nkeep[0] = 2
+ for j in T.parallel(nkeep[0]):
+ for k in range(6):
+ hybrid_nms[i * 30 + j * 6 + k] = p0[
+ i * 30 + argsort_nms_cpu[i * 5 + j] * 6 + k
+ ]
+ hybrid_nms[i * 5 + j] = argsort_nms_cpu[i * 5 + j]
+ if 2 < p1[i]:
+ for j in T.parallel(p1[i] - nkeep[0]):
+ for k in range(6):
+ hybrid_nms[i * 30 + j * 6 + nkeep[0] * 6 + k] =
T.float32(-1)
+ hybrid_nms[i * 5 + j + nkeep[0]] = -1
+ expected = before
-def test_sparse_read_cache():
- _check(sparse_read_cache, compacted_sparse_read_cache)
+class TestNarrowShape(BaseCompactTest):
+ @T.prim_func
+ def before(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"))
-> None:
+ B_cache = T.alloc_buffer(10, "float32")
+ for j in T.serial(3):
+ for k in T.serial(4):
+ with T.block("B_cache"):
+ T.where(j * 4 + k < 10)
+ B_cache[j * 4 + k] = B[j]
+ for i in T.serial(10):
+ A[i] = B_cache[i] + T.float32(1)
-def test_narrow_shape():
- _check(narrow_shape, compacted_narrow_shape)
+ @T.prim_func
+ def expected(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"))
-> None:
+ B_cache = T.alloc_buffer([10], dtype="float32")
+ for j, k in T.grid(3, 4):
+ with T.block("B_cache"):
+ T.where(j * 4 + k < 10)
+ T.reads(B[j])
+ T.writes(B_cache[j * 4 + k])
+ B_cache[j * 4 + k] = B[j]
+ for i in T.serial(10):
+ A[i] = B_cache[i] + T.float32(1)
-def test_compact_with_let_binding():
+class TestLetBinding(BaseCompactTest):
@T.prim_func
- def func_with_let_binding():
+ def before():
A = T.alloc_buffer((64, 8), "float32")
B = T.alloc_buffer((64, 8), "float32")
C = T.alloc_buffer((8, 8), "float32")
@@ -735,10 +758,12 @@ def test_compact_with_let_binding():
rjj: T.int32 = riijj % 8
C[rii, rjj] += A[rk, rii] * B[rk, rjj]
- _check(func_with_let_binding, func_with_let_binding)
+ expected = before
+
+class TestNonIndexLetBinding(BaseCompactTest):
@T.prim_func
- def func_with_non_index_let_binding():
+ def before():
A = T.alloc_buffer((64), "float32")
x1 = T.call_extern("get", dtype="float16")
x2 = T.call_extern("get", dtype="float32")
@@ -750,14 +775,12 @@ def test_compact_with_let_binding():
for rk in range(64):
A[rk] = T.call_extern("load_ptr", x1, x2, x3, x4, x5, x6, x7,
dtype="float32")
- _check(func_with_non_index_let_binding, func_with_non_index_let_binding)
+ expected = before
-def test_compact_spatial_tiled_pad_and_pooling():
+class TestSpatialTiledPadPooling(BaseCompactTest):
@T.prim_func
- def spatial_tiled_pad_and_pooling(
- X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56, 56),
"int32")
- ) -> None:
+ def before(X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56, 56),
"int32")) -> None:
for h_o, w_o in T.grid(14, 14):
with T.block():
X_cache = T.alloc_buffer([112, 112, 64], dtype="int32")
@@ -795,9 +818,7 @@ def test_compact_spatial_tiled_pad_and_pooling():
)
@T.prim_func
- def compacted_spatial_tiled_pad_and_pooling(
- X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56, 56),
"int32")
- ) -> None:
+ def expected(X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56,
56), "int32")) -> None:
for h_o, w_o in T.grid(14, 14):
with T.block():
T.reads(X[0:64, h_o * 8 - 1 : h_o * 8 + 8, w_o * 8 - 1 : w_o *
8 + 8])
@@ -846,15 +867,13 @@ def test_compact_spatial_tiled_pad_and_pooling():
),
)
- _check(spatial_tiled_pad_and_pooling,
compacted_spatial_tiled_pad_and_pooling)
-
-def test_complex_case_1():
+class TestComplexCase1(BaseCompactTest):
"""Meta-schedule matmul case for compact shared A, B matrix"""
# fmt: off
@T.prim_func
- def func(A: T.Buffer((960, 770), "float32"), B: T.Buffer((770, 2304),
"float32"), C: T.Buffer((960, 2304), "float32")) -> None:
+ def before(A: T.Buffer((960, 770), "float32"), B: T.Buffer((770, 2304),
"float32"), C: T.Buffer((960, 2304), "float32")) -> None:
for bx in T.thread_binding(144, thread="blockIdx.x"):
for vx in T.thread_binding(2, thread="vthread.x"):
for tx_p in T.thread_binding(256, thread="threadIdx.x"):
@@ -880,7 +899,7 @@ def test_complex_case_1():
C[(((bx // 18 + 0) * 8 + tx_p // 32) *
8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4]
= C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx
% 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] + A_shared[(((bx // 18 + 0) * 8 + tx_p
// 32) * 8 + i_3) * 2 + i_4, (k_0 + k_1) * 4 + k_2] * B_shared[(k_0 + k_1) * 4
+ k_2, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4]
@T.prim_func
- def compacted_func(A: T.Buffer((960, 770), "float32"), B: T.Buffer((770,
2304), "float32"), C: T.Buffer((960, 2304), "float32")) -> None:
+ def expected(A: T.Buffer((960, 770), "float32"), B: T.Buffer((770, 2304),
"float32"), C: T.Buffer((960, 2304), "float32")) -> None:
for bx in T.thread_binding(144, thread="blockIdx.x"):
for vx in T.thread_binding(2, thread="vthread.x"):
for tx_p in T.thread_binding(256, thread="threadIdx.x"):
@@ -906,14 +925,13 @@ def test_complex_case_1():
C[bx // 18 * 128 + tx_p // 32 * 16 +
i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] = C[bx // 18 *
128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2
+ j_4] + A_shared[tx_p // 32 * 16 + i_3 * 2 + i_4, k_2] * B_shared[k_2, vx * 64
+ tx_p % 32 * 2 + j_4]
# fmt: on
- _check(func, compacted_func)
-
-def test_compact_dependent_buffer_indices():
+class TestDependentBufferIndices(BaseCompactTest):
"""Check the upper bound on different indices could be independently
estimated."""
@T.prim_func
- def diagonal_access():
+ def before():
+ """This is a diagnal buffer access pattern"""
for i in range(8):
with T.block():
A = T.alloc_buffer((256, 256), "float32")
@@ -923,7 +941,7 @@ def test_compact_dependent_buffer_indices():
A[i * 64 + j * 8 + k, i * 64 + j * 8 + k] = 1.0
@T.prim_func
- def diagonal_access_compacted() -> None:
+ def expected() -> None:
for i in T.serial(8):
with T.block():
A = T.alloc_buffer([60, 60], dtype="float32")
@@ -932,14 +950,12 @@ def test_compact_dependent_buffer_indices():
T.where(j * 8 + k < 60)
A[j * 8 + k, j * 8 + k] = 1.0
- _check(diagonal_access, diagonal_access_compacted)
-
-def test_compact_dependent_buffer_indices_of_packed_matmul():
+class TestDependentBufferIndicesOfPackedMatmul(BaseCompactTest):
"""Check the outer dimension of the packed M-dim should be compacted to 1
wrt split condition."""
@T.prim_func
- def nonuniform_packed_matmul_write_cache(
+ def before(
A: T.Buffer((1020, 64), "float32"),
B: T.Buffer((1000, 64), "float32"),
C: T.Buffer((1020, 1000), "float32"),
@@ -976,7 +992,7 @@ def
test_compact_dependent_buffer_indices_of_packed_matmul():
]
@T.prim_func
- def nonuniform_packed_matmul_write_cache_compacted(
+ def expected(
A: T.Buffer((1020, 64), "float32"),
B: T.Buffer((1000, 64), "float32"),
C: T.Buffer((1020, 1000), "float32"),
@@ -1006,90 +1022,338 @@ def
test_compact_dependent_buffer_indices_of_packed_matmul():
(ax0 * 16 + ax1) % 255 % 16,
]
- _check(nonuniform_packed_matmul_write_cache,
nonuniform_packed_matmul_write_cache_compacted)
+class TestTileAwareCompaction(BaseCompactTest):
+ """Each partitioned tile could be independently compacted."""
-def test_compact_symbolic_bound0():
- """Test symbolic bound that get compacted to constant"""
+ # it is not an opaque block case intentionally
+ is_lower_order_free = False
+
+ @T.prim_func
+ def before(
+ A: T.Buffer((128, 128), "float32"),
+ B: T.Buffer((128, 128), "float32"),
+ C: T.Buffer((128, 128), "float32"),
+ ):
+ for i_0 in range(5, annotations={"pragma_loop_partition_hint": 1}):
+ for j_0 in range(5, annotations={"pragma_loop_partition_hint": 1}):
+ A_local = T.decl_buffer((26, 128), scope="local")
+ B_local = T.decl_buffer((128, 26), scope="local")
+ C_local = T.decl_buffer((26, 26), scope="local")
+ for ax0, ax1 in T.grid(26, 128):
+ if i_0 * 26 + ax0 < 128:
+ A_local[ax0, ax1] = A[i_0 * 26 + ax0, ax1]
+ for ax0, ax1 in T.grid(128, 26):
+ if j_0 * 26 + ax1 < 128:
+ B_local[ax0, ax1] = B[ax0, j_0 * 26 + ax1]
+ for i_1, j_1, k in T.grid(26, 26, 128):
+ if i_0 * 26 + i_1 < 128 and j_0 * 26 + j_1 < 128:
+ if k == 0:
+ C_local[i_1, j_1] = T.float32(0)
+ C_local[i_1, j_1] = C_local[i_1, j_1] + A_local[i_1,
k] * B_local[k, j_1]
+ for ax0, ax1 in T.grid(26, 26):
+ if i_0 * 26 + ax0 < 128 and j_0 * 26 + ax1 < 128:
+ C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local[ax0, ax1]
+
+ # Get partitioned workload to compact
+ before_mod = tvm.IRModule.from_expr(before)
+ with tvm.transform.PassContext(config={"tir.LoopPartition":
{"partition_const_loop": True}}):
+ before_mod = tvm.tir.transform.LowerOpaqueBlock()(before_mod)
+ before_mod = tvm.tir.transform.LoopPartition()(before_mod)
+ before = before_mod["main"]
+
+ @T.prim_func
+ def expected(
+ A: T.Buffer((128, 128), "float32"),
+ B: T.Buffer((128, 128), "float32"),
+ C: T.Buffer((128, 128), "float32"),
+ ):
+ for i_0 in range(4):
+ for j_0 in range(4):
+ A_local_tile0 = T.decl_buffer((26, 128), scope="local")
+ B_local_tile0 = T.decl_buffer((128, 26), scope="local")
+ C_local_tile0 = T.decl_buffer((26, 26), scope="local")
+ for ax0, ax1 in T.grid(26, 128):
+ A_local_tile0[ax0, ax1] = A[i_0 * 26 + ax0, ax1]
+ for ax0, ax1 in T.grid(128, 26):
+ B_local_tile0[ax0, ax1] = B[ax0, j_0 * 26 + ax1]
+ for i_1, j_1, k in T.grid(26, 26, 128):
+ if k == 0:
+ C_local_tile0[i_1, j_1] = T.float32(0)
+ C_local_tile0[i_1, j_1] = (
+ C_local_tile0[i_1, j_1] + A_local_tile0[i_1, k] *
B_local_tile0[k, j_1]
+ )
+ for ax0, ax1 in T.grid(26, 26):
+ C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local_tile0[ax0, ax1]
+
+ A_local_tile1 = T.decl_buffer((26, 128), scope="local")
+ B_local_tile1 = T.decl_buffer((128, 24), scope="local")
+ C_local_tile1 = T.decl_buffer((26, 24), scope="local")
+ for ax0, ax1 in T.grid(26, 128):
+ A_local_tile1[ax0, ax1] = A[i_0 * 26 + ax0, ax1]
+ for ax0, ax1 in T.grid(128, 26):
+ if ax1 < 24:
+ B_local_tile1[ax0, ax1] = B[ax0, ax1 + 104]
+ for i_1, j_1, k in T.grid(26, 26, 128):
+ if j_1 < 24:
+ if k == 0:
+ C_local_tile1[i_1, j_1] = T.float32(0)
+ C_local_tile1[i_1, j_1] = (
+ C_local_tile1[i_1, j_1] + A_local_tile1[i_1, k] *
B_local_tile1[k, j_1]
+ )
+ for ax0, ax1 in T.grid(26, 26):
+ if ax1 < 24:
+ C[i_0 * 26 + ax0, ax1 + 104] = C_local_tile1[ax0, ax1]
+
+ for j_0 in range(4):
+ A_local_tile2 = T.decl_buffer((24, 128), scope="local")
+ B_local_tile2 = T.decl_buffer((128, 26), scope="local")
+ C_local_tile2 = T.decl_buffer((24, 26), scope="local")
+ for ax0, ax1 in T.grid(26, 128):
+ if ax0 < 24:
+ A_local_tile2[ax0, ax1] = A[ax0 + 104, ax1]
+ for ax0, ax1 in T.grid(128, 26):
+ B_local_tile2[ax0, ax1] = B[ax0, j_0 * 26 + ax1]
+ for i_1, j_1, k in T.grid(26, 26, 128):
+ if i_1 < 24:
+ if k == 0:
+ C_local_tile2[i_1, j_1] = T.float32(0)
+ C_local_tile2[i_1, j_1] = (
+ C_local_tile2[i_1, j_1] + A_local_tile2[i_1, k] *
B_local_tile2[k, j_1]
+ )
+ for ax0, ax1 in T.grid(26, 26):
+ if ax0 < 24:
+ C[ax0 + 104, j_0 * 26 + ax1] = C_local_tile2[ax0, ax1]
+
+ A_local_tile3 = T.decl_buffer((24, 128), scope="local")
+ B_local_tile3 = T.decl_buffer((128, 24), scope="local")
+ C_local_tile3 = T.decl_buffer((24, 24), scope="local")
+ for ax0, ax1 in T.grid(26, 128):
+ if ax0 < 24:
+ A_local_tile3[ax0, ax1] = A[ax0 + 104, ax1]
+ for ax0, ax1 in T.grid(128, 26):
+ if ax1 < 24:
+ B_local_tile3[ax0, ax1] = B[ax0, ax1 + 104]
+ for i_1, j_1, k in T.grid(26, 26, 128):
+ if i_1 < 24 and j_1 < 24:
+ if k == 0:
+ C_local_tile3[i_1, j_1] = T.float32(0)
+ C_local_tile3[i_1, j_1] = (
+ C_local_tile3[i_1, j_1] + A_local_tile3[i_1, k] *
B_local_tile3[k, j_1]
+ )
+ for ax0, ax1 in T.grid(26, 26):
+ if ax0 < 24 and ax1 < 24:
+ C[ax0 + 104, ax1 + 104] = C_local_tile3[ax0, ax1]
+
+
+class TestNonStrictCompactionForPaddedMatmul(BaseCompactTest):
+
+ is_strict_mode = False
+
+ @T.prim_func
+ def before(
+ A: T.Buffer((127, 127), "float32"),
+ B: T.Buffer((127, 127), "float32"),
+ C: T.Buffer((127, 127), "float32"),
+ ):
+ """A mock workload where the intermediate buffer allocation is not
enought originally"""
+ for i_0, j_0 in T.grid(4, 4):
+ with T.block(""):
+ T.reads(A[i_0 * 32 : i_0 * 32 + 32, 0:128], B[0:128, j_0 * 32
: j_0 * 32 + 32])
+ T.writes(C[i_0 * 32 : i_0 * 32 + 32, j_0 * 32 : j_0 * 32 + 32])
+ A_local = T.alloc_buffer((127, 127), scope="local")
+ B_local = T.alloc_buffer((127, 127), scope="local")
+ C_local = T.alloc_buffer((127, 127), scope="local")
+ for ax0, ax1 in T.grid(32, 128):
+ with T.block("A_local"):
+ A_local[i_0 * 32 + ax0, ax1] = T.if_then_else(
+ i_0 * 32 + ax0 < 127, A[i_0 * 32 + ax0, ax1], 0.0
+ )
+ for ax0, ax1 in T.grid(128, 32):
+ with T.block("B_local"):
+ B_local[ax0, j_0 * 32 + ax1] = T.if_then_else(
+ j_0 * 32 + ax1 < 127, B[ax0, j_0 * 32 + ax1], 0.0
+ )
+ for i_1, j_1, k in T.grid(32, 32, 128):
+ with T.block("compute"):
+ T.where(i_0 * 32 + i_1 < 127 and j_0 * 32 + j_1 < 127)
+ if k == 0:
+ C_local[i_0 * 32 + i_1, j_0 * 32 + j_1] =
T.float32(0)
+ C_local[i_0 * 32 + i_1, j_0 * 32 + j_1] = (
+ C_local[i_0 * 32 + i_1, j_0 * 32 + j_1]
+ + A_local[i_0 * 32 + i_1, k] * B_local[k, j_0 * 32
+ j_1]
+ )
+ for ax0, ax1 in T.grid(32, 32):
+ with T.block("C_local"):
+ T.where(i_0 * 32 + ax0 < 127 and j_0 * 32 + ax1 < 127)
+ C[i_0 * 32 + ax0, j_0 * 32 + ax1] = C_local[i_0 * 32 +
ax0, j_0 * 32 + ax1]
+
+ @T.prim_func
+ def expected(
+ A: T.Buffer((127, 127), "float32"),
+ B: T.Buffer((127, 127), "float32"),
+ C: T.Buffer((127, 127), "float32"),
+ ):
+ for i_0, j_0 in T.grid(4, 4):
+ with T.block(""):
+ T.reads(A[i_0 * 32 : i_0 * 32 + 32, 0:128], B[0:128, j_0 * 32
: j_0 * 32 + 32])
+ T.writes(C[i_0 * 32 : i_0 * 32 + 32, j_0 * 32 : j_0 * 32 + 32])
+ A_local = T.alloc_buffer((32, 128), scope="local")
+ B_local = T.alloc_buffer((128, 32), scope="local")
+ C_local = T.alloc_buffer((32, 32), scope="local")
+ for ax0, ax1 in T.grid(32, 128):
+ with T.block("A_local"):
+ A_local[ax0, ax1] = T.if_then_else(
+ i_0 * 32 + ax0 < 127, A[i_0 * 32 + ax0, ax1],
T.float32(0)
+ )
+ for ax0, ax1 in T.grid(128, 32):
+ with T.block("B_local"):
+ B_local[ax0, ax1] = T.if_then_else(
+ j_0 * 32 + ax1 < 127, B[ax0, j_0 * 32 + ax1],
T.float32(0)
+ )
+ for i_1, j_1, k in T.grid(32, 32, 128):
+ with T.block("compute"):
+ T.where(i_0 * 32 + i_1 < 127 and j_0 * 32 + j_1 < 127)
+ if k == 0:
+ C_local[i_1, j_1] = T.float32(0)
+ C_local[i_1, j_1] = C_local[i_1, j_1] + A_local[i_1,
k] * B_local[k, j_1]
+ for ax0, ax1 in T.grid(32, 32):
+ with T.block("C_local"):
+ T.where(i_0 * 32 + ax0 < 127 and j_0 * 32 + ax1 < 127)
+ C[i_0 * 32 + ax0, j_0 * 32 + ax1] = C_local[ax0, ax1]
+
+
+class TestNotCompactAliasBuffer(BaseCompactTest):
- @tvm.script.ir_module
- class Before:
- @T.prim_func
- def main(x: T.handle, y: T.handle, n: T.int64):
- X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
- Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
- for i, k_0 in T.grid(T.int64(8), n):
- with T.block(""):
- X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
- for ax0 in range(T.int64(32)):
- with T.block("X_global"):
- X_global[i, k_0 * T.int64(32) + ax0] = X[i, k_0 *
T.int64(32) + ax0]
- for k_1 in range(T.int64(32)):
- with T.block("Y"):
- Y[i, k_0 * T.int64(32) + k_1] = X_global[i, k_0 *
T.int64(32) + k_1]
-
- @tvm.script.ir_module
- class Expected:
- @T.prim_func
- def main(x: T.handle, y: T.handle, n: T.int64):
- X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
- Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
- for i, k_0 in T.grid(T.int64(8), n):
- with T.block(""):
- X_global = T.alloc_buffer((T.int64(1), T.int64(32)))
- for ax0 in range(T.int64(32)):
- with T.block("X_global"):
- X_global[T.int64(0), ax0] = X[i, k_0 * T.int64(32)
+ ax0]
- for k_1 in range(T.int64(32)):
- with T.block("Y"):
- Y[i, k_0 * T.int64(32) + k_1] =
X_global[T.int64(0), k_1]
-
- mod = Before
- mod = tvm.tir.transform.CompactBufferAllocation()(mod)
- after = tvm.tir.transform.Simplify()(mod)
- tvm.ir.assert_structural_equal(after, Expected)
-
-
-def test_compact_symbolic_bound1():
+ # it is not testcase on block form
+ is_lower_order_free = False
+
+ @T.prim_func
+ def before():
+ """Partially accessed buffer, but should not compact
+ because existence of aliasing buffer B."""
+ data = T.allocate([1024], "int8")
+ A = T.decl_buffer([1024], "int8", data)
+ B = T.decl_buffer([512], "float16", data)
+ for i in range(10):
+ A[i] = A[i] + T.int8(1)
+ for i in range(10):
+ B[i] = B[i] + T.float16(1)
+
+ expected = before
+
+
+class TestNotCompactBufferWithDifferentDtype(BaseCompactTest):
+
+ # it is not testcase on block form
+ is_lower_order_free = False
+
+ @T.prim_func
+ def before():
+ """Partially accessed buffer, but should not compact
+ because existence of aliasing buffer B."""
+ data = T.allocate([1024], "int8")
+ A = T.decl_buffer([256], "int32", data)
+ for i in range(10):
+ A[i] = A[i] + 1
+
+ expected = before
+
+
+class TestNonBoolCondition(BaseCompactTest):
+
+ # it is not testcase on block form
+ is_lower_order_free = False
+
+ @T.prim_func
+ def before():
+ data = T.allocate([12], "int32")
+ A = T.Buffer([12], "int32", data)
+ for i in range(10):
+ if i:
+ A[i] = A[i] + 1
+
+ @T.prim_func
+ def expected():
+ data = T.allocate([9], "int32")
+ A = T.Buffer([9], "int32", data)
+ for i in range(10):
+ if i:
+ A[i - 1] = A[i - 1] + 1
+
+
+def test_lower_te():
+ x = te.placeholder((1,))
+ y = te.compute((1,), lambda i: x[i] + 2)
+ s = te.create_schedule(y.op)
+ orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+ mod = tvm.tir.transform.CompactBufferAllocation()(orig_mod)
+ tvm.ir.assert_structural_equal(mod, orig_mod) # CompactBufferAllocation
should do nothing on TE
+
+
+class TestCompactSymbolicBound0:
"""Test symbolic bound that get compacted to constant"""
- @tvm.script.ir_module
- class Before:
- @T.prim_func
- def main(x: T.handle, y: T.handle, n: T.int64):
- X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
- Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
- for i, k_0 in T.grid(T.int64(8), n):
- with T.block(""):
- X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
+ @T.prim_func
+ def before(x: T.handle, y: T.handle, n: T.int64):
+ X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+ Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+ for i, k_0 in T.grid(T.int64(8), n):
+ with T.block(""):
+ X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
+ for ax0 in range(T.int64(32)):
with T.block("X_global"):
- for x0 in range(T.int64(32)):
- X_global[i, k_0 * T.int64(32) + x0] = X[i, k_0 *
T.int64(32) + x0]
+ X_global[i, k_0 * T.int64(32) + ax0] = X[i, k_0 *
T.int64(32) + ax0]
+ for k_1 in range(T.int64(32)):
with T.block("Y"):
- for x1 in range(T.int64(32)):
- Y[i, k_0 * T.int64(32) + x1] = X_global[i, k_0 *
T.int64(32) + x1]
-
- @tvm.script.ir_module
- class Expected:
- @T.prim_func
- def main(x: T.handle, y: T.handle, n: T.int64):
- X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
- Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
- # with T.block("root"):
- for i, k_0 in T.grid(T.int64(8), n):
- with T.block(""):
- X_global = T.alloc_buffer((T.int64(1), T.int64(32)))
+ Y[i, k_0 * T.int64(32) + k_1] = X_global[i, k_0 *
T.int64(32) + k_1]
+
+ @T.prim_func
+ def expected(x: T.handle, y: T.handle, n: T.int64):
+ X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+ Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+ for i, k_0 in T.grid(T.int64(8), n):
+ with T.block(""):
+ X_global = T.alloc_buffer((T.int64(1), T.int64(32)))
+ for ax0 in range(T.int64(32)):
with T.block("X_global"):
- for x0 in range(T.int64(32)):
- X_global[T.int64(0), x0] = X[i, k_0 * T.int64(32)
+ x0]
+ X_global[T.int64(0), ax0] = X[i, k_0 * T.int64(32) +
ax0]
+ for k_1 in range(T.int64(32)):
with T.block("Y"):
- for x1 in range(T.int64(32)):
- Y[i, k_0 * T.int64(32) + x1] =
X_global[T.int64(0), x1]
+ Y[i, k_0 * T.int64(32) + k_1] = X_global[T.int64(0),
k_1]
- mod = Before
- mod = tvm.tir.transform.CompactBufferAllocation()(mod)
- after = tvm.tir.transform.Simplify()(mod)
- tvm.ir.assert_structural_equal(after, Expected)
+
+class TestCompactSymbolicBound1:
+ """Test symbolic bound that get compacted to constant"""
+
+ @T.prim_func
+ def before(x: T.handle, y: T.handle, n: T.int64):
+ X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+ Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+ for i, k_0 in T.grid(T.int64(8), n):
+ with T.block(""):
+ X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
+ with T.block("X_global"):
+ for x0 in range(T.int64(32)):
+ X_global[i, k_0 * T.int64(32) + x0] = X[i, k_0 *
T.int64(32) + x0]
+ with T.block("Y"):
+ for x1 in range(T.int64(32)):
+ Y[i, k_0 * T.int64(32) + x1] = X_global[i, k_0 *
T.int64(32) + x1]
+
+ @T.prim_func
+ def expected(x: T.handle, y: T.handle, n: T.int64):
+ X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+ Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+ # with T.block("root"):
+ for i, k_0 in T.grid(T.int64(8), n):
+ with T.block(""):
+ X_global = T.alloc_buffer((T.int64(1), T.int64(32)))
+ with T.block("X_global"):
+ for x0 in range(T.int64(32)):
+ X_global[T.int64(0), x0] = X[i, k_0 * T.int64(32) + x0]
+ with T.block("Y"):
+ for x1 in range(T.int64(32)):
+ Y[i, k_0 * T.int64(32) + x1] = X_global[T.int64(0), x1]
if __name__ == "__main__":