This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 585f9ce  Tighten split's extent (#4931)
585f9ce is described below

commit 585f9ce6e7bef7d0e8902b1c1e55dcb3bbe84eed
Author: Lianmin Zheng <[email protected]>
AuthorDate: Wed Mar 4 10:52:02 2020 -0800

    Tighten split's extent (#4931)
    
    * Set split node's range to minimum of ext and split factor or split 
nparts, but only when PassDownDomain is called with allow_missing == false, 
i.e. by InferBound.  Add a helper PassUpThreadBinding() to get a map telling 
whether an IterVar has at least one leaf IterVar deriving from it binding to a 
thread. Add two unit tests.
    
    * Enhance LoopVectorizer for vectorizing by 0.  Found at least one case 
from testtopi/tests/python/test_topi_transform.py::test_tile.
    
    * Revert changes vectorize_loop.cc; when parent's ext is zero, set split's 
range to the factor or nparts.
    
    * Update with comments.
    
    * Refactor the ext tightening predicate.
    
    * Fix reference types.
    
    * Integrate tvm.te changes.
    
    * Trivial comment change to trigger CI.
    
    * Trivial comment correction to trigger testing.
---
 src/te/schedule/message_passing.cc                 | 76 +++++++++++++++++++++-
 .../unittest/test_schedule_bound_inference.py      | 26 ++++++++
 2 files changed, 99 insertions(+), 3 deletions(-)

diff --git a/src/te/schedule/message_passing.cc 
b/src/te/schedule/message_passing.cc
index 5b6fa86..a7b2482 100644
--- a/src/te/schedule/message_passing.cc
+++ b/src/te/schedule/message_passing.cc
@@ -51,17 +51,66 @@ void Update(std::unordered_map<IterVar, Range>* p_state,
   }
 }
 
+/*!
+ * \param Upward propagating whether an IterVar derives at least one leaf 
IterVar that binds to
+ * a thread.
+ *
+ * \param stage The stage to operate on.
+ * \param p_state The propagation result of each IterVar.
+ */
+void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, 
bool>* p_state) {
+  auto bound_to_thread = [&stage](const IterVar& iv) {
+    bool bound = false;
+    auto it = stage->iter_var_attrs.find(iv);
+    if (it != stage->iter_var_attrs.end()) {
+      bound = (*it).second->bind_thread.defined();
+    }
+    return bound;
+  };
+
+  auto& state = *p_state;
+  // Fill p_state with leaf itervars
+  for (const IterVar& iv : stage->leaf_iter_vars) {
+    state[iv] = bound_to_thread(iv);
+  }
+  // Traverse the graph bottom-up to propagate thread binding information
+  for (size_t i = stage->relations.size(); i != 0; --i) {
+    IterVarRelation rel = stage->relations[i - 1];
+    if (const SplitNode* s = rel.as<SplitNode>()) {
+      state[s->parent] = state[s->inner] || state[s->outer];
+    } else if (const FuseNode* s = rel.as<FuseNode>()) {
+      state[s->inner] = state[s->fused];
+      state[s->outer] = state[s->fused];
+    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      state[s->parent] = state[s->rebased];
+    } else if (rel.as<SingletonNode>()) {
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
 void PassDownDomain(const Stage& stage,
                     std::unordered_map<IterVar, Range>* p_state,
                     arith::Analyzer* actx,
                     bool allow_missing) {
-  auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
+  auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) {
     if (actx->CanProve(indexmod(a, b) == 0)) {
       return actx->Simplify(indexdiv(a, b));
     }
     return actx->Simplify(indexdiv(a + (b - 1), b));
   };
 
+  auto minimum_or_later  = [actx](const PrimExpr& a, const PrimExpr& b) {
+    if (actx->CanProve(a < b)) {
+      return actx->Simplify(a);
+    }
+    return actx->Simplify(b);
+  };
+
+  std::unordered_map<IterVar, bool> dominating_thread;
+  PassUpThreadBinding(stage, &dominating_thread);
+
   auto& state = *p_state;
   // forwar iteration on relations
   for (IterVarRelation rel : stage->relations) {
@@ -72,14 +121,35 @@ void PassDownDomain(const Stage& stage,
       }
       CHECK(!state.count(r->inner));
       const Range& range_parent = state.at(r->parent);
+      // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if 
all of the
+      // following conditions are met:
+      // 1. No leaf IterVar derived from iv binds to any thread.  People may 
use split
+      // to force an IterVar extent to match the number of allocated threads 
to fuse stages
+      // that require different number of threads.  We don't want to change 
these extents.
+      // 2. allow_missing is false, i.e. that PassDownDomain is called by the 
final InferBound,
+      // rather than by an early compiler phase, such as rfactor().  We don't 
want to tighten an
+      // IterVar in an early phase allowing missing IterVars, because it may 
bind to a thread later.
+      // 3. range_parent's extent is not 0.  At lest one Topi test has a case 
where a tensor has one
+      // zero-sized dimension.  Split creates iv with a positive extent to 
avoid zero-extent
+      // IterVar.  We don't touch it.
+      auto resolve_min_extent_for_split = [&](const IterVar& iv, const 
PrimExpr& factor_or_nparts) {
+        return dominating_thread[iv] || allow_missing || 
is_zero(range_parent->extent)
+                   ? factor_or_nparts
+                   : minimum_or_later(range_parent->extent, factor_or_nparts);
+      };
       if (r->factor.defined()) {
         Update(p_state, r->inner,
-               Range::make_by_min_extent(0, r->factor), actx);
+               Range::make_by_min_extent(
+                   0, resolve_min_extent_for_split(r->inner, r->factor)),
+               actx);
         Update(p_state, r->outer,
                Range::make_by_min_extent(
                    0, ceil_div(range_parent->extent, r->factor)), actx);
       } else {
-        Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), 
actx);
+        Update(p_state, r->outer,
+               Range::make_by_min_extent(
+                   0, resolve_min_extent_for_split(r->outer, r->nparts)),
+               actx);
         Update(p_state, r->inner,
                Range::make_by_min_extent(
                    0, ceil_div(range_parent->extent, r->nparts)), actx);
diff --git a/tests/python/unittest/test_schedule_bound_inference.py 
b/tests/python/unittest/test_schedule_bound_inference.py
index 484aa50..edae527 100644
--- a/tests/python/unittest/test_schedule_bound_inference.py
+++ b/tests/python/unittest/test_schedule_bound_inference.py
@@ -70,6 +70,32 @@ def test_bound3():
     assert(bounds[A1.op.axis[0]].extent.value==32)
     assert(bounds[A1.op.axis[1]].extent.value==16)
 
+def test_bound_split_ext_less_than_factor():
+    m = 8
+    I = te.placeholder((m,), name='I')
+    EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
+    E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
+    s = te.create_schedule([E.op])
+    xo, xi = s[E].split(s[E].op.axis[0], factor = 32)
+    s[EF].compute_at(s[E], xo)
+
+    bounds = tvm.te.schedule.InferBound(s)
+    assert isinstance(bounds, tvm.container.Map)
+    assert bounds[xi].extent.value == m
+
+def test_bound_split_ext_less_than_naprts():
+    m = 8
+    I = te.placeholder((m,), name='I')
+    EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
+    E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
+    s = te.create_schedule([E.op])
+    xo, xi = s[E].split(s[E].op.axis[0], nparts = 32)
+    s[EF].compute_at(s[E], xo)
+
+    bounds = tvm.te.schedule.InferBound(s)
+    assert isinstance(bounds, tvm.container.Map)
+    assert bounds[xo].extent.value == m
+
 def test_bound_split_divisible():
     m = te.var('m')
     l = te.var('l')

Reply via email to