You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/19 00:27:58 UTC

[GitHub] [tvm] masahi commented on a diff in pull request #11050: [TIR] Utility function to decide loop mapping for auto tensorization

masahi commented on code in PR #11050:
URL: https://github.com/apache/tvm/pull/11050#discussion_r852501884


##########
src/tir/schedule/analysis/analysis.cc:
##########
@@ -2028,5 +2034,107 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,   //
   }
 }
 
+TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);
+
+Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
+                                                const tir::StmtSRef& block_sref,
+                                                const tir::PrimFunc& desc_func) {
+  arith::Analyzer analyzer;
+  const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
+  // Step 1. Analyze desc_func, extract its block, loops and loop vars
+  const tir::BlockRealizeNode* desc_block = nullptr;
+  std::vector<const tir::ForNode*> desc_loops;
+  std::unordered_set<const tir::VarNode*> desc_loop_vars;
+  const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
+  ICHECK(desc_scope_realize);
+  {
+    auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
+                    &analyzer](const ObjectRef& obj) -> bool {
+      // Extract the block
+      if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
+        desc_block = block;
+        return false;
+      }
+      // Extract loops
+      if (const auto* loop = obj.as<tir::ForNode>()) {
+        desc_loops.push_back(loop);
+        desc_loop_vars.insert(loop->loop_var.get());
+        if (!analyzer.CanProve(loop->min == 0)) {
+          return false;
+        }
+      }
+      return true;
+    };
+    tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
+    std::reverse(desc_loops.begin(), desc_loops.end());
+    ICHECK(desc_block);
+  }
+  // Step 2. Collect loops from block_sref
+  const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
+  const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+  std::vector<const tir::ForNode*> block_loops;
+  std::unordered_set<const tir::VarNode*> block_loop_vars;
+  {
+    for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) {
+      const auto* loop = loop_sref->StmtAs<tir::ForNode>();
+      if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
+        break;
+      }
+      block_loops.push_back(loop);
+      block_loop_vars.insert(loop->loop_var.get());
+      if (!analyzer.CanProve(loop->min == 0)) {
+        return NullOpt;
+      }
+    }
+    std::reverse(block_loops.begin(), block_loops.end());
+  }
+  // Step 3. Map from block loops to desc block loops
+  ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
+  const int n_block_vars = block->iter_values.size();
+  const int n_desc_vars = desc_block->iter_values.size();
+  const int offset = n_block_vars - n_desc_vars;
+
+  if (offset < 0) {
+    return NullOpt;
+  }
+
+  const std::vector<IterVarType> iter_types_block = GetBlockVarTypes(block_sref);
+  const std::vector<IterVarType> iter_types_desc = GetBlockVarTypes(desc_block->block.get());
+
+  ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars));
+  ICHECK(block_loops.size() == iter_types_block.size());
+
+  int next_block_ind = block_loops.size() - 1;
+  for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) {
+    const tir::ForNode* desc_loop = desc_loops[i_desc];
+    const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>();
+    if (!int_desc_extent) continue;
+
+    for (int i_block = next_block_ind; i_block >= 0; --i_block) {
+      const tir::ForNode* block_loop = block_loops[i_block];
+      const IntImmNode* int_block_extent = block_loop->extent.as<IntImmNode>();
+
+      if (!int_block_extent) continue;
+      if (int_block_extent->value % int_desc_extent->value != 0) continue;
+      if (iter_types_block[i_block] != iter_types_desc[i_desc]) continue;
+
+      const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
+      ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
+      next_block_ind = i_block - 1;
+      break;
+    }
+  }

Review Comment:
   The logic here is very different from the one in the original code https://github.com/spectrometerHBH/tvm/blob/auto-tensorization/src/tir/schedule/analysis/analysis.cc#L1246. I was not able to understand why the original code has been written that way and it didn't work for the case where matching loops in the target block are not in the innermost positions (conv2d NCHWc on CPU, a test in https://github.com/apache/tvm/blob/d6ae84879d4eb7befc3fc07e0f967973f50ece16/tests/python/unittest/test_tir_schedule_analysis.py#L199). 
   
   I think my change is simple and obvious. The condition for a match is (1) divisibility of loop extent and (2) matching iterator types (reduction vs spatial). Mapping is determined starting from the innermost axis.
   
   Please have a look at this change carefully, and let me know if I need to bring back some logic in the original code @spectrometerHBH @vinx13 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org