You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/21 03:49:25 UTC

[GitHub] [incubator-tvm] merrymercy opened a new pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

merrymercy opened a new pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103


   For full upstream plan, see [Ansor RFC](https://discuss.tvm.ai/t/rfc-ansor-an-auto-scheduler-for-tvm-autotvm-v2-0/7005/32).
   
   In this PR, we introduce the access analyzer which will analyze the read-write relations in a compute declaration.
   The search policy will use the analysis results to decide whether to do multi-level tiling or inline an op.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459749490



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {

Review comment:
       No. It will be addressed later.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r458573185



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {

Review comment:
       The extractor basically does two things:
   1) Check if there are branches
   2) For each op, figure out where it is read and save as a list of multi-dimensional indices
   
   So I think the class name might be misleading, because writing a tensor is not counted. Let's find a better name.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {

Review comment:
       It doesn't have to be static (it might interfere with backtrace in error reporting)
   
   ```suggestion
   std::unordered_set<const VarNode*> ExprGatherVars(const PrimExpr& expr) {
   ```

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {

Review comment:
       ditto

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {

Review comment:
       consider moving to utils.h?

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));

Review comment:
       It is a bit counter-intuitive. Let's use pattern matching in tvm/arith instead.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -126,6 +555,7 @@ class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
           fail_ = true;
           break;
         }
+        cur_type_code_ = pop->output_dtype(0).code();

Review comment:
       could you elaborate why we need `cur_type_code_`? how do we deal with the case that computation is mixed with int8 and fp32?

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.buf_accesses) {
+        std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.buf_accesses);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+        int n_common;
+        for (n_common = 0;
+             n_common < static_cast<int>(std::min(output_shape.size(), producer_shape.size()));
+             n_common++) {
+          if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) {
+            break;
+          }
+
+          bool direct_access = true;
+          for (const auto& access : access_list) {
+            if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) {
+              direct_access = false;
+              break;
+            }
+          }
+
+          if (!direct_access) {
+            break;
+          }
+        }
+
+        node->num_common_outer_iterators[op][producer] = n_common;
+        node->num_common_outer_iterators[producer][op] = n_common;
+      }
+    } else {
+      LOG(FATAL) << "Invalid op: " << op;
+    }
+  }
+
+  // do some static analysis
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->is_injective[op] = true;
+      node->needs_multi_level_tiling[op] = false;
+      node->is_strict_inlineable[op] = false;
+      node->is_output[op] = false;
+    } else if (auto pop = op.as<te::ComputeOpNode>()) {
+      // check whether this op is element-wise and strict-inlineable
+      bool is_injective = true;
+      bool is_strict_inlineable = true;
+
+      bool axis_missing, axis_duplicated, same_order;
+      for (const auto& pair : node->read_from[op]) {
+        const std::vector<std::vector<PrimExpr>>& access = pair.second;
+        for (const auto& index : access) {
+          if (!auto_scheduler::IsInjective(op, index, &axis_missing, &axis_duplicated,
+                                           &same_order)) {
+            is_injective = false;
+            is_strict_inlineable = false;
+            break;
+          }
+          if (!same_order || axis_duplicated) {
+            // do not strictly inline transpose
+            is_strict_inlineable = false;
+          }
+        }
+        if (!is_injective) {
+          break;
+        }
+      }
+      if (has_branch[op]) {
+        is_strict_inlineable = false;
+      }
+
+      // don't strictly inline expensive op (e.g. exp)
+      bool has_expensive_op = false;
+      for (const auto& expr : pop->body) {
+        has_expensive_op |= HasExpensiveOp(expr);
+      }
+
+      node->is_injective[op] = is_injective;
+      node->is_strict_inlineable[op] = is_strict_inlineable && !has_expensive_op;
+
+      // check whether the op needs multi-level tiling
+      bool needs_multi_level_tiling = false;
+      int n_missing = 0;
+
+      for (const auto& pair : node->read_from[op]) {
+        const std::vector<std::vector<PrimExpr>>& access = pair.second;
+        std::unordered_set<const VarNode*> vars;
+        for (const std::vector<PrimExpr>& indices : access) {
+          for (const PrimExpr& expr : indices) {
+            GatherVars(expr, &vars);
+          }
+        }
+        bool missing = false;
+        for (const auto& axis : pop->axis) {
+          if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) {
+            missing = true;
+          }
+        }
+        if (missing) {
+          n_missing++;
+        }
+
+        if (n_missing >= 2 || (n_missing >= 1 && !pop->reduce_axis.empty())) {
+          needs_multi_level_tiling = true;
+          break;
+        }
+      }
+
+      node->needs_multi_level_tiling[op] = needs_multi_level_tiling;
+
+      // check whether is output
+      node->is_output[op] = node->read_by[op].empty();
+    } else {
+      LOG(FATAL) << "Invalid op" << op;
+    }
+  }
+
+  data_ = std::move(node);
+}
+
+bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation& op) const {
+  return operator->()->needs_multi_level_tiling.at(op);
+}
+
+bool AccessAnalyzer::IsOutput(const te::Operation& op) const {
+  return operator->()->is_output.at(op);
+}
+
+bool AccessAnalyzer::IsInjective(const te::Operation& op) const {
+  return operator->()->is_injective.at(op);
+}
+
+bool AccessAnalyzer::IsStrictInlineable(const te::Operation& op) const {
+  return operator->()->is_strict_inlineable.at(op);
+}
+
+void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op,
+                                  OperationSet* consumers) const {
+  OperationSet inlined_ops;
+  for (const auto& stage : state->stages) {
+    if (stage->compute_at == ComputeAtKind::kInlined) {
+      inlined_ops.insert(stage->op);
+    }
+  }
+
+  std::function<void(const te::Operation&)> collect;
+  collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) {
+    for (const auto& iter : operator->()->read_by.at(op)) {
+      if (inlined_ops.count(iter.first)) {
+        collect(iter.first);
+      } else {
+        consumers->insert(iter.first);
+      }
+    }
+  };
+
+  consumers->clear();
+  collect(op);
+}
+
+void AccessAnalyzer::GetDirectProducers(const te::Operation& op, OperationSet* producers) const {
+  producers->clear();
+  for (const auto& iter : operator->()->read_from.at(op)) {
+    producers->insert(iter.first);
+  }
+}
+
+void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op,
+                                  OperationSet* producers) const {
+  OperationSet inlined_ops;
+  for (const auto& stage : state->stages) {
+    if (stage->compute_at == ComputeAtKind::kInlined) {
+      inlined_ops.insert(stage->op);
+    }
+  }
+
+  std::function<void(const te::Operation&)> collect;
+  collect = [this, &collect, &inlined_ops, &producers](const te::Operation& op) {
+    for (const auto& iter : operator->()->read_from.at(op)) {
+      if (inlined_ops.count(iter.first)) {
+        collect(iter.first);
+      } else {
+        producers->insert(iter.first);
+      }
+    }
+  };
+
+  producers->clear();
+  collect(op);
+}
+
+int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op,
+                                              const te::Operation& target_op) const {
+  int ret = INT32_MAX;
+  bool meet = false;
+
+  std::function<void(const te::Operation&, int)> traverse;
+  traverse = [this, &traverse, &target_op, &ret, &meet](const te::Operation& cur_op, int cur_num) {
+    if (cur_op == target_op) {
+      ret = std::min(ret, cur_num);
+      meet = true;
+      return;
+    }
+
+    for (const auto& iter : operator->()->read_by.at(cur_op)) {
+      traverse(
+          iter.first,
+          std::min(cur_num, operator->()->num_common_outer_iterators.at(cur_op).at(iter.first)));
+    }
+  };
+
+  traverse(op, op->output_shape(0).size());
+  return meet ? ret : 0;
+}
+
+// Return whether two int arrays are elementwise-equal
+bool IntArrayEqual(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {

Review comment:
       Moved to utils.h?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r457820848



##########
File path: src/auto_scheduler/compute_dag.h
##########
@@ -37,13 +37,126 @@
 
 #include <tvm/te/schedule.h>
 
+#include <unordered_map>
+#include <unordered_set>
 #include <utility>
+#include <vector>
 
 #include "loop_state.h"
 
 namespace tvm {
 namespace auto_scheduler {
 
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;

Review comment:
       These fields will only be accessed by c++ code, so it is okay to use std::unordered_map and std::vector.
   Nobody should write a search policy in python.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459036902



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {

Review comment:
       just curious: did we test with ComputeOp with multiple bodies?

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {

Review comment:
       I don't know if it is a good name tho...otherwise maybe it is good to add more docs

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,

Review comment:
       maybe index => indices?

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,

Review comment:
       Document this function, because it is relatively important

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;

Review comment:
       ```suggestion
     OperationMap<std::vector<std::vector<PrimExpr>>> read_from;
   ```

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));

Review comment:
       a second look: do we really want to include the case of "-var + const"?

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {

Review comment:
       A second look: looks like it is only applied to PrimExpr
   ```suggestion
   class ReadFromAndHasBranchExtractor : public ExprVisitor {
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r458605272



##########
File path: include/tvm/auto_scheduler/auto_schedule.h
##########
@@ -18,29 +18,27 @@
  */
 
 /*!
- * \file auto_scheduler/auto_schedule.h

Review comment:
       1. Because all other components locate their header files to the public header.
   2. For cpp test




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r458574335



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {

Review comment:
       It doesn't have to be static (it might interfere with backtrace in error reporting)
   
   ```suggestion
   std::unordered_set<const VarNode*> ExprGatherVars(const PrimExpr& expr) {
   ```

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {

Review comment:
       consider moving to utils.h?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459132211



##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
+  /*! \brief Map an operation to all operations it is read by.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
+  /*! \brief Store the number of common outer iterators for operation pairs that have
+   * read-write relations. */
+  OperationMap<OperationMap<int>> num_common_outer_iterators;
+  /*! \brief Store whether the operation is injective */
+  OperationMap<bool> is_injective;
+  /*! \brief Store whether the operation is strictly-inlineable */
+  OperationMap<bool> is_strict_inlineable;
+  /*! \brief Store whether the operation needs multi-level tiling */
+  OperationMap<bool> needs_multi_level_tiling;
+  /*! \brief Store whether the operation is an output operation */
+  OperationMap<bool> is_output;
+  /*! \brief Store the topological order of operations */
+  Array<te::Operation> ops_topo_order;

Review comment:
       The are the same. I store it in AccessAnalyzer because it is used first here.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#issuecomment-661629932


   It is one of the core part of the system. Will take a look later tomorrow.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r458504576



##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -159,10 +159,16 @@ using IterKey = std::pair<int, int>;
  */
 class AttachMapNode : public Object {
  public:
+  struct key_hash : public std::function<std::size_t(IterKey)> {

Review comment:
       any specific reason that we inherit `std::function`?

##########
File path: include/tvm/auto_scheduler/auto_schedule.h
##########
@@ -18,29 +18,27 @@
  */
 
 /*!
- * \file auto_scheduler/auto_schedule.h

Review comment:
       Could you elaborate why we move all of these to public headers? For cpptests?

##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -159,10 +159,16 @@ using IterKey = std::pair<int, int>;
  */
 class AttachMapNode : public Object {
  public:
+  struct key_hash : public std::function<std::size_t(IterKey)> {

Review comment:
       Also, suggest a better name: `IterKeyHash`

##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
+  /*! \brief Map an operation to all operations it is read by.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
+  /*! \brief Store the number of common outer iterators for operation pairs that have
+   * read-write relations. */
+  OperationMap<OperationMap<int>> num_common_outer_iterators;
+  /*! \brief Store whether the operation is injective */
+  OperationMap<bool> is_injective;
+  /*! \brief Store whether the operation is strictly-inlineable */
+  OperationMap<bool> is_strict_inlineable;
+  /*! \brief Store whether the operation needs multi-level tiling */
+  OperationMap<bool> needs_multi_level_tiling;
+  /*! \brief Store whether the operation is an output operation */
+  OperationMap<bool> is_output;
+  /*! \brief Store the topological order of operations */
+  Array<te::Operation> ops_topo_order;

Review comment:
       What's the relationship between this array and `ComputeDAG::ops`?

##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,242 @@
+/*r
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
+  /*! \brief Map an operation to all operations it is read by.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
+  /*! \brief Store the number of common outer iterators for operation pairs that have
+   * read-write relations. */
+  OperationMap<OperationMap<int>> num_common_outer_iterators;
+  /*! \brief Store whether the operation is injective */
+  OperationMap<bool> is_injective;
+  /*! \brief Store whether the operation is strictly-inlineable */
+  OperationMap<bool> is_strict_inlineable;
+  /*! \brief Store whether the operation needs multi-level tiling */
+  OperationMap<bool> needs_multi_level_tiling;
+  /*! \brief Store whether the operation is an output operation */
+  OperationMap<bool> is_output;
+  /*! \brief Store the topological order of operations */
+  Array<te::Operation> ops_topo_order;
+
+  static constexpr const char* _type_key = "auto_scheduler.AccessAnalyzer";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object);
+};
+
+/*!
+ * \brief Managed reference to AccessAnalyzerNode.
+ * \sa AccessAnalyzerNode
+ */
+class AccessAnalyzer : public ObjectRef {
+ public:
+  explicit AccessAnalyzer(const Array<te::Tensor>& tensors);
+
+  /*!
+   * \brief Return whether this operation needs multi-level tiling
+   * \param op The operation
+   */
+  TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;
+
+  /*!
+   * \brief Return whether this operation is an injective operation
+   * \param op The operation
+   */
+  TVM_DLL bool IsInjective(const te::Operation& op) const;
+
+  /*!
+   * \brief Return whether this operation is strictly inlinable
+   * \param op The operation
+   */
+  TVM_DLL bool IsStrictInlineable(const te::Operation& op) const;
+
+  /*!
+   * \brief Return whether this operation is an output op
+   * \param op The operation
+   */
+  TVM_DLL bool IsOutput(const te::Operation& op) const;
+
+  /*!
+   * \brief Get all consumers of on operation
+   * \param state The current loop state
+   * \param op The operation
+   * \param consumers The return consumer set
+   * \note This function propagates the relation for inlined ops
+   */
+  TVM_DLL void GetConsumers(
+      const State& state, const te::Operation& op,
+      std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* consumers) const;
+
+  /*!
+   * \brief Get all producers of on operation
+   * \param state The current loop state
+   * \param op The operation
+   * \param producers The return producer set
+   * \note This function propagates the relation for inlined ops
+   */
+  TVM_DLL void GetProducers(
+      const State& state, const te::Operation& op,
+      std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* producers) const;

Review comment:
       ditto

##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;

Review comment:
       It is somewhat annoying to have such nested data structure. What about we defining some other types for `OperationMap<OperationMap<T>>`?

##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/

Review comment:
       Elaborate "multi-dimensional access", like, the inner `std::vector` represents the multi-dimensional indices. Also use a specific type alias for `std::vector<PrimExpr>`

##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,242 @@
+/*r
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
+  /*! \brief Map an operation to all operations it is read by.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
+  /*! \brief Store the number of common outer iterators for operation pairs that have
+   * read-write relations. */
+  OperationMap<OperationMap<int>> num_common_outer_iterators;
+  /*! \brief Store whether the operation is injective */
+  OperationMap<bool> is_injective;
+  /*! \brief Store whether the operation is strictly-inlineable */
+  OperationMap<bool> is_strict_inlineable;
+  /*! \brief Store whether the operation needs multi-level tiling */
+  OperationMap<bool> needs_multi_level_tiling;
+  /*! \brief Store whether the operation is an output operation */
+  OperationMap<bool> is_output;
+  /*! \brief Store the topological order of operations */
+  Array<te::Operation> ops_topo_order;
+
+  static constexpr const char* _type_key = "auto_scheduler.AccessAnalyzer";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object);
+};
+
+/*!
+ * \brief Managed reference to AccessAnalyzerNode.
+ * \sa AccessAnalyzerNode
+ */
+class AccessAnalyzer : public ObjectRef {
+ public:
+  explicit AccessAnalyzer(const Array<te::Tensor>& tensors);
+
+  /*!
+   * \brief Return whether this operation needs multi-level tiling
+   * \param op The operation
+   */
+  TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;
+
+  /*!
+   * \brief Return whether this operation is an injective operation
+   * \param op The operation
+   */
+  TVM_DLL bool IsInjective(const te::Operation& op) const;
+
+  /*!
+   * \brief Return whether this operation is strictly inlinable
+   * \param op The operation
+   */
+  TVM_DLL bool IsStrictInlineable(const te::Operation& op) const;
+
+  /*!
+   * \brief Return whether this operation is an output op
+   * \param op The operation
+   */
+  TVM_DLL bool IsOutput(const te::Operation& op) const;
+
+  /*!
+   * \brief Get all consumers of on operation
+   * \param state The current loop state
+   * \param op The operation
+   * \param consumers The return consumer set
+   * \note This function propagates the relation for inlined ops
+   */
+  TVM_DLL void GetConsumers(
+      const State& state, const te::Operation& op,
+      std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* consumers) const;

Review comment:
       Shall we just return consumers?
   ```suggestion
     TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* GetConsumers(
       const State& state, const te::Operation& op) const;
   ```
   

##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;

Review comment:
       seems like pointer hashing and equality
   ```suggestion
     using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>;
   ```

##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,242 @@
+/*r
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
+  /*! \brief Map an operation to all operations it is read by.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
+  /*! \brief Store the number of common outer iterators for operation pairs that have
+   * read-write relations. */
+  OperationMap<OperationMap<int>> num_common_outer_iterators;
+  /*! \brief Store whether the operation is injective */
+  OperationMap<bool> is_injective;
+  /*! \brief Store whether the operation is strictly-inlineable */
+  OperationMap<bool> is_strict_inlineable;
+  /*! \brief Store whether the operation needs multi-level tiling */
+  OperationMap<bool> needs_multi_level_tiling;
+  /*! \brief Store whether the operation is an output operation */
+  OperationMap<bool> is_output;
+  /*! \brief Store the topological order of operations */
+  Array<te::Operation> ops_topo_order;
+
+  static constexpr const char* _type_key = "auto_scheduler.AccessAnalyzer";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object);
+};
+
+/*!
+ * \brief Managed reference to AccessAnalyzerNode.
+ * \sa AccessAnalyzerNode
+ */
+class AccessAnalyzer : public ObjectRef {
+ public:
+  explicit AccessAnalyzer(const Array<te::Tensor>& tensors);
+
+  /*!
+   * \brief Return whether this operation needs multi-level tiling
+   * \param op The operation
+   */
+  TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;
+
+  /*!
+   * \brief Return whether this operation is an injective operation
+   * \param op The operation
+   */
+  TVM_DLL bool IsInjective(const te::Operation& op) const;
+
+  /*!
+   * \brief Return whether this operation is strictly inlinable
+   * \param op The operation
+   */
+  TVM_DLL bool IsStrictInlineable(const te::Operation& op) const;
+
+  /*!
+   * \brief Return whether this operation is an output op
+   * \param op The operation
+   */
+  TVM_DLL bool IsOutput(const te::Operation& op) const;
+
+  /*!
+   * \brief Get all consumers of on operation
+   * \param state The current loop state
+   * \param op The operation
+   * \param consumers The return consumer set
+   * \note This function propagates the relation for inlined ops
+   */
+  TVM_DLL void GetConsumers(
+      const State& state, const te::Operation& op,
+      std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* consumers) const;
+
+  /*!
+   * \brief Get all producers of on operation
+   * \param state The current loop state
+   * \param op The operation
+   * \param producers The return producer set
+   * \note This function propagates the relation for inlined ops
+   */
+  TVM_DLL void GetProducers(
+      const State& state, const te::Operation& op,
+      std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* producers) const;
+
+  /*!
+   * \brief Get all direct producers of on operation
+   * \param op The operation
+   * \param producers The return producer set
+   * \note This function DOES NOT propagate the relation for inlined ops
+   */
+  TVM_DLL void GetDirectProducers(
+      const te::Operation& op,
+      std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* producers) const;

Review comment:
       ditto




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r457820848



##########
File path: src/auto_scheduler/compute_dag.h
##########
@@ -37,13 +37,126 @@
 
 #include <tvm/te/schedule.h>
 
+#include <unordered_map>
+#include <unordered_set>
 #include <utility>
+#include <vector>
 
 #include "loop_state.h"
 
 namespace tvm {
 namespace auto_scheduler {
 
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;

Review comment:
       These fields will only be accessed by c++ code, so it is okay to use std::unordered_map and std::vector.
   No body should write a search policy in python.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459183926



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -126,6 +555,7 @@ class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
           fail_ = true;
           break;
         }
+        cur_type_code_ = pop->output_dtype(0).code();

Review comment:
       @merrymercy I agree that it is totally okay if we can just use rough information (in fact it is highly non-trivial to get accurate info without backend info). My point is that `cur_type_code_` comes from the dtype of output, but it is totally possible that a compute dag contains computation of different type code (int8, fp16)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459143595



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -126,6 +555,7 @@ class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
           fail_ = true;
           break;
         }
+        cur_type_code_ = pop->output_dtype(0).code();

Review comment:
       The FLOP information is only used for debug. It is okay to just give a rough estimation.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r458567027



##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,242 @@
+/*r
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {

Review comment:
       I feel like AccessAnalyzer itself can be a much more principled and extensible component of the system, so shall we put it in a separate file instead?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459691514



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -126,6 +555,7 @@ class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
           fail_ = true;
           break;
         }
+        cur_type_code_ = pop->output_dtype(0).code();

Review comment:
       It is not common at all. The common case is either int8->int16/32 or fp16/fp16->fp32




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 edited a comment on pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
junrushao1994 edited a comment on pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#issuecomment-661629932


   It is one of the core part of the system. Will take a look later tomorrow night.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r458995260



##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,242 @@
+/*r
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {

Review comment:
       Agree. Maybe have an `analysis.h` to expect more analyzers in the future. On the other hand, another direction might be renaming AccessAnalyzer to ComputeDAGAnalyzer, because it provides some APIs for the ops in a compute DAG, such as `NeedMultiLevelTiling`, `IsOutput`, etc.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy merged pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy merged pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459143595



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -126,6 +555,7 @@ class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
           fail_ = true;
           break;
         }
+        cur_type_code_ = pop->output_dtype(0).code();

Review comment:
       The FLOP information is only used for printing and debugging. It is okay to just give a rough estimation.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459132211



##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
+  /*! \brief Map an operation to all operations it is read by.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
+  /*! \brief Store the number of common outer iterators for operation pairs that have
+   * read-write relations. */
+  OperationMap<OperationMap<int>> num_common_outer_iterators;
+  /*! \brief Store whether the operation is injective */
+  OperationMap<bool> is_injective;
+  /*! \brief Store whether the operation is strictly-inlineable */
+  OperationMap<bool> is_strict_inlineable;
+  /*! \brief Store whether the operation needs multi-level tiling */
+  OperationMap<bool> needs_multi_level_tiling;
+  /*! \brief Store whether the operation is an output operation */
+  OperationMap<bool> is_output;
+  /*! \brief Store the topological order of operations */
+  Array<te::Operation> ops_topo_order;

Review comment:
       Good catch. I found it is not used in the code..




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#issuecomment-661910948


   This PR is ready for review. But It has some dependency on #6073. I will make it pass the ci test after #6073 is merged


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459182796



##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/

Review comment:
       @merrymercy let's do "using MultiDimIdx = std::vector<PrimExpr>"




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459132211



##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
+  /*! \brief Map an operation to all operations it is read by.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
+  /*! \brief Store the number of common outer iterators for operation pairs that have
+   * read-write relations. */
+  OperationMap<OperationMap<int>> num_common_outer_iterators;
+  /*! \brief Store whether the operation is injective */
+  OperationMap<bool> is_injective;
+  /*! \brief Store whether the operation is strictly-inlineable */
+  OperationMap<bool> is_strict_inlineable;
+  /*! \brief Store whether the operation needs multi-level tiling */
+  OperationMap<bool> needs_multi_level_tiling;
+  /*! \brief Store whether the operation is an output operation */
+  OperationMap<bool> is_output;
+  /*! \brief Store the topological order of operations */
+  Array<te::Operation> ops_topo_order;

Review comment:
       They are the same




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#issuecomment-661615992


   cc @junrushao1994 @comaniac @jcf94 @FrozenGene @tqchen 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459132211



##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
+  /*! \brief Map an operation to all operations it is read by.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
+  /*! \brief Store the number of common outer iterators for operation pairs that have
+   * read-write relations. */
+  OperationMap<OperationMap<int>> num_common_outer_iterators;
+  /*! \brief Store whether the operation is injective */
+  OperationMap<bool> is_injective;
+  /*! \brief Store whether the operation is strictly-inlineable */
+  OperationMap<bool> is_strict_inlineable;
+  /*! \brief Store whether the operation needs multi-level tiling */
+  OperationMap<bool> needs_multi_level_tiling;
+  /*! \brief Store whether the operation is an output operation */
+  OperationMap<bool> is_output;
+  /*! \brief Store the topological order of operations */
+  Array<te::Operation> ops_topo_order;

Review comment:
       They are the same. I store it in AccessAnalyzer because it is used first here.
   In the constructor of ComputeDAG, it copies `ops_topo_order` as its `ops`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459749921



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.buf_accesses) {
+        std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.buf_accesses);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);

Review comment:
       Even with multiple outputs, the shape will be the same




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#issuecomment-663276179


   @junrushao1994 @jcf94 @comaniac  Most of the comments are addressed. I added more doc and make the name convention more consistent and meaningful. Please take another look


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 edited a comment on pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
jcf94 edited a comment on pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#issuecomment-661922011


   > This PR is ready for review. But It has some dependency on #6073. I will make it pass the ci test after #6073 is merged
   
   #6073 's macOS build failure seems caused by some infrastructure error, I think it's ready for merge? I'll continue to upstream the left steps after #6073 .


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r458057282



##########
File path: tests/python/unittest/test_auto_scheduler_common.py
##########
@@ -33,7 +33,7 @@ def matmul_auto_scheduler_test(N, M, K):
 
 
 @auto_scheduler.register_workload("matmul_auto_scheduler_test_rename_1")
-def matmul_auto_scheduler_test_rename_0(N, M, K):
+def matmul_auto_scheduler_test_rename_1(N, M, K):

Review comment:
       I was intent to register this function with a different name to test the register method ... Any way, it's unimportant.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459749921



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.buf_accesses) {
+        std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.buf_accesses);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);

Review comment:
       The shape will always be the same




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459750828



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.buf_accesses) {
+        std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.buf_accesses);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+        int n_common;
+        for (n_common = 0;
+             n_common < static_cast<int>(std::min(output_shape.size(), producer_shape.size()));
+             n_common++) {
+          if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) {
+            break;
+          }
+
+          bool direct_access = true;
+          for (const auto& access : access_list) {
+            if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) {
+              direct_access = false;
+              break;
+            }
+          }
+
+          if (!direct_access) {
+            break;
+          }
+        }
+
+        node->num_common_outer_iterators[op][producer] = n_common;
+        node->num_common_outer_iterators[producer][op] = n_common;
+      }
+    } else {
+      LOG(FATAL) << "Invalid op: " << op;
+    }
+  }
+
+  // do some static analysis
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->is_injective[op] = true;
+      node->needs_multi_level_tiling[op] = false;
+      node->is_strict_inlineable[op] = false;
+      node->is_output[op] = false;
+    } else if (auto pop = op.as<te::ComputeOpNode>()) {
+      // check whether this op is element-wise and strict-inlineable
+      bool is_injective = true;
+      bool is_strict_inlineable = true;
+
+      bool axis_missing, axis_duplicated, same_order;
+      for (const auto& pair : node->read_from[op]) {
+        const std::vector<std::vector<PrimExpr>>& access = pair.second;
+        for (const auto& index : access) {
+          if (!auto_scheduler::IsInjective(op, index, &axis_missing, &axis_duplicated,
+                                           &same_order)) {
+            is_injective = false;
+            is_strict_inlineable = false;
+            break;
+          }
+          if (!same_order || axis_duplicated) {
+            // do not strictly inline transpose

Review comment:
       The formal definition will be very long and tedious. I can try to make more comments




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459046943



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.buf_accesses) {
+        std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.buf_accesses);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);

Review comment:
       Do it only work for `te::Operation` with a single output? Do we have a fallback solution for operators with multiple outputs like `argmax`?

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.buf_accesses) {
+        std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.buf_accesses);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+        int n_common;
+        for (n_common = 0;
+             n_common < static_cast<int>(std::min(output_shape.size(), producer_shape.size()));

Review comment:
       it's a long line...maybe consider move rhs out

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.buf_accesses) {
+        std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.buf_accesses);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+        int n_common;
+        for (n_common = 0;
+             n_common < static_cast<int>(std::min(output_shape.size(), producer_shape.size()));
+             n_common++) {
+          if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) {
+            break;
+          }
+
+          bool direct_access = true;
+          for (const auto& access : access_list) {
+            if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) {
+              direct_access = false;
+              break;
+            }
+          }
+
+          if (!direct_access) {
+            break;
+          }
+        }
+
+        node->num_common_outer_iterators[op][producer] = n_common;
+        node->num_common_outer_iterators[producer][op] = n_common;
+      }
+    } else {
+      LOG(FATAL) << "Invalid op: " << op;
+    }
+  }
+
+  // do some static analysis
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->is_injective[op] = true;
+      node->needs_multi_level_tiling[op] = false;
+      node->is_strict_inlineable[op] = false;
+      node->is_output[op] = false;
+    } else if (auto pop = op.as<te::ComputeOpNode>()) {
+      // check whether this op is element-wise and strict-inlineable
+      bool is_injective = true;
+      bool is_strict_inlineable = true;
+
+      bool axis_missing, axis_duplicated, same_order;
+      for (const auto& pair : node->read_from[op]) {
+        const std::vector<std::vector<PrimExpr>>& access = pair.second;
+        for (const auto& index : access) {
+          if (!auto_scheduler::IsInjective(op, index, &axis_missing, &axis_duplicated,
+                                           &same_order)) {
+            is_injective = false;
+            is_strict_inlineable = false;
+            break;
+          }
+          if (!same_order || axis_duplicated) {
+            // do not strictly inline transpose

Review comment:
       i kinda understand this sentence: `transpose` doesn't give `same_order`, so it is not strictly inlinable. Can we give a formal definition of strict inlineable?

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.buf_accesses) {
+        std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.buf_accesses);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+        int n_common;
+        for (n_common = 0;
+             n_common < static_cast<int>(std::min(output_shape.size(), producer_shape.size()));
+             n_common++) {
+          if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) {
+            break;
+          }
+
+          bool direct_access = true;
+          for (const auto& access : access_list) {
+            if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) {
+              direct_access = false;
+              break;
+            }
+          }
+
+          if (!direct_access) {
+            break;
+          }
+        }
+
+        node->num_common_outer_iterators[op][producer] = n_common;
+        node->num_common_outer_iterators[producer][op] = n_common;
+      }
+    } else {
+      LOG(FATAL) << "Invalid op: " << op;
+    }
+  }
+
+  // do some static analysis
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->is_injective[op] = true;
+      node->needs_multi_level_tiling[op] = false;
+      node->is_strict_inlineable[op] = false;
+      node->is_output[op] = false;
+    } else if (auto pop = op.as<te::ComputeOpNode>()) {
+      // check whether this op is element-wise and strict-inlineable
+      bool is_injective = true;
+      bool is_strict_inlineable = true;
+
+      bool axis_missing, axis_duplicated, same_order;
+      for (const auto& pair : node->read_from[op]) {
+        const std::vector<std::vector<PrimExpr>>& access = pair.second;
+        for (const auto& index : access) {
+          if (!auto_scheduler::IsInjective(op, index, &axis_missing, &axis_duplicated,
+                                           &same_order)) {
+            is_injective = false;
+            is_strict_inlineable = false;
+            break;
+          }
+          if (!same_order || axis_duplicated) {
+            // do not strictly inline transpose
+            is_strict_inlineable = false;
+          }
+        }
+        if (!is_injective) {
+          break;
+        }
+      }
+      if (has_branch[op]) {
+        is_strict_inlineable = false;
+      }
+
+      // don't strictly inline expensive op (e.g. exp)
+      bool has_expensive_op = false;
+      for (const auto& expr : pop->body) {
+        has_expensive_op |= HasExpensiveOp(expr);
+      }
+
+      node->is_injective[op] = is_injective;
+      node->is_strict_inlineable[op] = is_strict_inlineable && !has_expensive_op;
+
+      // check whether the op needs multi-level tiling
+      bool needs_multi_level_tiling = false;
+      int n_missing = 0;
+
+      for (const auto& pair : node->read_from[op]) {
+        const std::vector<std::vector<PrimExpr>>& access = pair.second;
+        std::unordered_set<const VarNode*> vars;
+        for (const std::vector<PrimExpr>& indices : access) {
+          for (const PrimExpr& expr : indices) {
+            GatherVars(expr, &vars);
+          }
+        }
+        bool missing = false;
+        for (const auto& axis : pop->axis) {
+          if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) {
+            missing = true;
+          }
+        }
+        if (missing) {
+          n_missing++;
+        }

Review comment:
       you don't need this flag, just break
   ```suggestion
           for (const auto& axis : pop->axis) {
             if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) {
               ++n_missing;
               break;
             }
           }
   ```

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);

Review comment:
       this function is way too large...consider decomposing it into several smaller ones.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {

Review comment:
       It don't have to be static. static may sometimes interfere with backtrace printing stuff.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.buf_accesses) {
+        std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.buf_accesses);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+        int n_common;
+        for (n_common = 0;
+             n_common < static_cast<int>(std::min(output_shape.size(), producer_shape.size()));
+             n_common++) {
+          if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) {
+            break;
+          }
+
+          bool direct_access = true;
+          for (const auto& access : access_list) {
+            if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) {
+              direct_access = false;
+              break;
+            }
+          }
+
+          if (!direct_access) {
+            break;
+          }
+        }
+
+        node->num_common_outer_iterators[op][producer] = n_common;
+        node->num_common_outer_iterators[producer][op] = n_common;
+      }
+    } else {
+      LOG(FATAL) << "Invalid op: " << op;
+    }
+  }
+
+  // do some static analysis
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->is_injective[op] = true;
+      node->needs_multi_level_tiling[op] = false;
+      node->is_strict_inlineable[op] = false;
+      node->is_output[op] = false;
+    } else if (auto pop = op.as<te::ComputeOpNode>()) {
+      // check whether this op is element-wise and strict-inlineable
+      bool is_injective = true;
+      bool is_strict_inlineable = true;
+
+      bool axis_missing, axis_duplicated, same_order;
+      for (const auto& pair : node->read_from[op]) {
+        const std::vector<std::vector<PrimExpr>>& access = pair.second;
+        for (const auto& index : access) {
+          if (!auto_scheduler::IsInjective(op, index, &axis_missing, &axis_duplicated,
+                                           &same_order)) {
+            is_injective = false;
+            is_strict_inlineable = false;
+            break;
+          }
+          if (!same_order || axis_duplicated) {
+            // do not strictly inline transpose
+            is_strict_inlineable = false;
+          }
+        }
+        if (!is_injective) {
+          break;
+        }
+      }
+      if (has_branch[op]) {
+        is_strict_inlineable = false;
+      }
+
+      // don't strictly inline expensive op (e.g. exp)
+      bool has_expensive_op = false;
+      for (const auto& expr : pop->body) {
+        has_expensive_op |= HasExpensiveOp(expr);
+      }
+
+      node->is_injective[op] = is_injective;
+      node->is_strict_inlineable[op] = is_strict_inlineable && !has_expensive_op;
+
+      // check whether the op needs multi-level tiling
+      bool needs_multi_level_tiling = false;
+      int n_missing = 0;
+
+      for (const auto& pair : node->read_from[op]) {
+        const std::vector<std::vector<PrimExpr>>& access = pair.second;

Review comment:
       I feel like we should have consistent naming convention, at least inside a function. There are three places in this function, where "pair.second" is called "access" or "access_list", each element of which is called "index", "indices" and "access" - maybe it is better to come up with a consistent naming for them....

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.buf_accesses) {
+        std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.buf_accesses);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+        int n_common;
+        for (n_common = 0;
+             n_common < static_cast<int>(std::min(output_shape.size(), producer_shape.size()));
+             n_common++) {
+          if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) {
+            break;
+          }
+
+          bool direct_access = true;
+          for (const auto& access : access_list) {
+            if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) {
+              direct_access = false;
+              break;
+            }
+          }
+
+          if (!direct_access) {
+            break;
+          }
+        }
+
+        node->num_common_outer_iterators[op][producer] = n_common;
+        node->num_common_outer_iterators[producer][op] = n_common;
+      }
+    } else {
+      LOG(FATAL) << "Invalid op: " << op;
+    }
+  }
+
+  // do some static analysis
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->is_injective[op] = true;
+      node->needs_multi_level_tiling[op] = false;
+      node->is_strict_inlineable[op] = false;
+      node->is_output[op] = false;
+    } else if (auto pop = op.as<te::ComputeOpNode>()) {

Review comment:
       why it is named pop...i thought the convention is cop...
   ```suggestion
       } else if (const auto* cop = op.as<te::ComputeOpNode>()) {
   ```

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) ||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) ||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.buf_accesses) {
+        std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.buf_accesses);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+        int n_common;
+        for (n_common = 0;
+             n_common < static_cast<int>(std::min(output_shape.size(), producer_shape.size()));
+             n_common++) {
+          if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) {
+            break;
+          }
+
+          bool direct_access = true;

Review comment:
       hmm just curious why it is named direct_access?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459132211



##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectHash, ObjectEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
+  /*! \brief Map an operation to all operations it is read by.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
+  /*! \brief Store the number of common outer iterators for operation pairs that have
+   * read-write relations. */
+  OperationMap<OperationMap<int>> num_common_outer_iterators;
+  /*! \brief Store whether the operation is injective */
+  OperationMap<bool> is_injective;
+  /*! \brief Store whether the operation is strictly-inlineable */
+  OperationMap<bool> is_strict_inlineable;
+  /*! \brief Store whether the operation needs multi-level tiling */
+  OperationMap<bool> needs_multi_level_tiling;
+  /*! \brief Store whether the operation is an output operation */
+  OperationMap<bool> is_output;
+  /*! \brief Store the topological order of operations */
+  Array<te::Operation> ops_topo_order;

Review comment:
       They are the same. I store it in AccessAnalyzer because it is used first here.
   In the constructor of ComputeDAG, it copies this `ops`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 commented on pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
jcf94 commented on pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#issuecomment-661922011


   > This PR is ready for review. But It has some dependency on #6073. I will make it pass the ci test after #6073 is merged
   
   #6073 's macOS build failure seems caused by some infrastructure error, I think it's ready to merge? I'll continue to upstream the left steps after #6073 .


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #6103: [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#issuecomment-662541150


   cc @jwfromm @mbrookhart @jroesch 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org