You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/12/04 17:48:58 UTC

[incubator-mxnet] branch master updated: [MXNET-1234] Fix shape inference problems in Activation backward (#13409)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7dde0eb  [MXNET-1234] Fix shape inference problems in Activation backward (#13409)
7dde0eb is described below

commit 7dde0eb0e4dc910beabc023b45317bdb82d52a0f
Author: Pedro Larroy <92...@users.noreply.github.com>
AuthorDate: Tue Dec 4 18:48:39 2018 +0100

    [MXNET-1234] Fix shape inference problems in Activation backward (#13409)
    
    * Provide a failing test for ReLU activation shape inference bug
    
    * Fix Activation backward shape inference
    
    fixes: #13333
    
    * Add softsign Activation to test_gluon.py
    
    * Use activation in GPU if we are using CUDNN and not MKLDNN as it's happening right now
    
    * Don't disable MKLDNN
---
 src/operator/elemwise_op_common.h     | 20 +++++----
 src/operator/nn/activation-inl.h      | 12 +++---
 src/operator/nn/activation.cc         | 79 +++++++++++++++++++++--------------
 src/operator/nn/activation.cu         | 30 +++++++------
 tests/cpp/operator/activation_perf.cc | 26 +++++++++---
 tests/python/unittest/test_gluon.py   | 12 +++---
 6 files changed, 109 insertions(+), 70 deletions(-)

diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h
index 4b8663b..e622ce2 100644
--- a/src/operator/elemwise_op_common.h
+++ b/src/operator/elemwise_op_common.h
@@ -128,29 +128,33 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
   if (n_out != -1)
     out_size = static_cast<size_t>(n_out);
 
-  auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
+  CHECK_LE(in_size, in_attrs->size());
+  CHECK_LE(out_size, out_attrs->size());
+  auto deduce = [&](const std::vector<AttrType>& vec, size_t size, const char *name) {
       for (size_t i = 0; i < size; ++i) {
-        CHECK(assign(&dattr, (*vec)[i]))
+        CHECK(assign(&dattr, vec.at(i)))
           << "Incompatible attr in node " << attrs.name << " at " << i << "-th "
           << name << ": " << "expected " << attr_string(dattr)
-          << ", got " << attr_string((*vec)[i]);
+          << ", got " << attr_string(vec.at(i));
       }
     };
-  deduce(in_attrs, in_size, "input");
-  if (reverse_infer) deduce(out_attrs, out_size, "output");
+  deduce(*in_attrs, in_size, "input");
+  if (reverse_infer)
+      deduce(*out_attrs, out_size, "output");
 
   auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
       for (size_t i = 0; i < size; ++i) {
-        CHECK(assign(&(*vec)[i], dattr))
+        CHECK(assign(&(vec->at(i)), dattr))
           << "Incompatible attr in node " << attrs.name << " at " << i << "-th "
           << name << ": " << "expected " << attr_string(dattr)
-          << ", got " << attr_string((*vec)[i]);
+          << ", got " << attr_string(vec->at(i));
       }
     };
   write(in_attrs, in_size, "input");
   write(out_attrs, out_size, "output");
 
-  if (is_none(dattr)) return false;
+  if (is_none(dattr))
+      return false;
   return true;
 }
 
diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h
index 2705177..1d8e4c2 100644
--- a/src/operator/nn/activation-inl.h
+++ b/src/operator/nn/activation-inl.h
@@ -48,6 +48,9 @@ enum ActivationOpInputs {kData};
 enum ActivationOpOutputs {kOut};
 enum ActivationOpResource {kTempSpace};
 enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU, kSoftSign};
+
+// Get the number of inputs to the gradient depending on the activation type
+int GradNumInputs(int act_type);
 }  // activation
 
 struct ActivationParam : public dmlc::Parameter<ActivationParam> {
@@ -199,13 +202,8 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs,
                            const std::vector<OpReqType>& req,
                            const std::vector<TBlob>& outputs) {
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
-#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
-  bool relu = param.act_type == activation::kReLU;
-  CHECK_EQ(inputs.size(), relu ? 2U : 3U);
-#else
-  bool softsign = param.act_type == activation::kSoftSign;
-  CHECK_EQ(inputs.size(), softsign ? 3U : 2U);
-#endif
+  const int act_type = param.act_type;
+  CHECK_EQ(inputs.size(), activation::GradNumInputs(act_type));
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
   ActivationGradComputeImpl<xpu>(attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index ba44ebd..305eeab 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -30,13 +30,34 @@
 #if MXNET_USE_MKLDNN == 1
 #include "./mkldnn/mkldnn_base-inl.h"
 #include "./mkldnn/mkldnn_ops-inl.h"
-#endif  // MXNET_USE_MKLDNN
+#endif  // MXNET_USE_MKLDNN == 1
 #include "../operator_common.h"
 #include "../../common/utils.h"
 
 namespace mxnet {
 namespace op {
 
+namespace activation {
+
+int GradNumInputs(int act_type) {
+    // check activation.cu \sa ActivationGradCompute
+    switch (act_type) {
+        case kReLU:
+            return 2;
+        case kSoftReLU:
+        case kSoftSign:
+        case kTanh:
+        case kSigmoid:
+            return 3;
+        default:
+            CHECK(false) << "missing activation type";
+    }
+    // unreachable
+    return -1;
+}
+
+}  // namespace activation
+
 DMLC_REGISTER_PARAMETER(ActivationParam);
 
 // This will determine the order of the inputs for backward computation.
@@ -44,24 +65,28 @@ struct ActivationGrad {
   const char *op_name;
   std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
                                           const std::vector<nnvm::NodeEntry>& ograds) const {
+    // ograds, output...
     std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
     heads.emplace_back(nnvm::NodeEntry{n, activation::kOut, 0});
 
     const NodeAttrs& attrs = n->attrs;
+    using namespace activation;
     int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
-    if (act_type == activation::kSoftSign) {
-      // for softsign need the inputs to compute the activation.
-      heads.push_back(n->inputs[activation::kData]);
-    }
-
-#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
     // for ReLU, no need to pass input data. This enables inplace optimization during the
     // forward pass.
-    if (act_type != activation::kReLU &&
-        act_type != activation::kSoftSign) {
-      heads.push_back(n->inputs[activation::kData]);
+    // check activation.cu \sa ActivationGradCompute
+    switch (act_type) {
+        case kReLU:
+            break;
+        case kSoftReLU:
+        case kSoftSign:
+        case kTanh:
+        case kSigmoid:
+            heads.push_back(n->inputs[activation::kData]);
+            break;
+        default:
+            CHECK(false) << "missing activation type";
     }
-#endif
     return MakeGradNode(op_name, n, heads, n->attrs.dict);
   }
 };
@@ -89,21 +114,19 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                 const std::vector<OpReqType>& req,
                                 const std::vector<NDArray>& outputs) {
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
-  bool relu = param.act_type == activation::kReLU;
-  CHECK_EQ(inputs.size(), relu ? 2U : 3U);
+  CHECK_EQ(inputs.size(), activation::GradNumInputs(param.act_type));
   if (SupportMKLDNN(inputs[0])) {
     MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
     // XXX: for y = relu(x), y is passed as "in_data" to Backward()
-    MKLDNNActivationBackward(attrs, ctx, inputs[0], relu ? inputs[1] : inputs[2], req[0],
+    const bool relu = param.act_type == activation::kReLU;
+    MKLDNNActivationBackward(attrs, ctx, inputs.at(0), relu ? inputs.at(1) : inputs.at(2), req[0],
                              outputs[0]);
-     MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+    MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
     return;
   }
   FallBackCompute(ActivationGradComputeImpl<cpu>, attrs, ctx, inputs, req, outputs);
 }
-#endif
 
-#if MXNET_USE_MKLDNN == 1
 inline static bool ActivationStorageType(const nnvm::NodeAttrs& attrs,
                                          const int dev_mask,
                                          DispatchMode* dispatch_mode,
@@ -122,16 +145,12 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs,
                                           std::vector<int> *in_attrs,
                                           std::vector<int> *out_attrs) {
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
-  if (param.act_type != activation::kReLU) {
-    CHECK_EQ(in_attrs->size(), 3U);
-  } else {
-    // for ReLU activation, the backward pass only needs ograd and output
-    CHECK_EQ(in_attrs->size(), 2U);
-  }
+  CHECK_EQ(in_attrs->size(), activation::GradNumInputs(param.act_type));
   return MKLDNNStorageType(attrs, dev_mask, SupportMKLDNNAct(param),
                            dispatch_mode, in_attrs, out_attrs);
 }
-#endif
+#endif  // MXNET_USE_MKLDNN == 1
+
 
 MXNET_OPERATOR_REGISTER_UNARY(Activation)
 .describe(R"code(Applies an activation function element-wise to the input.
@@ -163,18 +182,16 @@ The following activation functions are supported:
 
 NNVM_REGISTER_OP(_backward_Activation)
 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
-    int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
-    // for ReLU activation, the backward pass only needs ograd and output
-    if (act_type == activation::kReLU) return 2;
-    return 3;
-  })
+    const int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
+    return activation::GradNumInputs(act_type);
+})
 .set_num_outputs(1)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 #if MXNET_USE_MKLDNN == 1
 .set_attr<FInferStorageType>("FInferStorageType", BackwardActStorageType)
 #endif
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
   return std::vector<std::pair<int, int> >{{0, 0}};
 })
diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu
index 8892cc3..ec7db84 100644
--- a/src/operator/nn/activation.cu
+++ b/src/operator/nn/activation.cu
@@ -54,12 +54,13 @@ void ActivationCompute<gpu>(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+  const int act_type = param.act_type;
 
   // SoftReLU and kSoftSign are both not supported by CUDNN yet
-  if (param.act_type == activation::kSoftReLU) {
+  if (act_type == activation::kSoftReLU) {
     ActivationForward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(ctx,
       inputs[0], req[0], outputs[0]);
-  } else if (param.act_type == activation::kSoftSign) {
+  } else if (act_type == activation::kSoftSign) {
     ActivationForward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(ctx,
       inputs[0], req[0], outputs[0]);
   } else {
@@ -76,23 +77,28 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
                                 const std::vector<OpReqType>& req,
                                 const std::vector<TBlob>& outputs) {
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
-  bool relu = param.act_type == activation::kReLU;
-  CHECK_EQ(inputs.size(), relu ? 2U : 3U);
+  const int act_type = param.act_type;
+  CHECK_EQ(inputs.size(), activation::GradNumInputs(act_type));
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
 
   // both SoftReLU and SoftSign not supported by CUDNN yet
-  if (param.act_type == activation::kSoftReLU) {
+  if (act_type == activation::kSoftReLU) {
     ActivationBackward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
-      ctx, inputs[0], inputs[1], req[0], outputs[0]);
-  } else if (param.act_type == activation::kSoftSign) {
+      ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
+  } else if (act_type == activation::kSoftSign) {
     ActivationBackward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
-      ctx, inputs[0], inputs[2], req[0], outputs[0]);
-  } else {
-    MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]);
+  } else if (act_type == activation::kReLU) {
+    MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, {
       // XXX: for y = relu(x), y is passed as "in_data" to Backward()
-      get_cudnn_op<DType>(param).Backward(ctx, inputs[0], relu ? inputs[1] : inputs[2],
-                                          inputs[1], req[0], outputs[0]);
+      get_cudnn_op<DType>(param).Backward(ctx, inputs.at(0), inputs.at(1),
+                                          inputs.at(1), req[0], outputs[0]);
+    });
+  } else {
+    MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, {
+      get_cudnn_op<DType>(param).Backward(ctx, inputs.at(0), inputs.at(2),
+                                          inputs.at(1), req[0], outputs[0]);
     });
   }
 }
diff --git a/tests/cpp/operator/activation_perf.cc b/tests/cpp/operator/activation_perf.cc
index 1bd8ca8..bba8a3e 100644
--- a/tests/cpp/operator/activation_perf.cc
+++ b/tests/cpp/operator/activation_perf.cc
@@ -38,13 +38,27 @@ const kwargs_t basic_activation_args = { };
  * \brief Generic bidirectional sanity test
  */
 TEST(ACTIVATION_PERF, ExecuteBidirectional) {
+  using namespace std;
   TShape shape({5, 5});
-  kwargs_t kwargs = basic_activation_args;
-  kwargs.push_back({"act_type", "tanh"});
-
-  test::op::CoreOperatorRunner<float> runner;
-  runner.RunBidirectional(false, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
-          kwargs, "Activation", "_backward_Activation"), 1);
+  vector<string> activations = {
+    "relu",
+    "sigmoid",
+    "tanh",
+    "softrelu",
+    "softsign"
+  };
+  for (const string& activation : activations) {
+    kwargs_t activation_args = {{"act_type", activation}};
+    test::op::CoreOperatorRunner<float> runner;
+    runner.RunBidirectional(false, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
+            activation_args, "Activation", "_backward_Activation"), 1);
+  }
+  for (const string& activation : activations) {
+    kwargs_t activation_args = {{"act_type", activation}};
+    test::op::CoreOperatorRunner<float> runner;
+    runner.RunBidirectional(true, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
+            activation_args, "Activation", "_backward_Activation"), 1);
+  }
 }
 
 /*!
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 3049674..abe6b13 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -2411,7 +2411,7 @@ def test_reshape_activation():
             x_reshape = x.reshape(self.reshape)
             out = self.act(x_reshape)
             return out
-    acts = ["relu", "sigmoid", "tanh", "softrelu"]
+    acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"]
     for act in acts:
         x = mx.nd.random.uniform(-1, 1, shape=(4, 16, 32, 32))
         shape = (4, 32, 32, -1)
@@ -2433,7 +2433,7 @@ def test_slice_activation():
             out = self.act(x_slice)
             return out
 
-    acts = ["relu", "sigmoid", "tanh", "softrelu"]
+    acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"]
     for act in acts:
         x = mx.nd.random.uniform(-1, 1, shape=(8, 32, 64, 64))
         slice = [(0, 16, 32, 32), (4, 32, 64, 64)]
@@ -2457,7 +2457,7 @@ def test_reshape_activation_reshape_activation():
             y_reshape = y.reshape(self.reshape[1])
             out = self.act1(y_reshape)
             return out
-    acts = ["relu", "sigmoid", "tanh", "softrelu"]
+    acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"]
     for idx0, act0 in enumerate(acts):
         for idx1, act1 in enumerate(acts):
             if idx1 == idx0:
@@ -2484,7 +2484,7 @@ def test_slice_activation_slice_activation():
             y_slice = y.slice(begin=self.slice[1][0], end=self.slice[1][1])
             out = self.act1(y_slice)
             return out
-    acts = ["relu", "sigmoid", "tanh", "softrelu"]
+    acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"]
     for idx0, act0 in enumerate(acts):
         for idx1, act1 in enumerate(acts):
             if idx1 == idx0:
@@ -2512,7 +2512,7 @@ def test_reshape_activation_slice_activation():
             y_slice = y.slice(begin=self.slice[0], end=self.slice[1])
             out = self.act1(y_slice)
             return out
-    acts = ["relu", "sigmoid", "tanh", "softrelu"]
+    acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"]
     for idx0, act0 in enumerate(acts):
         for idx1, act1 in enumerate(acts):
             if idx1 == idx0:
@@ -2541,7 +2541,7 @@ def test_slice_activation_reshape_activation():
             y_reshape = y.reshape(self.reshape)
             out = self.act1(y_reshape)
             return out
-    acts = ["relu", "sigmoid", "tanh", "softrelu"]
+    acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"]
     for idx0, act0 in enumerate(acts):
         for idx1, act1 in enumerate(acts):
             if idx1 == idx0: