This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity-staging by this push:
     new 525f5aee5f [Unity][TIR] Allow symbolic bounds in IndexMap analysis 
(#15262)
525f5aee5f is described below

commit 525f5aee5fa1b8989b3c7cb0284aa60d196d2460
Author: Junru Shao <[email protected]>
AuthorDate: Sun Jul 9 06:42:32 2023 -0700

    [Unity][TIR] Allow symbolic bounds in IndexMap analysis (#15262)
    
    Following #15264, this PR makes changes accordingly to the Unity branch
    to enable symbolic bounds in IndexMap analysis.
---
 src/relax/analysis/layout_transformation.cc |  3 ++-
 src/relax/op/tensor/manipulate.cc           |  3 ++-
 src/relax/transform/alter_op_impl.cc        | 12 ++++++++----
 3 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/src/relax/analysis/layout_transformation.cc 
b/src/relax/analysis/layout_transformation.cc
index 44538fea98..8348365761 100644
--- a/src/relax/analysis/layout_transformation.cc
+++ b/src/relax/analysis/layout_transformation.cc
@@ -22,6 +22,7 @@
  * \brief Analyze the PrimFunc and suggest layout transformation on it's 
blocks and buffers based on
  * the user provided layout transformations on it's outputs.
  */
+#include <tvm/arith/analyzer.h>
 #include <tvm/arith/iter_affine_map.h>
 #include <tvm/relax/analysis.h>
 #include <tvm/tir/analysis.h>
@@ -172,8 +173,8 @@ static bool AreIdenticalTransforms(const IndexMap& t0, 
const IndexMap& t1) {
   // Create a new shape expression.
   Array<PrimExpr> t1_initial_indices =
       t1->initial_indices.Map([](tir::Var i) -> PrimExpr { return i; });
-  auto t0_output = t0->MapIndices(t1_initial_indices);
   arith::Analyzer analyzer;
+  auto t0_output = t0->MapIndices(t1_initial_indices, &analyzer);
   for (size_t i = 0; i < t0_output.size(); ++i) {
     if (!analyzer.CanProveEqual(t0_output[i], t1->final_indices[i])) return 
false;
   }
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 5b298110be..a55d199822 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -474,7 +474,8 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, 
const BlockBuilder&
     return TensorStructInfo(data_sinfo->dtype, 
/*ndim=*/index_map->final_indices.size());
   }
 
-  Array<PrimExpr> output_shape = 
index_map->MapShape(shape_sinfo->values.value());
+  arith::Analyzer analyzer;
+  Array<PrimExpr> output_shape = 
index_map->MapShape(shape_sinfo->values.value(), &analyzer);
   return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
 }
 
diff --git a/src/relax/transform/alter_op_impl.cc 
b/src/relax/transform/alter_op_impl.cc
index 1fadb86d71..f40ee3b3bf 100644
--- a/src/relax/transform/alter_op_impl.cc
+++ b/src/relax/transform/alter_op_impl.cc
@@ -23,6 +23,7 @@
  * identify PrimFuncs to be replaced. Marks the new PrimFuncs with 
kFrozenLayout attribute set to
  * true.
  */
+#include <tvm/arith/analyzer.h>
 #include <tvm/ir/attrs.h>
 #include <tvm/node/serialization.h>
 #include <tvm/relax/analysis.h>
@@ -60,9 +61,9 @@ static IndexMap DeepCopyIndexMap(const IndexMap& index_map) {
 bool IsTransformBijective(const Expr& expr, const IndexMap& transform) {
   Array<PrimExpr> input_shape = GetShapeFromTensor(expr);
   Array<Range> initial_ranges = ConstructRangeFromShape(input_shape);
-  auto [inverse, padding_predicate] = 
transform.NonSurjectiveInverse(initial_ranges);
-  (void)inverse;  // to avoid unused variable warning;
   arith::Analyzer analyzer;
+  auto [inverse, padding_predicate] = 
transform.NonSurjectiveInverse(initial_ranges, &analyzer);
+  (void)inverse;  // to avoid unused variable warning;
   if (!analyzer.CanProve(!padding_predicate)) return false;
   return true;
 }
@@ -169,7 +170,9 @@ class AlterOpImplMutator : public ExprMutator {
                               const TensorStructInfo& old_tensor_sinfo) {
     Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo);
     Array<Range> initial_ranges = ConstructRangeFromShape(old_shape);
-    auto [inverse_index_map, padding_predicate] = 
index_map.NonSurjectiveInverse(initial_ranges);
+    arith::Analyzer analyzer;
+    auto [inverse_index_map, padding_predicate] =
+        index_map.NonSurjectiveInverse(initial_ranges, &analyzer);
     ICHECK(tir::is_zero(padding_predicate))
         << "Only bijective transformations on input/output buffers are 
supported, but found "
            "padding predicate "
@@ -245,7 +248,8 @@ class AlterOpImplMutator : public ExprMutator {
   /*! \brief Returns the TensorStructInfo after applying the \p transform on 
its shape */
   StructInfo UpdateStructInfo(const TensorStructInfo& tensor_sinfo, const 
IndexMap& transform) {
     auto shape = GetShapeFromTensorStructInfo(tensor_sinfo);
-    auto new_shape = transform->MapShape(shape);
+    arith::Analyzer analyzer;
+    auto new_shape = transform->MapShape(shape, &analyzer);
     return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype);
   }
 

Reply via email to