You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2022/08/02 14:26:29 UTC

[GitHub] [incubator-mxnet] bgawrych commented on a diff in pull request #21112: [BUGFIX] _npi_repeats with swap

bgawrych commented on code in PR #21112:
URL: https://github.com/apache/incubator-mxnet/pull/21112#discussion_r935664043


##########
src/operator/numpy/np_repeat_op-inl.h:
##########
@@ -219,15 +225,86 @@ void NumpyRepeatsOpForward(const nnvm::NodeAttrs& attrs,
       stride *= ishape[i];
     }
 
-    MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
-      MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+    MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(inputs[0].type_flag_, IType, {
+      MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(outputs[0].type_flag_, OType, {
         mxnet_op::Kernel<repeat_axis_fwd, xpu>::Launch(
             s, out_data.Size(), out_data.dptr<OType>(), in_data.dptr<IType>(), ind, stride);
       });
     });
   }
 }
 
+template <typename xpu>
+void NumpyRepeatsOpForward(const nnvm::NodeAttrs& attrs,
+                           const OpContext& ctx,
+                           const std::vector<TBlob>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  const mxnet::TShape& ishape = inputs[0].shape_;
+  int repeats                 = 0;
+  int axis                    = -1;
+  dmlc::optional<int> axisOpt;
+  const RepeatsParam& param = nnvm::get<RepeatsParam>(attrs.parsed);
+  GetRepeatsParams(param, ishape, &repeats, &axisOpt, &axis);
+
+  if (!shape_is_known(ishape) || repeats == 0)
+    return;
+
+  mxnet::Tuple<int> repts = param.repeats.value();
+  if (repts.ndim() == 1) {
+    int len = static_cast<bool>(axisOpt) ? ishape[axis] : ishape.Size();
+    std::vector<int> temp(len, repeats);
+    repts = mxnet::Tuple<int>(temp);
+  }
+
+  // If axis was specified then perform swapaxis before and after calling repeat function
+  if (axisOpt.has_value() && axisOpt.value() != 0) {
+    int type_size          = mshadow_sizeof(inputs[0].type_flag_);
+    size_t total_temp_size = inputs[0].shape_.Size() * type_size +
+                             outputs[0].shape_.Size() * type_size + repts.ndim() * sizeof(int);
+    Tensor<xpu, 1, char> temp_space = ctx.requested[0].get_space_typed<xpu, 1, char>(
+        Shape1(total_temp_size), ctx.get_stream<xpu>());
+    void* swap_output_tmp_dptr   = temp_space.dptr_;
+    void* repeat_output_tmp_dptr = temp_space.dptr_ + inputs[0].shape_.Size() * type_size;
+    int* repeat_tmp_dptr =
+        reinterpret_cast<int*>(temp_space.dptr_ + inputs[0].shape_.Size() * type_size +

Review Comment:
   ```suggestion
           reinterpret_cast<int*>(repeat_output_tmp_dptr +
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org