You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/27 00:01:03 UTC

[incubator-tvm] branch master updated: Improve IntervalSet's floormod (#5367)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 37e5754  Improve IntervalSet's floormod (#5367)
37e5754 is described below

commit 37e5754862d438dbaec63fdcd3367993a1070981
Author: yongfeng-nv <49...@users.noreply.github.com>
AuthorDate: Sun Apr 26 20:00:52 2020 -0400

    Improve IntervalSet's floormod (#5367)
---
 include/tvm/arith/analyzer.h                       | 25 +++++++--
 include/tvm/arith/int_set.h                        | 12 +++--
 src/arith/analyzer.cc                              | 29 +++++++----
 src/arith/const_int_bound.cc                       | 18 ++++---
 src/arith/int_set.cc                               | 10 ++++
 src/te/operation/compute_op.cc                     | 14 ++---
 src/te/schedule/bound.cc                           | 15 ++++--
 src/te/schedule/message_passing.cc                 | 12 +++--
 tests/python/unittest/test_arith_intset.py         | 14 +++++
 .../test_te_schedule_bound_inference_tiling.py     | 60 ++++++++++++++++++++++
 10 files changed, 168 insertions(+), 41 deletions(-)

diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 6ca3ba9..c08c0d6 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -138,8 +138,9 @@ class ConstIntBoundAnalyzer {
    *
    * \param var The variable.
    * \param range The range we bind to.
+   * \param override Whether we allow overriding an existing var's range.
    */
-  TVM_DLL void Bind(const Var& var, const Range& range);
+  TVM_DLL void Bind(const Var& var, const Range& range, bool override = false);
 
  private:
   friend class Analyzer;
@@ -411,8 +412,9 @@ class TVM_DLL Analyzer {
    *
    * \param var The variable.
    * \param expr The expression we bind to.
+   * \param override Whether we allow overriding an existing var's expression.
    */
-  void Bind(const Var& var, const PrimExpr& expr);
+  void Bind(const Var& var, const PrimExpr& expr, bool override = false);
   /*!
    * \brief Notify all the sub-analyzers that var
    *        is created and binded to a range.
@@ -421,14 +423,16 @@ class TVM_DLL Analyzer {
    *
    * \param var The variable.
    * \param range The range we bind to.
+   * \param override Whether we allow overriding an existing var's expression.
    */
-  void Bind(const Var& var, const Range& range);
+  void Bind(const Var& var, const Range& range, bool override = false);
   /*!
    * \brief Bind all the vars in the Map
    *
    * \param variables The {variable -> range} map.
+   * \param override Whether we allow overriding an existing var's expression.
    */
-  void Bind(const Map<Var, Range>& variables);
+  void Bind(const Map<Var, Range>& variables, bool override = false);
   /*!
    * \brief Whether can we prove expr >= val.
 
@@ -443,6 +447,19 @@ class TVM_DLL Analyzer {
    */
   bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
   /*!
+   * \brief Whether can we prove expr < val.
+
+   *  Non-negative proof is very useful in integer analysis
+   *  to lower divisions and mods given difference in trunc and ceil mode.
+   *
+   * \param expr The expression.
+   * \param upper_bound The upper bound.
+   * \return Whether we can prove it.
+   *
+   * \note Analyzer will call into sub-analyzers to get the result.
+   */
+  bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
+  /*!
    * \brief Whether can we prove condition.
    *
    * \param cond The expression to be proved.
diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h
index 86ef906..ab73b07 100644
--- a/include/tvm/arith/int_set.h
+++ b/include/tvm/arith/int_set.h
@@ -153,6 +153,13 @@ class IntSet : public ObjectRef {
 // Integer set legacy API.
 //------------------------------------------------
 /*!
+ * \brief Convert std::unordered_map<const VarNode*, IntSet> to Map<Var, IntSet>
+ *
+ * \param dom_map The domain map to convert.
+ * \return The converted map.
+ */
+Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map);
+/*!
  * \brief Find an symbolic integer set that contains all possible values of
  *  e given the domain of each iteration variables.
  *
@@ -160,8 +167,7 @@ class IntSet : public ObjectRef {
  * \param dom_map The domain of each variable.
  * \return An integer set that can cover all the possible values of e.
  */
-IntSet EvalSet(PrimExpr e,
-               const Map<IterVar, IntSet>& dom_map);
+IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map);
 /*!
  * \brief Same as EvalSet, but takes unordered_map
  *
@@ -171,7 +177,6 @@ IntSet EvalSet(PrimExpr e,
  */
 IntSet EvalSet(PrimExpr e,
                const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);
-
 /*!
  * \brief Find an symbolic integer set that contains is union over
  *  all the possible conditional values in dom_map.
@@ -202,7 +207,6 @@ IntSet EvalSet(IntSet s,
  */
 IntSet EvalSet(Range r,
                const std::unordered_map<const VarNode*, IntSet>& dom_map);
-
 /*! \brief Map from Expr to IntSet */
 using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
 /*!
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 83dfc64..9199bac 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -36,31 +36,31 @@ Analyzer::Analyzer()
       int_set(this) {
 }
 
-void Analyzer::Bind(const Var& var, const PrimExpr& expr) {
+void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) {
   PrimExpr new_expr = expr;
   new_expr = this->canonical_simplify(new_expr);
   new_expr = this->rewrite_simplify(new_expr);
 
-  this->const_int_bound.Update(var, this->const_int_bound(new_expr));
-  this->modular_set.Update(var, this->modular_set(new_expr));
-  this->rewrite_simplify.Update(var, new_expr);
-  this->canonical_simplify.Update(var, new_expr);
+  this->const_int_bound.Update(var, this->const_int_bound(new_expr), override);
+  this->modular_set.Update(var, this->modular_set(new_expr), override);
+  this->rewrite_simplify.Update(var, new_expr, override);
+  this->canonical_simplify.Update(var, new_expr, override);
 }
 
-void Analyzer::Bind(const Var& var, const Range& range) {
+void Analyzer::Bind(const Var& var, const Range& range, bool override) {
   CHECK(range.defined());
   if (tir::is_one(range->extent)) {
-    this->Bind(var, range->min);
+    this->Bind(var, range->min, override);
   } else {
-    this->const_int_bound.Bind(var, range);
+    this->const_int_bound.Bind(var, range, override);
   }
   // skip modular_set
   // skip rewrite simplify
 }
 
-void Analyzer::Bind(const Map<Var, Range>& variables) {
+void Analyzer::Bind(const Map<Var, Range>& variables, bool override) {
   for (const auto& iter : variables) {
-    this->Bind(iter.first, iter.second);
+    this->Bind(iter.first, iter.second, override);
   }
 }
 
@@ -92,6 +92,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
   return false;
 }
 
+bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) {
+  if (const auto* ptr = expr.as<tir::IntImmNode>()) {
+    return ptr->value < upper_bound;
+  }
+  auto bd = this->const_int_bound(this->rewrite_simplify(expr));
+  if (bd->max_value < upper_bound) return true;
+  return false;
+}
+
 bool Analyzer::CanProve(const PrimExpr& expr) {
   if (const auto* ptr = expr.as<IntImmNode>()) {
     return ptr->value != 0;
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 57dfc15..bb7c3dd 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -99,13 +99,13 @@ class ConstIntBoundAnalyzer::Impl :
     }
   };
 
-  void Bind(const Var& var, const Range& range) {
+  void Bind(const Var& var, const Range& range, bool override) {
     Entry a = VisitExpr(range->min);
     Entry b = VisitExpr(range->extent);
     Entry ret;
     ret.min_value = a.min_value;
     ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1));
-    Update(var, ret, false);
+    Update(var, ret, override);
   }
 
   void Update(const Var& var,
@@ -150,10 +150,12 @@ class ConstIntBoundAnalyzer::Impl :
       const PrimExprNode* op = expr.as<PrimExprNode>();
       auto val = bound_->find(op);
       if (val != bound_->end()) {
-        CHECK(val->second->min_value == res.min_value &&
-              val->second->max_value == res.max_value)
-          << "Detected bound for " << expr
-          << "conflicts with memorization";
+        auto everything = Everything(op->dtype);
+        CHECK(
+            (val->second->min_value == res.min_value && val->second->max_value == res.max_value) ||
+            (val->second->min_value == everything.min_value &&
+             val->second->max_value == everything.max_value))
+            << "Detected bound for " << expr << "conflicts with memorization";
       }
       (*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
     }
@@ -574,8 +576,8 @@ void ConstIntBoundAnalyzer::Update(const Var& var,
   impl_->Update(var, info, override);
 }
 
-void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) {
-  impl_->Bind(var, range);
+void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) {
+  impl_->Bind(var, range, override);
 }
 
 std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) {
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index 027259a..d2d43d6 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -311,6 +311,16 @@ inline IntervalSet Combine<tir::FloorModNode>(Analyzer* analyzer,
       LOG(FATAL) << "Modular by zero in CombineInterval Mod";
     }
     if (analyzer->CanProveGreaterEqual(divisor, 0)) {
+      if (divisor.as<tir::IntImmNode>()) {
+        // a mod b = a - (a / b) * b if a_max / b == a_min / b
+        auto qmax = floordiv(a->max_value, divisor);
+        auto qmin = floordiv(a->min_value, divisor);
+        if (analyzer->CanProve(qmax == qmin)) {
+          auto tmax = a->max_value - divisor * qmin;
+          auto tmin = a->min_value - divisor * qmin;
+          return IntervalSet(tmin, tmax);
+        }
+      }
       return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
     } else {
       PrimExpr bound = abs(divisor) - 1;
diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc
index 2d9f13b..1248547 100644
--- a/src/te/operation/compute_op.cc
+++ b/src/te/operation/compute_op.cc
@@ -231,7 +231,7 @@ void ComputeOpNode::PropBoundToInputs(
           // undefined behaviour), so we can intersect the estimated set of the argument with the
           // range expected by the tensor. However, intersection may result in overly complex
           // expressions, so we perform a more relaxed form of intersection.
-          IntSet arg_intset = EvalSet(call->args[i], dom_map);
+          IntSet arg_intset = analyzer->int_set(call->args[i], ConvertDomMap(dom_map));
           const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
           if (arg_interval) {
             PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
@@ -239,12 +239,14 @@ void ComputeOpNode::PropBoundToInputs(
             PrimExpr min_value = arg_interval->min_value;
             PrimExpr max_value = arg_interval->max_value;
             // Prefer the shape bounds only when we can prove they are tighter.
-            if (arith::is_neg_inf(min_value) ||
-                analyzer->CanProve(shape_i_min_value >= min_value)) {
+            // We must update bound's ends in pairs.  Here is an counter example: shape_i is
+            // [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is
+            // [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0],
+            // awkward for further analysis.
+            if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) ||
+                (analyzer->CanProve(shape_i_min_value >= min_value) &&
+                 analyzer->CanProve(shape_i_max_value <= max_value))) {
               min_value = shape_i_min_value;
-            }
-            if (arith::is_pos_inf(max_value) ||
-                analyzer->CanProve(shape_i_max_value <= max_value)) {
               max_value = shape_i_max_value;
             }
             dom.data[i].push_back(IntSet::interval(min_value, max_value));
diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc
index 4dde945..552d7b7 100644
--- a/src/te/schedule/bound.cc
+++ b/src/te/schedule/bound.cc
@@ -137,7 +137,7 @@ void InferRootBound(const Stage& stage,
   Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
   // The parent set.
   for (const Operation& op : consumers) {
-    std::unordered_map<const VarNode*, IntSet> relax_set;
+    Map<Var, IntSet> relax_set;
     std::unordered_map<IterVar, IntSet> up_state;
     bool found_attach = false;
     CHECK(ctx.op2stage_.count(op.get()));
@@ -176,9 +176,9 @@ void InferRootBound(const Stage& stage,
           << "InferBound requires every leaf iter var's min equals 0, "
           << "call schedule.normalize to achieve this.";
       if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
-        relax_set[iv->var.get()] = IntSet::range(vrange);
+        relax_set.Set(iv->var, IntSet::range(vrange));
         if (ctx.bind_map.count(iv)) {
-          relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
+          relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::range(vrange));
         }
       }
     }
@@ -190,6 +190,9 @@ void InferRootBound(const Stage& stage,
     // Relax if needed.
     std::unordered_map<const VarNode*, IntSet> dom_map;
     arith::Analyzer analyzer;
+    for (auto entry : *rmap) {
+      analyzer.Bind(entry.first->var, entry.second);
+    }
     for (auto iv : op->root_iter_vars()) {
       Range r;
       if (up_state.count(iv)) {
@@ -198,11 +201,13 @@ void InferRootBound(const Stage& stage,
         r = iv->dom;
       }
       if (relax_set.size() != 0) {
-        dom_map[iv->var.get()] = EvalSet(r, relax_set);
+        dom_map[iv->var.get()] = IntSet::interval(
+            analyzer.int_set(r->min, relax_set).min(),
+            analyzer.int_set(r->min + r->extent - 1, relax_set).max());
       } else {
         dom_map[iv->var.get()] = IntSet::range(r);
       }
-      analyzer.Bind(iv->var, r);
+      analyzer.Bind(iv->var, r, true);
     }
     op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
   }
diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc
index 1453ed0..6ae7464 100644
--- a/src/te/schedule/message_passing.cc
+++ b/src/te/schedule/message_passing.cc
@@ -579,11 +579,15 @@ std::vector<PrimExpr> MakeBoundCheck(
   PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
 
   std::vector<PrimExpr> preds;
-  std::unordered_map<const VarNode*, IntSet> iset_dmap;
+  Map<Var, IntSet> iset_dmap;
 
   // setup domain map for set analysis
   for (const auto& kv : dom_map) {
-    iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
+    iset_dmap.Set(kv.first->var, IntSet::range(kv.second));
+  }
+
+  for (auto entry : dom_map) {
+    analyzer.Bind(entry.first->var, entry.second);
   }
 
   for (const IterVar& iv : stage->all_iter_vars) {
@@ -591,7 +595,7 @@ std::vector<PrimExpr> MakeBoundCheck(
     if (bound_state.at(iv)) {
       Range dom = dom_map.at(iv);
       PrimExpr value = value_map.at(iv) - dom->min;
-      PrimExpr vmax = EvalSet(value, iset_dmap).max();
+      PrimExpr vmax = analyzer.int_set(value, iset_dmap).max();
       if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
         preds.emplace_back(value < dom->extent);
       }
@@ -603,7 +607,7 @@ std::vector<PrimExpr> MakeBoundCheck(
     CHECK(iv->dom.defined());
     if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
       PrimExpr value = value_map.at(iv) - iv->dom->min;
-      IntSet s = EvalSet(value, iset_dmap);
+      IntSet s = analyzer.int_set(value, iset_dmap);
       PrimExpr vmin = s.min();
       PrimExpr vmax = s.max();
       // The range of `value` resides in [vmin, vmax]
diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py
index e57dcef..9919c7b 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -90,6 +90,20 @@ def test_mod():
 
     flm = tvm.te.floormod
     ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
+    ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 5)}, (3, 5))
+    ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(13, 15)}, (3, 5))
+    ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9))
+    ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9))
+    ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(1, 21)}, (0, 9))
+
+    floordiv = tvm.te.floordiv
+    z = te.var("z")
+    ck.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 3))
+    ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)},
+              (0, 7))
+    ck1 = IntSetChecker()
+    ck1.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 2))
+    ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3))
 
 
 def test_max_min():
diff --git a/tests/python/unittest/test_te_schedule_bound_inference_tiling.py b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py
new file mode 100644
index 0000000..3893bb6
--- /dev/null
+++ b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import te
+
+def test_bound_tile_mod():
+    def compute(M_tiles, N_tiles, factor, dtype):
+        # Algo
+        M = M_tiles * factor
+        N = N_tiles * factor
+
+        A = tvm.te.placeholder((N, M), name='A', dtype=dtype)
+        C = tvm.te.compute((N, M), lambda n, m: A[n, m], name='C')
+        s = tvm.te.create_schedule(C.op)
+
+        return s, A, C
+
+    def schedule(s, factor, padding, A, C):
+        C_local = s.cache_write(C, "local")
+
+        n, m = C.op.axis
+        bn, bm, ni, mi = s[C].tile(n, m, factor, factor)
+        nio, nii = s[C].split(ni, 2)
+        n = s[C].fuse(nii, mi)
+        C_shared = s.cache_write(C, "shared")
+        bn, bm, ni, mi = C_shared.op.axis       
+        s[C_shared].storage_align(ni, factor * 2, padding)
+
+        n, m = s[C].op.axis
+        bn, bm, ni, mi = s[C].tile(n, m, factor, factor)
+        s[C].set_scope("global")
+        niio, niii = s[C].split(ni, 32)
+        s[C_shared].compute_at(s[C], niio)
+
+        return s
+
+    s, A, C = compute(2, 2, 128, "float16")
+    s = schedule(s, 128, 8, A, C)
+    bounds = tvm.te.schedule.InferBound(s)
+    check = (bounds[s.stages[2].op.axis[2]].extent == 16)
+    if(not check):
+        print(tvm.lower(s, [A, C], simple_mode=True))
+    assert(check)
+
+if __name__ == "__main__":
+    test_bound_tile_mod()