You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/29 07:19:01 UTC

[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6141: [Ansor][AutoTVM v2.0] Phase 1: Add pragma/storage_align/rfactor steps

merrymercy commented on a change in pull request #6141:
URL: https://github.com/apache/incubator-tvm/pull/6141#discussion_r461151066



##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -812,6 +937,74 @@ class CacheWriteStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode);
 };
 
+/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */
+class RfactorStepNode : public StepNode {
+ public:
+  /*! \brief The index of the iterator to be factored. */
+  int iter_id;
+  /*! \brief The position where the new iterator is placed. */
+  int factor_iter_id;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State, which will be updated.
+   * \param dag The original ComputeDAG of this state.
+   * \return The index of the new added stage.
+   */
+  int ApplyToState(State* state, const ComputeDAG& dag) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map..
+   * \param schedule A mutable pointer to a te::Schedule.
+   * \return The output Tensors of the new added stage.
+   */
+  Array<te::Tensor> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                    te::Schedule* schedule) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param schedule A mutable pointer to a te::Schedule.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          te::Schedule* schedule) const;
+
+  static constexpr const char* record_prefix_str = "RF";
+
+  static constexpr const char* _type_key = "auto_scheduler.RfactorStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to RfactorStepNode.
+ * \sa RfactorStepNode
+ */
+class RfactorStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the iterator to be factored.

Review comment:
       ```suggestion
      * \param stage_id The index of the stage to be factored.
   ```

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -967,11 +1162,27 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
  */
 Array<Step> GetFormerStageModifiableSteps(Step current_step, const Array<Step>& transform_steps) {
   Array<Step> ret_steps;
-  for (const Step& step : transform_steps) {
+  for (size_t i = 0; i < transform_steps.size(); ++i) {
+    const Step& step = transform_steps[i];
     if (step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>()) {
       ret_steps.push_back(step);
+    } else if (step->IsInstance<RfactorStepNode>()) {
+      // add FuseStepNode required by rfactor
+      if (i >= 2 && transform_steps[i - 2]->IsInstance<FuseStepNode>()) {
+        const Step& fuse_step = transform_steps[i - 2];
+        if (fuse_step->stage_id == step->stage_id) {
+          ret_steps.push_back(fuse_step);
+        }
+      }
+      // add SplitStepNode required by rfactor
+      CHECK_GE(i, 1);
+      CHECK(transform_steps[i - 1]->IsInstance<SplitStepNode>());
+      const Step& split_step = transform_steps[i - 1];
+      CHECK_EQ(split_step->stage_id, step->stage_id);
+      ret_steps.push_back(split_step);
+      // add RfactorStepNode
+      ret_steps.push_back(step);

Review comment:
       yes. Otherwise, there will be errors in in creating the new dag

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -967,11 +1162,27 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
  */
 Array<Step> GetFormerStageModifiableSteps(Step current_step, const Array<Step>& transform_steps) {
   Array<Step> ret_steps;
-  for (const Step& step : transform_steps) {
+  for (size_t i = 0; i < transform_steps.size(); ++i) {
+    const Step& step = transform_steps[i];
     if (step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>()) {
       ret_steps.push_back(step);
+    } else if (step->IsInstance<RfactorStepNode>()) {
+      // add FuseStepNode required by rfactor
+      if (i >= 2 && transform_steps[i - 2]->IsInstance<FuseStepNode>()) {
+        const Step& fuse_step = transform_steps[i - 2];
+        if (fuse_step->stage_id == step->stage_id) {
+          ret_steps.push_back(fuse_step);
+        }
+      }
+      // add SplitStepNode required by rfactor
+      CHECK_GE(i, 1);
+      CHECK(transform_steps[i - 1]->IsInstance<SplitStepNode>());
+      const Step& split_step = transform_steps[i - 1];
+      CHECK_EQ(split_step->stage_id, step->stage_id);
+      ret_steps.push_back(split_step);
+      // add RfactorStepNode
+      ret_steps.push_back(step);

Review comment:
       yes. Otherwise, there will be errors when creating the new dag

##########
File path: tests/python/unittest/test_auto_scheduler_measure.py
##########
@@ -71,6 +73,13 @@ def test_record():
     s.compute_at(D_global, E, s[E].iters[2])
     # Cache Write
     s.cache_write(D, "shared")
+    # Pragma

Review comment:
       This test is too long. Split it into smaller functions




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org