You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/03/10 15:18:13 UTC

[incubator-tvm] 01/01: Revert "Tighten split's extent (#4931)"

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch revert-4931-split-node-min-range-only-at-final-pass
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 13588f0ac709165467472f261f351bd86a17dfe1
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Tue Mar 10 08:17:59 2020 -0700

    Revert "Tighten split's extent (#4931)"
    
    This reverts commit 585f9ce6e7bef7d0e8902b1c1e55dcb3bbe84eed.
---
 src/te/schedule/message_passing.cc                 | 76 +---------------------
 .../unittest/test_schedule_bound_inference.py      | 26 --------
 2 files changed, 3 insertions(+), 99 deletions(-)

diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc
index a7b2482..5b6fa86 100644
--- a/src/te/schedule/message_passing.cc
+++ b/src/te/schedule/message_passing.cc
@@ -51,66 +51,17 @@ void Update(std::unordered_map<IterVar, Range>* p_state,
   }
 }
 
-/*!
- * \param Upward propagating whether an IterVar derives at least one leaf IterVar that binds to
- * a thread.
- *
- * \param stage The stage to operate on.
- * \param p_state The propagation result of each IterVar.
- */
-void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>* p_state) {
-  auto bound_to_thread = [&stage](const IterVar& iv) {
-    bool bound = false;
-    auto it = stage->iter_var_attrs.find(iv);
-    if (it != stage->iter_var_attrs.end()) {
-      bound = (*it).second->bind_thread.defined();
-    }
-    return bound;
-  };
-
-  auto& state = *p_state;
-  // Fill p_state with leaf itervars
-  for (const IterVar& iv : stage->leaf_iter_vars) {
-    state[iv] = bound_to_thread(iv);
-  }
-  // Traverse the graph bottom-up to propagate thread binding information
-  for (size_t i = stage->relations.size(); i != 0; --i) {
-    IterVarRelation rel = stage->relations[i - 1];
-    if (const SplitNode* s = rel.as<SplitNode>()) {
-      state[s->parent] = state[s->inner] || state[s->outer];
-    } else if (const FuseNode* s = rel.as<FuseNode>()) {
-      state[s->inner] = state[s->fused];
-      state[s->outer] = state[s->fused];
-    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
-      state[s->parent] = state[s->rebased];
-    } else if (rel.as<SingletonNode>()) {
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-}
-
 void PassDownDomain(const Stage& stage,
                     std::unordered_map<IterVar, Range>* p_state,
                     arith::Analyzer* actx,
                     bool allow_missing) {
-  auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) {
+  auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
     if (actx->CanProve(indexmod(a, b) == 0)) {
       return actx->Simplify(indexdiv(a, b));
     }
     return actx->Simplify(indexdiv(a + (b - 1), b));
   };
 
-  auto minimum_or_later  = [actx](const PrimExpr& a, const PrimExpr& b) {
-    if (actx->CanProve(a < b)) {
-      return actx->Simplify(a);
-    }
-    return actx->Simplify(b);
-  };
-
-  std::unordered_map<IterVar, bool> dominating_thread;
-  PassUpThreadBinding(stage, &dominating_thread);
-
   auto& state = *p_state;
   // forwar iteration on relations
   for (IterVarRelation rel : stage->relations) {
@@ -121,35 +72,14 @@ void PassDownDomain(const Stage& stage,
       }
       CHECK(!state.count(r->inner));
       const Range& range_parent = state.at(r->parent);
-      // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the
-      // following conditions are met:
-      // 1. No leaf IterVar derived from iv binds to any thread.  People may use split
-      // to force an IterVar extent to match the number of allocated threads to fuse stages
-      // that require different number of threads.  We don't want to change these extents.
-      // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound,
-      // rather than by an early compiler phase, such as rfactor().  We don't want to tighten an
-      // IterVar in an early phase allowing missing IterVars, because it may bind to a thread later.
-      // 3. range_parent's extent is not 0.  At lest one Topi test has a case where a tensor has one
-      // zero-sized dimension.  Split creates iv with a positive extent to avoid zero-extent
-      // IterVar.  We don't touch it.
-      auto resolve_min_extent_for_split = [&](const IterVar& iv, const PrimExpr& factor_or_nparts) {
-        return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent)
-                   ? factor_or_nparts
-                   : minimum_or_later(range_parent->extent, factor_or_nparts);
-      };
       if (r->factor.defined()) {
         Update(p_state, r->inner,
-               Range::make_by_min_extent(
-                   0, resolve_min_extent_for_split(r->inner, r->factor)),
-               actx);
+               Range::make_by_min_extent(0, r->factor), actx);
         Update(p_state, r->outer,
                Range::make_by_min_extent(
                    0, ceil_div(range_parent->extent, r->factor)), actx);
       } else {
-        Update(p_state, r->outer,
-               Range::make_by_min_extent(
-                   0, resolve_min_extent_for_split(r->outer, r->nparts)),
-               actx);
+        Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx);
         Update(p_state, r->inner,
                Range::make_by_min_extent(
                    0, ceil_div(range_parent->extent, r->nparts)), actx);
diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py
index edae527..484aa50 100644
--- a/tests/python/unittest/test_schedule_bound_inference.py
+++ b/tests/python/unittest/test_schedule_bound_inference.py
@@ -70,32 +70,6 @@ def test_bound3():
     assert(bounds[A1.op.axis[0]].extent.value==32)
     assert(bounds[A1.op.axis[1]].extent.value==16)
 
-def test_bound_split_ext_less_than_factor():
-    m = 8
-    I = te.placeholder((m,), name='I')
-    EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
-    E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
-    s = te.create_schedule([E.op])
-    xo, xi = s[E].split(s[E].op.axis[0], factor = 32)
-    s[EF].compute_at(s[E], xo)
-
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    assert bounds[xi].extent.value == m
-
-def test_bound_split_ext_less_than_naprts():
-    m = 8
-    I = te.placeholder((m,), name='I')
-    EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
-    E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
-    s = te.create_schedule([E.op])
-    xo, xi = s[E].split(s[E].op.axis[0], nparts = 32)
-    s[EF].compute_at(s[E], xo)
-
-    bounds = tvm.te.schedule.InferBound(s)
-    assert isinstance(bounds, tvm.container.Map)
-    assert bounds[xo].extent.value == m
-
 def test_bound_split_divisible():
     m = te.var('m')
     l = te.var('l')