You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/31 18:22:45 UTC

[GitHub] [tvm] hzfan commented on a change in pull request #7760: [ARITH] Subspace division

hzfan commented on a change in pull request #7760:
URL: https://github.com/apache/tvm/pull/7760#discussion_r605105753



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -1086,5 +1086,302 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter
   return NormalizeIterMapToExpr(expr);
 });
 
+/*!
+ * \brief Divider to divide the bindings into two sets of bindings(outer and inner)
+ *   such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X.
+ *   We do message passing among IterSplitExpr and IterSumExpr.
+ *
+ *   Example
+ *   - If we encounter sum = i*10 + j*5 + k, and i, j, k are splits,
+ *     and we know i = Yi*1 + 0, j = 0*E(Xj) + Xj, k = 0*E(Xk) + Xk through message passing,
+ *     then sum = Yi*10 + (Xj*5 + Xk) = Y*E(X) + X, where Y = Yi, X = Xj*5 + Xk.
+ *   - If we encounter split = (i / 2) % 4, and we know i = Y*E(X) + X through message passing.
+ *     We inspect all the splits of i, which are i / 8, (i / 2) % 4, i % 2.
+ *     Their extents are 2, 4, 2, if E(X) = 2, 8, 16, the splits can be divided.
+ */
+class SubspaceDivider {
+ public:
+  explicit SubspaceDivider(Analyzer* analyzer, const IterMarkSplitCollector& collector,
+                           const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters)
+      : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {}
+
+  size_t unresolved_count() const { return unresolved_count_; }
+
+  // Denotes outer*inner_extent + inner, used as message passing carrier
+  struct DivisionResult {
+   public:
+    // IterMapExpr of outer iters
+    IterMapExpr outer;
+    // IterMapExpr of inner iters
+    IterMapExpr inner;
+    // extent of outer
+    PrimExpr outer_extent;
+    // extent of inner
+    PrimExpr inner_extent;
+
+    DivisionResult(IterMapExpr outer, PrimExpr outer_extent, IterMapExpr inner,
+                   PrimExpr inner_extent)
+        : outer(std::move(outer)),
+          inner(std::move(inner)),
+          outer_extent(std::move(outer_extent)),
+          inner_extent(std::move(inner_extent)) {}
+
+    // whether the division result is totally in outer subspace
+    bool IsOuter() const { return is_one(inner_extent); }
+
+    // whether the division result is totally in inner subspace
+    bool IsInner() const { return is_one(outer_extent); }
+
+    IterSplitExpr GetOuterAsSplit() const { return GetAsSplit(outer, outer_extent); }
+
+    IterSplitExpr GetInnerAsSplit() const { return GetAsSplit(inner, inner_extent); }
+
+    static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) {
+      return DivisionResult(IterSumExpr({}, 0), 1, iter, extent);
+    }
+
+    static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) {
+      return DivisionResult(iter, extent, IterSumExpr({}, 0), 1);
+    }
+
+   private:
+    static IterSplitExpr GetAsSplit(const IterMapExpr& expr, const PrimExpr& extent) {
+      if (const auto* op = expr.as<IterSplitExprNode>()) {
+        return GetRef<IterSplitExpr>(op);
+      } else if (const auto* op = expr.as<IterSumExprNode>()) {
+        return IterSplitExpr(IterMark(GetRef<IterSumExpr>(op), extent));
+      } else {
+        LOG(FATAL) << "Unknown IterMapExpr type";
+        return NullValue<IterSplitExpr>();
+      }
+    }
+  };
+
+  // Divide an IterSumExpr
+  DivisionResult DivideIterSumExpr(const IterSumExpr& expr, const PrimExpr& mark_extent) {
+    if (expr->args.empty()) {
+      // base
+      return DivisionResult(IterSumExpr({}, 0), 1, IterSumExpr({}, expr->base), 1);
+    } else if (expr->args.size() == 1) {
+      // arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base)
+      if (!is_one(expr->args[0]->scale)) return Fail();
+      DivisionResult res = DivideIterSplitExpr(expr->args[0]);
+      if (!is_zero(expr->base)) res = AddBase(res, expr->base);
+      return res;
+    }
+    // arg1 + arg2 + ... + argn + base
+    // then we can write it as Y*E(X)+X
+    // if it starts with contiguous outer splits, followed by contiguous inner splits
+    PrimExpr extent = 1;
+    std::vector<IterSplitExpr> outer_args, inner_args;
+    bool inner = true, scale_is_one = false;
+    // we check in inverse order so we can visit from inner to outer
+    for (auto it = expr->args.rbegin(); it != expr->args.rend(); ++it) {
+      const IterSplitExpr& arg = *it;
+      if (is_one(arg->scale)) scale_is_one = true;
+      DivisionResult arg_division = DivideIterSplitExpr(arg);
+      IterSplitExpr new_arg;
+      if (arg_division.IsInner()) {
+        if (!inner) return Fail();
+        new_arg = arg_division.GetInnerAsSplit();
+        inner_args.push_back(new_arg);
+        inner = true;
+      } else if (arg_division.IsOuter()) {
+        new_arg = arg_division.GetOuterAsSplit();
+        outer_args.push_back(new_arg);
+        inner = false;
+      } else {
+        return Fail();
+      }
+      extent *= new_arg->extent;
+    }
+    if (!scale_is_one) return Fail();
+    bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent);
+    const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, 0);
+    const IterMark& inner_mark = MarkFromArgsAndBase(inner_args, expr->base);
+    IterSumExpr outer_source = Downcast<IterSumExpr>(outer_mark->source);
+    IterSumExpr inner_source = Downcast<IterSumExpr>(inner_mark->source);
+    if (need_predicate) {
+      // if we have a predicate on this sum expr, then we cannot divide it into Y*E+X
+      // it should either be Y*1+0 or 0*E(X)+X
+      IterMapToExprNormalizer converter(analyzer_);
+      if (inner_args.empty()) {
+        // Y*1+0
+        outer_preds_ = outer_preds_ && (converter.Convert(outer_source) < mark_extent);
+        return DivisionResult::Outer(outer_source, mark_extent);
+      } else if (outer_args.empty()) {
+        // 0*E(X)+X
+        inner_preds_ = inner_preds_ && (converter.Convert(inner_source) < mark_extent);
+        return DivisionResult::Inner(inner_source, mark_extent);
+      } else {
+        return Fail();
+      }
+    }
+    return DivisionResult(outer_source, outer_mark->extent, inner_source, inner_mark->extent);
+  }
+
+  PrimExpr GetOuterPreds() const { return outer_preds_; }
+  PrimExpr GetInnerPreds() const { return inner_preds_; }
+
+ private:
+  DivisionResult Fail() {
+    unresolved_count_++;
+    return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
+  }
+
+  DivisionResult AddBase(DivisionResult division, PrimExpr base) {
+    DivisionResult res = division;
+    if (const auto* op = division.inner.as<IterSplitExprNode>()) {
+      res.inner = IterSumExpr({GetRef<IterSplitExpr>(op)}, base);
+    } else if (const auto* op = division.inner.as<IterSumExprNode>()) {
+      const auto& expr = GetRef<IterSumExpr>(op);
+      res.inner = IterSumExpr(expr->args, expr->base + base);
+    }
+    return res;
+  }
+
+  // args are sorted from inner to outer
+  static IterMark MarkFromArgsAndBase(const std::vector<IterSplitExpr>& args, PrimExpr base) {
+    std::vector<IterSplitExpr> res;
+    PrimExpr extent = 1;
+    for (const IterSplitExpr& it : args) {
+      IterSplitExpr arg = it;
+      arg.CopyOnWrite()->scale = extent;
+      extent *= arg->extent;
+      res.push_back(arg);
+    }
+    return IterMark(IterSumExpr(Array<IterSplitExpr>(res.rbegin(), res.rend()), base), extent);
+  }
+
+  DivisionResult DivideIterSplitExpr(const IterSplitExpr& expr) {
+    auto it = split_map_.find(expr);
+    if (it != split_map_.end()) {
+      // We will calculate all the splits of an IterMark's division form when we first
+      // encounter one of them. If we encounter another later, we directly return the record.
+      return it->second;
+    }
+    const Array<IterSplitExpr>& splits = collector_.mark2splits_.at(expr->source);
+    if (const auto* iter_ptr = expr->source->source.as<VarNode>()) {
+      // source is input_iter
+      bool inner = sub_iters_.count(GetRef<Var>(iter_ptr));
+      for (const IterSplitExpr& split : splits) {
+        if (inner) {
+          // 0*E(split)+split
+          split_map_.emplace(split, DivisionResult::Inner(split, split->extent));
+        } else {
+          // split*1 + 0
+          split_map_.emplace(split, DivisionResult::Outer(split, split->extent));
+        }
+      }
+    } else if (const auto* iter_ptr = expr->source->source.as<IterSumExprNode>()) {
+      // source = Y*E+X
+      // splits = [s1, s2, ..., sn]
+      // we can divide if there exists i, such that extent(s1)extent(s2)...extent(si)=extent(Y)
+      //                                            extent(si+1)...extent(sn)=extent(X)
+      // For example, if source = Y*3+X \in [0, 12), Y \in [0, 4), X \in [0, 3)
+      // Case 1. splits = [s1, s2, s3] = [source / 6, (source / 3) % 2, source % 3],
+      //         where extent(s1) = 2, extent(s2) = 2, extent(s3) = 3.
+      //         Since extent(s1)extent(s2) = extent(Y), extent(s3) = extent(X), we have
+      //         s1 = (Y / 2)*1 + 0, s2 = (Y % 2)*1 + 0, s3 = 0*3 + X
+      // Case 2. splits = [s1, s2, s3] = [source / 4, (source / 2) % 2, source % 2],
+      //         where extent(s1) = 3, extent(s2) = 2, extent(s3) = 2.
+      //         It's impossible to rewrite s1, s2, s3 in the form of Y*E(X) + X.
+      DivisionResult mark_division =
+          DivideIterSumExpr(GetRef<IterSumExpr>(iter_ptr), expr->source->extent);
+      if (splits.size() == 1) {
+        return mark_division;
+      }
+      IterMark outer_mark(Downcast<IterSumExpr>(mark_division.outer), mark_division.outer_extent);
+      IterMark inner_mark(Downcast<IterSumExpr>(mark_division.inner), mark_division.inner_extent);
+      bool encountered_boundary = mark_division.IsOuter();
+      std::vector<bool> used(splits.size(), false);
+      std::vector<IterSplitExpr> inner_iters, outer_iters;
+      PrimExpr expected_lower_factor = make_const(expr->source->source->dtype, 1);
+      // find the boundary of outer and inner, like case 1 above
+      for (size_t i = 0; i < splits.size(); ++i) {
+        size_t j = 0;
+        for (; j < splits.size(); ++j) {
+          if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor))
+            break;
+        }
+        if (j == splits.size()) return Fail();
+        used[j] = true;
+        if (!encountered_boundary) {
+          inner_iters.push_back(splits[j]);
+        } else {
+          outer_iters.push_back(splits[j]);
+        }
+        expected_lower_factor *= splits[j]->extent;
+        if (analyzer_->CanProveEqual(expected_lower_factor, mark_division.inner_extent))
+          encountered_boundary = true;
+      }
+      if (!encountered_boundary) return Fail();
+      for (const IterSplitExpr& inner_iter : inner_iters) {
+        IterSplitExpr new_iter = inner_iter;
+        new_iter.CopyOnWrite()->source = inner_mark;
+        split_map_.emplace(inner_iter, DivisionResult::Inner(new_iter, inner_iter->extent));
+      }
+      for (const IterSplitExpr& outer_iter : outer_iters) {
+        IterSplitExpr new_iter = outer_iter;
+        new_iter.CopyOnWrite()->source = outer_mark;
+        new_iter.CopyOnWrite()->lower_factor =
+            floordiv(outer_iter->lower_factor, outer_iters[0]->lower_factor);
+        split_map_.emplace(outer_iter, DivisionResult::Outer(new_iter, outer_iter->extent));
+      }
+    } else {
+      return Fail();
+    }
+    return split_map_.at(expr);
+  }
+
+  size_t unresolved_count_{0};

Review comment:
       Could you add comments for these members? Same for `analyzer_` and `collector_`

##########
File path: python/tvm/arith/iter_affine_map.py
##########
@@ -128,3 +128,46 @@ def normalize_iter_map_to_expr(expr):
         the corresponding normal PrimExpr
     """
     return _ffi_api.NormalizeIterMapToExpr(expr)
+
+
+def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bijective=False):
+    """Detect if bindings can be written as
+    [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
+    where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters)
+          b = some-quasi-affine-iter-map(sub_iters)
+          c is constant symbols
+          e is the extent of b
+    For example, z*12 + y*3 + x + c = (z*4+y)*3 + x
+                bindings = [z*12 + y*3 + x + c]
+                input_iters = [z, y, x]
+                sub_iter = [x]
+                Then the result will be [a, b] where
+                a = [z*4 + y]
+                b = [x]
+
+    Parameters
+    ----------
+    bindings : List[PrimExpr]
+        The input bindings
+
+    input_iters : Map[Var, Range]
+        The domain of input iterator, which is the basis of the whole space
+
+    sub_iters : Array[Var]
+        The subset of input_iters, which is the basis of the subspace
+
+    predicate : PrimExpr
+        The predicate constraints on the input iterators
+
+    require_bijective : bool
+        A boolean flag that indicates whether the bindings should be bijective
+
+    Returns
+    -------
+    results : List[List[PrimExpr]]

Review comment:
       What's the length of the outer list? It seems to me that it's always `len(bindings) + 1`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org