This is an automated email from the ASF dual-hosted git repository.
wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 946581ab56 [TIR][Compute-at] Utilize InverseAffineIterMap for dom
estimation (#14184)
946581ab56 is described below
commit 946581ab56ab0a74aab30338f537ee4dca20aad4
Author: wrongtest <[email protected]>
AuthorDate: Thu Mar 16 16:42:29 2023 +0800
[TIR][Compute-at] Utilize InverseAffineIterMap for dom estimation (#14184)
utilize inverse iter map tool for compute_at iter region estimation
---
src/arith/iter_affine_map.cc | 3 +-
src/tir/schedule/primitive/compute_at.cc | 179 +++++++++++++++------
.../unittest/test_tir_schedule_compute_at.py | 74 +++++++++
3 files changed, 205 insertions(+), 51 deletions(-)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index af6e47b7a0..05af5b4070 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -2147,7 +2147,8 @@ class InverseAffineIterMapTransformer {
// Case 1: Propagate to the input node directly when the sum expression
has only one components
if (iter_map_expr->args.size() == 1) {
const auto& source = iter_map_expr->args[0];
- backprop_.Set(source, backprop_.at(source) + input);
+ ICHECK(analyzer_->CanProveEqual(abs(source->scale), 1));
+ backprop_.Set(source, (backprop_.at(source) + input) * source->scale);
return;
}
diff --git a/src/tir/schedule/primitive/compute_at.cc
b/src/tir/schedule/primitive/compute_at.cc
index 5a5a536157..988c73c3f0 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -391,81 +391,163 @@ void RelaxBufferRegions(const Map<Var, PrimExpr>&
binding,
* domain
* \param provided The provided integer set to cover the required domain
* \param required The required domain to be covered
+ * \param dim_max The maximum index bound by the buffer shape
* \param analyzer The arithmetic analyzer
*/
-std::pair<Var, arith::IntSet> SolveBlockVarDomain(const arith::IntSet&
provided,
- const arith::IntSet&
required,
- arith::Analyzer* analyzer) {
+std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet&
provided,
+ const arith::IntSet&
required,
+ PrimExpr dim_max,
+ arith::Analyzer*
analyzer) {
PrimExpr provided_min = analyzer->Simplify(provided.min());
PrimExpr provided_max = analyzer->Simplify(provided.max());
PrimExpr required_min = analyzer->Simplify(required.min());
PrimExpr required_max = analyzer->Simplify(required.max());
- PrimExpr dom_min{nullptr}, dom_max{nullptr};
- Var dom_var{ObjectPtr<VarNode>{nullptr}};
+ arith::IntSet var_dom, var_bound;
+ Optional<Var> var;
arith::PVar<Var> p_v;
arith::PVar<PrimExpr> p_e;
if ((p_v * p_e).Match(provided_min) || (p_e * p_v).Match(provided_min)) {
PrimExpr e = p_e.Eval();
- dom_var = p_v.Eval();
- dom_min = floordiv(required_min, e);
- dom_max = floordiv(required_max, e);
+ var = p_v.Eval();
+ var_dom = arith::IntSet::Interval(floordiv(required_min, e),
floordiv(required_max, e));
+ var_bound = arith::IntSet::Interval(0, floordiv(dim_max, e));
} else if (analyzer->CanProveEqual(provided_min, provided_max)) {
if (p_v.Match(provided_min)) {
- dom_var = p_v.Eval();
- dom_min = required_min;
- dom_max = required_max;
+ var = p_v.Eval();
+ var_dom = arith::IntSet::Interval(required_min, required_max);
+ var_bound = arith::IntSet::Interval(0, dim_max);
} else {
arith::PVar<PrimExpr> p_f;
if ((floordiv(p_v, p_f)).Match(provided_min)) {
// a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac +
fac - 1)
PrimExpr fac = p_f.Eval();
if (analyzer->CanProveGreaterEqual(fac, 1)) {
- dom_var = p_v.Eval();
- dom_min = required_min * fac;
- dom_max = analyzer->Simplify(required_max * fac + fac - 1);
+ var = p_v.Eval();
+ var_dom = arith::IntSet::Interval(required_min * fac,
+ analyzer->Simplify(required_max *
fac + fac - 1));
+ var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max *
fac + fac - 1));
}
} else if ((floormod(p_v, p_f).Match(provided_min))) {
// generally domain of (x % fac) enforce no constraints to domain of x
- dom_var = p_v.Eval();
- return std::make_pair(dom_var, arith::IntSet::Nothing());
+ return {p_v.Eval(), BlockVarDomainInfo()};
}
}
}
- ICHECK(dom_var.defined()) << "ValueError: BufferRegion pattern match failed:
" << provided_min;
- return std::make_pair(dom_var, arith::IntSet::Interval(dom_min, dom_max));
+ ICHECK(var.defined()) << "ValueError: BufferRegion pattern match failed: "
<< provided_min;
+ return {var.value(), BlockVarDomainInfo{var_dom, var_bound}};
}
/*!
- * \brief Calculate and update the iteration domain info to fully cover the
required domain
- * \param provided The provided integer set to cover the required domain
- * \param required The required domain to be covered
- * \param required_bound The additional region bound of the required domain to
be covered
+ * \brief Calculate and update the iteration domain info to fully cover the
required domain in
+ * dimension-wise fashion. The region relation on each buffer dimension is
independently estimated.
+ * \param buffer The accessed buffer
+ * \param provided_region The provided NDIntSet to cover the required domain
+ * \param required_region The required NDIntSet domain to be covered
+ * \param analyzer The arithmetic analyzer
* \param iter_doms The result iteration domains to be updated
+ */
+void UpdateBlockVarDomainDimwise(
+ const BufferNode* buffer, const NDIntSet& provided_region, const NDIntSet&
required_region,
+ arith::Analyzer* analyzer, std::unordered_map<const VarNode*,
BlockVarDomainInfo>* iter_doms) {
+ size_t ndim = buffer->shape.size();
+ for (size_t i = 0; i < ndim; ++i) {
+ arith::IntSet provided = provided_region[i];
+ arith::IntSet required = required_region[i];
+ PrimExpr dim_max = max(buffer->shape[i] - 1, 0);
+
+ if (provided.IsSinglePoint() && is_const_int(provided.min())) {
+ ICHECK(required.IsSinglePoint() &&
analyzer->CanProveEqual(provided.min(), required.min()));
+ continue;
+ }
+
+ auto [var, dom_info] = SolveBlockVarDomain(provided, required, dim_max,
analyzer);
+ auto it = iter_doms->find(var.get());
+ if (it != iter_doms->end()) {
+ it->second.Union(dom_info);
+ } else {
+ ICHECK(analyzer->CanProveEqual(provided.min(), required.min()));
+ ICHECK(analyzer->CanProveEqual(provided.max(), required.max()));
+ }
+ }
+}
+
+/*! \brief Helper function to implement intset version of
`InverseAffineIterMap`. */
+Map<Var, arith::IntSet> InverseAffineIterMap(const Array<arith::IterSumExpr>&
iter_map,
+ const NDIntSet& outputs,
arith::Analyzer* analyzer) {
+ Array<PrimExpr> min_point, max_point;
+ min_point.reserve(outputs.size());
+ max_point.reserve(outputs.size());
+ for (const auto& intset : outputs) {
+ ICHECK(intset.HasLowerBound() && intset.HasUpperBound());
+ min_point.push_back(intset.min());
+ max_point.push_back(intset.max());
+ }
+ auto rev_min = InverseAffineIterMap(iter_map, min_point);
+ auto rev_max = InverseAffineIterMap(iter_map, max_point);
+ Map<Var, arith::IntSet> dom_map;
+ for (const auto& kv : rev_min) {
+ const Var& var = kv.first;
+ auto it = rev_max.find(var);
+ ICHECK(it != rev_max.end()); // InverseAffineIterMap's result vars are
assumed stable
+ const PrimExpr& rev_min_point = kv.second;
+ const PrimExpr& rev_max_point = (*it).second;
+ dom_map.Set(var,
+ arith::IntSet::Interval(analyzer->Simplify(min(rev_min_point,
rev_max_point)),
+ analyzer->Simplify(max(rev_min_point,
rev_max_point))));
+ }
+ return dom_map;
+}
+
+/*!
+ * \brief Calculate and update the iteration domain info to fully cover the
required domain
+ * with affine analysis. It requires bijective mapping of block var to
provided region points.
+ * \param buffer The accessed buffer
+ * \param iter_vars The list of block vars to cover the required region
+ * \param provided_region The provided NDIntSet to cover the required domain
+ * \param required_region The required NDIntSet domain to be covered
* \param analyzer The arithmetic analyzer
+ * \param iter_doms The result iteration domains to be updated
+ * \returns bool. Denotes whether update success
*/
-void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet&
required,
- const arith::IntSet& required_bound,
- std::unordered_map<const VarNode*,
BlockVarDomainInfo>* iter_doms,
- arith::Analyzer* analyzer) {
- if (provided.IsSinglePoint() && is_const_int(provided.min())) {
- ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(),
required.min()));
- ICHECK(required_bound.IsSinglePoint() &&
- analyzer->CanProveEqual(provided.min(), required_bound.min()));
- return;
+bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const
Array<IterVar>& iter_vars,
+ const NDIntSet& provided_region, const
NDIntSet& required_region,
+ arith::Analyzer* analyzer,
+ std::unordered_map<const VarNode*,
BlockVarDomainInfo>* iter_doms) {
+ // we only support single point provided region now, which could cover most
cases
+ for (const auto& intset : provided_region) {
+ if (!intset.IsSinglePoint()) return false;
+ }
+ // calculate forward mapping (block vars -> provided region point)
+ Map<Var, Range> dom_map;
+ for (const IterVar& iter_var : iter_vars) {
+ dom_map.Set(iter_var->var, iter_var->dom);
}
- auto var_with_dom = SolveBlockVarDomain(provided, required, analyzer);
- auto var_with_bound = SolveBlockVarDomain(provided, required_bound,
analyzer);
- const Var& var = var_with_dom.first;
- const auto& var_dom = var_with_dom.second;
- const auto& var_bound = var_with_bound.second;
- ICHECK(var.same_as(var_with_bound.first));
- auto it = iter_doms->find(var.get());
- if (it != iter_doms->end()) {
- it->second.Union({var_dom, var_bound});
- } else {
- ICHECK(analyzer->CanProveEqual(provided.min(), required.min()));
- ICHECK(analyzer->CanProveEqual(provided.max(), required.max()));
+ size_t ndim = buffer->shape.size();
+ Array<PrimExpr> provide_indices;
+ provide_indices.reserve(ndim);
+ for (size_t i = 0; i < ndim; ++i) {
+ provide_indices.push_back(provided_region[i].min());
+ }
+ auto res = arith::DetectIterMap(provide_indices, dom_map, const_true(),
+ arith::IterMapLevel::Bijective, analyzer,
false);
+ if (res->indices.empty()) {
+ return false;
}
+ // calculate backward mapping (required region point -> block vars)
+ NDIntSet required_bound;
+ for (size_t i = 0; i < ndim; ++i) {
+ required_bound.push_back(
+ arith::IntSet::Interval(make_zero(buffer->shape[i]->dtype),
max(buffer->shape[i] - 1, 0)));
+ }
+ Map<Var, arith::IntSet> var_dom = InverseAffineIterMap(res->indices,
required_region, analyzer);
+ Map<Var, arith::IntSet> var_bound = InverseAffineIterMap(res->indices,
required_bound, analyzer);
+ for (const auto& kv : var_dom) {
+ const Var& var = kv.first;
+ auto it = var_bound.find(var);
+ ICHECK(it != var_bound.end()); // InverseAffineIterMap's result vars are
assumed stable
+ (*iter_doms)[var.get()].Union(BlockVarDomainInfo{kv.second, (*it).second});
+ }
+ return true;
}
/*!
@@ -501,13 +583,10 @@ std::vector<BlockVarDomainInfo> CalculateBlockVarDomain(
NDIntSet provided_region = support::NDIntSetUnion(many_provided_regions);
ICHECK_EQ(provided_region.size(), buffer->shape.size());
ICHECK_EQ(required_region.size(), buffer->shape.size());
- // For each dimension, update the iteration domain
- int ndim = buffer->shape.size();
- for (int i = 0; i < ndim; ++i) {
- arith::IntSet provided = provided_region[i];
- arith::IntSet required = required_region[i];
- arith::IntSet required_bound = arith::IntSet::FromMinExtent(Integer(0),
buffer->shape[i]);
- UpdateBlockVarDomain(provided, required, required_bound, &iter_doms,
analyzer);
+ // Try update iter var domains with current required and provided region
pair.
+ if (!UpdateBlockVarDomainAffine(buffer, iter_vars, provided_region,
required_region, analyzer,
+ &iter_doms)) {
+ UpdateBlockVarDomainDimwise(buffer, provided_region, required_region,
analyzer, &iter_doms);
}
}
// Union the iter var domains, put them in the same order of block vars, and
return
diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py
b/tests/python/unittest/test_tir_schedule_compute_at.py
index f94347409a..364a43acda 100644
--- a/tests/python/unittest/test_tir_schedule_compute_at.py
+++ b/tests/python/unittest/test_tir_schedule_compute_at.py
@@ -1174,6 +1174,40 @@ def test_compute_at_tiled_repeat_op(use_block_name):
verify_trace_roundtrip(sch=sch, mod=tiled_repeat_op)
+def test_compute_at_rev_iter():
+ @T.prim_func
+ def before(X: T.Buffer[(10, 10), "float32"], Z: T.Buffer[(10, 10),
"float32"]):
+ Y = T.alloc_buffer([10, 10], "float32")
+ for i, j in T.grid(10, 10):
+ with T.block("b0"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ Y[9 - vi, 9 - vj] = X[vi, vj] + 1.0
+ for i, j in T.grid(10, 10):
+ with T.block("b1"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ Z[vi, vj] = Y[vj, vi] + 2.0
+
+ @T.prim_func
+ def after(X: T.Buffer[(10, 10), "float32"], Z: T.Buffer[(10, 10),
"float32"]):
+ Y = T.alloc_buffer([10, 10], "float32")
+ for i in range(10):
+ for j in range(10):
+ with T.block("b0"):
+ vi = T.axis.spatial(10, j)
+ vj = T.axis.spatial(10, 9 - i)
+ Y[9 - vi, 9 - vj] = X[vi, vj] + 1.0
+ for j in range(10):
+ with T.block("b1"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ Z[vi, vj] = Y[vj, vi] + 2.0
+
+ sch = tir.Schedule(before, debug_mask="all")
+ axis = sch.get_loops(sch.get_block("b1"))[0]
+ sch.compute_at(sch.get_block("b0"), axis)
+ tvm.ir.assert_structural_equal(after, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=before)
+
+
def test_reverse_compute_at_tiled(use_block_name):
sch = tir.Schedule(tiled, debug_mask="all")
block = sch.get_block("C")
@@ -1557,5 +1591,45 @@ def test_reverse_compute_at_with_unit_loop():
tvm.ir.assert_structural_equal(main_reverse_compute_at, sch.mod["main"])
+def test_reverse_compute_at_layout_trans():
+ @T.prim_func
+ def before(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5,
5, 8), "float32")):
+ B = T.alloc_buffer((1, 3, 5, 5, 16))
+ for i0, i1, i2, i3, i4 in T.grid(1, 3, 5, 5, 16):
+ with T.block("compute"):
+ v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1,
i2, i3, i4])
+ B[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3,
v_i4] + T.float32(1)
+ for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 6, 5, 5, 8):
+ with T.block("T_layout_trans"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS",
[ax0, ax1, ax2, ax3, ax4])
+ C[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = B[
+ v_ax0, (v_ax1 * 8 + v_ax4) // 16, v_ax2, v_ax3, (v_ax1 * 8
+ v_ax4) % 16
+ ]
+
+ @T.prim_func
+ def after(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5,
5, 8), "float32")):
+ B = T.alloc_buffer((1, 3, 5, 5, 16))
+ for i0, i1 in T.grid(1, 3):
+ for i2, i3, i4 in T.grid(5, 5, 16):
+ with T.block("compute"):
+ v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0,
i1, i2, i3, i4])
+ B[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2,
v_i3, v_i4] + T.float32(1)
+ for ax0, ax1, ax2, ax3 in T.grid(2, 5, 5, 8):
+ with T.block("T_layout_trans"):
+ v_ax0 = T.axis.spatial(1, 0)
+ v_ax1 = T.axis.spatial(6, i1 * 2 + ax0)
+ v_ax2, v_ax3, v_ax4 = T.axis.remap("SSS", [ax1, ax2, ax3])
+ C[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = B[
+ v_ax0, (v_ax1 * 8 + v_ax4) // 16, v_ax2, v_ax3, (v_ax1
* 8 + v_ax4) % 16
+ ]
+
+ sch = tir.Schedule(before, debug_mask="all")
+ trans = sch.get_block("T_layout_trans")
+ axis = sch.get_loops("compute")[1]
+ sch.reverse_compute_at(trans, axis)
+ tvm.ir.assert_structural_equal(after, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=before)
+
+
if __name__ == "__main__":
tvm.testing.main()