You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/31 20:30:04 UTC

[GitHub] piiswrong closed pull request #8254: Fix softmax_cross_entropy

piiswrong closed pull request #8254: Fix softmax_cross_entropy
URL: https://github.com/apache/incubator-mxnet/pull/8254
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/loss_binary_op-inl.h b/src/operator/loss_binary_op-inl.h
index 1362997231..2beb70c141 100644
--- a/src/operator/loss_binary_op-inl.h
+++ b/src/operator/loss_binary_op-inl.h
@@ -33,7 +33,7 @@
 namespace mxnet {
 namespace op {
 
-// return a shape of scalar
+// return a shape of batch_size
 inline bool SoftmaxCrossEntropyShape(const nnvm::NodeAttrs& attrs,
                                      std::vector<TShape> *in_attrs,
                                      std::vector<TShape> *out_attrs) {
@@ -43,7 +43,7 @@ inline bool SoftmaxCrossEntropyShape(const nnvm::NodeAttrs& attrs,
       << "SoftmaxCrossEntropy only accept 1D label";
   CHECK_EQ((*in_attrs)[0][0], (*in_attrs)[1][0])
       << "SoftmaxCrossEntropy: data label shape mismatch";
-  SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(1));
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(mshadow::Shape1((*in_attrs)[0][0])));
   return true;
 }
 
@@ -66,18 +66,17 @@ void SoftmaxCrossEntropyForward(const nnvm::NodeAttrs& attrs,
     mshadow::Tensor<xpu, 1, DType> workspace = ctx.requested[0].get_space_typed<xpu, 1, DType>(
         mshadow::Shape1(mdata.shape_.Size() + mlabel.size(0)), s);
     mshadow::Tensor<xpu, 2, DType> temp1(workspace.dptr_, mdata.shape_, s);
-    mshadow::Tensor<xpu, 2, DType> temp2(workspace.dptr_ + mdata.shape_.Size(),
-        mshadow::Shape2(1, mlabel.size(0)), s);
+    mshadow::Tensor<xpu, 1, DType> temp2(workspace.dptr_ + mdata.shape_.Size(),
+        mshadow::Shape1(mlabel.size(0)), s);
     // calculate softmax on temp
     // TODO(tqchen): change to SoftmaxLog later
     mshadow::Softmax(temp1, mdata);
     // choose the softmax rows
-    mshadow::Tensor<xpu, 1, DType> tdst = temp2[0];
-    tdst = F<mshadow_op::negation>(
+    temp2 = F<mshadow_op::negation>(
         F<mshadow_op::log>(
             F<mshadow_op::maximum>(mat_choose_row_element(temp1, mlabel),
                                    scalar<DType>(1e-8f))));
-    ASSIGN_DISPATCH(out, req[0], sumall_except_dim<0>(temp2));
+    ASSIGN_DISPATCH(out, req[0], F<mshadow_op::identity>(temp2));
   });
 }
 
@@ -100,7 +99,7 @@ void SoftmaxCrossEntropyBackward(const nnvm::NodeAttrs& attrs,
         mdata.shape_, s);
     mshadow::Softmax(temp, mdata);
     mshadow::SoftmaxGrad(temp, temp, mlabel);
-    ASSIGN_DISPATCH(mdata_grad, req[0], broadcast_scalar(mscale, temp.shape_) * temp);
+    ASSIGN_DISPATCH(mdata_grad, req[0], broadcast<0>(mscale, temp.shape_) * temp);
   });
 }
 
diff --git a/src/operator/loss_binary_op.cc b/src/operator/loss_binary_op.cc
index c1fedb3de6..9bfea76258 100644
--- a/src/operator/loss_binary_op.cc
+++ b/src/operator/loss_binary_op.cc
@@ -59,6 +59,10 @@ Example::
 )code" ADD_FILELINE)
 .set_num_inputs(2)
 .set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data", "label"};
+  })
 .set_attr<nnvm::FInferShape>("FInferShape", SoftmaxCrossEntropyShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
 .set_attr<FResourceRequest>("FResourceRequest",
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 19c4e65d3d..2529a7022a 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -333,6 +333,40 @@ def check_softmax_with_shape(shape, xpu, preserve_shape=False):
     assert_almost_equal(grad.asnumpy(), np_softmax(x.asnumpy()) - l.asnumpy(), rtol=1e-4)
 
 
+def test_softmax_cross_entropy():
+    batch_size = 16
+    num_class = 20
+
+    data = mx.sym.Variable('data')
+    label = mx.sym.Variable('softmax_label')
+
+    ce = mx.sym.softmax_cross_entropy(data=data, label=label)
+    out = mx.sym.make_loss(ce, name='softmax')
+
+    data_shape = (batch_size, num_class)
+    label_shape = (batch_size,)
+    mod = mx.mod.Module(symbol=out)
+    mod.bind(data_shapes=[('data', data_shape)],
+             label_shapes=[('softmax_label', label_shape)],
+             inputs_need_grad=True)
+    mod.init_params()
+    mod.init_optimizer(optimizer_params={'learning_rate': 0.01})
+
+    data_array = mx.nd.random.uniform(0, 10, shape=data_shape)
+    label_array = mx.nd.array([i for i in range(batch_size)])
+
+    mod.forward(mx.io.DataBatch(data=[data_array],
+                                label=[label_array]))
+    assert mod.get_outputs()[0].shape == (data_shape[0],)
+    expected_out = -mx.nd.pick(mx.nd.log_softmax(data_array), label_array, axis=1)
+    assert_almost_equal(mod.get_outputs()[0].asnumpy(), expected_out.asnumpy())
+
+    mod.backward()
+    grad = mod.get_input_grads()
+    expected_grad = mx.nd.softmax(data_array) - mx.nd.one_hot(label_array, depth=num_class)
+    assert_almost_equal(grad[0].asnumpy(), expected_grad.asnumpy())
+
+
 def test_python_op():
     X = mx.symbol.Variable('X')
     op = mx.operator.NumpyOp()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services