You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/08/18 09:31:51 UTC

[GitHub] [incubator-tvm] minminsun opened a new pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

minminsun opened a new pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297


   For full upstream plan, see [Ansor RFC](https://discuss.tvm.ai/t/rfc-ansor-an-auto-scheduler-for-tvm-autotvm-v2-0/7005/32).
   
   In this PR, we enable AutoScheduler to rewrite the layout of placeholders to best fit the loop nest of the candidate schedule to be applied on ComputeDAG.
   
   Note that the function in this PR is only for performance evaluation with layout rewrite. An end-to-end solution of layout rewrite requires close cooperation  with relay passes, which will be provided in future PRs.
   
   CC @merrymercy @jcf94 @FrozenGene @comaniac @tqchen 
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] minminsun commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
minminsun commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r477112406



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);
+        axis_name = "IntImm";
+      } else {
+        axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+      }
+
+      placeholder_axis_names->insert(axis_name);
+      os << placeholder->shape[i++] << axis_name;
+    }
+  }
+
+  CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+  std::string ori_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+  return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const int stage_id,
+                           const Stage& stage, const te::Operation& op,
+                           const te::Tensor& placeholder,
+                           const std::set<std::string>& placeholder_axis_names) {
+  std::ostringstream os;
+  Array<Iterator> stage_iters;
+
+  auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+  int attach_pos = -1;
+  size_t iters_before_attach = 0;
+  if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+    auto attach = attach_it->second;
+    const auto& attach_stage = state->stages[attach.first];
+    attach_pos = attach.second;
+    stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+                       attach_stage->iters.begin() + attach_pos + 1);
+  }
+
+  stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end());
+
+  std::vector<Iterator> iters;
+  for (size_t i = 0; i < stage_iters.size(); ++i) {
+    const auto& iter = stage_iters[i];
+    if (iter->ori_iters.empty()) {
+      iters.push_back(iter);
+    } else {
+      for (const Iterator& ori_iter : iter->ori_iters) {
+        iters.push_back(ori_iter);
+      }
+    }
+    if (static_cast<int>(i) == attach_pos) {
+      iters_before_attach = iters.size();
+    }
+  }
+
+  std::vector<std::string> new_names;
+  std::vector<std::string> new_axis_names;
+  for (const Iterator& iter : iters) {
+    std::set<std::string> ori_iter_names;
+    ExtractOriginalIterators(iter->name, &ori_iter_names);
+    // fused iters have been replaced with iter->ori_iters.
+    // So there should be only one ori iter name extracted from iter->name.
+    CHECK_EQ(ori_iter_names.size(), 1);
+    auto ori_iter_name = BaseName(*ori_iter_names.begin());
+    new_axis_names.push_back(ori_iter_name);
+  }
+  for (size_t i = 0; i < new_axis_names.size(); ++i) {
+    auto iter = iters[i];
+    std::string ori_iter_name;
+    if (i < iters_before_attach) {
+      ori_iter_name = new_axis_names[i + iters_before_attach];
+    } else {
+      ori_iter_name = new_axis_names[i];
+    }
+    if (placeholder_axis_names.count(ori_iter_name)) {
+      os << iter->range->extent << ori_iter_name;
+      new_names.push_back(ori_iter_name);
+      new_shape->push_back(iter->range->extent);
+    }
+  }
+  std::string new_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+  return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+  ComputeDAGNode* pdag = this->CopyOnWrite();
+  auto node = make_object<StateNode>();
+  node->transform_steps = transform_steps;
+  node->concrete = true;
+  const State& state = InferBound(State(node));
+  OperationSet handled_ops;
+  int stage_id = -1;
+  for (const auto& stage : state->stages) {
+    stage_id += 1;
+    const te::Operation& op = stage->op;
+    if (op->IsInstance<te::ComputeOpNode>()) {
+      const Map<String, ObjectRef>& attrs = op->attrs;
+      if (attrs.count(layout_free_placeholders_key)) {

Review comment:
       Done.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] minminsun commented on pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
minminsun commented on pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#issuecomment-689347497


   > @minminsun Please fix some final style comments and we can merge this
   
   Done.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy merged pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
merrymercy merged pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r472339009



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -40,6 +40,7 @@
 #include <vector>
 
 #include "../arith/pattern_match.h"
+#include "search_policy/utils.h"

Review comment:
       - All utility functions should be moved to `utils.h`.
   - All the function names should follow C++ naming convention. For example, `ParseKernelLayout` instead of `parser_kernel_layout`.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform

Review comment:
       Provide more comments in this function to help future maintain.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }

Review comment:
       - Can we inline this function?
   - `BaseName` is too general.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;

Review comment:
       s/uint/size_t/

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {

Review comment:
       This function needs more comments.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);
+        axis_name = "IntImm";
+      } else {
+        axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+      }
+
+      placeholder_axis_names->insert(axis_name);
+      os << placeholder->shape[i++] << axis_name;
+    }
+  }
+
+  CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+  std::string ori_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+  return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const int stage_id,
+                           const Stage& stage, const te::Operation& op,
+                           const te::Tensor& placeholder,
+                           const std::set<std::string>& placeholder_axis_names) {
+  std::ostringstream os;
+  Array<Iterator> stage_iters;
+
+  auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+  int attach_pos = -1;
+  size_t iters_before_attach = 0;
+  if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+    auto attach = attach_it->second;
+    const auto& attach_stage = state->stages[attach.first];
+    attach_pos = attach.second;
+    stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+                       attach_stage->iters.begin() + attach_pos + 1);
+  }
+
+  stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end());
+
+  std::vector<Iterator> iters;
+  for (size_t i = 0; i < stage_iters.size(); ++i) {
+    const auto& iter = stage_iters[i];
+    if (iter->ori_iters.empty()) {
+      iters.push_back(iter);
+    } else {
+      for (const Iterator& ori_iter : iter->ori_iters) {
+        iters.push_back(ori_iter);
+      }
+    }
+    if (static_cast<int>(i) == attach_pos) {
+      iters_before_attach = iters.size();
+    }
+  }
+
+  std::vector<std::string> new_names;
+  std::vector<std::string> new_axis_names;
+  for (const Iterator& iter : iters) {
+    std::set<std::string> ori_iter_names;
+    ExtractOriginalIterators(iter->name, &ori_iter_names);
+    // fused iters have been replaced with iter->ori_iters.
+    // So there should be only one ori iter name extracted from iter->name.
+    CHECK_EQ(ori_iter_names.size(), 1);
+    auto ori_iter_name = BaseName(*ori_iter_names.begin());
+    new_axis_names.push_back(ori_iter_name);
+  }
+  for (size_t i = 0; i < new_axis_names.size(); ++i) {
+    auto iter = iters[i];
+    std::string ori_iter_name;
+    if (i < iters_before_attach) {
+      ori_iter_name = new_axis_names[i + iters_before_attach];
+    } else {
+      ori_iter_name = new_axis_names[i];
+    }
+    if (placeholder_axis_names.count(ori_iter_name)) {
+      os << iter->range->extent << ori_iter_name;
+      new_names.push_back(ori_iter_name);
+      new_shape->push_back(iter->range->extent);
+    }
+  }
+  std::string new_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+  return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+  ComputeDAGNode* pdag = this->CopyOnWrite();
+  auto node = make_object<StateNode>();
+  node->transform_steps = transform_steps;
+  node->concrete = true;
+  const State& state = InferBound(State(node));
+  OperationSet handled_ops;
+  int stage_id = -1;
+  for (const auto& stage : state->stages) {
+    stage_id += 1;

Review comment:
       It'd be better to use `for(size_t stage_id = 0; stage_id < stage->stages.size(); ++stage_id)` if you need the ID.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);

Review comment:
       It seems like `new_layout_` is fixed after the rewriter is constructed. Accordingly, `new_shape` and `new_names` should also be fixed. IMHO, we should be able to figure out the new shape and names in the constructor as well to make the logic more clear.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,

Review comment:
       Ditto: utility function, naming, and comments.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);
+        axis_name = "IntImm";
+      } else {
+        axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+      }
+
+      placeholder_axis_names->insert(axis_name);
+      os << placeholder->shape[i++] << axis_name;
+    }
+  }
+
+  CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+  std::string ori_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+  return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const int stage_id,

Review comment:
       Ditto: utility function, naming, and comments.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);
+        axis_name = "IntImm";
+      } else {
+        axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+      }
+
+      placeholder_axis_names->insert(axis_name);
+      os << placeholder->shape[i++] << axis_name;
+    }
+  }
+
+  CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+  std::string ori_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+  return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const int stage_id,
+                           const Stage& stage, const te::Operation& op,
+                           const te::Tensor& placeholder,
+                           const std::set<std::string>& placeholder_axis_names) {
+  std::ostringstream os;
+  Array<Iterator> stage_iters;
+
+  auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+  int attach_pos = -1;
+  size_t iters_before_attach = 0;
+  if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+    auto attach = attach_it->second;
+    const auto& attach_stage = state->stages[attach.first];
+    attach_pos = attach.second;
+    stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+                       attach_stage->iters.begin() + attach_pos + 1);
+  }
+
+  stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end());
+
+  std::vector<Iterator> iters;
+  for (size_t i = 0; i < stage_iters.size(); ++i) {
+    const auto& iter = stage_iters[i];
+    if (iter->ori_iters.empty()) {
+      iters.push_back(iter);
+    } else {
+      for (const Iterator& ori_iter : iter->ori_iters) {
+        iters.push_back(ori_iter);
+      }
+    }
+    if (static_cast<int>(i) == attach_pos) {
+      iters_before_attach = iters.size();
+    }
+  }
+
+  std::vector<std::string> new_names;
+  std::vector<std::string> new_axis_names;
+  for (const Iterator& iter : iters) {
+    std::set<std::string> ori_iter_names;
+    ExtractOriginalIterators(iter->name, &ori_iter_names);
+    // fused iters have been replaced with iter->ori_iters.
+    // So there should be only one ori iter name extracted from iter->name.
+    CHECK_EQ(ori_iter_names.size(), 1);
+    auto ori_iter_name = BaseName(*ori_iter_names.begin());
+    new_axis_names.push_back(ori_iter_name);
+  }
+  for (size_t i = 0; i < new_axis_names.size(); ++i) {
+    auto iter = iters[i];
+    std::string ori_iter_name;
+    if (i < iters_before_attach) {
+      ori_iter_name = new_axis_names[i + iters_before_attach];
+    } else {
+      ori_iter_name = new_axis_names[i];
+    }
+    if (placeholder_axis_names.count(ori_iter_name)) {
+      os << iter->range->extent << ori_iter_name;
+      new_names.push_back(ori_iter_name);
+      new_shape->push_back(iter->range->extent);
+    }
+  }
+  std::string new_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+  return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+  ComputeDAGNode* pdag = this->CopyOnWrite();
+  auto node = make_object<StateNode>();
+  node->transform_steps = transform_steps;
+  node->concrete = true;
+  const State& state = InferBound(State(node));
+  OperationSet handled_ops;
+  int stage_id = -1;
+  for (const auto& stage : state->stages) {
+    stage_id += 1;
+    const te::Operation& op = stage->op;
+    if (op->IsInstance<te::ComputeOpNode>()) {

Review comment:
       Since this statement is pretty long, I'd suggest
   ```
   if (!op->IsInstance<te::ComputeOpNode>()) {
     continue;
   }
   ```
   so that we can reduce an indent.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);
+        axis_name = "IntImm";
+      } else {
+        axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+      }
+
+      placeholder_axis_names->insert(axis_name);
+      os << placeholder->shape[i++] << axis_name;
+    }
+  }
+
+  CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+  std::string ori_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+  return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const int stage_id,
+                           const Stage& stage, const te::Operation& op,
+                           const te::Tensor& placeholder,
+                           const std::set<std::string>& placeholder_axis_names) {
+  std::ostringstream os;
+  Array<Iterator> stage_iters;
+
+  auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+  int attach_pos = -1;
+  size_t iters_before_attach = 0;
+  if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+    auto attach = attach_it->second;
+    const auto& attach_stage = state->stages[attach.first];
+    attach_pos = attach.second;
+    stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+                       attach_stage->iters.begin() + attach_pos + 1);
+  }
+
+  stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end());
+
+  std::vector<Iterator> iters;
+  for (size_t i = 0; i < stage_iters.size(); ++i) {
+    const auto& iter = stage_iters[i];
+    if (iter->ori_iters.empty()) {
+      iters.push_back(iter);
+    } else {
+      for (const Iterator& ori_iter : iter->ori_iters) {
+        iters.push_back(ori_iter);
+      }
+    }
+    if (static_cast<int>(i) == attach_pos) {
+      iters_before_attach = iters.size();
+    }
+  }
+
+  std::vector<std::string> new_names;
+  std::vector<std::string> new_axis_names;
+  for (const Iterator& iter : iters) {
+    std::set<std::string> ori_iter_names;
+    ExtractOriginalIterators(iter->name, &ori_iter_names);
+    // fused iters have been replaced with iter->ori_iters.
+    // So there should be only one ori iter name extracted from iter->name.
+    CHECK_EQ(ori_iter_names.size(), 1);
+    auto ori_iter_name = BaseName(*ori_iter_names.begin());
+    new_axis_names.push_back(ori_iter_name);
+  }
+  for (size_t i = 0; i < new_axis_names.size(); ++i) {
+    auto iter = iters[i];
+    std::string ori_iter_name;
+    if (i < iters_before_attach) {
+      ori_iter_name = new_axis_names[i + iters_before_attach];
+    } else {
+      ori_iter_name = new_axis_names[i];
+    }
+    if (placeholder_axis_names.count(ori_iter_name)) {
+      os << iter->range->extent << ori_iter_name;
+      new_names.push_back(ori_iter_name);
+      new_shape->push_back(iter->range->extent);
+    }
+  }
+  std::string new_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+  return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+  ComputeDAGNode* pdag = this->CopyOnWrite();
+  auto node = make_object<StateNode>();
+  node->transform_steps = transform_steps;
+  node->concrete = true;
+  const State& state = InferBound(State(node));
+  OperationSet handled_ops;
+  int stage_id = -1;
+  for (const auto& stage : state->stages) {
+    stage_id += 1;
+    const te::Operation& op = stage->op;
+    if (op->IsInstance<te::ComputeOpNode>()) {
+      const Map<String, ObjectRef>& attrs = op->attrs;
+      if (attrs.count(layout_free_placeholders_key)) {
+        const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
+        Array<te::Tensor> placeholders = Downcast<Array<te::Tensor>>(attr_value);
+        for (const auto& placeholder : placeholders) {
+          const auto& placeholder_op = placeholder->op;
+
+          // Check whether this placeholder has already been handled
+          if (handled_ops.count(placeholder_op)) {
+            continue;
+          }
+
+          // skip the op that is not direct consumer of this placeholder,
+          // mostly due to cache read/write.
+          bool direct_consumer = false;
+          for (auto& t : op->InputTensors()) {
+            if (t->op == placeholder_op) {
+              direct_consumer = true;
+              break;
+            }
+          }
+          if (!direct_consumer) {
+            continue;
+          }
+
+          std::set<std::string> placeholder_axis_names;
+          get_ori_layout(&placeholder_axis_names, op, placeholder);
+
+          Array<PrimExpr> new_shape;
+          std::string new_layout = get_new_layout(&new_shape, state, stage_id, stage, op,
+                                                  placeholder, placeholder_axis_names);
+
+          handled_ops.insert(placeholder_op);
+
+          Array<te::Operation> old_ops = pdag->ops;
+          ArrayNode* pops = pdag->ops.CopyOnWrite();
+
+          // Create new placeholder
+          te::Operation new_placeholder_op;
+          new_placeholder_op = te::PlaceholderOp(placeholder_op->name, new_shape,
+                                                 placeholder_op.as<te::PlaceholderOpNode>()->dtype);
+
+          te::Operation new_compute_op, old_compute_op;
+          Array<PrimExpr> new_body;
+          IndexRewriter index_rewriter(placeholder_op, new_layout);
+          for (auto& op : old_ops) {
+            if (auto* pop = op.as<te::ComputeOpNode>()) {
+              bool need_update = false;
+              for (auto& t : op->InputTensors()) {
+                if (t->op == placeholder_op) {
+                  need_update = true;
+                  break;
+                }
+              }
+              if (need_update) {
+                for (auto& body : pop->body) {
+                  new_body.push_back(index_rewriter.Rewrite(body));
+                }
+                old_compute_op = op;
+                CHECK(!new_compute_op.defined());
+                new_compute_op =
+                    te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body);
+              }
+            }
+          }
+
+          // construct the map from old_op to new_op
+          std::unordered_map<te::Operation, te::Operation> updated_ops;
+          for (size_t i = 0; i < old_ops.size(); ++i) {
+            auto old_op = old_ops[i];
+            if (old_op == placeholder_op) {
+              pops->SetItem(i, new_placeholder_op);
+              updated_ops[placeholder_op] = new_placeholder_op;
+            } else if (old_op == old_compute_op) {
+              pops->SetItem(i, new_compute_op);
+              updated_ops[old_compute_op] = new_compute_op;
+            } else {
+              pops->SetItem(i, old_op);
+            }
+          }
+
+          // Because ops is sorted in topo-order, only do one pass linear scan here.

Review comment:
       Add comments talking about the purpose.
   ```suggestion
             // Because ops is sorted in topo-order, we only need one pass to (what).
   ```

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);
+        axis_name = "IntImm";
+      } else {
+        axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+      }
+
+      placeholder_axis_names->insert(axis_name);
+      os << placeholder->shape[i++] << axis_name;
+    }
+  }
+
+  CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+  std::string ori_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+  return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const int stage_id,
+                           const Stage& stage, const te::Operation& op,
+                           const te::Tensor& placeholder,
+                           const std::set<std::string>& placeholder_axis_names) {
+  std::ostringstream os;
+  Array<Iterator> stage_iters;
+
+  auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+  int attach_pos = -1;
+  size_t iters_before_attach = 0;
+  if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+    auto attach = attach_it->second;
+    const auto& attach_stage = state->stages[attach.first];
+    attach_pos = attach.second;
+    stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+                       attach_stage->iters.begin() + attach_pos + 1);
+  }
+
+  stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end());
+
+  std::vector<Iterator> iters;
+  for (size_t i = 0; i < stage_iters.size(); ++i) {
+    const auto& iter = stage_iters[i];
+    if (iter->ori_iters.empty()) {
+      iters.push_back(iter);
+    } else {
+      for (const Iterator& ori_iter : iter->ori_iters) {
+        iters.push_back(ori_iter);
+      }
+    }
+    if (static_cast<int>(i) == attach_pos) {
+      iters_before_attach = iters.size();
+    }
+  }
+
+  std::vector<std::string> new_names;
+  std::vector<std::string> new_axis_names;
+  for (const Iterator& iter : iters) {
+    std::set<std::string> ori_iter_names;
+    ExtractOriginalIterators(iter->name, &ori_iter_names);
+    // fused iters have been replaced with iter->ori_iters.
+    // So there should be only one ori iter name extracted from iter->name.
+    CHECK_EQ(ori_iter_names.size(), 1);
+    auto ori_iter_name = BaseName(*ori_iter_names.begin());
+    new_axis_names.push_back(ori_iter_name);
+  }
+  for (size_t i = 0; i < new_axis_names.size(); ++i) {
+    auto iter = iters[i];
+    std::string ori_iter_name;
+    if (i < iters_before_attach) {
+      ori_iter_name = new_axis_names[i + iters_before_attach];
+    } else {
+      ori_iter_name = new_axis_names[i];
+    }
+    if (placeholder_axis_names.count(ori_iter_name)) {
+      os << iter->range->extent << ori_iter_name;
+      new_names.push_back(ori_iter_name);
+      new_shape->push_back(iter->range->extent);
+    }
+  }
+  std::string new_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+  return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+  ComputeDAGNode* pdag = this->CopyOnWrite();
+  auto node = make_object<StateNode>();
+  node->transform_steps = transform_steps;
+  node->concrete = true;
+  const State& state = InferBound(State(node));
+  OperationSet handled_ops;
+  int stage_id = -1;
+  for (const auto& stage : state->stages) {
+    stage_id += 1;
+    const te::Operation& op = stage->op;
+    if (op->IsInstance<te::ComputeOpNode>()) {
+      const Map<String, ObjectRef>& attrs = op->attrs;
+      if (attrs.count(layout_free_placeholders_key)) {
+        const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
+        Array<te::Tensor> placeholders = Downcast<Array<te::Tensor>>(attr_value);
+        for (const auto& placeholder : placeholders) {
+          const auto& placeholder_op = placeholder->op;
+
+          // Check whether this placeholder has already been handled
+          if (handled_ops.count(placeholder_op)) {
+            continue;
+          }
+
+          // skip the op that is not direct consumer of this placeholder,
+          // mostly due to cache read/write.
+          bool direct_consumer = false;
+          for (auto& t : op->InputTensors()) {
+            if (t->op == placeholder_op) {
+              direct_consumer = true;
+              break;
+            }
+          }
+          if (!direct_consumer) {
+            continue;
+          }
+
+          std::set<std::string> placeholder_axis_names;
+          get_ori_layout(&placeholder_axis_names, op, placeholder);
+
+          Array<PrimExpr> new_shape;
+          std::string new_layout = get_new_layout(&new_shape, state, stage_id, stage, op,
+                                                  placeholder, placeholder_axis_names);
+
+          handled_ops.insert(placeholder_op);
+
+          Array<te::Operation> old_ops = pdag->ops;
+          ArrayNode* pops = pdag->ops.CopyOnWrite();
+
+          // Create new placeholder
+          te::Operation new_placeholder_op;
+          new_placeholder_op = te::PlaceholderOp(placeholder_op->name, new_shape,
+                                                 placeholder_op.as<te::PlaceholderOpNode>()->dtype);
+
+          te::Operation new_compute_op, old_compute_op;

Review comment:
       Comment on what is this loop for.

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -118,6 +123,8 @@ class IteratorNode : public Object {
   IteratorKind iter_kind;
   /*! \brief The annotation type of this iterator. */
   IteratorAnnotation annotation;
+  /*! The original iterators before fusion. */
+  std::vector<Iterator> ori_iters;

Review comment:
       Same opinion.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);
+        axis_name = "IntImm";
+      } else {
+        axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+      }
+
+      placeholder_axis_names->insert(axis_name);
+      os << placeholder->shape[i++] << axis_name;
+    }
+  }
+
+  CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+  std::string ori_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+  return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const int stage_id,
+                           const Stage& stage, const te::Operation& op,
+                           const te::Tensor& placeholder,
+                           const std::set<std::string>& placeholder_axis_names) {
+  std::ostringstream os;
+  Array<Iterator> stage_iters;
+
+  auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+  int attach_pos = -1;
+  size_t iters_before_attach = 0;
+  if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+    auto attach = attach_it->second;
+    const auto& attach_stage = state->stages[attach.first];
+    attach_pos = attach.second;
+    stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+                       attach_stage->iters.begin() + attach_pos + 1);
+  }
+
+  stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end());
+
+  std::vector<Iterator> iters;
+  for (size_t i = 0; i < stage_iters.size(); ++i) {
+    const auto& iter = stage_iters[i];
+    if (iter->ori_iters.empty()) {
+      iters.push_back(iter);
+    } else {
+      for (const Iterator& ori_iter : iter->ori_iters) {
+        iters.push_back(ori_iter);
+      }
+    }
+    if (static_cast<int>(i) == attach_pos) {
+      iters_before_attach = iters.size();
+    }
+  }
+
+  std::vector<std::string> new_names;
+  std::vector<std::string> new_axis_names;
+  for (const Iterator& iter : iters) {
+    std::set<std::string> ori_iter_names;
+    ExtractOriginalIterators(iter->name, &ori_iter_names);
+    // fused iters have been replaced with iter->ori_iters.
+    // So there should be only one ori iter name extracted from iter->name.
+    CHECK_EQ(ori_iter_names.size(), 1);
+    auto ori_iter_name = BaseName(*ori_iter_names.begin());
+    new_axis_names.push_back(ori_iter_name);
+  }
+  for (size_t i = 0; i < new_axis_names.size(); ++i) {
+    auto iter = iters[i];
+    std::string ori_iter_name;
+    if (i < iters_before_attach) {
+      ori_iter_name = new_axis_names[i + iters_before_attach];
+    } else {
+      ori_iter_name = new_axis_names[i];
+    }
+    if (placeholder_axis_names.count(ori_iter_name)) {
+      os << iter->range->extent << ori_iter_name;
+      new_names.push_back(ori_iter_name);
+      new_shape->push_back(iter->range->extent);
+    }
+  }
+  std::string new_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+  return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+  ComputeDAGNode* pdag = this->CopyOnWrite();
+  auto node = make_object<StateNode>();
+  node->transform_steps = transform_steps;
+  node->concrete = true;
+  const State& state = InferBound(State(node));
+  OperationSet handled_ops;
+  int stage_id = -1;
+  for (const auto& stage : state->stages) {
+    stage_id += 1;
+    const te::Operation& op = stage->op;
+    if (op->IsInstance<te::ComputeOpNode>()) {
+      const Map<String, ObjectRef>& attrs = op->attrs;
+      if (attrs.count(layout_free_placeholders_key)) {
+        const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
+        Array<te::Tensor> placeholders = Downcast<Array<te::Tensor>>(attr_value);
+        for (const auto& placeholder : placeholders) {
+          const auto& placeholder_op = placeholder->op;
+
+          // Check whether this placeholder has already been handled
+          if (handled_ops.count(placeholder_op)) {
+            continue;
+          }
+
+          // skip the op that is not direct consumer of this placeholder,
+          // mostly due to cache read/write.

Review comment:
       ```suggestion
             // Skip the op that is not direct consumer of this placeholder.
             // This is usually caused by cache read/write.
   ```

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);
+        axis_name = "IntImm";
+      } else {
+        axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+      }
+
+      placeholder_axis_names->insert(axis_name);
+      os << placeholder->shape[i++] << axis_name;
+    }
+  }
+
+  CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+  std::string ori_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+  return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const int stage_id,
+                           const Stage& stage, const te::Operation& op,
+                           const te::Tensor& placeholder,
+                           const std::set<std::string>& placeholder_axis_names) {
+  std::ostringstream os;
+  Array<Iterator> stage_iters;
+
+  auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+  int attach_pos = -1;
+  size_t iters_before_attach = 0;
+  if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+    auto attach = attach_it->second;
+    const auto& attach_stage = state->stages[attach.first];
+    attach_pos = attach.second;
+    stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+                       attach_stage->iters.begin() + attach_pos + 1);
+  }
+
+  stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end());
+
+  std::vector<Iterator> iters;
+  for (size_t i = 0; i < stage_iters.size(); ++i) {
+    const auto& iter = stage_iters[i];
+    if (iter->ori_iters.empty()) {
+      iters.push_back(iter);
+    } else {
+      for (const Iterator& ori_iter : iter->ori_iters) {
+        iters.push_back(ori_iter);
+      }
+    }
+    if (static_cast<int>(i) == attach_pos) {
+      iters_before_attach = iters.size();
+    }
+  }
+
+  std::vector<std::string> new_names;
+  std::vector<std::string> new_axis_names;
+  for (const Iterator& iter : iters) {
+    std::set<std::string> ori_iter_names;
+    ExtractOriginalIterators(iter->name, &ori_iter_names);
+    // fused iters have been replaced with iter->ori_iters.
+    // So there should be only one ori iter name extracted from iter->name.
+    CHECK_EQ(ori_iter_names.size(), 1);
+    auto ori_iter_name = BaseName(*ori_iter_names.begin());
+    new_axis_names.push_back(ori_iter_name);
+  }
+  for (size_t i = 0; i < new_axis_names.size(); ++i) {
+    auto iter = iters[i];
+    std::string ori_iter_name;
+    if (i < iters_before_attach) {
+      ori_iter_name = new_axis_names[i + iters_before_attach];
+    } else {
+      ori_iter_name = new_axis_names[i];
+    }
+    if (placeholder_axis_names.count(ori_iter_name)) {
+      os << iter->range->extent << ori_iter_name;
+      new_names.push_back(ori_iter_name);
+      new_shape->push_back(iter->range->extent);
+    }
+  }
+  std::string new_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+  return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+  ComputeDAGNode* pdag = this->CopyOnWrite();
+  auto node = make_object<StateNode>();
+  node->transform_steps = transform_steps;
+  node->concrete = true;
+  const State& state = InferBound(State(node));
+  OperationSet handled_ops;
+  int stage_id = -1;
+  for (const auto& stage : state->stages) {
+    stage_id += 1;
+    const te::Operation& op = stage->op;
+    if (op->IsInstance<te::ComputeOpNode>()) {
+      const Map<String, ObjectRef>& attrs = op->attrs;
+      if (attrs.count(layout_free_placeholders_key)) {
+        const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
+        Array<te::Tensor> placeholders = Downcast<Array<te::Tensor>>(attr_value);
+        for (const auto& placeholder : placeholders) {
+          const auto& placeholder_op = placeholder->op;
+
+          // Check whether this placeholder has already been handled
+          if (handled_ops.count(placeholder_op)) {
+            continue;
+          }
+
+          // skip the op that is not direct consumer of this placeholder,
+          // mostly due to cache read/write.
+          bool direct_consumer = false;
+          for (auto& t : op->InputTensors()) {
+            if (t->op == placeholder_op) {
+              direct_consumer = true;
+              break;
+            }
+          }
+          if (!direct_consumer) {
+            continue;
+          }
+
+          std::set<std::string> placeholder_axis_names;
+          get_ori_layout(&placeholder_axis_names, op, placeholder);
+
+          Array<PrimExpr> new_shape;
+          std::string new_layout = get_new_layout(&new_shape, state, stage_id, stage, op,
+                                                  placeholder, placeholder_axis_names);
+
+          handled_ops.insert(placeholder_op);
+
+          Array<te::Operation> old_ops = pdag->ops;
+          ArrayNode* pops = pdag->ops.CopyOnWrite();
+
+          // Create new placeholder
+          te::Operation new_placeholder_op;
+          new_placeholder_op = te::PlaceholderOp(placeholder_op->name, new_shape,
+                                                 placeholder_op.as<te::PlaceholderOpNode>()->dtype);
+
+          te::Operation new_compute_op, old_compute_op;
+          Array<PrimExpr> new_body;
+          IndexRewriter index_rewriter(placeholder_op, new_layout);
+          for (auto& op : old_ops) {
+            if (auto* pop = op.as<te::ComputeOpNode>()) {
+              bool need_update = false;
+              for (auto& t : op->InputTensors()) {
+                if (t->op == placeholder_op) {
+                  need_update = true;
+                  break;
+                }
+              }
+              if (need_update) {
+                for (auto& body : pop->body) {
+                  new_body.push_back(index_rewriter.Rewrite(body));
+                }
+                old_compute_op = op;
+                CHECK(!new_compute_op.defined());
+                new_compute_op =
+                    te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body);
+              }
+            }
+          }
+
+          // construct the map from old_op to new_op
+          std::unordered_map<te::Operation, te::Operation> updated_ops;
+          for (size_t i = 0; i < old_ops.size(); ++i) {
+            auto old_op = old_ops[i];
+            if (old_op == placeholder_op) {
+              pops->SetItem(i, new_placeholder_op);
+              updated_ops[placeholder_op] = new_placeholder_op;
+            } else if (old_op == old_compute_op) {
+              pops->SetItem(i, new_compute_op);
+              updated_ops[old_compute_op] = new_compute_op;
+            } else {
+              pops->SetItem(i, old_op);
+            }
+          }
+
+          // Because ops is sorted in topo-order, only do one pass linear scan here.
+          for (size_t i = 0; i < pops->size(); ++i) {
+            auto old_op = Downcast<te::Operation>(pops->at(i));
+            if (auto* pop = old_op.as<te::ComputeOpNode>()) {
+              auto inputs = pop->InputTensors();
+              std::unordered_map<te::Tensor, te::Tensor> rmap;
+              for (auto input : inputs) {
+                auto it = updated_ops.find(input->op);
+                te::Operation new_op;
+                while (it != updated_ops.end()) {
+                  new_op = it->second;
+                  it = updated_ops.find(new_op);
+                }
+                if (new_op.defined()) {
+                  int index = input->value_index;
+                  rmap[input] = new_op.output(index);
+                }
+              }
+              if (!rmap.empty()) {
+                te::Operation new_op = pop->ReplaceInputs(old_op, rmap);
+                updated_ops[old_op] = new_op;
+                pops->SetItem(i, new_op);
+              }
+            }
+          }
+
+          pdag->init_state = State(pdag->ops);
+
+          Array<te::Tensor> old_tensors = pdag->tensors;
+          ArrayNode* ptensors = pdag->tensors.CopyOnWrite();
+
+          for (size_t i = 0; i < old_tensors.size(); ++i) {

Review comment:
       Comment on what is this loop for.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";

Review comment:
       Why do you need to assign a name to `axis_name`? Seems like this will never be used in the rest of this function.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+

Review comment:
       remove this line.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);

Review comment:
       Be more specific about what's the assumption here.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);
+        axis_name = "IntImm";
+      } else {
+        axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+      }
+
+      placeholder_axis_names->insert(axis_name);
+      os << placeholder->shape[i++] << axis_name;
+    }
+  }
+
+  CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+  std::string ori_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+  return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const int stage_id,
+                           const Stage& stage, const te::Operation& op,
+                           const te::Tensor& placeholder,
+                           const std::set<std::string>& placeholder_axis_names) {
+  std::ostringstream os;
+  Array<Iterator> stage_iters;
+
+  auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+  int attach_pos = -1;
+  size_t iters_before_attach = 0;
+  if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+    auto attach = attach_it->second;
+    const auto& attach_stage = state->stages[attach.first];
+    attach_pos = attach.second;
+    stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+                       attach_stage->iters.begin() + attach_pos + 1);
+  }
+
+  stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end());
+
+  std::vector<Iterator> iters;
+  for (size_t i = 0; i < stage_iters.size(); ++i) {
+    const auto& iter = stage_iters[i];
+    if (iter->ori_iters.empty()) {
+      iters.push_back(iter);
+    } else {
+      for (const Iterator& ori_iter : iter->ori_iters) {
+        iters.push_back(ori_iter);
+      }
+    }
+    if (static_cast<int>(i) == attach_pos) {
+      iters_before_attach = iters.size();
+    }
+  }
+
+  std::vector<std::string> new_names;
+  std::vector<std::string> new_axis_names;
+  for (const Iterator& iter : iters) {
+    std::set<std::string> ori_iter_names;
+    ExtractOriginalIterators(iter->name, &ori_iter_names);
+    // fused iters have been replaced with iter->ori_iters.
+    // So there should be only one ori iter name extracted from iter->name.
+    CHECK_EQ(ori_iter_names.size(), 1);
+    auto ori_iter_name = BaseName(*ori_iter_names.begin());
+    new_axis_names.push_back(ori_iter_name);
+  }
+  for (size_t i = 0; i < new_axis_names.size(); ++i) {
+    auto iter = iters[i];
+    std::string ori_iter_name;
+    if (i < iters_before_attach) {
+      ori_iter_name = new_axis_names[i + iters_before_attach];
+    } else {
+      ori_iter_name = new_axis_names[i];
+    }
+    if (placeholder_axis_names.count(ori_iter_name)) {
+      os << iter->range->extent << ori_iter_name;
+      new_names.push_back(ori_iter_name);
+      new_shape->push_back(iter->range->extent);
+    }
+  }
+  std::string new_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+  return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+  ComputeDAGNode* pdag = this->CopyOnWrite();
+  auto node = make_object<StateNode>();
+  node->transform_steps = transform_steps;
+  node->concrete = true;
+  const State& state = InferBound(State(node));
+  OperationSet handled_ops;
+  int stage_id = -1;
+  for (const auto& stage : state->stages) {
+    stage_id += 1;
+    const te::Operation& op = stage->op;
+    if (op->IsInstance<te::ComputeOpNode>()) {
+      const Map<String, ObjectRef>& attrs = op->attrs;
+      if (attrs.count(layout_free_placeholders_key)) {

Review comment:
       ```suggestion
         if (attrs.count(layout_free_placeholders_key) == 0) {
           continue;
         }
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r485321344



##########
File path: python/tvm/auto_scheduler/compute_dag.py
##########
@@ -81,12 +81,16 @@ def apply_steps_from_state(self, state):
         state : Union[State, StateObject]
             The state from which we get transform steps.
 
+        layout_rewrite: Bool
+            Rewrite the layout of placeholder to make it
+            most frendly for the generated schedule to read from.

Review comment:
       ```suggestion
           layout_rewrite: Bool
               Rewrite the layout of placeholders specified by "layout_free_placeholders" attr
               to make it most friendly for the generated schedule to read from.
   ```

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,319 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op) {
+    ParseKernelLayout(new_layout, &new_shape_, &new_names_);
+  }
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = AxisBaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names_.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names_[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape_[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape_[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  Array<PrimExpr> new_shape_;
+  std::vector<std::string> new_names_;
+};
+
+std::string get_orig_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,

Review comment:
       Code style
   get_orig_layout -> GetOrigLayout

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,319 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op) {
+    ParseKernelLayout(new_layout, &new_shape_, &new_names_);
+  }
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = AxisBaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names_.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names_[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape_[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape_[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  Array<PrimExpr> new_shape_;
+  std::vector<std::string> new_names_;
+};
+
+std::string get_orig_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                            const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);
+        axis_name = "IntImm";
+      } else {
+        axis_name = AxisBaseName(CleanName(Downcast<Var>(e)->name_hint));
+      }
+
+      placeholder_axis_names->insert(axis_name);
+      os << placeholder->shape[i++] << axis_name;
+    }
+  }
+
+  CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+  std::string orig_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_orig_layouts_queue.push_back(orig_layout);
+  return orig_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const int stage_id,

Review comment:
       name style




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy edited a comment on pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
merrymercy edited a comment on pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#issuecomment-689286253


   @minminsun  Please fix some final style comments and we can merge this


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r484096899



##########
File path: python/tvm/auto_scheduler/compute_dag.py
##########
@@ -81,12 +81,15 @@ def apply_steps_from_state(self, state):
         state : Union[State, StateObject]
             The state from which we get transform steps.
 
+        layout_rewrite: Bool
+            Rewrite the layout of placeholder.

Review comment:
       This description doesnt add much beyond the variable name. What's the benefit of doing this? Maybe add a sentence describing when you'd want to set this.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,

Review comment:
       instead of using `ori` as shorthand for original, it's probably worth the extra letter to go with `orig`, which is much less ambiguous.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r472052291



##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -205,19 +205,28 @@ class ComputeDAG : public ObjectRef {
    */
   TVM_DLL explicit ComputeDAG(Array<te::Tensor> tensors);
 
+  /*!
+   * \brief Rewrite the layout of placeholder specified by attr `layout_free_placeholders`
+   * according to the loop nest derived with `transform_steps`.
+   * \param transform_steps Transform steps of a state.
+   */
+  void RewriteLayout(
+    const Array<Step> &transform_steps);

Review comment:
       Format... This should be fixed if these codes can pass the CI.

##########
File path: python/tvm/auto_scheduler/compute_dag.py
##########
@@ -72,7 +72,7 @@ def get_init_state(self):
         """
         return State(self.init_state, self)
 
-    def apply_steps_from_state(self, state):
+    def apply_steps_from_state(self, state, layout_rewrite=False):

Review comment:
       Add layout_rewrite to the doc string below.

##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -419,7 +419,7 @@ def timed_func():
 
         try:
             sch, args = task.compute_dag.apply_steps_from_state(
-                inp.state)
+                inp.state, layout_rewrite=True)

Review comment:
       Consider to get the `layout_rewrite` parameter from outside? (e.g. from GLOBAL_BUILD_ARGUMENTS)
   It's not so good to just set it `True`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] minminsun commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
minminsun commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r477107266



##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -419,7 +419,7 @@ def timed_func():
 
         try:
             sch, args = task.compute_dag.apply_steps_from_state(
-                inp.state)
+                inp.state, layout_rewrite=True)

Review comment:
       Whether to do layout rewrite or not for an op is specified by attr layout_free_placeholders in compute definition, so it's safe to set True by default here.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] minminsun commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
minminsun commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r477112491



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
+                           const te::Tensor& placeholder) {
+  ReadAccessExtractor extractor;
+  for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
+    extractor.Extract(exp);
+  }
+
+  std::ostringstream os;
+  uint i = 0;
+  const auto& placeholder_op = placeholder->op;
+  CHECK_GT(extractor.read_access.count(placeholder_op), 0);
+  for (const auto& ev : extractor.read_access[placeholder_op]) {
+    for (const auto& e : ev) {
+      std::string axis_name;
+      if (const auto* pimm = e.as<IntImmNode>()) {
+        CHECK_EQ(pimm->value, 0);
+        axis_name = "IntImm";
+      } else {
+        axis_name = BaseName(CleanName(Downcast<Var>(e)->name_hint));
+      }
+
+      placeholder_axis_names->insert(axis_name);
+      os << placeholder->shape[i++] << axis_name;
+    }
+  }
+
+  CHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
+  std::string ori_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_ori_layouts_queue.push_back(ori_layout);
+  return ori_layout;
+}
+
+std::string get_new_layout(Array<PrimExpr>* new_shape, const State& state, const int stage_id,
+                           const Stage& stage, const te::Operation& op,
+                           const te::Tensor& placeholder,
+                           const std::set<std::string>& placeholder_axis_names) {
+  std::ostringstream os;
+  Array<Iterator> stage_iters;
+
+  auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
+  int attach_pos = -1;
+  size_t iters_before_attach = 0;
+  if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
+    auto attach = attach_it->second;
+    const auto& attach_stage = state->stages[attach.first];
+    attach_pos = attach.second;
+    stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
+                       attach_stage->iters.begin() + attach_pos + 1);
+  }
+
+  stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end());
+
+  std::vector<Iterator> iters;
+  for (size_t i = 0; i < stage_iters.size(); ++i) {
+    const auto& iter = stage_iters[i];
+    if (iter->ori_iters.empty()) {
+      iters.push_back(iter);
+    } else {
+      for (const Iterator& ori_iter : iter->ori_iters) {
+        iters.push_back(ori_iter);
+      }
+    }
+    if (static_cast<int>(i) == attach_pos) {
+      iters_before_attach = iters.size();
+    }
+  }
+
+  std::vector<std::string> new_names;
+  std::vector<std::string> new_axis_names;
+  for (const Iterator& iter : iters) {
+    std::set<std::string> ori_iter_names;
+    ExtractOriginalIterators(iter->name, &ori_iter_names);
+    // fused iters have been replaced with iter->ori_iters.
+    // So there should be only one ori iter name extracted from iter->name.
+    CHECK_EQ(ori_iter_names.size(), 1);
+    auto ori_iter_name = BaseName(*ori_iter_names.begin());
+    new_axis_names.push_back(ori_iter_name);
+  }
+  for (size_t i = 0; i < new_axis_names.size(); ++i) {
+    auto iter = iters[i];
+    std::string ori_iter_name;
+    if (i < iters_before_attach) {
+      ori_iter_name = new_axis_names[i + iters_before_attach];
+    } else {
+      ori_iter_name = new_axis_names[i];
+    }
+    if (placeholder_axis_names.count(ori_iter_name)) {
+      os << iter->range->extent << ori_iter_name;
+      new_names.push_back(ori_iter_name);
+      new_shape->push_back(iter->range->extent);
+    }
+  }
+  std::string new_layout = os.str();
+  os.str("");
+  // TODO(minmin): uncomment this line for relay integration
+  // ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout);
+  return new_layout;
+}
+
+void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
+  ComputeDAGNode* pdag = this->CopyOnWrite();
+  auto node = make_object<StateNode>();
+  node->transform_steps = transform_steps;
+  node->concrete = true;
+  const State& state = InferBound(State(node));
+  OperationSet handled_ops;
+  int stage_id = -1;
+  for (const auto& stage : state->stages) {
+    stage_id += 1;
+    const te::Operation& op = stage->op;
+    if (op->IsInstance<te::ComputeOpNode>()) {
+      const Map<String, ObjectRef>& attrs = op->attrs;
+      if (attrs.count(layout_free_placeholders_key)) {
+        const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
+        Array<te::Tensor> placeholders = Downcast<Array<te::Tensor>>(attr_value);
+        for (const auto& placeholder : placeholders) {
+          const auto& placeholder_op = placeholder->op;
+
+          // Check whether this placeholder has already been handled
+          if (handled_ops.count(placeholder_op)) {
+            continue;
+          }
+
+          // skip the op that is not direct consumer of this placeholder,
+          // mostly due to cache read/write.

Review comment:
       Done.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r472094744



##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -419,7 +419,7 @@ def timed_func():
 
         try:
             sch, args = task.compute_dag.apply_steps_from_state(
-                inp.state)
+                inp.state, layout_rewrite=True)

Review comment:
       Add a separate UT for `layout_rewrite`, make sure it works well and get the correct results.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r481595812



##########
File path: include/tvm/auto_scheduler/compute_dag.h
##########
@@ -205,19 +205,28 @@ class ComputeDAG : public ObjectRef {
    */
   TVM_DLL explicit ComputeDAG(Array<te::Tensor> tensors);
 
+  /*!
+   * \brief Rewrite the layout of placeholder specified by attr `layout_free_placeholders`
+   * according to the loop nest derived with `transform_steps`.
+   * \param transform_steps Transform steps of a state.
+   */
+  void RewriteLayout(const Array<Step>& transform_steps);
+
   /*!
    * \brief Apply the history transform steps to get a TVM schedule.
    * \param transform_steps Transform steps of a state.
    * \param stages The list of stages after applying the steps.
    * Pass a valid pointer if this information needs to be used outside this function.
    * \param stage_to_axes The map that stores all axes for one stage.
    * Pass a valid pointer if this information needs to be used outside this function.
+   * \param layout_rewrite Rewrite the layout of placeholder.

Review comment:
       ```suggestion
      * \param layout_rewrite Rewrite the layout of placeholders specified by
      * attr `layout_free_placeholders`
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] minminsun commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
minminsun commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r477111962



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -40,6 +40,7 @@
 #include <vector>
 
 #include "../arith/pattern_match.h"
+#include "search_policy/utils.h"

Review comment:
       Done.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#issuecomment-686729085


   cc @jwfromm @mbrookhart @yzhliu 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r472109982



##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -118,6 +123,8 @@ class IteratorNode : public Object {
   IteratorKind iter_kind;
   /*! \brief The annotation type of this iterator. */
   IteratorAnnotation annotation;
+  /*! The original iterators before fusion. */
+  std::vector<Iterator> ori_iters;

Review comment:
       Is there any better way to store the original iterators? Maybe directly get from the ComputeDAG's init_state or original op?
   It seems strange to have a Iterator list inside a IteratorNode object.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
merrymercy commented on pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#issuecomment-689286253


   Please fix some final style comments and we can merge this


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] minminsun commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
minminsun commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r477112240



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }

Review comment:
       Done.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);

Review comment:
       Done.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r472094744



##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -419,7 +419,7 @@ def timed_func():
 
         try:
             sch, args = task.compute_dag.apply_steps_from_state(
-                inp.state)
+                inp.state, layout_rewrite=True)

Review comment:
       Add a separate UT for `layout_rewrite`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] minminsun commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

Posted by GitBox <gi...@apache.org>.
minminsun commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r477105200



##########
File path: python/tvm/auto_scheduler/compute_dag.py
##########
@@ -72,7 +72,7 @@ def get_init_state(self):
         """
         return State(self.init_state, self)
 
-    def apply_steps_from_state(self, state):
+    def apply_steps_from_state(self, state, layout_rewrite=False):

Review comment:
       done




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org