You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/14 18:46:38 UTC

[GitHub] hcho3 closed pull request #11209: [MXNET-536] implement var/std operators

hcho3 closed pull request #11209: [MXNET-536] implement var/std operators
URL: https://github.com/apache/incubator-mxnet/pull/11209
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index e50071bdab7..e2029f6fcb0 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -84,6 +84,18 @@ struct NormParam : public dmlc::Parameter<NormParam> {
   }
 };
 
+struct VarParam : public dmlc::Parameter<VarParam> {
+  dmlc::optional<TShape> axis;
+  bool keepdims;
+  DMLC_DECLARE_PARAMETER(VarParam) {
+    DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<TShape>())
+      .describe("The axis along which to perform the reduction");
+    DMLC_DECLARE_FIELD(keepdims).set_default(false)
+      .describe("If this is set to `True`, the reduced axis is left "
+                "in the result as dimension with size one.");
+  }
+};
+
 struct ReduceAxisParam : public dmlc::Parameter<ReduceAxisParam> {
   dmlc::optional<int> axis;
   bool keepdims;
@@ -292,6 +304,20 @@ inline bool NormShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+inline bool VarShape(const nnvm::NodeAttrs& attrs,
+                     std::vector<TShape>* in_attrs,
+                     std::vector<TShape>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 3U);
+  const VarParam& param = nnvm::get<VarParam>(attrs.parsed);
+  const TShape oshape = ReduceAxesShapeImpl((*in_attrs)[0], param.axis,
+                                            param.keepdims, false);
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
+  SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape);
+  SHAPE_ASSIGN_CHECK(*out_attrs, 2, (*in_attrs)[0]);
+  return true;
+}
+
 inline bool BroadcastAxesShape(const nnvm::NodeAttrs& attrs,
                                std::vector<TShape> *in_attrs,
                                std::vector<TShape> *out_attrs) {
@@ -767,6 +793,96 @@ void ReduceAxesBackwardUseInOutImpl(const OpContext& ctx,
   });
 }
 
+struct var_backward {
+  template <typename DType>
+  MSHADOW_XINLINE static DType Map(DType a, DType b) {
+    return 2.0 * (a - b);
+  }
+};
+
+template<typename xpu>
+void VarBackwardImpl(const OpContext& ctx,
+                     const TShape& small,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+
+  TShape src_shape, dst_shape;
+  BroadcastReduceShapeCompact(outputs[0].shape_, small, &src_shape, &dst_shape);
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    if (dst_shape.ndim() == 2) {
+      Tensor<xpu, 2, DType> igrad =
+        outputs[0].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
+      Tensor<xpu, 2, DType> ograd =
+        inputs[0].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
+      Tensor<xpu, 2, DType> data =
+        inputs[3].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
+      Tensor<xpu, 2, DType> hidden =
+        inputs[5].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
+      ASSIGN_DISPATCH(igrad, req[0],
+        broadcast_to(ograd, src_shape)
+        * F<var_backward>(data, broadcast_to(hidden, src_shape)));
+      igrad /= scalar<DType>(src_shape.Size()/dst_shape.Size());
+    } else {
+      const int ndim = MXNET_SPECIAL_MAX_NDIM;
+      Tensor<xpu, ndim, DType> igrad =
+        outputs[0].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
+      Tensor<xpu, ndim, DType> ograd =
+        inputs[0].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
+      Tensor<xpu, ndim, DType> data =
+        inputs[3].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
+      Tensor<xpu, ndim, DType> hidden =
+        inputs[5].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
+      ASSIGN_DISPATCH(igrad, req[0],
+        broadcast_to(ograd, src_shape)
+        * F<var_backward>(data, broadcast_to(hidden, src_shape)));
+      igrad /= scalar<DType>(src_shape.Size()/dst_shape.Size());
+    }
+  });
+}
+
+struct std_backward {
+  template <typename DType>
+  MSHADOW_XINLINE static DType Map(DType a, DType b) {
+    return 0.5 * (a / b);
+  }
+};
+
+template<typename xpu>
+void StdBackwardImpl(const OpContext& ctx,
+                     const TShape& small,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+
+  TShape src_shape, dst_shape;
+  BroadcastReduceShapeCompact(outputs[0].shape_, small, &src_shape, &dst_shape);
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    if (dst_shape.ndim() == 2) {
+      Tensor<xpu, 2, DType> igrad =
+        outputs[0].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
+      Tensor<xpu, 2, DType> out =
+        inputs[4].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
+      ASSIGN_DISPATCH(igrad, req[0],
+        F<std_backward>(igrad, broadcast_to(out, src_shape)));
+    } else {
+      const int ndim = MXNET_SPECIAL_MAX_NDIM;
+      Tensor<xpu, ndim, DType> igrad =
+        outputs[0].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
+      Tensor<xpu, ndim, DType> out =
+        inputs[4].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
+      ASSIGN_DISPATCH(igrad, req[0],
+        F<std_backward>(igrad, broadcast_to(out, src_shape)));
+    }
+  });
+}
+
 // works when shape inference of output is given
 template<typename xpu, typename OP, bool normalize = false>
 void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs,
@@ -1004,6 +1120,112 @@ void L2NormCompute(const nnvm::NodeAttrs& attrs,
   SqRootForL2<xpu>(ctx, req[0], outputs[0]);
 }
 
+template <typename xpu>
+void VarCompute(const nnvm::NodeAttrs& attrs,
+                const OpContext& ctx,
+                const std::vector<TBlob>& inputs,
+                const std::vector<OpReqType>& req,
+                const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 3U);
+  CHECK_EQ(req.size(), 3U);
+  const VarParam& param = nnvm::get<VarParam>(attrs.parsed);
+  if (req[0] == kNullOp) return;
+
+  TShape small;
+  if (param.keepdims) {
+    small = outputs[0].shape_;
+  } else {
+    small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false);
+  }
+  // Compute mean of elements (E[X])
+  ReduceAxesComputeImpl<xpu, mshadow::red::sum, true, mshadow_op::identity>(
+    ctx, inputs, req, {outputs[1]}, small);
+  // Compute X - E[X]
+  TShape src_shape, dst_shape;
+  BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
+  using namespace mxnet_op;
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    if (dst_shape.ndim() == 2) {
+      Tensor<xpu, 2, DType> data =
+        inputs[0].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
+      Tensor<xpu, 2, DType> hidden_mean =
+        outputs[1].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
+      Tensor<xpu, 2, DType> hidden_residual =
+        outputs[2].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
+      ASSIGN_DISPATCH(hidden_residual, req[0],
+        F<mshadow_op::minus>(data, broadcast_to(hidden_mean, src_shape)));
+    } else {
+      const int ndim = MXNET_SPECIAL_MAX_NDIM;
+      Tensor<xpu, ndim, DType> data =
+        inputs[0].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
+      Tensor<xpu, ndim, DType> hidden_mean =
+        outputs[1].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
+      Tensor<xpu, ndim, DType> hidden_residual =
+        outputs[2].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
+      ASSIGN_DISPATCH(hidden_residual, req[0],
+        F<mshadow_op::minus>(data, broadcast_to(hidden_mean, src_shape)));
+    }
+  });
+  // Compute var = E[(X - E[X])^2]
+  ReduceAxesComputeImpl<xpu, mshadow::red::sum, true, mshadow_op::square>(
+    ctx, {outputs[2]}, req, {outputs[0]}, small);
+}
+
+template <typename xpu>
+void StdCompute(const nnvm::NodeAttrs& attrs,
+                const OpContext& ctx,
+                const std::vector<TBlob>& inputs,
+                const std::vector<OpReqType>& req,
+                const std::vector<TBlob>& outputs) {
+  VarCompute<xpu>(attrs, ctx, inputs, req, outputs);
+  SqRootForL2<xpu>(ctx, req[0], outputs[0]);
+}
+
+template<typename xpu>
+void VarGradCompute(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<TBlob>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 7U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  if (req[0] == kNullOp) return;
+  const VarParam& param = nnvm::get<VarParam>(attrs.parsed);
+  TShape small;
+  if (param.keepdims) {
+    small = inputs[0].shape_;
+  } else {
+    small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, false);
+  }
+  VarBackwardImpl<xpu>(ctx, small, inputs, req, outputs);
+}
+
+template<typename xpu>
+void StdGradCompute(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<TBlob>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 7U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  if (req[0] == kNullOp) return;
+  const VarParam& param = nnvm::get<VarParam>(attrs.parsed);
+  TShape small;
+  if (param.keepdims) {
+    small = inputs[0].shape_;
+  } else {
+    small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, false);
+  }
+  VarBackwardImpl<xpu>(ctx, small, inputs, req, outputs);
+  StdBackwardImpl<xpu>(ctx, small, inputs, req, outputs);
+}
+
 template<typename xpu>
 void L2NormGradCompute(const nnvm::NodeAttrs& attrs,
                        const OpContext& ctx,
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc
index 7bcc3e97e8e..9f3afab7483 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cc
+++ b/src/operator/tensor/broadcast_reduce_op_value.cc
@@ -28,6 +28,7 @@ namespace mxnet {
 namespace op {
 DMLC_REGISTER_PARAMETER(ReduceAxesParam);
 DMLC_REGISTER_PARAMETER(NormParam);
+DMLC_REGISTER_PARAMETER(VarParam);
 DMLC_REGISTER_PARAMETER(ReduceAxisParam);
 DMLC_REGISTER_PARAMETER(BroadcastAxesParam);
 DMLC_REGISTER_PARAMETER(BroadcastToParam);
@@ -324,6 +325,77 @@ NNVM_REGISTER_OP(_backward_norm)
   })
 .set_attr<FCompute>("FCompute<cpu>", L2NormGradCompute<cpu>);
 
+NNVM_REGISTER_OP(variance)
+.describe(R"code(Computes the variance on an NDArray.
+
+This operator computes the variance on an NDArray with the specified axis.
+By default, it computes the variance on the entire array.
+
+Examples::
+
+  x = [[1, 2],
+       [3, 4]]
+
+  variance(x) = [1.25]
+
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(3)
+.set_attr_parser(ParamParser<VarParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", VarShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 3>)
+.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
+  [](const NodeAttrs& attrs) { return 1; })
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_var"})
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", VarCompute<cpu>)
+.add_argument("data", "NDArray-or-Symbol", "The input")
+.add_arguments(VarParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_var)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<VarParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", VarGradCompute<cpu>);
+
+NNVM_REGISTER_OP(std)
+.describe(R"code(Computes the standard deviation on an NDArray.
+
+This operator computes the standard deviation on an NDArray with the specified
+axis. By default, it computes the standard deviation on the entire array.
+
+Examples::
+
+  x = [[1, 2],
+       [3, 4]]
+
+  std(x) = [1.118]
+
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(3)
+.set_attr_parser(ParamParser<VarParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", VarShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 3>)
+.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
+  [](const NodeAttrs& attrs) { return 1; })
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_std"})
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", StdCompute<cpu>)
+.add_argument("data", "NDArray-or-Symbol", "The input")
+.add_arguments(VarParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_std)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<VarParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", StdGradCompute<cpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cu b/src/operator/tensor/broadcast_reduce_op_value.cu
index f7fba682878..375f853d10a 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cu
+++ b/src/operator/tensor/broadcast_reduce_op_value.cu
@@ -107,5 +107,17 @@ NNVM_REGISTER_OP(norm)
 NNVM_REGISTER_OP(_backward_norm)
 .set_attr<FCompute>("FCompute<gpu>", L2NormGradCompute<gpu>);
 
+NNVM_REGISTER_OP(variance)
+.set_attr<FCompute>("FCompute<gpu>", VarCompute<gpu>);
+
+NNVM_REGISTER_OP(_backward_var)
+.set_attr<FCompute>("FCompute<gpu>", VarGradCompute<gpu>);
+
+NNVM_REGISTER_OP(std)
+.set_attr<FCompute>("FCompute<gpu>", StdCompute<gpu>);
+
+NNVM_REGISTER_OP(_backward_std)
+.set_attr<FCompute>("FCompute<gpu>", StdGradCompute<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index ab03973e8e8..2c3560f6f0f 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -5691,6 +5691,59 @@ def test_softmax():
     check_softmax_grad(default_context())
     check_smoothed_softmax_grad(default_context())
 
+def test_variance():
+  def true_var(x):
+    if len(x.shape) == 1:
+      return np.var(x, keepdims=True)
+    else:
+      return np.var(x)
+  def true_var_grad(x, ograd):
+    return 2 * (x - np.mean(x)) * ograd / np.prod(x.shape)
+
+  for ndim in range(1, 6):
+    # check forward
+    shape = rand_shape_nd(ndim, 5)
+    data = rand_ndarray(shape=shape, stype='default')
+    data_np = data.asnumpy()
+    expected = true_var(data_np)
+    output = mx.nd.variance(data)
+    assert_almost_equal(output.asnumpy(), expected)
+
+    # check backward
+    data = mx.sym.Variable('data')
+    var_sym = mx.sym.variance(data=data)
+    check_numeric_gradient(var_sym, [data_np], atol=1e-3)
+    ograd = np.random.random(size=output.shape)
+    check_symbolic_backward(var_sym, [data_np], [ograd],
+      [true_var_grad(data_np, ograd)], atol=1e-8)
+
+def test_std():
+  def true_std(x):
+    if len(x.shape) == 1:
+      return np.std(x, keepdims=True)
+    else:
+      return np.std(x)
+  def true_std_grad(x, ograd):
+    return (x - np.mean(x)) * ograd / true_std(x) / np.prod(x.shape)
+
+  for ndim in range(1, 6):
+    # check forward
+    shape = rand_shape_nd(ndim, 5)
+    while np.prod(shape[0]) == 1:  # avoid length-1 array
+      shape = rand_shape_nd(ndim, 5)
+    data = rand_ndarray(shape=shape, stype='default')
+    data_np = data.asnumpy()
+    expected = true_std(data_np)
+    output = mx.nd.std(data)
+    assert_almost_equal(output.asnumpy(), expected)
+
+    # check backward
+    data = mx.sym.Variable('data')
+    std_sym = mx.sym.std(data=data)
+    check_numeric_gradient(std_sym, [data_np], atol=1e-3)
+    ograd = np.random.random(size=output.shape)
+    check_symbolic_backward(std_sym, [data_np], [ograd],
+      [true_std_grad(data_np, ograd)], atol=1e-8)
 
 @with_seed()
 def test_slice():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services