Lunderberg commented on code in PR #14021:
URL: https://github.com/apache/tvm/pull/14021#discussion_r1124618521
##########
src/tir/analysis/block_access_region_detector.cc:
##########
@@ -170,12 +172,12 @@ void BlockReadWriteDetector::VisitStmt_(const
IfThenElseNode* op) {
VisitExpr(op->condition);
{
// Visit then branch
- With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_,
true);
+ With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_,
&pending_conditions_);
Review Comment:
Should these `With<ConditionalBoundsContext>` also be applied to a
`BlockRealizeNode::predicate`?
##########
src/tir/transforms/compact_buffer_region.cc:
##########
@@ -481,47 +520,67 @@ class BufferCompactor : public StmtExprMutator {
return std::move(block);
}
- Array<Buffer> RewriteAllocBuffer(const Array<Buffer>& buffers) {
- Array<Buffer> result;
- result.reserve(buffers.size());
- for (const Buffer& buffer : buffers) {
- auto it = buffer_info_.find(buffer);
- ICHECK(it != buffer_info_.end());
- BufferAllocInfo& info = it->second;
- Array<PrimExpr> shape;
- shape.reserve(info.region.size());
- for (const Range& range : info.region) {
- shape.push_back(range->extent);
- }
- Array<PrimExpr> strides;
- if (info.dim_aligns.size()) {
- ICHECK(info.dim_aligns.size() == shape.size());
- strides.resize(shape.size());
- PrimExpr stride = make_const(shape[0].dtype(), 1);
- for (size_t i = shape.size(); i != 0; --i) {
- size_t dim = i - 1;
- if (info.dim_aligns[dim].align_factor != 0) {
- PrimExpr factor = make_const(stride.dtype(),
info.dim_aligns[dim].align_factor);
- PrimExpr offset = make_const(stride.dtype(),
info.dim_aligns[dim].align_offset);
- stride = stride + indexmod(factor + offset - indexmod(stride,
factor), factor);
- }
- strides.Set(dim, stride);
- stride = stride * shape[dim];
+ Stmt VisitStmt_(const DeclBufferNode* op) final {
+ Buffer new_buffer = RewriteAllocBuffer(op->buffer);
+ auto n = CopyOnWrite(op);
+ n->buffer = std::move(new_buffer);
+ n->body = VisitStmt(op->body);
+ return DeclBuffer(n);
+ }
+
+ Stmt VisitStmt_(const AllocateNode* op) final {
Review Comment:
Should we also support `AllocateConstNode`?
##########
src/tir/transforms/compact_buffer_region.cc:
##########
@@ -247,37 +218,61 @@ class BufferAccessRegionCollector : public
StmtExprVisitor {
}
// Step 6. Update buffer_access_region_ from relaxed_accesses_ for inner
buffers.
for (const Buffer& buffer : op->alloc_buffers) {
- auto it = relaxed_accesses_.find(buffer);
- ICHECK(it != relaxed_accesses_.end())
- << buffer << " is allocated but not accessed within block scope";
- const NDIntSet& nd_int_set = it->second;
- buffer_access_region_[buffer] =
SimplifyAndNarrowBufferRegionFromNDIntSet(
- nd_int_set, buffer->shape, &dom_analyzer_, ancestor_loops_);
+ SimplifyAndNarrowBufferRegionFromNDIntSet(buffer);
}
}
void VisitStmt_(const BlockRealizeNode* op) final {
- PrimExpr cur_predicate = predicate_in_scope;
- predicate_in_scope = op->predicate;
+ With<ConditionalBoundsContext> ctx(op->predicate, &dom_map_, &hint_map_,
&pending_conditions_);
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitStmt_(const DeclBufferNode* op) final {
Review Comment:
If there is no DeclBuffer present for a buffer, would the resizing work
correctly if it pre-visits the stmt to collect all buffers used within the
body? That may be a way to avoid depending on DeclBuffer, which isn't used at
all points in the lowering flow.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]