You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/20 00:11:35 UTC

[GitHub] [tvm] vinx13 commented on a diff in pull request #11050: [TIR] Utility function to decide loop mapping for auto tensorization

vinx13 commented on code in PR #11050:
URL: https://github.com/apache/tvm/pull/11050#discussion_r853618665


##########
src/tir/schedule/analysis/analysis.cc:
##########
@@ -2028,5 +2035,162 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,   //
   }
 }
 
+TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);
+
+Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
+                                                const tir::StmtSRef& block_sref,
+                                                const tir::PrimFunc& desc_func) {
+  arith::Analyzer analyzer;
+  const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
+  // Step 1. Analyze desc_func, extract its block, loops and loop vars
+  const tir::BlockRealizeNode* desc_block = nullptr;
+  std::vector<const tir::ForNode*> desc_loops;
+  std::unordered_set<const tir::VarNode*> desc_loop_vars;
+  const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
+  ICHECK(desc_scope_realize);
+  {
+    auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
+                    &analyzer](const ObjectRef& obj) -> bool {
+      // Extract the block
+      if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
+        desc_block = block;
+        return false;
+      }
+      // Extract loops
+      if (const auto* loop = obj.as<tir::ForNode>()) {
+        desc_loops.push_back(loop);
+        desc_loop_vars.insert(loop->loop_var.get());
+        if (!analyzer.CanProve(loop->min == 0)) {
+          return false;
+        }
+      }
+      return true;
+    };
+    tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
+    std::reverse(desc_loops.begin(), desc_loops.end());
+    ICHECK(desc_block);
+  }
+  // Step 2. Collect loops from block_sref
+  const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
+  const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+  std::vector<const tir::ForNode*> block_loops;
+  std::unordered_set<const tir::VarNode*> block_loop_vars;
+  {
+    for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) {
+      const auto* loop = loop_sref->StmtAs<tir::ForNode>();
+      if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
+        break;
+      }
+      block_loops.push_back(loop);
+      block_loop_vars.insert(loop->loop_var.get());
+      if (!analyzer.CanProve(loop->min == 0)) {
+        return NullOpt;
+      }
+    }
+    std::reverse(block_loops.begin(), block_loops.end());
+  }
+  // Step 3. Map from block loops to desc block loops
+  ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
+  const int n_block_vars = block->iter_values.size();
+  const int n_desc_vars = desc_block->iter_values.size();
+  const int offset = n_block_vars - n_desc_vars;
+
+  if (offset < 0) {
+    return NullOpt;
+  }
+
+  const std::vector<IterVarType> iter_types_block = GetBlockVarTypes(block_sref);
+  const std::vector<IterVarType> iter_types_desc = GetBlockVarTypes(desc_block->block.get());
+
+  ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars));
+  ICHECK(block_loops.size() == iter_types_block.size());
+
+  // We assume that the orders of iter_vars in the target and the desc block are consistent.

Review Comment:
   I agree this is a reasonable assumption. Though there might be corner cases, it covers all of the current use cases



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org