You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/01/23 21:09:51 UTC

[incubator-mxnet] branch master updated: split_v2 (#13687)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 45d1a1e  split_v2 (#13687)
45d1a1e is described below

commit 45d1a1e6ff9c58cfc75f72651c1bf671ac7f1885
Author: HyperZealot <40...@users.noreply.github.com>
AuthorDate: Wed Jan 23 13:09:33 2019 -0800

    split_v2 (#13687)
---
 python/mxnet/ndarray/ndarray.py        |  61 ++++++-
 python/mxnet/symbol/symbol.py          |  55 +++++-
 src/operator/tensor/matrix_op-inl.h    | 313 +++++++++++++++++++++++++++++++++
 src/operator/tensor/matrix_op.cc       | 103 ++++++++++-
 src/operator/tensor/matrix_op.cu       |   6 +
 tests/python/unittest/test_ndarray.py  |   4 +-
 tests/python/unittest/test_operator.py |  19 ++
 7 files changed, 556 insertions(+), 5 deletions(-)

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 9a62620..fb329f1 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -47,7 +47,7 @@ __all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRA
            "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
            "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
            "power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram",
-           "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"]
+           "split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"]
 
 _STORAGE_TYPE_UNDEFINED = -1
 _STORAGE_TYPE_DEFAULT = 0
@@ -1133,6 +1133,14 @@ fixed-size items.
         """
         return op.split(self, *args, **kwargs)
 
+    def split_v2(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`split_v2`.
+
+        The arguments are the same as for :py:func:`split_v2`, with
+        this array as data.
+        """
+        return split_v2(self, *args, **kwargs)
+
     def slice(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`slice`.
 
@@ -3901,6 +3909,12 @@ def histogram(a, bins=10, range=None):
         Values outside the range are ignored. The first element of the range must be less than or
         equal to the second. range affects the automatic bin computation as well, the range will
         be equally divided by the number of bins.
+
+    Returns
+    -------
+    NDArray
+        A created array.
+
     """
 
     # pylint: disable= no-member, protected-access
@@ -3916,6 +3930,51 @@ def histogram(a, bins=10, range=None):
     raise ValueError("bins argument should be either an integer or an NDArray")
     # pylint: enable= no-member, protected-access, redefined-builtin
 
+def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False):
+    """Split an array into multiple sub-arrays.
+
+    Parameters
+    ----------
+    ary : NDArray
+        Array to be divided into sub-arrays.
+    indices_or_sections : int or tuple of ints
+        If `indices_or_sections` is an integer, N, the array will be divided
+        into N equal arrays along `axis`.  If such a split is not possible,
+        an error is raised.
+        If `indices_or_sections` is a 1-D array of sorted integers, the entries
+        indicate where along `axis` the array is split.  For example,
+        ``[2, 3]`` would, for ``axis=0``, result in
+        - ary[:2]
+        - ary[2:3]
+        - ary[3:]
+        If an index exceeds the dimension of the array along `axis`,
+        an empty sub-array is returned correspondingly.
+    axis : int, optional
+        The axis along which to split, default is 0.
+    squeeze_axis: boolean, optional
+        Whether to squeeze the axis of sub-arrays or not, only useful when size
+        of the sub-arrays are 1 on the `axis`. Default is False.
+
+    Returns
+    -------
+    NDArray
+        A created array.
+
+    """
+    indices = []
+    axis_size = ary.shape[axis]
+    if isinstance(indices_or_sections, int):
+        sections = indices_or_sections
+        if axis_size % sections:
+            raise ValueError('array split does not result in an equal division')
+        section_size = int(axis_size / sections)
+        indices = [i * section_size for i in range(sections)]
+    elif isinstance(indices_or_sections, tuple):
+        indices = [0] + list(indices_or_sections)
+    else:
+        raise ValueError('indices_or_sections must either int or tuple of ints')
+    return _internal._split_v2(ary, indices, axis, squeeze_axis)
+
 PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
 _c_str_dltensor = c_str('dltensor')
 _c_str_used_dltensor = c_str('used_dltensor')
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 530d727..43de0c9 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -48,7 +48,7 @@ from ._internal import SymbolBase, _set_symbol_class
 
 __all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
            "pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange",
-           "histogram"]
+           "histogram", "split_v2"]
 
 
 class Symbol(SymbolBase):
@@ -1855,6 +1855,14 @@ class Symbol(SymbolBase):
         """
         return op.split(self, *args, **kwargs)
 
+    def split_v2(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`split_v2`.
+
+        The arguments are the same as for :py:func:`split_v2`, with
+        this array as data.
+        """
+        return split_v2(self, *args, **kwargs)
+
     def slice(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`slice`.
 
@@ -2958,6 +2966,11 @@ def histogram(a, bins=10, range=None, **kwargs):
         Values outside the range are ignored. The first element of the range must be less than or
         equal to the second. range affects the automatic bin computation as well, the range will
         be equally divided by the number of bins.
+
+    Returns
+    -------
+    out : Symbol
+        The created Symbol
     """
     if isinstance(bins, Symbol):
         return _internal._histogram(data=a, bins=bins, **kwargs)
@@ -2967,4 +2980,44 @@ def histogram(a, bins=10, range=None, **kwargs):
         return _internal._histogram(data=a, bin_cnt=bins, range=range, **kwargs)
     raise ValueError("bins argument should be either an integer or an NDArray")
 
+def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False):
+    """Split an array into multiple sub-arrays.
+
+    Parameters
+    ----------
+    ary : NDArray
+        Array to be divided into sub-arrays.
+    indices_or_sections : int or tuple of ints
+        If `indices_or_sections` is an integer, N, the array will be divided
+        into N equal arrays along `axis`.  If such a split is not possible,
+        an error is raised.
+        If `indices_or_sections` is a 1-D array of sorted integers, the entries
+        indicate where along `axis` the array is split.  For example,
+        ``[2, 3]`` would, for ``axis=0``, result in
+        - ary[:2]
+        - ary[2:3]
+        - ary[3:]
+        If an index exceeds the dimension of the array along `axis`,
+        an empty sub-array is returned correspondingly.
+    axis : int, optional
+        The axis along which to split, default is 0.
+    squeeze_axis: boolean, optional
+        Whether to squeeze the axis of sub-arrays or not, only useful when size
+        of the sub-arrays are 1 on the `axis`. Default is False.
+
+    Returns
+    -------
+    out : Symbol
+        The created Symbol
+    """
+    indices = []
+    sections = 0
+    if isinstance(indices_or_sections, int):
+        sections = indices_or_sections
+    elif isinstance(indices_or_sections, tuple):
+        indices = [0] + list(indices_or_sections)
+    else:
+        raise ValueError('indices_or_sections must either int or tuple of ints')
+    return _internal._split_v2(ary, indices, axis, squeeze_axis, sections)
+
 _set_symbol_class(Symbol)
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 8b575ca..97c4fa5 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -2527,6 +2527,319 @@ void SpaceToDepthOpForward(const nnvm::NodeAttrs& attrs,
   });
 }
 
+namespace split_enum {
+enum SplitOpInputs {kData};
+}  // namespace split_enum
+
+struct SplitParam : public dmlc::Parameter<SplitParam> {
+  TShape indices;
+  int axis;
+  bool squeeze_axis;
+  int sections;
+  DMLC_DECLARE_PARAMETER(SplitParam) {
+    DMLC_DECLARE_FIELD(indices)
+    .describe("Indices of splits. The elements should denote the boundaries of at which split"
+              " is performed along the `axis`.");
+    DMLC_DECLARE_FIELD(axis).set_default(1)
+    .describe("Axis along which to split.");
+    DMLC_DECLARE_FIELD(squeeze_axis).set_default(0)
+    .describe("If true, Removes the axis with length 1 from the shapes of the output arrays."
+              " **Note** that setting `squeeze_axis` to ``true`` removes axis with length 1"
+              " only along the `axis` which it is split."
+              " Also `squeeze_axis` can be set to ``true``"
+              " only if ``input.shape[axis] == num_outputs``.");
+    DMLC_DECLARE_FIELD(sections).set_default(0)
+    .describe("Number of sections if equally splitted. Default to 0 which means split by indices.");
+  }
+};  // struct SplitParam
+
+inline TShape GetSplitIndices(const TShape& ishape, int axis, int sections) {
+  TShape indices(sections+1);
+  indices[0] = 0;
+  int64_t section_size = ishape[axis] / sections;
+  for (int i = 0; i < sections; ++i) {
+    indices[i+1] = section_size * (i + 1);
+  }
+  return indices;
+}
+
+inline bool SplitOpType(const nnvm::NodeAttrs& attrs,
+                        std::vector<int>* in_attrs,
+                        std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  int dtype = (*in_attrs)[0];
+  CHECK_NE(dtype, -1) << "First input must have specified type";
+  const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
+  out_attrs->clear();
+  int num_outputs = (param.sections > 0) ? param.sections : param.indices.ndim();
+  for (int i = 0; i < num_outputs; ++i) {
+    out_attrs->push_back(dtype);
+  }
+  return true;
+}
+
+inline bool SplitOpShape(const nnvm::NodeAttrs& attrs,
+                         std::vector<TShape>* in_attrs,
+                         std::vector<TShape>* out_attrs) {
+  using namespace mshadow;
+  const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), 1U);
+  TShape dshape = in_attrs->at(split_enum::kData);
+  TShape ishape = in_attrs->at(split_enum::kData);
+  if (dshape.ndim() == 0) return false;
+  if (param.axis >= 0) {
+    CHECK_LT(static_cast<size_t>(param.axis), dshape.ndim());
+  } else {
+    CHECK_LT(param.axis + dshape.ndim(), dshape.ndim());
+  }
+  int real_axis = param.axis;
+  if (real_axis < 0) {
+    real_axis += dshape.ndim();
+  }
+  const TShape indices =
+    (param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections) : param.indices;
+  int num_outputs = (param.sections > 0) ? indices.ndim() - 1 : indices.ndim();
+  // Pre-compute squeezed output shape for future usage
+  TShape squeezed_dshape = dshape;
+  for (int d = real_axis; d < static_cast<int>(squeezed_dshape.ndim()) - 1; ++d) {
+    squeezed_dshape[d] = squeezed_dshape[d+1];
+  }
+  squeezed_dshape = TShape(&squeezed_dshape[0], &squeezed_dshape[squeezed_dshape.ndim()-1]);
+  // Assign shape to every output
+  for (int i = 0; i < num_outputs; ++i) {
+    int start = indices[i];
+    int end = (i < num_outputs - 1) ? indices[i + 1] : ishape[real_axis];
+    CHECK(start < end)
+      << "start " << start << " is not less than end " << end << "for subarray " << i;
+    CHECK(end <= ishape[real_axis])
+      << "end " << end << " is no less than the size of the axis " << ishape[real_axis];
+    dshape[real_axis] = (end - start);
+    if (param.squeeze_axis) {
+      CHECK_EQ(end - start, 1U) << "expected axis size of 1 but got " << end - start;
+      SHAPE_ASSIGN_CHECK(*out_attrs, i, squeezed_dshape);
+    } else {
+      SHAPE_ASSIGN_CHECK(*out_attrs, i, dshape);
+    }
+  }
+  TShape back_calculate_dshape = ishape;
+  back_calculate_dshape[real_axis] = 0;
+  for (int d = 0; d < real_axis; ++d) {
+    back_calculate_dshape[d] = (*out_attrs)[0][d];
+  }
+  if (param.squeeze_axis) {
+    back_calculate_dshape[real_axis] = num_outputs;
+  } else {
+    for (int i = 0; i < num_outputs; ++i) {
+      back_calculate_dshape[real_axis] += (*out_attrs)[i][real_axis];
+    }
+  }
+  for (int d = real_axis + 1; d < static_cast<int>(ishape.ndim()); ++d) {
+    if (param.squeeze_axis) {
+      back_calculate_dshape[d] = (*out_attrs)[0][d - 1];
+    } else {
+      back_calculate_dshape[d] = (*out_attrs)[0][d];
+    }
+  }
+  SHAPE_ASSIGN_CHECK(*in_attrs, split_enum::kData, back_calculate_dshape);
+  return true;
+}
+
+struct SplitKernel {
+  /*!
+   * \brief Map function for forward split_v2 operator
+   * \param i              global thread id
+   * \param in_data        ptr to input buffer
+   * \param out_data       ptr to ptr of outputs buffer
+   * \param indices        ptr to indices buffer
+   * \param num_sections   # of sections after split
+   * \param axis_size      size of axis to be splitted on
+   * \param trailing_size  step size within the data buffer of the axis to be splitted on
+   */
+  template<typename DType>
+  static MSHADOW_XINLINE void Map(size_t i,
+                                  const DType *in_data, DType** out_data, const size_t* indices,
+                                  const size_t num_sections, const size_t axis_size,
+                                  const size_t trailing_size) {
+    size_t idx = i / trailing_size % axis_size;
+    size_t target = 0;
+    for (size_t section = 0;
+         section < num_sections && indices[section] <= idx;
+         target = section++) {}
+    DType* target_data = out_data[target];
+    const size_t mid_idx = idx - indices[target];
+    const size_t head_idx = i / (trailing_size * axis_size);
+    const size_t tail_idx = i % trailing_size;
+    const size_t section_size = indices[target + 1] - indices[target];
+    const size_t target_idx =
+      head_idx * trailing_size * section_size + mid_idx * trailing_size + tail_idx;
+    target_data[target_idx] = in_data[i];
+  }
+};
+
+struct ConcatenateKernel {
+  /*!
+   * \brief Map function for backward split_v2 operator
+   * \param i              global thread id
+   * \param out_grad       ptr to ptr of out grads buffer
+   * \param in_grad        ptr to input grad buffer
+   * \param indices        ptr to indices buffer
+   * \param num_sections   # of sections after split
+   * \param axis_size      size of axis to be splitted on
+   * \param trailing_size  step size within the data buffer of the axis to be splitted on
+   */
+  template<typename DType>
+  static MSHADOW_XINLINE void Map(size_t i,
+                                  DType** out_grad, DType* in_grad, const size_t* indices,
+                                  const size_t num_sections, const size_t axis_size,
+                                  const size_t trailing_size) {
+    size_t idx = i / trailing_size % axis_size;
+    size_t src = 0;
+    for (size_t section = 0;
+         section < num_sections && indices[section] <= idx;
+         src = section++) {}
+    DType* src_grad = out_grad[src];
+    const size_t mid_idx = idx - indices[src];
+    const size_t head_idx = i / (trailing_size * axis_size);
+    const size_t tail_idx = i % trailing_size;
+    const size_t section_size = indices[src + 1] - indices[src];
+    const size_t src_idx =
+      head_idx * trailing_size * section_size + mid_idx * trailing_size + tail_idx;
+    in_grad[i] = src_grad[src_idx];
+  }
+};
+
+template<typename xpu>
+inline void SplitOpForward(const nnvm::NodeAttrs& attrs,
+                           const OpContext& ctx,
+                           const std::vector<TBlob>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim());
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob& input_data = inputs[split_enum::kData];
+  size_t leading = 1, trailing = 1;
+  int real_axis = param.axis;
+  if (real_axis < 0) {
+    real_axis += input_data.ndim();
+  }
+  CHECK_LT(real_axis, input_data.ndim());
+  size_t mid = input_data.shape_[real_axis];
+  for (int i = 0; i < real_axis; ++i) {
+    leading *= input_data.shape_[i];
+  }
+  for (int i = real_axis + 1; i < input_data.ndim(); ++i) {
+    trailing *= input_data.shape_[i];
+  }
+
+  size_t workspace_size = 0;
+  const TShape& ishape = input_data.shape_;
+  const TShape split_pts =
+    (param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections) : param.indices;
+  std::vector<size_t> indices;
+  for (const auto& section : split_pts) {
+    indices.push_back(section);
+  }
+  if (param.sections == 0) {
+    indices.push_back(ishape[real_axis]);
+  }
+  workspace_size += indices.size() * sizeof(size_t);
+  MSHADOW_TYPE_SWITCH(input_data.type_flag_, DType, {
+    std::vector<DType*> output_data;
+    for (const TBlob& data : outputs) {
+      output_data.push_back(data.dptr<DType>());
+    }
+    workspace_size += output_data.size() * sizeof(DType*);
+    Tensor<xpu, 1, char> workspace =
+      ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+    Tensor<cpu, 1, size_t> indices_cpu_tensor(indices.data(), Shape1(indices.size()));
+    Tensor<xpu, 1, size_t> indices_xpu_tensor(
+      reinterpret_cast<size_t*>(workspace.dptr_), Shape1(indices.size()));
+    Tensor<cpu, 1, DType*> ptrs_cpu_tensor(output_data.data(), Shape1(output_data.size()));
+    Tensor<xpu, 1, DType*> ptrs_xpu_tensor(
+      reinterpret_cast<DType**>(workspace.dptr_ + indices.size() * sizeof(size_t)),
+      Shape1(output_data.size()));
+    mshadow::Copy(indices_xpu_tensor, indices_cpu_tensor, s);
+    mshadow::Copy(ptrs_xpu_tensor, ptrs_cpu_tensor, s);
+    Kernel<SplitKernel, xpu>::Launch(
+      s, input_data.Size(), input_data.dptr<DType>(), ptrs_xpu_tensor.dptr_,
+      indices_xpu_tensor.dptr_, indices.size() - 1, mid, trailing);
+  });
+}
+
+template<typename xpu>
+inline void SplitOpBackward(const nnvm::NodeAttrs& attrs,
+                            const OpContext& ctx,
+                            const std::vector<TBlob>& inputs,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim())
+    << "out grad vector size mush match the output size";
+  CHECK_EQ(outputs.size(), 1U);
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  TBlob input_grad = outputs[split_enum::kData];
+  size_t leading = 1, trailing = 1;
+  int real_axis = param.axis;
+  if (real_axis < 0) {
+      real_axis += input_grad.ndim();
+  }
+  CHECK_LT(real_axis, input_grad.ndim());
+  size_t mid = input_grad.shape_[real_axis];
+  for (int i = 0; i < real_axis; ++i) {
+    leading *= input_grad.shape_[i];
+  }
+  for (int i = real_axis + 1; i < input_grad.ndim(); ++i) {
+    trailing *= input_grad.shape_[i];
+  }
+
+  size_t workspace_size = 0;
+  const TShape& ishape = input_grad.shape_;
+  const TShape split_pts =
+    (param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections) : param.indices;
+  std::vector<size_t> indices;
+  for (const auto& section : split_pts) {
+    indices.push_back(section);
+  }
+  if (param.sections == 0) {
+    indices.push_back(ishape[real_axis]);
+  }
+  workspace_size += indices.size() * sizeof(size_t);
+  MSHADOW_TYPE_SWITCH(input_grad.type_flag_, DType, {
+    std::vector<DType*> out_grads;
+    for (const TBlob& output_grad : inputs) {
+      out_grads.push_back(output_grad.dptr<DType>());
+    }
+    workspace_size += out_grads.size() * sizeof(DType*);
+    Tensor<xpu, 1, char> workspace =
+      ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+    Tensor<cpu, 1, size_t> indices_cpu_tensor(indices.data(), Shape1(indices.size()));
+    Tensor<xpu, 1, size_t> indices_xpu_tensor(
+      reinterpret_cast<size_t*>(workspace.dptr_), Shape1(indices.size()));
+    Tensor<cpu, 1, DType*> ptrs_cpu_tensor(out_grads.data(), Shape1(inputs.size()));
+    Tensor<xpu, 1, DType*> ptrs_xpu_tensor(
+      reinterpret_cast<DType**>(workspace.dptr_ + indices.size() * sizeof(size_t)),
+      Shape1(inputs.size()));
+    mshadow::Copy(indices_xpu_tensor, indices_cpu_tensor, s);
+    mshadow::Copy(ptrs_xpu_tensor, ptrs_cpu_tensor, s);
+    Kernel<ConcatenateKernel, xpu>::Launch(
+      s, input_grad.Size(), ptrs_xpu_tensor.dptr_, input_grad.dptr<DType>(),
+      indices_xpu_tensor.dptr_, indices.size() - 1, mid, trailing);
+  });
+}
+
+inline uint32_t SplitNumOutputs(const NodeAttrs& attrs) {
+  const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
+  return (param.sections > 0) ? param.sections : param.indices.ndim();
+}
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index ed8912f..e5d354b 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -103,6 +103,7 @@ DMLC_REGISTER_PARAMETER(ReverseParam);
 DMLC_REGISTER_PARAMETER(StackParam);
 DMLC_REGISTER_PARAMETER(SqueezeParam);
 DMLC_REGISTER_PARAMETER(DepthToSpaceParam);
+DMLC_REGISTER_PARAMETER(SplitParam);
 
 #if MXNET_USE_MKLDNN == 1
 void MKLDNNReshape(const NDArray &in_data, const NDArray &out_data) {
@@ -1071,8 +1072,8 @@ Example::
          [12, 18, 13, 19, 14, 20],
          [3, 9, 4, 10, 5, 11],
          [15, 21, 16, 22, 17, 23]]]]
-  
-  
+
+
   space_to_depth(x, 2) = [[[[0, 1, 2],
                             [3, 4, 5]],
                            [[6, 7, 8],
@@ -1100,5 +1101,103 @@ Example::
 .add_argument("data", "NDArray-or-Symbol", "Input ndarray")
 .add_arguments(DepthToSpaceParam::__FIELDS__());
 
+NNVM_REGISTER_OP(_split_v2)
+.describe(R"code(Splits an array along a particular axis into multiple sub-arrays.
+
+Example::
+
+   x  = [[[ 1.]
+          [ 2.]]
+         [[ 3.]
+          [ 4.]]
+         [[ 5.]
+          [ 6.]]]
+   x.shape = (3, 2, 1)
+
+   y = split_v2(x, axis=1, indices_or_sections=2) // a list of 2 arrays with shape (3, 1, 1)
+   y = [[[ 1.]]
+        [[ 3.]]
+        [[ 5.]]]
+
+       [[[ 2.]]
+        [[ 4.]]
+        [[ 6.]]]
+
+   y[0].shape = (3, 1, 1)
+
+   z = split_v2(x, axis=0, indices_or_sections=3) // a list of 3 arrays with shape (1, 2, 1)
+   z = [[[ 1.]
+         [ 2.]]]
+
+       [[[ 3.]
+         [ 4.]]]
+
+       [[[ 5.]
+         [ 6.]]]
+
+   z[0].shape = (1, 2, 1)
+
+   w = split_v2(x, axis=0, indices_or_sections=(1,)) // a list of 2 arrays with shape [(1, 2, 1), (2, 2, 1)]
+   w = [[[ 1.]
+         [ 2.]]]
+
+       [[[3.]
+         [4.]]
+
+        [[5.]
+         [6.]]]
+
+  w[0].shape = (1, 2, 1)
+  w[1].shape = (2, 2, 1)
+
+`squeeze_axis=True` removes the axis with length 1 from the shapes of the output arrays.
+**Note** that setting `squeeze_axis` to ``1`` removes axis with length 1 only
+along the `axis` which it is split.
+Also `squeeze_axis` can be set to true only if ``input.shape[axis] == indices_or_sections``.
+
+Example::
+
+   z = split_v2(x, axis=0, indices_or_sections=3, squeeze_axis=1) // a list of 3 arrays with shape (2, 1)
+   z = [[ 1.]
+        [ 2.]]
+
+       [[ 3.]
+        [ 4.]]
+
+       [[ 5.]
+        [ 6.]]
+   z[0].shape = (2, 1)
+
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<SplitParam>)
+.set_num_inputs(1)
+.set_num_outputs(SplitNumOutputs)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data"};
+  })
+.set_attr<nnvm::FInferShape>("FInferShape", SplitOpShape)
+.set_attr<nnvm::FInferType>("FInferType", SplitOpType)
+.set_attr<FCompute>("FCompute<cpu>", SplitOpForward<cpu>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& n) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_split_v2_backward"})
+.add_argument("data", "NDArray-or-Symbol", "The input")
+.add_arguments(SplitParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_split_v2_backward)
+.set_attr_parser(ParamParser<SplitParam>)
+.set_num_inputs(SplitNumOutputs)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& n) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+.set_attr<FCompute>("FCompute<cpu>", SplitOpBackward<cpu>);
+
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu
index 4e31a4c..8731127 100644
--- a/src/operator/tensor/matrix_op.cu
+++ b/src/operator/tensor/matrix_op.cu
@@ -217,5 +217,11 @@ NNVM_REGISTER_OP(depth_to_space)
 NNVM_REGISTER_OP(space_to_depth)
 .set_attr<FCompute>("FCompute<gpu>", SpaceToDepthOpForward<gpu>);
 
+NNVM_REGISTER_OP(_split_v2)
+.set_attr<FCompute>("FCompute<gpu>", SplitOpForward<gpu>);
+
+NNVM_REGISTER_OP(_split_v2_backward)
+.set_attr<FCompute>("FCompute<gpu>", SplitOpBackward<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 0aa4855..7176b18 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1056,7 +1056,7 @@ def test_output():
 @with_seed()
 def test_ndarray_fluent():
     has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
-                    'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',
+                    'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split', 'split_v2',
                     'broadcast_axes', 'pad', 'swapaxes', 'slice', 'slice_axis', 'slice_like',
                     'take', 'one_hot', 'pick', 'sort', 'topk', 'argsort', 'argmax', 'argmin',
                     'clip', 'abs', 'sign', 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
@@ -1093,6 +1093,8 @@ def test_ndarray_fluent():
     check_fluent_regular('repeat', {'repeats': 3})
     check_fluent_regular('transpose', {'axes': (1,0,2)})
     check_fluent_regular('split', {'axis': 2, 'num_outputs': 3}, shape=(5, 17, 6))
+    check_fluent_regular('split_v2', {'axis': 2, 'indices_or_sections': 3}, shape=(5, 17, 6))
+    check_fluent_regular('split_v2', {'axis': 2, 'indices_or_sections': (1, 3, 5)}, shape=(5, 17, 6))
     check_fluent_regular('slice', {'begin': (2, 5, 1), 'end': (4, 7, 6)}, shape=(5, 17, 6))
     check_fluent_regular('slice_axis', {'axis': 1, 'begin': 5, 'end': 7})
     check_fluent_regular('slice_like', {'axes': (0, -2), 'shape_like': mx.nd.zeros((3, 3))})
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index cda801c..67aeddf 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -7249,6 +7249,25 @@ def test_softmax_cross_entropy():
 
 
 @with_seed()
+def test_split_v2():
+    dim = random.randint(2, 6)
+    shape = rand_shape_nd(dim)
+    axis = random.randint(-dim, dim-1)
+    axis_size = shape[axis]
+    samples = random.randint(0, axis_size - 1)
+    indices = sorted(random.sample([i for i in range(1, axis_size)], samples))
+    indices = tuple(indices)
+    mx_data = rand_ndarray(shape)
+    np_data = mx_data.asnumpy()
+    np_out = np.split(np_data, indices_or_sections=indices, axis=axis)
+    data = mx.sym.Variable("data")
+    sym = mx.sym.split_v2(data, indices_or_sections=indices, axis=axis)
+    check_symbolic_forward(sym, {"data": mx_data}, np_out, rtol=1e-3, atol=1e-5)
+    out_grad = [np.ones(arr.shape) for arr in np_out]
+    check_symbolic_backward(sym, {"data": mx_data}, out_grad, [np.concatenate(out_grad, axis=axis)])
+
+
+@with_seed()
 def test_invalid_kernel_size():
     invalid_kernel_size = 28
     assert_exception(