You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/27 00:01:03 UTC
[incubator-tvm] branch master updated: Improve IntervalSet's
floormod (#5367)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 37e5754 Improve IntervalSet's floormod (#5367)
37e5754 is described below
commit 37e5754862d438dbaec63fdcd3367993a1070981
Author: yongfeng-nv <49...@users.noreply.github.com>
AuthorDate: Sun Apr 26 20:00:52 2020 -0400
Improve IntervalSet's floormod (#5367)
---
include/tvm/arith/analyzer.h | 25 +++++++--
include/tvm/arith/int_set.h | 12 +++--
src/arith/analyzer.cc | 29 +++++++----
src/arith/const_int_bound.cc | 18 ++++---
src/arith/int_set.cc | 10 ++++
src/te/operation/compute_op.cc | 14 ++---
src/te/schedule/bound.cc | 15 ++++--
src/te/schedule/message_passing.cc | 12 +++--
tests/python/unittest/test_arith_intset.py | 14 +++++
.../test_te_schedule_bound_inference_tiling.py | 60 ++++++++++++++++++++++
10 files changed, 168 insertions(+), 41 deletions(-)
diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 6ca3ba9..c08c0d6 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -138,8 +138,9 @@ class ConstIntBoundAnalyzer {
*
* \param var The variable.
* \param range The range we bind to.
+ * \param override Whether we allow overriding an existing var's range.
*/
- TVM_DLL void Bind(const Var& var, const Range& range);
+ TVM_DLL void Bind(const Var& var, const Range& range, bool override = false);
private:
friend class Analyzer;
@@ -411,8 +412,9 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param expr The expression we bind to.
+ * \param override Whether we allow overriding an existing var's expression.
*/
- void Bind(const Var& var, const PrimExpr& expr);
+ void Bind(const Var& var, const PrimExpr& expr, bool override = false);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
@@ -421,14 +423,16 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param range The range we bind to.
+ * \param override Whether we allow overriding an existing var's expression.
*/
- void Bind(const Var& var, const Range& range);
+ void Bind(const Var& var, const Range& range, bool override = false);
/*!
* \brief Bind all the vars in the Map
*
* \param variables The {variable -> range} map.
+ * \param override Whether we allow overriding an existing var's expression.
*/
- void Bind(const Map<Var, Range>& variables);
+ void Bind(const Map<Var, Range>& variables, bool override = false);
/*!
* \brief Whether can we prove expr >= val.
@@ -443,6 +447,19 @@ class TVM_DLL Analyzer {
*/
bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
/*!
+ * \brief Whether can we prove expr < val.
+
+ * Non-negative proof is very useful in integer analysis
+ * to lower divisions and mods given difference in trunc and ceil mode.
+ *
+ * \param expr The expression.
+ * \param upper_bound The upper bound.
+ * \return Whether we can prove it.
+ *
+ * \note Analyzer will call into sub-analyzers to get the result.
+ */
+ bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
+ /*!
* \brief Whether can we prove condition.
*
* \param cond The expression to be proved.
diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h
index 86ef906..ab73b07 100644
--- a/include/tvm/arith/int_set.h
+++ b/include/tvm/arith/int_set.h
@@ -153,6 +153,13 @@ class IntSet : public ObjectRef {
// Integer set legacy API.
//------------------------------------------------
/*!
+ * \brief Convert std::unordered_map<const VarNode*, IntSet> to Map<Var, IntSet>
+ *
+ * \param dom_map The domain map to convert.
+ * \return The converted map.
+ */
+Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map);
+/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
*
@@ -160,8 +167,7 @@ class IntSet : public ObjectRef {
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
-IntSet EvalSet(PrimExpr e,
- const Map<IterVar, IntSet>& dom_map);
+IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
@@ -171,7 +177,6 @@ IntSet EvalSet(PrimExpr e,
*/
IntSet EvalSet(PrimExpr e,
const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);
-
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
@@ -202,7 +207,6 @@ IntSet EvalSet(IntSet s,
*/
IntSet EvalSet(Range r,
const std::unordered_map<const VarNode*, IntSet>& dom_map);
-
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
/*!
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 83dfc64..9199bac 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -36,31 +36,31 @@ Analyzer::Analyzer()
int_set(this) {
}
-void Analyzer::Bind(const Var& var, const PrimExpr& expr) {
+void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) {
PrimExpr new_expr = expr;
new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);
- this->const_int_bound.Update(var, this->const_int_bound(new_expr));
- this->modular_set.Update(var, this->modular_set(new_expr));
- this->rewrite_simplify.Update(var, new_expr);
- this->canonical_simplify.Update(var, new_expr);
+ this->const_int_bound.Update(var, this->const_int_bound(new_expr), override);
+ this->modular_set.Update(var, this->modular_set(new_expr), override);
+ this->rewrite_simplify.Update(var, new_expr, override);
+ this->canonical_simplify.Update(var, new_expr, override);
}
-void Analyzer::Bind(const Var& var, const Range& range) {
+void Analyzer::Bind(const Var& var, const Range& range, bool override) {
CHECK(range.defined());
if (tir::is_one(range->extent)) {
- this->Bind(var, range->min);
+ this->Bind(var, range->min, override);
} else {
- this->const_int_bound.Bind(var, range);
+ this->const_int_bound.Bind(var, range, override);
}
// skip modular_set
// skip rewrite simplify
}
-void Analyzer::Bind(const Map<Var, Range>& variables) {
+void Analyzer::Bind(const Map<Var, Range>& variables, bool override) {
for (const auto& iter : variables) {
- this->Bind(iter.first, iter.second);
+ this->Bind(iter.first, iter.second, override);
}
}
@@ -92,6 +92,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
return false;
}
+bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) {
+ if (const auto* ptr = expr.as<tir::IntImmNode>()) {
+ return ptr->value < upper_bound;
+ }
+ auto bd = this->const_int_bound(this->rewrite_simplify(expr));
+ if (bd->max_value < upper_bound) return true;
+ return false;
+}
+
bool Analyzer::CanProve(const PrimExpr& expr) {
if (const auto* ptr = expr.as<IntImmNode>()) {
return ptr->value != 0;
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 57dfc15..bb7c3dd 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -99,13 +99,13 @@ class ConstIntBoundAnalyzer::Impl :
}
};
- void Bind(const Var& var, const Range& range) {
+ void Bind(const Var& var, const Range& range, bool override) {
Entry a = VisitExpr(range->min);
Entry b = VisitExpr(range->extent);
Entry ret;
ret.min_value = a.min_value;
ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1));
- Update(var, ret, false);
+ Update(var, ret, override);
}
void Update(const Var& var,
@@ -150,10 +150,12 @@ class ConstIntBoundAnalyzer::Impl :
const PrimExprNode* op = expr.as<PrimExprNode>();
auto val = bound_->find(op);
if (val != bound_->end()) {
- CHECK(val->second->min_value == res.min_value &&
- val->second->max_value == res.max_value)
- << "Detected bound for " << expr
- << "conflicts with memorization";
+ auto everything = Everything(op->dtype);
+ CHECK(
+ (val->second->min_value == res.min_value && val->second->max_value == res.max_value) ||
+ (val->second->min_value == everything.min_value &&
+ val->second->max_value == everything.max_value))
+ << "Detected bound for " << expr << "conflicts with memorization";
}
(*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
}
@@ -574,8 +576,8 @@ void ConstIntBoundAnalyzer::Update(const Var& var,
impl_->Update(var, info, override);
}
-void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) {
- impl_->Bind(var, range);
+void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) {
+ impl_->Bind(var, range, override);
}
std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) {
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index 027259a..d2d43d6 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -311,6 +311,16 @@ inline IntervalSet Combine<tir::FloorModNode>(Analyzer* analyzer,
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
+ if (divisor.as<tir::IntImmNode>()) {
+ // a mod b = a - (a / b) * b if a_max / b == a_min / b
+ auto qmax = floordiv(a->max_value, divisor);
+ auto qmin = floordiv(a->min_value, divisor);
+ if (analyzer->CanProve(qmax == qmin)) {
+ auto tmax = a->max_value - divisor * qmin;
+ auto tmin = a->min_value - divisor * qmin;
+ return IntervalSet(tmin, tmax);
+ }
+ }
return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
} else {
PrimExpr bound = abs(divisor) - 1;
diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc
index 2d9f13b..1248547 100644
--- a/src/te/operation/compute_op.cc
+++ b/src/te/operation/compute_op.cc
@@ -231,7 +231,7 @@ void ComputeOpNode::PropBoundToInputs(
// undefined behaviour), so we can intersect the estimated set of the argument with the
// range expected by the tensor. However, intersection may result in overly complex
// expressions, so we perform a more relaxed form of intersection.
- IntSet arg_intset = EvalSet(call->args[i], dom_map);
+ IntSet arg_intset = analyzer->int_set(call->args[i], ConvertDomMap(dom_map));
const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
if (arg_interval) {
PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
@@ -239,12 +239,14 @@ void ComputeOpNode::PropBoundToInputs(
PrimExpr min_value = arg_interval->min_value;
PrimExpr max_value = arg_interval->max_value;
// Prefer the shape bounds only when we can prove they are tighter.
- if (arith::is_neg_inf(min_value) ||
- analyzer->CanProve(shape_i_min_value >= min_value)) {
+ // We must update bound's ends in pairs. Here is an counter example: shape_i is
+ // [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is
+ // [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0],
+ // awkward for further analysis.
+ if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) ||
+ (analyzer->CanProve(shape_i_min_value >= min_value) &&
+ analyzer->CanProve(shape_i_max_value <= max_value))) {
min_value = shape_i_min_value;
- }
- if (arith::is_pos_inf(max_value) ||
- analyzer->CanProve(shape_i_max_value <= max_value)) {
max_value = shape_i_max_value;
}
dom.data[i].push_back(IntSet::interval(min_value, max_value));
diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc
index 4dde945..552d7b7 100644
--- a/src/te/schedule/bound.cc
+++ b/src/te/schedule/bound.cc
@@ -137,7 +137,7 @@ void InferRootBound(const Stage& stage,
Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
// The parent set.
for (const Operation& op : consumers) {
- std::unordered_map<const VarNode*, IntSet> relax_set;
+ Map<Var, IntSet> relax_set;
std::unordered_map<IterVar, IntSet> up_state;
bool found_attach = false;
CHECK(ctx.op2stage_.count(op.get()));
@@ -176,9 +176,9 @@ void InferRootBound(const Stage& stage,
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
- relax_set[iv->var.get()] = IntSet::range(vrange);
+ relax_set.Set(iv->var, IntSet::range(vrange));
if (ctx.bind_map.count(iv)) {
- relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
+ relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::range(vrange));
}
}
}
@@ -190,6 +190,9 @@ void InferRootBound(const Stage& stage,
// Relax if needed.
std::unordered_map<const VarNode*, IntSet> dom_map;
arith::Analyzer analyzer;
+ for (auto entry : *rmap) {
+ analyzer.Bind(entry.first->var, entry.second);
+ }
for (auto iv : op->root_iter_vars()) {
Range r;
if (up_state.count(iv)) {
@@ -198,11 +201,13 @@ void InferRootBound(const Stage& stage,
r = iv->dom;
}
if (relax_set.size() != 0) {
- dom_map[iv->var.get()] = EvalSet(r, relax_set);
+ dom_map[iv->var.get()] = IntSet::interval(
+ analyzer.int_set(r->min, relax_set).min(),
+ analyzer.int_set(r->min + r->extent - 1, relax_set).max());
} else {
dom_map[iv->var.get()] = IntSet::range(r);
}
- analyzer.Bind(iv->var, r);
+ analyzer.Bind(iv->var, r, true);
}
op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
}
diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc
index 1453ed0..6ae7464 100644
--- a/src/te/schedule/message_passing.cc
+++ b/src/te/schedule/message_passing.cc
@@ -579,11 +579,15 @@ std::vector<PrimExpr> MakeBoundCheck(
PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
std::vector<PrimExpr> preds;
- std::unordered_map<const VarNode*, IntSet> iset_dmap;
+ Map<Var, IntSet> iset_dmap;
// setup domain map for set analysis
for (const auto& kv : dom_map) {
- iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
+ iset_dmap.Set(kv.first->var, IntSet::range(kv.second));
+ }
+
+ for (auto entry : dom_map) {
+ analyzer.Bind(entry.first->var, entry.second);
}
for (const IterVar& iv : stage->all_iter_vars) {
@@ -591,7 +595,7 @@ std::vector<PrimExpr> MakeBoundCheck(
if (bound_state.at(iv)) {
Range dom = dom_map.at(iv);
PrimExpr value = value_map.at(iv) - dom->min;
- PrimExpr vmax = EvalSet(value, iset_dmap).max();
+ PrimExpr vmax = analyzer.int_set(value, iset_dmap).max();
if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
}
@@ -603,7 +607,7 @@ std::vector<PrimExpr> MakeBoundCheck(
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
PrimExpr value = value_map.at(iv) - iv->dom->min;
- IntSet s = EvalSet(value, iset_dmap);
+ IntSet s = analyzer.int_set(value, iset_dmap);
PrimExpr vmin = s.min();
PrimExpr vmax = s.max();
// The range of `value` resides in [vmin, vmax]
diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py
index e57dcef..9919c7b 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -90,6 +90,20 @@ def test_mod():
flm = tvm.te.floormod
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
+ ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 5)}, (3, 5))
+ ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(13, 15)}, (3, 5))
+ ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9))
+ ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9))
+ ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(1, 21)}, (0, 9))
+
+ floordiv = tvm.te.floordiv
+ z = te.var("z")
+ ck.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 3))
+ ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)},
+ (0, 7))
+ ck1 = IntSetChecker()
+ ck1.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 2))
+ ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3))
def test_max_min():
diff --git a/tests/python/unittest/test_te_schedule_bound_inference_tiling.py b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py
new file mode 100644
index 0000000..3893bb6
--- /dev/null
+++ b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import te
+
+def test_bound_tile_mod():
+ def compute(M_tiles, N_tiles, factor, dtype):
+ # Algo
+ M = M_tiles * factor
+ N = N_tiles * factor
+
+ A = tvm.te.placeholder((N, M), name='A', dtype=dtype)
+ C = tvm.te.compute((N, M), lambda n, m: A[n, m], name='C')
+ s = tvm.te.create_schedule(C.op)
+
+ return s, A, C
+
+ def schedule(s, factor, padding, A, C):
+ C_local = s.cache_write(C, "local")
+
+ n, m = C.op.axis
+ bn, bm, ni, mi = s[C].tile(n, m, factor, factor)
+ nio, nii = s[C].split(ni, 2)
+ n = s[C].fuse(nii, mi)
+ C_shared = s.cache_write(C, "shared")
+ bn, bm, ni, mi = C_shared.op.axis
+ s[C_shared].storage_align(ni, factor * 2, padding)
+
+ n, m = s[C].op.axis
+ bn, bm, ni, mi = s[C].tile(n, m, factor, factor)
+ s[C].set_scope("global")
+ niio, niii = s[C].split(ni, 32)
+ s[C_shared].compute_at(s[C], niio)
+
+ return s
+
+ s, A, C = compute(2, 2, 128, "float16")
+ s = schedule(s, 128, 8, A, C)
+ bounds = tvm.te.schedule.InferBound(s)
+ check = (bounds[s.stages[2].op.axis[2]].extent == 16)
+ if(not check):
+ print(tvm.lower(s, [A, C], simple_mode=True))
+ assert(check)
+
+if __name__ == "__main__":
+ test_bound_tile_mod()