You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/09/18 10:51:54 UTC

[GitHub] [incubator-mxnet] gyshi commented on a change in pull request #15902: Numpy add numpy op roll

gyshi commented on a change in pull request #15902: Numpy add numpy op roll
URL: https://github.com/apache/incubator-mxnet/pull/15902#discussion_r325607539
 
 

 ##########
 File path: src/operator/numpy/np_matrix_op.cc
 ##########
 @@ -345,5 +346,73 @@ Examples::
 .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack")
 .add_arguments(StackParam::__FIELDS__());
 
+inline bool NumpyRollShape(const nnvm::NodeAttrs& attrs,
+                           mxnet::ShapeVector *in_attrs,
+                           mxnet::ShapeVector *out_attrs) {
+  using namespace mshadow;
+  const NumpyRollParam& param = nnvm::get<NumpyRollParam>(attrs.parsed);
+
+  if (!param.shift.has_value()) {
+    LOG(FATAL) << "roll missing 1 required positional argument: 'shift'.";
+  }
+  if (param.shift.value().ndim() > 1 &&
+      param.axis.has_value() &&
+      param.axis.value().ndim() != param.shift.value().ndim()) {
+    LOG(FATAL) << "shift and `axis` must be a tuple of the same size.";
+  }
+  if (!param.axis.has_value() && param.shift.has_value() && param.shift.value().ndim() > 1) {
+    LOG(FATAL) << "shift must be an int.";
+  }
+  if (param.axis.has_value()) {
+    mxnet::TShape axes(param.axis.value());
+    const index_t ndim = (*in_attrs)[0].ndim();
+    for (index_t i = 0; i < axes.ndim(); i++) {
+      if (axes[i] < 0) {
+        axes[i] += ndim;
+      }
+    }
+    std::sort(axes.begin(), axes.end());
+    for (index_t i = 1; i < axes.ndim(); i++) {
+      CHECK_LT(axes[i - 1], axes[i])
+        << "axes have duplicates " << axes;
+    }
+    CHECK_LT(axes[axes.ndim() - 1], ndim)
+      << "axis " << axes[axes.ndim() - 1]
+      << " Exceeds input dimensions " << (*in_attrs)[0];
+    CHECK_GE(axes[0], 0)
+      << "Reduction axis " << param.axis.value()
+      << " Exceeds input dimensions " << (*in_attrs)[0];
+  }
+  return ElemwiseShape<1, 1>(attrs, in_attrs, out_attrs);
+}
+
+NNVM_REGISTER_OP(_np_roll)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyRollParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+     return std::vector<std::string>{"data"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyRollShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<mxnet::FCompute>("FCompute<cpu>", NumpyRollCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient",
+  [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+     const NumpyRollParam& param = nnvm::get<NumpyRollParam>(n->attrs.parsed);
+     std::ostringstream os1;
+     os1 << param.shift;
+     std::ostringstream os2;
+     os2 << param.axis;
+     return MakeNonlossGradNode("_np_roll", n, ograds, {},
+                                {{"shift", os1.str()}, {"axis", os2.str()}});
 
 Review comment:
   i think the value of backward is 1, or it has no backward.  

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services