wrongtest-intellif commented on code in PR #13033: URL: https://github.com/apache/tvm/pull/13033#discussion_r994085307
########## src/tir/schedule/primitive/rolling_buffer.cc: ########## @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <functional> + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +namespace { + +struct RollingBufferInfo { + Buffer old_buffer; + Buffer new_buffer; + int rolling_axis; + int rolling_extent; + std::vector<int> axis_overlaps; + std::vector<Optional<Var>> axis_iter_vars; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map<Block, Block> block_reuse; +}; + +BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, + const Map<Var, arith::IntSet>& dom_map) { + Array<arith::IntSet> relaxed_intsets = + arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); + Region relaxed_region; + relaxed_region.reserve(relaxed_intsets.size()); + for (size_t i = 0; i < relaxed_intsets.size(); ++i) { + relaxed_region.push_back( + relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); + } + return BufferRegion(buffer_region->buffer, relaxed_region); +} + +class RollingBufferMatchError : public ScheduleError { + public: + RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + : mod_(mod), block_(block), buffer_region_(buffer_region) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" + "matching the rolling pattern such as: hh.outer * stride + hh.inner"; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_region_->buffer->name << " with region " + << buffer_region_->region + << " should have at least one dimension range that matches a rolling pattern " + "such as hh.outer * stride + hh.inner. "; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion buffer_region_; +}; + +class RollingBufferInsertionError : public ScheduleError { + public: + RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + : mod_(mod), buffer_(std::move(buffer)), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " + "location of the target buffer is not a for loop. "; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " + << "the lca of the access location of the target buffer " << buffer_->name + << " is a for loop. "; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; +}; + +class RollingBufferInfoCollector { + public: + static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, + const StmtSRef& block_sref, + const BufferRegion& buffer_region) { + RollingBufferInfoCollector collector; + if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region); + } + return collector.info_; + } + + private: + bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + const Region& region = buffer_region->region; + + std::vector<Optional<Var>> bound_iter_vars; + std::vector<int> bound_overlaps; + auto stride = 0; + auto divisor = 1; + Optional<Var> iter_var; + for (auto bound : region) { + divisor = 1; + if (auto floor_div = bound->min.as<FloorDivNode>()) { + // Handle the case of fractional strides + // They take this form: floordiv(hh.outer, 2) + // Strip the floordiv and keep track of the divisor + divisor = Downcast<IntImm>(floor_div->b)->value; + bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span); + } + if (bound->min.as<IntImmNode>()) { + // If the bound is an int, we can't roll over it + iter_var = NullOpt; + } else if (auto var = bound->min.as<VarNode>()) { + // If the bound is just a Var, that implies the stride is 1 + iter_var = GetRef<Var>(var); + stride = 1; + } else { + // Otherwise, it's the iter var multiplied by the stride + // If not we're in unknown behaviour + if (auto mul = bound->min.as<MulNode>()) { + if (mul->a->IsInstance<VarNode>() && mul->b->IsInstance<IntImmNode>()) { + iter_var = Downcast<Var>(mul->a); + stride = Downcast<IntImm>(mul->b)->value; + } else { + return false; + } + } else { + return false; + } + } + stride = std::ceil(static_cast<float>(stride) / divisor); + auto bound_overlap = 0; + if (iter_var.defined()) { + auto extent = bound->extent.as<IntImmNode>(); + ICHECK(extent); + bound_overlap = extent->value - stride; + // Since Pass CompactBufferAllocation will be responsible for compacting the buffer + // allocation region, there is no need to roll over the axis where the overlap is not + // positive, so reset iter_var to NullOpt. + if (bound_overlap <= 0) { + iter_var = NullOpt; + } + } + bound_iter_vars.push_back(iter_var); + bound_overlaps.push_back(bound_overlap); + } + Array<StmtSRef> loop_srefs = GetLoops(block_sref); + // Pick the outermost iter_var that's mentioned in the bounds + // to be the rolling axis + Optional<Var> roll_iter_var; + int roll_axis; + for (const tir::StmtSRef& loop_sref : loop_srefs) { + auto loop_var = loop_sref->StmtAs<ForNode>()->loop_var; + + auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional<Var> var) { + return var && (var.get() == loop_var.get()); + })}; + if (it != bound_iter_vars.end()) { + auto i = std::distance(bound_iter_vars.begin(), it); + roll_iter_var = loop_var; + roll_axis = i; + break; + } + } + + if (!roll_iter_var.defined()) { + return false; + } + Array<PrimExpr> new_shape = buffer->shape; + new_shape.Set(roll_axis, region[roll_axis]->extent); + Buffer new_buffer = buffer; + new_buffer.CopyOnWrite()->shape = new_shape; + + info_.old_buffer = buffer; + info_.new_buffer = new_buffer; + info_.rolling_axis = roll_axis; + info_.rolling_extent = static_cast<int>(Downcast<IntImm>(region[roll_axis]->extent)->value); + info_.axis_overlaps = bound_overlaps; + info_.axis_iter_vars = bound_iter_vars; + + return true; + } + + RollingBufferInfo info_; +}; + +class RollingBufferRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) { + RollingBufferRewriter rewriter(scope_sref, info); + return rewriter(GetRef<Stmt>(scope_sref->stmt)); + } + + private: + explicit RollingBufferRewriter(const StmtSRef& scope_sref, RollingBufferInfo* info) + : scope_sref_(scope_sref), info_(info) {} + + void RewriteAccessRegion(Array<BufferRegion>* old_access_regions, + const Array<BufferRegion>& infered_access_regions) { + auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { + if (buffer_region->buffer.same_as(info_->old_buffer)) { + ICHECK(infered_access_regions.size() == 1); + return infered_access_regions[0]; + } + return buffer_region; + }; + (*old_access_regions).MutateByApply(fmutate); + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef<Block>(block); + Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block)); + if (block == scope_sref_->stmt) { + ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); + + Array<Buffer> new_alloc_buffers; + for (const Buffer& buffer : stmt->alloc_buffers) { + if (buffer != info_->old_buffer) { + new_alloc_buffers.push_back(buffer); + } else { + new_alloc_buffers.push_back(info_->new_buffer); + } + } + n->alloc_buffers = std::move(new_alloc_buffers); + stmt = Block(n); + } else { + Array<IterVar> new_iter_bindings; + for (size_t i = 0; i < stmt->iter_vars.size(); ++i) { + auto old_iter_var = stmt->iter_vars[i]; + if (static_cast<int>(i) == info_->rolling_axis) { + // All inner loops of the rolling axis has a loop carried dependency + // (i.e. each iteration calculation of the rolling axis depends on + // the calculation results of all the historical iterations of inner loops), + // so annotate the iteration type of the rolling axis as 'opaque', + // avoid the iterative range of its inner loop from being compressed + // during lowering phase. + IterVar new_iter_var = + IterVar(old_iter_var->dom, old_iter_var->var, IterVarType::kOpaque); + new_iter_bindings.push_back(new_iter_var); + } else { + new_iter_bindings.push_back(old_iter_var); + } + } + Map<Var, Buffer> buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; + auto infered_access_regions = GetBlockReadWriteRegion(stmt, buffer_data_to_buffer); + + BlockNode* n = stmt.CopyOnWrite(); + n->iter_vars = std::move(new_iter_bindings); + RewriteAccessRegion(&n->reads, infered_access_regions[0]); + RewriteAccessRegion(&n->writes, infered_access_regions[1]); + } + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + BlockRealize stmt = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(realize)); + // Append block predicate to avoid recomputing elements. + if (rewrite_block_predicate_) { + rewrite_block_predicate_ = false; + PrimExpr condition = stmt->predicate; + for (size_t i = 0; i < info_->axis_iter_vars.size(); ++i) { + auto iter_var = info_->axis_iter_vars[i]; + if (iter_var && info_->axis_overlaps[i] > 0) { + Var var = iter_var.value(); + const Map<Var, arith::IntSet> dmap = {std::make_pair(var, arith::IntSet::Interval(0, 0))}; + auto iter_value = realize->iter_values[i]; + arith::Analyzer analyzer; + auto term_2 = analyzer.int_set(iter_value, dmap).min(); + condition = analyzer.Simplify( + And(condition, Or(LT(var, 1), GE(term_2, info_->axis_overlaps[i])))); + } + } + BlockRealizeNode* n = stmt.CopyOnWrite(); + n->predicate = condition; + } + return std::move(stmt); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore stmt = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); + if (stmt->buffer.same_as(info_->old_buffer)) { + Array<PrimExpr> new_indices; + new_indices.reserve(stmt->indices.size()); + // First modify the access indices to use modulo arithmetic + // for the rolling axis + for (size_t i = 0; i < stmt->indices.size(); ++i) { + auto index = stmt->indices[i]; + if (static_cast<int>(i) == info_->rolling_axis) { + new_indices.push_back(FloorMod(index, info_->rolling_extent)); + } else { + new_indices.push_back(index); + } + } + BufferStoreNode* n = stmt.CopyOnWrite(); + // Replace the stored buffer with the new buffer. + n->buffer = info_->new_buffer; + n->indices = std::move(new_indices); + // Need to add predicate to the current block to avoid recomputing elements. + rewrite_block_predicate_ = true; + } + return std::move(stmt); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad stmt = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); + if (stmt->buffer.same_as(info_->old_buffer)) { + Array<PrimExpr> new_indices; + new_indices.reserve(stmt->indices.size()); + for (size_t i{0}; i < stmt->indices.size(); ++i) { + auto index = stmt->indices[i]; + if (static_cast<int>(i) == info_->rolling_axis) { + new_indices.push_back(FloorMod(index, info_->rolling_extent)); + } else { + new_indices.push_back(index); + } + } + BufferLoadNode* n = stmt.CopyOnWrite(); + // Replace the loaded buffer with the new buffer. + n->buffer = info_->new_buffer; + n->indices = std::move(new_indices); + } + return std::move(stmt); + } + + private: + const StmtSRef& scope_sref_; + RollingBufferInfo* info_; + bool rewrite_block_predicate_ = false; +}; + +} // namespace + +void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index) { + /*! + * Check + * - The block is not an output block. + * - The block is tiled and there is access overlap between adjacent tiles. + * Mutate + * - Select the outermost rollable axis appeared in the block's loop nest + * as the 'rolling axis', trim the target buffer from the rolling axis. + * - Use modulo arithmetic to modify the target buffer's read and load + * indices to circularize the buffer along the rolling dimension. + * - Append block predicate to avoid recomputing overlapping elements. + */ + Map<Var, arith::IntSet> dom_map; + const BlockRealize& realize = GetBlockRealize(self, block_sref); + const Block& block = realize->block; + + // Step 1. Checking index, getting the target buffer region and the parent scope. + const BufferRegion& buffer_region = + GetNthAccessBufferRegion(self, block, write_buffer_index, BufferIndexType::kWrite); + StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); + // Step 2. Check the target block is not an output block. + CheckNotOutputBlock(self, block_sref, scope_root_sref); + + // Step 3. Find the lca of the access location of the target buffer and relax the buffer + Array<StmtSRef> loop_srefs = GetLoops(block_sref); + Array<StmtSRef> consumers_sref = GetConsumers(self, block_sref); + consumers_sref.push_back(block_sref); + StmtSRef lca = GetSRefLowestCommonAncestor(consumers_sref); + if (!lca->StmtAs<ForNode>()) { + throw RollingBufferInsertionError(self->mod, buffer_region->buffer, block); + } + + for (auto it = loop_srefs.rbegin(); it != loop_srefs.rend(); ++it) { + auto stmt = *it; + // Stop at the lca of all the rolling_buffer access points; + if (stmt == lca) { + break; + } + For cur_loop = GetRef<For>(stmt->StmtAs<ForNode>()); + Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent); + dom_map.Set(cur_loop->loop_var, arith::IntSet::FromRange(range)); + } + BufferRegion relaxed_region = GetRelaxedBufferRegion(realize, buffer_region, dom_map); + + // Step 4. Find an valid rolling axis and collect bound overlaps on the target buffer. Review Comment: typo: a valid ########## tests/python/unittest/test_tir_schedule_rolling_buffer.py: ########## @@ -0,0 +1,534 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import numpy as np +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip +import pytest + + +def check_rolling_buffer( + sch: tir.Schedule, origin: tir.PrimFunc, expected: tir.PrimFunc, check_run=False +): + scheduled = sch.mod["main"] + tvm.ir.assert_structural_equal(scheduled, expected) + verify_trace_roundtrip(sch, origin) + if check_run: + in_buffer = origin.buffer_map[origin.params[0]] + out_buffer = origin.buffer_map[origin.params[1]] + in_shape = [int(_) for _ in in_buffer.shape] + out_shape = [int(_) for _ in out_buffer.shape] + x = tvm.nd.array(np.random.uniform(0, 64, in_shape).astype(in_buffer.dtype)) + y0 = tvm.nd.array(np.zeros(out_shape).astype(out_buffer.dtype)) + y1 = tvm.nd.array(np.zeros(out_shape).astype(out_buffer.dtype)) + f_origin = tvm.build(origin) + f_scheduled = tvm.build(scheduled) + f_origin(x, y0) + f_scheduled(x, y1) + tvm.testing.assert_allclose(y0.numpy(), y1.numpy()) + + +def _tile_nd(s, tile, block_name): + outer_indices = [] + inner_indices = [] + block = s.get_block(block_name) + loops = s.get_loops(block) + for i, size in enumerate(tile): + outer, inner = s.split(loops[i], [None, size]) + outer_indices.append(outer) + inner_indices.append(inner) + + s.reorder(*outer_indices, *inner_indices) + return outer_indices, inner_indices + + +def test_1d_rolling_buffer(): + @T.prim_func + def before(A: T.Buffer[(4, 12), "int32"], C: T.Buffer[(4, 8), "int32"]): + B = T.alloc_buffer((4, 10), "int32") + for c in T.serial(4): + for i in T.serial(0, 10): + for k in T.serial(3): + with T.block("B"): + cc, vi, vk = T.axis.remap("SSR", [c, i, k]) + with T.init(): + B[cc, vi] = 0 + B[cc, vi] = B[cc, vi] + A[cc, vi + vk] + for i in T.serial(0, 8): + for k in T.serial(3): + with T.block("C"): + cc, vi, vk = T.axis.remap("SSR", [c, i, k]) + with T.init(): + C[cc, vi] = 0 + C[cc, vi] = C[cc, vi] + B[cc, vi + vk] + + @T.prim_func + def expected(A: T.Buffer[(4, 12), "int32"], C: T.Buffer[(4, 8), "int32"]): + B = T.alloc_buffer([4, 6], dtype="int32") + for c, i_0 in T.grid(4, 2): + for ax0, ax1 in T.grid(6, 3): + with T.block("B"): + T.where(i_0 < 1 or 2 <= ax0) + cc = T.axis.spatial(4, c) + vi = T.axis.opaque(10, i_0 * 4 + ax0) + vk = T.axis.reduce(3, ax1) + T.reads(A[cc, vi + vk]) + T.writes(B[cc, vi % 6]) + with T.init(): + B[cc, vi % 6] = 0 + B[cc, vi % 6] = B[cc, vi % 6] + A[cc, vi + vk] + for i_1, k in T.grid(4, 3): + with T.block("C"): + cc = T.axis.spatial(4, c) + vi = T.axis.opaque(8, i_0 * 4 + i_1) + vk = T.axis.reduce(3, k) + T.reads(B[cc, (vi + vk) % 6]) + T.writes(C[cc, vi]) + with T.init(): + C[cc, vi] = 0 + C[cc, vi] = C[cc, vi] + B[cc, (vi + vk) % 6] + + sch = tir.Schedule(before, debug_mask="all") + _, i, _ = sch.get_loops(sch.get_block("C")) + io, _ = sch.split(i, [2, 4]) + sch.compute_at(sch.get_block("B"), io) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, before, expected, check_run=True) + + [email protected]_func +def cascade_2_max_pool2d(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]): + B = T.alloc_buffer([1, 10, 10, 16], dtype="int8") + for i0, i1, i2, i3, i4, i5 in T.grid(1, 10, 10, 16, 3, 3): + with T.block("B"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(A[ax0, ax1 + rv0, ax2 + rv1, ax3]) + T.writes(B[ax0, ax1, ax2, ax3]) + with T.init(): + B[ax0, ax1, ax2, ax3] = T.int8(-128) + B[ax0, ax1, ax2, ax3] = T.max(B[ax0, ax1, ax2, ax3], A[ax0, ax1 + rv0, ax2 + rv1, ax3]) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 8, 8, 16, 3, 3): + with T.block("C"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(B[ax0, ax1 + rv0, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max(C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, ax2 + rv1, ax3]) + + [email protected]_func +def cascade_3_max_pool2d_with_stride( + A: T.Buffer[(1, 24, 24, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"] +): + B_0 = T.alloc_buffer([1, 22, 22, 16], dtype="int8") + B_1 = T.alloc_buffer([1, 10, 10, 16], dtype="int8") + for i0, i1, i2, i3, i4, i5 in T.grid(1, 22, 22, 16, 3, 3): + with T.block("B_0"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(A[ax0, ax1 + rv0, ax2 + rv1, ax3]) + T.writes(B_0[ax0, ax1, ax2, ax3]) + with T.init(): + B_0[ax0, ax1, ax2, ax3] = T.int8(-128) + B_0[ax0, ax1, ax2, ax3] = T.max( + B_0[ax0, ax1, ax2, ax3], A[ax0, ax1 + rv0, ax2 + rv1, ax3] + ) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 10, 10, 16, 3, 3): + with T.block("B_1"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(B_0[ax0, ax1 * 2 + rv0, ax2 * 2 + rv1, ax3]) + T.writes(B_1[ax0, ax1, ax2, ax3]) + with T.init(): + B_1[ax0, ax1, ax2, ax3] = T.int8(-128) + B_1[ax0, ax1, ax2, ax3] = T.max( + B_1[ax0, ax1, ax2, ax3], B_0[ax0, ax1 * 2 + rv0, ax2 * 2 + rv1, ax3] + ) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 8, 8, 16, 3, 3): + with T.block("C"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(B_1[ax0, ax1 + rv0, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B_1[ax0, ax1 + rv0, ax2 + rv1, ax3] + ) + + +def test_cascade_max_pool2d_w_tiled(): + @T.prim_func + def expected(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]): + B = T.alloc_buffer([1, 10, 6, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 2, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(10, 6, 16, 3, 3): + with T.block("B"): + T.where(i2_0 < 1 or 2 <= ax1) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(10, ax0) + ax2_1 = T.axis.opaque(10, i2_0 * 4 + ax1) + ax3_1, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1, ax2_1 % 6, ax3_1]) + with T.init(): + B[ax0_1, ax1_1, ax2_1 % 6, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1, ax2_1 % 6, ax3_1] = T.max( + B[ax0_1, ax1_1, ax2_1 % 6, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 8, 4, 16, 3, 3): + with T.block("C"): + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.spatial(8, i1_0 * 8 + i1_1) + ax2 = T.axis.opaque(8, i2_0 * 4 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, ax1 + rv0, (ax2 + rv1) % 6, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, (ax2 + rv1) % 6, ax3] + ) + + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + oi, _ = _tile_nd(sch, [1, 8, 4, 16], "C") + sch.compute_at(sch.get_block("B"), oi[-1]) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) + + +def test_cascade_max_pool2d_h_tiled(): + @T.prim_func + def expected(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]): + B = T.alloc_buffer([1, 6, 10, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 1, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(6, 10, 16, 3, 3): + with T.block("B"): + T.where(i1_0 < 1 or 2 <= ax0) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(10, i1_0 * 4 + ax0) + ax2_1 = T.axis.spatial(10, ax1) + ax3_1, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1 % 6, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 8, 16, 3, 3): + with T.block("C"): + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.opaque(8, i1_0 * 4 + i1_1) + ax2 = T.axis.spatial(8, i2_0 * 8 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, (ax1 + rv0) % 6, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, (ax1 + rv0) % 6, ax2 + rv1, ax3] + ) + + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + io, _ = _tile_nd(sch, [1, 4, 8, 16], "C") + sch.compute_at(sch.get_block("B"), io[-1]) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) + + +def test_cascade_max_pool2d_h_w_c_tiled(): + @T.prim_func + def expected(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]): + B = T.alloc_buffer([1, 6, 10, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 2, 2): + for ax0, ax1, ax2, ax3, ax4 in T.grid(6, 6, 8, 3, 3): + with T.block("B"): + T.where((i1_0 < 1 or 2 <= ax0) and (i2_0 < 1 or 2 <= ax1)) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(10, i1_0 * 4 + ax0) + ax2_1 = T.axis.spatial(10, i2_0 * 4 + ax1) + ax3_1 = T.axis.spatial(16, i3_0 * 8 + ax2) + rv0, rv1 = T.axis.remap("RR", [ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1 % 6, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 8, 3, 3): + with T.block("C"): + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.opaque(8, i1_0 * 4 + i1_1) + ax2 = T.axis.spatial(8, i2_0 * 4 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 8 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, (ax1 + rv0) % 6, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, (ax1 + rv0) % 6, ax2 + rv1, ax3] + ) + + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + io, _ = _tile_nd(sch, [1, 4, 4, 8], "C") + sch.compute_at(sch.get_block("B"), io[-1]) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) + + +def test_cascade_max_pool2d_non_perfect_tiled(): + @T.prim_func + def expected(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: + B = T.alloc_buffer([1, 8, 10, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 2, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(8, 8, 16, 3, 3): + with T.block("B"): + T.where( + i1_0 * 6 + ax0 < 10 + and i2_0 * 6 + ax1 < 10 + and (i1_0 < 1 or 2 <= ax0) + and (i2_0 < 1 or 2 <= ax1) + ) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(10, i1_0 * 6 + ax0) + ax2_1 = T.axis.spatial(10, i2_0 * 6 + ax1) + ax3_1, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1 % 8, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1 % 8, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1 % 8, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1 % 8, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 6, 6, 16, 3, 3): + with T.block("C"): + T.where(i1_0 * 6 + i1_1 < 8 and i2_0 * 6 + i2_1 < 8) + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.opaque(8, i1_0 * 6 + i1_1) + ax2 = T.axis.spatial(8, i2_0 * 6 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, (ax1 + rv0) % 8, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, (ax1 + rv0) % 8, ax2 + rv1, ax3] + ) + + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + io, _ = _tile_nd(sch, [1, 6, 6, 16], "C") + sch.compute_at(sch.get_block("B"), io[-1]) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) + + +def test_cascade_3_max_pool2d_with_stride(): + @T.prim_func + def expected(A: T.Buffer[(1, 24, 24, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: + B_0 = T.alloc_buffer([1, 13, 22, 16], dtype="int8") + B_1 = T.alloc_buffer([1, 6, 10, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 2, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(13, 13, 16, 3, 3): + with T.block("B_0"): + T.where((i1_0 < 1 or 5 <= ax0) and (i2_0 < 1 or 5 <= ax1)) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(22, i1_0 * 8 + ax0) + ax2_1 = T.axis.spatial(22, i2_0 * 8 + ax1) + ax3_1, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B_0[ax0_1, ax1_1 % 13, ax2_1, ax3_1]) + with T.init(): + B_0[ax0_1, ax1_1 % 13, ax2_1, ax3_1] = T.int8(-128) + B_0[ax0_1, ax1_1 % 13, ax2_1, ax3_1] = T.max( + B_0[ax0_1, ax1_1 % 13, ax2_1, ax3_1], + A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1], + ) + for ax0, ax1, ax2, ax3, ax4 in T.grid(6, 6, 16, 3, 3): + with T.block("B_1"): + T.where((i1_0 < 1 or 2 <= ax0) and (i2_0 < 1 or 2 <= ax1)) + ax0_2 = T.axis.spatial(1, 0) + ax1_2 = T.axis.opaque(10, i1_0 * 4 + ax0) + ax2_2 = T.axis.spatial(10, i2_0 * 4 + ax1) + ax3_2, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(B_0[ax0_2, (ax1_2 * 2 + rv0) % 13, ax2_2 * 2 + rv1, ax3_2]) + T.writes(B_1[ax0_2, ax1_2 % 6, ax2_2, ax3_2]) + with T.init(): + B_1[ax0_2, ax1_2 % 6, ax2_2, ax3_2] = T.int8(-128) + B_1[ax0_2, ax1_2 % 6, ax2_2, ax3_2] = T.max( + B_1[ax0_2, ax1_2 % 6, ax2_2, ax3_2], + B_0[ax0_2, (ax1_2 * 2 + rv0) % 13, ax2_2 * 2 + rv1, ax3_2], + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 3, 3): + with T.block("C"): + ax0_3 = T.axis.spatial(1, i0_0 + i0_1) + ax1_3 = T.axis.opaque(8, i1_0 * 4 + i1_1) + ax2_3 = T.axis.spatial(8, i2_0 * 4 + i2_1) + ax3_3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B_1[ax0_3, (ax1_3 + rv0) % 6, ax2_3 + rv1, ax3_3]) + T.writes(C[ax0_3, ax1_3, ax2_3, ax3_3]) + with T.init(): + C[ax0_3, ax1_3, ax2_3, ax3_3] = T.int8(-128) + C[ax0_3, ax1_3, ax2_3, ax3_3] = T.max( + C[ax0_3, ax1_3, ax2_3, ax3_3], + B_1[ax0_3, (ax1_3 + rv0) % 6, ax2_3 + rv1, ax3_3], + ) + + sch = tir.Schedule(cascade_3_max_pool2d_with_stride, debug_mask="all") + io, _ = _tile_nd(sch, [1, 4, 4, 16], "C") + sch.compute_at(sch.get_block("B_1"), io[-1]) + sch.compute_at(sch.get_block("B_0"), io[-1]) + sch.rolling_buffer(sch.get_block("B_0"), 0) + sch.rolling_buffer(sch.get_block("B_1"), 0) + check_rolling_buffer(sch, cascade_3_max_pool2d_with_stride, expected, check_run=True) + + +def test_upscale(): + @T.prim_func + def before(A: T.Buffer[(1, 16, 16, 16), "int8"], C: T.Buffer[(1, 24, 24, 16), "int8"]) -> None: + B = T.alloc_buffer([1, 14, 14, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 5, 5, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(5, 5, 16, 3, 3): + with T.block("B"): + T.where(i1_0 * 5 // 2 + ax0 < 14 and i2_0 * 5 // 2 + ax1 < 14) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(14, i1_0 * 5 // 2 + ax0) + ax2_1 = T.axis.spatial(14, i2_0 * 5 // 2 + ax1) + ax3_1 = T.axis.spatial(16, ax2) + rv0, rv1 = T.axis.remap("RR", [ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 5, 5, 16, 3, 3): + with T.block("C"): + T.where(i1_0 * 5 + i1_1 < 24 and i2_0 * 5 + i2_1 < 24) + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.spatial(24, i1_0 * 5 + i1_1) + ax2 = T.axis.spatial(24, i2_0 * 5 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, ax1 // 2 + rv0, ax2 // 2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, ax1 // 2 + rv0, ax2 // 2 + rv1, ax3] + ) + + @T.prim_func + def expected( + A: T.Buffer[(1, 16, 16, 16), "int8"], C: T.Buffer[(1, 24, 24, 16), "int8"] + ) -> None: + B = T.alloc_buffer([1, 5, 14, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 5, 5, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(5, 5, 16, 3, 3): + with T.block("B"): + T.where( + i1_0 * 5 // 2 + ax0 < 14 + and i2_0 * 5 // 2 + ax1 < 14 + and (i1_0 < 1 or 2 <= ax0) + and (i2_0 < 1 or 2 <= ax1) + ) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(14, i1_0 * 5 // 2 + ax0) + ax2_1 = T.axis.spatial(14, i2_0 * 5 // 2 + ax1) + ax3_1 = T.axis.spatial(16, ax2) + rv0, rv1 = T.axis.remap("RR", [ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1 % 5, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1 % 5, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1 % 5, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1 % 5, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 5, 5, 16, 3, 3): + with T.block("C"): + T.where(i1_0 * 5 + i1_1 < 24 and i2_0 * 5 + i2_1 < 24) + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.opaque(24, i1_0 * 5 + i1_1) + ax2 = T.axis.spatial(24, i2_0 * 5 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, (ax1 // 2 + rv0) % 5, ax2 // 2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, (ax1 // 2 + rv0) % 5, ax2 // 2 + rv1, ax3] + ) + + sch = tir.Schedule(before, debug_mask="all") + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, before, expected, check_run=True) + + +def test_rolling_buffer_match_fail(): + @T.prim_func + def func_non_overlap( + A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 12, 12, 16), "int8"] + ): + B = T.alloc_buffer([1, 12, 12, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 3, 3, 1): + for ax0, ax1, ax2 in T.grid(4, 4, 16): + with T.block("B"): + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(12, i1_0 * 4 + ax0) + ax2_1 = T.axis.spatial(12, i2_0 * 4 + ax1) + ax3 = T.axis.spatial(16, ax2) + T.reads(A[ax0_1, ax1_1, ax2_1, ax3]) + T.writes(B[ax0_1, ax1_1, ax2_1, ax3]) + with T.init(): + B[ax0_1, ax1_1, ax2_1, ax3] = T.int8(-128) + B[ax0_1, ax1_1, ax2_1, ax3] = A[ax0_1, ax1_1, ax2_1, ax3] + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 1, 1): + with T.block("C"): + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.spatial(12, i1_0 * 4 + i1_1) + ax2 = T.axis.spatial(12, i2_0 * 4 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, ax1 + rv0, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, ax2 + rv1, ax3] + ) + + sch = tir.Schedule(func_non_overlap, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.rolling_buffer(sch.get_block("B"), 0) + + +def test_rolling_buffer_injection_invalid(): + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + # Block B is not compute_at to Block C, so rolling_buffer injection is invalid. + _, _ = _tile_nd(sch, [1, 4, 8, 16], "C") + _, _ = _tile_nd(sch, [1, 4, 8, 16], "B") + with pytest.raises(tvm.tir.ScheduleError): + sch.rolling_buffer(sch.get_block("B"), 0) Review Comment: Would be great to add ```python if __name__ == "__main__": tvm.testing.main() ``` ########## src/tir/schedule/primitive/rolling_buffer.cc: ########## @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <functional> + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +namespace { + +struct RollingBufferInfo { + Buffer old_buffer; + Buffer new_buffer; + int rolling_axis; + int rolling_extent; + std::vector<int> axis_overlaps; + std::vector<Optional<Var>> axis_iter_vars; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map<Block, Block> block_reuse; +}; + +BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, + const Map<Var, arith::IntSet>& dom_map) { + Array<arith::IntSet> relaxed_intsets = + arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); + Region relaxed_region; + relaxed_region.reserve(relaxed_intsets.size()); + for (size_t i = 0; i < relaxed_intsets.size(); ++i) { + relaxed_region.push_back( + relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); + } + return BufferRegion(buffer_region->buffer, relaxed_region); +} + +class RollingBufferMatchError : public ScheduleError { + public: + RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + : mod_(mod), block_(block), buffer_region_(buffer_region) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" + "matching the rolling pattern such as: hh.outer * stride + hh.inner"; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_region_->buffer->name << " with region " + << buffer_region_->region + << " should have at least one dimension range that matches a rolling pattern " + "such as hh.outer * stride + hh.inner. "; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion buffer_region_; +}; + +class RollingBufferInsertionError : public ScheduleError { + public: + RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + : mod_(mod), buffer_(std::move(buffer)), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " + "location of the target buffer is not a for loop. "; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " + << "the lca of the access location of the target buffer " << buffer_->name + << " is a for loop. "; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; +}; + +class RollingBufferInfoCollector { + public: + static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, + const StmtSRef& block_sref, + const BufferRegion& buffer_region) { + RollingBufferInfoCollector collector; + if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region); + } + return collector.info_; + } + + private: + bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + const Region& region = buffer_region->region; + + std::vector<Optional<Var>> bound_iter_vars; + std::vector<int> bound_overlaps; + auto stride = 0; + auto divisor = 1; + Optional<Var> iter_var; + for (auto bound : region) { + divisor = 1; + if (auto floor_div = bound->min.as<FloorDivNode>()) { + // Handle the case of fractional strides + // They take this form: floordiv(hh.outer, 2) + // Strip the floordiv and keep track of the divisor + divisor = Downcast<IntImm>(floor_div->b)->value; + bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span); + } + if (bound->min.as<IntImmNode>()) { + // If the bound is an int, we can't roll over it + iter_var = NullOpt; + } else if (auto var = bound->min.as<VarNode>()) { + // If the bound is just a Var, that implies the stride is 1 + iter_var = GetRef<Var>(var); + stride = 1; + } else { + // Otherwise, it's the iter var multiplied by the stride + // If not we're in unknown behaviour + if (auto mul = bound->min.as<MulNode>()) { + if (mul->a->IsInstance<VarNode>() && mul->b->IsInstance<IntImmNode>()) { + iter_var = Downcast<Var>(mul->a); + stride = Downcast<IntImm>(mul->b)->value; + } else { + return false; + } + } else { + return false; + } + } + stride = std::ceil(static_cast<float>(stride) / divisor); + auto bound_overlap = 0; + if (iter_var.defined()) { + auto extent = bound->extent.as<IntImmNode>(); + ICHECK(extent); + bound_overlap = extent->value - stride; + // Since Pass CompactBufferAllocation will be responsible for compacting the buffer + // allocation region, there is no need to roll over the axis where the overlap is not + // positive, so reset iter_var to NullOpt. + if (bound_overlap <= 0) { + iter_var = NullOpt; + } + } + bound_iter_vars.push_back(iter_var); + bound_overlaps.push_back(bound_overlap); + } + Array<StmtSRef> loop_srefs = GetLoops(block_sref); + // Pick the outermost iter_var that's mentioned in the bounds + // to be the rolling axis + Optional<Var> roll_iter_var; + int roll_axis; + for (const tir::StmtSRef& loop_sref : loop_srefs) { + auto loop_var = loop_sref->StmtAs<ForNode>()->loop_var; + + auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional<Var> var) { + return var && (var.get() == loop_var.get()); + })}; + if (it != bound_iter_vars.end()) { + auto i = std::distance(bound_iter_vars.begin(), it); + roll_iter_var = loop_var; + roll_axis = i; + break; + } + } + + if (!roll_iter_var.defined()) { + return false; + } + Array<PrimExpr> new_shape = buffer->shape; + new_shape.Set(roll_axis, region[roll_axis]->extent); + Buffer new_buffer = buffer; + new_buffer.CopyOnWrite()->shape = new_shape; + + info_.old_buffer = buffer; + info_.new_buffer = new_buffer; + info_.rolling_axis = roll_axis; + info_.rolling_extent = static_cast<int>(Downcast<IntImm>(region[roll_axis]->extent)->value); + info_.axis_overlaps = bound_overlaps; + info_.axis_iter_vars = bound_iter_vars; + + return true; + } + + RollingBufferInfo info_; +}; + +class RollingBufferRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) { + RollingBufferRewriter rewriter(scope_sref, info); + return rewriter(GetRef<Stmt>(scope_sref->stmt)); + } + + private: + explicit RollingBufferRewriter(const StmtSRef& scope_sref, RollingBufferInfo* info) + : scope_sref_(scope_sref), info_(info) {} + + void RewriteAccessRegion(Array<BufferRegion>* old_access_regions, + const Array<BufferRegion>& infered_access_regions) { + auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { + if (buffer_region->buffer.same_as(info_->old_buffer)) { + ICHECK(infered_access_regions.size() == 1); + return infered_access_regions[0]; + } + return buffer_region; + }; + (*old_access_regions).MutateByApply(fmutate); + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef<Block>(block); + Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block)); + if (block == scope_sref_->stmt) { + ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); + + Array<Buffer> new_alloc_buffers; + for (const Buffer& buffer : stmt->alloc_buffers) { + if (buffer != info_->old_buffer) { + new_alloc_buffers.push_back(buffer); + } else { + new_alloc_buffers.push_back(info_->new_buffer); + } + } + n->alloc_buffers = std::move(new_alloc_buffers); + stmt = Block(n); + } else { + Array<IterVar> new_iter_bindings; + for (size_t i = 0; i < stmt->iter_vars.size(); ++i) { + auto old_iter_var = stmt->iter_vars[i]; + if (static_cast<int>(i) == info_->rolling_axis) { + // All inner loops of the rolling axis has a loop carried dependency + // (i.e. each iteration calculation of the rolling axis depends on + // the calculation results of all the historical iterations of inner loops), + // so annotate the iteration type of the rolling axis as 'opaque', + // avoid the iterative range of its inner loop from being compressed + // during lowering phase. + IterVar new_iter_var = + IterVar(old_iter_var->dom, old_iter_var->var, IterVarType::kOpaque); + new_iter_bindings.push_back(new_iter_var); + } else { + new_iter_bindings.push_back(old_iter_var); + } + } + Map<Var, Buffer> buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; + auto infered_access_regions = GetBlockReadWriteRegion(stmt, buffer_data_to_buffer); + + BlockNode* n = stmt.CopyOnWrite(); + n->iter_vars = std::move(new_iter_bindings); + RewriteAccessRegion(&n->reads, infered_access_regions[0]); + RewriteAccessRegion(&n->writes, infered_access_regions[1]); + } + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + BlockRealize stmt = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(realize)); + // Append block predicate to avoid recomputing elements. + if (rewrite_block_predicate_) { + rewrite_block_predicate_ = false; + PrimExpr condition = stmt->predicate; + for (size_t i = 0; i < info_->axis_iter_vars.size(); ++i) { + auto iter_var = info_->axis_iter_vars[i]; + if (iter_var && info_->axis_overlaps[i] > 0) { + Var var = iter_var.value(); + const Map<Var, arith::IntSet> dmap = {std::make_pair(var, arith::IntSet::Interval(0, 0))}; + auto iter_value = realize->iter_values[i]; + arith::Analyzer analyzer; + auto term_2 = analyzer.int_set(iter_value, dmap).min(); + condition = analyzer.Simplify( + And(condition, Or(LT(var, 1), GE(term_2, info_->axis_overlaps[i])))); + } + } + BlockRealizeNode* n = stmt.CopyOnWrite(); + n->predicate = condition; + } + return std::move(stmt); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore stmt = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); + if (stmt->buffer.same_as(info_->old_buffer)) { + Array<PrimExpr> new_indices; + new_indices.reserve(stmt->indices.size()); + // First modify the access indices to use modulo arithmetic + // for the rolling axis + for (size_t i = 0; i < stmt->indices.size(); ++i) { + auto index = stmt->indices[i]; + if (static_cast<int>(i) == info_->rolling_axis) { + new_indices.push_back(FloorMod(index, info_->rolling_extent)); + } else { + new_indices.push_back(index); + } + } + BufferStoreNode* n = stmt.CopyOnWrite(); + // Replace the stored buffer with the new buffer. + n->buffer = info_->new_buffer; + n->indices = std::move(new_indices); + // Need to add predicate to the current block to avoid recomputing elements. + rewrite_block_predicate_ = true; + } + return std::move(stmt); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad stmt = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); + if (stmt->buffer.same_as(info_->old_buffer)) { + Array<PrimExpr> new_indices; + new_indices.reserve(stmt->indices.size()); + for (size_t i{0}; i < stmt->indices.size(); ++i) { + auto index = stmt->indices[i]; + if (static_cast<int>(i) == info_->rolling_axis) { + new_indices.push_back(FloorMod(index, info_->rolling_extent)); + } else { + new_indices.push_back(index); + } + } + BufferLoadNode* n = stmt.CopyOnWrite(); + // Replace the loaded buffer with the new buffer. + n->buffer = info_->new_buffer; + n->indices = std::move(new_indices); + } + return std::move(stmt); + } + + private: + const StmtSRef& scope_sref_; + RollingBufferInfo* info_; + bool rewrite_block_predicate_ = false; +}; + +} // namespace + +void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index) { + /*! + * Check + * - The block is not an output block. + * - The block is tiled and there is access overlap between adjacent tiles. + * Mutate + * - Select the outermost rollable axis appeared in the block's loop nest + * as the 'rolling axis', trim the target buffer from the rolling axis. + * - Use modulo arithmetic to modify the target buffer's read and load + * indices to circularize the buffer along the rolling dimension. + * - Append block predicate to avoid recomputing overlapping elements. + */ + Map<Var, arith::IntSet> dom_map; + const BlockRealize& realize = GetBlockRealize(self, block_sref); + const Block& block = realize->block; + + // Step 1. Checking index, getting the target buffer region and the parent scope. + const BufferRegion& buffer_region = + GetNthAccessBufferRegion(self, block, write_buffer_index, BufferIndexType::kWrite); + StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); + // Step 2. Check the target block is not an output block. + CheckNotOutputBlock(self, block_sref, scope_root_sref); + + // Step 3. Find the lca of the access location of the target buffer and relax the buffer + Array<StmtSRef> loop_srefs = GetLoops(block_sref); + Array<StmtSRef> consumers_sref = GetConsumers(self, block_sref); + consumers_sref.push_back(block_sref); + StmtSRef lca = GetSRefLowestCommonAncestor(consumers_sref); + if (!lca->StmtAs<ForNode>()) { + throw RollingBufferInsertionError(self->mod, buffer_region->buffer, block); + } + + for (auto it = loop_srefs.rbegin(); it != loop_srefs.rend(); ++it) { + auto stmt = *it; + // Stop at the lca of all the rolling_buffer access points; + if (stmt == lca) { + break; + } + For cur_loop = GetRef<For>(stmt->StmtAs<ForNode>()); + Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent); + dom_map.Set(cur_loop->loop_var, arith::IntSet::FromRange(range)); + } + BufferRegion relaxed_region = GetRelaxedBufferRegion(realize, buffer_region, dom_map); + + // Step 4. Find an valid rolling axis and collect bound overlaps on the target buffer. Review Comment: In the te version, the rolling axis is determined by the same criteria, right? ########## src/tir/schedule/primitive/rolling_buffer.cc: ########## @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <functional> + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +namespace { + +struct RollingBufferInfo { + Buffer old_buffer; + Buffer new_buffer; + int rolling_axis; + int rolling_extent; + std::vector<int> axis_overlaps; + std::vector<Optional<Var>> axis_iter_vars; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map<Block, Block> block_reuse; +}; + +BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, + const Map<Var, arith::IntSet>& dom_map) { + Array<arith::IntSet> relaxed_intsets = + arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); + Region relaxed_region; + relaxed_region.reserve(relaxed_intsets.size()); + for (size_t i = 0; i < relaxed_intsets.size(); ++i) { + relaxed_region.push_back( + relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); + } + return BufferRegion(buffer_region->buffer, relaxed_region); +} + +class RollingBufferMatchError : public ScheduleError { + public: + RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + : mod_(mod), block_(block), buffer_region_(buffer_region) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" + "matching the rolling pattern such as: hh.outer * stride + hh.inner"; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_region_->buffer->name << " with region " + << buffer_region_->region + << " should have at least one dimension range that matches a rolling pattern " + "such as hh.outer * stride + hh.inner. "; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion buffer_region_; +}; + +class RollingBufferInsertionError : public ScheduleError { + public: + RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + : mod_(mod), buffer_(std::move(buffer)), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " + "location of the target buffer is not a for loop. "; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " + << "the lca of the access location of the target buffer " << buffer_->name + << " is a for loop. "; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; +}; + +class RollingBufferInfoCollector { + public: + static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, + const StmtSRef& block_sref, + const BufferRegion& buffer_region) { + RollingBufferInfoCollector collector; + if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region); + } + return collector.info_; + } + + private: + bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + const Region& region = buffer_region->region; + + std::vector<Optional<Var>> bound_iter_vars; + std::vector<int> bound_overlaps; + auto stride = 0; + auto divisor = 1; + Optional<Var> iter_var; + for (auto bound : region) { + divisor = 1; + if (auto floor_div = bound->min.as<FloorDivNode>()) { Review Comment: Can we use tir pattern matching here? ########## src/tir/schedule/primitive/rolling_buffer.cc: ########## @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <functional> + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +namespace { + +struct RollingBufferInfo { + Buffer old_buffer; + Buffer new_buffer; + int rolling_axis; + int rolling_extent; + std::vector<int> axis_overlaps; + std::vector<Optional<Var>> axis_iter_vars; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map<Block, Block> block_reuse; +}; + +BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, + const Map<Var, arith::IntSet>& dom_map) { + Array<arith::IntSet> relaxed_intsets = + arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); + Region relaxed_region; + relaxed_region.reserve(relaxed_intsets.size()); + for (size_t i = 0; i < relaxed_intsets.size(); ++i) { + relaxed_region.push_back( + relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); + } + return BufferRegion(buffer_region->buffer, relaxed_region); +} + +class RollingBufferMatchError : public ScheduleError { + public: + RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + : mod_(mod), block_(block), buffer_region_(buffer_region) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" + "matching the rolling pattern such as: hh.outer * stride + hh.inner"; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_region_->buffer->name << " with region " + << buffer_region_->region + << " should have at least one dimension range that matches a rolling pattern " + "such as hh.outer * stride + hh.inner. "; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion buffer_region_; +}; + +class RollingBufferInsertionError : public ScheduleError { + public: + RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + : mod_(mod), buffer_(std::move(buffer)), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " + "location of the target buffer is not a for loop. "; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " + << "the lca of the access location of the target buffer " << buffer_->name + << " is a for loop. "; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; +}; + +class RollingBufferInfoCollector { + public: + static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, + const StmtSRef& block_sref, + const BufferRegion& buffer_region) { + RollingBufferInfoCollector collector; + if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region); + } + return collector.info_; + } + + private: + bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + const Region& region = buffer_region->region; + + std::vector<Optional<Var>> bound_iter_vars; + std::vector<int> bound_overlaps; + auto stride = 0; + auto divisor = 1; + Optional<Var> iter_var; + for (auto bound : region) { + divisor = 1; + if (auto floor_div = bound->min.as<FloorDivNode>()) { + // Handle the case of fractional strides + // They take this form: floordiv(hh.outer, 2) + // Strip the floordiv and keep track of the divisor + divisor = Downcast<IntImm>(floor_div->b)->value; + bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span); + } + if (bound->min.as<IntImmNode>()) { + // If the bound is an int, we can't roll over it + iter_var = NullOpt; + } else if (auto var = bound->min.as<VarNode>()) { + // If the bound is just a Var, that implies the stride is 1 + iter_var = GetRef<Var>(var); + stride = 1; + } else { + // Otherwise, it's the iter var multiplied by the stride + // If not we're in unknown behaviour + if (auto mul = bound->min.as<MulNode>()) { + if (mul->a->IsInstance<VarNode>() && mul->b->IsInstance<IntImmNode>()) { + iter_var = Downcast<Var>(mul->a); + stride = Downcast<IntImm>(mul->b)->value; + } else { + return false; + } + } else { + return false; + } + } + stride = std::ceil(static_cast<float>(stride) / divisor); + auto bound_overlap = 0; + if (iter_var.defined()) { + auto extent = bound->extent.as<IntImmNode>(); + ICHECK(extent); + bound_overlap = extent->value - stride; + // Since Pass CompactBufferAllocation will be responsible for compacting the buffer + // allocation region, there is no need to roll over the axis where the overlap is not + // positive, so reset iter_var to NullOpt. + if (bound_overlap <= 0) { + iter_var = NullOpt; + } + } + bound_iter_vars.push_back(iter_var); + bound_overlaps.push_back(bound_overlap); + } + Array<StmtSRef> loop_srefs = GetLoops(block_sref); + // Pick the outermost iter_var that's mentioned in the bounds + // to be the rolling axis + Optional<Var> roll_iter_var; + int roll_axis; + for (const tir::StmtSRef& loop_sref : loop_srefs) { + auto loop_var = loop_sref->StmtAs<ForNode>()->loop_var; + + auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional<Var> var) { + return var && (var.get() == loop_var.get()); + })}; + if (it != bound_iter_vars.end()) { + auto i = std::distance(bound_iter_vars.begin(), it); + roll_iter_var = loop_var; + roll_axis = i; + break; + } + } + + if (!roll_iter_var.defined()) { + return false; + } + Array<PrimExpr> new_shape = buffer->shape; + new_shape.Set(roll_axis, region[roll_axis]->extent); + Buffer new_buffer = buffer; + new_buffer.CopyOnWrite()->shape = new_shape; + + info_.old_buffer = buffer; + info_.new_buffer = new_buffer; + info_.rolling_axis = roll_axis; + info_.rolling_extent = static_cast<int>(Downcast<IntImm>(region[roll_axis]->extent)->value); + info_.axis_overlaps = bound_overlaps; + info_.axis_iter_vars = bound_iter_vars; + + return true; + } + + RollingBufferInfo info_; +}; + +class RollingBufferRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) { + RollingBufferRewriter rewriter(scope_sref, info); + return rewriter(GetRef<Stmt>(scope_sref->stmt)); + } + + private: + explicit RollingBufferRewriter(const StmtSRef& scope_sref, RollingBufferInfo* info) + : scope_sref_(scope_sref), info_(info) {} + + void RewriteAccessRegion(Array<BufferRegion>* old_access_regions, + const Array<BufferRegion>& infered_access_regions) { + auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { + if (buffer_region->buffer.same_as(info_->old_buffer)) { + ICHECK(infered_access_regions.size() == 1); + return infered_access_regions[0]; + } + return buffer_region; + }; + (*old_access_regions).MutateByApply(fmutate); + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef<Block>(block); + Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block)); + if (block == scope_sref_->stmt) { + ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); + + Array<Buffer> new_alloc_buffers; + for (const Buffer& buffer : stmt->alloc_buffers) { + if (buffer != info_->old_buffer) { + new_alloc_buffers.push_back(buffer); + } else { + new_alloc_buffers.push_back(info_->new_buffer); + } + } + n->alloc_buffers = std::move(new_alloc_buffers); + stmt = Block(n); + } else { + Array<IterVar> new_iter_bindings; + for (size_t i = 0; i < stmt->iter_vars.size(); ++i) { + auto old_iter_var = stmt->iter_vars[i]; + if (static_cast<int>(i) == info_->rolling_axis) { + // All inner loops of the rolling axis has a loop carried dependency + // (i.e. each iteration calculation of the rolling axis depends on + // the calculation results of all the historical iterations of inner loops), + // so annotate the iteration type of the rolling axis as 'opaque', + // avoid the iterative range of its inner loop from being compressed + // during lowering phase. + IterVar new_iter_var = + IterVar(old_iter_var->dom, old_iter_var->var, IterVarType::kOpaque); + new_iter_bindings.push_back(new_iter_var); + } else { + new_iter_bindings.push_back(old_iter_var); + } + } + Map<Var, Buffer> buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; + auto infered_access_regions = GetBlockReadWriteRegion(stmt, buffer_data_to_buffer); + + BlockNode* n = stmt.CopyOnWrite(); + n->iter_vars = std::move(new_iter_bindings); + RewriteAccessRegion(&n->reads, infered_access_regions[0]); + RewriteAccessRegion(&n->writes, infered_access_regions[1]); + } + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + BlockRealize stmt = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(realize)); + // Append block predicate to avoid recomputing elements. + if (rewrite_block_predicate_) { + rewrite_block_predicate_ = false; + PrimExpr condition = stmt->predicate; + for (size_t i = 0; i < info_->axis_iter_vars.size(); ++i) { + auto iter_var = info_->axis_iter_vars[i]; + if (iter_var && info_->axis_overlaps[i] > 0) { + Var var = iter_var.value(); + const Map<Var, arith::IntSet> dmap = {std::make_pair(var, arith::IntSet::Interval(0, 0))}; + auto iter_value = realize->iter_values[i]; + arith::Analyzer analyzer; + auto term_2 = analyzer.int_set(iter_value, dmap).min(); + condition = analyzer.Simplify( + And(condition, Or(LT(var, 1), GE(term_2, info_->axis_overlaps[i])))); + } + } + BlockRealizeNode* n = stmt.CopyOnWrite(); + n->predicate = condition; + } + return std::move(stmt); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore stmt = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); + if (stmt->buffer.same_as(info_->old_buffer)) { + Array<PrimExpr> new_indices; + new_indices.reserve(stmt->indices.size()); + // First modify the access indices to use modulo arithmetic + // for the rolling axis + for (size_t i = 0; i < stmt->indices.size(); ++i) { + auto index = stmt->indices[i]; + if (static_cast<int>(i) == info_->rolling_axis) { + new_indices.push_back(FloorMod(index, info_->rolling_extent)); + } else { + new_indices.push_back(index); + } + } + BufferStoreNode* n = stmt.CopyOnWrite(); + // Replace the stored buffer with the new buffer. + n->buffer = info_->new_buffer; + n->indices = std::move(new_indices); + // Need to add predicate to the current block to avoid recomputing elements. + rewrite_block_predicate_ = true; + } + return std::move(stmt); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad stmt = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); + if (stmt->buffer.same_as(info_->old_buffer)) { + Array<PrimExpr> new_indices; + new_indices.reserve(stmt->indices.size()); + for (size_t i{0}; i < stmt->indices.size(); ++i) { + auto index = stmt->indices[i]; + if (static_cast<int>(i) == info_->rolling_axis) { + new_indices.push_back(FloorMod(index, info_->rolling_extent)); + } else { + new_indices.push_back(index); + } + } + BufferLoadNode* n = stmt.CopyOnWrite(); + // Replace the loaded buffer with the new buffer. + n->buffer = info_->new_buffer; + n->indices = std::move(new_indices); + } + return std::move(stmt); + } + + private: + const StmtSRef& scope_sref_; + RollingBufferInfo* info_; + bool rewrite_block_predicate_ = false; +}; + +} // namespace + +void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index) { + /*! + * Check + * - The block is not an output block. + * - The block is tiled and there is access overlap between adjacent tiles. + * Mutate + * - Select the outermost rollable axis appeared in the block's loop nest + * as the 'rolling axis', trim the target buffer from the rolling axis. + * - Use modulo arithmetic to modify the target buffer's read and load + * indices to circularize the buffer along the rolling dimension. + * - Append block predicate to avoid recomputing overlapping elements. + */ + Map<Var, arith::IntSet> dom_map; + const BlockRealize& realize = GetBlockRealize(self, block_sref); + const Block& block = realize->block; + + // Step 1. Checking index, getting the target buffer region and the parent scope. + const BufferRegion& buffer_region = + GetNthAccessBufferRegion(self, block, write_buffer_index, BufferIndexType::kWrite); + StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); Review Comment: Maybe we have to set `require_stage_pipeline` as true to ensure safety? ########## src/tir/schedule/primitive/rolling_buffer.cc: ########## @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <functional> + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +namespace { + +struct RollingBufferInfo { + Buffer old_buffer; + Buffer new_buffer; + int rolling_axis; + int rolling_extent; + std::vector<int> axis_overlaps; + std::vector<Optional<Var>> axis_iter_vars; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map<Block, Block> block_reuse; +}; + +BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, + const Map<Var, arith::IntSet>& dom_map) { + Array<arith::IntSet> relaxed_intsets = + arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); + Region relaxed_region; + relaxed_region.reserve(relaxed_intsets.size()); + for (size_t i = 0; i < relaxed_intsets.size(); ++i) { + relaxed_region.push_back( + relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); + } + return BufferRegion(buffer_region->buffer, relaxed_region); +} + +class RollingBufferMatchError : public ScheduleError { + public: + RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + : mod_(mod), block_(block), buffer_region_(buffer_region) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" + "matching the rolling pattern such as: hh.outer * stride + hh.inner"; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_region_->buffer->name << " with region " + << buffer_region_->region + << " should have at least one dimension range that matches a rolling pattern " + "such as hh.outer * stride + hh.inner. "; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion buffer_region_; +}; + +class RollingBufferInsertionError : public ScheduleError { + public: + RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + : mod_(mod), buffer_(std::move(buffer)), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " + "location of the target buffer is not a for loop. "; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " + << "the lca of the access location of the target buffer " << buffer_->name + << " is a for loop. "; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; +}; + +class RollingBufferInfoCollector { + public: + static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, + const StmtSRef& block_sref, + const BufferRegion& buffer_region) { + RollingBufferInfoCollector collector; + if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region); + } + return collector.info_; + } + + private: + bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + const Region& region = buffer_region->region; + + std::vector<Optional<Var>> bound_iter_vars; + std::vector<int> bound_overlaps; + auto stride = 0; + auto divisor = 1; + Optional<Var> iter_var; + for (auto bound : region) { + divisor = 1; + if (auto floor_div = bound->min.as<FloorDivNode>()) { + // Handle the case of fractional strides + // They take this form: floordiv(hh.outer, 2) + // Strip the floordiv and keep track of the divisor + divisor = Downcast<IntImm>(floor_div->b)->value; + bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span); + } + if (bound->min.as<IntImmNode>()) { Review Comment: is_const_int(bound->min) ########## src/tir/schedule/primitive/rolling_buffer.cc: ########## @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <functional> + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +namespace { + +struct RollingBufferInfo { + Buffer old_buffer; + Buffer new_buffer; + int rolling_axis; + int rolling_extent; + std::vector<int> axis_overlaps; + std::vector<Optional<Var>> axis_iter_vars; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map<Block, Block> block_reuse; +}; + +BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, + const Map<Var, arith::IntSet>& dom_map) { + Array<arith::IntSet> relaxed_intsets = + arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); + Region relaxed_region; + relaxed_region.reserve(relaxed_intsets.size()); + for (size_t i = 0; i < relaxed_intsets.size(); ++i) { + relaxed_region.push_back( + relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); + } + return BufferRegion(buffer_region->buffer, relaxed_region); +} + +class RollingBufferMatchError : public ScheduleError { + public: + RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + : mod_(mod), block_(block), buffer_region_(buffer_region) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" + "matching the rolling pattern such as: hh.outer * stride + hh.inner"; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_region_->buffer->name << " with region " + << buffer_region_->region + << " should have at least one dimension range that matches a rolling pattern " + "such as hh.outer * stride + hh.inner. "; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion buffer_region_; +}; + +class RollingBufferInsertionError : public ScheduleError { + public: + RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + : mod_(mod), buffer_(std::move(buffer)), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " + "location of the target buffer is not a for loop. "; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " + << "the lca of the access location of the target buffer " << buffer_->name + << " is a for loop. "; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; +}; + +class RollingBufferInfoCollector { + public: + static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, + const StmtSRef& block_sref, + const BufferRegion& buffer_region) { + RollingBufferInfoCollector collector; + if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region); + } + return collector.info_; + } + + private: + bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + const Region& region = buffer_region->region; + + std::vector<Optional<Var>> bound_iter_vars; + std::vector<int> bound_overlaps; + auto stride = 0; + auto divisor = 1; + Optional<Var> iter_var; + for (auto bound : region) { + divisor = 1; + if (auto floor_div = bound->min.as<FloorDivNode>()) { + // Handle the case of fractional strides + // They take this form: floordiv(hh.outer, 2) + // Strip the floordiv and keep track of the divisor + divisor = Downcast<IntImm>(floor_div->b)->value; + bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span); + } + if (bound->min.as<IntImmNode>()) { + // If the bound is an int, we can't roll over it + iter_var = NullOpt; + } else if (auto var = bound->min.as<VarNode>()) { + // If the bound is just a Var, that implies the stride is 1 + iter_var = GetRef<Var>(var); + stride = 1; + } else { + // Otherwise, it's the iter var multiplied by the stride + // If not we're in unknown behaviour + if (auto mul = bound->min.as<MulNode>()) { + if (mul->a->IsInstance<VarNode>() && mul->b->IsInstance<IntImmNode>()) { + iter_var = Downcast<Var>(mul->a); + stride = Downcast<IntImm>(mul->b)->value; + } else { + return false; + } + } else { + return false; + } + } + stride = std::ceil(static_cast<float>(stride) / divisor); + auto bound_overlap = 0; + if (iter_var.defined()) { + auto extent = bound->extent.as<IntImmNode>(); + ICHECK(extent); + bound_overlap = extent->value - stride; + // Since Pass CompactBufferAllocation will be responsible for compacting the buffer + // allocation region, there is no need to roll over the axis where the overlap is not + // positive, so reset iter_var to NullOpt. + if (bound_overlap <= 0) { + iter_var = NullOpt; + } + } + bound_iter_vars.push_back(iter_var); + bound_overlaps.push_back(bound_overlap); + } + Array<StmtSRef> loop_srefs = GetLoops(block_sref); + // Pick the outermost iter_var that's mentioned in the bounds + // to be the rolling axis + Optional<Var> roll_iter_var; + int roll_axis; + for (const tir::StmtSRef& loop_sref : loop_srefs) { + auto loop_var = loop_sref->StmtAs<ForNode>()->loop_var; + + auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional<Var> var) { + return var && (var.get() == loop_var.get()); + })}; + if (it != bound_iter_vars.end()) { + auto i = std::distance(bound_iter_vars.begin(), it); + roll_iter_var = loop_var; + roll_axis = i; + break; + } + } + + if (!roll_iter_var.defined()) { + return false; + } + Array<PrimExpr> new_shape = buffer->shape; + new_shape.Set(roll_axis, region[roll_axis]->extent); + Buffer new_buffer = buffer; + new_buffer.CopyOnWrite()->shape = new_shape; + + info_.old_buffer = buffer; + info_.new_buffer = new_buffer; + info_.rolling_axis = roll_axis; + info_.rolling_extent = static_cast<int>(Downcast<IntImm>(region[roll_axis]->extent)->value); + info_.axis_overlaps = bound_overlaps; + info_.axis_iter_vars = bound_iter_vars; + + return true; + } + + RollingBufferInfo info_; +}; + +class RollingBufferRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) { + RollingBufferRewriter rewriter(scope_sref, info); + return rewriter(GetRef<Stmt>(scope_sref->stmt)); + } + + private: + explicit RollingBufferRewriter(const StmtSRef& scope_sref, RollingBufferInfo* info) + : scope_sref_(scope_sref), info_(info) {} + + void RewriteAccessRegion(Array<BufferRegion>* old_access_regions, + const Array<BufferRegion>& infered_access_regions) { + auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { + if (buffer_region->buffer.same_as(info_->old_buffer)) { + ICHECK(infered_access_regions.size() == 1); + return infered_access_regions[0]; + } + return buffer_region; + }; + (*old_access_regions).MutateByApply(fmutate); + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef<Block>(block); + Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block)); + if (block == scope_sref_->stmt) { + ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); Review Comment: Could be `CopyOnWrite(block)`? ########## include/tvm/tir/schedule/schedule.h: ########## @@ -681,6 +681,9 @@ class ScheduleNode : public runtime::Object { */ virtual void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) = 0; + /******** Schedule: Buffer transformation ********/ + virtual void RollingBuffer(const BlockRV& block_rv, int buffer_index) = 0; Review Comment: It would be great to use name `read_buffer_index` or `write_buffer_index` ########## src/tir/schedule/primitive/rolling_buffer.cc: ########## @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <functional> + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +namespace { + +struct RollingBufferInfo { + Buffer old_buffer; + Buffer new_buffer; + int rolling_axis; + int rolling_extent; + std::vector<int> axis_overlaps; + std::vector<Optional<Var>> axis_iter_vars; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map<Block, Block> block_reuse; +}; + +BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, + const Map<Var, arith::IntSet>& dom_map) { + Array<arith::IntSet> relaxed_intsets = + arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); + Region relaxed_region; + relaxed_region.reserve(relaxed_intsets.size()); + for (size_t i = 0; i < relaxed_intsets.size(); ++i) { + relaxed_region.push_back( + relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); + } + return BufferRegion(buffer_region->buffer, relaxed_region); +} + +class RollingBufferMatchError : public ScheduleError { + public: + RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + : mod_(mod), block_(block), buffer_region_(buffer_region) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" + "matching the rolling pattern such as: hh.outer * stride + hh.inner"; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_region_->buffer->name << " with region " + << buffer_region_->region + << " should have at least one dimension range that matches a rolling pattern " + "such as hh.outer * stride + hh.inner. "; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion buffer_region_; +}; + +class RollingBufferInsertionError : public ScheduleError { + public: + RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + : mod_(mod), buffer_(std::move(buffer)), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " + "location of the target buffer is not a for loop. "; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " + << "the lca of the access location of the target buffer " << buffer_->name + << " is a for loop. "; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; +}; + +class RollingBufferInfoCollector { + public: + static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, + const StmtSRef& block_sref, + const BufferRegion& buffer_region) { + RollingBufferInfoCollector collector; + if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region); + } + return collector.info_; + } + + private: + bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + const Region& region = buffer_region->region; + + std::vector<Optional<Var>> bound_iter_vars; + std::vector<int> bound_overlaps; + auto stride = 0; + auto divisor = 1; + Optional<Var> iter_var; + for (auto bound : region) { + divisor = 1; + if (auto floor_div = bound->min.as<FloorDivNode>()) { + // Handle the case of fractional strides + // They take this form: floordiv(hh.outer, 2) + // Strip the floordiv and keep track of the divisor + divisor = Downcast<IntImm>(floor_div->b)->value; + bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span); + } + if (bound->min.as<IntImmNode>()) { + // If the bound is an int, we can't roll over it + iter_var = NullOpt; + } else if (auto var = bound->min.as<VarNode>()) { + // If the bound is just a Var, that implies the stride is 1 + iter_var = GetRef<Var>(var); + stride = 1; + } else { + // Otherwise, it's the iter var multiplied by the stride + // If not we're in unknown behaviour + if (auto mul = bound->min.as<MulNode>()) { + if (mul->a->IsInstance<VarNode>() && mul->b->IsInstance<IntImmNode>()) { + iter_var = Downcast<Var>(mul->a); + stride = Downcast<IntImm>(mul->b)->value; + } else { + return false; + } + } else { + return false; + } + } + stride = std::ceil(static_cast<float>(stride) / divisor); + auto bound_overlap = 0; + if (iter_var.defined()) { + auto extent = bound->extent.as<IntImmNode>(); + ICHECK(extent); + bound_overlap = extent->value - stride; + // Since Pass CompactBufferAllocation will be responsible for compacting the buffer + // allocation region, there is no need to roll over the axis where the overlap is not + // positive, so reset iter_var to NullOpt. + if (bound_overlap <= 0) { + iter_var = NullOpt; + } + } + bound_iter_vars.push_back(iter_var); + bound_overlaps.push_back(bound_overlap); + } + Array<StmtSRef> loop_srefs = GetLoops(block_sref); + // Pick the outermost iter_var that's mentioned in the bounds + // to be the rolling axis + Optional<Var> roll_iter_var; + int roll_axis; + for (const tir::StmtSRef& loop_sref : loop_srefs) { + auto loop_var = loop_sref->StmtAs<ForNode>()->loop_var; + + auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional<Var> var) { + return var && (var.get() == loop_var.get()); + })}; + if (it != bound_iter_vars.end()) { + auto i = std::distance(bound_iter_vars.begin(), it); + roll_iter_var = loop_var; + roll_axis = i; + break; + } + } + + if (!roll_iter_var.defined()) { + return false; + } + Array<PrimExpr> new_shape = buffer->shape; + new_shape.Set(roll_axis, region[roll_axis]->extent); + Buffer new_buffer = buffer; + new_buffer.CopyOnWrite()->shape = new_shape; + + info_.old_buffer = buffer; + info_.new_buffer = new_buffer; + info_.rolling_axis = roll_axis; + info_.rolling_extent = static_cast<int>(Downcast<IntImm>(region[roll_axis]->extent)->value); Review Comment: Could we store rolling_extent as PrimExpr to avoid int cast? ########## src/tir/schedule/primitive/rolling_buffer.cc: ########## @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <functional> + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +namespace { + +struct RollingBufferInfo { + Buffer old_buffer; + Buffer new_buffer; + int rolling_axis; + int rolling_extent; + std::vector<int> axis_overlaps; + std::vector<Optional<Var>> axis_iter_vars; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map<Block, Block> block_reuse; +}; + +BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, + const Map<Var, arith::IntSet>& dom_map) { + Array<arith::IntSet> relaxed_intsets = + arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); + Region relaxed_region; + relaxed_region.reserve(relaxed_intsets.size()); + for (size_t i = 0; i < relaxed_intsets.size(); ++i) { + relaxed_region.push_back( + relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); + } + return BufferRegion(buffer_region->buffer, relaxed_region); +} + +class RollingBufferMatchError : public ScheduleError { + public: + RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + : mod_(mod), block_(block), buffer_region_(buffer_region) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" + "matching the rolling pattern such as: hh.outer * stride + hh.inner"; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_region_->buffer->name << " with region " + << buffer_region_->region + << " should have at least one dimension range that matches a rolling pattern " + "such as hh.outer * stride + hh.inner. "; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion buffer_region_; +}; + +class RollingBufferInsertionError : public ScheduleError { + public: + RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + : mod_(mod), buffer_(std::move(buffer)), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " + "location of the target buffer is not a for loop. "; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " + << "the lca of the access location of the target buffer " << buffer_->name + << " is a for loop. "; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; +}; + +class RollingBufferInfoCollector { + public: + static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, + const StmtSRef& block_sref, + const BufferRegion& buffer_region) { + RollingBufferInfoCollector collector; + if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region); + } + return collector.info_; + } + + private: + bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + const Region& region = buffer_region->region; + + std::vector<Optional<Var>> bound_iter_vars; + std::vector<int> bound_overlaps; + auto stride = 0; + auto divisor = 1; + Optional<Var> iter_var; + for (auto bound : region) { + divisor = 1; + if (auto floor_div = bound->min.as<FloorDivNode>()) { + // Handle the case of fractional strides + // They take this form: floordiv(hh.outer, 2) + // Strip the floordiv and keep track of the divisor + divisor = Downcast<IntImm>(floor_div->b)->value; + bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span); + } + if (bound->min.as<IntImmNode>()) { + // If the bound is an int, we can't roll over it + iter_var = NullOpt; + } else if (auto var = bound->min.as<VarNode>()) { + // If the bound is just a Var, that implies the stride is 1 + iter_var = GetRef<Var>(var); + stride = 1; + } else { + // Otherwise, it's the iter var multiplied by the stride + // If not we're in unknown behaviour + if (auto mul = bound->min.as<MulNode>()) { + if (mul->a->IsInstance<VarNode>() && mul->b->IsInstance<IntImmNode>()) { + iter_var = Downcast<Var>(mul->a); + stride = Downcast<IntImm>(mul->b)->value; + } else { + return false; + } + } else { + return false; + } + } + stride = std::ceil(static_cast<float>(stride) / divisor); + auto bound_overlap = 0; + if (iter_var.defined()) { + auto extent = bound->extent.as<IntImmNode>(); + ICHECK(extent); + bound_overlap = extent->value - stride; + // Since Pass CompactBufferAllocation will be responsible for compacting the buffer + // allocation region, there is no need to roll over the axis where the overlap is not + // positive, so reset iter_var to NullOpt. + if (bound_overlap <= 0) { + iter_var = NullOpt; + } + } + bound_iter_vars.push_back(iter_var); + bound_overlaps.push_back(bound_overlap); + } + Array<StmtSRef> loop_srefs = GetLoops(block_sref); + // Pick the outermost iter_var that's mentioned in the bounds + // to be the rolling axis + Optional<Var> roll_iter_var; + int roll_axis; + for (const tir::StmtSRef& loop_sref : loop_srefs) { + auto loop_var = loop_sref->StmtAs<ForNode>()->loop_var; + + auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional<Var> var) { + return var && (var.get() == loop_var.get()); + })}; + if (it != bound_iter_vars.end()) { + auto i = std::distance(bound_iter_vars.begin(), it); + roll_iter_var = loop_var; + roll_axis = i; + break; + } + } + + if (!roll_iter_var.defined()) { + return false; + } + Array<PrimExpr> new_shape = buffer->shape; + new_shape.Set(roll_axis, region[roll_axis]->extent); + Buffer new_buffer = buffer; + new_buffer.CopyOnWrite()->shape = new_shape; + + info_.old_buffer = buffer; + info_.new_buffer = new_buffer; + info_.rolling_axis = roll_axis; + info_.rolling_extent = static_cast<int>(Downcast<IntImm>(region[roll_axis]->extent)->value); + info_.axis_overlaps = bound_overlaps; + info_.axis_iter_vars = bound_iter_vars; + + return true; + } + + RollingBufferInfo info_; +}; + +class RollingBufferRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) { + RollingBufferRewriter rewriter(scope_sref, info); + return rewriter(GetRef<Stmt>(scope_sref->stmt)); + } + + private: + explicit RollingBufferRewriter(const StmtSRef& scope_sref, RollingBufferInfo* info) + : scope_sref_(scope_sref), info_(info) {} + + void RewriteAccessRegion(Array<BufferRegion>* old_access_regions, + const Array<BufferRegion>& infered_access_regions) { + auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { + if (buffer_region->buffer.same_as(info_->old_buffer)) { + ICHECK(infered_access_regions.size() == 1); + return infered_access_regions[0]; + } + return buffer_region; + }; + (*old_access_regions).MutateByApply(fmutate); + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef<Block>(block); + Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block)); + if (block == scope_sref_->stmt) { + ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); + + Array<Buffer> new_alloc_buffers; + for (const Buffer& buffer : stmt->alloc_buffers) { + if (buffer != info_->old_buffer) { + new_alloc_buffers.push_back(buffer); + } else { + new_alloc_buffers.push_back(info_->new_buffer); + } + } + n->alloc_buffers = std::move(new_alloc_buffers); + stmt = Block(n); + } else { + Array<IterVar> new_iter_bindings; Review Comment: This variable is actually `new_iter_vars`? ########## src/tir/schedule/primitive/rolling_buffer.cc: ########## @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <functional> + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +namespace { + +struct RollingBufferInfo { + Buffer old_buffer; + Buffer new_buffer; + int rolling_axis; + int rolling_extent; + std::vector<int> axis_overlaps; + std::vector<Optional<Var>> axis_iter_vars; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map<Block, Block> block_reuse; +}; + +BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, + const Map<Var, arith::IntSet>& dom_map) { + Array<arith::IntSet> relaxed_intsets = + arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); + Region relaxed_region; + relaxed_region.reserve(relaxed_intsets.size()); + for (size_t i = 0; i < relaxed_intsets.size(); ++i) { + relaxed_region.push_back( + relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); + } + return BufferRegion(buffer_region->buffer, relaxed_region); +} + +class RollingBufferMatchError : public ScheduleError { + public: + RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + : mod_(mod), block_(block), buffer_region_(buffer_region) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" + "matching the rolling pattern such as: hh.outer * stride + hh.inner"; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_region_->buffer->name << " with region " + << buffer_region_->region + << " should have at least one dimension range that matches a rolling pattern " + "such as hh.outer * stride + hh.inner. "; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion buffer_region_; +}; + +class RollingBufferInsertionError : public ScheduleError { + public: + RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + : mod_(mod), buffer_(std::move(buffer)), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " + "location of the target buffer is not a for loop. "; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " + << "the lca of the access location of the target buffer " << buffer_->name + << " is a for loop. "; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; +}; + +class RollingBufferInfoCollector { + public: + static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, + const StmtSRef& block_sref, + const BufferRegion& buffer_region) { + RollingBufferInfoCollector collector; + if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region); + } + return collector.info_; + } + + private: + bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + const Region& region = buffer_region->region; + + std::vector<Optional<Var>> bound_iter_vars; + std::vector<int> bound_overlaps; + auto stride = 0; + auto divisor = 1; + Optional<Var> iter_var; + for (auto bound : region) { + divisor = 1; + if (auto floor_div = bound->min.as<FloorDivNode>()) { + // Handle the case of fractional strides + // They take this form: floordiv(hh.outer, 2) + // Strip the floordiv and keep track of the divisor + divisor = Downcast<IntImm>(floor_div->b)->value; + bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span); + } + if (bound->min.as<IntImmNode>()) { + // If the bound is an int, we can't roll over it + iter_var = NullOpt; + } else if (auto var = bound->min.as<VarNode>()) { + // If the bound is just a Var, that implies the stride is 1 + iter_var = GetRef<Var>(var); + stride = 1; + } else { + // Otherwise, it's the iter var multiplied by the stride + // If not we're in unknown behaviour + if (auto mul = bound->min.as<MulNode>()) { + if (mul->a->IsInstance<VarNode>() && mul->b->IsInstance<IntImmNode>()) { + iter_var = Downcast<Var>(mul->a); + stride = Downcast<IntImm>(mul->b)->value; + } else { + return false; + } + } else { + return false; + } + } + stride = std::ceil(static_cast<float>(stride) / divisor); + auto bound_overlap = 0; + if (iter_var.defined()) { + auto extent = bound->extent.as<IntImmNode>(); + ICHECK(extent); + bound_overlap = extent->value - stride; + // Since Pass CompactBufferAllocation will be responsible for compacting the buffer + // allocation region, there is no need to roll over the axis where the overlap is not + // positive, so reset iter_var to NullOpt. + if (bound_overlap <= 0) { + iter_var = NullOpt; + } + } + bound_iter_vars.push_back(iter_var); + bound_overlaps.push_back(bound_overlap); + } + Array<StmtSRef> loop_srefs = GetLoops(block_sref); + // Pick the outermost iter_var that's mentioned in the bounds + // to be the rolling axis + Optional<Var> roll_iter_var; + int roll_axis; + for (const tir::StmtSRef& loop_sref : loop_srefs) { + auto loop_var = loop_sref->StmtAs<ForNode>()->loop_var; + + auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional<Var> var) { + return var && (var.get() == loop_var.get()); + })}; + if (it != bound_iter_vars.end()) { + auto i = std::distance(bound_iter_vars.begin(), it); + roll_iter_var = loop_var; + roll_axis = i; + break; + } + } + + if (!roll_iter_var.defined()) { + return false; + } + Array<PrimExpr> new_shape = buffer->shape; + new_shape.Set(roll_axis, region[roll_axis]->extent); + Buffer new_buffer = buffer; + new_buffer.CopyOnWrite()->shape = new_shape; + + info_.old_buffer = buffer; + info_.new_buffer = new_buffer; + info_.rolling_axis = roll_axis; + info_.rolling_extent = static_cast<int>(Downcast<IntImm>(region[roll_axis]->extent)->value); + info_.axis_overlaps = bound_overlaps; + info_.axis_iter_vars = bound_iter_vars; + + return true; + } + + RollingBufferInfo info_; +}; + +class RollingBufferRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) { + RollingBufferRewriter rewriter(scope_sref, info); + return rewriter(GetRef<Stmt>(scope_sref->stmt)); + } + + private: + explicit RollingBufferRewriter(const StmtSRef& scope_sref, RollingBufferInfo* info) + : scope_sref_(scope_sref), info_(info) {} + + void RewriteAccessRegion(Array<BufferRegion>* old_access_regions, + const Array<BufferRegion>& infered_access_regions) { + auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { + if (buffer_region->buffer.same_as(info_->old_buffer)) { + ICHECK(infered_access_regions.size() == 1); + return infered_access_regions[0]; + } + return buffer_region; + }; + (*old_access_regions).MutateByApply(fmutate); + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef<Block>(block); + Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block)); + if (block == scope_sref_->stmt) { + ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); + + Array<Buffer> new_alloc_buffers; + for (const Buffer& buffer : stmt->alloc_buffers) { + if (buffer != info_->old_buffer) { + new_alloc_buffers.push_back(buffer); + } else { + new_alloc_buffers.push_back(info_->new_buffer); + } + } + n->alloc_buffers = std::move(new_alloc_buffers); + stmt = Block(n); + } else { + Array<IterVar> new_iter_bindings; + for (size_t i = 0; i < stmt->iter_vars.size(); ++i) { + auto old_iter_var = stmt->iter_vars[i]; + if (static_cast<int>(i) == info_->rolling_axis) { + // All inner loops of the rolling axis has a loop carried dependency + // (i.e. each iteration calculation of the rolling axis depends on + // the calculation results of all the historical iterations of inner loops), + // so annotate the iteration type of the rolling axis as 'opaque', + // avoid the iterative range of its inner loop from being compressed + // during lowering phase. + IterVar new_iter_var = + IterVar(old_iter_var->dom, old_iter_var->var, IterVarType::kOpaque); + new_iter_bindings.push_back(new_iter_var); + } else { + new_iter_bindings.push_back(old_iter_var); + } + } + Map<Var, Buffer> buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; + auto infered_access_regions = GetBlockReadWriteRegion(stmt, buffer_data_to_buffer); + + BlockNode* n = stmt.CopyOnWrite(); + n->iter_vars = std::move(new_iter_bindings); + RewriteAccessRegion(&n->reads, infered_access_regions[0]); + RewriteAccessRegion(&n->writes, infered_access_regions[1]); + } + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + BlockRealize stmt = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(realize)); + // Append block predicate to avoid recomputing elements. + if (rewrite_block_predicate_) { + rewrite_block_predicate_ = false; + PrimExpr condition = stmt->predicate; + for (size_t i = 0; i < info_->axis_iter_vars.size(); ++i) { + auto iter_var = info_->axis_iter_vars[i]; + if (iter_var && info_->axis_overlaps[i] > 0) { + Var var = iter_var.value(); + const Map<Var, arith::IntSet> dmap = {std::make_pair(var, arith::IntSet::Interval(0, 0))}; + auto iter_value = realize->iter_values[i]; + arith::Analyzer analyzer; + auto term_2 = analyzer.int_set(iter_value, dmap).min(); + condition = analyzer.Simplify( + And(condition, Or(LT(var, 1), GE(term_2, info_->axis_overlaps[i])))); + } + } + BlockRealizeNode* n = stmt.CopyOnWrite(); + n->predicate = condition; + } + return std::move(stmt); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore stmt = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); + if (stmt->buffer.same_as(info_->old_buffer)) { + Array<PrimExpr> new_indices; + new_indices.reserve(stmt->indices.size()); + // First modify the access indices to use modulo arithmetic + // for the rolling axis + for (size_t i = 0; i < stmt->indices.size(); ++i) { + auto index = stmt->indices[i]; + if (static_cast<int>(i) == info_->rolling_axis) { + new_indices.push_back(FloorMod(index, info_->rolling_extent)); + } else { + new_indices.push_back(index); + } + } + BufferStoreNode* n = stmt.CopyOnWrite(); + // Replace the stored buffer with the new buffer. + n->buffer = info_->new_buffer; + n->indices = std::move(new_indices); + // Need to add predicate to the current block to avoid recomputing elements. + rewrite_block_predicate_ = true; + } + return std::move(stmt); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad stmt = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); + if (stmt->buffer.same_as(info_->old_buffer)) { + Array<PrimExpr> new_indices; Review Comment: the implementation here seems to be a duplicate with in buffer store -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
