This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity-staging by this push:
new 525f5aee5f [Unity][TIR] Allow symbolic bounds in IndexMap analysis
(#15262)
525f5aee5f is described below
commit 525f5aee5fa1b8989b3c7cb0284aa60d196d2460
Author: Junru Shao <[email protected]>
AuthorDate: Sun Jul 9 06:42:32 2023 -0700
[Unity][TIR] Allow symbolic bounds in IndexMap analysis (#15262)
Following #15264, this PR makes changes accordingly to the Unity branch
to enable symbolic bounds in IndexMap analysis.
---
src/relax/analysis/layout_transformation.cc | 3 ++-
src/relax/op/tensor/manipulate.cc | 3 ++-
src/relax/transform/alter_op_impl.cc | 12 ++++++++----
3 files changed, 12 insertions(+), 6 deletions(-)
diff --git a/src/relax/analysis/layout_transformation.cc
b/src/relax/analysis/layout_transformation.cc
index 44538fea98..8348365761 100644
--- a/src/relax/analysis/layout_transformation.cc
+++ b/src/relax/analysis/layout_transformation.cc
@@ -22,6 +22,7 @@
* \brief Analyze the PrimFunc and suggest layout transformation on it's
blocks and buffers based on
* the user provided layout transformations on it's outputs.
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/relax/analysis.h>
#include <tvm/tir/analysis.h>
@@ -172,8 +173,8 @@ static bool AreIdenticalTransforms(const IndexMap& t0,
const IndexMap& t1) {
// Create a new shape expression.
Array<PrimExpr> t1_initial_indices =
t1->initial_indices.Map([](tir::Var i) -> PrimExpr { return i; });
- auto t0_output = t0->MapIndices(t1_initial_indices);
arith::Analyzer analyzer;
+ auto t0_output = t0->MapIndices(t1_initial_indices, &analyzer);
for (size_t i = 0; i < t0_output.size(); ++i) {
if (!analyzer.CanProveEqual(t0_output[i], t1->final_indices[i])) return
false;
}
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 5b298110be..a55d199822 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -474,7 +474,8 @@ StructInfo InferStructInfoLayoutTransform(const Call& call,
const BlockBuilder&
return TensorStructInfo(data_sinfo->dtype,
/*ndim=*/index_map->final_indices.size());
}
- Array<PrimExpr> output_shape =
index_map->MapShape(shape_sinfo->values.value());
+ arith::Analyzer analyzer;
+ Array<PrimExpr> output_shape =
index_map->MapShape(shape_sinfo->values.value(), &analyzer);
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}
diff --git a/src/relax/transform/alter_op_impl.cc
b/src/relax/transform/alter_op_impl.cc
index 1fadb86d71..f40ee3b3bf 100644
--- a/src/relax/transform/alter_op_impl.cc
+++ b/src/relax/transform/alter_op_impl.cc
@@ -23,6 +23,7 @@
* identify PrimFuncs to be replaced. Marks the new PrimFuncs with
kFrozenLayout attribute set to
* true.
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/ir/attrs.h>
#include <tvm/node/serialization.h>
#include <tvm/relax/analysis.h>
@@ -60,9 +61,9 @@ static IndexMap DeepCopyIndexMap(const IndexMap& index_map) {
bool IsTransformBijective(const Expr& expr, const IndexMap& transform) {
Array<PrimExpr> input_shape = GetShapeFromTensor(expr);
Array<Range> initial_ranges = ConstructRangeFromShape(input_shape);
- auto [inverse, padding_predicate] =
transform.NonSurjectiveInverse(initial_ranges);
- (void)inverse; // to avoid unused variable warning;
arith::Analyzer analyzer;
+ auto [inverse, padding_predicate] =
transform.NonSurjectiveInverse(initial_ranges, &analyzer);
+ (void)inverse; // to avoid unused variable warning;
if (!analyzer.CanProve(!padding_predicate)) return false;
return true;
}
@@ -169,7 +170,9 @@ class AlterOpImplMutator : public ExprMutator {
const TensorStructInfo& old_tensor_sinfo) {
Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo);
Array<Range> initial_ranges = ConstructRangeFromShape(old_shape);
- auto [inverse_index_map, padding_predicate] =
index_map.NonSurjectiveInverse(initial_ranges);
+ arith::Analyzer analyzer;
+ auto [inverse_index_map, padding_predicate] =
+ index_map.NonSurjectiveInverse(initial_ranges, &analyzer);
ICHECK(tir::is_zero(padding_predicate))
<< "Only bijective transformations on input/output buffers are
supported, but found "
"padding predicate "
@@ -245,7 +248,8 @@ class AlterOpImplMutator : public ExprMutator {
/*! \brief Returns the TensorStructInfo after applying the \p transform on
its shape */
StructInfo UpdateStructInfo(const TensorStructInfo& tensor_sinfo, const
IndexMap& transform) {
auto shape = GetShapeFromTensorStructInfo(tensor_sinfo);
- auto new_shape = transform->MapShape(shape);
+ arith::Analyzer analyzer;
+ auto new_shape = transform->MapShape(shape, &analyzer);
return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype);
}