You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/27 23:28:27 UTC

[GitHub] haojin2 commented on a change in pull request #11326: [MXNET-381] Enhancement of take operator

haojin2 commented on a change in pull request #11326: [MXNET-381] Enhancement of take operator
URL: https://github.com/apache/incubator-mxnet/pull/11326#discussion_r198670531
 
 

 ##########
 File path: src/operator/tensor/indexing_op.h
 ##########
 @@ -805,17 +836,259 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs,
   const TShape& oshape = outputs[take_::kOut].shape_;
 
   Stream<xpu> *s = ctx.get_stream<xpu>();
+  const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);
+
   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {  // output data type
     MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // index data type
-      Kernel<Take, xpu>::Launch(s, oshape.Size(),
-                                outputs[take_::kOut].dptr<DType>(),
-                                inputs[take_::kArr].dptr<DType>(),
-                                inputs[take_::kIdx].dptr<IType>(),
-                                oshape.Size()/idxshape.Size(), arrshape[0]);
+      if (actual_axis == 0) {
+        if (param.mode == take_::kClip) {
+          Kernel<Take<true>, xpu>::Launch(s, oshape.Size(),
+                                          outputs[take_::kOut].dptr<DType>(),
+                                          inputs[take_::kArr].dptr<DType>(),
+                                          inputs[take_::kIdx].dptr<IType>(),
+                                          oshape.Size()/idxshape.Size(), arrshape[0]);
+        } else {
+          Kernel<Take<false>, xpu>::Launch(s, oshape.Size(),
+                                           outputs[take_::kOut].dptr<DType>(),
+                                           inputs[take_::kArr].dptr<DType>(),
+                                           inputs[take_::kIdx].dptr<IType>(),
+                                           oshape.Size()/idxshape.Size(), arrshape[0]);
+        }
+      } else {
+        mshadow::Shape<10> in_strides;
+        int stride = 1;
+        for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) {
+          in_strides[i] = stride;
+        }
+        mshadow::Shape<10> out_strides;
+        stride = 1;
+        for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) {
+          out_strides[i] = stride;
+        }
+        if (param.mode == take_::kClip) {
+          Kernel<Take<true>, xpu>::Launch(s, oshape.Size(),
+                                          outputs[take_::kOut].dptr<DType>(),
+                                          inputs[take_::kArr].dptr<DType>(),
+                                          inputs[take_::kIdx].dptr<IType>(),
+                                          in_strides, out_strides, arrshape.ndim(), oshape.ndim(),
+                                          idxshape.ndim(), arrshape[actual_axis], actual_axis);
+        } else if (param.mode == take_::kWrap) {
+          Kernel<Take<false>, xpu>::Launch(s, oshape.Size(),
+                                           outputs[take_::kOut].dptr<DType>(),
+                                           inputs[take_::kArr].dptr<DType>(),
+                                           inputs[take_::kIdx].dptr<IType>(),
+                                           in_strides, out_strides, arrshape.ndim(), oshape.ndim(),
+                                           idxshape.ndim(), arrshape[actual_axis], actual_axis);
+        }
+      }
+    });
+  });
+}
+
+struct TakeGradGeneralKernel {
+  /*!
+   * \brief Map function for general case of take grad
+   * \param tid           global thread id
+   * \param arr_grad      ptr to in_grad
+   * \param ograd         ptr to out_grad
+   * \param src_indptr    ptr to indptr to src indices
+   * \param original_idx  ptr to original indices of the inputs
+   * \param in_strides    strides of inputs
+   * \param out_strides   strides of outputs
+   * \param in_ndims      # of dims of input tensor
+   * \param out_ndims     # of dims of output tensor
+   * \param idx_ndims     # of dims of indices tensor
+   * \param axis_dim      dim size of the axis dimension
+   * \param axis          axis id
+   */
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, const DType* ograd,
+                                  const IType* src_indptr, const IType* original_idx,
+                                  mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides,
+                                  const int in_ndims, const int out_ndims, const int idx_ndims,
+                                  const int axis) {
+    const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1];
+    const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1];
+    const int in_mid_index = in_rest_index / in_strides[axis];
+    const int in_tail_index = (axis == in_ndims - 1) ?
+                              0 : (in_rest_index % in_strides[axis]);
+    for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) {
+      const int out_mid_index = original_idx[i];
+      int target = in_tail_index + out_mid_index * in_strides[axis];
+      target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1];
+      arr_grad[tid] += ograd[target];
+    }
+  }
+};
+
+template<bool clip = true>
+void TakeOpBackwardImpl(mshadow::Stream<cpu>* s,
+                        const OpContext& ctx,
+                        const TBlob& arr,
+                        const TBlob& idx,
+                        const TBlob& ograd,
+                        const int axis) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation";
+  const TShape& arrshape = arr.shape_;
+  const TShape& idxshape = idx.shape_;
+  const TShape& oshape = ograd.shape_;
+  MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, {
+    // get size of temporary storage for sort
+    char* temp_storage_ptr = nullptr;
+    int* src_indptr_ptr = nullptr;
+    size_t temp_storage_bytes = SortByKeyWorkspaceSize<int, int, cpu>(idxshape.Size());
+    size_t original_idx_bytes = idxshape.Size() * sizeof(int);
+    size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int);
+    size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes;
+    Tensor<cpu, 1, char> workspace =
+      ctx.requested[0].get_space_typed<cpu, 1, char>(Shape1(workspace_bytes), s);
+    int* sorted_idx_ptr = reinterpret_cast<int*>(workspace.dptr_);
+    int* original_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ + original_idx_bytes);
+    src_indptr_ptr = reinterpret_cast<int*>(workspace.dptr_ + 2 * original_idx_bytes);
+    temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes;
+    // Reset indptr to zero
+    Kernel<set_zero, cpu>::Launch(s, arrshape[axis] + 1, src_indptr_ptr);
+    // Fill original_idx
+    Kernel<range_fwd, cpu>::Launch(s, idxshape.Size(), 1, 0, 1, kWriteTo, original_idx_ptr);
+    // Fill sorted_idx_ptr with unsorted copy of idx
+    Kernel<mshadow_op::identity_with_cast, cpu>::Launch(
+      s, idxshape.Size(), sorted_idx_ptr, idx.dptr<IType>());
+    if (clip) {
+      Kernel<op_with_req<mshadow_op::clip, kWriteTo>, cpu>::Launch(
+        s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr,
+        0, static_cast<int>(arrshape[axis] - 1));
+    } else {
+      Kernel<op_with_req<mshadow_op::mod, kWriteTo>, cpu>::Launch(
+        s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, static_cast<int>(arrshape[axis]));
+    }
+    Tensor<cpu, 1, int> original_idx(original_idx_ptr, Shape1(idxshape.Size()), s);
+    Tensor<cpu, 1, char> temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s);
+    int num_bits = ilog2(static_cast<unsigned int>(idxshape.Size()) - 1);
+    Tensor<cpu, 1, int> sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s);
+    SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits);
+    for (size_t i = 0; i < idxshape.Size(); ++i) {
+      src_indptr_ptr[sorted_idx_ptr[i] + 1] += 1;
+    }
+    for (int i = 0; i < arrshape[axis]; ++i) {
+      src_indptr_ptr[i + 1] += src_indptr_ptr[i];
+    }
+    Shape<10> in_strides;
+    int stride = 1;
+    for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) {
+      in_strides[i] = stride;
+    }
+    Shape<10> out_strides;
+    stride = 1;
+    for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) {
+      out_strides[i] = stride;
+    }
+    MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, {
+      Kernel<TakeGradGeneralKernel, cpu>::Launch(
+        s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(), src_indptr_ptr,
+        original_idx_ptr, in_strides, out_strides,
+        arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis);
     });
   });
 }
 
+#ifdef __CUDACC__
 
 Review comment:
   I would like to re-use the kernel here, if I move this to cuh and the cpu compiler will not see that kernel.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services