You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/27 00:29:38 UTC

[GitHub] [tvm] MasterJH5574 commented on a change in pull request #7752: [ARITH] detect iter affine map with predicate

MasterJH5574 commented on a change in pull request #7752:
URL: https://github.com/apache/tvm/pull/7752#discussion_r602639137



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -161,59 +160,116 @@ class IterMarkSplitCollector {
   }
 };
 
-// Rewriter to rewrite PrimExpr to IterMapExpr
-// when possible
+/*! \brief Rewriter to rewrite PrimExpr to IterMapExpr when possible */
 class IterMapRewriter : public ExprMutator {
  public:
   using Parent = ExprMutator;
 
   explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
       : analyzer_(analyzer) {
     for (auto kv : input_iters) {
-      const auto& vrng = kv.second;
-      if (is_zero(vrng->min)) {
-        IterMark mark(kv.first, vrng->extent);
-        var_map_[kv.first] = IterSplitExpr(mark);
+      const Var& var = kv.first;
+      const Range& vrng = kv.second;
+      if (is_one(vrng->extent)) {
+        var_map_[kv.first] = IterSumExpr({}, vrng->min);

Review comment:
       ```suggestion
           var_map_[var] = IterSumExpr({}, vrng->min);
   ```

##########
File path: src/arith/iter_affine_map.cc
##########
@@ -161,59 +160,116 @@ class IterMarkSplitCollector {
   }
 };
 
-// Rewriter to rewrite PrimExpr to IterMapExpr
-// when possible
+/*! \brief Rewriter to rewrite PrimExpr to IterMapExpr when possible */
 class IterMapRewriter : public ExprMutator {
  public:
   using Parent = ExprMutator;
 
   explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
       : analyzer_(analyzer) {
     for (auto kv : input_iters) {
-      const auto& vrng = kv.second;
-      if (is_zero(vrng->min)) {
-        IterMark mark(kv.first, vrng->extent);
-        var_map_[kv.first] = IterSplitExpr(mark);
+      const Var& var = kv.first;
+      const Range& vrng = kv.second;
+      if (is_one(vrng->extent)) {
+        var_map_[kv.first] = IterSumExpr({}, vrng->min);
+      } else if (is_zero(vrng->min)) {
+        IterMark mark(var, vrng->extent);
+        var_map_[var] = IterSplitExpr(mark);
         input_marks_.push_back(mark);
       } else {
-        IterMark mark(kv.first - vrng->min, vrng->extent);
-        auto sum_expr = ToIterSumExpr(IterSplitExpr(mark));
+        IterMark mark(var - vrng->min, vrng->extent);
+        IterSumExpr sum_expr = ToIterSumExpr(IterSplitExpr(mark));
         sum_expr.CopyOnWrite()->base = vrng->min;
-        var_map_[kv.first] = sum_expr;
+        var_map_[var] = sum_expr;
         input_marks_.push_back(mark);
       }
     }
   }
 
   size_t unresolved_count() const { return unresolved_count_; }
 
-  IterSumExpr Rewrite(PrimExpr expr) {
+  IterSumExpr Rewrite(const PrimExpr& expr) {
     return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
   }
 
-  bool CheckBijective(const Array<IterSumExpr>& indices) {
-    // This function checks two conditions:
-    // - C0: Each iter mark should be fully covered by non-overlapping splits.
-    // - C1: All of the input iterators are used.
-    //
-    // Example: given x in [0, 8) y in [0, 6)
-    // - indices = [x, x+1, y] won't pass because x and x+1 contribute
-    //   two splits that overlaps with each other.
-    // - indices = [x / 4, x % 4, y] will pass because x / 4 and x % 4
-    //   contribute two non-overlapping splits that covers x.
-    // - indices = [x / 4, x % 4] won't pass because y is not used.
-    //
+  IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
+                                    const PrimExpr& predicate_induced_extent) {
+    return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_extent);
+  }
+
+  /*!
+   * \brief If require_bijective is true, this function checks two conditions:
+   *   - C0: Each iter mark should be fully covered by non-overlapping splits.
+   *   - C1: All of the input iterators are used.
+   *   Example: given x in [0, 8) y in [0, 6)
+   *   - bindings = [x, x + 1, y] won't pass because x and x+1 contribute
+   *     two splits that overlaps with each other.
+   *   - bindings = [x / 4, x % 4, y] will pass because x / 4 and x % 4
+   *     contribute two non-overlapping splits that covers x.
+   *   - bindings = [x / 4, x % 4] won't pass because y is not used.
+   *
+   *   If require_bijective is false, this function checks one condition:
+   *   - C0: Each iter mark exists chance to be fully covered by non-overlapping splits.

Review comment:
       ```suggestion
      *   - C0: Each iter mark has a chance to be fully covered by non-overlapping splits.
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org