You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/07/12 18:35:49 UTC

[GitHub] [tvm] vinx13 commented on a diff in pull request #12070: [TIR][Schedule] Refactor Tensorize

vinx13 commented on code in PR #12070:
URL: https://github.com/apache/tvm/pull/12070#discussion_r919237660


##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -65,15 +79,15 @@ class SubspaceNotDivisibleError : public ScheduleError {
  * \param iter_vars The input iterators
  * \param bindings The values of iter_vars
  * \param outer_loops Iterators outside the subspace.
- * \param inner_loops Iterators of the subspace
+ * \param loops Iterators of the subspace

Review Comment:
   params names are not consistent with function signature



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -122,531 +136,461 @@ Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>& iter
 }
 
 /*!
- * \brief Generate the blockized init block.
- * \param block The original block with init.
- * \param inner_block_realize The block realize of the inner block after blockize.
- * \param inner_loops The inner loops after blockize.
- * \return The subtree of the init block and its outer loops.
+ * \brief Subspace division. The space is divided into two subspaces:
+ *  1. The subspace represented by the outer loops above `loop_sref` (exclusive).
+ *  2. The subspace represented by the inner loops below `loop_sref` (inclusive).
+ * \param realize The inner block
+ * \param block_sref The sref to the inner block
+ * \param loop_sref The loop that is the root of the second subspace.
+ * \param loops The loops that represents the second part of the subspace.
+ * \param analyzer The arithmetic analyzer to use.
  */
-Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_realize,
-                           const std::vector<const ForNode*>& inner_loops) {
-  Array<IterVar> init_block_iters;
-  Array<PrimExpr> init_bindings;
-  const Block& inner_block = inner_block_realize->block;
-
-  // Step 1: Collect data-parallel block iters
-  for (size_t i = 0; i < inner_block->iter_vars.size(); i++) {
-    const IterVar& iter_var = inner_block->iter_vars[i];
-    const PrimExpr& binding = inner_block_realize->iter_values[i];
-    if (iter_var->iter_type == IterVarType::kDataPar &&
-        UsesVar(block->init.value(),
-                [tgt_var = iter_var->var.get()](const VarNode* var) { return var == tgt_var; })) {
-      init_block_iters.push_back(iter_var);
-      init_bindings.push_back(binding);
+Array<Array<arith::IterMark>> SubspaceDivide(const BlockRealize& realize,
+                                             const StmtSRef& block_sref,  //
+                                             const StmtSRef& loop_sref,   //
+                                             std::vector<const ForNode*>* loops,
+                                             arith::Analyzer* analyzer) {
+  Array<Var> inner_vars;
+  Array<Var> outer_vars;
+  Map<Var, Range> loop_var_domain;
+  bool inner = true;
+  for (StmtSRefNode* sref = block_sref->parent;    //
+       sref && sref->stmt->IsInstance<ForNode>();  //
+       sref = sref->parent) {
+    const ForNode* loop = static_cast<const ForNode*>(sref->stmt);
+    if (inner) {
+      loops->push_back(loop);
+      inner_vars.push_back(loop->loop_var);
+    } else {
+      outer_vars.push_back(loop->loop_var);
     }
-  }
-
-  // Step 2: Collect loops related to iters of the init block
-  std::vector<const ForNode*> init_loops;
-  for (const ForNode* inner_loop : inner_loops) {
-    for (const PrimExpr& init_binding : init_bindings) {
-      if (UsesVar(init_binding, [tgt_var = inner_loop->loop_var.get()](const VarNode* var) {
-            return var == tgt_var;
-          })) {
-        init_loops.push_back(inner_loop);
-        break;
-      }
+    loop_var_domain.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
+    if (sref == loop_sref.get()) {
+      inner = false;
     }
   }
-
-  // Step 3: Create new block iters for the init block
-  Map<Var, PrimExpr> subst_map;
-  for (size_t i = 0; i < init_block_iters.size(); i++) {
-    IterVar new_iter_var = init_block_iters[i];
-    Var old_var = new_iter_var->var;
-    Var new_var = old_var.copy_with_suffix("_init");
-    new_iter_var.CopyOnWrite()->var = new_var;
-    subst_map.Set(old_var, new_var);
-    init_block_iters.Set(i, std::move(new_iter_var));
-  }
-
-  // Step 4: Generate loop nests and the init block
-  Stmt new_init = BlockRealize(
-      /*iter_values=*/init_bindings,
-      /*predicate=*/inner_block_realize->predicate,
-      /*block=*/
-      Block{/*iter_vars=*/init_block_iters,
-            /*reads=*/{},
-            /*writes=*/block->writes,
-            /*name_hint=*/block->name_hint + "_init",
-            /*body=*/block->init.value(),
-            /*init=*/NullOpt});
-
-  // Step 5: Generate the parent loops for the init block
-  for (const ForNode* init_loop : init_loops) {
-    ObjectPtr<ForNode> new_loop = make_object<ForNode>(*init_loop);
-    new_loop->loop_var = init_loop->loop_var.copy_with_suffix("");
-    subst_map.Set(init_loop->loop_var, new_loop->loop_var);
-    new_loop->body = std::move(new_init);
-    new_init = For(new_loop);
+  Array<Array<arith::IterMark>> result =
+      arith::SubspaceDivide(realize->iter_values, loop_var_domain, inner_vars, realize->predicate,
+                            arith::IterMapLevel::Surjective, analyzer);
+  if (!result.empty()) {
+    return result;
   }
-
-  // Step 6: Substitute with new loop variables and block iters to prevent duplication of
-  // variables in the outer block.
-  new_init = Substitute(new_init, subst_map);
-
-  return new_init;
+  return TrivialSubspaceDivision(realize->block->iter_vars,
+                                 realize->iter_values,  //
+                                 realize->predicate,    //
+                                 outer_vars, inner_vars);
 }
 
 /*!
- * \brief A helper to collect the parent loops of the block. The loops are divided into two groups,
- * 'outer_loops', and 'inner_loops', by a specified loop as the separator. 'outer_loops' are the
- * ancestor loops of the separator loop. 'inner_loops' include the separator loop itself, and its
- * successor loops. It is possible that 'outer_loops' is empty.
+ * \brief Derive the block bindings for both inner and outer block
+ * \param iter_vars The original block iterators to the inner block
+ * \param division The subspace division.
+ * \param outer_iter_vars The outer block iterators.
+ * \param outer_bindings The outer block bindings.
+ * \param inner_iter_vars The inner block iterators.
+ * \param inner_bindings The inner block bindings.
+ * \return A substitution plan to the iterators in the original inner block.
  */
-class LoopSubspaceCollector {
- public:
-  /*!
-   * \brief Collect the parent loops of the block and store the result in the corresponding fields.
-   * \param block_sref The sref to the target block.
-   * \param loop_sref The sref to the separator loop. The loop itself is counted as an inner loop.
-   */
-  void Collect(const StmtSRef& block_sref, const StmtSRef& loop_sref) {
-    bool inner = true;
-    for (StmtSRefNode* current_sref = block_sref->parent;
-         current_sref && current_sref->stmt->IsInstance<ForNode>();
-         current_sref = current_sref->parent) {
-      const auto* current_loop = current_sref->StmtAs<ForNode>();
-      ICHECK(current_loop);
-      if (inner) {
-        inner_loops.push_back(current_loop);
-        inner_loop_vars.push_back(current_loop->loop_var);
-      } else {
-        outer_loops.push_back(current_loop);
-        outer_loop_vars.push_back(current_loop->loop_var);
-      }
-      loop_var_domain.Set(current_loop->loop_var,
-                          Range::FromMinExtent(current_loop->min, current_loop->extent));
-      if (current_sref == loop_sref.get()) inner = false;
+Map<Var, PrimExpr> DeriveBlockBinding(const Array<IterVar>& iter_vars,                //
+                                      const Array<Array<arith::IterMark>>& division,  //
+                                      Array<IterVar>* outer_iter_vars,                //
+                                      Array<PrimExpr>* outer_bindings,                //
+                                      Array<IterVar>* inner_iter_vars,                //
+                                      Array<PrimExpr>* inner_bindings) {
+  using arith::IterMapExpr;
+  using arith::IterMapExprNode;
+  using arith::NormalizeIterMapToExpr;
+  Map<Var, PrimExpr> block_var_subst;
+  ICHECK_EQ(iter_vars.size() + 1, division.size());
+  for (int i = 0, n = iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = iter_vars[i];
+    arith::IterMark outer_mark = division[i][0];
+    arith::IterMark inner_mark = division[i][1];
+    IterMapExpr outer_binding = Downcast<IterMapExpr>(outer_mark->source);
+    IterMapExpr inner_binding = Downcast<IterMapExpr>(inner_mark->source);
+    // After computing the subspace division, bindings[i] can be written as
+    // outer_binding * inner_binding->extent + inner_binding
+    // The outer block will have binding: iter_outer -> outer_binding
+    // The inner block will have binding: iter_inner -> inner_binding
+    // The iter in the original block will be substituted with base + iter_inner where
+    // base == iter_outer * iter_inner_extent
+    if (is_one(inner_mark->extent)) {  // IsOuter
+      // extract this iter var to outer block directly
+      outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
+      outer_iter_vars->push_back(iter_var);
+      continue;
     }
+    // create iter var for the outer block
+    IterVar outer_iter(/*dom=*/RangeFromExtent(outer_mark->extent),
+                       /*var=*/iter_var->var.copy_with_suffix("_o"),
+                       /*iter_type=*/iter_var->iter_type);
+    outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
+    outer_iter_vars->push_back(outer_iter);
+    // create iter var for the inner block
+    IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent),
+                       /*var=*/iter_var->var.copy_with_suffix("_i"),
+                       /*iter_type=*/iter_var->iter_type);
+    inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding));
+    inner_iter_vars->push_back(inner_iter);
+    // substitution
+    PrimExpr sub{nullptr};
+    if (is_one(outer_mark->extent)) {
+      sub = inner_iter->var;
+    } else {
+      sub = outer_iter * inner_mark->extent + inner_iter->var;
+    }
+    block_var_subst.Set(iter_var->var, sub);
   }
-  /*! \brief Outer loops which are ancestors of the separator. */
-  std::vector<const ForNode*> outer_loops;
-  /*! \brief Inner loops which are the separator itself or its successors. */
-  std::vector<const ForNode*> inner_loops;
-  /*! \brief Loop variables of the outer loops. */
-  Array<Var> outer_loop_vars;
-  /*! \brief Loop variables of the inner loops. */
-  Array<Var> inner_loop_vars;
-  /*! \brief Domain of the loop variables. */
-  Map<Var, Range> loop_var_domain;
-};
+  return block_var_subst;
+}
 
 /*!
- * \brief Check the bindings of the block iters can be divided by a subspace collected by the
- * collector.
- * \param mod The current IR module.
- * \param block_realize The block realize to be checked.
- * \param collector The collector which has collected the loops of the block.
- * \param analyzer The arithmetic analyzer.
- * \return The result of the subspace division.
- * \throws ScheduleError If the bindings are not divisible by the subspace.
+ * \brief Generate the inner block for blockization
+ * \param iter_vars IterVars used in the inner block.

Review Comment:
   document `is_write_reduction`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org