You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/06 18:00:05 UTC

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #6297: [Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler

jwfromm commented on a change in pull request #6297:
URL: https://github.com/apache/incubator-tvm/pull/6297#discussion_r484096899



##########
File path: python/tvm/auto_scheduler/compute_dag.py
##########
@@ -81,12 +81,15 @@ def apply_steps_from_state(self, state):
         state : Union[State, StateObject]
             The state from which we get transform steps.
 
+        layout_rewrite: Bool
+            Rewrite the layout of placeholder.

Review comment:
       This description doesnt add much beyond the variable name. What's the benefit of doing this? Maybe add a sentence describing when you'd want to set this.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -665,9 +666,349 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
+/*!
+ * \brief utility function for kernel_layout_transform
+ */
+inline void parse_kernel_layout(const String& layout, Array<PrimExpr>* shape,
+                                std::vector<std::string>* axes) {
+  int32_t factor = 0;
+  std::string axis = "";
+  for (char c : std::string(layout)) {
+    if (c >= 'A' && c <= 'z') {
+      axis += c;
+      if (factor != 0) {
+        shape->push_back(factor);
+        factor = 0;
+      }
+    } else if (c >= '0' && c <= '9') {
+      factor = factor * 10 + c - '0';
+      if (!axis.empty()) {
+        axes->push_back(axis);
+        axis = "";
+      }
+    } else {
+      LOG(FATAL) << "Invalid layout " << layout;
+    }
+  }
+  if (!axis.empty()) {
+    axes->push_back(axis);
+  }
+}
+
+std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
+
+class IndexRewriter : public StmtExprMutator {
+ public:
+  IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
+      : placeholder_op_(placeholder_op), new_layout_(new_layout) {}
+
+  PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    te::Tensor t = Downcast<te::Tensor>(op->producer);
+    if (t->op == placeholder_op_) {
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_names;
+      parse_kernel_layout(new_layout_, &new_shape, &new_names);
+      std::unordered_map<std::string, PrimExpr> name_to_arg;
+      for (const auto& arg : op->indices) {
+        std::string axis_name;
+        if (const auto* pimm = arg.as<IntImmNode>()) {
+          CHECK_EQ(pimm->value, 0);
+          axis_name = "IntImm";
+        } else {
+          axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
+          CHECK_EQ(name_to_arg.count(axis_name), 0);
+          name_to_arg[axis_name] = arg;
+        }
+      }
+
+      std::unordered_map<std::string, PrimExpr> div_factors;
+      std::vector<PrimExpr> r_new_args;
+      for (int i = new_names.size() - 1; i >= 0; --i) {
+        auto ori_iter_name = new_names[i];
+        auto name_it = name_to_arg.find(ori_iter_name);
+        CHECK(name_it != name_to_arg.end());
+        PrimExpr ori_arg = name_it->second;
+
+        PrimExpr mod_factor = new_shape[i];
+
+        PrimExpr div_factor = 1;
+        if (div_factors.count(ori_iter_name)) {
+          div_factor = div_factors[ori_iter_name];
+        }
+        div_factors[ori_iter_name] = div_factor * new_shape[i];
+
+        PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
+
+        r_new_args.push_back(new_arg);
+      }
+
+      Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
+                               std::make_move_iterator(r_new_args.rend()));
+      return ProducerLoad(op->producer, new_args);
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const te::Operation& placeholder_op_;
+  const std::string& new_layout_;
+};
+
+std::string get_ori_layout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,

Review comment:
       instead of using `ori` as shorthand for original, it's probably worth the extra letter to go with `orig`, which is much less ambiguous.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org