You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/01/08 03:38:33 UTC

[GitHub] HyperZealot commented on a change in pull request #13687: split_v2 operator

HyperZealot commented on a change in pull request #13687: split_v2 operator
URL: https://github.com/apache/incubator-mxnet/pull/13687#discussion_r245869554
 
 

 ##########
 File path: src/operator/tensor/matrix_op-inl.h
 ##########
 @@ -2520,6 +2520,323 @@ void SpaceToDepthOpForward(const nnvm::NodeAttrs& attrs,
   });
 }
 
+namespace split_enum {
+enum SplitOpInputs {kData};
+}  // namespace split_enum
+
+struct SplitParam : public dmlc::Parameter<SplitParam> {
+  TShape indices;
+  int axis;
+  bool squeeze_axis;
+  int sections;
+  DMLC_DECLARE_PARAMETER(SplitParam) {
+    DMLC_DECLARE_FIELD(indices)
+    .describe("Indices of splits. The elements should denote the boundaries of at which split"
+              " is performed along the `axis`.");
+    DMLC_DECLARE_FIELD(axis).set_default(1)
+    .describe("Axis along which to split.");
+    DMLC_DECLARE_FIELD(squeeze_axis).set_default(0)
+    .describe("If true, Removes the axis with length 1 from the shapes of the output arrays."
+              " **Note** that setting `squeeze_axis` to ``true`` removes axis with length 1"
+              " only along the `axis` which it is split."
+              " Also `squeeze_axis` can be set to ``true``"
+              " only if ``input.shape[axis] == num_outputs``.");
+    DMLC_DECLARE_FIELD(sections).set_default(0)
+    .describe("Number of sections if equally splitted. Default to 0 which means split by indices.");
+  }
+};  // struct SplitParam
+
+inline TShape GetSplitIndices(const TShape& ishape, int axis, int sections) {
+  TShape indices(sections+1);
+  indices[0] = 0;
+  int64_t section_size = ishape[axis] / sections;
+  for (int i = 0; i < sections; ++i) {
+    indices[i+1] = section_size * (i + 1);
+  }
+  return indices;
+}
+
+inline bool SplitOpType(const nnvm::NodeAttrs& attrs,
+                        std::vector<int>* in_attrs,
+                        std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  int dtype = (*in_attrs)[0];
+  CHECK_NE(dtype, -1) << "First input must have specified type";
+  const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
+  out_attrs->clear();
+  int num_outputs = (param.sections > 0) ? param.sections : param.indices.ndim();
+  for (int i = 0; i < num_outputs; ++i) {
+    out_attrs->push_back(dtype);
+  }
+  return true;
+}
+
+inline bool SplitOpShape(const nnvm::NodeAttrs& attrs,
+                         std::vector<TShape>* in_attrs,
+                         std::vector<TShape>* out_attrs) {
+  using namespace mshadow;
+  const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), 1U);
+  TShape dshape = in_attrs->at(split_enum::kData);
+  TShape ishape = in_attrs->at(split_enum::kData);
+  if (dshape.ndim() == 0) return false;
+  if (param.axis >= 0) {
+    CHECK_LT(static_cast<size_t>(param.axis), dshape.ndim());
+  } else {
+    CHECK_LT(param.axis + dshape.ndim(), dshape.ndim());
+  }
+  int real_axis = param.axis;
+  if (real_axis < 0) {
+    real_axis += dshape.ndim();
+  }
+  const TShape indices =
+    (param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections) : param.indices;
+  int num_outputs = (param.sections > 0) ? indices.ndim() - 1 : indices.ndim();
+  // Pre-compute squeezed output shape for future usage
+  TShape squeezed_dshape = dshape;
+  for (int d = real_axis; d < static_cast<int>(squeezed_dshape.ndim()) - 1; ++d) {
+    squeezed_dshape[d] = squeezed_dshape[d+1];
+  }
+  squeezed_dshape = TShape(&squeezed_dshape[0], &squeezed_dshape[squeezed_dshape.ndim()-1]);
+  // Assign shape to every output
+  for (int i = 0; i < num_outputs; ++i) {
+    int start = indices[i];
+    int end = (i < num_outputs - 1) ? indices[i + 1] : ishape[real_axis];
+    CHECK(start < end)
+      << "start " << start << " is not less than end " << end << "for subarray " << i;
+    CHECK(end <= ishape[real_axis])
+      << "end " << end << " is no less than the size of the axis " << ishape[real_axis];
+    dshape[real_axis] = (end - start);
+    if (param.squeeze_axis) {
+      CHECK_EQ(end - start, 1U) << "expected axis size of 1 but got " << end - start;
+      SHAPE_ASSIGN_CHECK(*out_attrs, i, squeezed_dshape);
+    } else {
+      SHAPE_ASSIGN_CHECK(*out_attrs, i, dshape);
+    }
+  }
+  TShape back_calculate_dshape = ishape;
+  back_calculate_dshape[real_axis] = 0;
+  for (int d = 0; d < real_axis; ++d) {
+    back_calculate_dshape[d] = (*out_attrs)[0][d];
+  }
+  if (param.squeeze_axis) {
+    back_calculate_dshape[real_axis] = num_outputs;
+  } else {
+    for (int i = 0; i < num_outputs; ++i) {
+      back_calculate_dshape[real_axis] += (*out_attrs)[i][real_axis];
+    }
+  }
+  for (int d = real_axis + 1; d < static_cast<int>(ishape.ndim()); ++d) {
+    if (param.squeeze_axis) {
+      back_calculate_dshape[d] = (*out_attrs)[0][d - 1];
+    } else {
+      back_calculate_dshape[d] = (*out_attrs)[0][d];
+    }
+  }
+  SHAPE_ASSIGN_CHECK(*in_attrs, split_enum::kData, back_calculate_dshape);
+  return true;
+}
+
+struct SplitKernel {
+  /*!
+   * \brief Map function for split operator indices option
+   * \param i              global thread id
+   * \param in_data        ptr to input buffer
+   * \param out_data       ptr to ptr of outputs buffer
+   * \param indices        ptr to indices buffer
+   * \param num_sections   # of sections after split
+   * \param axis_size      size of axis to be splitted on
+   * \param trailing_size  step size within the data buffer of the axis to be splitted on
+   */
+  template<typename DType>
+  static MSHADOW_XINLINE void Map(size_t i,
+                                  const DType *in_data, DType** out_data, const size_t* indices,
+                                  const size_t num_sections, const size_t axis_size,
+                                  const size_t trailing_size) {
+    size_t idx = i / trailing_size % axis_size;
+    size_t target = 0;
+    for (size_t section = 0; section < num_sections; target = section++) {
+      if (indices[section] > idx) {
+        break;
+      }
+    }
+    DType* target_data = out_data[target];
+    const size_t mid_idx = idx - indices[target];
+    const size_t head_idx = i / (trailing_size * axis_size);
+    const size_t tail_idx = i % trailing_size;
+    const size_t section_size = indices[target + 1] - indices[target];
+    const size_t target_idx =
+      head_idx * trailing_size * section_size + mid_idx * trailing_size + tail_idx;
+    target_data[target_idx] = in_data[i];
+  }
+};
+
+struct ConcatenateKernel {
+  /*!
+   * \brief Map function for split operator indices option
 
 Review comment:
   Done.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services