This is an automated email from the ASF dual-hosted git repository. junrushao pushed a commit to branch revert-9880-encode_conditional_accesses_in_read_write_annotations in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 2e14b8565515b3940f1624de61fc23d3446c4cbe Author: Junru Shao <[email protected]> AuthorDate: Mon Jan 17 22:46:07 2022 -0800 Revert "[TIR] Encode conditional accesses info into block read/write regions (#9880)" This reverts commit 6f6fc68f5a028a92607b2907b9e4144543686639. --- src/tir/analysis/block_access_region_detector.cc | 29 ++-------- src/tir/transforms/compact_buffer_region.cc | 10 ++-- src/tir/transforms/ir_utils.cc | 62 +++++--------------- src/tir/transforms/ir_utils.h | 18 +++--- .../test_tir_analysis_get_block_access_region.py | 66 ---------------------- .../test_tir_transform_compact_buffer_region.py | 1 - 6 files changed, 30 insertions(+), 156 deletions(-) diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 07dcace..776538a 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -56,8 +56,6 @@ class BlockReadWriteDetector : public StmtExprVisitor { private: /*! \brief Iteration range for loop_vars */ std::unordered_map<const VarNode*, arith::IntSet> dom_map_; - /*! \brief Extra iteration range hint for free vars */ - std::unordered_map<const VarNode*, arith::IntSet> hint_map_; /*! \brief The buffers that the current block reads */ std::vector<Buffer> read_buffers_; /*! \brief The buffers that the current block writes */ @@ -98,9 +96,6 @@ class BlockReadWriteDetector : public StmtExprVisitor { /*! \brief Helper function to update a opaque access. */ void UpdateOpaque(const Var& buffer_var); - /*! \brief Helper function to relax the buffer indices */ - arith::IntSet RelaxAccessIndex(const PrimExpr& index); - void VisitStmt_(const ForNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; @@ -145,22 +140,10 @@ void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) { ExprVisitor::VisitExpr_(op); } -arith::IntSet BlockReadWriteDetector::RelaxAccessIndex(const PrimExpr& index) { - arith::IntSet relaxed = arith::EvalSet(index, dom_map_); - if (!hint_map_.empty()) { - // take non-relaxed var bound hints into considerations - // eg, if i * 4 + j with i >= 10 and j in [0, 4), only j in domain scope - // then the index region can be relaxed to [i*4, i*4+4) ^ [40, inf) - arith::IntSet hint_bound = arith::EvalSet(relaxed, hint_map_); - relaxed = arith::Intersect({relaxed, hint_bound}); - } - return relaxed; -} - void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { std::vector<arith::IntSet> relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(RelaxAccessIndex(index)); + relaxed_region.push_back(arith::EvalSet(index, dom_map_)); } Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region); ExprVisitor::VisitExpr_(op); @@ -177,12 +160,12 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { VisitExpr(op->condition); { // Visit then branch - With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true); + With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true); StmtExprVisitor::VisitStmt(op->then_case); } if (op->else_case.defined()) { // Visit else branch - With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false); + With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false); StmtExprVisitor::VisitStmt(op->else_case); } } @@ -192,12 +175,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { VisitExpr(op->args[0]); { // Visit then branch - With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true); + With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true); StmtExprVisitor::VisitExpr(op->args[1]); } { // Visit else branch - With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false); + With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false); StmtExprVisitor::VisitExpr(op->args[2]); } return; @@ -213,7 +196,7 @@ void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { std::vector<arith::IntSet> relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(RelaxAccessIndex(index)); + relaxed_region.push_back(arith::EvalSet(index, dom_map_)); } Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region); StmtVisitor::VisitStmt_(op); diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 20ddd7f..07f9778 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -123,12 +123,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor { StmtExprVisitor::VisitExpr(op->condition); { // Visit then branch - With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true); + With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true); StmtExprVisitor::VisitStmt(op->then_case); } if (op->else_case.defined()) { // Visit else branch - With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false); + With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false); StmtExprVisitor::VisitStmt(op->else_case); } } @@ -139,12 +139,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor { StmtExprVisitor::VisitExpr(op->args[0]); { // Visit then branch - With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true); + With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true); StmtExprVisitor::VisitExpr(op->args[1]); } { // Visit else branch - With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false); + With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false); StmtExprVisitor::VisitExpr(op->args[2]); } return; @@ -282,8 +282,6 @@ class BufferAccessRegionCollector : public StmtExprVisitor { /*! \brief The map from loop vars to their iter range. */ std::unordered_map<const VarNode*, arith::IntSet> dom_map_; - /*! \brief Extra map from free vars to their iter range hints. */ - std::unordered_map<const VarNode*, arith::IntSet> hint_map_; /*! \brief The analyzer aware of loop domains. */ arith::Analyzer dom_analyzer_; /*! \brief The map from Buffer to it's relaxed access set. */ diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index bc2f7ad..2423b09 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -300,18 +300,11 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() { Array<Var> vars = Array<Var>(var_set.begin(), var_set.end()); Map<Var, Range> ranges; for (const Var& v : vars) { - arith::IntSet dom; - auto relax_it = relax_map_->find(v.get()); - if (relax_it != relax_map_->end()) { - dom = relax_it->second; - } else { - auto hint_it = hint_map_->find(v.get()); - if (hint_it != hint_map_->end()) { - dom = hint_it->second; - } - } - if (dom.defined()) { - ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1))); + auto it = dom_map_->find(v.get()); + if (it != dom_map_->end()) { + const auto& int_set = it->second; + ranges.Set(v, Range::FromMinExtent(int_set.min(), + analyzer.Simplify(int_set.max() - int_set.min() + 1))); } } // solve constraints @@ -321,53 +314,24 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() { } ConditionalBoundsContext::ConditionalBoundsContext( - const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* relax_map, - std::unordered_map<const VarNode*, arith::IntSet>* hint_map, bool is_true_branch) - : condition_(condition), - relax_map_(relax_map), - hint_map_(hint_map), - is_true_branch_(is_true_branch) {} + const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* dom_map, + bool is_true_branch) + : condition_(condition), dom_map_(dom_map), is_true_branch_(is_true_branch) {} void ConditionalBoundsContext::EnterWithScope() { for (const auto& p : GetVarBoundsFromCondition()) { const auto* var = p.first.get(); - arith::IntSet new_dom = arith::IntSet::FromRange(p.second); - auto relax_it = relax_map_->find(var); - if (relax_it != relax_map_->end()) { - // this is a bound for relaxed var - origin_map_.emplace(var, relax_it->second); - relax_it->second = arith::Intersect({relax_it->second, new_dom}); - } else { - // this is a bound for free var - auto hint_it = hint_map_->find(var); - if (hint_it != hint_map_->end()) { - origin_map_.emplace(var, hint_it->second); - hint_it->second = arith::Intersect({hint_it->second, new_dom}); - } else { - origin_map_.emplace(var, arith::IntSet::Nothing()); - hint_map_->insert(hint_it, {var, new_dom}); - } + auto it = dom_map_->find(var); + if (it != dom_map_->end()) { + origin_map_.emplace(var, it->second); + it->second = arith::Intersect({it->second, arith::IntSet::FromRange(p.second)}); } } } void ConditionalBoundsContext::ExitWithScope() { for (const auto& p : origin_map_) { - const auto* var = p.first; - auto relax_it = relax_map_->find(var); - if (relax_it != relax_map_->end()) { - // recover bound for relaxed var - relax_it->second = p.second; - } else { - // recover bound for free var - auto hint_it = hint_map_->find(var); - ICHECK(hint_it != hint_map_->end()); - if (p.second.IsNothing()) { - hint_map_->erase(hint_it); - } else { - hint_it->second = p.second; - } - } + (*dom_map_)[p.first] = p.second; } } diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index da52a82..7b1d34c 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -231,9 +231,9 @@ Bool IsFromLegacyTESchedule(PrimFunc f); *\brief Context helper to update domain map within conditional scope. * * Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 20], thus `bounds[i]` is - * [0, 8]. Then `With<ConditionalBoundsContext> ctx(condition, &relax_map, &hint_map, true)` step - *into scope where dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(condition, - *&relax_map, &hint_map, false)` step into scope where dom_map[i] is [9, 20] + *[0, 8]. Then `With<ConditionalBoundsContext> ctx(&dom_map, bounds, true)` step into scope where + *dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(&dom_map, bounds, false)` step into + *scope where dom_map[i] is [9, 20] */ class ConditionalBoundsContext { private: @@ -241,13 +241,11 @@ class ConditionalBoundsContext { /*! * \brief Construct a condition bounds context. * \param condition The condition holds on true branch. - * \param relax_map The domain map for relaxed vars to update. - * \param hint_map The domain map for free vars to update. + * \param dom_map The global domain map to be updated. * \param is_true_branch Whether step into the branch where condition bounds holds. */ ConditionalBoundsContext(const PrimExpr& condition, - std::unordered_map<const VarNode*, arith::IntSet>* relax_map, - std::unordered_map<const VarNode*, arith::IntSet>* hint_map, + std::unordered_map<const VarNode*, arith::IntSet>* dom_map, bool is_true_branch); void EnterWithScope(); void ExitWithScope(); @@ -257,10 +255,8 @@ class ConditionalBoundsContext { /*! \brief the condition holds on true branch. */ const PrimExpr& condition_; - /*! \brief domain map for relaxed vars to update */ - std::unordered_map<const VarNode*, arith::IntSet>* relax_map_; - /*! \brief domain map for free vars to update */ - std::unordered_map<const VarNode*, arith::IntSet>* hint_map_; + /*! \brief global domain map to updated */ + std::unordered_map<const VarNode*, arith::IntSet>* dom_map_; /*! \brief whether is on true branch */ bool is_true_branch_; /*! \brief used to record and restore original var bounds */ diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 5403754..e508fbb 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -130,41 +130,6 @@ def access_in_branch_func() -> None: B[i] = A[i - 1] [email protected]_func -def access_of_padding_pattern() -> None: - X = T.alloc_buffer([28, 28]) - X_pad = T.alloc_buffer([32, 32]) - Y = T.alloc_buffer([28, 28]) - for i, j in T.grid(32, 32): - with T.block("padding"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads( - [ - X[ - T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1, - T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1, - ] - ] - ) - T.writes([X_pad[vi, vj]]) - X_pad[vi, vj] = T.if_then_else( - 2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32" - ) - with T.block("padding_reverse"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads([X_pad[T.max(vi, 2) : T.min(vi, 29) + 1, T.max(vj, 2) : T.min(vj, 29) + 1]]) - T.writes( - [ - Y[ - T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1, - T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1, - ] - ] - ) - if 2 <= vi and vi < 30 and 2 <= vj and vj < 30: - Y[vi - 2, vj - 2] = X_pad[vi, vj] - - def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -255,36 +220,6 @@ def test_access_in_branch_func(): tvm.ir.assert_structural_equal(ret0[1], ret1[1]) -def test_access_of_padding_pattern(): - s = tvm.tir.schedule.Schedule(access_of_padding_pattern) - alloc_buffers = s.get_sref(s.get_block("root")).stmt.alloc_buffers - buffer_var_map = {buf.data: buf for buf in alloc_buffers} - - def do_compare_buffer_region(region, expect): - assert region.buffer == expect.buffer - analyzer = tvm.arith.Analyzer() - for k, rng in enumerate(region.region): - tvm.ir.assert_structural_equal( - analyzer.simplify(rng.min), analyzer.simplify(expect.region[k].min) - ) - tvm.ir.assert_structural_equal( - analyzer.simplify(rng.extent), analyzer.simplify(expect.region[k].extent) - ) - - def do_check_block(block_name): - block = s.get_sref(s.get_block(block_name)).stmt - expect_reads = block.reads - expect_writes = block.writes - ret = tir.analysis.get_block_access_region(block, buffer_var_map) - for i, read in enumerate(ret[0]): - do_compare_buffer_region(read, expect_reads[i]) - for i, write in enumerate(ret[1]): - do_compare_buffer_region(write, expect_writes[i]) - - do_check_block("padding") - do_check_block("padding_reverse") - - if __name__ == "__main__": test_block_access_region_detector() test_opaque_block() @@ -292,4 +227,3 @@ if __name__ == "__main__": test_match_buffer() test_access_in_if_then_else_func() test_access_in_branch_func() - test_access_of_padding_pattern() diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 9b84485..57c87e5 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -24,7 +24,6 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.CompactBufferAllocation()(mod) mod = tvm.tir.transform.Simplify()(mod) - transformed = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(transformed))["main"] tvm.ir.assert_structural_equal(mod["main"], transformed)
