You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/10/11 22:53:54 UTC

[GitHub] [incubator-tvm] tqchen opened a new pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

tqchen opened a new pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667


   The loop transformations (split, fuse) create bijective
   maps from a collection of source iterators to target iterators.
   
   DetectIterMap is a function that detects such bijective mappings
   from the lowered index expression.
   
   We choose the term quasi affine to be consistent with the
   terminology used by in polyhedral compilation.
   DetectIterMap can handle symbolic integers(in split/fuse) to some extent.
   
   The utility can be useful in detecting loop transformation
   patterns and data layout change patterns in TIR.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667#discussion_r504330166



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -0,0 +1,703 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/arith/iter_affine_map.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/iter_affine_map.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
+
+#include "../support/util.h"
+#include "const_fold.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+IterMark::IterMark(PrimExpr source, PrimExpr extent) {
+  auto n = make_object<IterMarkNode>();
+  n->source = std::move(source);
+  n->extent = std::move(extent);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) {
+  return IterMark(source, extent);
+});
+
+TVM_REGISTER_NODE_TYPE(IterMarkNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterMarkNode*>(node.get());
+      p->stream << "IterMark(" << op->source << ", extent=" << op->extent;
+    });
+
+IterSplitExpr::IterSplitExpr(IterMark source) {
+  auto n = make_object<IterSplitExprNode>();
+  auto one = make_const(source->source->dtype, 1);
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->extent = n->source->extent;
+  n->lower_factor = one;
+  n->scale = one;
+  data_ = std::move(n);
+}
+
+IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent,
+                             PrimExpr scale) {
+  auto n = make_object<IterSplitExprNode>();
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->lower_factor = std::move(lower_factor);
+  n->extent = std::move(extent);
+  n->scale = std::move(scale);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSplitExpr")
+    .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) {
+      return IterSplitExpr(source, lower_factor, extent, scale);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSplitExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSplitExprNode*>(node.get());
+      p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor
+                << ", extent=" << op->extent << ", scale=" << op->scale;
+    });
+
+IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) {
+  auto n = make_object<IterSumExprNode>();
+  n->dtype = base->dtype;
+  n->args = std::move(args);
+  n->base = std::move(base);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSumExpr")
+    .set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) {
+      return IterSumExpr(args, base);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSumExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSumExprNode*>(node.get());
+      p->stream << "IterSum(" << op->args << ", " << op->base << ")";
+    });
+
+/*!
+ * \brief Util to check if all splits in the sumexpr are
+ *        independent and complete (covers all the original iter space).
+ *
+ */
+class IterMarkSplitCollector {
+ public:
+  // mark all IterMarks that are visited.
+  std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
+  // each iter mark to its outgoing splits that are referenced.
+  std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual>
+      mark2splits_;
+  /*!
+   * \brief Collect all mark2splits recursively from indices.
+   * \param indices The iterator of interest.
+   */
+  void Collect(const Array<IterSumExpr>& indices) {
+    for (IterSumExpr sum_expr : indices) {
+      for (IterSplitExpr split : sum_expr->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+
+  void CollectInternal(const IterMark& mark) {
+    if (visited_.count(mark)) return;
+    visited_.insert(mark);
+    if (auto* op = mark->source.as<IterSumExprNode>()) {
+      for (IterSplitExpr split : op->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+};
+
+// Rewriter to rewrite oinformations in iter
+class IterMapRewriter : public ExprMutator {
+ public:
+  using Parent = ExprMutator;
+
+  explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
+      : analyzer_(analyzer) {
+    for (auto kv : input_iters) {
+      const auto& vrng = kv.second;
+      if (is_zero(vrng->min)) {
+        IterMark mark(kv.first, vrng->extent);
+        var_map_[kv.first] = IterSplitExpr(mark);
+        input_marks_.push_back(mark);
+      } else {
+        IterMark mark(kv.first - vrng->min, vrng->extent);
+        auto sum_expr = ToIterSumExpr(IterSplitExpr(mark));
+        sum_expr.CopyOnWrite()->base = vrng->min;
+        var_map_[kv.first] = sum_expr;
+        input_marks_.push_back(mark);
+      }
+    }
+  }
+
+  size_t unresolved_count() const { return unresolved_count_; }
+
+  IterSumExpr Rewrite(PrimExpr expr) {
+    return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
+  }
+
+  bool CheckBijective(const Array<IterSumExpr>& indices) {
+    IterMarkSplitCollector collector;
+    // We can check that for each iter mark:
+    // All the splits that refers to the itermark covers its extent.
+    // The splits do not overlap with each other.
+    collector.Collect(indices);
+    for (IterMark mark : collector.visited_) {
+      if (TryNormalizeSplits(mark, collector.mark2splits_[mark]).size() == 0) return false;
+    }
+    // all input marks must be visited
+    for (auto mark : input_marks_) {
+      if (collector.visited_.count(mark) == 0) return false;
+    }
+    return true;
+  }
+
+  // override the original mutate function.
+  PrimExpr VisitExpr(const PrimExpr& input_expr) final {
+    auto expr = ExprMutator::VisitExpr(input_expr);
+    if (expr->IsInstance<IterMapExprNode>()) {
+      ++unresolved_count_;
+    }
+    return expr;
+  }
+
+  // Normal mutation without normalization.
+  PrimExpr DirectMutate(PrimExpr expr) { return ExprMutator::VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const VarNode* op) final;
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const FloorDivNode* op) final;
+  PrimExpr VisitExpr_(const FloorModNode* op) final;
+
+ private:
+  // temp hash for de-duplication purposes.
+  struct IterSumHash {
+    size_t operator()(const IterSumExpr& value) const {
+      // for now only hash on source index.
+      size_t hash = value->args.size();
+      for (size_t i = 0; i < value->args.size(); ++i) {
+        hash = support::HashCombine(hash, std::hash<const Object*>()(value->args[i]->source.get()));
+      }
+      return hash;
+    }
+  };
+
+  struct IterSumEqual {
+    bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const {
+      tir::ExprDeepEqual equal;
+      if (lhs->args.size() != rhs->args.size()) return false;
+      if (!equal(lhs->base, rhs->base)) return false;
+      for (size_t i = 0; i < lhs->args.size(); ++i) {
+        auto lvalue = lhs->args[i];
+        auto rvalue = lhs->args[i];
+        if (!lvalue->source.same_as(rvalue->source)) return false;
+        if (!equal(lvalue->lower_factor, rvalue->lower_factor)) return false;
+        if (!equal(lvalue->scale, rvalue->scale)) return false;
+        if (!equal(lvalue->extent, rvalue->extent)) return false;
+      }
+      return true;
+    }
+  };
+
+  // Internal analyzer
+  Analyzer* analyzer_;
+  // Counter to keep track of unresolved cases.
+  int unresolved_count_{0};
+  // The var map
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
+  // input iter marks
+  std::vector<IterMark> input_marks_;
+  // The canonical map for sum
+  std::unordered_map<IterSumExpr, IterSplitExpr, IterSumHash, IterSumEqual> sum_fuse_map_;
+
+  /*!
+   * \brief Verify that splits fully covers mark in a non-overlapping fashion.
+   *        If verification passes, return splits from outermost to inner most order.
+   *        If not, return an empty array
+   * \param mark The iterator of interest.
+   * \param splits The splits to be verified.
+   * \return The normalized splits.
+   */
+  Array<IterSplitExpr> TryNormalizeSplits(const IterMark& mark,
+                                          const std::vector<IterSplitExpr>& splits) {
+    std::vector<bool> used(splits.size(), false);
+    std::vector<IterSplitExpr> iters;
+    PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
+
+    for (size_t i = 0; i < splits.size(); ++i) {
+      size_t j = 0;
+      for (; j < splits.size(); ++j) {
+        if (used[j]) continue;
+        if (!used[j] && CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break;
+      }
+      if (j == splits.size()) {
+        return Array<IterSplitExpr>();
+      }
+      used[j] = true;
+      iters.push_back(splits[j]);
+      expected_lower_factor *= splits[j]->extent;
+    }
+    if (!CanProveEqual(expected_lower_factor, mark->extent)) return Array<IterSplitExpr>();
+    return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+  }
+
+  /*!
+   * \brief Normalize expr to an iterator + offset.
+   * \param expr The input expression.
+   * \return The Normalized expression.
+   */
+  IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
+    if (expr->args.size() <= 1) return expr;
+    PrimExpr base = expr->base;
+    expr.CopyOnWrite()->base = make_zero(expr->dtype);
+    auto opt = TryFuseIters(expr);
+    expr.CopyOnWrite()->base = base;
+    if (opt) {
+      expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()});
+      return expr;
+    } else {
+      ++unresolved_count_;
+      return expr;
+    }
+  }
+
+  bool CanProveEqual(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value == crhs->value;
+    return analyzer_->CanProve(lhs - rhs == 0);
+  }
+
+  /*!
+   * \brief Create a IterSumExpr from expr.
+   * \param expr The input expr.
+   * \return The transformed IterSumExpr.
+   */
+  IterSumExpr ToIterSumExpr(PrimExpr expr) {
+    if (const auto* op = expr.as<IterSumExprNode>()) {
+      return GetRef<IterSumExpr>(op);
+    } else if (const auto* op = expr.as<IterSplitExprNode>()) {
+      return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype));
+    } else {
+      CHECK(!expr->IsInstance<IterMapExprNode>());
+      return IterSumExpr({}, expr);
+    }
+  }
+
+  // Try to normalize IterSum into a fused IterMark
+  // return a corresponding splitexpr if needed.
+  Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) {
+    if (!is_zero(expr->base)) return NullOpt;
+    if (expr->args.size() == 1) return expr->args[0];
+    // select the iterators in order
+    std::vector<bool> visited(expr->args.size(), false);
+    std::vector<IterSplitExpr> iters;
+    iters.reserve(expr->args.size());
+    // canonicalize the expression
+    // check if it can be remapped into a fused pattern.
+    PrimExpr expected_scale = make_const(expr->base->dtype, 1);
+    for (size_t i = 0; i < expr->args.size(); ++i) {
+      size_t j = 0;
+      for (; j < expr->args.size(); ++j) {
+        if (!visited[j] && CanProveEqual(expr->args[j]->scale, expected_scale)) break;
+      }
+      if (j == expr->args.size()) {
+        return NullOpt;
+      }
+      visited[j] = true;
+      iters.push_back(expr->args[j]);
+      expected_scale *= expr->args[j]->extent;
+    }
+    // update the iterator to use the canonicalized form
+    expr.CopyOnWrite()->args = Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+    auto it = sum_fuse_map_.find(expr);
+    if (it != sum_fuse_map_.end()) return it->second;
+    auto mark = IterMark(expr, expected_scale);
+    IterSplitExpr split(mark);
+    sum_fuse_map_[expr] = split;
+    return split;
+  }
+
+  bool CanProveDivisible(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value % crhs->value == 0;
+    return analyzer_->CanProve(floormod(lhs, rhs) == 0);
+  }
+
+  PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
+  PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
+    tir::ExprDeepEqual equal;
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) &&
+          equal(lvalue->extent, rhs->extent)) {
+        if (sign > 0) {
+          rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale;
+        } else {
+          rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale;
+        }
+        lhs->args.Set(i, rhs);
+        return;
+      }
+    }
+    if (sign > 0) {
+      lhs->args.push_back(rhs);
+    } else {
+      rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale;
+      lhs->args.push_back(rhs);
+    }
+  }
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSumExpr rhs, int sign) {
+    for (size_t i = 0; i < rhs->args.size(); ++i) {
+      AddToLhs(lhs, rhs->args[i], sign);
+    }
+    if (sign > 0) {
+      lhs->base += rhs->base;
+    } else {
+      lhs->base -= rhs->base;
+    }
+  }
+
+  static void MulToLhs(IterSumExprNode* lhs, PrimExpr rhs) {
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      lvalue.CopyOnWrite()->scale *= rhs;
+      lhs->args.Set(i, lvalue);
+    }
+    lhs->base *= rhs;
+  }
+};
+
+Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
+                                 arith::Analyzer* analyzer) {
+  // Overall detection algorithm is divided into two steps:
+  // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
+  // - Step1: IterIndependenceChecker checks if the iterator are independent.
+  IterMapRewriter rewriter(analyzer, input_iters);
+  Array<IterSumExpr> results;
+
+  for (PrimExpr value : indices) {
+    results.push_back(rewriter.Rewrite(value));
+    if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
+  }
+  if (!rewriter.CheckBijective(results)) return Array<IterSumExpr>();

Review comment:
       yes, in this case, x and x+1 will contribute two Split entries to mark2split of x, (one contributed by x and another by x+1)
   ```
   mark2split[x]= [ Split(x, extent), Split(x, extent)]
   ```
   
   And it will results in an error because the two splits overlaps with each other




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on a change in pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667#discussion_r504010733



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -0,0 +1,703 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/arith/iter_affine_map.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/iter_affine_map.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
+
+#include "../support/util.h"
+#include "const_fold.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+IterMark::IterMark(PrimExpr source, PrimExpr extent) {
+  auto n = make_object<IterMarkNode>();
+  n->source = std::move(source);
+  n->extent = std::move(extent);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) {
+  return IterMark(source, extent);
+});
+
+TVM_REGISTER_NODE_TYPE(IterMarkNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterMarkNode*>(node.get());
+      p->stream << "IterMark(" << op->source << ", extent=" << op->extent;
+    });
+
+IterSplitExpr::IterSplitExpr(IterMark source) {
+  auto n = make_object<IterSplitExprNode>();
+  auto one = make_const(source->source->dtype, 1);
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->extent = n->source->extent;
+  n->lower_factor = one;
+  n->scale = one;
+  data_ = std::move(n);
+}
+
+IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent,
+                             PrimExpr scale) {
+  auto n = make_object<IterSplitExprNode>();
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->lower_factor = std::move(lower_factor);
+  n->extent = std::move(extent);
+  n->scale = std::move(scale);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSplitExpr")
+    .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) {
+      return IterSplitExpr(source, lower_factor, extent, scale);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSplitExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSplitExprNode*>(node.get());
+      p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor
+                << ", extent=" << op->extent << ", scale=" << op->scale;
+    });
+
+IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) {
+  auto n = make_object<IterSumExprNode>();
+  n->dtype = base->dtype;
+  n->args = std::move(args);
+  n->base = std::move(base);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSumExpr")
+    .set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) {
+      return IterSumExpr(args, base);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSumExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSumExprNode*>(node.get());
+      p->stream << "IterSum(" << op->args << ", " << op->base << ")";
+    });
+
+/*!
+ * \brief Util to check if all splits in the sumexpr are
+ *        independent and complete (covers all the original iter space).
+ *
+ */
+class IterMarkSplitCollector {
+ public:
+  // mark all IterMarks that are visited.
+  std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
+  // each iter mark to its outgoing splits that are referenced.
+  std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual>
+      mark2splits_;
+  /*!
+   * \brief Collect all mark2splits recursively from indices.
+   * \param indices The iterator of interest.
+   */
+  void Collect(const Array<IterSumExpr>& indices) {
+    for (IterSumExpr sum_expr : indices) {
+      for (IterSplitExpr split : sum_expr->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+
+  void CollectInternal(const IterMark& mark) {
+    if (visited_.count(mark)) return;
+    visited_.insert(mark);
+    if (auto* op = mark->source.as<IterSumExprNode>()) {
+      for (IterSplitExpr split : op->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+};
+
+// Rewriter to rewrite oinformations in iter
+class IterMapRewriter : public ExprMutator {
+ public:
+  using Parent = ExprMutator;
+
+  explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
+      : analyzer_(analyzer) {
+    for (auto kv : input_iters) {
+      const auto& vrng = kv.second;
+      if (is_zero(vrng->min)) {
+        IterMark mark(kv.first, vrng->extent);
+        var_map_[kv.first] = IterSplitExpr(mark);
+        input_marks_.push_back(mark);
+      } else {
+        IterMark mark(kv.first - vrng->min, vrng->extent);
+        auto sum_expr = ToIterSumExpr(IterSplitExpr(mark));
+        sum_expr.CopyOnWrite()->base = vrng->min;
+        var_map_[kv.first] = sum_expr;
+        input_marks_.push_back(mark);
+      }
+    }
+  }
+
+  size_t unresolved_count() const { return unresolved_count_; }
+
+  IterSumExpr Rewrite(PrimExpr expr) {
+    return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
+  }
+
+  bool CheckBijective(const Array<IterSumExpr>& indices) {
+    IterMarkSplitCollector collector;
+    // We can check that for each iter mark:
+    // All the splits that refers to the itermark covers its extent.
+    // The splits do not overlap with each other.
+    collector.Collect(indices);
+    for (IterMark mark : collector.visited_) {
+      if (TryNormalizeSplits(mark, collector.mark2splits_[mark]).size() == 0) return false;
+    }
+    // all input marks must be visited
+    for (auto mark : input_marks_) {
+      if (collector.visited_.count(mark) == 0) return false;
+    }
+    return true;
+  }
+
+  // override the original mutate function.
+  PrimExpr VisitExpr(const PrimExpr& input_expr) final {
+    auto expr = ExprMutator::VisitExpr(input_expr);
+    if (expr->IsInstance<IterMapExprNode>()) {
+      ++unresolved_count_;
+    }
+    return expr;
+  }
+
+  // Normal mutation without normalization.
+  PrimExpr DirectMutate(PrimExpr expr) { return ExprMutator::VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const VarNode* op) final;
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const FloorDivNode* op) final;
+  PrimExpr VisitExpr_(const FloorModNode* op) final;
+
+ private:
+  // temp hash for de-duplication purposes.
+  struct IterSumHash {
+    size_t operator()(const IterSumExpr& value) const {
+      // for now only hash on source index.
+      size_t hash = value->args.size();
+      for (size_t i = 0; i < value->args.size(); ++i) {
+        hash = support::HashCombine(hash, std::hash<const Object*>()(value->args[i]->source.get()));
+      }
+      return hash;
+    }
+  };
+
+  struct IterSumEqual {
+    bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const {
+      tir::ExprDeepEqual equal;
+      if (lhs->args.size() != rhs->args.size()) return false;
+      if (!equal(lhs->base, rhs->base)) return false;
+      for (size_t i = 0; i < lhs->args.size(); ++i) {
+        auto lvalue = lhs->args[i];
+        auto rvalue = lhs->args[i];
+        if (!lvalue->source.same_as(rvalue->source)) return false;
+        if (!equal(lvalue->lower_factor, rvalue->lower_factor)) return false;
+        if (!equal(lvalue->scale, rvalue->scale)) return false;
+        if (!equal(lvalue->extent, rvalue->extent)) return false;
+      }
+      return true;
+    }
+  };
+
+  // Internal analyzer
+  Analyzer* analyzer_;
+  // Counter to keep track of unresolved cases.
+  int unresolved_count_{0};
+  // The var map
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
+  // input iter marks
+  std::vector<IterMark> input_marks_;
+  // The canonical map for sum
+  std::unordered_map<IterSumExpr, IterSplitExpr, IterSumHash, IterSumEqual> sum_fuse_map_;
+
+  /*!
+   * \brief Verify that splits fully covers mark in a non-overlapping fashion.
+   *        If verification passes, return splits from outermost to inner most order.
+   *        If not, return an empty array
+   * \param mark The iterator of interest.
+   * \param splits The splits to be verified.
+   * \return The normalized splits.
+   */
+  Array<IterSplitExpr> TryNormalizeSplits(const IterMark& mark,
+                                          const std::vector<IterSplitExpr>& splits) {
+    std::vector<bool> used(splits.size(), false);
+    std::vector<IterSplitExpr> iters;
+    PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
+
+    for (size_t i = 0; i < splits.size(); ++i) {
+      size_t j = 0;
+      for (; j < splits.size(); ++j) {
+        if (used[j]) continue;
+        if (!used[j] && CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break;
+      }
+      if (j == splits.size()) {
+        return Array<IterSplitExpr>();
+      }
+      used[j] = true;
+      iters.push_back(splits[j]);
+      expected_lower_factor *= splits[j]->extent;
+    }
+    if (!CanProveEqual(expected_lower_factor, mark->extent)) return Array<IterSplitExpr>();
+    return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+  }
+
+  /*!
+   * \brief Normalize expr to an iterator + offset.
+   * \param expr The input expression.
+   * \return The Normalized expression.
+   */
+  IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
+    if (expr->args.size() <= 1) return expr;
+    PrimExpr base = expr->base;
+    expr.CopyOnWrite()->base = make_zero(expr->dtype);
+    auto opt = TryFuseIters(expr);
+    expr.CopyOnWrite()->base = base;
+    if (opt) {
+      expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()});
+      return expr;
+    } else {
+      ++unresolved_count_;
+      return expr;
+    }
+  }
+
+  bool CanProveEqual(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value == crhs->value;
+    return analyzer_->CanProve(lhs - rhs == 0);
+  }
+
+  /*!
+   * \brief Create a IterSumExpr from expr.
+   * \param expr The input expr.
+   * \return The transformed IterSumExpr.
+   */
+  IterSumExpr ToIterSumExpr(PrimExpr expr) {
+    if (const auto* op = expr.as<IterSumExprNode>()) {
+      return GetRef<IterSumExpr>(op);
+    } else if (const auto* op = expr.as<IterSplitExprNode>()) {
+      return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype));
+    } else {
+      CHECK(!expr->IsInstance<IterMapExprNode>());
+      return IterSumExpr({}, expr);
+    }
+  }
+
+  // Try to normalize IterSum into a fused IterMark
+  // return a corresponding splitexpr if needed.
+  Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) {
+    if (!is_zero(expr->base)) return NullOpt;
+    if (expr->args.size() == 1) return expr->args[0];
+    // select the iterators in order
+    std::vector<bool> visited(expr->args.size(), false);
+    std::vector<IterSplitExpr> iters;
+    iters.reserve(expr->args.size());
+    // canonicalize the expression
+    // check if it can be remapped into a fused pattern.
+    PrimExpr expected_scale = make_const(expr->base->dtype, 1);
+    for (size_t i = 0; i < expr->args.size(); ++i) {
+      size_t j = 0;
+      for (; j < expr->args.size(); ++j) {
+        if (!visited[j] && CanProveEqual(expr->args[j]->scale, expected_scale)) break;
+      }
+      if (j == expr->args.size()) {
+        return NullOpt;
+      }
+      visited[j] = true;
+      iters.push_back(expr->args[j]);
+      expected_scale *= expr->args[j]->extent;
+    }
+    // update the iterator to use the canonicalized form
+    expr.CopyOnWrite()->args = Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+    auto it = sum_fuse_map_.find(expr);
+    if (it != sum_fuse_map_.end()) return it->second;
+    auto mark = IterMark(expr, expected_scale);
+    IterSplitExpr split(mark);
+    sum_fuse_map_[expr] = split;
+    return split;
+  }
+
+  bool CanProveDivisible(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value % crhs->value == 0;
+    return analyzer_->CanProve(floormod(lhs, rhs) == 0);
+  }
+
+  PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
+  PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
+    tir::ExprDeepEqual equal;
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) &&
+          equal(lvalue->extent, rhs->extent)) {
+        if (sign > 0) {
+          rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale;
+        } else {
+          rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale;
+        }
+        lhs->args.Set(i, rhs);
+        return;
+      }
+    }
+    if (sign > 0) {
+      lhs->args.push_back(rhs);
+    } else {
+      rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale;
+      lhs->args.push_back(rhs);
+    }
+  }
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSumExpr rhs, int sign) {
+    for (size_t i = 0; i < rhs->args.size(); ++i) {
+      AddToLhs(lhs, rhs->args[i], sign);
+    }
+    if (sign > 0) {
+      lhs->base += rhs->base;
+    } else {
+      lhs->base -= rhs->base;
+    }
+  }
+
+  static void MulToLhs(IterSumExprNode* lhs, PrimExpr rhs) {
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      lvalue.CopyOnWrite()->scale *= rhs;
+      lhs->args.Set(i, lvalue);
+    }
+    lhs->base *= rhs;
+  }
+};
+
+Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
+                                 arith::Analyzer* analyzer) {
+  // Overall detection algorithm is divided into two steps:
+  // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
+  // - Step1: IterIndependenceChecker checks if the iterator are independent.
+  IterMapRewriter rewriter(analyzer, input_iters);
+  Array<IterSumExpr> results;
+
+  for (PrimExpr value : indices) {
+    results.push_back(rewriter.Rewrite(value));
+    if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
+  }
+  if (!rewriter.CheckBijective(results)) return Array<IterSumExpr>();

Review comment:
       Currently, the Independence check seems to only check that all input marks are iterated. Is that sufficient to ensure independence?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667#discussion_r504330166



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -0,0 +1,703 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/arith/iter_affine_map.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/iter_affine_map.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
+
+#include "../support/util.h"
+#include "const_fold.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+IterMark::IterMark(PrimExpr source, PrimExpr extent) {
+  auto n = make_object<IterMarkNode>();
+  n->source = std::move(source);
+  n->extent = std::move(extent);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) {
+  return IterMark(source, extent);
+});
+
+TVM_REGISTER_NODE_TYPE(IterMarkNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterMarkNode*>(node.get());
+      p->stream << "IterMark(" << op->source << ", extent=" << op->extent;
+    });
+
+IterSplitExpr::IterSplitExpr(IterMark source) {
+  auto n = make_object<IterSplitExprNode>();
+  auto one = make_const(source->source->dtype, 1);
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->extent = n->source->extent;
+  n->lower_factor = one;
+  n->scale = one;
+  data_ = std::move(n);
+}
+
+IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent,
+                             PrimExpr scale) {
+  auto n = make_object<IterSplitExprNode>();
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->lower_factor = std::move(lower_factor);
+  n->extent = std::move(extent);
+  n->scale = std::move(scale);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSplitExpr")
+    .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) {
+      return IterSplitExpr(source, lower_factor, extent, scale);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSplitExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSplitExprNode*>(node.get());
+      p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor
+                << ", extent=" << op->extent << ", scale=" << op->scale;
+    });
+
+IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) {
+  auto n = make_object<IterSumExprNode>();
+  n->dtype = base->dtype;
+  n->args = std::move(args);
+  n->base = std::move(base);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSumExpr")
+    .set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) {
+      return IterSumExpr(args, base);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSumExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSumExprNode*>(node.get());
+      p->stream << "IterSum(" << op->args << ", " << op->base << ")";
+    });
+
+/*!
+ * \brief Util to check if all splits in the sumexpr are
+ *        independent and complete (covers all the original iter space).
+ *
+ */
+class IterMarkSplitCollector {
+ public:
+  // mark all IterMarks that are visited.
+  std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
+  // each iter mark to its outgoing splits that are referenced.
+  std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual>
+      mark2splits_;
+  /*!
+   * \brief Collect all mark2splits recursively from indices.
+   * \param indices The iterator of interest.
+   */
+  void Collect(const Array<IterSumExpr>& indices) {
+    for (IterSumExpr sum_expr : indices) {
+      for (IterSplitExpr split : sum_expr->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+
+  void CollectInternal(const IterMark& mark) {
+    if (visited_.count(mark)) return;
+    visited_.insert(mark);
+    if (auto* op = mark->source.as<IterSumExprNode>()) {
+      for (IterSplitExpr split : op->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+};
+
+// Rewriter to rewrite oinformations in iter
+class IterMapRewriter : public ExprMutator {
+ public:
+  using Parent = ExprMutator;
+
+  explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
+      : analyzer_(analyzer) {
+    for (auto kv : input_iters) {
+      const auto& vrng = kv.second;
+      if (is_zero(vrng->min)) {
+        IterMark mark(kv.first, vrng->extent);
+        var_map_[kv.first] = IterSplitExpr(mark);
+        input_marks_.push_back(mark);
+      } else {
+        IterMark mark(kv.first - vrng->min, vrng->extent);
+        auto sum_expr = ToIterSumExpr(IterSplitExpr(mark));
+        sum_expr.CopyOnWrite()->base = vrng->min;
+        var_map_[kv.first] = sum_expr;
+        input_marks_.push_back(mark);
+      }
+    }
+  }
+
+  size_t unresolved_count() const { return unresolved_count_; }
+
+  IterSumExpr Rewrite(PrimExpr expr) {
+    return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
+  }
+
+  bool CheckBijective(const Array<IterSumExpr>& indices) {
+    IterMarkSplitCollector collector;
+    // We can check that for each iter mark:
+    // All the splits that refers to the itermark covers its extent.
+    // The splits do not overlap with each other.
+    collector.Collect(indices);
+    for (IterMark mark : collector.visited_) {
+      if (TryNormalizeSplits(mark, collector.mark2splits_[mark]).size() == 0) return false;
+    }
+    // all input marks must be visited
+    for (auto mark : input_marks_) {
+      if (collector.visited_.count(mark) == 0) return false;
+    }
+    return true;
+  }
+
+  // override the original mutate function.
+  PrimExpr VisitExpr(const PrimExpr& input_expr) final {
+    auto expr = ExprMutator::VisitExpr(input_expr);
+    if (expr->IsInstance<IterMapExprNode>()) {
+      ++unresolved_count_;
+    }
+    return expr;
+  }
+
+  // Normal mutation without normalization.
+  PrimExpr DirectMutate(PrimExpr expr) { return ExprMutator::VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const VarNode* op) final;
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const FloorDivNode* op) final;
+  PrimExpr VisitExpr_(const FloorModNode* op) final;
+
+ private:
+  // temp hash for de-duplication purposes.
+  struct IterSumHash {
+    size_t operator()(const IterSumExpr& value) const {
+      // for now only hash on source index.
+      size_t hash = value->args.size();
+      for (size_t i = 0; i < value->args.size(); ++i) {
+        hash = support::HashCombine(hash, std::hash<const Object*>()(value->args[i]->source.get()));
+      }
+      return hash;
+    }
+  };
+
+  struct IterSumEqual {
+    bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const {
+      tir::ExprDeepEqual equal;
+      if (lhs->args.size() != rhs->args.size()) return false;
+      if (!equal(lhs->base, rhs->base)) return false;
+      for (size_t i = 0; i < lhs->args.size(); ++i) {
+        auto lvalue = lhs->args[i];
+        auto rvalue = lhs->args[i];
+        if (!lvalue->source.same_as(rvalue->source)) return false;
+        if (!equal(lvalue->lower_factor, rvalue->lower_factor)) return false;
+        if (!equal(lvalue->scale, rvalue->scale)) return false;
+        if (!equal(lvalue->extent, rvalue->extent)) return false;
+      }
+      return true;
+    }
+  };
+
+  // Internal analyzer
+  Analyzer* analyzer_;
+  // Counter to keep track of unresolved cases.
+  int unresolved_count_{0};
+  // The var map
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
+  // input iter marks
+  std::vector<IterMark> input_marks_;
+  // The canonical map for sum
+  std::unordered_map<IterSumExpr, IterSplitExpr, IterSumHash, IterSumEqual> sum_fuse_map_;
+
+  /*!
+   * \brief Verify that splits fully covers mark in a non-overlapping fashion.
+   *        If verification passes, return splits from outermost to inner most order.
+   *        If not, return an empty array
+   * \param mark The iterator of interest.
+   * \param splits The splits to be verified.
+   * \return The normalized splits.
+   */
+  Array<IterSplitExpr> TryNormalizeSplits(const IterMark& mark,
+                                          const std::vector<IterSplitExpr>& splits) {
+    std::vector<bool> used(splits.size(), false);
+    std::vector<IterSplitExpr> iters;
+    PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
+
+    for (size_t i = 0; i < splits.size(); ++i) {
+      size_t j = 0;
+      for (; j < splits.size(); ++j) {
+        if (used[j]) continue;
+        if (!used[j] && CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break;
+      }
+      if (j == splits.size()) {
+        return Array<IterSplitExpr>();
+      }
+      used[j] = true;
+      iters.push_back(splits[j]);
+      expected_lower_factor *= splits[j]->extent;
+    }
+    if (!CanProveEqual(expected_lower_factor, mark->extent)) return Array<IterSplitExpr>();
+    return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+  }
+
+  /*!
+   * \brief Normalize expr to an iterator + offset.
+   * \param expr The input expression.
+   * \return The Normalized expression.
+   */
+  IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
+    if (expr->args.size() <= 1) return expr;
+    PrimExpr base = expr->base;
+    expr.CopyOnWrite()->base = make_zero(expr->dtype);
+    auto opt = TryFuseIters(expr);
+    expr.CopyOnWrite()->base = base;
+    if (opt) {
+      expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()});
+      return expr;
+    } else {
+      ++unresolved_count_;
+      return expr;
+    }
+  }
+
+  bool CanProveEqual(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value == crhs->value;
+    return analyzer_->CanProve(lhs - rhs == 0);
+  }
+
+  /*!
+   * \brief Create a IterSumExpr from expr.
+   * \param expr The input expr.
+   * \return The transformed IterSumExpr.
+   */
+  IterSumExpr ToIterSumExpr(PrimExpr expr) {
+    if (const auto* op = expr.as<IterSumExprNode>()) {
+      return GetRef<IterSumExpr>(op);
+    } else if (const auto* op = expr.as<IterSplitExprNode>()) {
+      return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype));
+    } else {
+      CHECK(!expr->IsInstance<IterMapExprNode>());
+      return IterSumExpr({}, expr);
+    }
+  }
+
+  // Try to normalize IterSum into a fused IterMark
+  // return a corresponding splitexpr if needed.
+  Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) {
+    if (!is_zero(expr->base)) return NullOpt;
+    if (expr->args.size() == 1) return expr->args[0];
+    // select the iterators in order
+    std::vector<bool> visited(expr->args.size(), false);
+    std::vector<IterSplitExpr> iters;
+    iters.reserve(expr->args.size());
+    // canonicalize the expression
+    // check if it can be remapped into a fused pattern.
+    PrimExpr expected_scale = make_const(expr->base->dtype, 1);
+    for (size_t i = 0; i < expr->args.size(); ++i) {
+      size_t j = 0;
+      for (; j < expr->args.size(); ++j) {
+        if (!visited[j] && CanProveEqual(expr->args[j]->scale, expected_scale)) break;
+      }
+      if (j == expr->args.size()) {
+        return NullOpt;
+      }
+      visited[j] = true;
+      iters.push_back(expr->args[j]);
+      expected_scale *= expr->args[j]->extent;
+    }
+    // update the iterator to use the canonicalized form
+    expr.CopyOnWrite()->args = Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+    auto it = sum_fuse_map_.find(expr);
+    if (it != sum_fuse_map_.end()) return it->second;
+    auto mark = IterMark(expr, expected_scale);
+    IterSplitExpr split(mark);
+    sum_fuse_map_[expr] = split;
+    return split;
+  }
+
+  bool CanProveDivisible(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value % crhs->value == 0;
+    return analyzer_->CanProve(floormod(lhs, rhs) == 0);
+  }
+
+  PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
+  PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
+    tir::ExprDeepEqual equal;
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) &&
+          equal(lvalue->extent, rhs->extent)) {
+        if (sign > 0) {
+          rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale;
+        } else {
+          rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale;
+        }
+        lhs->args.Set(i, rhs);
+        return;
+      }
+    }
+    if (sign > 0) {
+      lhs->args.push_back(rhs);
+    } else {
+      rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale;
+      lhs->args.push_back(rhs);
+    }
+  }
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSumExpr rhs, int sign) {
+    for (size_t i = 0; i < rhs->args.size(); ++i) {
+      AddToLhs(lhs, rhs->args[i], sign);
+    }
+    if (sign > 0) {
+      lhs->base += rhs->base;
+    } else {
+      lhs->base -= rhs->base;
+    }
+  }
+
+  static void MulToLhs(IterSumExprNode* lhs, PrimExpr rhs) {
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      lvalue.CopyOnWrite()->scale *= rhs;
+      lhs->args.Set(i, lvalue);
+    }
+    lhs->base *= rhs;
+  }
+};
+
+Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
+                                 arith::Analyzer* analyzer) {
+  // Overall detection algorithm is divided into two steps:
+  // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
+  // - Step1: IterIndependenceChecker checks if the iterator are independent.
+  IterMapRewriter rewriter(analyzer, input_iters);
+  Array<IterSumExpr> results;
+
+  for (PrimExpr value : indices) {
+    results.push_back(rewriter.Rewrite(value));
+    if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
+  }
+  if (!rewriter.CheckBijective(results)) return Array<IterSumExpr>();

Review comment:
       yes, in this case, x and x+1 will contribute two Split entries to mark2split of x,
   ```
   mark2split[x]= [ Split(x, extent), Split(x, extent)]
   ```
   
   And it will results in an error because the two splits overlaps with each other




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667#discussion_r504256910



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -0,0 +1,703 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/arith/iter_affine_map.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/iter_affine_map.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
+
+#include "../support/util.h"
+#include "const_fold.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+IterMark::IterMark(PrimExpr source, PrimExpr extent) {
+  auto n = make_object<IterMarkNode>();
+  n->source = std::move(source);
+  n->extent = std::move(extent);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) {
+  return IterMark(source, extent);
+});
+
+TVM_REGISTER_NODE_TYPE(IterMarkNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterMarkNode*>(node.get());
+      p->stream << "IterMark(" << op->source << ", extent=" << op->extent;
+    });
+
+IterSplitExpr::IterSplitExpr(IterMark source) {
+  auto n = make_object<IterSplitExprNode>();
+  auto one = make_const(source->source->dtype, 1);
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->extent = n->source->extent;
+  n->lower_factor = one;
+  n->scale = one;
+  data_ = std::move(n);
+}
+
+IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent,
+                             PrimExpr scale) {
+  auto n = make_object<IterSplitExprNode>();
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->lower_factor = std::move(lower_factor);
+  n->extent = std::move(extent);
+  n->scale = std::move(scale);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSplitExpr")
+    .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) {
+      return IterSplitExpr(source, lower_factor, extent, scale);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSplitExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSplitExprNode*>(node.get());
+      p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor
+                << ", extent=" << op->extent << ", scale=" << op->scale;
+    });
+
+IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) {
+  auto n = make_object<IterSumExprNode>();
+  n->dtype = base->dtype;
+  n->args = std::move(args);
+  n->base = std::move(base);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSumExpr")
+    .set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) {
+      return IterSumExpr(args, base);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSumExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSumExprNode*>(node.get());
+      p->stream << "IterSum(" << op->args << ", " << op->base << ")";
+    });
+
+/*!
+ * \brief Util to check if all splits in the sumexpr are
+ *        independent and complete (covers all the original iter space).
+ *
+ */
+class IterMarkSplitCollector {
+ public:
+  // mark all IterMarks that are visited.
+  std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
+  // each iter mark to its outgoing splits that are referenced.
+  std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual>
+      mark2splits_;
+  /*!
+   * \brief Collect all mark2splits recursively from indices.
+   * \param indices The iterator of interest.
+   */
+  void Collect(const Array<IterSumExpr>& indices) {
+    for (IterSumExpr sum_expr : indices) {
+      for (IterSplitExpr split : sum_expr->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+
+  void CollectInternal(const IterMark& mark) {
+    if (visited_.count(mark)) return;
+    visited_.insert(mark);
+    if (auto* op = mark->source.as<IterSumExprNode>()) {
+      for (IterSplitExpr split : op->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+};
+
+// Rewriter to rewrite oinformations in iter
+class IterMapRewriter : public ExprMutator {
+ public:
+  using Parent = ExprMutator;
+
+  explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
+      : analyzer_(analyzer) {
+    for (auto kv : input_iters) {
+      const auto& vrng = kv.second;
+      if (is_zero(vrng->min)) {
+        IterMark mark(kv.first, vrng->extent);
+        var_map_[kv.first] = IterSplitExpr(mark);
+        input_marks_.push_back(mark);
+      } else {
+        IterMark mark(kv.first - vrng->min, vrng->extent);
+        auto sum_expr = ToIterSumExpr(IterSplitExpr(mark));
+        sum_expr.CopyOnWrite()->base = vrng->min;
+        var_map_[kv.first] = sum_expr;
+        input_marks_.push_back(mark);
+      }
+    }
+  }
+
+  size_t unresolved_count() const { return unresolved_count_; }
+
+  IterSumExpr Rewrite(PrimExpr expr) {
+    return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
+  }
+
+  bool CheckBijective(const Array<IterSumExpr>& indices) {
+    IterMarkSplitCollector collector;
+    // We can check that for each iter mark:
+    // All the splits that refers to the itermark covers its extent.
+    // The splits do not overlap with each other.
+    collector.Collect(indices);
+    for (IterMark mark : collector.visited_) {
+      if (TryNormalizeSplits(mark, collector.mark2splits_[mark]).size() == 0) return false;
+    }
+    // all input marks must be visited
+    for (auto mark : input_marks_) {
+      if (collector.visited_.count(mark) == 0) return false;
+    }
+    return true;
+  }
+
+  // override the original mutate function.
+  PrimExpr VisitExpr(const PrimExpr& input_expr) final {
+    auto expr = ExprMutator::VisitExpr(input_expr);
+    if (expr->IsInstance<IterMapExprNode>()) {
+      ++unresolved_count_;
+    }
+    return expr;
+  }
+
+  // Normal mutation without normalization.
+  PrimExpr DirectMutate(PrimExpr expr) { return ExprMutator::VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const VarNode* op) final;
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const FloorDivNode* op) final;
+  PrimExpr VisitExpr_(const FloorModNode* op) final;
+
+ private:
+  // temp hash for de-duplication purposes.
+  struct IterSumHash {
+    size_t operator()(const IterSumExpr& value) const {
+      // for now only hash on source index.
+      size_t hash = value->args.size();
+      for (size_t i = 0; i < value->args.size(); ++i) {
+        hash = support::HashCombine(hash, std::hash<const Object*>()(value->args[i]->source.get()));
+      }
+      return hash;
+    }
+  };
+
+  struct IterSumEqual {
+    bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const {
+      tir::ExprDeepEqual equal;
+      if (lhs->args.size() != rhs->args.size()) return false;
+      if (!equal(lhs->base, rhs->base)) return false;
+      for (size_t i = 0; i < lhs->args.size(); ++i) {
+        auto lvalue = lhs->args[i];
+        auto rvalue = lhs->args[i];
+        if (!lvalue->source.same_as(rvalue->source)) return false;
+        if (!equal(lvalue->lower_factor, rvalue->lower_factor)) return false;
+        if (!equal(lvalue->scale, rvalue->scale)) return false;
+        if (!equal(lvalue->extent, rvalue->extent)) return false;
+      }
+      return true;
+    }
+  };
+
+  // Internal analyzer
+  Analyzer* analyzer_;
+  // Counter to keep track of unresolved cases.
+  int unresolved_count_{0};
+  // The var map
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
+  // input iter marks
+  std::vector<IterMark> input_marks_;
+  // The canonical map for sum
+  std::unordered_map<IterSumExpr, IterSplitExpr, IterSumHash, IterSumEqual> sum_fuse_map_;
+
+  /*!
+   * \brief Verify that splits fully covers mark in a non-overlapping fashion.
+   *        If verification passes, return splits from outermost to inner most order.
+   *        If not, return an empty array
+   * \param mark The iterator of interest.
+   * \param splits The splits to be verified.
+   * \return The normalized splits.
+   */
+  Array<IterSplitExpr> TryNormalizeSplits(const IterMark& mark,
+                                          const std::vector<IterSplitExpr>& splits) {
+    std::vector<bool> used(splits.size(), false);
+    std::vector<IterSplitExpr> iters;
+    PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
+
+    for (size_t i = 0; i < splits.size(); ++i) {
+      size_t j = 0;
+      for (; j < splits.size(); ++j) {
+        if (used[j]) continue;
+        if (!used[j] && CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break;
+      }
+      if (j == splits.size()) {
+        return Array<IterSplitExpr>();
+      }
+      used[j] = true;
+      iters.push_back(splits[j]);
+      expected_lower_factor *= splits[j]->extent;
+    }
+    if (!CanProveEqual(expected_lower_factor, mark->extent)) return Array<IterSplitExpr>();
+    return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+  }
+
+  /*!
+   * \brief Normalize expr to an iterator + offset.
+   * \param expr The input expression.
+   * \return The Normalized expression.
+   */
+  IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
+    if (expr->args.size() <= 1) return expr;
+    PrimExpr base = expr->base;
+    expr.CopyOnWrite()->base = make_zero(expr->dtype);
+    auto opt = TryFuseIters(expr);
+    expr.CopyOnWrite()->base = base;
+    if (opt) {
+      expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()});
+      return expr;
+    } else {
+      ++unresolved_count_;
+      return expr;
+    }
+  }
+
+  bool CanProveEqual(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value == crhs->value;
+    return analyzer_->CanProve(lhs - rhs == 0);
+  }
+
+  /*!
+   * \brief Create a IterSumExpr from expr.
+   * \param expr The input expr.
+   * \return The transformed IterSumExpr.
+   */
+  IterSumExpr ToIterSumExpr(PrimExpr expr) {
+    if (const auto* op = expr.as<IterSumExprNode>()) {
+      return GetRef<IterSumExpr>(op);
+    } else if (const auto* op = expr.as<IterSplitExprNode>()) {
+      return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype));
+    } else {
+      CHECK(!expr->IsInstance<IterMapExprNode>());
+      return IterSumExpr({}, expr);
+    }
+  }
+
+  // Try to normalize IterSum into a fused IterMark
+  // return a corresponding splitexpr if needed.
+  Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) {
+    if (!is_zero(expr->base)) return NullOpt;
+    if (expr->args.size() == 1) return expr->args[0];
+    // select the iterators in order
+    std::vector<bool> visited(expr->args.size(), false);
+    std::vector<IterSplitExpr> iters;
+    iters.reserve(expr->args.size());
+    // canonicalize the expression
+    // check if it can be remapped into a fused pattern.
+    PrimExpr expected_scale = make_const(expr->base->dtype, 1);
+    for (size_t i = 0; i < expr->args.size(); ++i) {
+      size_t j = 0;
+      for (; j < expr->args.size(); ++j) {
+        if (!visited[j] && CanProveEqual(expr->args[j]->scale, expected_scale)) break;
+      }
+      if (j == expr->args.size()) {
+        return NullOpt;
+      }
+      visited[j] = true;
+      iters.push_back(expr->args[j]);
+      expected_scale *= expr->args[j]->extent;
+    }
+    // update the iterator to use the canonicalized form
+    expr.CopyOnWrite()->args = Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+    auto it = sum_fuse_map_.find(expr);
+    if (it != sum_fuse_map_.end()) return it->second;
+    auto mark = IterMark(expr, expected_scale);
+    IterSplitExpr split(mark);
+    sum_fuse_map_[expr] = split;
+    return split;
+  }
+
+  bool CanProveDivisible(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value % crhs->value == 0;
+    return analyzer_->CanProve(floormod(lhs, rhs) == 0);
+  }
+
+  PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
+  PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
+    tir::ExprDeepEqual equal;
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) &&
+          equal(lvalue->extent, rhs->extent)) {
+        if (sign > 0) {
+          rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale;
+        } else {
+          rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale;
+        }
+        lhs->args.Set(i, rhs);
+        return;
+      }
+    }
+    if (sign > 0) {
+      lhs->args.push_back(rhs);
+    } else {
+      rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale;
+      lhs->args.push_back(rhs);
+    }
+  }
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSumExpr rhs, int sign) {
+    for (size_t i = 0; i < rhs->args.size(); ++i) {
+      AddToLhs(lhs, rhs->args[i], sign);
+    }
+    if (sign > 0) {
+      lhs->base += rhs->base;
+    } else {
+      lhs->base -= rhs->base;
+    }
+  }
+
+  static void MulToLhs(IterSumExprNode* lhs, PrimExpr rhs) {
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      lvalue.CopyOnWrite()->scale *= rhs;
+      lhs->args.Set(i, lvalue);
+    }
+    lhs->base *= rhs;
+  }
+};
+
+Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
+                                 arith::Analyzer* analyzer) {
+  // Overall detection algorithm is divided into two steps:
+  // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
+  // - Step1: IterIndependenceChecker checks if the iterator are independent.
+  IterMapRewriter rewriter(analyzer, input_iters);
+  Array<IterSumExpr> results;
+
+  for (PrimExpr value : indices) {
+    results.push_back(rewriter.Rewrite(value));
+    if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
+  }
+  if (!rewriter.CheckBijective(results)) return Array<IterSumExpr>();

Review comment:
       It also checks all the intermediate marks(including the input) are being covered without overlapping




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen merged pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
tqchen merged pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667#issuecomment-708084525


   Thanks @spectrometerHBH I addressed your comment and added more explaination about bijective check, I also added your example as a testcase


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on a change in pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667#discussion_r504328794



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -0,0 +1,703 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/arith/iter_affine_map.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/iter_affine_map.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
+
+#include "../support/util.h"
+#include "const_fold.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+IterMark::IterMark(PrimExpr source, PrimExpr extent) {
+  auto n = make_object<IterMarkNode>();
+  n->source = std::move(source);
+  n->extent = std::move(extent);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) {
+  return IterMark(source, extent);
+});
+
+TVM_REGISTER_NODE_TYPE(IterMarkNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterMarkNode*>(node.get());
+      p->stream << "IterMark(" << op->source << ", extent=" << op->extent;
+    });
+
+IterSplitExpr::IterSplitExpr(IterMark source) {
+  auto n = make_object<IterSplitExprNode>();
+  auto one = make_const(source->source->dtype, 1);
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->extent = n->source->extent;
+  n->lower_factor = one;
+  n->scale = one;
+  data_ = std::move(n);
+}
+
+IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent,
+                             PrimExpr scale) {
+  auto n = make_object<IterSplitExprNode>();
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->lower_factor = std::move(lower_factor);
+  n->extent = std::move(extent);
+  n->scale = std::move(scale);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSplitExpr")
+    .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) {
+      return IterSplitExpr(source, lower_factor, extent, scale);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSplitExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSplitExprNode*>(node.get());
+      p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor
+                << ", extent=" << op->extent << ", scale=" << op->scale;
+    });
+
+IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) {
+  auto n = make_object<IterSumExprNode>();
+  n->dtype = base->dtype;
+  n->args = std::move(args);
+  n->base = std::move(base);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSumExpr")
+    .set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) {
+      return IterSumExpr(args, base);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSumExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSumExprNode*>(node.get());
+      p->stream << "IterSum(" << op->args << ", " << op->base << ")";
+    });
+
+/*!
+ * \brief Util to check if all splits in the sumexpr are
+ *        independent and complete (covers all the original iter space).
+ *
+ */
+class IterMarkSplitCollector {
+ public:
+  // mark all IterMarks that are visited.
+  std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
+  // each iter mark to its outgoing splits that are referenced.
+  std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual>
+      mark2splits_;
+  /*!
+   * \brief Collect all mark2splits recursively from indices.
+   * \param indices The iterator of interest.
+   */
+  void Collect(const Array<IterSumExpr>& indices) {
+    for (IterSumExpr sum_expr : indices) {
+      for (IterSplitExpr split : sum_expr->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+
+  void CollectInternal(const IterMark& mark) {
+    if (visited_.count(mark)) return;
+    visited_.insert(mark);
+    if (auto* op = mark->source.as<IterSumExprNode>()) {
+      for (IterSplitExpr split : op->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+};
+
+// Rewriter to rewrite oinformations in iter
+class IterMapRewriter : public ExprMutator {
+ public:
+  using Parent = ExprMutator;
+
+  explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
+      : analyzer_(analyzer) {
+    for (auto kv : input_iters) {
+      const auto& vrng = kv.second;
+      if (is_zero(vrng->min)) {
+        IterMark mark(kv.first, vrng->extent);
+        var_map_[kv.first] = IterSplitExpr(mark);
+        input_marks_.push_back(mark);
+      } else {
+        IterMark mark(kv.first - vrng->min, vrng->extent);
+        auto sum_expr = ToIterSumExpr(IterSplitExpr(mark));
+        sum_expr.CopyOnWrite()->base = vrng->min;
+        var_map_[kv.first] = sum_expr;
+        input_marks_.push_back(mark);
+      }
+    }
+  }
+
+  size_t unresolved_count() const { return unresolved_count_; }
+
+  IterSumExpr Rewrite(PrimExpr expr) {
+    return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
+  }
+
+  bool CheckBijective(const Array<IterSumExpr>& indices) {
+    IterMarkSplitCollector collector;
+    // We can check that for each iter mark:
+    // All the splits that refers to the itermark covers its extent.
+    // The splits do not overlap with each other.
+    collector.Collect(indices);
+    for (IterMark mark : collector.visited_) {
+      if (TryNormalizeSplits(mark, collector.mark2splits_[mark]).size() == 0) return false;
+    }
+    // all input marks must be visited
+    for (auto mark : input_marks_) {
+      if (collector.visited_.count(mark) == 0) return false;
+    }
+    return true;
+  }
+
+  // override the original mutate function.
+  PrimExpr VisitExpr(const PrimExpr& input_expr) final {
+    auto expr = ExprMutator::VisitExpr(input_expr);
+    if (expr->IsInstance<IterMapExprNode>()) {
+      ++unresolved_count_;
+    }
+    return expr;
+  }
+
+  // Normal mutation without normalization.
+  PrimExpr DirectMutate(PrimExpr expr) { return ExprMutator::VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const VarNode* op) final;
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const FloorDivNode* op) final;
+  PrimExpr VisitExpr_(const FloorModNode* op) final;
+
+ private:
+  // temp hash for de-duplication purposes.
+  struct IterSumHash {
+    size_t operator()(const IterSumExpr& value) const {
+      // for now only hash on source index.
+      size_t hash = value->args.size();
+      for (size_t i = 0; i < value->args.size(); ++i) {
+        hash = support::HashCombine(hash, std::hash<const Object*>()(value->args[i]->source.get()));
+      }
+      return hash;
+    }
+  };
+
+  struct IterSumEqual {
+    bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const {
+      tir::ExprDeepEqual equal;
+      if (lhs->args.size() != rhs->args.size()) return false;
+      if (!equal(lhs->base, rhs->base)) return false;
+      for (size_t i = 0; i < lhs->args.size(); ++i) {
+        auto lvalue = lhs->args[i];
+        auto rvalue = lhs->args[i];
+        if (!lvalue->source.same_as(rvalue->source)) return false;
+        if (!equal(lvalue->lower_factor, rvalue->lower_factor)) return false;
+        if (!equal(lvalue->scale, rvalue->scale)) return false;
+        if (!equal(lvalue->extent, rvalue->extent)) return false;
+      }
+      return true;
+    }
+  };
+
+  // Internal analyzer
+  Analyzer* analyzer_;
+  // Counter to keep track of unresolved cases.
+  int unresolved_count_{0};
+  // The var map
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
+  // input iter marks
+  std::vector<IterMark> input_marks_;
+  // The canonical map for sum
+  std::unordered_map<IterSumExpr, IterSplitExpr, IterSumHash, IterSumEqual> sum_fuse_map_;
+
+  /*!
+   * \brief Verify that splits fully covers mark in a non-overlapping fashion.
+   *        If verification passes, return splits from outermost to inner most order.
+   *        If not, return an empty array
+   * \param mark The iterator of interest.
+   * \param splits The splits to be verified.
+   * \return The normalized splits.
+   */
+  Array<IterSplitExpr> TryNormalizeSplits(const IterMark& mark,
+                                          const std::vector<IterSplitExpr>& splits) {
+    std::vector<bool> used(splits.size(), false);
+    std::vector<IterSplitExpr> iters;
+    PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
+
+    for (size_t i = 0; i < splits.size(); ++i) {
+      size_t j = 0;
+      for (; j < splits.size(); ++j) {
+        if (used[j]) continue;
+        if (!used[j] && CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break;
+      }
+      if (j == splits.size()) {
+        return Array<IterSplitExpr>();
+      }
+      used[j] = true;
+      iters.push_back(splits[j]);
+      expected_lower_factor *= splits[j]->extent;
+    }
+    if (!CanProveEqual(expected_lower_factor, mark->extent)) return Array<IterSplitExpr>();
+    return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+  }
+
+  /*!
+   * \brief Normalize expr to an iterator + offset.
+   * \param expr The input expression.
+   * \return The Normalized expression.
+   */
+  IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
+    if (expr->args.size() <= 1) return expr;
+    PrimExpr base = expr->base;
+    expr.CopyOnWrite()->base = make_zero(expr->dtype);
+    auto opt = TryFuseIters(expr);
+    expr.CopyOnWrite()->base = base;
+    if (opt) {
+      expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()});
+      return expr;
+    } else {
+      ++unresolved_count_;
+      return expr;
+    }
+  }
+
+  bool CanProveEqual(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value == crhs->value;
+    return analyzer_->CanProve(lhs - rhs == 0);
+  }
+
+  /*!
+   * \brief Create a IterSumExpr from expr.
+   * \param expr The input expr.
+   * \return The transformed IterSumExpr.
+   */
+  IterSumExpr ToIterSumExpr(PrimExpr expr) {
+    if (const auto* op = expr.as<IterSumExprNode>()) {
+      return GetRef<IterSumExpr>(op);
+    } else if (const auto* op = expr.as<IterSplitExprNode>()) {
+      return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype));
+    } else {
+      CHECK(!expr->IsInstance<IterMapExprNode>());
+      return IterSumExpr({}, expr);
+    }
+  }
+
+  // Try to normalize IterSum into a fused IterMark
+  // return a corresponding splitexpr if needed.
+  Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) {
+    if (!is_zero(expr->base)) return NullOpt;
+    if (expr->args.size() == 1) return expr->args[0];
+    // select the iterators in order
+    std::vector<bool> visited(expr->args.size(), false);
+    std::vector<IterSplitExpr> iters;
+    iters.reserve(expr->args.size());
+    // canonicalize the expression
+    // check if it can be remapped into a fused pattern.
+    PrimExpr expected_scale = make_const(expr->base->dtype, 1);
+    for (size_t i = 0; i < expr->args.size(); ++i) {
+      size_t j = 0;
+      for (; j < expr->args.size(); ++j) {
+        if (!visited[j] && CanProveEqual(expr->args[j]->scale, expected_scale)) break;
+      }
+      if (j == expr->args.size()) {
+        return NullOpt;
+      }
+      visited[j] = true;
+      iters.push_back(expr->args[j]);
+      expected_scale *= expr->args[j]->extent;
+    }
+    // update the iterator to use the canonicalized form
+    expr.CopyOnWrite()->args = Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+    auto it = sum_fuse_map_.find(expr);
+    if (it != sum_fuse_map_.end()) return it->second;
+    auto mark = IterMark(expr, expected_scale);
+    IterSplitExpr split(mark);
+    sum_fuse_map_[expr] = split;
+    return split;
+  }
+
+  bool CanProveDivisible(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value % crhs->value == 0;
+    return analyzer_->CanProve(floormod(lhs, rhs) == 0);
+  }
+
+  PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
+  PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
+    tir::ExprDeepEqual equal;
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) &&
+          equal(lvalue->extent, rhs->extent)) {
+        if (sign > 0) {
+          rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale;
+        } else {
+          rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale;
+        }
+        lhs->args.Set(i, rhs);
+        return;
+      }
+    }
+    if (sign > 0) {
+      lhs->args.push_back(rhs);
+    } else {
+      rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale;
+      lhs->args.push_back(rhs);
+    }
+  }
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSumExpr rhs, int sign) {
+    for (size_t i = 0; i < rhs->args.size(); ++i) {
+      AddToLhs(lhs, rhs->args[i], sign);
+    }
+    if (sign > 0) {
+      lhs->base += rhs->base;
+    } else {
+      lhs->base -= rhs->base;
+    }
+  }
+
+  static void MulToLhs(IterSumExprNode* lhs, PrimExpr rhs) {
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      lvalue.CopyOnWrite()->scale *= rhs;
+      lhs->args.Set(i, lvalue);
+    }
+    lhs->base *= rhs;
+  }
+};
+
+Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
+                                 arith::Analyzer* analyzer) {
+  // Overall detection algorithm is divided into two steps:
+  // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
+  // - Step1: IterIndependenceChecker checks if the iterator are independent.
+  IterMapRewriter rewriter(analyzer, input_iters);
+  Array<IterSumExpr> results;
+
+  for (PrimExpr value : indices) {
+    results.push_back(rewriter.Rewrite(value));
+    if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
+  }
+  if (!rewriter.CheckBijective(results)) return Array<IterSumExpr>();

Review comment:
       If the indices are `y, x, x+1` and input_iters are `y, x`, will this be checked?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667#discussion_r504331608



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -0,0 +1,703 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/arith/iter_affine_map.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/iter_affine_map.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
+
+#include "../support/util.h"
+#include "const_fold.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+IterMark::IterMark(PrimExpr source, PrimExpr extent) {
+  auto n = make_object<IterMarkNode>();
+  n->source = std::move(source);
+  n->extent = std::move(extent);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) {
+  return IterMark(source, extent);
+});
+
+TVM_REGISTER_NODE_TYPE(IterMarkNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterMarkNode*>(node.get());
+      p->stream << "IterMark(" << op->source << ", extent=" << op->extent;
+    });
+
+IterSplitExpr::IterSplitExpr(IterMark source) {
+  auto n = make_object<IterSplitExprNode>();
+  auto one = make_const(source->source->dtype, 1);
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->extent = n->source->extent;
+  n->lower_factor = one;
+  n->scale = one;
+  data_ = std::move(n);
+}
+
+IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent,
+                             PrimExpr scale) {
+  auto n = make_object<IterSplitExprNode>();
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->lower_factor = std::move(lower_factor);
+  n->extent = std::move(extent);
+  n->scale = std::move(scale);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSplitExpr")
+    .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) {
+      return IterSplitExpr(source, lower_factor, extent, scale);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSplitExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSplitExprNode*>(node.get());
+      p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor
+                << ", extent=" << op->extent << ", scale=" << op->scale;
+    });
+
+IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) {
+  auto n = make_object<IterSumExprNode>();
+  n->dtype = base->dtype;
+  n->args = std::move(args);
+  n->base = std::move(base);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSumExpr")
+    .set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) {
+      return IterSumExpr(args, base);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSumExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSumExprNode*>(node.get());
+      p->stream << "IterSum(" << op->args << ", " << op->base << ")";
+    });
+
+/*!
+ * \brief Util to check if all splits in the sumexpr are
+ *        independent and complete (covers all the original iter space).
+ *
+ */
+class IterMarkSplitCollector {
+ public:
+  // mark all IterMarks that are visited.
+  std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
+  // each iter mark to its outgoing splits that are referenced.
+  std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual>
+      mark2splits_;
+  /*!
+   * \brief Collect all mark2splits recursively from indices.
+   * \param indices The iterator of interest.
+   */
+  void Collect(const Array<IterSumExpr>& indices) {
+    for (IterSumExpr sum_expr : indices) {
+      for (IterSplitExpr split : sum_expr->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+
+  void CollectInternal(const IterMark& mark) {
+    if (visited_.count(mark)) return;
+    visited_.insert(mark);
+    if (auto* op = mark->source.as<IterSumExprNode>()) {
+      for (IterSplitExpr split : op->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+};
+
+// Rewriter to rewrite oinformations in iter
+class IterMapRewriter : public ExprMutator {
+ public:
+  using Parent = ExprMutator;
+
+  explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
+      : analyzer_(analyzer) {
+    for (auto kv : input_iters) {
+      const auto& vrng = kv.second;
+      if (is_zero(vrng->min)) {
+        IterMark mark(kv.first, vrng->extent);
+        var_map_[kv.first] = IterSplitExpr(mark);
+        input_marks_.push_back(mark);
+      } else {
+        IterMark mark(kv.first - vrng->min, vrng->extent);
+        auto sum_expr = ToIterSumExpr(IterSplitExpr(mark));
+        sum_expr.CopyOnWrite()->base = vrng->min;
+        var_map_[kv.first] = sum_expr;
+        input_marks_.push_back(mark);
+      }
+    }
+  }
+
+  size_t unresolved_count() const { return unresolved_count_; }
+
+  IterSumExpr Rewrite(PrimExpr expr) {
+    return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
+  }
+
+  bool CheckBijective(const Array<IterSumExpr>& indices) {
+    IterMarkSplitCollector collector;
+    // We can check that for each iter mark:
+    // All the splits that refers to the itermark covers its extent.
+    // The splits do not overlap with each other.
+    collector.Collect(indices);
+    for (IterMark mark : collector.visited_) {
+      if (TryNormalizeSplits(mark, collector.mark2splits_[mark]).size() == 0) return false;
+    }
+    // all input marks must be visited
+    for (auto mark : input_marks_) {
+      if (collector.visited_.count(mark) == 0) return false;
+    }
+    return true;
+  }
+
+  // override the original mutate function.
+  PrimExpr VisitExpr(const PrimExpr& input_expr) final {
+    auto expr = ExprMutator::VisitExpr(input_expr);
+    if (expr->IsInstance<IterMapExprNode>()) {
+      ++unresolved_count_;
+    }
+    return expr;
+  }
+
+  // Normal mutation without normalization.
+  PrimExpr DirectMutate(PrimExpr expr) { return ExprMutator::VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const VarNode* op) final;
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const FloorDivNode* op) final;
+  PrimExpr VisitExpr_(const FloorModNode* op) final;
+
+ private:
+  // temp hash for de-duplication purposes.
+  struct IterSumHash {
+    size_t operator()(const IterSumExpr& value) const {
+      // for now only hash on source index.
+      size_t hash = value->args.size();
+      for (size_t i = 0; i < value->args.size(); ++i) {
+        hash = support::HashCombine(hash, std::hash<const Object*>()(value->args[i]->source.get()));
+      }
+      return hash;
+    }
+  };
+
+  struct IterSumEqual {
+    bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const {
+      tir::ExprDeepEqual equal;
+      if (lhs->args.size() != rhs->args.size()) return false;
+      if (!equal(lhs->base, rhs->base)) return false;
+      for (size_t i = 0; i < lhs->args.size(); ++i) {
+        auto lvalue = lhs->args[i];
+        auto rvalue = lhs->args[i];
+        if (!lvalue->source.same_as(rvalue->source)) return false;
+        if (!equal(lvalue->lower_factor, rvalue->lower_factor)) return false;
+        if (!equal(lvalue->scale, rvalue->scale)) return false;
+        if (!equal(lvalue->extent, rvalue->extent)) return false;
+      }
+      return true;
+    }
+  };
+
+  // Internal analyzer
+  Analyzer* analyzer_;
+  // Counter to keep track of unresolved cases.
+  int unresolved_count_{0};
+  // The var map
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
+  // input iter marks
+  std::vector<IterMark> input_marks_;
+  // The canonical map for sum
+  std::unordered_map<IterSumExpr, IterSplitExpr, IterSumHash, IterSumEqual> sum_fuse_map_;
+
+  /*!
+   * \brief Verify that splits fully covers mark in a non-overlapping fashion.
+   *        If verification passes, return splits from outermost to inner most order.
+   *        If not, return an empty array
+   * \param mark The iterator of interest.
+   * \param splits The splits to be verified.
+   * \return The normalized splits.
+   */
+  Array<IterSplitExpr> TryNormalizeSplits(const IterMark& mark,
+                                          const std::vector<IterSplitExpr>& splits) {
+    std::vector<bool> used(splits.size(), false);
+    std::vector<IterSplitExpr> iters;
+    PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
+
+    for (size_t i = 0; i < splits.size(); ++i) {
+      size_t j = 0;
+      for (; j < splits.size(); ++j) {
+        if (used[j]) continue;
+        if (!used[j] && CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break;
+      }
+      if (j == splits.size()) {
+        return Array<IterSplitExpr>();
+      }
+      used[j] = true;
+      iters.push_back(splits[j]);
+      expected_lower_factor *= splits[j]->extent;
+    }
+    if (!CanProveEqual(expected_lower_factor, mark->extent)) return Array<IterSplitExpr>();
+    return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+  }
+
+  /*!
+   * \brief Normalize expr to an iterator + offset.
+   * \param expr The input expression.
+   * \return The Normalized expression.
+   */
+  IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
+    if (expr->args.size() <= 1) return expr;
+    PrimExpr base = expr->base;
+    expr.CopyOnWrite()->base = make_zero(expr->dtype);
+    auto opt = TryFuseIters(expr);
+    expr.CopyOnWrite()->base = base;
+    if (opt) {
+      expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()});
+      return expr;
+    } else {
+      ++unresolved_count_;
+      return expr;
+    }
+  }
+
+  bool CanProveEqual(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value == crhs->value;
+    return analyzer_->CanProve(lhs - rhs == 0);
+  }
+
+  /*!
+   * \brief Create a IterSumExpr from expr.
+   * \param expr The input expr.
+   * \return The transformed IterSumExpr.
+   */
+  IterSumExpr ToIterSumExpr(PrimExpr expr) {
+    if (const auto* op = expr.as<IterSumExprNode>()) {
+      return GetRef<IterSumExpr>(op);
+    } else if (const auto* op = expr.as<IterSplitExprNode>()) {
+      return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype));
+    } else {
+      CHECK(!expr->IsInstance<IterMapExprNode>());
+      return IterSumExpr({}, expr);
+    }
+  }
+
+  // Try to normalize IterSum into a fused IterMark
+  // return a corresponding splitexpr if needed.
+  Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) {
+    if (!is_zero(expr->base)) return NullOpt;
+    if (expr->args.size() == 1) return expr->args[0];
+    // select the iterators in order
+    std::vector<bool> visited(expr->args.size(), false);
+    std::vector<IterSplitExpr> iters;
+    iters.reserve(expr->args.size());
+    // canonicalize the expression
+    // check if it can be remapped into a fused pattern.
+    PrimExpr expected_scale = make_const(expr->base->dtype, 1);
+    for (size_t i = 0; i < expr->args.size(); ++i) {
+      size_t j = 0;
+      for (; j < expr->args.size(); ++j) {
+        if (!visited[j] && CanProveEqual(expr->args[j]->scale, expected_scale)) break;
+      }
+      if (j == expr->args.size()) {
+        return NullOpt;
+      }
+      visited[j] = true;
+      iters.push_back(expr->args[j]);
+      expected_scale *= expr->args[j]->extent;
+    }
+    // update the iterator to use the canonicalized form
+    expr.CopyOnWrite()->args = Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+    auto it = sum_fuse_map_.find(expr);
+    if (it != sum_fuse_map_.end()) return it->second;
+    auto mark = IterMark(expr, expected_scale);
+    IterSplitExpr split(mark);
+    sum_fuse_map_[expr] = split;
+    return split;
+  }
+
+  bool CanProveDivisible(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value % crhs->value == 0;
+    return analyzer_->CanProve(floormod(lhs, rhs) == 0);
+  }
+
+  PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
+  PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
+    tir::ExprDeepEqual equal;
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) &&
+          equal(lvalue->extent, rhs->extent)) {
+        if (sign > 0) {
+          rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale;
+        } else {
+          rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale;
+        }
+        lhs->args.Set(i, rhs);
+        return;
+      }
+    }
+    if (sign > 0) {
+      lhs->args.push_back(rhs);
+    } else {
+      rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale;
+      lhs->args.push_back(rhs);
+    }
+  }
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSumExpr rhs, int sign) {
+    for (size_t i = 0; i < rhs->args.size(); ++i) {
+      AddToLhs(lhs, rhs->args[i], sign);
+    }
+    if (sign > 0) {
+      lhs->base += rhs->base;
+    } else {
+      lhs->base -= rhs->base;
+    }
+  }
+
+  static void MulToLhs(IterSumExprNode* lhs, PrimExpr rhs) {
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      lvalue.CopyOnWrite()->scale *= rhs;
+      lhs->args.Set(i, lvalue);
+    }
+    lhs->base *= rhs;
+  }
+};
+
+Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
+                                 arith::Analyzer* analyzer) {
+  // Overall detection algorithm is divided into two steps:
+  // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
+  // - Step1: IterIndependenceChecker checks if the iterator are independent.
+  IterMapRewriter rewriter(analyzer, input_iters);
+  Array<IterSumExpr> results;
+
+  for (PrimExpr value : indices) {
+    results.push_back(rewriter.Rewrite(value));
+    if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
+  }
+  if (!rewriter.CheckBijective(results)) return Array<IterSumExpr>();

Review comment:
       The checks are via the TryNormalizeSplits function




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667#issuecomment-706780293


   cc @spectrometerHBH @comaniac @yzhliu @Hzfengsy @merrymercy 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on a change in pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667#discussion_r502984478



##########
File path: tests/python/unittest/test_arith_iter_affine_map.py
##########
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import te
+
+
+def ifuse(inputs):
+    """Fuse iterators"""
+    value, extent = 0, 1
+    for i, ext in inputs:
+        value = value * ext + i
+        extent = extent * ext
+    return (value, extent)
+
+
+def isplit(axis, factor):
+    """Fuse iterators"""

Review comment:
       Typo

##########
File path: python/tvm/arith/iter_affine_map.py
##########
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Iterator (quasi)affine mapping patterns."""
+import tvm._ffi
+from . import _ffi_api
+from tvm.runtime import Object
+from tvm.ir import PrimExpr
+
+
+class IterMapExpr(PrimExpr):
+    """Base class of all IterMap expressions."""
+
+
+@tvm._ffi.register_object("arith.IterMark")
+class IterMark(Object):
+    """Mark the source as an iterator in [0, extent).
+
+    Parameters
+    ----------
+    source : PrimExpr.
+        The source expression.
+
+    extent : PrimExpr
+        The extent of the iterator.
+    """
+
+    def __init__(self, min_value, max_value):
+        self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent)
+
+
+@tvm._ffi.register_object("arith.IterSplitExpr")
+class IterSplitExpr(IterMapExpr):
+    """Split of an iterator.
+
+    result = floormod(floordiv(source, lower_factor), extent) * scale
+
+    Parameters
+    ----------
+    source : IterMark
+        The source marked iterator.
+
+    lower_factor : PrimExpr
+        The lower factor to split the domain.
+
+    extent : PrimExpr
+        The extent of the split.
+
+    scale : PrimExpr
+        Additional scale to the split.
+    """
+
+    def __init__(self, min_value, max_value):
+        self.__init_handle_by_constructor__(
+            _ffi_api.IterSplitExpr, source, lower_factor, extent, scale
+        )
+
+
+@tvm._ffi.register_object("arith.IterSumExpr")
+class IterSumExpr(IterMapExpr):
+    """Fuse multiple iterators by summing them with scaling.
+
+    result = sum(args) + base
+
+    Parameters
+    ----------
+    args : List[IterSplitExpr]
+        The input to the sum expression.
+
+    base : PrimExpr
+        The base offset.
+    """
+
+    def __init__(self, args, base):
+        self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base)
+
+
+def detect_iter_map(indices, input_iters):
+    """Detect if indices can be written mapped iters from input_iters.
+
+    Parameters
+    ----------
+    indices : List[PrimExpr]
+        The input indices.
+
+    input_iters : Map[Var, Range]
+        The domain of each input iterators.

Review comment:
       May be good to add Returns in comment here.

##########
File path: python/tvm/arith/iter_affine_map.py
##########
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Iterator (quasi)affine mapping patterns."""
+import tvm._ffi
+from . import _ffi_api
+from tvm.runtime import Object
+from tvm.ir import PrimExpr
+
+
+class IterMapExpr(PrimExpr):
+    """Base class of all IterMap expressions."""
+
+
+@tvm._ffi.register_object("arith.IterMark")
+class IterMark(Object):
+    """Mark the source as an iterator in [0, extent).
+
+    Parameters
+    ----------
+    source : PrimExpr.
+        The source expression.
+
+    extent : PrimExpr
+        The extent of the iterator.
+    """
+
+    def __init__(self, min_value, max_value):
+        self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent)

Review comment:
       Inconsistent with function arguments

##########
File path: python/tvm/arith/iter_affine_map.py
##########
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Iterator (quasi)affine mapping patterns."""
+import tvm._ffi
+from . import _ffi_api
+from tvm.runtime import Object
+from tvm.ir import PrimExpr
+
+
+class IterMapExpr(PrimExpr):
+    """Base class of all IterMap expressions."""
+
+
+@tvm._ffi.register_object("arith.IterMark")
+class IterMark(Object):
+    """Mark the source as an iterator in [0, extent).
+
+    Parameters
+    ----------
+    source : PrimExpr.
+        The source expression.
+
+    extent : PrimExpr
+        The extent of the iterator.
+    """
+
+    def __init__(self, min_value, max_value):
+        self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent)
+
+
+@tvm._ffi.register_object("arith.IterSplitExpr")
+class IterSplitExpr(IterMapExpr):
+    """Split of an iterator.
+
+    result = floormod(floordiv(source, lower_factor), extent) * scale
+
+    Parameters
+    ----------
+    source : IterMark
+        The source marked iterator.
+
+    lower_factor : PrimExpr
+        The lower factor to split the domain.
+
+    extent : PrimExpr
+        The extent of the split.
+
+    scale : PrimExpr
+        Additional scale to the split.
+    """
+
+    def __init__(self, min_value, max_value):

Review comment:
       Incorrect arguments




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on a change in pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667#discussion_r503749804



##########
File path: include/tvm/arith/iter_affine_map.h
##########
@@ -0,0 +1,276 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/iter_affine_map.h
+ * \brief Iterator quasi-affine mapping patterns.
+ *
+ *  This file defines a collection of mapping patterns
+ *  maps a collection of independent iterators to another
+ *  collection of independent iterators.
+ *
+ *  There are two main kinds of mapping patterns:
+ *
+ *  - Fuse: fuse a collection of iterators into a single one
+ *
+ *    domain(x0) = [0, 4), domain(x1) = [0, 3), domain(x2) = [0, 2)
+ *    fuse(x0, x1, x2): y = x2 * 12 + x1 * 4 + x0
+ *    domain(y) = [0, 24)
+ *
+ *  - Split: split an iterator into multiple ones
+ *
+ *    domain(x) = [0, 24)
+ *    split(x, 3, 12): [y0, y1, y2] = [x % 3, (x % 12) / 3, x / 12]
+ *    domain(y0) = [0, 3), domain(y1) = [0, 4), domain(y2) = [0, 2)
+ *
+ *  We use the name "(quasi)affine" to be consistent with
+ *  the terminology used in the polyhedral compilation.
+ *  Notably, fuse is an affine transformation,
+ *  while split corresponds to additional floordiv/mod operations
+ *  that can appear in quasi-affine transformations.
+ */
+#ifndef TVM_ARITH_ITER_AFFINE_MAP_H_
+#define TVM_ARITH_ITER_AFFINE_MAP_H_
+
+#include <tvm/ir/expr.h>
+
+namespace tvm {
+namespace arith {
+
+/*!
+ * \brief Base class of all iter map expressions.
+ *
+ *  An IterMapExpr is a special expression to store
+ *  the result of IterMapDetection, it should not
+ */
+class IterMapExprNode : public PrimExprNode {
+ public:
+  // overrides
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "arith.IterMapExpr";
+  static constexpr const uint32_t _type_child_slots = 3;
+  TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode);
+};
+
+/*!
+ * \brief Managed reference to IterMapExprNode.
+ * \sa IterMapExprNode
+ */
+class IterMapExpr : public PrimExpr {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(IterMapExpr, PrimExpr, IterMapExprNode);
+};
+
+/*!
+ * \brief Mark the source as an iterator in [0, extent).
+ *
+ *  IterMark is used to mark source expression as a valid
+ *  iterator to make future analysis easy.
+ */
+class IterMarkNode : public Object {
+ public:
+  /*!
+   * \brief The source expression, can either be
+   *  a IterSumExpr or a Var.
+   */
+  PrimExpr source;
+  /*!
+   * \brief The extent of the iteration.
+   */
+  PrimExpr extent;
+
+  // overrides
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("source", &source);
+    v->Visit("extent", &extent);
+  }
+
+  bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const {
+    equal->MarkGraphNode();
+    return equal(source, other->source) && equal(extent, other->extent);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce->MarkGraphNode();
+    hash_reduce(source);
+    hash_reduce(extent);
+  }
+
+  static constexpr const bool _type_has_method_sequal_reduce = true;
+  static constexpr const bool _type_has_method_shash_reduce = true;
+  static constexpr const char* _type_key = "arith.IterMark";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IterMarkNode, Object);
+};
+
+/*!
+ * \brief Managed reference to IterMarkExprNode.
+ * \sa IterMarkExprNode
+ */
+class IterMark : public ObjectRef {
+ public:
+  /*!
+   * \brief constructor.
+   * \param source The source expression.
+   * \param extent The extent of the iterator.
+   */
+  TVM_DLL IterMark(PrimExpr source, PrimExpr extent);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode);
+};
+
+/*!
+ * \brief Split of an iterator.
+ *
+ *  result = floormod(floordiv(source, lower_factor), extent) * scale
+ */
+class IterSplitExprNode : public IterMapExprNode {
+ public:
+  /*! \brief The source marked iterator. */
+  IterMark source;

Review comment:
       Can `source` of IterSplitExpr be an IterSumExpr, i.e. we split some iterators, fuse some of them and split some of them again?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667#discussion_r504256910



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -0,0 +1,703 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/arith/iter_affine_map.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/iter_affine_map.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
+
+#include "../support/util.h"
+#include "const_fold.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+IterMark::IterMark(PrimExpr source, PrimExpr extent) {
+  auto n = make_object<IterMarkNode>();
+  n->source = std::move(source);
+  n->extent = std::move(extent);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) {
+  return IterMark(source, extent);
+});
+
+TVM_REGISTER_NODE_TYPE(IterMarkNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterMarkNode*>(node.get());
+      p->stream << "IterMark(" << op->source << ", extent=" << op->extent;
+    });
+
+IterSplitExpr::IterSplitExpr(IterMark source) {
+  auto n = make_object<IterSplitExprNode>();
+  auto one = make_const(source->source->dtype, 1);
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->extent = n->source->extent;
+  n->lower_factor = one;
+  n->scale = one;
+  data_ = std::move(n);
+}
+
+IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent,
+                             PrimExpr scale) {
+  auto n = make_object<IterSplitExprNode>();
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->lower_factor = std::move(lower_factor);
+  n->extent = std::move(extent);
+  n->scale = std::move(scale);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSplitExpr")
+    .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) {
+      return IterSplitExpr(source, lower_factor, extent, scale);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSplitExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSplitExprNode*>(node.get());
+      p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor
+                << ", extent=" << op->extent << ", scale=" << op->scale;
+    });
+
+IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) {
+  auto n = make_object<IterSumExprNode>();
+  n->dtype = base->dtype;
+  n->args = std::move(args);
+  n->base = std::move(base);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSumExpr")
+    .set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) {
+      return IterSumExpr(args, base);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSumExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSumExprNode*>(node.get());
+      p->stream << "IterSum(" << op->args << ", " << op->base << ")";
+    });
+
+/*!
+ * \brief Util to check if all splits in the sumexpr are
+ *        independent and complete (covers all the original iter space).
+ *
+ */
+class IterMarkSplitCollector {
+ public:
+  // mark all IterMarks that are visited.
+  std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
+  // each iter mark to its outgoing splits that are referenced.
+  std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual>
+      mark2splits_;
+  /*!
+   * \brief Collect all mark2splits recursively from indices.
+   * \param indices The iterator of interest.
+   */
+  void Collect(const Array<IterSumExpr>& indices) {
+    for (IterSumExpr sum_expr : indices) {
+      for (IterSplitExpr split : sum_expr->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+
+  void CollectInternal(const IterMark& mark) {
+    if (visited_.count(mark)) return;
+    visited_.insert(mark);
+    if (auto* op = mark->source.as<IterSumExprNode>()) {
+      for (IterSplitExpr split : op->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+};
+
+// Rewriter to rewrite oinformations in iter
+class IterMapRewriter : public ExprMutator {
+ public:
+  using Parent = ExprMutator;
+
+  explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
+      : analyzer_(analyzer) {
+    for (auto kv : input_iters) {
+      const auto& vrng = kv.second;
+      if (is_zero(vrng->min)) {
+        IterMark mark(kv.first, vrng->extent);
+        var_map_[kv.first] = IterSplitExpr(mark);
+        input_marks_.push_back(mark);
+      } else {
+        IterMark mark(kv.first - vrng->min, vrng->extent);
+        auto sum_expr = ToIterSumExpr(IterSplitExpr(mark));
+        sum_expr.CopyOnWrite()->base = vrng->min;
+        var_map_[kv.first] = sum_expr;
+        input_marks_.push_back(mark);
+      }
+    }
+  }
+
+  size_t unresolved_count() const { return unresolved_count_; }
+
+  IterSumExpr Rewrite(PrimExpr expr) {
+    return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
+  }
+
+  bool CheckBijective(const Array<IterSumExpr>& indices) {
+    IterMarkSplitCollector collector;
+    // We can check that for each iter mark:
+    // All the splits that refers to the itermark covers its extent.
+    // The splits do not overlap with each other.
+    collector.Collect(indices);
+    for (IterMark mark : collector.visited_) {
+      if (TryNormalizeSplits(mark, collector.mark2splits_[mark]).size() == 0) return false;
+    }
+    // all input marks must be visited
+    for (auto mark : input_marks_) {
+      if (collector.visited_.count(mark) == 0) return false;
+    }
+    return true;
+  }
+
+  // override the original mutate function.
+  PrimExpr VisitExpr(const PrimExpr& input_expr) final {
+    auto expr = ExprMutator::VisitExpr(input_expr);
+    if (expr->IsInstance<IterMapExprNode>()) {
+      ++unresolved_count_;
+    }
+    return expr;
+  }
+
+  // Normal mutation without normalization.
+  PrimExpr DirectMutate(PrimExpr expr) { return ExprMutator::VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const VarNode* op) final;
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const FloorDivNode* op) final;
+  PrimExpr VisitExpr_(const FloorModNode* op) final;
+
+ private:
+  // temp hash for de-duplication purposes.
+  struct IterSumHash {
+    size_t operator()(const IterSumExpr& value) const {
+      // for now only hash on source index.
+      size_t hash = value->args.size();
+      for (size_t i = 0; i < value->args.size(); ++i) {
+        hash = support::HashCombine(hash, std::hash<const Object*>()(value->args[i]->source.get()));
+      }
+      return hash;
+    }
+  };
+
+  struct IterSumEqual {
+    bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const {
+      tir::ExprDeepEqual equal;
+      if (lhs->args.size() != rhs->args.size()) return false;
+      if (!equal(lhs->base, rhs->base)) return false;
+      for (size_t i = 0; i < lhs->args.size(); ++i) {
+        auto lvalue = lhs->args[i];
+        auto rvalue = lhs->args[i];
+        if (!lvalue->source.same_as(rvalue->source)) return false;
+        if (!equal(lvalue->lower_factor, rvalue->lower_factor)) return false;
+        if (!equal(lvalue->scale, rvalue->scale)) return false;
+        if (!equal(lvalue->extent, rvalue->extent)) return false;
+      }
+      return true;
+    }
+  };
+
+  // Internal analyzer
+  Analyzer* analyzer_;
+  // Counter to keep track of unresolved cases.
+  int unresolved_count_{0};
+  // The var map
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
+  // input iter marks
+  std::vector<IterMark> input_marks_;
+  // The canonical map for sum
+  std::unordered_map<IterSumExpr, IterSplitExpr, IterSumHash, IterSumEqual> sum_fuse_map_;
+
+  /*!
+   * \brief Verify that splits fully covers mark in a non-overlapping fashion.
+   *        If verification passes, return splits from outermost to inner most order.
+   *        If not, return an empty array
+   * \param mark The iterator of interest.
+   * \param splits The splits to be verified.
+   * \return The normalized splits.
+   */
+  Array<IterSplitExpr> TryNormalizeSplits(const IterMark& mark,
+                                          const std::vector<IterSplitExpr>& splits) {
+    std::vector<bool> used(splits.size(), false);
+    std::vector<IterSplitExpr> iters;
+    PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
+
+    for (size_t i = 0; i < splits.size(); ++i) {
+      size_t j = 0;
+      for (; j < splits.size(); ++j) {
+        if (used[j]) continue;
+        if (!used[j] && CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break;
+      }
+      if (j == splits.size()) {
+        return Array<IterSplitExpr>();
+      }
+      used[j] = true;
+      iters.push_back(splits[j]);
+      expected_lower_factor *= splits[j]->extent;
+    }
+    if (!CanProveEqual(expected_lower_factor, mark->extent)) return Array<IterSplitExpr>();
+    return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+  }
+
+  /*!
+   * \brief Normalize expr to an iterator + offset.
+   * \param expr The input expression.
+   * \return The Normalized expression.
+   */
+  IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
+    if (expr->args.size() <= 1) return expr;
+    PrimExpr base = expr->base;
+    expr.CopyOnWrite()->base = make_zero(expr->dtype);
+    auto opt = TryFuseIters(expr);
+    expr.CopyOnWrite()->base = base;
+    if (opt) {
+      expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()});
+      return expr;
+    } else {
+      ++unresolved_count_;
+      return expr;
+    }
+  }
+
+  bool CanProveEqual(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value == crhs->value;
+    return analyzer_->CanProve(lhs - rhs == 0);
+  }
+
+  /*!
+   * \brief Create a IterSumExpr from expr.
+   * \param expr The input expr.
+   * \return The transformed IterSumExpr.
+   */
+  IterSumExpr ToIterSumExpr(PrimExpr expr) {
+    if (const auto* op = expr.as<IterSumExprNode>()) {
+      return GetRef<IterSumExpr>(op);
+    } else if (const auto* op = expr.as<IterSplitExprNode>()) {
+      return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype));
+    } else {
+      CHECK(!expr->IsInstance<IterMapExprNode>());
+      return IterSumExpr({}, expr);
+    }
+  }
+
+  // Try to normalize IterSum into a fused IterMark
+  // return a corresponding splitexpr if needed.
+  Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) {
+    if (!is_zero(expr->base)) return NullOpt;
+    if (expr->args.size() == 1) return expr->args[0];
+    // select the iterators in order
+    std::vector<bool> visited(expr->args.size(), false);
+    std::vector<IterSplitExpr> iters;
+    iters.reserve(expr->args.size());
+    // canonicalize the expression
+    // check if it can be remapped into a fused pattern.
+    PrimExpr expected_scale = make_const(expr->base->dtype, 1);
+    for (size_t i = 0; i < expr->args.size(); ++i) {
+      size_t j = 0;
+      for (; j < expr->args.size(); ++j) {
+        if (!visited[j] && CanProveEqual(expr->args[j]->scale, expected_scale)) break;
+      }
+      if (j == expr->args.size()) {
+        return NullOpt;
+      }
+      visited[j] = true;
+      iters.push_back(expr->args[j]);
+      expected_scale *= expr->args[j]->extent;
+    }
+    // update the iterator to use the canonicalized form
+    expr.CopyOnWrite()->args = Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+    auto it = sum_fuse_map_.find(expr);
+    if (it != sum_fuse_map_.end()) return it->second;
+    auto mark = IterMark(expr, expected_scale);
+    IterSplitExpr split(mark);
+    sum_fuse_map_[expr] = split;
+    return split;
+  }
+
+  bool CanProveDivisible(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value % crhs->value == 0;
+    return analyzer_->CanProve(floormod(lhs, rhs) == 0);
+  }
+
+  PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
+  PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
+    tir::ExprDeepEqual equal;
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) &&
+          equal(lvalue->extent, rhs->extent)) {
+        if (sign > 0) {
+          rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale;
+        } else {
+          rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale;
+        }
+        lhs->args.Set(i, rhs);
+        return;
+      }
+    }
+    if (sign > 0) {
+      lhs->args.push_back(rhs);
+    } else {
+      rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale;
+      lhs->args.push_back(rhs);
+    }
+  }
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSumExpr rhs, int sign) {
+    for (size_t i = 0; i < rhs->args.size(); ++i) {
+      AddToLhs(lhs, rhs->args[i], sign);
+    }
+    if (sign > 0) {
+      lhs->base += rhs->base;
+    } else {
+      lhs->base -= rhs->base;
+    }
+  }
+
+  static void MulToLhs(IterSumExprNode* lhs, PrimExpr rhs) {
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      lvalue.CopyOnWrite()->scale *= rhs;
+      lhs->args.Set(i, lvalue);
+    }
+    lhs->base *= rhs;
+  }
+};
+
+Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
+                                 arith::Analyzer* analyzer) {
+  // Overall detection algorithm is divided into two steps:
+  // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
+  // - Step1: IterIndependenceChecker checks if the iterator are independent.
+  IterMapRewriter rewriter(analyzer, input_iters);
+  Array<IterSumExpr> results;
+
+  for (PrimExpr value : indices) {
+    results.push_back(rewriter.Rewrite(value));
+    if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
+  }
+  if (!rewriter.CheckBijective(results)) return Array<IterSumExpr>();

Review comment:
       It also checks all the intermediate marks are being covered without overlapping




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on a change in pull request #6667: [ARITH] Introduce iterator (quasi)affine map detection.

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #6667:
URL: https://github.com/apache/incubator-tvm/pull/6667#discussion_r503749804



##########
File path: include/tvm/arith/iter_affine_map.h
##########
@@ -0,0 +1,276 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/iter_affine_map.h
+ * \brief Iterator quasi-affine mapping patterns.
+ *
+ *  This file defines a collection of mapping patterns
+ *  maps a collection of independent iterators to another
+ *  collection of independent iterators.
+ *
+ *  There are two main kinds of mapping patterns:
+ *
+ *  - Fuse: fuse a collection of iterators into a single one
+ *
+ *    domain(x0) = [0, 4), domain(x1) = [0, 3), domain(x2) = [0, 2)
+ *    fuse(x0, x1, x2): y = x2 * 12 + x1 * 4 + x0
+ *    domain(y) = [0, 24)
+ *
+ *  - Split: split an iterator into multiple ones
+ *
+ *    domain(x) = [0, 24)
+ *    split(x, 3, 12): [y0, y1, y2] = [x % 3, (x % 12) / 3, x / 12]
+ *    domain(y0) = [0, 3), domain(y1) = [0, 4), domain(y2) = [0, 2)
+ *
+ *  We use the name "(quasi)affine" to be consistent with
+ *  the terminology used in the polyhedral compilation.
+ *  Notably, fuse is an affine transformation,
+ *  while split corresponds to additional floordiv/mod operations
+ *  that can appear in quasi-affine transformations.
+ */
+#ifndef TVM_ARITH_ITER_AFFINE_MAP_H_
+#define TVM_ARITH_ITER_AFFINE_MAP_H_
+
+#include <tvm/ir/expr.h>
+
+namespace tvm {
+namespace arith {
+
+/*!
+ * \brief Base class of all iter map expressions.
+ *
+ *  An IterMapExpr is a special expression to store
+ *  the result of IterMapDetection, it should not
+ */
+class IterMapExprNode : public PrimExprNode {
+ public:
+  // overrides
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "arith.IterMapExpr";
+  static constexpr const uint32_t _type_child_slots = 3;
+  TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode);
+};
+
+/*!
+ * \brief Managed reference to IterMapExprNode.
+ * \sa IterMapExprNode
+ */
+class IterMapExpr : public PrimExpr {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(IterMapExpr, PrimExpr, IterMapExprNode);
+};
+
+/*!
+ * \brief Mark the source as an iterator in [0, extent).
+ *
+ *  IterMark is used to mark source expression as a valid
+ *  iterator to make future analysis easy.
+ */
+class IterMarkNode : public Object {
+ public:
+  /*!
+   * \brief The source expression, can either be
+   *  a IterSumExpr or a Var.
+   */
+  PrimExpr source;
+  /*!
+   * \brief The extent of the iteration.
+   */
+  PrimExpr extent;
+
+  // overrides
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("source", &source);
+    v->Visit("extent", &extent);
+  }
+
+  bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const {
+    equal->MarkGraphNode();
+    return equal(source, other->source) && equal(extent, other->extent);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce->MarkGraphNode();
+    hash_reduce(source);
+    hash_reduce(extent);
+  }
+
+  static constexpr const bool _type_has_method_sequal_reduce = true;
+  static constexpr const bool _type_has_method_shash_reduce = true;
+  static constexpr const char* _type_key = "arith.IterMark";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IterMarkNode, Object);
+};
+
+/*!
+ * \brief Managed reference to IterMarkExprNode.
+ * \sa IterMarkExprNode
+ */
+class IterMark : public ObjectRef {
+ public:
+  /*!
+   * \brief constructor.
+   * \param source The source expression.
+   * \param extent The extent of the iterator.
+   */
+  TVM_DLL IterMark(PrimExpr source, PrimExpr extent);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode);
+};
+
+/*!
+ * \brief Split of an iterator.
+ *
+ *  result = floormod(floordiv(source, lower_factor), extent) * scale
+ */
+class IterSplitExprNode : public IterMapExprNode {
+ public:
+  /*! \brief The source marked iterator. */
+  IterMark source;

Review comment:
       Can `source` of IterSplitExpr be an IterSumExpr, i.e. we split some iterators, fuse some of them and split some of them again?

##########
File path: include/tvm/arith/iter_affine_map.h
##########
@@ -0,0 +1,275 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/iter_affine_map.h
+ * \brief Iterator quasi-affine mapping patterns.
+ *
+ *  This file defines a collection of mapping patterns
+ *  maps a collection of independent iterators to another
+ *  collection of independent iterators.
+ *
+ *  There are two main kinds of mapping patterns:
+ *
+ *  - Fuse: fuse a collection of iterators into a single one
+ *
+ *    domain(x0) = [0, 4), domain(x1) = [0, 3), domain(x2) = [0, 2)
+ *    fuse(x0, x1, x2): y = x2 * 12 + x1 * 4 + x0
+ *    domain(y) = [0, 24)
+ *
+ *  - Split: split an iterator into multiple ones
+ *
+ *    domain(x) = [0, 24)
+ *    split(x, 3, 12): [y0, y1, y2] = [x % 3, (x % 12) / 3, x / 12]
+ *    domain(y0) = [0, 3), domain(y1) = [0, 4), domain(y2) = [0, 2)
+ *
+ *  We use the name "(quasi)affine" to be consistent with
+ *  the terminology used in the polyhedral compilation.
+ *  Notably, fuse is an affine transformation,
+ *  while split corresponds to additional floordiv/mod operations
+ *  that can appear in quasi-affine transformations.
+ */
+#ifndef TVM_ARITH_ITER_AFFINE_MAP_H_
+#define TVM_ARITH_ITER_AFFINE_MAP_H_
+
+#include <tvm/ir/expr.h>
+
+namespace tvm {
+namespace arith {
+
+/*!
+ * \brief Base class of all iter map expressions.
+ *
+ *  An IterMapExpr is a special expression to store
+ *  the result of IterMapDetection, it should not

Review comment:
       Would be great to complete the comments

##########
File path: src/arith/iter_affine_map.cc
##########
@@ -0,0 +1,704 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/arith/iter_affine_map.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/iter_affine_map.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
+
+#include "../support/util.h"
+#include "const_fold.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+IterMark::IterMark(PrimExpr source, PrimExpr extent) {
+  auto n = make_object<IterMarkNode>();
+  n->source = std::move(source);
+  n->extent = std::move(extent);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) {
+  return IterMark(source, extent);
+});
+
+TVM_REGISTER_NODE_TYPE(IterMarkNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterMarkNode*>(node.get());
+      p->stream << "IterMark(" << op->source << ", extent=" << op->extent;
+    });
+
+IterSplitExpr::IterSplitExpr(IterMark source) {
+  auto n = make_object<IterSplitExprNode>();
+  auto one = make_const(source->source->dtype, 1);
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->extent = n->source->extent;
+  n->lower_factor = one;
+  n->scale = one;
+  data_ = std::move(n);
+}
+
+IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent,
+                             PrimExpr scale) {
+  auto n = make_object<IterSplitExprNode>();
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->lower_factor = std::move(lower_factor);
+  n->extent = std::move(extent);
+  n->scale = std::move(scale);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSplitExpr")
+    .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) {
+      return IterSplitExpr(source, lower_factor, extent, scale);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSplitExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSplitExprNode*>(node.get());
+      p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor
+                << ", extent=" << op->extent << ", scale=" << op->scale;
+    });
+
+IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) {
+  auto n = make_object<IterSumExprNode>();
+  n->dtype = base->dtype;
+  n->args = std::move(args);
+  n->base = std::move(base);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSumExpr")
+    .set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) {
+      return IterSumExpr(args, base);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSumExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSumExprNode*>(node.get());
+      p->stream << "IterSum(" << op->args << ", " << op->base << ")";
+    });
+
+/*!
+ * \brief Util to check if all splits in the sumexpr are
+ *        independent and complete (covers all the original iter space).
+ *
+ */
+class IterMarkSplitCollector {
+ public:
+  // mark all IterMarks that are visited.
+  std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
+  // each iter mark to its outgoing splits that are referenced.
+  std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual>
+      mark2splits_;
+  /*!
+   * \brief Collect all mark2splits recursively from indices.
+   * \param indices The iterator of interest.
+   */
+  void Collect(const Array<IterSumExpr>& indices) {
+    for (IterSumExpr sum_expr : indices) {
+      for (IterSplitExpr split : sum_expr->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+
+  void CollectInternal(const IterMark& mark) {
+    if (visited_.count(mark)) return;
+    visited_.insert(mark);
+    if (auto* op = mark->source.as<IterSumExprNode>()) {
+      for (IterSplitExpr split : op->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+};
+
+// Rewriter to rewrite oinformations in iter

Review comment:
       Typo




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org