This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 55b6be598a [TIR] Support affine expressions as indices in reverse 
compute inline (#11317)
55b6be598a is described below

commit 55b6be598a17f3ffc6845db834a1182d447c4c3c
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon May 16 15:20:39 2022 -0700

    [TIR] Support affine expressions as indices in reverse compute inline 
(#11317)
    
    * [TIR] Support affine expressions as indices in reverse compute inline
    
    * fix trivial iterators
---
 src/tir/schedule/primitive/compute_inline.cc       | 174 ++++++++++++++-------
 .../unittest/test_tir_schedule_compute_inline.py   | 140 +++++++++++++++++
 2 files changed, 256 insertions(+), 58 deletions(-)

diff --git a/src/tir/schedule/primitive/compute_inline.cc 
b/src/tir/schedule/primitive/compute_inline.cc
index 630a72cede..452f72e722 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -24,12 +24,13 @@ namespace tir {
 static const char kErrBodyInline[] = R"(The body of the inlined block should 
be in form of
     'A[i, j, k, ...] = f(i, j, k, ...)',
 where the indices on the left are distinct atomic variables,
-and there should not no variables other than the index variables)";
+and there should be no variables other than the index variables)";
 
 static const char kErrBodyReverseInline[] = R"(The body of the inlined block 
should be in form of
-    `B[...] = g(i, j, k, A[i, j, k, ...] ...)`,
+    `B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`,
 where A is the only buffer the block consumes, whose indices are distinct 
atomic variables,
-and there should not no variables other than the index variables)";
+and there should be no variables other than the index variables), and f is a 
bijective affine
+mapping)";
 
 class HasInitBlock : public ScheduleError {
  public:
@@ -257,57 +258,6 @@ class BaseInliner : public StmtExprMutator {
     return std::move(tgt_block);
   }
 
-  /*!
-   * \brief Check if the indices are atomic distinct variables and the access 
is n-dimensional.
-   * If so, set `self->idx_vars_` properly.
-   * \param indices The indices to be extracted
-   * \param expected_ndim The expected ndim of the access
-   * \return A boolean flag indicating if the check is successful
-   */
-  bool UpdateAndCheckIndexVars(const Array<PrimExpr>& indices, int 
expected_ndim) {
-    int n = indices.size();
-    if (n != expected_ndim) {
-      // Failure: dimension mismatch
-      return false;
-    }
-    std::vector<const VarNode*> result;
-    result.reserve(n);
-    for (const PrimExpr& i : indices) {
-      if (const auto* var = i.as<VarNode>()) {
-        result.push_back(var);
-      } else {
-        // Failure: indexing expression is not a variable
-        return false;
-      }
-    }
-    using DistinctSet = std::unordered_set<const VarNode*>;
-    int n_distinct = DistinctSet(result.begin(), result.end()).size();
-    if (n != n_distinct) {
-      // Failure: indexing variables are not distinct
-      return false;
-    }
-    if (idx_vars_.empty()) {
-      idx_vars_ = std::move(result);
-    } else if (!support::ArrayWithSameContent(idx_vars_, result)) {
-      // Failure: indexing variables are not consitent in different BufferLoads
-      return false;
-    }
-    return true;
-  }
-
-  /*!
-   * \brief Set the mapping of index substitution `self->idx_sub_`
-   * \param indices The expressions that the corresponding index variables are 
replaced to
-   */
-  void SetIndexSubstitution(const Array<PrimExpr>& indices) {
-    ICHECK_EQ(indices.size(), idx_vars_.size());
-    int n = idx_vars_.size();
-    idx_sub_.reserve(n);
-    for (int i = 0; i < n; ++i) {
-      idx_sub_[idx_vars_[i]] = indices[i];
-    }
-  }
-
   /*!
    * \brief Count the number of undefined variables that are not used
    * as buffer objects.
@@ -490,6 +440,57 @@ class ComputeInliner : public BaseInliner {
     SetIndexSubstitution(load->indices);
     return Substitute(inlined_store_->value, idx_sub_);
   }
+
+  /*!
+   * \brief Check if the indices are atomic distinct variables and the access 
is n-dimensional.
+   * If so, set `self->idx_vars_` properly.
+   * \param indices The indices to be extracted
+   * \param expected_ndim The expected ndim of the access
+   * \return A boolean flag indicating if the check is successful
+   */
+  bool UpdateAndCheckIndexVars(const Array<PrimExpr>& indices, int 
expected_ndim) {
+    int n = indices.size();
+    if (n != expected_ndim) {
+      // Failure: dimension mismatch
+      return false;
+    }
+    std::vector<const VarNode*> result;
+    result.reserve(n);
+    for (const PrimExpr& i : indices) {
+      if (const auto* var = i.as<VarNode>()) {
+        result.push_back(var);
+      } else {
+        // Failure: indexing expression is not a variable
+        return false;
+      }
+    }
+    using DistinctSet = std::unordered_set<const VarNode*>;
+    int n_distinct = DistinctSet(result.begin(), result.end()).size();
+    if (n != n_distinct) {
+      // Failure: indexing variables are not distinct
+      return false;
+    }
+    if (idx_vars_.empty()) {
+      idx_vars_ = std::move(result);
+    } else if (!support::ArrayWithSameContent(idx_vars_, result)) {
+      // Failure: indexing variables are not consitent in different BufferLoads
+      return false;
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Set the mapping of index substitution `self->idx_sub_`
+   * \param indices The expressions that the corresponding index variables are 
replaced to
+   */
+  void SetIndexSubstitution(const Array<PrimExpr>& indices) {
+    ICHECK_EQ(indices.size(), idx_vars_.size());
+    int n = idx_vars_.size();
+    idx_sub_.reserve(n);
+    for (int i = 0; i < n; ++i) {
+      idx_sub_[idx_vars_[i]] = indices[i];
+    }
+  }
 };
 
 /*!
@@ -534,13 +535,34 @@ class ReverseComputeInliner : public BaseInliner {
       // Failure: no BufferLoad from the `inlined_buffer_`
       return false;
     }
-    int n_vars = GetNumUndefinedNonpointerVars(GetRef<Stmt>(inlined_store_));
+
+    // Collect block iter domains and update the substition map
+    Map<Var, Range> consumer_iter_doms;
+    for (const auto& iter_var : consumer_block->iter_vars) {
+      consumer_iter_doms.Set(iter_var->var, iter_var->dom);
+      // Set default mapping for unit iters
+      if (is_const_int(iter_var->dom->extent, 1) && 
is_const_int(iter_var->dom->min)) {
+        idx_sub_[iter_var->var.get()] = iter_var->dom->min;
+      }
+    }
+
     for (const BufferLoadNode* load : loads) {
-      if (!UpdateAndCheckIndexVars(load->indices, n_vars)) {
-        // Failure: incorrect of inconsistent index vars
+      if (!UpdateAndCheckIndexExprs(load->indices)) {
         return false;
       }
     }
+
+    buffer_load_iter_map_ = arith::DetectIterMap(
+        /*indices=*/buffer_load_indices_,
+        /*input_iters=*/consumer_iter_doms,
+        /*predicate=*/true,
+        /*require_bijective=*/true,
+        /*analyzer=*/&analyzer,
+        /*simplify_trivial_iterators=*/false);
+    if (buffer_load_iter_map_.empty()) {
+      // Failure: indices of BufferLoad are not bijective affine
+      return false;
+    }
     return true;
   }
 
@@ -556,8 +578,20 @@ class ReverseComputeInliner : public BaseInliner {
     return ReplaceInlinedBuffer(std::move(store));
   }
 
+  /*!
+   * \brief Apply the inverse of `buffer_load_iter_map_` to producer indices. 
Update `idx_sub_` with
+   *        the result. It will be later used to transform the BufferStore 
indices of the producer.
+   * \param producer_indices The BufferStore indices of the producer.
+   */
+  void CreateInverseMapping(const Array<PrimExpr> producer_indices) {
+    auto inverse_iter_map = arith::InverseAffineIterMap(buffer_load_iter_map_, 
producer_indices);
+    for (const auto& pair : inverse_iter_map) {
+      idx_sub_[pair.first.get()] = pair.second;
+    }
+  }
+
   Stmt ReplaceInlinedBuffer(BufferStore producer) {
-    SetIndexSubstitution(producer->indices);
+    CreateInverseMapping(producer->indices);
     producer_rhs_ = producer->value;
     return Substituter(this)(GetRef<BufferStore>(inlined_store_));
   }
@@ -588,8 +622,32 @@ class ReverseComputeInliner : public BaseInliner {
     return std::move(extractor.result);
   }
 
+  /*!
+   * \brief Update `buffer_load_indices_` with the given indices. If 
`buffer_load_indices_` is
+   *        already non-empty, check it is consistent with the given indices.
+   * \param indices The indices
+   * \param expected_ndim The expected ndim of the access
+   * \return A boolean flag indicating if the check is successful
+   */
+  bool UpdateAndCheckIndexExprs(const Array<PrimExpr>& indices) {
+    if (buffer_load_indices_.empty()) {
+      buffer_load_indices_ = indices;
+    } else if (!std::equal(buffer_load_indices_.begin(), 
buffer_load_indices_.end(),
+                           indices.begin(), indices.end(), ExprDeepEqual())) {
+      // Failure: indices are not consistent in different BufferLoads
+      return false;
+    }
+    return true;
+  }
+
   /*! \brief The RHS value of the producer's BufferStore statement */
   PrimExpr producer_rhs_{nullptr};
+  /*! \brief The indices of the consumer's BufferLoad */
+  Array<PrimExpr> buffer_load_indices_;
+  /*! \brief The IterMap representing the indices of the consumer's BufferLoad 
*/
+  Array<arith::IterSumExpr> buffer_load_iter_map_{nullptr};
+  /*! \brief The arithmetic analyzer */
+  arith::Analyzer analyzer;
 };
 
 void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py 
b/tests/python/unittest/test_tir_schedule_compute_inline.py
index 8894cd4d9f..057d808ca4 100644
--- a/tests/python/unittest/test_tir_schedule_compute_inline.py
+++ b/tests/python/unittest/test_tir_schedule_compute_inline.py
@@ -169,6 +169,112 @@ def elementwise_multi_reverse_loads_inlined(a: T.handle, 
c: T.handle) -> None:
             C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0
 
 
[email protected]_func
+def elementwise_reverse_affine_load(
+    A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(8, 32, 8, 8), "float32"]
+) -> None:
+    B = T.alloc_buffer((128, 128))
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vi, vj] * 2.0
+    for i, j, k, l in T.grid(8, 32, 8, 8):
+        with T.block("C"):
+            vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
+            C[vi, vj, vk, vl] = B[
+                ((((vi * 32) + vj) * 8 + vk) * 8 + vl) // 128,
+                ((((vi * 32) + vj) * 8 + vk) * 8 + vl) % 128,
+            ]
+
+
[email protected]_func
+def elementwise_reverse_affine_load_inlined(
+    A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(8, 32, 8, 8), "float32"]
+) -> None:
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            C[
+                (vj + vi * 128) // 2048,
+                (vj + vi * 128) // 64 % 32,
+                ((vj + vi * 128) // 8) % 8,
+                (vj + vi * 128) % 8,
+            ] = (
+                A[vi, vj] * 2.0
+            )
+
+
[email protected]_func
+def elementwise_reverse_affine_load_unit_iter(
+    A: T.Buffer[(128, 128), "float32"],
+    B: T.Buffer[(8, 16, 1), "float32"],
+    D: T.Buffer[(1, 8, 16, 128), "float32"],
+) -> None:
+    C = T.alloc_buffer((128, 128))
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            C[vi, vj] = A[vi, vj] * 2.0
+    for i, j, k in T.grid(8, 16, 128):
+        with T.block("C"):
+            vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+            D[0, vi, vj, vk] = C[vi * 16 + vj, vk] + B[vi, vj, 0]
+
+
[email protected]_func
+def elementwise_reverse_affine_load_unit_iter_inlined(
+    A: T.Buffer[(128, 128), "float32"],
+    B: T.Buffer[(8, 16, 1), "float32"],
+    D: T.Buffer[(1, 8, 16, 128), "float32"],
+) -> None:
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            D[0, vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + B[vi // 16, vi % 
16, 0]
+
+
[email protected]_func
+def elementwise_multi_reverse_affine_load(
+    A: T.Buffer[(128, 128), "float32"],
+    C: T.Buffer[(8, 16, 128), "float32"],
+) -> None:
+    B = T.alloc_buffer((128, 128))
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vi, vj] * 2.0
+    for i, j, k in T.grid(8, 16, 128):
+        with T.block("C"):
+            vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+            C[vi, vj, vk] = B[vi * 16 + vj, vk] + B[vi * 16 + vj, vk]
+
+
[email protected]_func
+def elementwise_multi_reverse_affine_load_inlined(
+    A: T.Buffer[(128, 128), "float32"],
+    C: T.Buffer[(8, 16, 128), "float32"],
+) -> None:
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            C[vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + A[vi, vj] * 2.0
+
+
[email protected]_func
+def elementwise_reverse_non_affine_load(
+    A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(8, 16, 128), "float32"]
+) -> None:
+    B = T.alloc_buffer((128, 128))
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vi, vj] * 2.0
+    for i, j, k in T.grid(8, 16, 128):
+        with T.block("C"):
+            vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+            C[vi, vj, vk] = B[vi * 16 + vj, vi * 16 + vj]
+
+
 @T.prim_func
 def opaque_access_load(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (128, 128))
@@ -520,6 +626,40 @@ def test_reverse_compute_multi_reverse_loads():
     verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_loads)
 
 
+def test_reverse_compute_inline_affine_load():
+    sch = tir.Schedule(elementwise_reverse_affine_load, debug_mask="all")
+    block_c = sch.get_block("C")
+    sch.reverse_compute_inline(block_c)
+    tvm.ir.assert_structural_equal(elementwise_reverse_affine_load_inlined, 
sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load)
+
+
+def test_reverse_compute_inline_multi_affine_load():
+    sch = tir.Schedule(elementwise_multi_reverse_affine_load, debug_mask="all")
+    block_c = sch.get_block("C")
+    sch.reverse_compute_inline(block_c)
+    
tvm.ir.assert_structural_equal(elementwise_multi_reverse_affine_load_inlined, 
sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_affine_load)
+
+
+def test_reverse_compute_inline_affine_load_unit_iter():
+    sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter, 
debug_mask="all")
+    block_c = sch.get_block("C")
+    sch.reverse_compute_inline(block_c)
+    print(sch.mod.script())
+    tvm.ir.assert_structural_equal(
+        elementwise_reverse_affine_load_unit_iter_inlined, sch.mod["main"]
+    )
+    verify_trace_roundtrip(sch=sch, 
mod=elementwise_reverse_affine_load_unit_iter)
+
+
+def test_reverse_compute_fail_non_affine_load():
+    sch = tir.Schedule(elementwise_reverse_non_affine_load, debug_mask="all")
+    block_c = sch.get_block("C")
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.reverse_compute_inline(block_c)
+
+
 def test_reverse_compute_fail_multi_reverse_loads():
     sch = tir.Schedule(elementwise_multi_loads, debug_mask="all")
     block_c = sch.get_block("C")

Reply via email to