This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fe1090e8aa [TIR] IndexMap Simplification Constraints (#11342)
fe1090e8aa is described below
commit fe1090e8aa6b6307f150f46ab968451765a6a079
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed May 18 11:38:55 2022 -0500
[TIR] IndexMap Simplification Constraints (#11342)
* [TIR] Added optional arith::Analyzer argument to IndexMap methods
Simplifications done when applying a transformation may require
iteration bounds from the caller scope. This is a C++ only feature,
because `arith::Analyzer` doesn't inherit from `ObjectRef`, and cannot
be passed through the FFI.
* [TIR] Pass analyzer from TransformLayoutRewriter to IndexMap
Avoid needing to simplify twice, now that IndexMap can accept the
analyzer from the calling scope.
* [TIR] Added BlockNode handling to IRMutatorWithAnalyzer
Iteration variables defined in `BlockNode::iter_vars` may be useful
for simplifications. This functionality was extracted from
`TransformLayoutRewriter`.
---
include/tvm/tir/index_map.h | 22 ++++++++++--
src/arith/ir_mutator_with_analyzer.cc | 7 ++++
src/arith/ir_mutator_with_analyzer.h | 1 +
src/tir/ir/index_map.cc | 42 ++++++++++++++--------
.../schedule/primitive/layout_transformation.cc | 28 ++++++++-------
5 files changed, 70 insertions(+), 30 deletions(-)
diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h
index b6faa67ab5..315bda2599 100644
--- a/include/tvm/tir/index_map.h
+++ b/include/tvm/tir/index_map.h
@@ -33,6 +33,12 @@
#include <utility>
+namespace tvm {
+namespace arith {
+class Analyzer;
+}
+} // namespace tvm
+
namespace tvm {
namespace tir {
@@ -78,10 +84,14 @@ class IndexMapNode : public Object {
* \param indices The indices in the input space. Should contain
* one value for each variable in `initial_indices`.
*
+ * \param analyzer An optional analyzer to be used to simplify the
+ * resulting expressions. If null, will use a fresh analyzer.
+ *
* \returns The indices in the output space. Contains one value for
* each expression in `final_indices`.
*/
- Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices) const;
+ Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices,
+ arith::Analyzer* analyzer = nullptr) const;
/*! \brief Map a memory range to the output space
*
@@ -93,20 +103,26 @@ class IndexMapNode : public Object {
* \param ranges The ranges in the input space. Should contain one
* value for each variable in `initial_indices`.
*
+ * \param analyzer An optional analyzer to be used to simplify the
+ * resulting expressions. If null, will use a fresh analyzer.
+ *
* \returns The ranges in the output space. Contains one value for
* each expression in `final_indices`.
*/
- Array<Range> MapRanges(const Array<Range>& ranges) const;
+ Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer
= nullptr) const;
/*! \brief Map a buffer shape to the output space
*
* \param shape The buffer shape in the input space. Should contain
* one value for each variable in `initial_indices`.
*
+ * \param analyzer An optional analyzer to be used to simplify the
+ * resulting expressions. If null, will use a fresh analyzer.
+ *
* \returns The buffer shape in the output space. Contains one
* value for each expression in `final_indices`.
*/
- Array<PrimExpr> MapShape(const Array<PrimExpr>& shape) const;
+ Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer*
analyzer = nullptr) const;
/*!
* \brief Convert to string representation in Python.
diff --git a/src/arith/ir_mutator_with_analyzer.cc
b/src/arith/ir_mutator_with_analyzer.cc
index 7bc0d946ad..9cae3b7a6a 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -35,6 +35,13 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
return StmtExprMutator::VisitStmt_(op);
}
+Stmt IRMutatorWithAnalyzer::VisitStmt_(const BlockNode* op) {
+ for (const auto& iter_var : op->iter_vars) {
+ analyzer_->Bind(iter_var->var, iter_var->dom);
+ }
+ return StmtExprMutator::VisitStmt_(op);
+}
+
Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
PrimExpr value = this->VisitExpr(op->value);
if (SideEffect(value) <= CallEffectKind::kPure) {
diff --git a/src/arith/ir_mutator_with_analyzer.h
b/src/arith/ir_mutator_with_analyzer.h
index 004265bbe5..3bd3a98a84 100644
--- a/src/arith/ir_mutator_with_analyzer.h
+++ b/src/arith/ir_mutator_with_analyzer.h
@@ -50,6 +50,7 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
// override functions that need to populate the context information.
tir::Stmt VisitStmt_(const tir::ForNode* op) override;
+ tir::Stmt VisitStmt_(const tir::BlockNode* op) override;
tir::Stmt VisitStmt_(const tir::LetStmtNode* op) override;
tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override;
tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override;
diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc
index 4c0a7d3508..77678d829a 100644
--- a/src/tir/ir/index_map.cc
+++ b/src/tir/ir/index_map.cc
@@ -159,24 +159,29 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges)
const {
return IndexMap(output_vars, inverse_exprs);
}
-Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices) const
{
+Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices,
+ arith::Analyzer* analyzer) const {
ICHECK_EQ(indices.size(), initial_indices.size());
- arith::Analyzer analyzer;
+ Map<Var, PrimExpr> vmap;
for (size_t i = 0; i < initial_indices.size(); i++) {
- analyzer.Bind(initial_indices[i], indices[i]);
+ vmap.Set(initial_indices[i], indices[i]);
}
- Array<PrimExpr> output;
- for (const auto& output_dim : final_indices) {
- output.push_back(analyzer.Simplify(output_dim));
+ arith::Analyzer local_analyzer;
+ if (!analyzer) {
+ analyzer = &local_analyzer;
}
+ Array<PrimExpr> output = final_indices;
+ output.MutateByApply(
+ [&](const PrimExpr& index) { return analyzer->Simplify(Substitute(index,
vmap)); });
+
return output;
}
-Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges) const {
+Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges,
arith::Analyzer* analyzer) const {
ICHECK_EQ(ranges.size(), initial_indices.size());
Map<Var, Range> input_iters;
@@ -189,25 +194,30 @@ Array<Range> IndexMapNode::MapRanges(const Array<Range>&
ranges) const {
dom_map[initial_indices[i].get()] = arith::IntSet::FromRange(ranges[i]);
}
+ arith::Analyzer local_analyzer;
+ if (!analyzer) {
+ analyzer = &local_analyzer;
+ }
+
Array<Range> output;
- arith::Analyzer analyzer;
for (const auto& final_index : final_indices) {
auto int_set = arith::EvalSet(final_index, dom_map);
- output.push_back(Range::FromMinExtent(analyzer.Simplify(int_set.min()),
- analyzer.Simplify(int_set.max() -
int_set.min() + 1)));
+ output.push_back(Range::FromMinExtent(analyzer->Simplify(int_set.min()),
+ analyzer->Simplify(int_set.max() -
int_set.min() + 1)));
}
return output;
}
-Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape) const {
+Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape,
+ arith::Analyzer* analyzer) const {
ICHECK_EQ(shape.size(), initial_indices.size());
Array<Range> ranges;
for (auto& dim : shape) {
ranges.push_back(Range(0, dim));
}
- Array<Range> mapped = MapRanges(std::move(ranges));
+ Array<Range> mapped = MapRanges(std::move(ranges), analyzer);
Array<PrimExpr> output;
for (auto& range : mapped) {
@@ -265,8 +275,12 @@ TVM_REGISTER_GLOBAL("tir.IndexMap")
return IndexMap(initial_indices, final_indices);
});
-TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices").set_body_method<IndexMap>(&IndexMapNode::MapIndices);
-TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_method<IndexMap>(&IndexMapNode::MapShape);
+TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices")
+ .set_body_typed([](IndexMap map, Array<PrimExpr> indices) { return
map->MapIndices(indices); });
+
+TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_typed([](IndexMap map,
Array<PrimExpr> shape) {
+ return map->MapShape(shape);
+});
TVM_REGISTER_GLOBAL("tir.IndexMapInverse").set_body_method(&IndexMap::Inverse);
TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse")
diff --git a/src/tir/schedule/primitive/layout_transformation.cc
b/src/tir/schedule/primitive/layout_transformation.cc
index 87e09505f5..fb63b1b289 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -16,12 +16,13 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include "../../../arith/ir_mutator_with_analyzer.h"
#include "../utils.h"
namespace tvm {
namespace tir {
-class TransformLayoutRewriter : private StmtExprMutator {
+class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
public:
/*!
* \brief Rewrite the access to the buffer after the transformation
@@ -36,27 +37,32 @@ class TransformLayoutRewriter : private StmtExprMutator {
const Buffer& old_buffer,
const Buffer& new_buffer,
const IndexMap& index_map)
{
- TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map);
+ arith::Analyzer analyzer;
+ TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map,
&analyzer);
Stmt result = rewriter(scope_stmt);
return {result, rewriter.block_sref_reuse_};
}
private:
TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer,
- const IndexMap& index_map)
- : old_buffer_(old_buffer),
+ const IndexMap& index_map, arith::Analyzer* analyzer)
+ : IRMutatorWithAnalyzer(analyzer),
+ old_buffer_(old_buffer),
new_buffer_(new_buffer),
index_map_(index_map),
buffer_data_to_buffer_{{new_buffer->data, new_buffer}} {}
void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) {
*buffer = new_buffer_;
- *indices = index_map_->MapIndices(*indices);
- (*indices).MutateByApply([this](const PrimExpr& index) { return
analyzer_.Simplify(index); });
+ *indices = index_map_->MapIndices(*indices, analyzer_);
}
+ using Parent = arith::IRMutatorWithAnalyzer;
+ using Parent::VisitExpr_;
+ using Parent::VisitStmt_;
+
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
- BufferLoad buffer_load =
Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+ BufferLoad buffer_load = Downcast<BufferLoad>(Parent::VisitExpr_(op));
if (buffer_load->buffer.same_as(old_buffer_)) {
auto* n = buffer_load.CopyOnWrite();
RewriteBufferAccess(&n->buffer, &n->indices);
@@ -65,7 +71,7 @@ class TransformLayoutRewriter : private StmtExprMutator {
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
- BufferStore buffer_store =
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+ BufferStore buffer_store = Downcast<BufferStore>(Parent::VisitStmt_(op));
if (buffer_store->buffer.same_as(old_buffer_)) {
auto* n = buffer_store.CopyOnWrite();
RewriteBufferAccess(&n->buffer, &n->indices);
@@ -86,10 +92,7 @@ class TransformLayoutRewriter : private StmtExprMutator {
}
Stmt VisitStmt_(const BlockNode* op) final {
- for (const auto& iter_var : op->iter_vars) {
- analyzer_.Bind(iter_var->var, iter_var->dom);
- }
- Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+ Block block = Downcast<Block>(Parent::VisitStmt_(op));
auto infered_access_regions = GetBlockReadWriteRegion(block,
buffer_data_to_buffer_);
auto* n = block.CopyOnWrite();
RewriteAccessRegion(&n->reads, infered_access_regions[0]);
@@ -101,7 +104,6 @@ class TransformLayoutRewriter : private StmtExprMutator {
const Buffer& old_buffer_;
const Buffer& new_buffer_;
const IndexMap& index_map_;
- arith::Analyzer analyzer_;
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Block, Block> block_sref_reuse_;
};