This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 585f9ce Tighten split's extent (#4931)
585f9ce is described below
commit 585f9ce6e7bef7d0e8902b1c1e55dcb3bbe84eed
Author: Lianmin Zheng <[email protected]>
AuthorDate: Wed Mar 4 10:52:02 2020 -0800
Tighten split's extent (#4931)
* Set split node's range to minimum of ext and split factor or split
nparts, but only when PassDownDomain is called with allow_missing == false,
i.e. by InferBound. Add a helper PassUpThreadBinding() to get a map telling
whether an IterVar has at least one leaf IterVar deriving from it binding to a
thread. Add two unit tests.
* Enhance LoopVectorizer for vectorizing by 0. Found at least one case
from testtopi/tests/python/test_topi_transform.py::test_tile.
* Revert changes vectorize_loop.cc; when parent's ext is zero, set split's
range to the factor or nparts.
* Update with comments.
* Refactor the ext tightening predicate.
* Fix reference types.
* Integrate tvm.te changes.
* Trivial comment change to trigger CI.
* Trivial comment correction to trigger testing.
---
src/te/schedule/message_passing.cc | 76 +++++++++++++++++++++-
.../unittest/test_schedule_bound_inference.py | 26 ++++++++
2 files changed, 99 insertions(+), 3 deletions(-)
diff --git a/src/te/schedule/message_passing.cc
b/src/te/schedule/message_passing.cc
index 5b6fa86..a7b2482 100644
--- a/src/te/schedule/message_passing.cc
+++ b/src/te/schedule/message_passing.cc
@@ -51,17 +51,66 @@ void Update(std::unordered_map<IterVar, Range>* p_state,
}
}
+/*!
+ * \param Upward propagating whether an IterVar derives at least one leaf
IterVar that binds to
+ * a thread.
+ *
+ * \param stage The stage to operate on.
+ * \param p_state The propagation result of each IterVar.
+ */
+void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar,
bool>* p_state) {
+ auto bound_to_thread = [&stage](const IterVar& iv) {
+ bool bound = false;
+ auto it = stage->iter_var_attrs.find(iv);
+ if (it != stage->iter_var_attrs.end()) {
+ bound = (*it).second->bind_thread.defined();
+ }
+ return bound;
+ };
+
+ auto& state = *p_state;
+ // Fill p_state with leaf itervars
+ for (const IterVar& iv : stage->leaf_iter_vars) {
+ state[iv] = bound_to_thread(iv);
+ }
+ // Traverse the graph bottom-up to propagate thread binding information
+ for (size_t i = stage->relations.size(); i != 0; --i) {
+ IterVarRelation rel = stage->relations[i - 1];
+ if (const SplitNode* s = rel.as<SplitNode>()) {
+ state[s->parent] = state[s->inner] || state[s->outer];
+ } else if (const FuseNode* s = rel.as<FuseNode>()) {
+ state[s->inner] = state[s->fused];
+ state[s->outer] = state[s->fused];
+ } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+ state[s->parent] = state[s->rebased];
+ } else if (rel.as<SingletonNode>()) {
+ } else {
+ LOG(FATAL) << "unknown relation type";
+ }
+ }
+}
+
void PassDownDomain(const Stage& stage,
std::unordered_map<IterVar, Range>* p_state,
arith::Analyzer* actx,
bool allow_missing) {
- auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
+ auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) {
if (actx->CanProve(indexmod(a, b) == 0)) {
return actx->Simplify(indexdiv(a, b));
}
return actx->Simplify(indexdiv(a + (b - 1), b));
};
+ auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) {
+ if (actx->CanProve(a < b)) {
+ return actx->Simplify(a);
+ }
+ return actx->Simplify(b);
+ };
+
+ std::unordered_map<IterVar, bool> dominating_thread;
+ PassUpThreadBinding(stage, &dominating_thread);
+
auto& state = *p_state;
// forwar iteration on relations
for (IterVarRelation rel : stage->relations) {
@@ -72,14 +121,35 @@ void PassDownDomain(const Stage& stage,
}
CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent);
+ // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if
all of the
+ // following conditions are met:
+ // 1. No leaf IterVar derived from iv binds to any thread. People may
use split
+ // to force an IterVar extent to match the number of allocated threads
to fuse stages
+ // that require different number of threads. We don't want to change
these extents.
+ // 2. allow_missing is false, i.e. that PassDownDomain is called by the
final InferBound,
+ // rather than by an early compiler phase, such as rfactor(). We don't
want to tighten an
+ // IterVar in an early phase allowing missing IterVars, because it may
bind to a thread later.
+ // 3. range_parent's extent is not 0. At lest one Topi test has a case
where a tensor has one
+ // zero-sized dimension. Split creates iv with a positive extent to
avoid zero-extent
+ // IterVar. We don't touch it.
+ auto resolve_min_extent_for_split = [&](const IterVar& iv, const
PrimExpr& factor_or_nparts) {
+ return dominating_thread[iv] || allow_missing ||
is_zero(range_parent->extent)
+ ? factor_or_nparts
+ : minimum_or_later(range_parent->extent, factor_or_nparts);
+ };
if (r->factor.defined()) {
Update(p_state, r->inner,
- Range::make_by_min_extent(0, r->factor), actx);
+ Range::make_by_min_extent(
+ 0, resolve_min_extent_for_split(r->inner, r->factor)),
+ actx);
Update(p_state, r->outer,
Range::make_by_min_extent(
0, ceil_div(range_parent->extent, r->factor)), actx);
} else {
- Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts),
actx);
+ Update(p_state, r->outer,
+ Range::make_by_min_extent(
+ 0, resolve_min_extent_for_split(r->outer, r->nparts)),
+ actx);
Update(p_state, r->inner,
Range::make_by_min_extent(
0, ceil_div(range_parent->extent, r->nparts)), actx);
diff --git a/tests/python/unittest/test_schedule_bound_inference.py
b/tests/python/unittest/test_schedule_bound_inference.py
index 484aa50..edae527 100644
--- a/tests/python/unittest/test_schedule_bound_inference.py
+++ b/tests/python/unittest/test_schedule_bound_inference.py
@@ -70,6 +70,32 @@ def test_bound3():
assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)
+def test_bound_split_ext_less_than_factor():
+ m = 8
+ I = te.placeholder((m,), name='I')
+ EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
+ E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
+ s = te.create_schedule([E.op])
+ xo, xi = s[E].split(s[E].op.axis[0], factor = 32)
+ s[EF].compute_at(s[E], xo)
+
+ bounds = tvm.te.schedule.InferBound(s)
+ assert isinstance(bounds, tvm.container.Map)
+ assert bounds[xi].extent.value == m
+
+def test_bound_split_ext_less_than_naprts():
+ m = 8
+ I = te.placeholder((m,), name='I')
+ EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
+ E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
+ s = te.create_schedule([E.op])
+ xo, xi = s[E].split(s[E].op.axis[0], nparts = 32)
+ s[EF].compute_at(s[E], xo)
+
+ bounds = tvm.te.schedule.InferBound(s)
+ assert isinstance(bounds, tvm.container.Map)
+ assert bounds[xo].extent.value == m
+
def test_bound_split_divisible():
m = te.var('m')
l = te.var('l')