You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/08 22:41:09 UTC

[GitHub] piiswrong closed pull request #8558: slice operator supporting arbitrary values of step

piiswrong closed pull request #8558: slice operator supporting arbitrary values of step
URL: https://github.com/apache/incubator-mxnet/pull/8558
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/common/static_array.h b/src/common/static_array.h
new file mode 100644
index 0000000000..8d51967b17
--- /dev/null
+++ b/src/common/static_array.h
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file static_array.h
+ */
+#ifndef MXNET_COMMON_STATIC_ARRAY_H_
+#define MXNET_COMMON_STATIC_ARRAY_H_
+
+#include <mshadow/base.h>
+
+namespace mxnet {
+namespace common {
+
+/*! \brief
+ * Static array. This code is borrowed from struct Shape<ndim>,
+ * except that users can specify the type of the elements of
+ * the statically allocated array.
+ * The object instance of the struct is copyable between CPU and GPU.
+ * \tparam T element type of the array, must be copyable between CPU and GPU
+ * \tparam num number of elements in the array
+ */
+template<typename T, int num>
+struct StaticArray {
+  static const int kNum = num;
+
+  T array_[kNum];
+
+  /*! \brief default constructor, do nothing */
+  MSHADOW_XINLINE StaticArray(void) {}
+
+  /*! \brief constructor, fill in the array with the input value */
+  MSHADOW_XINLINE StaticArray(const T& val) {
+    #pragma unroll
+    for (int i = 0; i < num; ++i) {
+      this->array_[i] = val;
+    }
+  }
+
+  /*! \brief constuctor */
+  MSHADOW_XINLINE StaticArray(const StaticArray<T, num>& sa) {
+    #pragma unroll
+    for (int i = 0; i < num; ++i) {
+      this->array_[i] = sa[i];
+    }
+  }
+
+  MSHADOW_XINLINE T& operator[](const index_t idx) {
+    return array_[idx];
+  }
+
+  MSHADOW_XINLINE const T& operator[](const index_t idx) const {
+    return array_[idx];
+  }
+};  // StaticArray
+
+}  // namespace common
+}  // namespace mxnet
+#endif  // MXNET_COMMON_STATIC_ARRAY_H_
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 06e2393524..5b8e109d7d 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -108,6 +108,28 @@ inline int get_num_threads<cpu>(const int N) {
   }
 
 
+#define MXNET_NDIM_SWITCH(NDim, ndim, ...)         \
+  if (NDim == 0) {                                 \
+  } else if (NDim == 1) {                          \
+    const int ndim = 1;                            \
+    {__VA_ARGS__}                                  \
+  } else if (NDim == 2) {                          \
+    const int ndim = 2;                            \
+    {__VA_ARGS__}                                  \
+  } else if (NDim == 3) {                          \
+    const int ndim = 3;                            \
+    {__VA_ARGS__}                                  \
+  } else if (NDim == 4) {                          \
+    const int ndim = 4;                            \
+    {__VA_ARGS__}                                  \
+  } else if (NDim == 5) {                          \
+    const int ndim = 5;                            \
+    {__VA_ARGS__}                                  \
+  } else {                                         \
+    LOG(FATAL) << "ndim=" << NDim << "too large "; \
+  }
+
+
 /*!
  * \brief assign the val to out according
  * to request in Kernel::Launch
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index bbcad70d69..7c8e53e529 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -404,7 +404,8 @@ Examples::
 
   data = [2, 3, 0]
   indices = [[1, 1, 0], [0, 1, 0]]
-  scatter_nd(data, indices) = [[0, 0], [2, 3]]
+  shape = (2, 2)
+  scatter_nd(data, indices, shape) = [[0, 0], [2, 3]]
 
 )code")
 .set_num_outputs(1)
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 262a431007..6a27b7296c 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -740,7 +740,7 @@ inline bool TakeOpShape(const nnvm::NodeAttrs& attrs,
     using namespace mshadow;
     const TShape &arrshape = (*in_attrs)[take_::kArr];
     const TShape &idxshape = (*in_attrs)[take_::kIdx];
-    if (idxshape.ndim() == 0) return false;
+    if (idxshape.ndim() == 0U || idxshape.Size() == 0U) return false;
 
     out_attrs->clear();
 
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index e012efe25b..6940cfa71a 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -34,6 +34,8 @@
 #include "../channel_op_common.h"
 #include "../mxnet_op.h"
 #include "broadcast_reduce_op.h"
+#include "./init_op.h"
+#include "../../common/static_array.h"
 
 #if MXNET_USE_CUDA
 #include <thrust/device_vector.h>
@@ -370,12 +372,16 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
 }
 
 struct SliceParam : public dmlc::Parameter<SliceParam> {
-  nnvm::Tuple<dmlc::optional<int> > begin, end;
+  nnvm::Tuple<dmlc::optional<int>> begin, end;
+  nnvm::Tuple<dmlc::optional<int>> step;
   DMLC_DECLARE_PARAMETER(SliceParam) {
     DMLC_DECLARE_FIELD(begin)
     .describe("starting indices for the slice operation, supports negative indices.");
     DMLC_DECLARE_FIELD(end)
     .describe("ending indices for the slice operation, supports negative indices.");
+    DMLC_DECLARE_FIELD(step)
+    .set_default(nnvm::Tuple<dmlc::optional<int>>())
+    .describe("step for the slice operation, supports negative values.");
   }
 };
 
@@ -414,16 +420,6 @@ inline TShape GetSliceShape(const SliceParam& param, const TShape& dshape) {
   return oshape;
 }
 
-inline bool SliceShape(const nnvm::NodeAttrs& attrs,
-                       std::vector<TShape> *in_attrs,
-                       std::vector<TShape> *out_attrs) {
-  const TShape& dshape = (*in_attrs)[0];
-  if (dshape.ndim() == 0) return false;
-  const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
-  SHAPE_ASSIGN_CHECK(*out_attrs, 0, GetSliceShape(param, dshape));
-  return true;
-}
-
 inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
                                          const int dev_mask,
                                          DispatchMode* dispatch_mode,
@@ -438,12 +434,20 @@ inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
   const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
   const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback :
                                          DispatchMode::kFComputeEx;
+  // If step = 1, no need to fallback; otherwise fallback to dense
+  bool trivial_step = false;
+  if (param.step.ndim() == 0U) {
+    trivial_step = true;
+  } else if (param.step.ndim() == 1U
+      && (!param.step[0].has_value() || param.step[0].value() == 1)) {
+    trivial_step = true;
+  }
   if (!dispatched && in_stype == kDefaultStorage) {
     dispatched = storage_type_assign(&out_stype, kDefaultStorage,
                                      dispatch_mode, DispatchMode::kFCompute);
   }
 
-  if (!dispatched && in_stype == kCSRStorage) {
+  if (!dispatched && in_stype == kCSRStorage && trivial_step) {
     dispatched = storage_type_assign(&out_stype, kCSRStorage,
                                      dispatch_mode, dispatch_ex);
   }
@@ -458,75 +462,6 @@ inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-// matrix crop for multi dimensional cropping: see also slice
-template<typename xpu>
-void Slice(const nnvm::NodeAttrs& attrs,
-          const OpContext& ctx,
-          const std::vector<TBlob>& inputs,
-          const std::vector<OpReqType>& req,
-          const std::vector<TBlob>& outputs) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
-  index_t N = inputs[0].ndim();
-  TShape begin(N), end(N);
-  for (index_t i = 0; i < N; ++i) {
-    int s = 0;
-    if (i < param.begin.ndim() && param.begin[i]) {
-      s = *param.begin[i];
-      if (s < 0) {
-        s += inputs[0].size(i);
-        CHECK(s >= 0)
-            << "Invalid slicing begin " << param.begin << " and end "
-            << param.end << " for data of shape " << inputs[0].shape_;
-      }
-    }
-    begin[i] = s;
-    end[i] = s + outputs[0].size(i);
-  }
-
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    switch (inputs[0].ndim()) {
-     case 0:
-      break;
-     case 1: {
-      Tensor<xpu, 1, DType> in = inputs[0].get<xpu, 1, DType>(s);
-      Tensor<xpu, 1, DType> out = outputs[0].get<xpu, 1, DType>(s);
-      out = slice(in, begin.get<1>(), end.get<1>());
-      break;
-     }
-     case 2: {
-      Tensor<xpu, 2, DType> in = inputs[0].get<xpu, 2, DType>(s);
-      Tensor<xpu, 2, DType> out = outputs[0].get<xpu, 2, DType>(s);
-      out = slice(in, begin.get<2>(), end.get<2>());
-      break;
-     }
-     case 3: {
-      Tensor<xpu, 3, DType> in = inputs[0].get<xpu, 3, DType>(s);
-      Tensor<xpu, 3, DType> out = outputs[0].get<xpu, 3, DType>(s);
-      out = slice(in, begin.get<3>(), end.get<3>());
-      break;
-     }
-     case 4: {
-      Tensor<xpu, 4, DType> in = inputs[0].get<xpu, 4, DType>(s);
-      Tensor<xpu, 4, DType> out = outputs[0].get<xpu, 4, DType>(s);
-      out = slice(in, begin.get<4>(), end.get<4>());
-      break;
-     }
-     case 5: {
-      Tensor<xpu, 5, DType> in = inputs[0].get<xpu, 5, DType>(s);
-      Tensor<xpu, 5, DType> out = outputs[0].get<xpu, 5, DType>(s);
-      out = slice(in, begin.get<5>(), end.get<5>());
-      break;
-     }
-     default:
-      LOG(FATAL) << "slice supports at most 5 dimensions";
-      break;
-    }
-  });
-}
-
 // slice the indptr of a csr
 struct SliceCsrIndPtr {
   template<typename IType>
@@ -747,6 +682,227 @@ void SliceEx(const nnvm::NodeAttrs& attrs,
   }
 }
 
+template<int ndim>
+inline void GetIndexRange(const SliceParam& param,
+                          const TShape& dshape,
+                          common::StaticArray<int, ndim>* begin,
+                          common::StaticArray<int, ndim>* end,
+                          common::StaticArray<int, ndim>* step) {
+  CHECK_NE(dshape.ndim(), 0U);
+  CHECK_NE(dshape.Size(), 0U);
+  CHECK_LE(param.begin.ndim(), dshape.ndim())
+    << "Slicing axis exceeds data dimensions";
+  CHECK_LE(param.end.ndim(), dshape.ndim())
+    << "Slicing axis exceeds data dimensions";
+  CHECK_EQ(param.begin.ndim(), param.end.ndim())
+    << "begin and end must have the same length";
+  CHECK_EQ(ndim, dshape.ndim())
+    << "Static array size=" << ndim
+    << " is not equal to data shape ndim=" << dshape.ndim();
+
+  if (param.step.ndim() != 0U) {
+    CHECK_EQ(param.step.ndim(), param.begin.ndim())
+      << "step and begin must have the same length";
+  }
+
+  for (index_t i = 0; i < param.begin.ndim(); ++i) {
+    int b = 0, e = dshape[i], s = 1;
+    const int len = dshape[i];
+    if (param.step.ndim() != 0U) {
+      const auto& opt_step_val = param.step[i];
+      if (opt_step_val.has_value()) {
+        s = opt_step_val.value();
+        CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0";
+      }
+    }
+
+    if (param.begin[i].has_value()) {
+      b = param.begin[i].value();
+      if (b < 0) {
+        b += len;
+        CHECK_GE(b, 0) << "slicing with begin[" << i << "]="
+                       << b - len << " exceeds limit of " << len;
+      }
+    } else if (s < 0) {
+      b = len - 1;
+    }
+    CHECK_LT(b, len) << "slicing with begin[" << i << "]="
+                     << b << " exceends limit of " << len;
+
+    if (param.end[i].has_value()) {
+      e = param.end[i].value();
+      if (e < 0) {
+        e += len;
+        CHECK_GE(e, 0) << "slicing with end[" << i << "]="
+                       << e - len << " exceeds limit of " << len;
+      }
+    } else if (s < 0) {
+      e = -1;
+    }
+    CHECK_LE(e, len) << "slicing with end[" << i << "]="
+                     << e << " exceeds limit of " << len;
+
+    (*begin)[i] = b;
+    (*end)[i] = e;
+    (*step)[i] = s;
+  }
+  for (index_t i = param.begin.ndim(); i < dshape.ndim(); ++i) {
+    (*begin)[i] = 0;
+    (*end)[i] = dshape[i];
+    (*step)[i] = 1;
+  }
+}
+
+inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
+                         std::vector<TShape>* in_attrs,
+                         std::vector<TShape>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const TShape& dshape = (*in_attrs)[0];
+  if (dshape.ndim() == 0 || dshape.Size() == 0) return false;
+  const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
+  TShape oshape = dshape;
+  MXNET_NDIM_SWITCH(dshape.ndim(), ndim, {
+    common::StaticArray<int, ndim> begin, end, step;
+    GetIndexRange(param, dshape, &begin, &end, &step);
+
+    for (index_t i = 0; i < param.begin.ndim(); ++i) {
+      const int b = begin[i], e = end[i], s = step[i];
+      if (s > 0) {
+        CHECK_LT(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
+                       << e << ", and step[" << i << "]=" << s << " is invalid";
+        oshape[i] = (e - b - 1) / s + 1;
+      } else {
+        CHECK_LT(e, b) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
+                       << e << ", and step[" << i << "]=" << s << " is invalid";
+        oshape[i] = (b - e - 1) / (-s) + 1;
+      }
+    }
+  });
+
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
+  return oshape.ndim() != 0 && oshape.Size() != 0;
+}
+
+template<int ndim>
+struct slice_forward {
+  // i is the i-th row after flattening out into 2D tensor
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data,
+                                  const OpReqType req,
+                                  const mshadow::Shape<ndim> dshape,
+                                  const mshadow::Shape<ndim> oshape,
+                                  const common::StaticArray<int, ndim> begin,
+                                  const common::StaticArray<int, ndim> step) {
+    const int data_last_dim_size = dshape[ndim-1];
+    const int out_last_dim_size = oshape[ndim-1];
+    const int step_last_dim = step[ndim-1];
+    const int begin_last_dim = begin[ndim-1];
+    int out_offset = i * out_last_dim_size;
+    for (int j = 0; j < out_last_dim_size; ++j) {
+      int irow = 0;  // row id of flattend 2D data
+      int stride = 1;
+      int idx = i;
+      #pragma unroll
+      for (int k = ndim - 2; k >= 0; --k) {
+        irow += stride * ((idx % oshape[k]) * step[k] + begin[k]);
+        idx /= oshape[k];
+        stride *= dshape[k];
+      }
+      KERNEL_ASSIGN(out[out_offset++], req,
+                    data[irow * data_last_dim_size + j * step_last_dim + begin_last_dim]);
+    }
+  }
+};
+
+template<typename xpu>
+void SliceOpForward(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<TBlob>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  if (req[0] == kNullOp) return;
+  using namespace mshadow;
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  const TBlob& data = inputs[0];
+  const TBlob& out = outputs[0];
+  const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
+  MXNET_NDIM_SWITCH(data.ndim(), ndim, {
+    common::StaticArray<int, ndim> begin, end, step;
+    GetIndexRange(param, data.shape_, &begin, &end, &step);
+    MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
+      mxnet_op::Kernel<slice_forward<ndim>, xpu>::Launch(s, out.shape_.FlatTo2D()[0],
+          out.dptr<DType>(), data.dptr<DType>(), req[0],
+          data.shape_.get<ndim>(), out.shape_.get<ndim>(), begin, step);
+    })
+  })
+}
+
+template<int ndim>
+struct slice_backward {
+  // i is the i-th row after flattening out into 2D tensor
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* igrad, const DType* ograd,
+                                  const OpReqType req,
+                                  const mshadow::Shape<ndim> dshape,
+                                  const mshadow::Shape<ndim> oshape,
+                                  const common::StaticArray<int, ndim> begin,
+                                  const common::StaticArray<int, ndim> step) {
+    const int data_last_dim_size = dshape[ndim-1];
+    const int out_last_dim_size = oshape[ndim-1];
+    const int step_last_dim = step[ndim-1];
+    const int begin_last_dim = begin[ndim-1];
+    int ograd_offset = i * out_last_dim_size;
+    for (int j = 0; j < out_last_dim_size; ++j) {
+      int irow = 0;  // row id of flattend 2D igrad
+      int stride = 1;
+      int idx = i;
+      #pragma unroll
+      for (int k = ndim - 2; k >= 0; --k) {
+        irow += stride * ((idx % oshape[k]) * step[k] + begin[k]);
+        idx /= oshape[k];
+        stride *= dshape[k];
+      }
+      KERNEL_ASSIGN(igrad[irow * data_last_dim_size + j * step_last_dim + begin_last_dim],
+                    req, ograd[ograd_offset++]);
+    }
+  }
+};
+
+template<typename xpu>
+void SliceOpBackward(const nnvm::NodeAttrs& attrs,
+                     const OpContext& ctx,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  if (req[0] == kNullOp) return;
+  using namespace mshadow;
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  const TBlob& ograd = inputs[0];
+  const TBlob& igrad = outputs[0];
+  const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
+  if (req[0] == kWriteTo) {
+    Fill(s, igrad, req[0], 0);
+  } else if (req[0] == kWriteInplace) {
+    LOG(FATAL) << "_slice_backward does not support kWriteInplace";
+  }
+  MXNET_NDIM_SWITCH(ograd.ndim(), ndim, {
+    common::StaticArray<int, ndim> begin, end, step;
+    GetIndexRange(param, igrad.shape_, &begin, &end, &step);
+    MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
+      mxnet_op::Kernel<slice_backward<ndim>, xpu>::Launch(s, ograd.shape_.FlatTo2D()[0],
+          igrad.dptr<DType>(), ograd.dptr<DType>(), req[0],
+          igrad.shape_.get<ndim>(), ograd.shape_.get<ndim>(), begin, step);
+    })
+  })
+}
+
 inline bool SliceAssignShape(const nnvm::NodeAttrs& attrs,
                              std::vector<TShape> *in_attrs,
                              std::vector<TShape> *out_attrs) {
@@ -843,32 +999,6 @@ void SliceAssign(const nnvm::NodeAttrs& attrs,
   SliceAssignImpl<xpu>(s, param, outputs[0], inputs[1]);
 }
 
-template<typename xpu>
-void SliceBackward(const nnvm::NodeAttrs& attrs,
-                   const OpContext& ctx,
-                   const std::vector<TBlob>& inputs,
-                   const std::vector<OpReqType>& req,
-                   const std::vector<TBlob>& outputs) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-
-  const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-
-  if (req[0] == kNullOp) {
-    return;
-  } else if (req[0] == kWriteTo) {
-    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-      Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
-      out = DType(0);
-    });
-  } else {
-    LOG(FATAL) << "CropAssign only supports kWriteTo";
-  }
-
-  SliceAssignImpl<xpu>(s, param, outputs[0], inputs[0]);
-}
-
 struct SimpleCropAssignScalarParam : public dmlc::Parameter<SimpleCropAssignScalarParam> {
   real_t scalar;
   TShape begin, end;
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index b76f9e94b5..7f109b69d8 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -249,25 +249,39 @@ will return a new array with shape ``(2,1,3,4)``.
 NNVM_REGISTER_OP(slice)
 .add_alias("_sparse_slice")
 .add_alias("crop")
-.describe(R"code(Slices a contiguous region of the array.
+.describe(R"code(Slices a region of the array.
 
 .. note:: ``crop`` is deprecated. Use ``slice`` instead.
 
-This function returns a sliced continuous region of the array between the indices given
-by `begin` and `end`.
+This function returns a sliced array between the indices given
+by `begin` and `end` with the corresponding `step`.
 
-For an input array of `n` dimensions, slice operation with ``begin=(b_0, b_1...b_n-1)`` indices
-and ``end=(e_1, e_2, ... e_n)`` indices will result in an array with the shape
-``(e_1-b_0, ..., e_n-b_n-1)``.
+For an input array of ``shape=(d_0, d_1, ..., d_n-1)``,
+slice operation with ``begin=(b_0, b_1...b_m-1)``,
+``end=(e_0, e_1, ..., e_m-1)``, and ``step=(s_0, s_1, ..., s_m-1)``,
+where m <= n, results in an array with the shape
+``(|e_0-b_0|/|s_0|, ..., |e_m-1-b_m-1|/|s_m-1|, d_m, ..., d_n-1)``.
 
 The resulting array's *k*-th dimension contains elements
-from the *k*-th dimension of the input array with the open range ``[b_k, e_k)``.
+from the *k*-th dimension of the input array starting
+from index ``b_k`` (inclusive) with step ``s_k``
+until reaching ``e_k`` (exclusive).
+
+If the *k*-th elements are `None` in the sequence of `begin`, `end`,
+and `step`, the following rule will be used to set default values.
+If `s_k` is `None`, set `s_k=1`. If `s_k > 0`, set `b_k=0`, `e_k=d_k`;
+else, set `b_k=d_k-1`, `e_k=-1`.
 
 The storage type of ``slice`` output depends on storage types of inputs
 
 - slice(csr) = csr
 - otherwise, ``slice`` generates output with default storage
 
+.. note:: When input data storage type is csr, it only supports
+step=(), or step=(None,), or step=(1,) to generate a csr output.
+For other step parameter values, it falls back to slicing
+a dense tensor.
+
 Example::
 
   x = [[  1.,   2.,   3.,   4.],
@@ -276,14 +290,16 @@ Example::
 
   slice(x, begin=(0,1), end=(2,4)) = [[ 2.,  3.,  4.],
                                      [ 6.,  7.,  8.]]
-
+  slice(x, begin=(None, 0), end=(None, 3), step=(-1, 2)) = [[9., 11.],
+                                                            [5.,  7.],
+                                                            [1.,  3.]]
 )code" ADD_FILELINE)
 .set_attr_parser(ParamParser<SliceParam>)
-.set_attr<nnvm::FInferShape>("FInferShape", SliceShape)
+.set_attr<nnvm::FInferShape>("FInferShape", SliceOpShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<FInferStorageType>("FInferStorageType", SliceForwardInferStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice"})
-.set_attr<FCompute>("FCompute<cpu>", Slice<cpu>)
+.set_attr<FCompute>("FCompute<cpu>", SliceOpForward<cpu>)
 .set_attr<FComputeEx>("FComputeEx<cpu>", SliceEx<cpu>)
 .add_argument("data", "NDArray-or-Symbol", "Source input")
 .add_arguments(SliceParam::__FIELDS__());
@@ -291,7 +307,7 @@ Example::
 NNVM_REGISTER_OP(_backward_slice)
 .set_attr_parser(ParamParser<SliceParam>)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FCompute>("FCompute<cpu>", SliceBackward<cpu>);
+.set_attr<FCompute>("FCompute<cpu>", SliceOpBackward<cpu>);
 
 NNVM_REGISTER_OP(_slice_assign)
 .add_alias("_crop_assign")
diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu
index 85d81a79cc..3866fc419f 100644
--- a/src/operator/tensor/matrix_op.cu
+++ b/src/operator/tensor/matrix_op.cu
@@ -39,10 +39,10 @@ NNVM_REGISTER_OP(expand_dims)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
 
 NNVM_REGISTER_OP(slice)
-.set_attr<FCompute>("FCompute<gpu>", Slice<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", SliceOpForward<gpu>);
 
 NNVM_REGISTER_OP(_backward_slice)
-.set_attr<FCompute>("FCompute<gpu>", SliceBackward<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", SliceOpBackward<gpu>);
 
 NNVM_REGISTER_OP(_slice_assign)
 .set_attr<FCompute>("FCompute<gpu>", SliceAssign<gpu>);
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index ef866dd768..93dc4a0534 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4611,6 +4611,43 @@ def test_softmax():
     check_smoothed_softmax_grad(default_context())
 
 
+def test_slice():
+    def test_slice_forward_backward(a, index):
+        a_np = a.asnumpy()
+        begin = []
+        end = []
+        step = []
+        for slice_i in index:
+            begin.append(slice_i.start)
+            end.append(slice_i.stop)
+            step.append(slice_i.step)
+        b = mx.nd.slice(a, begin=begin, end=end, step=step)
+        b_np = a_np[index]
+        assert same(b.asnumpy(), b_np)
+
+        data = mx.sym.Variable('data')
+        slice_sym = mx.sym.slice(data, begin=begin, end=end, step=step)
+        expected_in_grad = np.zeros_like(a_np)
+        expected_in_grad[index] = b_np
+        check_symbolic_backward(slice_sym, [a_np], [b_np], [expected_in_grad])
+
+    shape = (16, 14, 17, 20)
+    arr = mx.nd.arange(np.prod(shape)).reshape(shape=shape)
+    index_list = [(slice(None),), (slice(None), slice(None)), (slice(1, 10),), (slice(1, 10), slice(3, 9)),
+                  (slice(1, 10), slice(2, 5), slice(3, 6), slice(7, 10)),
+                  (slice(1, 10, 2), slice(2, 9, 3), slice(3, 6, 5), slice(7, 10, 2)),
+                  (slice(None, None, -1), slice(None, None, -1), slice(None, None, -1)),
+                  (slice(10, 0, -2), slice(5, 2, -1), slice(7, None, 3), slice(None, 12, 4))]
+    for index in index_list:
+        test_slice_forward_backward(arr, index)
+
+    # check numeric gradient
+    in_data = np.arange(36).reshape(2, 2, 3, 3)
+    data = mx.sym.Variable('data')
+    slice_sym = mx.sym.slice(data, begin=[0, None], end=[1, None], step=[2, -1])
+    check_numeric_gradient(slice_sym, [in_data])
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 7f24799abf..7576050f55 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -131,6 +131,21 @@ def test_sparse_nd_slice():
     result_dense = mx.nd.slice(mx.nd.array(A2), begin=(start, start_col), end=(end, end_col))
     assert same(result_dense.asnumpy(), result.asnumpy())
 
+    def check_slice_nd_csr_fallback(shape):
+        stype = 'csr'
+        A, _ = rand_sparse_ndarray(shape, stype)
+        A2 = A.asnumpy()
+        start = rnd.randint(0, shape[0] - 1)
+        end = rnd.randint(start + 1, shape[0])
+
+        # non-trivial step should fallback to dense slice op
+        result = mx.nd.sparse.slice(A, begin=(start,), end=(end + 1,), step=(2,))
+        result_dense = mx.nd.slice(mx.nd.array(A2), begin=(start,), end=(end + 1,), step=(2,))
+        assert same(result_dense.asnumpy(), result.asnumpy())
+
+    shape = (rnd.randint(2, 10), rnd.randint(1, 10))
+    check_slice_nd_csr_fallback(shape)
+
 
 def test_sparse_nd_equal():
     for stype in ['row_sparse', 'csr']:


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services