You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2022/05/16 15:15:24 UTC

[GitHub] [incubator-mxnet] agrabows commented on a diff in pull request #21024: [FEATURE] Refactor SwapAxis operator.

agrabows commented on code in PR #21024:
URL: https://github.com/apache/incubator-mxnet/pull/21024#discussion_r873851715


##########
src/operator/swapaxis-inl.h:
##########
@@ -53,223 +53,157 @@ struct SwapAxisParam : public dmlc::Parameter<SwapAxisParam> {
   }
 };
 
-template <typename xpu, typename DType>
-class SwapAxisOp : public Operator {
- public:
-  explicit SwapAxisOp(SwapAxisParam p) {
-    this->param_ = p;
+inline void Reshape2Five(mshadow::Shape<5>* inter_shape,
+                         const mxnet::TShape& shape,
+                         int dim1,
+                         int dim2) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  int ndim_in = shape.ndim();
+  int si;
+
+  if (dim1 > dim2) {
+    std::swap(dim1, dim2);
   }
 
-  void Reshape2Five(mshadow::Shape<5>* inter_shape,
-                    const mxnet::TShape& shape,
-                    int dim1,
-                    int dim2) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    int ndim_in = shape.ndim();
-    int si;
-
-    if (dim1 > dim2) {
-      std::swap(dim1, dim2);
-    }
-
-    for (si = 0; si < 5; si++) {
-      (*inter_shape)[si] = 1;
-    }
-    // dim_0
-    for (si = 0; si < dim1; si++) {
-      (*inter_shape)[0] *= shape[si];
-    }
-    // dim_1
-    (*inter_shape)[1] = shape[dim1];
-    // dim_2
-    for (si = dim1 + 1; si < dim2; si++) {
-      (*inter_shape)[2] *= shape[si];
-    }
-    // dim_3
-    (*inter_shape)[3] = shape[dim2];
-    // dim_4
-    for (si = dim2 + 1; si < ndim_in; si++) {
-      (*inter_shape)[4] *= shape[si];
-    }
+  for (si = 0; si < 5; si++) {
+    (*inter_shape)[si] = 1;
   }
-
-  void SwapAxis(mshadow::Stream<xpu>* s,
-                const std::vector<TBlob>& in_data,
-                const std::vector<TBlob>& out_data,
-                const std::vector<OpReqType>& req) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-
-    TBlob data_in     = in_data[swapaxisenum::kData];
-    TBlob data_out    = out_data[swapaxisenum::kData];
-    OpReqType out_req = req[swapaxisenum::kData];
-
-    mxnet::TShape shape_in  = data_in.shape_;
-    mxnet::TShape shape_out = data_out.shape_;
-    int axis1               = param_.dim1;
-    if (axis1 < 0) {
-      axis1 += shape_in.ndim();
-    }
-    CHECK(axis1 >= 0 && axis1 < shape_in.ndim())
-        << "axis1: axis " << param_.dim1 << " is out of bounds for array of ndim "
-        << shape_in.ndim();
-
-    int axis2 = param_.dim2;
-    if (axis2 < 0) {
-      axis2 += shape_in.ndim();
-    }
-    CHECK(axis2 >= 0 && axis2 < shape_in.ndim())
-        << "axis2: axis " << param_.dim2 << " is out of bounds for array of ndim "
-        << shape_in.ndim();
-
-    if (shape_in.Size() == 0U)
-      return;
-
-    if (axis1 == axis2) {
-      if (out_req == kAddTo) {
-        mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, kAddTo>, xpu>::Launch(
-            s, data_out.Size(), data_out.dptr<DType>(), data_in.dptr<DType>());
-      } else {
-        mxnet_op::copy(s, data_out, data_in);
-      }
-      return;
-    }
-
-    Shape<5> inter_shape;
-
-    Reshape2Five(&inter_shape, shape_in, axis1, axis2);
-
-    Tensor<xpu, 5, DType> inter_data_in = data_in.get_with_shape<xpu, 5, DType>(inter_shape, s);
-
-    Shape<5> inter_shape2 = inter_shape;
-    std::swap(inter_shape2[1], inter_shape2[3]);
-
-    Tensor<xpu, 5, DType> inter_data_out = data_out.get_with_shape<xpu, 5, DType>(inter_shape2, s);
-
-    if (out_req == kAddTo) {
-      inter_data_out += swapaxis<3, 1>(inter_data_in);
-    } else {
-      inter_data_out = swapaxis<3, 1>(inter_data_in);
-    }
+  // dim_0
+  for (si = 0; si < dim1; si++) {
+    (*inter_shape)[0] *= shape[si];
   }
-
-  virtual void Forward(const OpContext& ctx,
-                       const std::vector<TBlob>& in_data,
-                       const std::vector<OpReqType>& req,
-                       const std::vector<TBlob>& out_data,
-                       const std::vector<TBlob>& aux_args) {
-    using namespace mshadow;
-    Stream<xpu>* s = ctx.get_stream<xpu>();
-
-    SwapAxis(s, in_data, out_data, req);
+  // dim_1
+  (*inter_shape)[1] = shape[dim1];
+  // dim_2
+  for (si = dim1 + 1; si < dim2; si++) {
+    (*inter_shape)[2] *= shape[si];
   }
-
-  virtual void Backward(const OpContext& ctx,
-                        const std::vector<TBlob>& out_grad,
-                        const std::vector<TBlob>& in_data,
-                        const std::vector<TBlob>& out_data,
-                        const std::vector<OpReqType>& req,
-                        const std::vector<TBlob>& in_grad,
-                        const std::vector<TBlob>& aux_args) {
-    using namespace mshadow;
-    Stream<xpu>* s = ctx.get_stream<xpu>();
-
-    SwapAxis(s, out_grad, in_grad, req);
+  // dim_3
+  (*inter_shape)[3] = shape[dim2];
+  // dim_4
+  for (si = dim2 + 1; si < ndim_in; si++) {
+    (*inter_shape)[4] *= shape[si];
   }
+}
 
-  SwapAxisParam param_;
-};
-
-template <typename xpu>
-Operator* CreateOp(SwapAxisParam param, int dtype);
-
-#if DMLC_USE_CXX11
-class SwapAxisProp : public OperatorProperty {
- public:
-  std::vector<std::string> ListArguments() const override {
-    return {"data"};
+template <typename xpu, typename DType>
+void SwapAxis(const nnvm::NodeAttrs& attrs,
+              const OpContext& ctx,
+              const std::vector<TBlob>& in_data,
+              const std::vector<TBlob>& out_data,
+              const std::vector<OpReqType>& req) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+
+  TBlob data_in              = in_data[swapaxisenum::kData];
+  TBlob data_out             = out_data[swapaxisenum::kOut];
+  OpReqType out_req          = req[swapaxisenum::kData];

Review Comment:
   done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org