You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/29 07:58:35 UTC

[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

comaniac commented on a change in pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#discussion_r461082016



##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -301,6 +310,89 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
                                                      iterator, lengths, inner_to_outer)
         return res
 
+    def follow_split(self, stage, iterator, src_step_id, n_split):
+        """ Schedule primitive extends to split step.
+
+        This step is used to follow a former SplitStep, keeps their iterator structures to be same.

Review comment:
       ```suggestion
           This step splits the iterator by the same factors as the given SplitStep.
   ```

##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -301,6 +310,89 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
                                                      iterator, lengths, inner_to_outer)
         return res
 
+    def follow_split(self, stage, iterator, src_step_id, n_split):
+        """ Schedule primitive extends to split step.
+
+        This step is used to follow a former SplitStep, keeps their iterator structures to be same.
+
+        Example cases:
+            With subgraph: Dense -> Relu
+            Some tiling structures are used in Relu stage and we intend to compute the Dense
+            stage at Relu.
+            The follow_split is used here to keep their outer most few iterators the same for
+            applying compute at.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be split, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iterator : Iterator
+            The iterator to split.
+        src_step_id : int
+            The index of the split step to follow in the history.
+        n_split : int
+            The number of split level.
+
+        Returns
+        -------
+        res_its : List[Iterator]
+            The splitted new Iterators.
+        """
+
+        self.state_object, res = _ffi_api.StateFollowSplit(self.state_object,
+                                                           self._resolve_stage_id(stage),
+                                                           iterator,
+                                                           src_step_id, n_split)
+        return res
+
+    def follow_fused_split(self, stage, iterator, src_step_ids, level,
+                           factor_or_nparts):
+        """ Schedule primitive extends to split step.
+
+        This step is used to follow several former SplitSteps and FuseSteps.
+
+        Example cases:
+            With subgraph in GPU schedule: Input -> Dense
+            for i.0@j.0 = ... : Bind to blockIdx.x
+                for i.1@j.1 = ... : Bind to threadIdx.x
+                    for i.2@j.2 = ...
+                        Input_shared = Input ...
+                        for k = ...
+                            Dense = ...
+            We intend to apply cooperative fetching with the Input stage, while the threadIdx.x
+            axis is binded to a iterator generated by split & fuse step.
+            The follow_fused_step is used here to figure out the final extent of the threadIdx.x
+            binded iterator.

Review comment:
       ```suggestion
           This step is used to split an iterator by the same factors as the given list of SplitSteps and FuseSteps.
           
           Notes
           ------
               This step is useful in a scenario that we have a subgraph in GPU schedule: Input -> Dense
               for i.0@j.0 = ... : Bind to blockIdx.x
                   for i.1@j.1 = ... : Bind to threadIdx.x
                       for i.2@j.2 = ...
                           Input_shared = Input ...
                           for k = ...
                               Dense = ...
               We intend to apply cooperative fetching with the input stage, while the threadIdx.x
               axis is bound to an iterator generated by split & fuse step.
               The follow_fused_step is used split the iterator to 2 parts, while the split factor matches
               the final extent of the threadIdx.x bound iterator.
   ```

##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -301,6 +310,89 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
                                                      iterator, lengths, inner_to_outer)
         return res
 
+    def follow_split(self, stage, iterator, src_step_id, n_split):
+        """ Schedule primitive extends to split step.
+
+        This step is used to follow a former SplitStep, keeps their iterator structures to be same.
+
+        Example cases:
+            With subgraph: Dense -> Relu
+            Some tiling structures are used in Relu stage and we intend to compute the Dense
+            stage at Relu.
+            The follow_split is used here to keep their outer most few iterators the same for
+            applying compute at.

Review comment:
       ```suggestion
           Notes
           ------
               This step is useful in a scenario that we have subgraph Dense -> Relu,
               and we want to compute the Dense stage at ReLU. In this case, we need them to have
               the same tiling structure of common outer loops.
               The follow_split step could be used here to split the Dense stage and makes sure its
               splitting factors are the same as the given split step for the ReLU stage.
   ```

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -111,6 +115,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<SplitStepNode>()) {
     ps->ApplyToState(state);
+  } else if (auto ps = step.as<FollowSplitStepNode>()) {

Review comment:
       In this way we will have a super long if-statement because all steps can be put together. Per discussion in the cache_read/write PR, we will refactor this part in the future to eliminate redundant step dispatching.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);

Review comment:
       The order here corresponds to the read order defined in the constructor of this step.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);

Review comment:
       1. Need to make sure the size of `ps->lengths` is smaller than `n_split - 1`.
   2. Add comment for this loop and the following loop.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;

Review comment:
       Same here. IMO, I don't think the following loop is required.

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -489,6 +492,166 @@ class SplitStep : public Step {
 
 /********** Steps working on multiple stages **********/
 
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to be split. */
+  int iter_id;
+  /*! \brief The index of the split step to follow in the history. */
+  int src_step_id;
+  /*! \brief The number of split level. */
+  int n_split;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split lengths.
+   * \param transform_steps An array record all transform steps.
+   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   */
+  void ExtractSplitLengths(const Array<Step>& transform_steps,
+                           Array<Optional<Integer>>* lengths) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          const Array<Step>& transform_steps) const;
+
+  static constexpr const char* record_prefix_str = "FSP";
+
+  static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be split.
+   * \param iter_id The index of the iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   */
+  FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps.
+ *  \note This can be used for the split in cooperative fetching

Review comment:
       ```suggestion
   /*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps.
    *  \note This can be used for the split in cooperative fetching.
   ```

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);

Review comment:
       Add comment.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;
+  for (; j < static_cast<int>(ps->lengths.size()); ++j) {
+    if (ps->lengths[j]) {
+      last_factor *= ps->lengths[j].value();
+    } else {
+      last_factor = PrimExpr();
+      break;
+    }
+  }
+  if (last_factor.defined()) {
+    lengths->push_back(Downcast<Integer>(last_factor));
+  } else {
+    lengths->push_back(NullOpt);
+  }
+}
+
+FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->src_step_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->n_split);
+
+  data_ = std::move(node);
+}
+
+Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths((*state)->transform_steps, &lengths);
+  return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+}
+
+Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                    StageToAxesMap* stage_to_axes,
+                                                    const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                             StageToAxesMap* stage_to_axes,
+                                             const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+/********** Follow Fused Split **********/
+FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
+                                           const Array<Integer>& src_step_ids, int level,
+                                           bool factor_or_nparts) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_ids = src_step_ids;
+  node->level = level;
+  node->factor_or_nparts = factor_or_nparts;
+  data_ = std::move(node);
+}
+
+FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->level);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->factor_or_nparts);
+
+  ::tvm::Array<::tvm::Integer> src_step_ids;
+  for (const auto& i : int_list) {
+    src_step_ids.push_back(i);
+  }
+  node->src_step_ids = src_step_ids;
+  data_ = std::move(node);
+}
+
+void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(IntArrayToVector(src_step_ids));
+  writer->WriteArrayItem(level);
+  writer->WriteArrayItem(static_cast<int>(factor_or_nparts));
+}
+
+Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength(
+    const Array<Step>& transform_steps) const {
+  PrimExpr ret(1);
+
+  for (int src_step_id : src_step_ids) {
+    CHECK_LT(src_step_id, transform_steps.size());
+    auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+    CHECK(ps != nullptr);

Review comment:
       Add comments.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;
+  for (; j < static_cast<int>(ps->lengths.size()); ++j) {
+    if (ps->lengths[j]) {
+      last_factor *= ps->lengths[j].value();
+    } else {
+      last_factor = PrimExpr();
+      break;
+    }
+  }
+  if (last_factor.defined()) {
+    lengths->push_back(Downcast<Integer>(last_factor));
+  } else {
+    lengths->push_back(NullOpt);
+  }
+}
+
+FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->src_step_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->n_split);
+
+  data_ = std::move(node);
+}
+
+Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths((*state)->transform_steps, &lengths);
+  return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+}
+
+Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                    StageToAxesMap* stage_to_axes,
+                                                    const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                             StageToAxesMap* stage_to_axes,
+                                             const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+/********** Follow Fused Split **********/
+FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
+                                           const Array<Integer>& src_step_ids, int level,
+                                           bool factor_or_nparts) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_ids = src_step_ids;
+  node->level = level;
+  node->factor_or_nparts = factor_or_nparts;
+  data_ = std::move(node);
+}
+
+FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);

Review comment:
       This is the decision we made in the last PR. Since the (de)serialization logic of each step differ a lot, self-maintained is better to manage. In this way, we can easily trace the logic simply by looking at the step definition class instead of searching around the code base.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org