You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ap...@apache.org on 2019/07/03 05:54:04 UTC
[incubator-mxnet] branch master updated: [MXNET-978] Higher order
gradient for sigmoid (#15288)
This is an automated email from the ASF dual-hosted git repository.
apeforest pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 6a8d9eb [MXNET-978] Higher order gradient for sigmoid (#15288)
6a8d9eb is described below
commit 6a8d9eb5fd4f7133c094149dc80a3a236534f223
Author: Lin Yuan <ap...@gmail.com>
AuthorDate: Tue Jul 2 22:53:39 2019 -0700
[MXNET-978] Higher order gradient for sigmoid (#15288)
* try to add support some ops
* add unit test for second order grad
* implement grad for relu and add unit test
* fix lint
* register FGradient attribute for backward relu
* resolve conflict
* remove unused imports
* change gradient using set_attr
* remove higher order grad test for negative(x)
* fix lint
* reverse indent
* remove unused backward operator
* refactor backward for sin(x) and cos(x)
* change value init to list init
* change to list initialization
* generate random shape in test
* fix a bug in second order backward
* fix lint
* fix lint
* address reviewer comment and renaming
* test 2nd order gradient for sigmoid
* higher order grads for sigmoid
* add unit test
* remove blank lines
* update test
* fix lint
* fix third order gradient for sigmoid
---
src/common/exec_utils.h | 5 ++---
src/imperative/imperative.cc | 4 ++++
src/operator/tensor/elemwise_unary_op_basic.cc | 30 ++++++++++++++++++++++++-
src/operator/tensor/elemwise_unary_op_trig.cc | 4 ++--
tests/python/unittest/test_higher_order_grad.py | 17 ++++++++++++++
5 files changed, 54 insertions(+), 6 deletions(-)
diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h
index 0551b42..d8b7a33 100644
--- a/src/common/exec_utils.h
+++ b/src/common/exec_utils.h
@@ -286,7 +286,6 @@ inline void LogMemoryPlan(const nnvm::Graph& g) {
const auto &idx = g.indexed_graph();
const auto& vshape = g.GetAttr<mxnet::ShapeVector>("shape");
const auto& vtype = g.GetAttr<nnvm::DTypeVector>("dtype");
- const auto& vstorage = g.GetAttr<nnvm::StorageVector>("storage_id");
// find node range
uint32_t node_start = 0, node_end = idx.num_nodes();
if (g.attrs.count("node_range")) {
@@ -304,13 +303,13 @@ inline void LogMemoryPlan(const nnvm::Graph& g) {
auto eid = idx.entry_id(e);
size_t kilo_bytes = vshape[eid].Size() * mshadow::mshadow_sizeof(vtype[eid]) / 1024;
LOG(INFO) << "\t\tinput " << eid << ": " << vshape[eid] << " ("
- << kilo_bytes << " KB) -> " << storage_str(vstorage[eid]);
+ << kilo_bytes << " KB)";
}
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
size_t kilo_bytes = vshape[eid].Size() * mshadow::mshadow_sizeof(vtype[eid]) / 1024;
LOG(INFO) << "\t\toutput " << eid << ": " << vshape[eid] << " ("
- << kilo_bytes << " KB) -> " << storage_str(vstorage[eid]);
+ << kilo_bytes << " KB)";
}
}
}
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index d8fba1c..e2c0c9d 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -501,6 +501,10 @@ std::vector<NDArray*> Imperative::Backward(
}
}
+ if (dmlc::GetEnv("MXNET_MEM_PLAN_VERBOSE_LOGGING", false)) {
+ common::LogMemoryPlan(graph);
+ }
+
// Execution
bool prev_recording = set_is_recording(create_graph);
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index 98dc8da..26c7408 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -121,7 +121,35 @@ The storage type of ``sigmoid`` output is always dense
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"});
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid,
- unary_bwd<mshadow_op::sigmoid_grad>);
+ unary_bwd<mshadow_op::sigmoid_grad>)
+.set_attr<nnvm::FGradient>("FGradient",
+ [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+ // n->inputs[0] : y_grad
+ // n->inputs[1] : f(x) = sigmoid(x)
+ // ograds[0] : head_grads
+ // f''(x) = f'(x) * (1 - 2*f(x))
+ // NodeEntry{n} : y_grad * f'(x)
+ auto ones = MakeNode("ones_like", n->attrs.name + "_grad_ones", {n->inputs[1]}, nullptr, &n);
+ const std::unordered_map<std::string, std::string> args = {{"scalar", "2.0"}};
+ auto two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_two", {n->inputs[1]}, &args, &n);
+ auto one_minus_two_y = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub",
+ {nnvm::NodeEntry{ones}, nnvm::NodeEntry{two_y}}, nullptr, &n);
+ auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
+ {n->inputs[0], nnvm::NodeEntry{one_minus_two_y}}, nullptr, &n);
+ auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div",
+ {nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n);
+
+ // when building gradient graph, the backward node of n->inputs[1] will be
+ // added to the graph again, therefore f`(x) will be multiplied
+ std::vector<nnvm::NodeEntry> ret;
+ ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad",
+ {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
+ ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in",
+ {ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
+ return ret;
+ });
+
+
DMLC_REGISTER_PARAMETER(HardSigmoidParam);
MXNET_OPERATOR_REGISTER_UNARY(hard_sigmoid)
diff --git a/src/operator/tensor/elemwise_unary_op_trig.cc b/src/operator/tensor/elemwise_unary_op_trig.cc
index b7cf76e..13410e9 100644
--- a/src/operator/tensor/elemwise_unary_op_trig.cc
+++ b/src/operator/tensor/elemwise_unary_op_trig.cc
@@ -49,7 +49,7 @@ The storage type of ``sin`` output depends upon the input storage type:
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_sin, unary_bwd<mshadow_op::sin_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
- // ograds[0]: d^2L/dx^2
+ // ograds[0]: head_grad_grads (dL/dxgrad)
// inputs[0]: dL/dy
// inputs[1]: x (ElemwiseUseIn)
// f(x) = sin(x)
@@ -92,7 +92,7 @@ The storage type of ``cos`` output is always dense
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_cos, unary_bwd<mshadow_op::cos_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
- // ograds[0]: d^2L/dx^2
+ // ograds[0]: head_grad_grads (dL/dx_grad)
// inputs[0]: dL/dy
// inputs[1]: x (ElemwiseUseIn)
// f(x) = cos(x)
diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py
index 4f1ea9a..ad14c50 100644
--- a/tests/python/unittest/test_higher_order_grad.py
+++ b/tests/python/unittest/test_higher_order_grad.py
@@ -106,6 +106,23 @@ def test_log10():
check_second_order_unary(array, log10, grad_grad_op)
+@with_seed()
+def test_sigmoid():
+ def sigmoid(x):
+ return nd.sigmoid(x)
+
+ def grad_op(x):
+ return sigmoid(x) * (1 - sigmoid(x))
+
+ def grad_grad_op(x):
+ return grad_op(x) * (1 - 2 * sigmoid(x))
+
+ for dim in range(1, 5):
+ shape = rand_shape_nd(dim)
+ array = random_arrays(shape)
+ check_second_order_unary(array, sigmoid, grad_grad_op)
+
+
def check_second_order_unary(x, op, grad_grad_op):
x = nd.array(x)
grad_grad_x = grad_grad_op(x)