You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/06/16 20:58:15 UTC

[GitHub] [tvm] vinx13 commented on a diff in pull request #11740: [TIR, analysis] Add GetAutoTensorizeMappingInfo to generate transforms for auto tensorization

vinx13 commented on code in PR #11740:
URL: https://github.com/apache/tvm/pull/11740#discussion_r899516179


##########
src/tir/schedule/ir_comparator.cc:
##########
@@ -355,5 +355,135 @@ void TensorizeComparator::EmitError(const std::string& error_message) {
   error_messages_.push_back(error_message);
 }
 
+/******** AutoTensorize Extractor ********/
+
+bool AutoTensorizeExtractor::VisitExprDefault_(const Object* op, const PrimExpr& other) {
+  return false;
+}
+
+bool AutoTensorizeExtractor::VisitStmtDefault_(const Object* op, const Stmt& other) {
+  return false;
+}
+
+template <typename T, typename F>
+bool AutoTensorizeExtractor::CompareArray(const Array<T>& lhs, const Array<T>& rhs, F cmp) {
+  if (lhs.same_as(rhs)) return true;
+  if (lhs.size() != rhs.size()) return false;
+  for (size_t i = 0; i < lhs.size(); ++i) {
+    if (!(this->*cmp)(lhs[i], rhs[i])) return false;
+  }
+  return true;
+}
+
+bool AutoTensorizeExtractor::VisitStmt_(const BlockNode* op, const Stmt& other) {
+  const auto* rhs = other.as<BlockNode>();
+  // Check block equality.
+  // All iter vars and buffer regions including the order should match.
+  // When checking iter vars, DefEqual is used to remap variables.
+  if (!is_scope_block) {
+    if (!CompareArray(op->iter_vars, rhs->iter_vars, &AutoTensorizeExtractor::CompareIterVar)) {
+      return false;
+    }
+    if (!CompareAnnotationMap(op->annotations, rhs->annotations)) {
+      return false;
+    }
+    if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers,
+                      &AutoTensorizeExtractor::CompareBuffer)) {
+      return false;
+    }
+    for (const IterVar& block_iter : op->iter_vars) {
+      inner_iter_dom_map_.Set(block_iter->var, arith::IntSet::FromRange(block_iter->dom));
+    }
+  } else {
+    auto collect_iter = [&](const BlockNode* op, std::vector<IterVar>& iters) -> bool {
+      for (const auto& iter : op->iter_vars) {
+        analyzer_.Bind(iter->var, iter->dom);
+        if (iter->iter_type == IterVarType::kDataPar ||
+            iter->iter_type == IterVarType::kCommReduce) {
+          iters.push_back(iter);
+        } else {
+          return false;
+        }
+      }
+      return true;
+    };
+    if (!collect_iter(op, lhs_iters_)) {
+      return false;
+    }
+    if (!collect_iter(rhs, rhs_iters_)) {
+      return false;
+    }
+  }
+  is_scope_block = false;
+  return VisitStmt(op->body, rhs->body);
+}
+
+bool AutoTensorizeExtractor::CompareBuffer(const Buffer& lhs, const Buffer& rhs) {
+  if (lhs.same_as(rhs)) return true;
+  auto it = rhs_buffer_map_.find(rhs);
+  bool equal;
+  if (it != rhs_buffer_map_.end()) {
+    equal = (*it).second.same_as(lhs);
+  } else {
+    // Remap both buffer itself and buffer data, skip buffer shape and scope
+    equal = DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype;
+    if (equal) {
+      rhs_buffer_map_[rhs] = lhs;
+      lhs_buffer_map_[lhs] = rhs;
+    }
+  }
+  return equal;
+}
+
+bool AutoTensorizeExtractor::VisitStmt_(const BufferStoreNode* op, const Stmt& other) {
+  const auto* rhs = other.as<BufferStoreNode>();
+  return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value);
+}
+
+bool AutoTensorizeExtractor::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) {
+  const auto* rhs = other.as<BufferLoadNode>();
+  return CompareBufferAccess(op, rhs);
+}

Review Comment:
   Yes. They call the template function `CompareBufferAccess`, which can't be virtual. So I have to duplicate these two overrides



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org