You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/26 14:05:15 UTC

[GitHub] [tvm] spectrometerHBH opened a new pull request #7752: [ARITH] detect iter affine map with predicate

spectrometerHBH opened a new pull request #7752:
URL: https://github.com/apache/tvm/pull/7752


   Update on previous PR https://github.com/apache/tvm/pull/6667
   
   Enhance the detect_iter_map utility to be able to detect split patterns with predicates.
   
   cc @tqchen 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen merged pull request #7752: [ARITH] detect iter affine map with predicate

Posted by GitBox <gi...@apache.org>.
tqchen merged pull request #7752:
URL: https://github.com/apache/tvm/pull/7752


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on pull request #7752: [ARITH] detect iter affine map with predicate

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #7752:
URL: https://github.com/apache/tvm/pull/7752#issuecomment-808608709


   > Thanks @spectrometerHBH There seems to be three places where the ExprComplexity is involved.
   > 
   > * https://github.com/apache/tvm/blob/main/src/arith/solve_linear_inequality.cc#L55
   > * #7469
   > 
   > Given your PrimExprSizeCounter counting is simpler, can we create an ExprComplexity function in analysis and call from there?
   
   I'm not sure whther current ExprComplexity is compatible with mine. Looks like current ExprComplexity doesn't count Nodes like Load/BufferLoad/Cast....


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #7752: [ARITH] detect iter affine map with predicate

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #7752:
URL: https://github.com/apache/tvm/pull/7752#discussion_r602639137



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -161,59 +160,116 @@ class IterMarkSplitCollector {
   }
 };
 
-// Rewriter to rewrite PrimExpr to IterMapExpr
-// when possible
+/*! \brief Rewriter to rewrite PrimExpr to IterMapExpr when possible */
 class IterMapRewriter : public ExprMutator {
  public:
   using Parent = ExprMutator;
 
   explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
       : analyzer_(analyzer) {
     for (auto kv : input_iters) {
-      const auto& vrng = kv.second;
-      if (is_zero(vrng->min)) {
-        IterMark mark(kv.first, vrng->extent);
-        var_map_[kv.first] = IterSplitExpr(mark);
+      const Var& var = kv.first;
+      const Range& vrng = kv.second;
+      if (is_one(vrng->extent)) {
+        var_map_[kv.first] = IterSumExpr({}, vrng->min);

Review comment:
       ```suggestion
           var_map_[var] = IterSumExpr({}, vrng->min);
   ```

##########
File path: src/arith/iter_affine_map.cc
##########
@@ -161,59 +160,116 @@ class IterMarkSplitCollector {
   }
 };
 
-// Rewriter to rewrite PrimExpr to IterMapExpr
-// when possible
+/*! \brief Rewriter to rewrite PrimExpr to IterMapExpr when possible */
 class IterMapRewriter : public ExprMutator {
  public:
   using Parent = ExprMutator;
 
   explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
       : analyzer_(analyzer) {
     for (auto kv : input_iters) {
-      const auto& vrng = kv.second;
-      if (is_zero(vrng->min)) {
-        IterMark mark(kv.first, vrng->extent);
-        var_map_[kv.first] = IterSplitExpr(mark);
+      const Var& var = kv.first;
+      const Range& vrng = kv.second;
+      if (is_one(vrng->extent)) {
+        var_map_[kv.first] = IterSumExpr({}, vrng->min);
+      } else if (is_zero(vrng->min)) {
+        IterMark mark(var, vrng->extent);
+        var_map_[var] = IterSplitExpr(mark);
         input_marks_.push_back(mark);
       } else {
-        IterMark mark(kv.first - vrng->min, vrng->extent);
-        auto sum_expr = ToIterSumExpr(IterSplitExpr(mark));
+        IterMark mark(var - vrng->min, vrng->extent);
+        IterSumExpr sum_expr = ToIterSumExpr(IterSplitExpr(mark));
         sum_expr.CopyOnWrite()->base = vrng->min;
-        var_map_[kv.first] = sum_expr;
+        var_map_[var] = sum_expr;
         input_marks_.push_back(mark);
       }
     }
   }
 
   size_t unresolved_count() const { return unresolved_count_; }
 
-  IterSumExpr Rewrite(PrimExpr expr) {
+  IterSumExpr Rewrite(const PrimExpr& expr) {
     return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
   }
 
-  bool CheckBijective(const Array<IterSumExpr>& indices) {
-    // This function checks two conditions:
-    // - C0: Each iter mark should be fully covered by non-overlapping splits.
-    // - C1: All of the input iterators are used.
-    //
-    // Example: given x in [0, 8) y in [0, 6)
-    // - indices = [x, x+1, y] won't pass because x and x+1 contribute
-    //   two splits that overlaps with each other.
-    // - indices = [x / 4, x % 4, y] will pass because x / 4 and x % 4
-    //   contribute two non-overlapping splits that covers x.
-    // - indices = [x / 4, x % 4] won't pass because y is not used.
-    //
+  IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
+                                    const PrimExpr& predicate_induced_extent) {
+    return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_extent);
+  }
+
+  /*!
+   * \brief If require_bijective is true, this function checks two conditions:
+   *   - C0: Each iter mark should be fully covered by non-overlapping splits.
+   *   - C1: All of the input iterators are used.
+   *   Example: given x in [0, 8) y in [0, 6)
+   *   - bindings = [x, x + 1, y] won't pass because x and x+1 contribute
+   *     two splits that overlaps with each other.
+   *   - bindings = [x / 4, x % 4, y] will pass because x / 4 and x % 4
+   *     contribute two non-overlapping splits that covers x.
+   *   - bindings = [x / 4, x % 4] won't pass because y is not used.
+   *
+   *   If require_bijective is false, this function checks one condition:
+   *   - C0: Each iter mark exists chance to be fully covered by non-overlapping splits.

Review comment:
       ```suggestion
      *   - C0: Each iter mark has a chance to be fully covered by non-overlapping splits.
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7752: [ARITH] detect iter affine map with predicate

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7752:
URL: https://github.com/apache/tvm/pull/7752#issuecomment-808724955


   Thanks @spectrometerHBH @FrozenGene @MasterJH5574 this PR is now merged


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FrozenGene commented on a change in pull request #7752: [ARITH] detect iter affine map with predicate

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #7752:
URL: https://github.com/apache/tvm/pull/7752#discussion_r602653607



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -459,27 +665,107 @@ class IterMapRewriter : public ExprMutator {
   }
 };
 
+/*! \brief An internal struct to represent range extent on iterators(iter < upper_bound). */
+struct IterConstraint {
+  // The expr of the iter
+  PrimExpr iter;
+  // The expr of the upper_bound
+  PrimExpr upper_bound;
+  // The size of the iter, which is the number of nodes
+  size_t expr_size = 0;
+
+  IterConstraint(PrimExpr iter, PrimExpr upper_bound, size_t size)
+      : iter(std::move(iter)), upper_bound(std::move(upper_bound)), expr_size(size) {}
+};
+
+/*!
+ * \brief Split the predicate into `(a < b) && (c < d) && ...`
+ * \param pred The predicate to be split.
+ * \return A list of pairs, each element of which are lhs and rhs of the '<' sign,
+ *         empty if the split failed.
+ */
+std::vector<IterConstraint> MatchUpperBoundConstraints(PrimExpr pred) {
+  std::vector<IterConstraint> result;
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      result.emplace_back(lhs.Eval(), rhs.Eval(), 0);
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      result.emplace_back(lhs.Eval(), rhs.Eval(), 0);
+      break;
+    } else {
+      return std::vector<IterConstraint>();
+    }
+  }
+  return result;
+}
+
+/*! \brief Count the size of the PrimExpr. */
+class PrimExprSizeCounter : public ExprVisitor {

Review comment:
       Yes. In my previous pr #7469 I have done something. We could borrow it. My pr has the left problem of max ops counter, however, it is not an urgent problem IMO, which just a more safer guard for complex expr analysis.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7752: [ARITH] detect iter affine map with predicate

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7752:
URL: https://github.com/apache/tvm/pull/7752#issuecomment-808273920






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7752: [ARITH] detect iter affine map with predicate

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7752:
URL: https://github.com/apache/tvm/pull/7752#issuecomment-808724866


   Thanks @spectrometerHBH I think calculating all exprs is the right def. Skipping load do makes sense for index expressions, but my guess is that we can always try the other way if it turns out to be not good


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH edited a comment on pull request #7752: [ARITH] detect iter affine map with predicate

Posted by GitBox <gi...@apache.org>.
spectrometerHBH edited a comment on pull request #7752:
URL: https://github.com/apache/tvm/pull/7752#issuecomment-808608709


   > Thanks @spectrometerHBH There seems to be three places where the ExprComplexity is involved.
   > 
   > * https://github.com/apache/tvm/blob/main/src/arith/solve_linear_inequality.cc#L55
   > * #7469
   > 
   > Given your PrimExprSizeCounter counting is simpler, can we create an ExprComplexity function in analysis and call from there?
   
   @tqchen I'm not sure whther current ExprComplexity is compatible with mine. Looks like current ExprComplexity doesn't count Nodes like Load/BufferLoad/Cast....


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #7752: [ARITH] detect iter affine map with predicate

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #7752:
URL: https://github.com/apache/tvm/pull/7752#discussion_r602342033



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -459,27 +665,107 @@ class IterMapRewriter : public ExprMutator {
   }
 };
 
+/*! \brief An internal struct to represent range extent on iterators(iter < upper_bound). */
+struct IterConstraint {
+  // The expr of the iter
+  PrimExpr iter;
+  // The expr of the upper_bound
+  PrimExpr upper_bound;
+  // The size of the iter, which is the number of nodes
+  size_t expr_size = 0;
+
+  IterConstraint(PrimExpr iter, PrimExpr upper_bound, size_t size)
+      : iter(std::move(iter)), upper_bound(std::move(upper_bound)), expr_size(size) {}
+};
+
+/*!
+ * \brief Split the predicate into `(a < b) && (c < d) && ...`
+ * \param pred The predicate to be split.
+ * \return A list of pairs, each element of which are lhs and rhs of the '<' sign,
+ *         empty if the split failed.
+ */
+std::vector<IterConstraint> MatchUpperBoundConstraints(PrimExpr pred) {
+  std::vector<IterConstraint> result;
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      result.emplace_back(lhs.Eval(), rhs.Eval(), 0);
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      result.emplace_back(lhs.Eval(), rhs.Eval(), 0);
+      break;
+    } else {
+      return std::vector<IterConstraint>();
+    }
+  }
+  return result;
+}
+
+/*! \brief Count the size of the PrimExpr. */
+class PrimExprSizeCounter : public ExprVisitor {

Review comment:
       move to src/analysis/expr_size.cc analysis.h  and expose as a function `size_t ExprSize(const PrimExpr& expr)`; document as number of expressions in the child

##########
File path: src/arith/iter_affine_map.cc
##########
@@ -381,31 +534,84 @@ class IterMapRewriter : public ExprMutator {
     if (!base_scale) return NullOpt;
     // check if it can be remapped into a fused pattern.
     PrimExpr expected_scale = base_scale.value();
-    for (size_t i = 0; i < expr->args.size(); ++i) {
+    for (size_t i = 0; i < expr->args.size();) {
+      // find j such that expr->args[j] has expected scale
       size_t j = i == 0 ? base_index : 0;
       for (; j < expr->args.size(); ++j) {
-        if (!visited[j] && CanProveEqual(expr->args[j]->scale, expected_scale)) break;
+        if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break;
       }
-      if (j == expr->args.size()) {
-        return NullOpt;
+      if (j == expr->args.size()) return NullOpt;
+      // look for the longest constrained iter started from expr->args[j]
+      // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
+      //          predicate: j*2 + k < 9
+      // We need to match the predicate in expr and adjust the expected scale,
+      // otherwise we expect the scale of i to be 2*5=10
+      Optional<IterSumExpr> constraint_to_match;
+      for (const IterSumExpr& iter : constrained_iters_flattened_) {
+        if (IterSplitEqual(expr->args[j], iter->args.back(), false)) {
+          // find a predicate started from expr->args[j]
+          if (!constraint_to_match ||
+              constraint_to_match.value()->args.size() < iter->args.size()) {
+            constraint_to_match = iter;
+          }
+        }
       }
-      visited[j] = true;
-      auto arg = expr->args[j];
-      arg.CopyOnWrite()->scale = div(expr->args[j]->scale, base_scale.value());
-      iters.push_back(arg);
-      expected_scale *= expr->args[j]->extent;
-    }
-    // update the iterator to use the canonicalized form
-    expr.CopyOnWrite()->args = Array<IterSplitExpr>(iters.rbegin(), iters.rend());
-    auto it = sum_fuse_map_.find(expr);
-    if (it != sum_fuse_map_.end()) return it->second;
-    auto mark = IterMark(expr, div(expected_scale, base_scale.value()));
-    IterSplitExpr split(mark, base_scale.value());
-    sum_fuse_map_[expr] = split;
-    return split;
+      if (constraint_to_match) {
+        // match the predicate and mark the iterators in the constraint_to_match as visited
+        // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
+        //          predicate = j*2 + k < 9
+        //          then j*2 + k matches the lower two splits of expr
+        for (auto it = constraint_to_match.value()->args.rbegin();
+             it != constraint_to_match.value()->args.rend(); ++it) {
+          size_t k = 0;
+          for (; k < expr->args.size(); ++k) {
+            if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
+              if (analyzer_->CanProveEqual((*it)->scale * expected_scale, expr->args[k]->scale))
+                break;
+            }
+          }
+          if (k == expr->args.size()) return NullOpt;
+          visited[k] = true;
+          flattened_iters.push_back(expr->args[k]);
+        }
+        auto iter = sum_fuse_map_.find(constraint_to_match.value());
+        ICHECK(iter != sum_fuse_map_.end());
+        IterMark iter_matched = iter->second;
+        grouped_iters.emplace_back(iter_matched, expected_scale);
+        expected_scale *= iter_matched->extent;
+        // move forward
+        i += constraint_to_match.value()->args.size();
+      } else {
+        // constraint_to_match not found, skip this iterator
+        visited[j] = true;
+        flattened_iters.push_back(expr->args[j]);
+        grouped_iters.push_back(expr->args[j]);
+        expected_scale *= expr->args[j]->extent;
+        ++i;
+      }
+    }
+    // Get the flattened form and structured form
+    // both forms have splits from outermost to innermost
+    IterSumExpr structured_form = expr, flattened_form = expr;
+    flattened_form.CopyOnWrite()->args =
+        Array<IterSplitExpr>(flattened_iters.rbegin(), flattened_iters.rend());
+    structured_form.CopyOnWrite()->args =
+        Array<IterSplitExpr>(grouped_iters.rbegin(), grouped_iters.rend());
+    auto it = sum_fuse_map_.find(flattened_form);
+    if (it != sum_fuse_map_.end()) {
+      // old iter
+      return IterSplitExpr(it->second, base_scale.value());
+    } else {
+      // new iter, form a new mark
+      IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value()));
+      sum_fuse_map_[flattened_form] = mark;
+      flattened_map_[structured_form] = flattened_form;
+      return IterSplitExpr(mark, base_scale.value());
+    }
   }
 
   bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) {
+    if (analyzer_->CanProveEqual(lhs, rhs)) return true;

Review comment:
       this is not needed. Mainly because the const integer comparison is a faster path while CanProveEqual contains a slower path

##########
File path: tests/python/unittest/test_arith_iter_affine_map.py
##########
@@ -107,32 +106,20 @@ def test_fuse():
     res = tvm.arith.detect_iter_map([y * 4 + x], var_dom([(x, 3), (y, 4)]))
     assert len(res) == 0
 
-    # simple stride pattern
-    res = tvm.arith.detect_iter_map([x * 4 + y * 2], var_dom([(x, 3), (y, 2)]))
-    assert len(res) == 1

Review comment:
       check if these testcases are intentionally deleted, i see they are moved, just want to make sure




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #7752: [ARITH] detect iter affine map with predicate

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #7752:
URL: https://github.com/apache/tvm/pull/7752#discussion_r602342033



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -459,27 +665,107 @@ class IterMapRewriter : public ExprMutator {
   }
 };
 
+/*! \brief An internal struct to represent range extent on iterators(iter < upper_bound). */
+struct IterConstraint {
+  // The expr of the iter
+  PrimExpr iter;
+  // The expr of the upper_bound
+  PrimExpr upper_bound;
+  // The size of the iter, which is the number of nodes
+  size_t expr_size = 0;
+
+  IterConstraint(PrimExpr iter, PrimExpr upper_bound, size_t size)
+      : iter(std::move(iter)), upper_bound(std::move(upper_bound)), expr_size(size) {}
+};
+
+/*!
+ * \brief Split the predicate into `(a < b) && (c < d) && ...`
+ * \param pred The predicate to be split.
+ * \return A list of pairs, each element of which are lhs and rhs of the '<' sign,
+ *         empty if the split failed.
+ */
+std::vector<IterConstraint> MatchUpperBoundConstraints(PrimExpr pred) {
+  std::vector<IterConstraint> result;
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      result.emplace_back(lhs.Eval(), rhs.Eval(), 0);
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      result.emplace_back(lhs.Eval(), rhs.Eval(), 0);
+      break;
+    } else {
+      return std::vector<IterConstraint>();
+    }
+  }
+  return result;
+}
+
+/*! \brief Count the size of the PrimExpr. */
+class PrimExprSizeCounter : public ExprVisitor {

Review comment:
       move to src/analysis/expr_complexity.cc analysis.h  and expose as a function `size_t ExprComplexity(const PrimExpr& expr)`; document as number of expressions in the child




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7752: [ARITH] detect iter affine map with predicate

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7752:
URL: https://github.com/apache/tvm/pull/7752#issuecomment-808248236


   Thanks @spectrometerHBH . @MasterJH5574 @hzfan @Hzfengsy please also help to take a look when you have time


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org