You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/30 00:50:37 UTC

[GitHub] [tvm] Hzfengsy commented on a change in pull request #7760: [ARITH] Subspace division

Hzfengsy commented on a change in pull request #7760:
URL: https://github.com/apache/tvm/pull/7760#discussion_r603704512



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -1028,5 +1028,358 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
   }
 }
 
+/*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */
+class IterMapToExprNormalizer {
+ public:
+  explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {}
+
+  PrimExpr Convert(const IterMapExpr& expr) {
+    if (const auto* op = expr.as<IterSplitExprNode>()) {
+      return ConvertIterSplitExpr(GetRef<IterSplitExpr>(op));
+    } else if (const auto* op = expr.as<IterSumExprNode>()) {
+      return ConvertIterSumExpr(GetRef<IterSumExpr>(op));
+    } else {
+      ICHECK(expr.defined());
+      LOG(FATAL) << "Unknown IterMapExpr type " << expr->GetTypeKey();
+      return 0;
+    }
+  }
+
+  PrimExpr ConvertIterSumExpr(const IterSumExpr& expr) {
+    PrimExpr res = 0;
+    for (const IterSplitExpr& arg : expr->args) {
+      res += ConvertIterSplitExpr(arg);
+    }
+    res += expr->base;
+    return res;
+  }
+
+  PrimExpr ConvertIterSplitExpr(const IterSplitExpr& expr) {
+    PrimExpr source;
+    if (const auto* op = expr->source->source.as<VarNode>()) {
+      source = GetRef<Var>(op);
+    } else if (const auto& op = expr->source->source.as<IterSumExprNode>()) {
+      source = ConvertIterSumExpr(GetRef<IterSumExpr>(op));
+    }
+    if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) {
+      return source * expr->scale;
+    } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) {
+      return floordiv(source, expr->lower_factor) * expr->scale;
+    } else {
+      return floormod(floordiv(source, expr->lower_factor), expr->extent) * expr->scale;
+    }
+  }
+
+ private:
+  Analyzer* analyzer_;
+};
+
+PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr) {
+  arith::Analyzer analyzer;
+  IterMapToExprNormalizer normalizer(&analyzer);
+  return normalizer.Convert(expr);
+}
+
+TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const IterMapExpr& expr) {
+  return NormalizeIterMapToExpr(expr);
+});
+
+/*!
+ * \brief Divider to divide the bindings into two sets of bindings(outer and inner)
+ *   such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X.
+ *   We do message passing among IterSplitExpr and IterSumExpr.
+ *
+ *   Example
+ *   - If we encounter sum = i*10 + j*5 + k, and i, j, k are splits,
+ *     and we know i = Yi*1 + 0, j = 0*E(Xj) + Xj, k = 0*E(Xk) + Xk through message passing,
+ *     then sum = Yi*10 + (Xj*5 + Xk) = Y*E(X) + X, where Y = Yi, X = Xj*5 + Xk.
+ *   - If we encounter split = (i / 2) % 4, and we know i = Y*E(X) + X through message passing.
+ *     We inspect all the splits of i, which are i / 8, (i / 2) % 4, i % 2.
+ *     Their extents are 2, 4, 2, if E(X) = 2, 8, 16, the splits can be divided.
+ */
+class SubspaceDivider {
+ public:
+  explicit SubspaceDivider(Analyzer* analyzer, const IterMarkSplitCollector& collector,
+                           const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters)
+      : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {}
+
+  size_t unresolved_count() const { return unresolved_count_; }
+
+  // Denotes outer*inner_extent + inner, used as message passing carrier
+  struct DivisionResult {
+   public:
+    // IterMapExpr of outer iters
+    IterMapExpr outer;
+    // IterMapExpr of inner iters
+    IterMapExpr inner;
+    // extent of outer
+    PrimExpr outer_extent;
+    // extent of inner
+    PrimExpr inner_extent;
+
+    DivisionResult(IterMapExpr outer, PrimExpr outer_extent, IterMapExpr inner,
+                   PrimExpr inner_extent)
+        : outer(std::move(outer)),
+          inner(std::move(inner)),
+          outer_extent(std::move(outer_extent)),
+          inner_extent(std::move(inner_extent)) {}
+
+    // whether the division result is totally in outer subspace
+    bool IsOuter() const { return is_one(inner_extent); }
+
+    // whether the division result is totally in inner subspace
+    bool IsInner() const { return is_one(outer_extent); }
+
+    IterSplitExpr GetOuterAsSplit() const { return GetAsSplit(outer, outer_extent); }
+
+    IterSplitExpr GetInnerAsSplit() const { return GetAsSplit(inner, inner_extent); }
+
+    static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) {
+      return DivisionResult(IterSumExpr({}, 0), 1, iter, extent);
+    }
+
+    static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) {
+      return DivisionResult(iter, extent, IterSumExpr({}, 0), 1);
+    }
+
+   private:
+    static IterSplitExpr GetAsSplit(const IterMapExpr& expr, const PrimExpr& extent) {
+      if (const auto* op = expr.as<IterSplitExprNode>()) {
+        return GetRef<IterSplitExpr>(op);
+      } else if (const auto* op = expr.as<IterSumExprNode>()) {
+        return IterSplitExpr(IterMark(GetRef<IterSumExpr>(op), extent));
+      } else {
+        LOG(FATAL);

Review comment:
       It would be great if you can add some error message here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org