This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fe1090e8aa [TIR] IndexMap Simplification Constraints (#11342)
fe1090e8aa is described below

commit fe1090e8aa6b6307f150f46ab968451765a6a079
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed May 18 11:38:55 2022 -0500

    [TIR] IndexMap Simplification Constraints (#11342)
    
    * [TIR] Added optional arith::Analyzer argument to IndexMap methods
    
    Simplifications done when applying a transformation may require
    iteration bounds from the caller scope.  This is a C++ only feature,
    because `arith::Analyzer` doesn't inherit from `ObjectRef`, and cannot
    be passed through the FFI.
    
    * [TIR] Pass analyzer from TransformLayoutRewriter to IndexMap
    
    Avoid needing to simplify twice, now that IndexMap can accept the
    analyzer from the calling scope.
    
    * [TIR] Added BlockNode handling to IRMutatorWithAnalyzer
    
    Iteration variables defined in `BlockNode::iter_vars` may be useful
    for simplifications.  This functionality was extracted from
    `TransformLayoutRewriter`.
---
 include/tvm/tir/index_map.h                        | 22 ++++++++++--
 src/arith/ir_mutator_with_analyzer.cc              |  7 ++++
 src/arith/ir_mutator_with_analyzer.h               |  1 +
 src/tir/ir/index_map.cc                            | 42 ++++++++++++++--------
 .../schedule/primitive/layout_transformation.cc    | 28 ++++++++-------
 5 files changed, 70 insertions(+), 30 deletions(-)

diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h
index b6faa67ab5..315bda2599 100644
--- a/include/tvm/tir/index_map.h
+++ b/include/tvm/tir/index_map.h
@@ -33,6 +33,12 @@
 
 #include <utility>
 
+namespace tvm {
+namespace arith {
+class Analyzer;
+}
+}  // namespace tvm
+
 namespace tvm {
 namespace tir {
 
@@ -78,10 +84,14 @@ class IndexMapNode : public Object {
    * \param indices The indices in the input space.  Should contain
    * one value for each variable in `initial_indices`.
    *
+   * \param analyzer An optional analyzer to be used to simplify the
+   * resulting expressions.  If null, will use a fresh analyzer.
+   *
    * \returns The indices in the output space.  Contains one value for
    * each expression in `final_indices`.
    */
-  Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices) const;
+  Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices,
+                             arith::Analyzer* analyzer = nullptr) const;
 
   /*! \brief Map a memory range to the output space
    *
@@ -93,20 +103,26 @@ class IndexMapNode : public Object {
    * \param ranges The ranges in the input space.  Should contain one
    * value for each variable in `initial_indices`.
    *
+   * \param analyzer An optional analyzer to be used to simplify the
+   * resulting expressions.  If null, will use a fresh analyzer.
+   *
    * \returns The ranges in the output space.  Contains one value for
    * each expression in `final_indices`.
    */
-  Array<Range> MapRanges(const Array<Range>& ranges) const;
+  Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer 
= nullptr) const;
 
   /*! \brief Map a buffer shape to the output space
    *
    * \param shape The buffer shape in the input space.  Should contain
    * one value for each variable in `initial_indices`.
    *
+   * \param analyzer An optional analyzer to be used to simplify the
+   * resulting expressions.  If null, will use a fresh analyzer.
+   *
    * \returns The buffer shape in the output space.  Contains one
    * value for each expression in `final_indices`.
    */
-  Array<PrimExpr> MapShape(const Array<PrimExpr>& shape) const;
+  Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer* 
analyzer = nullptr) const;
 
   /*!
    * \brief Convert to string representation in Python.
diff --git a/src/arith/ir_mutator_with_analyzer.cc 
b/src/arith/ir_mutator_with_analyzer.cc
index 7bc0d946ad..9cae3b7a6a 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -35,6 +35,13 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
   return StmtExprMutator::VisitStmt_(op);
 }
 
+Stmt IRMutatorWithAnalyzer::VisitStmt_(const BlockNode* op) {
+  for (const auto& iter_var : op->iter_vars) {
+    analyzer_->Bind(iter_var->var, iter_var->dom);
+  }
+  return StmtExprMutator::VisitStmt_(op);
+}
+
 Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
   PrimExpr value = this->VisitExpr(op->value);
   if (SideEffect(value) <= CallEffectKind::kPure) {
diff --git a/src/arith/ir_mutator_with_analyzer.h 
b/src/arith/ir_mutator_with_analyzer.h
index 004265bbe5..3bd3a98a84 100644
--- a/src/arith/ir_mutator_with_analyzer.h
+++ b/src/arith/ir_mutator_with_analyzer.h
@@ -50,6 +50,7 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
 
   // override functions that need to populate the context information.
   tir::Stmt VisitStmt_(const tir::ForNode* op) override;
+  tir::Stmt VisitStmt_(const tir::BlockNode* op) override;
   tir::Stmt VisitStmt_(const tir::LetStmtNode* op) override;
   tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override;
   tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override;
diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc
index 4c0a7d3508..77678d829a 100644
--- a/src/tir/ir/index_map.cc
+++ b/src/tir/ir/index_map.cc
@@ -159,24 +159,29 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges) 
const {
   return IndexMap(output_vars, inverse_exprs);
 }
 
-Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices) const 
{
+Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices,
+                                         arith::Analyzer* analyzer) const {
   ICHECK_EQ(indices.size(), initial_indices.size());
 
-  arith::Analyzer analyzer;
+  Map<Var, PrimExpr> vmap;
 
   for (size_t i = 0; i < initial_indices.size(); i++) {
-    analyzer.Bind(initial_indices[i], indices[i]);
+    vmap.Set(initial_indices[i], indices[i]);
   }
 
-  Array<PrimExpr> output;
-  for (const auto& output_dim : final_indices) {
-    output.push_back(analyzer.Simplify(output_dim));
+  arith::Analyzer local_analyzer;
+  if (!analyzer) {
+    analyzer = &local_analyzer;
   }
 
+  Array<PrimExpr> output = final_indices;
+  output.MutateByApply(
+      [&](const PrimExpr& index) { return analyzer->Simplify(Substitute(index, 
vmap)); });
+
   return output;
 }
 
-Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges) const {
+Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges, 
arith::Analyzer* analyzer) const {
   ICHECK_EQ(ranges.size(), initial_indices.size());
 
   Map<Var, Range> input_iters;
@@ -189,25 +194,30 @@ Array<Range> IndexMapNode::MapRanges(const Array<Range>& 
ranges) const {
     dom_map[initial_indices[i].get()] = arith::IntSet::FromRange(ranges[i]);
   }
 
+  arith::Analyzer local_analyzer;
+  if (!analyzer) {
+    analyzer = &local_analyzer;
+  }
+
   Array<Range> output;
-  arith::Analyzer analyzer;
   for (const auto& final_index : final_indices) {
     auto int_set = arith::EvalSet(final_index, dom_map);
-    output.push_back(Range::FromMinExtent(analyzer.Simplify(int_set.min()),
-                                          analyzer.Simplify(int_set.max() - 
int_set.min() + 1)));
+    output.push_back(Range::FromMinExtent(analyzer->Simplify(int_set.min()),
+                                          analyzer->Simplify(int_set.max() - 
int_set.min() + 1)));
   }
 
   return output;
 }
 
-Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape) const {
+Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape,
+                                       arith::Analyzer* analyzer) const {
   ICHECK_EQ(shape.size(), initial_indices.size());
 
   Array<Range> ranges;
   for (auto& dim : shape) {
     ranges.push_back(Range(0, dim));
   }
-  Array<Range> mapped = MapRanges(std::move(ranges));
+  Array<Range> mapped = MapRanges(std::move(ranges), analyzer);
 
   Array<PrimExpr> output;
   for (auto& range : mapped) {
@@ -265,8 +275,12 @@ TVM_REGISTER_GLOBAL("tir.IndexMap")
       return IndexMap(initial_indices, final_indices);
     });
 
-TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices").set_body_method<IndexMap>(&IndexMapNode::MapIndices);
-TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_method<IndexMap>(&IndexMapNode::MapShape);
+TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices")
+    .set_body_typed([](IndexMap map, Array<PrimExpr> indices) { return 
map->MapIndices(indices); });
+
+TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_typed([](IndexMap map, 
Array<PrimExpr> shape) {
+  return map->MapShape(shape);
+});
 TVM_REGISTER_GLOBAL("tir.IndexMapInverse").set_body_method(&IndexMap::Inverse);
 
 TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse")
diff --git a/src/tir/schedule/primitive/layout_transformation.cc 
b/src/tir/schedule/primitive/layout_transformation.cc
index 87e09505f5..fb63b1b289 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -16,12 +16,13 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#include "../../../arith/ir_mutator_with_analyzer.h"
 #include "../utils.h"
 
 namespace tvm {
 namespace tir {
 
-class TransformLayoutRewriter : private StmtExprMutator {
+class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
  public:
   /*!
    * \brief Rewrite the access to the buffer after the transformation
@@ -36,27 +37,32 @@ class TransformLayoutRewriter : private StmtExprMutator {
                                                     const Buffer& old_buffer,
                                                     const Buffer& new_buffer,
                                                     const IndexMap& index_map) 
{
-    TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map);
+    arith::Analyzer analyzer;
+    TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, 
&analyzer);
     Stmt result = rewriter(scope_stmt);
     return {result, rewriter.block_sref_reuse_};
   }
 
  private:
   TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer,
-                          const IndexMap& index_map)
-      : old_buffer_(old_buffer),
+                          const IndexMap& index_map, arith::Analyzer* analyzer)
+      : IRMutatorWithAnalyzer(analyzer),
+        old_buffer_(old_buffer),
         new_buffer_(new_buffer),
         index_map_(index_map),
         buffer_data_to_buffer_{{new_buffer->data, new_buffer}} {}
 
   void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) {
     *buffer = new_buffer_;
-    *indices = index_map_->MapIndices(*indices);
-    (*indices).MutateByApply([this](const PrimExpr& index) { return 
analyzer_.Simplify(index); });
+    *indices = index_map_->MapIndices(*indices, analyzer_);
   }
 
+  using Parent = arith::IRMutatorWithAnalyzer;
+  using Parent::VisitExpr_;
+  using Parent::VisitStmt_;
+
   PrimExpr VisitExpr_(const BufferLoadNode* op) final {
-    BufferLoad buffer_load = 
Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    BufferLoad buffer_load = Downcast<BufferLoad>(Parent::VisitExpr_(op));
     if (buffer_load->buffer.same_as(old_buffer_)) {
       auto* n = buffer_load.CopyOnWrite();
       RewriteBufferAccess(&n->buffer, &n->indices);
@@ -65,7 +71,7 @@ class TransformLayoutRewriter : private StmtExprMutator {
   }
 
   Stmt VisitStmt_(const BufferStoreNode* op) final {
-    BufferStore buffer_store = 
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    BufferStore buffer_store = Downcast<BufferStore>(Parent::VisitStmt_(op));
     if (buffer_store->buffer.same_as(old_buffer_)) {
       auto* n = buffer_store.CopyOnWrite();
       RewriteBufferAccess(&n->buffer, &n->indices);
@@ -86,10 +92,7 @@ class TransformLayoutRewriter : private StmtExprMutator {
   }
 
   Stmt VisitStmt_(const BlockNode* op) final {
-    for (const auto& iter_var : op->iter_vars) {
-      analyzer_.Bind(iter_var->var, iter_var->dom);
-    }
-    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+    Block block = Downcast<Block>(Parent::VisitStmt_(op));
     auto infered_access_regions = GetBlockReadWriteRegion(block, 
buffer_data_to_buffer_);
     auto* n = block.CopyOnWrite();
     RewriteAccessRegion(&n->reads, infered_access_regions[0]);
@@ -101,7 +104,6 @@ class TransformLayoutRewriter : private StmtExprMutator {
   const Buffer& old_buffer_;
   const Buffer& new_buffer_;
   const IndexMap& index_map_;
-  arith::Analyzer analyzer_;
   Map<Var, Buffer> buffer_data_to_buffer_;
   Map<Block, Block> block_sref_reuse_;
 };

Reply via email to