You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/19 02:52:23 UTC

[GitHub] [tvm] junrushao1994 commented on a diff in pull request #11050: [TIR] Utility function to decide loop mapping for auto tensorization

junrushao1994 commented on code in PR #11050:
URL: https://github.com/apache/tvm/pull/11050#discussion_r852550848


##########
src/tir/schedule/analysis/analysis.cc:
##########
@@ -2028,5 +2034,107 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,   //
   }
 }
 
+TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);
+
+Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
+                                                const tir::StmtSRef& block_sref,
+                                                const tir::PrimFunc& desc_func) {
+  arith::Analyzer analyzer;
+  const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
+  // Step 1. Analyze desc_func, extract its block, loops and loop vars
+  const tir::BlockRealizeNode* desc_block = nullptr;
+  std::vector<const tir::ForNode*> desc_loops;
+  std::unordered_set<const tir::VarNode*> desc_loop_vars;
+  const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
+  ICHECK(desc_scope_realize);
+  {
+    auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
+                    &analyzer](const ObjectRef& obj) -> bool {
+      // Extract the block
+      if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
+        desc_block = block;
+        return false;
+      }
+      // Extract loops
+      if (const auto* loop = obj.as<tir::ForNode>()) {
+        desc_loops.push_back(loop);
+        desc_loop_vars.insert(loop->loop_var.get());
+        if (!analyzer.CanProve(loop->min == 0)) {
+          return false;
+        }
+      }
+      return true;
+    };
+    tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
+    std::reverse(desc_loops.begin(), desc_loops.end());
+    ICHECK(desc_block);
+  }
+  // Step 2. Collect loops from block_sref
+  const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
+  const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+  std::vector<const tir::ForNode*> block_loops;
+  std::unordered_set<const tir::VarNode*> block_loop_vars;
+  {
+    for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) {
+      const auto* loop = loop_sref->StmtAs<tir::ForNode>();
+      if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
+        break;
+      }
+      block_loops.push_back(loop);
+      block_loop_vars.insert(loop->loop_var.get());
+      if (!analyzer.CanProve(loop->min == 0)) {
+        return NullOpt;
+      }
+    }
+    std::reverse(block_loops.begin(), block_loops.end());
+  }
+  // Step 3. Map from block loops to desc block loops
+  ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
+  const int n_block_vars = block->iter_values.size();
+  const int n_desc_vars = desc_block->iter_values.size();
+  const int offset = n_block_vars - n_desc_vars;
+
+  if (offset < 0) {
+    return NullOpt;
+  }
+
+  const std::vector<IterVarType> iter_types_block = GetBlockVarTypes(block_sref);
+  const std::vector<IterVarType> iter_types_desc = GetBlockVarTypes(desc_block->block.get());
+
+  ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars));
+  ICHECK(block_loops.size() == iter_types_block.size());
+
+  int next_block_ind = block_loops.size() - 1;
+  for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) {
+    const tir::ForNode* desc_loop = desc_loops[i_desc];
+    const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>();
+    if (!int_desc_extent) continue;
+
+    for (int i_block = next_block_ind; i_block >= 0; --i_block) {
+      const tir::ForNode* block_loop = block_loops[i_block];
+      const IntImmNode* int_block_extent = block_loop->extent.as<IntImmNode>();
+
+      if (!int_block_extent) continue;
+      if (int_block_extent->value % int_desc_extent->value != 0) continue;
+      if (iter_types_block[i_block] != iter_types_desc[i_desc]) continue;
+
+      const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
+      ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
+      next_block_ind = i_block - 1;
+      break;
+    }
+  }

Review Comment:
   I would love to have @spectrometerHBH review this change before merging



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org