This is an automated email from the ASF dual-hosted git repository.

wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 946581ab56 [TIR][Compute-at] Utilize InverseAffineIterMap for dom 
estimation (#14184)
946581ab56 is described below

commit 946581ab56ab0a74aab30338f537ee4dca20aad4
Author: wrongtest <[email protected]>
AuthorDate: Thu Mar 16 16:42:29 2023 +0800

    [TIR][Compute-at] Utilize InverseAffineIterMap for dom estimation (#14184)
    
    utilize inverse iter map tool for compute_at iter region estimation
---
 src/arith/iter_affine_map.cc                       |   3 +-
 src/tir/schedule/primitive/compute_at.cc           | 179 +++++++++++++++------
 .../unittest/test_tir_schedule_compute_at.py       |  74 +++++++++
 3 files changed, 205 insertions(+), 51 deletions(-)

diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index af6e47b7a0..05af5b4070 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -2147,7 +2147,8 @@ class InverseAffineIterMapTransformer {
     // Case 1: Propagate to the input node directly when the sum expression 
has only one components
     if (iter_map_expr->args.size() == 1) {
       const auto& source = iter_map_expr->args[0];
-      backprop_.Set(source, backprop_.at(source) + input);
+      ICHECK(analyzer_->CanProveEqual(abs(source->scale), 1));
+      backprop_.Set(source, (backprop_.at(source) + input) * source->scale);
       return;
     }
 
diff --git a/src/tir/schedule/primitive/compute_at.cc 
b/src/tir/schedule/primitive/compute_at.cc
index 5a5a536157..988c73c3f0 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -391,81 +391,163 @@ void RelaxBufferRegions(const Map<Var, PrimExpr>& 
binding,
  * domain
  * \param provided The provided integer set to cover the required domain
  * \param required The required domain to be covered
+ * \param dim_max The maximum index bound by the buffer shape
  * \param analyzer The arithmetic analyzer
  */
-std::pair<Var, arith::IntSet> SolveBlockVarDomain(const arith::IntSet& 
provided,
-                                                  const arith::IntSet& 
required,
-                                                  arith::Analyzer* analyzer) {
+std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& 
provided,
+                                                       const arith::IntSet& 
required,
+                                                       PrimExpr dim_max,
+                                                       arith::Analyzer* 
analyzer) {
   PrimExpr provided_min = analyzer->Simplify(provided.min());
   PrimExpr provided_max = analyzer->Simplify(provided.max());
   PrimExpr required_min = analyzer->Simplify(required.min());
   PrimExpr required_max = analyzer->Simplify(required.max());
-  PrimExpr dom_min{nullptr}, dom_max{nullptr};
-  Var dom_var{ObjectPtr<VarNode>{nullptr}};
+  arith::IntSet var_dom, var_bound;
+  Optional<Var> var;
   arith::PVar<Var> p_v;
   arith::PVar<PrimExpr> p_e;
   if ((p_v * p_e).Match(provided_min) || (p_e * p_v).Match(provided_min)) {
     PrimExpr e = p_e.Eval();
-    dom_var = p_v.Eval();
-    dom_min = floordiv(required_min, e);
-    dom_max = floordiv(required_max, e);
+    var = p_v.Eval();
+    var_dom = arith::IntSet::Interval(floordiv(required_min, e), 
floordiv(required_max, e));
+    var_bound = arith::IntSet::Interval(0, floordiv(dim_max, e));
   } else if (analyzer->CanProveEqual(provided_min, provided_max)) {
     if (p_v.Match(provided_min)) {
-      dom_var = p_v.Eval();
-      dom_min = required_min;
-      dom_max = required_max;
+      var = p_v.Eval();
+      var_dom = arith::IntSet::Interval(required_min, required_max);
+      var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
       arith::PVar<PrimExpr> p_f;
       if ((floordiv(p_v, p_f)).Match(provided_min)) {
         // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + 
fac - 1)
         PrimExpr fac = p_f.Eval();
         if (analyzer->CanProveGreaterEqual(fac, 1)) {
-          dom_var = p_v.Eval();
-          dom_min = required_min * fac;
-          dom_max = analyzer->Simplify(required_max * fac + fac - 1);
+          var = p_v.Eval();
+          var_dom = arith::IntSet::Interval(required_min * fac,
+                                            analyzer->Simplify(required_max * 
fac + fac - 1));
+          var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * 
fac + fac - 1));
         }
       } else if ((floormod(p_v, p_f).Match(provided_min))) {
         // generally domain of (x % fac) enforce no constraints to domain of x
-        dom_var = p_v.Eval();
-        return std::make_pair(dom_var, arith::IntSet::Nothing());
+        return {p_v.Eval(), BlockVarDomainInfo()};
       }
     }
   }
-  ICHECK(dom_var.defined()) << "ValueError: BufferRegion pattern match failed: 
" << provided_min;
-  return std::make_pair(dom_var, arith::IntSet::Interval(dom_min, dom_max));
+  ICHECK(var.defined()) << "ValueError: BufferRegion pattern match failed: " 
<< provided_min;
+  return {var.value(), BlockVarDomainInfo{var_dom, var_bound}};
 }
 
 /*!
- * \brief Calculate and update the iteration domain info to fully cover the 
required domain
- * \param provided The provided integer set to cover the required domain
- * \param required The required domain to be covered
- * \param required_bound The additional region bound of the required domain to 
be covered
+ * \brief Calculate and update the iteration domain info to fully cover the 
required domain in
+ * dimension-wise fashion. The region relation on each buffer dimension is 
independently estimated.
+ * \param buffer The accessed buffer
+ * \param provided_region The provided NDIntSet to cover the required domain
+ * \param required_region The required NDIntSet domain to be covered
+ * \param analyzer The arithmetic analyzer
  * \param iter_doms The result iteration domains to be updated
+ */
+void UpdateBlockVarDomainDimwise(
+    const BufferNode* buffer, const NDIntSet& provided_region, const NDIntSet& 
required_region,
+    arith::Analyzer* analyzer, std::unordered_map<const VarNode*, 
BlockVarDomainInfo>* iter_doms) {
+  size_t ndim = buffer->shape.size();
+  for (size_t i = 0; i < ndim; ++i) {
+    arith::IntSet provided = provided_region[i];
+    arith::IntSet required = required_region[i];
+    PrimExpr dim_max = max(buffer->shape[i] - 1, 0);
+
+    if (provided.IsSinglePoint() && is_const_int(provided.min())) {
+      ICHECK(required.IsSinglePoint() && 
analyzer->CanProveEqual(provided.min(), required.min()));
+      continue;
+    }
+
+    auto [var, dom_info] = SolveBlockVarDomain(provided, required, dim_max, 
analyzer);
+    auto it = iter_doms->find(var.get());
+    if (it != iter_doms->end()) {
+      it->second.Union(dom_info);
+    } else {
+      ICHECK(analyzer->CanProveEqual(provided.min(), required.min()));
+      ICHECK(analyzer->CanProveEqual(provided.max(), required.max()));
+    }
+  }
+}
+
+/*! \brief Helper function to implement intset version of 
`InverseAffineIterMap`. */
+Map<Var, arith::IntSet> InverseAffineIterMap(const Array<arith::IterSumExpr>& 
iter_map,
+                                             const NDIntSet& outputs, 
arith::Analyzer* analyzer) {
+  Array<PrimExpr> min_point, max_point;
+  min_point.reserve(outputs.size());
+  max_point.reserve(outputs.size());
+  for (const auto& intset : outputs) {
+    ICHECK(intset.HasLowerBound() && intset.HasUpperBound());
+    min_point.push_back(intset.min());
+    max_point.push_back(intset.max());
+  }
+  auto rev_min = InverseAffineIterMap(iter_map, min_point);
+  auto rev_max = InverseAffineIterMap(iter_map, max_point);
+  Map<Var, arith::IntSet> dom_map;
+  for (const auto& kv : rev_min) {
+    const Var& var = kv.first;
+    auto it = rev_max.find(var);
+    ICHECK(it != rev_max.end());  // InverseAffineIterMap's result vars are 
assumed stable
+    const PrimExpr& rev_min_point = kv.second;
+    const PrimExpr& rev_max_point = (*it).second;
+    dom_map.Set(var,
+                arith::IntSet::Interval(analyzer->Simplify(min(rev_min_point, 
rev_max_point)),
+                                        analyzer->Simplify(max(rev_min_point, 
rev_max_point))));
+  }
+  return dom_map;
+}
+
+/*!
+ * \brief Calculate and update the iteration domain info to fully cover the 
required domain
+ * with affine analysis. It requires bijective mapping of block var to 
provided region points.
+ * \param buffer The accessed buffer
+ * \param iter_vars The list of block vars to cover the required region
+ * \param provided_region The provided NDIntSet to cover the required domain
+ * \param required_region The required NDIntSet domain to be covered
  * \param analyzer The arithmetic analyzer
+ * \param iter_doms The result iteration domains to be updated
+ * \returns bool. Denotes whether update success
  */
-void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& 
required,
-                          const arith::IntSet& required_bound,
-                          std::unordered_map<const VarNode*, 
BlockVarDomainInfo>* iter_doms,
-                          arith::Analyzer* analyzer) {
-  if (provided.IsSinglePoint() && is_const_int(provided.min())) {
-    ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), 
required.min()));
-    ICHECK(required_bound.IsSinglePoint() &&
-           analyzer->CanProveEqual(provided.min(), required_bound.min()));
-    return;
+bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const 
Array<IterVar>& iter_vars,
+                                const NDIntSet& provided_region, const 
NDIntSet& required_region,
+                                arith::Analyzer* analyzer,
+                                std::unordered_map<const VarNode*, 
BlockVarDomainInfo>* iter_doms) {
+  // we only support single point provided region now, which could cover most 
cases
+  for (const auto& intset : provided_region) {
+    if (!intset.IsSinglePoint()) return false;
+  }
+  // calculate forward mapping (block vars -> provided region point)
+  Map<Var, Range> dom_map;
+  for (const IterVar& iter_var : iter_vars) {
+    dom_map.Set(iter_var->var, iter_var->dom);
   }
-  auto var_with_dom = SolveBlockVarDomain(provided, required, analyzer);
-  auto var_with_bound = SolveBlockVarDomain(provided, required_bound, 
analyzer);
-  const Var& var = var_with_dom.first;
-  const auto& var_dom = var_with_dom.second;
-  const auto& var_bound = var_with_bound.second;
-  ICHECK(var.same_as(var_with_bound.first));
-  auto it = iter_doms->find(var.get());
-  if (it != iter_doms->end()) {
-    it->second.Union({var_dom, var_bound});
-  } else {
-    ICHECK(analyzer->CanProveEqual(provided.min(), required.min()));
-    ICHECK(analyzer->CanProveEqual(provided.max(), required.max()));
+  size_t ndim = buffer->shape.size();
+  Array<PrimExpr> provide_indices;
+  provide_indices.reserve(ndim);
+  for (size_t i = 0; i < ndim; ++i) {
+    provide_indices.push_back(provided_region[i].min());
+  }
+  auto res = arith::DetectIterMap(provide_indices, dom_map, const_true(),
+                                  arith::IterMapLevel::Bijective, analyzer, 
false);
+  if (res->indices.empty()) {
+    return false;
   }
+  // calculate backward mapping (required region point -> block vars)
+  NDIntSet required_bound;
+  for (size_t i = 0; i < ndim; ++i) {
+    required_bound.push_back(
+        arith::IntSet::Interval(make_zero(buffer->shape[i]->dtype), 
max(buffer->shape[i] - 1, 0)));
+  }
+  Map<Var, arith::IntSet> var_dom = InverseAffineIterMap(res->indices, 
required_region, analyzer);
+  Map<Var, arith::IntSet> var_bound = InverseAffineIterMap(res->indices, 
required_bound, analyzer);
+  for (const auto& kv : var_dom) {
+    const Var& var = kv.first;
+    auto it = var_bound.find(var);
+    ICHECK(it != var_bound.end());  // InverseAffineIterMap's result vars are 
assumed stable
+    (*iter_doms)[var.get()].Union(BlockVarDomainInfo{kv.second, (*it).second});
+  }
+  return true;
 }
 
 /*!
@@ -501,13 +583,10 @@ std::vector<BlockVarDomainInfo> CalculateBlockVarDomain(
     NDIntSet provided_region = support::NDIntSetUnion(many_provided_regions);
     ICHECK_EQ(provided_region.size(), buffer->shape.size());
     ICHECK_EQ(required_region.size(), buffer->shape.size());
-    // For each dimension, update the iteration domain
-    int ndim = buffer->shape.size();
-    for (int i = 0; i < ndim; ++i) {
-      arith::IntSet provided = provided_region[i];
-      arith::IntSet required = required_region[i];
-      arith::IntSet required_bound = arith::IntSet::FromMinExtent(Integer(0), 
buffer->shape[i]);
-      UpdateBlockVarDomain(provided, required, required_bound, &iter_doms, 
analyzer);
+    // Try update iter var domains with current required and provided region 
pair.
+    if (!UpdateBlockVarDomainAffine(buffer, iter_vars, provided_region, 
required_region, analyzer,
+                                    &iter_doms)) {
+      UpdateBlockVarDomainDimwise(buffer, provided_region, required_region, 
analyzer, &iter_doms);
     }
   }
   // Union the iter var domains, put them in the same order of block vars, and 
return
diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py 
b/tests/python/unittest/test_tir_schedule_compute_at.py
index f94347409a..364a43acda 100644
--- a/tests/python/unittest/test_tir_schedule_compute_at.py
+++ b/tests/python/unittest/test_tir_schedule_compute_at.py
@@ -1174,6 +1174,40 @@ def test_compute_at_tiled_repeat_op(use_block_name):
     verify_trace_roundtrip(sch=sch, mod=tiled_repeat_op)
 
 
+def test_compute_at_rev_iter():
+    @T.prim_func
+    def before(X: T.Buffer[(10, 10), "float32"], Z: T.Buffer[(10, 10), 
"float32"]):
+        Y = T.alloc_buffer([10, 10], "float32")
+        for i, j in T.grid(10, 10):
+            with T.block("b0"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                Y[9 - vi, 9 - vj] = X[vi, vj] + 1.0
+        for i, j in T.grid(10, 10):
+            with T.block("b1"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                Z[vi, vj] = Y[vj, vi] + 2.0
+
+    @T.prim_func
+    def after(X: T.Buffer[(10, 10), "float32"], Z: T.Buffer[(10, 10), 
"float32"]):
+        Y = T.alloc_buffer([10, 10], "float32")
+        for i in range(10):
+            for j in range(10):
+                with T.block("b0"):
+                    vi = T.axis.spatial(10, j)
+                    vj = T.axis.spatial(10, 9 - i)
+                    Y[9 - vi, 9 - vj] = X[vi, vj] + 1.0
+            for j in range(10):
+                with T.block("b1"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    Z[vi, vj] = Y[vj, vi] + 2.0
+
+    sch = tir.Schedule(before, debug_mask="all")
+    axis = sch.get_loops(sch.get_block("b1"))[0]
+    sch.compute_at(sch.get_block("b0"), axis)
+    tvm.ir.assert_structural_equal(after, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=before)
+
+
 def test_reverse_compute_at_tiled(use_block_name):
     sch = tir.Schedule(tiled, debug_mask="all")
     block = sch.get_block("C")
@@ -1557,5 +1591,45 @@ def test_reverse_compute_at_with_unit_loop():
     tvm.ir.assert_structural_equal(main_reverse_compute_at, sch.mod["main"])
 
 
+def test_reverse_compute_at_layout_trans():
+    @T.prim_func
+    def before(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 
5, 8), "float32")):
+        B = T.alloc_buffer((1, 3, 5, 5, 16))
+        for i0, i1, i2, i3, i4 in T.grid(1, 3, 5, 5, 16):
+            with T.block("compute"):
+                v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, 
i2, i3, i4])
+                B[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, 
v_i4] + T.float32(1)
+        for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 6, 5, 5, 8):
+            with T.block("T_layout_trans"):
+                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", 
[ax0, ax1, ax2, ax3, ax4])
+                C[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = B[
+                    v_ax0, (v_ax1 * 8 + v_ax4) // 16, v_ax2, v_ax3, (v_ax1 * 8 
+ v_ax4) % 16
+                ]
+
+    @T.prim_func
+    def after(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 
5, 8), "float32")):
+        B = T.alloc_buffer((1, 3, 5, 5, 16))
+        for i0, i1 in T.grid(1, 3):
+            for i2, i3, i4 in T.grid(5, 5, 16):
+                with T.block("compute"):
+                    v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, 
i1, i2, i3, i4])
+                    B[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, 
v_i3, v_i4] + T.float32(1)
+            for ax0, ax1, ax2, ax3 in T.grid(2, 5, 5, 8):
+                with T.block("T_layout_trans"):
+                    v_ax0 = T.axis.spatial(1, 0)
+                    v_ax1 = T.axis.spatial(6, i1 * 2 + ax0)
+                    v_ax2, v_ax3, v_ax4 = T.axis.remap("SSS", [ax1, ax2, ax3])
+                    C[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = B[
+                        v_ax0, (v_ax1 * 8 + v_ax4) // 16, v_ax2, v_ax3, (v_ax1 
* 8 + v_ax4) % 16
+                    ]
+
+    sch = tir.Schedule(before, debug_mask="all")
+    trans = sch.get_block("T_layout_trans")
+    axis = sch.get_loops("compute")[1]
+    sch.reverse_compute_at(trans, axis)
+    tvm.ir.assert_structural_equal(after, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=before)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to