You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/06/12 14:35:11 UTC

[1/3] incubator-singa git commit: SINGA-190 - Add prelu layer and flatten layer

Repository: incubator-singa
Updated Branches:
  refs/heads/dev 6d69047ad -> 26df5ac03


SINGA-190 - Add prelu layer and flatten layer

Implement prelu layer and flatten layer for cpu version.

Write gtest for prelu and flatten layer.

Pass all tests.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/5afd81b7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/5afd81b7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/5afd81b7

Branch: refs/heads/dev
Commit: 5afd81b7f4841b15ce292b5b2e3e26c25a79b912
Parents: 04e23d1
Author: jixin <ji...@comp.nus.edu.sg>
Authored: Wed Jun 8 15:58:09 2016 +0800
Committer: jixin <ji...@comp.nus.edu.sg>
Committed: Sat Jun 11 16:46:59 2016 +0800

----------------------------------------------------------------------
 src/model/layer/flatten.cc |  62 ++++++++
 src/model/layer/flatten.h  |  54 +++++++
 src/model/layer/prelu.cc   | 169 +++++++++++++++++++++
 src/model/layer/prelu.h    |  60 ++++++++
 src/proto/model.proto      | 321 ++++++++++++++++++----------------------
 test/singa/test_flatten.cc | 156 +++++++++++++++++++
 test/singa/test_prelu.cc   | 149 +++++++++++++++++++
 7 files changed, 793 insertions(+), 178 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5afd81b7/src/model/layer/flatten.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/flatten.cc b/src/model/layer/flatten.cc
new file mode 100644
index 0000000..3ed37fe
--- /dev/null
+++ b/src/model/layer/flatten.cc
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "singa/model/layer.h"
+#include "./flatten.h"
+namespace singa {
+
+void Flatten::Setup(const LayerConf &conf) {
+  Layer::Setup(conf);
+  axis_ = conf.flatten_conf().axis();
+}
+
+const Tensor Flatten::Forward(int flag, const Tensor &input) {
+  Tensor output = input;
+  input_shape_ = input.shape();
+  if (!Axis()) {
+    // reshape to 1D
+    size_t dim = output.Size();
+    output.Reshape(Shape {
+      dim
+    });
+    output_shape_ = Shape { dim }
+    ;
+  } else {
+    // reshape to 2D
+    size_t dim1 = 1, dim2;
+    for (int i = 0; i < Axis(); i++)
+      dim1 *= output.shape(i);
+    dim2 = output.Size() / dim1;
+    output.Reshape(Shape {
+      dim1, dim2
+    });
+    output_shape_ = Shape { dim1, dim2 }
+    ;
+  }
+  return output;
+}
+
+const std::pair<Tensor, vector<Tensor> > Flatten::Backward(int flag,
+                                                           const Tensor &grad) {
+  vector<Tensor> param_grad;
+  Tensor input_grad = grad;
+  input_grad.Reshape(Input_shape());
+  return std::make_pair(input_grad, param_grad);
+}
+
+} // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5afd81b7/src/model/layer/flatten.h
----------------------------------------------------------------------
diff --git a/src/model/layer/flatten.h b/src/model/layer/flatten.h
new file mode 100644
index 0000000..cb36542
--- /dev/null
+++ b/src/model/layer/flatten.h
@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SRC_MODEL_LAYER_FLATTEN_H_
+#define SRC_MODEL_LAYER_FLATTEN_H_
+#include <utility>
+#include <string>
+#include <vector>
+#include "singa/model/layer.h"
+
+namespace singa {
+class Flatten : public Layer {
+public:
+  /// \copydoc Layer::layer_type();
+  const std::string layer_type() const override { return "Flatten"; }
+
+  /// \copydoc Layer::Setup(const LayerConf&);
+  void Setup(const LayerConf &conf) override;
+
+  /// \copydoc Layer::Forward(int flag, const Tensor&);
+  const Tensor Forward(int flag, const Tensor &input) override;
+
+  /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&);
+  const std::pair<Tensor, vector<Tensor> > Backward(int flag,
+                                                    const Tensor &grad)
+      override;
+
+  const int Axis() const { return axis_; }
+  const Shape Input_shape() const { return input_shape_; }
+  const Shape Output_shape() const { return output_shape_; }
+
+protected:
+  /// flatten layer reshape the input to 2D, one from 0 to axis_-1, one from
+  /// axis_ to end.
+  /// if axis_ is 0, reshape the input to 1D.
+  int axis_;
+  Shape input_shape_, output_shape_;
+};
+}      // namespace singa
+#endif // SRC_MODEL_LAYER_FLATTEN_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5afd81b7/src/model/layer/prelu.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/prelu.cc b/src/model/layer/prelu.cc
new file mode 100644
index 0000000..1d6a2e7
--- /dev/null
+++ b/src/model/layer/prelu.cc
@@ -0,0 +1,169 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "singa/model/layer.h"
+#include "./prelu.h"
+namespace singa {
+
+void PReLU::Setup(const LayerConf &conf) {
+  Layer::Setup(conf);
+  channel_shared_ = conf.prelu_conf().channel_shared();
+  format_ = conf.prelu_conf().format();
+  // Push back params into param_values_
+  for (const auto &spec : conf.param())
+    param_specs_.push_back(spec);
+  param_values_.push_back(&a_);
+}
+
+const Tensor PReLU::Forward(int flag, const Tensor &input) {
+  Tensor output;
+  if (!channel_shared_) {
+    size_t n, c, h, w;
+    Tensor temp = (input <= 0.f);
+    if (temp.nDim() == 4) {
+      if (format_ == "NCHW") {
+        n = temp.shape(0);
+        c = temp.shape(1);
+        h = temp.shape(2);
+        w = temp.shape(3);
+        temp.Reshape(Shape {
+          n *c, h *w
+        });
+        Tensor temp_a(Shape {
+          n, c
+        });
+        Uniform(1.f, 1.f, &temp_a);
+        MultRow(a_, &temp_a);
+        temp_a.Reshape(Shape {
+          n *c
+        });
+        MultColumn(temp_a, &temp);
+      } else if (format_ == "NHWC") {
+        n = temp.shape(0);
+        h = temp.shape(1);
+        w = temp.shape(2);
+        c = temp.shape(3);
+        temp.Reshape(Shape {
+          n *h *w, c
+        });
+        MultRow(a_, &temp);
+      } else {
+        LOG(FATAL) << "Incorrect input format for prelu layer.";
+      }
+    } else {
+      LOG(FATAL) << "Incorrect input format for prelu layer.";
+    }
+    output = input * ((input > 0.f) + temp);
+  } else {
+    // share the first param of Tensor A along all channels
+    const float a = a_.data<const float *>()[0];
+    output = input * ((input > 0.f) + (input <= 0.f) * a);
+  }
+  if (flag & kTrain)
+    buf_.push(input);
+  return output;
+}
+
+const std::pair<Tensor, vector<Tensor> > PReLU::Backward(int flag,
+                                                         const Tensor &grad) {
+  vector<Tensor> param_grad;
+  CHECK(!buf_.empty());
+  Tensor input_grad, input = buf_.top();
+  buf_.pop();
+  Tensor da;
+  da.ResetLike(a_);
+  if (!channel_shared_) {
+    size_t n, c, h, w;
+    Tensor temp1 = (input <= 0.f);
+    if (temp1.nDim() == 4) {
+      if (format_ == "NCHW") {
+        n = temp1.shape(0);
+        c = temp1.shape(1);
+        h = temp1.shape(2);
+        w = temp1.shape(3);
+        temp1.Reshape(Shape {
+          n *c, h *w
+        });
+        Tensor temp_a(Shape {
+          n, c
+        });
+        Uniform(1.f, 1.f, &temp_a);
+        MultRow(a_, &temp_a);
+        temp_a.Reshape(Shape {
+          n *c
+        });
+        MultColumn(temp_a, &temp1);
+        temp1.Reshape(Shape {
+          n, c, h, w
+        });
+      } else if (format_ == "NHWC") {
+        n = temp1.shape(0);
+        h = temp1.shape(1);
+        w = temp1.shape(2);
+        c = temp1.shape(3);
+        temp1.Reshape(Shape {
+          n *h *w, c
+        });
+        MultRow(a_, &temp1);
+        temp1.Reshape(Shape {
+          n, h, w, c
+        });
+      } else {
+        LOG(FATAL) << "Incorrect input format for prelu layer.";
+      }
+    } else {
+      LOG(FATAL) << "Incorrect input format for prelu layer.";
+    }
+    input_grad = grad * input * ((input > 0.f) + temp1);
+    Tensor temp2 = grad * input * (input <= 0.f), temp3(Shape {
+      n *c
+    });
+    if (format_ == "NCHW") {
+      temp2.Reshape(Shape {
+        n *c, h *w
+      });
+      SumColumns(temp2, &temp3);
+      temp3.Reshape(Shape {
+        n, c
+      });
+      SumRows(temp3, &da);
+    } else if (format_ == "NHWC") {
+      temp2.Reshape(Shape {
+        n *h *w, c
+      });
+      SumRows(temp2, &da);
+    }
+  } else {
+    // share the first param of Tensor A along all channels
+    const float a = a_.data<const float *>()[0];
+    input_grad = grad * input * ((input > 0.f) + (input <= 0.f) * a);
+    Tensor temp = grad * input * (input <= 0.f);
+    float sum = Sum<float>(temp);
+    Uniform(1.f, 1.f, &da);
+    da *= sum;
+  }
+  param_grad.push_back(da);
+  return std::make_pair(input_grad, param_grad);
+}
+
+void PReLU::ToDevice(Device *device) {
+  Layer::ToDevice(device);
+  a_.ToDevice(device);
+}
+
+} // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5afd81b7/src/model/layer/prelu.h
----------------------------------------------------------------------
diff --git a/src/model/layer/prelu.h b/src/model/layer/prelu.h
new file mode 100644
index 0000000..1a01d98
--- /dev/null
+++ b/src/model/layer/prelu.h
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SINGA_MODEL_LAYER_PRELU_H_
+#define SINGA_MODEL_LAYER_PRELU_H_
+#include <utility>
+#include <string>
+#include <vector>
+#include "singa/model/layer.h"
+
+namespace singa {
+class PReLU : public Layer {
+ public:
+  /// \copydoc Layer::layer_type()
+   const std::string layer_type() const override { return "PReLU"; }
+
+  /// \copydoc Layer::Setup(const LayerConf&);
+  void Setup(const LayerConf &conf) override;
+
+  /// \copydoc Layer::Forward(int flag, const Tensor&)
+  const Tensor Forward(int flag, const Tensor &input) override;
+
+  /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&);
+  const std::pair<Tensor, vector<Tensor> > Backward(int flag,
+                                                    const Tensor &grad)
+      override;
+
+  void ToDevice(Device *device);
+
+  const bool Channel_shared() const { return channel_shared_; }
+  const Tensor A() const { return a_; }
+  const std::string Format() const { return format_; }
+
+  void Set_a(Tensor a) {
+    a_.ResetLike(a);
+    a_.CopyData(a);
+  }
+
+ protected:
+  bool channel_shared_;
+  std::string format_; // format_ has two valid value, i.e. NCHW, NHWC
+  Tensor a_; // shape of a_ is 2D, i.e. (channels, 1)
+  std::stack<Tensor> buf_;
+};
+}  // namespace singa
+#endif  // SINGA_MODEL_LAYER_PRELU_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5afd81b7/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index d368296..1d1f3cf 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -33,64 +33,59 @@ package singa;
 /// using Python (or C++/Java).
 
 // Specifies the shape (dimensions) of a Blob.
-message BlobShape {
-  repeated int64 dim = 1 [packed = true];
-}
+message BlobShape { repeated int64 dim = 1[packed = true]; }
 
 message BlobProto {
   optional BlobShape shape = 7;
-  repeated float data = 5 [packed = true];
-  repeated float diff = 6 [packed = true];
-  repeated double double_data = 8 [packed = true];
-  repeated double double_diff = 9 [packed = true];
+  repeated float data = 5[packed = true];
+  repeated float diff = 6[packed = true];
+  repeated double double_data = 8[packed = true];
+  repeated double double_diff = 9[packed = true];
 
   // 4D dimensions -- deprecated.  Use "shape" instead.
-  optional int32 num = 1 [default = 0];
-  optional int32 channels = 2 [default = 0];
-  optional int32 height = 3 [default = 0];
-  optional int32 width = 4 [default = 0];
+  optional int32 num = 1[default = 0];
+  optional int32 channels = 2[default = 0];
+  optional int32 height = 3[default = 0];
+  optional int32 width = 4[default = 0];
 }
 
 message FillerConf {
   // The filler type, case insensitive
-  optional string type = 1 [default = 'constant'];
-  optional float value = 2 [default = 0]; // the value in constant filler
-  optional float min = 3 [default = 0]; // the min value in uniform filler
-  optional float max = 4 [default = 1]; // the max value in uniform filler
-  optional float mean = 5 [default = 0]; // the mean value in Gaussian filler
-  optional float std = 6 [default = 1]; // the std value in Gaussian filler
+  optional string type = 1[default = 'constant'];
+  optional float value = 2[default = 0]; // the value in constant filler
+  optional float min = 3[default = 0];   // the min value in uniform filler
+  optional float max = 4[default = 1];   // the max value in uniform filler
+  optional float mean = 5[default = 0];  // the mean value in Gaussian filler
+  optional float std = 6[default = 1];   // the std value in Gaussian filler
   // The expected number of non-zero output weights for a given input in
   // Gaussian filler -- the default -1 means don't perform sparsification.
   /* optional int32 sparse = 7 [default = -1]; */
   // Normalize the filler variance by fan_in, fan_out, or their average.
   // Applies to 'xavier' and 'msra' fillers.
   enum VarianceNorm {
-    FAN_IN = 0;
-    FAN_OUT = 1;
-    AVERAGE = 2;
-  }
-  optional VarianceNorm variance_norm = 8 [default = FAN_IN];
+    FAN_IN = 0; FAN_OUT = 1; AVERAGE = 2;
+  } optional VarianceNorm variance_norm = 8[default = FAN_IN];
 }
 
 /// SINGA message
 message OptimizerConf {
   // case insensitive
-  optional string type = 1 [default = "sgd"];
+  optional string type = 1[default = "sgd"];
 
   // used by RMSprop and Adadelta
-  optional float rho = 2 [default = 0.001];
+  optional float rho = 2[default = 0.001];
 
   // used by Adam and AdamMax
-  optional float beta_1 = 3 [default = 0.9];
-  optional float beta_2 = 4 [default = 0.999];
+  optional float beta_1 = 3[default = 0.9];
+  optional float beta_2 = 4[default = 0.999];
 
   // used by vanilla sgd and nesterov
-  optional float momentum = 5 [default = 0.9];
+  optional float momentum = 5[default = 0.9];
 }
 
 message ConstraintConf {
   // case insensitive to limit the parameter value/gradient scale
-  optional string type = 1 [default = "l2"];
+  optional string type = 1[default = "l2"];
   // e.g., the threshold for limiting the parameter scale.
   optional float threshold = 2;
 }
@@ -98,7 +93,7 @@ message ConstraintConf {
 /// SINGA message
 message RegularizerConf {
   // case insensitive to regularize the parameters, e.g., L2.
-  optional string type = 1 [default = "l2"];
+  optional string type = 1[default = "l2"];
   // e.g., the weight decay for L2 regularizer
   optional float coefficient = 2;
 }
@@ -124,10 +119,10 @@ message ParamSpec {
   */
 
   // The multiplier on the global learning rate for this parameter.
-  optional float lr_mult = 3 [default = 1.0];
+  optional float lr_mult = 3[default = 1.0];
 
   // The multiplier on the global weight decay for this parameter.
-  optional float decay_mult = 4 [default = 1.0];
+  optional float decay_mult = 4[default = 1.0];
 
   // SINGA uses this filed internally. Users just configure the fillers in
   // Layer specific conf message as caffe (style).
@@ -137,14 +132,13 @@ message ParamSpec {
 }
 
 enum Phase {
-  kTrain = 4;
-  kEval = 8;
-}
-// NOTE
-// Update the next available ID when you add a new LayerConf field.
-//
-// LayerConf next available layer-specific ID: 139 (last added: tile_param)
-message LayerConf {
+  kTrain = 4; kEval = 8;
+}
+    // NOTE
+    // Update the next available ID when you add a new LayerConf field.
+    //
+    // LayerConf next available layer-specific ID: 139 (last added: tile_param)
+    message LayerConf {
   optional string name = 1; // the layer name
   optional string type = 2; // the layer type
   /* repeated string bottom = 3; // the name of each bottom blob */
@@ -248,7 +242,8 @@ message TransformationConf {
   optional uint32 crop_size = 3 [default = 0];
   // mean_file and mean_value cannot be specified at the same time
   optional string mean_file = 4;
-  // if specified can be repeated once (would substract it from all the channels)
+  // if specified can be repeated once (would substract it from all the
+channels)
   // or can be repeated the same number of times as channels
   // (would subtract them from the corresponding channel)
   repeated float mean_value = 5;
@@ -265,34 +260,33 @@ message LossConf {
   optional int32 ignore_label = 1;
   // If true, normalize each batch across all instances (including spatial
   // dimesions, but not ignored instances); else, divide by batch size only.
-  optional bool normalize = 2 [default = true];
+  optional bool normalize = 2[default = true];
 }
 
 message MetricConf {
   // When computing accuracy, count as correct by comparing the true label to
   // the top k scoring classes.  By default, only compare to the top scoring
   // class (i.e. argmax).
-  optional uint32 top_k = 1 [default = 1];
+  optional uint32 top_k = 1[default = 1];
 
   // The "label" axis of the prediction blob, whose argmax corresponds to the
   // predicted label -- may be negative to index from the end (e.g., -1 for the
   // last axis).  For example, if axis == 1 and the predictions are
   // (N x C x H x W), the label blob is expected to contain N*H*W ground truth
   // labels with integer values in {0, 1, ..., C-1}.
-  optional int32 axis = 2 [default = 1];
+  optional int32 axis = 2[default = 1];
 
   // If specified, ignore instances with the given label.
   optional int32 ignore_label = 3;
 }
-// Messages that store hyper-parameters used by individual layer types follow, in
+// Messages that store hyper-parameters used by individual layer types follow,
+// in
 // alphabetical order.
 
-
-
 message ArgMaxConf {
   // If true produce pairs (argmax, maxval)
-  optional bool out_max_val = 1 [default = false];
-  optional uint32 top_k = 2 [default = 1];
+  optional bool out_max_val = 1[default = false];
+  optional uint32 top_k = 2[default = 1];
   // The axis along which to maximise -- may be negative to index from the
   // end (e.g., -1 for the last axis).
   // By default ArgMaxLayer maximizes over the flattened trailing dimensions
@@ -305,54 +299,51 @@ message ConcatConf {
   // end (e.g., -1 for the last axis).  Other axes must have the
   // same dimension for all the bottom blobs.
   // By default, ConcatLayer concatenates blobs along the "channels" axis (1).
-  optional int32 axis = 2 [default = 1];
+  optional int32 axis = 2[default = 1];
 
   // DEPRECATED: alias for "axis" -- does not support negative indexing.
-  optional uint32 concat_dim = 1 [default = 1];
+  optional uint32 concat_dim = 1[default = 1];
 }
 
 message ContrastiveLossConf {
   // margin for dissimilar pair
-  optional float margin = 1 [default = 1.0];
+  optional float margin = 1[default = 1.0];
   // The first implementation of this cost did not exactly match the cost of
   // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2.
   // legacy_version = false (the default) uses (margin - d)^2 as proposed in the
   // Hadsell paper. New models should probably use this version.
   // legacy_version = true uses (margin - d^2). This is kept to support /
   // reproduce existing models and results
-  optional bool legacy_version = 2 [default = false];
+  optional bool legacy_version = 2[default = false];
 }
 
 message ConvolutionConf {
   optional uint32 num_output = 1; // The number of outputs for the layer
-  optional bool bias_term = 2 [default = true]; // whether to have bias terms
+  optional bool bias_term = 2[default = true]; // whether to have bias terms
 
   // Pad, kernel size, and stride are all given as a single value for equal
   // dimensions in all spatial dimensions, or once per spatial dimension.
-  repeated uint32 pad = 3; // The padding size; defaults to 0
+  repeated uint32 pad = 3;         // The padding size; defaults to 0
   repeated uint32 kernel_size = 4; // The kernel size
-  repeated uint32 stride = 6; // The stride; defaults to 1
+  repeated uint32 stride = 6;      // The stride; defaults to 1
 
   // For 2D convolution only, the *_h and *_w versions may also be used to
   // specify both spatial dimensions.
-  optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only)
-  optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only)
-  optional uint32 kernel_h = 11; // The kernel height (2D only)
-  optional uint32 kernel_w = 12; // The kernel width (2D only)
-  optional uint32 stride_h = 13; // The stride height (2D only)
-  optional uint32 stride_w = 14; // The stride width (2D only)
+  optional uint32 pad_h = 9[default = 0];  // The padding height (2D only)
+  optional uint32 pad_w = 10[default = 0]; // The padding width (2D only)
+  optional uint32 kernel_h = 11;           // The kernel height (2D only)
+  optional uint32 kernel_w = 12;           // The kernel width (2D only)
+  optional uint32 stride_h = 13;           // The stride height (2D only)
+  optional uint32 stride_w = 14;           // The stride width (2D only)
 
   // SINGA: not supported.
   // optional uint32 group = 5 [default = 1]; // The group size for group conv
 
   optional FillerConf weight_filler = 7; // The filler for the weight
-  optional FillerConf bias_filler = 8; // The filler for the bias
+  optional FillerConf bias_filler = 8;   // The filler for the bias
   enum Engine {
-    DEFAULT = 0;
-    CAFFE = 1;
-    CUDNN = 2;
-  }
-  optional Engine engine = 15 [default = DEFAULT];
+    DEFAULT = 0; CAFFE = 1; CUDNN = 2;
+  } optional Engine engine = 15[default = DEFAULT];
 
   // The axis to interpret as "channels" when performing convolution.
   // Preceding dimensions are treated as independent inputs;
@@ -374,13 +365,12 @@ message ConvolutionConf {
   // SINGA: not supported;
   // optional bool force_nd_im2col = 17 [default = false];
 
-
   // SINGA: add by xiangrui
   // cudnn workspace size in MB
-  optional int32 workspace_byte_limit = 50 [default = 512];
+  optional int32 workspace_byte_limit = 50[default = 512];
   // cudnn algorithm preference
   // options: "fastest", "limited_workspace", "no_workspace"
-  optional string prefer = 51 [default = "fastest"];
+  optional string prefer = 51[default = "fastest"];
   // input shape
   optional int32 channels = 52;
   optional int32 height = 53;
@@ -424,7 +414,7 @@ message DataConf {
 */
 
 message DropoutConf {
-  optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio
+  optional float dropout_ratio = 1[default = 0.5]; // dropout ratio
 }
 
 // DummyDataLayer fills any number of arbitrarily shaped blobs with random
@@ -448,16 +438,13 @@ message DummyDataConf {
 
 message EltwiseConf {
   enum EltwiseOp {
-    PROD = 0;
-    SUM = 1;
-    MAX = 2;
-  }
-  optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation
+    PROD = 0; SUM = 1; MAX = 2;
+  } optional EltwiseOp operation = 1[default = SUM]; // element-wise operation
   repeated float coeff = 2; // blob-wise coefficient for SUM operation
 
   // Whether to use an asymptotically slower (for >2 inputs) but stabler method
   // of computing the gradient for the PROD operation. (No effect for SUM op.)
-  optional bool stable_prod_grad = 3 [default = true];
+  optional bool stable_prod_grad = 3[default = true];
 }
 
 // Message that stores hyper-parameters used by EmbedLayer
@@ -468,9 +455,9 @@ message EmbedConf {
   // 1 greater than the maximum possible input value.
   optional uint32 input_dim = 2;
 
-  optional bool bias_term = 3 [default = true]; // Whether to use a bias term
-  optional FillerConf weight_filler = 4; // The filler for the weight
-  optional FillerConf bias_filler = 5; // The filler for the bias
+  optional bool bias_term = 3[default = true]; // Whether to use a bias term
+  optional FillerConf weight_filler = 4;       // The filler for the weight
+  optional FillerConf bias_filler = 5;         // The filler for the bias
 
 }
 
@@ -479,21 +466,21 @@ message ExpConf {
   // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0.
   // Or if base is set to the default (-1), base is set to e,
   // so y = exp(shift + scale * x).
-  optional float base = 1 [default = -1.0];
-  optional float scale = 2 [default = 1.0];
-  optional float shift = 3 [default = 0.0];
+  optional float base = 1[default = -1.0];
+  optional float scale = 2[default = 1.0];
+  optional float shift = 3[default = 0.0];
 }
 
 /// Message that stores hyper-parameters used by FlattenLayer
 message FlattenConf {
   // The first axis to flatten: all preceding axes are retained in the output.
   // May be negative to index from the end (e.g., -1 for the last axis).
-  optional int32 axis = 1 [default = 1];
+  optional int32 axis = 1[default = 1];
 
   // The last axis to flatten: all following axes are retained in the output.
   // May be negative to index from the end (e.g., the default -1 for the last
   // axis).
-  optional int32 end_axis = 2 [default = -1];
+  optional int32 end_axis = 2[default = -1];
 }
 
 /*
@@ -519,11 +506,10 @@ message HDF5OutputConf {
 
 message HingeLossConf {
   enum Norm {
-    L1 = 1;
-    L2 = 2;
+    L1 = 1; L2 = 2;
   }
-  // Specify the Norm to use L1 or L2
-  optional Norm norm = 1 [default = L1];
+      // Specify the Norm to use L1 or L2
+      optional Norm norm = 1[default = L1];
 }
 
 /*
@@ -566,29 +552,29 @@ message InfogainLossConf {
 
 message InnerProductConf {
   optional uint32 num_output = 1; // The number of outputs for the layer
-  optional bool bias_term = 2 [default = true]; // whether to have bias terms
-  optional FillerConf weight_filler = 3; // The filler for the weight
-  optional FillerConf bias_filler = 4; // The filler for the bias
+  optional bool bias_term = 2[default = true]; // whether to have bias terms
+  optional FillerConf weight_filler = 3;       // The filler for the weight
+  optional FillerConf bias_filler = 4;         // The filler for the bias
 
   // The first axis to be lumped into a single inner product computation;
   // all preceding axes are retained in the output.
   // May be negative to index from the end (e.g., -1 for the last axis).
-  optional int32 axis = 5 [default = 1];
+  optional int32 axis = 5[default = 1];
 }
 
 message DenseConf {
   optional uint32 num_output = 1; // The number of outputs for the layer
-  optional bool bias_term = 2 [default = true]; // whether to have bias terms
-  optional FillerConf weight_filler = 3; // The filler for the weight
-  optional FillerConf bias_filler = 4; // The filler for the bias
+  optional bool bias_term = 2[default = true]; // whether to have bias terms
+  optional FillerConf weight_filler = 3;       // The filler for the weight
+  optional FillerConf bias_filler = 4;         // The filler for the bias
 
   // The first axis to be lumped into a single inner product computation;
   // all preceding axes are retained in the output.
   // May be negative to index from the end (e.g., -1 for the last axis).
-  optional int32 axis = 5 [default = 1];
+  optional int32 axis = 5[default = 1];
 
   optional uint32 num_input = 20; // The number of inputs for the layer
-  optional bool transpose = 21 [default = false]; // whether transpose or not
+  optional bool transpose = 21[default = false]; // whether transpose or not
 }
 
 // Message that stores hyper-parameters used by LogLayer
@@ -596,22 +582,20 @@ message LogConf {
   // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0.
   // Or if base is set to the default (-1), base is set to e,
   // so y = ln(shift + scale * x) = log_e(shift + scale * x)
-  optional float base = 1 [default = -1.0];
-  optional float scale = 2 [default = 1.0];
-  optional float shift = 3 [default = 0.0];
+  optional float base = 1[default = -1.0];
+  optional float scale = 2[default = 1.0];
+  optional float shift = 3[default = 0.0];
 }
 
 // Message that stores hyper-parameters used by LRNLayer
 message LRNConf {
-  optional uint32 local_size = 1 [default = 5];
-  optional float alpha = 2 [default = 1.];
-  optional float beta = 3 [default = 0.75];
+  optional uint32 local_size = 1[default = 5];
+  optional float alpha = 2[default = 1.];
+  optional float beta = 3[default = 0.75];
   enum NormRegion {
-    ACROSS_CHANNELS = 0;
-    WITHIN_CHANNEL = 1;
-  }
-  optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS];
-  optional float k = 5 [default = 1.];
+    ACROSS_CHANNELS = 0; WITHIN_CHANNEL = 1;
+  } optional NormRegion norm_region = 4[default = ACROSS_CHANNELS];
+  optional float k = 5[default = 1.];
 }
 
 message MemoryDataConf {
@@ -623,33 +607,30 @@ message MemoryDataConf {
 
 message MVNConf {
   // This parameter can be set to false to normalize mean only
-  optional bool normalize_variance = 1 [default = true];
+  optional bool normalize_variance = 1[default = true];
 
   // This parameter can be set to true to perform DNN-like MVN
-  optional bool across_channels = 2 [default = false];
+  optional bool across_channels = 2[default = false];
 
   // Epsilon for not dividing by zero while normalizing variance
-  optional float eps = 3 [default = 1e-9];
+  optional float eps = 3[default = 1e-9];
 }
 
 message PoolingConf {
   enum PoolMethod {
-    MAX = 0;
-    AVE = 1;
-    STOCHASTIC = 2;
-  }
-  optional PoolMethod pool = 1 [default = MAX]; // The pooling method
+    MAX = 0; AVE = 1; STOCHASTIC = 2;
+  } optional PoolMethod pool = 1[default = MAX]; // The pooling method
   // Pad, kernel size, and stride are all given as a single value for equal
   // dimensions in height and width or as Y, X pairs.
-  optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X)
-  optional uint32 pad_h = 9 [default = 0]; // The padding height
-  optional uint32 pad_w = 10 [default = 0]; // The padding width
-  optional uint32 kernel_size = 2; // The kernel size (square)
-  optional uint32 kernel_h = 5; // The kernel height
-  optional uint32 kernel_w = 6; // The kernel width
-  optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X)
-  optional uint32 stride_h = 7; // The stride height
-  optional uint32 stride_w = 8; // The stride width
+  optional uint32 pad = 4[default = 0];    // The padding size (equal in Y, X)
+  optional uint32 pad_h = 9[default = 0];  // The padding height
+  optional uint32 pad_w = 10[default = 0]; // The padding width
+  optional uint32 kernel_size = 2;         // The kernel size (square)
+  optional uint32 kernel_h = 5;            // The kernel height
+  optional uint32 kernel_w = 6;            // The kernel width
+  optional uint32 stride = 3[default = 1]; // The stride (equal in Y, X)
+  optional uint32 stride_h = 7;            // The stride height
+  optional uint32 stride_w = 8;            // The stride width
   /*
   enum Engine {
     DEFAULT = 0;
@@ -660,20 +641,20 @@ message PoolingConf {
   */
   // If global_pooling then it will pool over the size of the bottom by doing
   // kernel_h = bottom->height and kernel_w = bottom->width
-  optional bool global_pooling = 12 [default = false];
+  optional bool global_pooling = 12[default = false];
   // Shape of source
   optional int32 channels = 50;
   optional int32 height = 51;
   optional int32 width = 52;
   // whether to propagate nan
-  optional bool nan_prop = 53 [default = false];
+  optional bool nan_prop = 53[default = false];
 }
 
 message PowerConf {
   // PowerLayer computes outputs y = (shift + scale * x) ^ power.
-  optional float power = 1 [default = 1.0];
-  optional float scale = 2 [default = 1.0];
-  optional float shift = 3 [default = 0.0];
+  optional float power = 1[default = 1.0];
+  optional float scale = 2[default = 1.0];
+  optional float shift = 3[default = 0.0];
 }
 /*
 message PythonConf {
@@ -684,7 +665,8 @@ message PythonConf {
   // string, dictionary in Python dict format, JSON, etc. You may parse this
   // string in `setup` method and use it in `forward` and `backward`.
   optional string param_str = 3 [default = ''];
-  // Whether this PythonLayer is shared among worker solvers during data parallelism.
+  // Whether this PythonLayer is shared among worker solvers during data
+parallelism.
   // If true, each worker solver sequentially run forward from this layer.
   // This value should be set true if you are using it as a data layer.
   optional bool share_in_parallel = 4 [default = false];
@@ -694,13 +676,8 @@ message PythonConf {
 // Message that stores hyper-parameters used by ReductionLayer
 message ReductionConf {
   enum ReductionOp {
-    SUM = 1;
-    ASUM = 2;
-    SUMSQ = 3;
-    MEAN = 4;
-  }
-
-  optional ReductionOp operation = 1 [default = SUM]; // reduction operation
+    SUM = 1; ASUM = 2; SUMSQ = 3; MEAN = 4;
+  } optional ReductionOp operation = 1[default = SUM]; // reduction operation
 
   // The first axis to reduce to a scalar -- may be negative to index from the
   // end (e.g., -1 for the last axis).
@@ -715,9 +692,9 @@ message ReductionConf {
   // If axis == 0 (the default), the output Blob always has the empty shape
   // (count 1), performing reduction across the entire input --
   // often useful for creating new loss functions.
-  optional int32 axis = 2 [default = 0];
+  optional int32 axis = 2[default = 0];
 
-  optional float coeff = 3 [default = 1.0]; // coefficient for output
+  optional float coeff = 3[default = 1.0]; // coefficient for output
 }
 
 // Message that stores hyper-parameters used by ReLULayer
@@ -727,7 +704,7 @@ message ReLUConf {
   // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities
   // improve neural network acoustic models. In ICML Workshop on Deep Learning
   // for Audio, Speech, and Language Processing.
-  optional float negative_slope = 1 [default = 0];
+  optional float negative_slope = 1[default = 0];
   /*
   enum Engine {
     DEFAULT = 0;
@@ -798,58 +775,50 @@ message ReshapeConf {
   //   reshape_param { shape { dim: 2  dim: 1  dim: 8  }  }
   //   reshape_param { shape { dim: 1 }  axis: 1  num_axes: 0 }
   //
-  optional int32 axis = 2 [default = 0];
-  optional int32 num_axes = 3 [default = -1];
+  optional int32 axis = 2[default = 0];
+  optional int32 num_axes = 3[default = -1];
 }
 
 message SigmoidConf {
   enum Engine {
-    DEFAULT = 0;
-    CAFFE = 1;
-    CUDNN = 2;
-  }
-  optional Engine engine = 1 [default = DEFAULT];
+    DEFAULT = 0; CAFFE = 1; CUDNN = 2;
+  } optional Engine engine = 1[default = DEFAULT];
 }
 
 message SliceConf {
   // The axis along which to slice -- may be negative to index from the end
   // (e.g., -1 for the last axis).
   // By default, SliceLayer concatenates blobs along the "channels" axis (1).
-  optional int32 axis = 3 [default = 1];
+  optional int32 axis = 3[default = 1];
   repeated uint32 slice_point = 2;
 
   // DEPRECATED: alias for "axis" -- does not support negative indexing.
-  optional uint32 slice_dim = 1 [default = 1];
+  optional uint32 slice_dim = 1[default = 1];
 }
 
-// Message that stores hyper-parameters used by SoftmaxLayer, SoftmaxWithLossLayer
+// Message that stores hyper-parameters used by SoftmaxLayer,
+// SoftmaxWithLossLayer
 message SoftmaxConf {
   enum Engine {
-    DEFAULT = 0;
-    CAFFE = 1;
-    CUDNN = 2;
-  }
-  optional Engine engine = 1 [default = DEFAULT];
+    DEFAULT = 0; CAFFE = 1; CUDNN = 2;
+  } optional Engine engine = 1[default = DEFAULT];
 
   // The axis along which to perform the softmax -- may be negative to index
   // from the end (e.g., -1 for the last axis).
   // Any other axes will be evaluated as independent softmaxes.
-  optional int32 axis = 2 [default = 1];
+  optional int32 axis = 2[default = 1];
 }
 
 message TanHConf {
   enum Engine {
-    DEFAULT = 0;
-    CAFFE = 1;
-    CUDNN = 2;
-  }
-  optional Engine engine = 1 [default = DEFAULT];
+    DEFAULT = 0; CAFFE = 1; CUDNN = 2;
+  } optional Engine engine = 1[default = DEFAULT];
 }
 
 // Message that stores hyper-parameters used by TileLayer
 message TileConf {
   // The index of the axis to tile.
-  optional int32 axis = 1 [default = 1];
+  optional int32 axis = 1[default = 1];
 
   // The number of copies (tiles) of the blob to output.
   optional int32 tiles = 2;
@@ -857,7 +826,7 @@ message TileConf {
 
 // Message that stores hyper-parameters used by ThresholdLayer
 message ThresholdConf {
-  optional float threshold = 1 [default = 0]; // Strictly positive values
+  optional float threshold = 1[default = 0]; // Strictly positive values
 }
 
 /*
@@ -897,18 +866,12 @@ message WindowDataConf {
 
 message SPPConf {
   enum PoolMethod {
-    MAX = 0;
-    AVE = 1;
-    STOCHASTIC = 2;
-  }
-  optional uint32 pyramid_height = 1;
-  optional PoolMethod pool = 2 [default = MAX]; // The pooling method
+    MAX = 0; AVE = 1; STOCHASTIC = 2;
+  } optional uint32 pyramid_height = 1;
+  optional PoolMethod pool = 2[default = MAX]; // The pooling method
   enum Engine {
-    DEFAULT = 0;
-    CAFFE = 1;
-    CUDNN = 2;
-  }
-  optional Engine engine = 6 [default = DEFAULT];
+    DEFAULT = 0; CAFFE = 1; CUDNN = 2;
+  } optional Engine engine = 6[default = DEFAULT];
 }
 
 message PReLUConf {
@@ -918,13 +881,15 @@ message PReLUConf {
   // Initial value of a_i. Default is a_i=0.25 for all i.
   optional FillerConf filler = 1;
   // Whether or not slope paramters are shared across channels.
-  optional bool channel_shared = 2 [default = false];
+  optional bool channel_shared = 2[default = false];
+  // format of the input. Default is NCHW.
+  optional string format = 50[default = "NCHW"];
 }
 
 message BatchNormConf {
   // Used in the moving average computation runningMean =
   // newMean*factor + runningMean*(1-factor).
-  optional double factor = 1 [default = 0.9];
+  optional double factor = 1[default = 0.9];
   // input shape
   optional int32 channels = 2;
   optional int32 height = 3;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5afd81b7/test/singa/test_flatten.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_flatten.cc b/test/singa/test_flatten.cc
new file mode 100644
index 0000000..906e4b8
--- /dev/null
+++ b/test/singa/test_flatten.cc
@@ -0,0 +1,156 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "../src/model/layer/flatten.h"
+#include "gtest/gtest.h"
+
+using singa::Flatten;
+TEST(Flatten, Setup) {
+  Flatten flt;
+  EXPECT_EQ("Flatten", flt.layer_type());
+
+  singa::LayerConf conf;
+  singa::FlattenConf *flattenconf = conf.mutable_flatten_conf();
+  flattenconf->set_axis(1);
+
+  flt.Setup(conf);
+  EXPECT_EQ(1, flt.Axis());
+}
+
+TEST(Flatten, ForwardCPU) {
+  const float x[] = { 1.f, 2.f, 3.f, -2.f, -3.f, -4.f, 1.5f, -1.5f, 0.f, -0.5f,
+                      -2.f, -1.f };
+  size_t n = sizeof(x) / sizeof(float);
+  singa::Shape s = { 2, 1, 3, 2 };
+  singa::Tensor in(s);
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  int axis = 3;
+  Flatten flt;
+  singa::LayerConf conf;
+  singa::FlattenConf *flattenconf = conf.mutable_flatten_conf();
+  flattenconf->set_axis(axis);
+  flt.Setup(conf);
+
+  singa::Tensor out = flt.Forward(singa::kTrain, in);
+  EXPECT_EQ(n, out.Size());
+  EXPECT_EQ(6, out.shape(0));
+  EXPECT_EQ(2, out.shape(1));
+  const float *yptr = out.data<const float *>();
+  for (size_t i = 0; i < n; i++)
+    EXPECT_FLOAT_EQ(x[i], yptr[i]);
+}
+
+TEST(Flatten, BackwardCPU) {
+  // directly use input as the output_grad for backward
+  // note that only the shape of input really matters
+  const float dy[] = { 1.f, 2.f, 3.f, -2.f, -3.f, -4.f, 1.5f, -1.5f, 0.f, -0.5f,
+                       -2.f, -1.f };
+  size_t n = sizeof(dy) / sizeof(float);
+  singa::Tensor in(singa::Shape {
+    2, 1, 3, 2
+  });
+  in.CopyDataFromHostPtr<float>(dy, n);
+
+  int axis = 2;
+  Flatten flt;
+  singa::LayerConf conf;
+  singa::FlattenConf *flattenconf = conf.mutable_flatten_conf();
+  flattenconf->set_axis(axis);
+  flt.Setup(conf);
+
+  singa::Tensor temp = flt.Forward(singa::kTrain, in);
+  const auto out = flt.Backward(singa::kTrain, temp);
+  const float *xptr = out.first.data<const float *>();
+  EXPECT_EQ(n, out.first.Size());
+  EXPECT_EQ(2, out.first.shape(0));
+  EXPECT_EQ(1, out.first.shape(1));
+  EXPECT_EQ(3, out.first.shape(2));
+  EXPECT_EQ(2, out.first.shape(3));
+  for (size_t i = 0; i < n; i++)
+    EXPECT_FLOAT_EQ(dy[i], xptr[i]);
+}
+
+#ifdef USE_CUDA
+TEST(Flatten, ForwardGPU) {
+  const float x[] = { 1.f, 2.f, 3.f, -2.f, -3.f, -4.f, 1.5f, -1.5f, 0.f, -0.5f,
+                      -2.f, -1.f };
+  size_t n = sizeof(x) / sizeof(float);
+  singa::CudaGPU cuda(0, 1);
+  singa::Tensor in(singa::Shape {
+    2, 1, 3, 2
+  },
+                   &cuda);
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  int axis = 3;
+  Flatten flt;
+  singa::LayerConf conf;
+  singa::FlattenConf *flattenconf = conf.mutable_flatten_conf();
+  flattenconf->set_axis(axis);
+  flt.Setup(conf);
+
+  singa::Tensor out = flt.Forward(singa::kTrain, in);
+  singa::CppCPU host(0, 1);
+  out.ToDevice(&host);
+  EXPECT_EQ(n, out.Size());
+  EXPECT_EQ(6, out.shape(0));
+  EXPECT_EQ(2, out.shape(1));
+  const float *yptr = out.data<const float *>();
+  for (size_t i = 0; i < n; i++)
+    EXPECT_FLOAT_EQ(x[i], yptr[i]);
+}
+
+TEST(Flatten, BackwardGPU) {
+  // directly use input as the output_grad for backward
+  // note that only the shape of input really matters
+  const float dy[] = { 1.f, 2.f, 3.f, -2.f, -3.f, -4.f, 1.5f, -1.5f, 0.f, -0.5f,
+                       -2.f, -1.f };
+  size_t n = sizeof(dy) / sizeof(float);
+  singa::CudaGPU cuda(0, 1);
+  singa::Tensor in(singa::Shape {
+    2, 1, 3, 2
+  },
+                   &cuda);
+  in.CopyDataFromHostPtr<float>(dy, n);
+
+  int axis = 2;
+  Flatten flt;
+  singa::LayerConf conf;
+  singa::FlattenConf *flattenconf = conf.mutable_flatten_conf();
+  flattenconf->set_axis(axis);
+  flt.Setup(conf);
+
+  singa::Tensor out = flt.Forward(singa::kTrain, in);
+  const auto ret = flt.Backward(singa::kTrain, out);
+  singa::CppCPU host(0, 1);
+  singa::Tensor in_diff = ret.first;
+  in_diff.ToDevice(&host);
+  const float *xptr = in_diff.data<const float *>();
+  EXPECT_EQ(n, in_diff.Size());
+  EXPECT_EQ(2, in_diff.shape(0));
+  EXPECT_EQ(1, in_diff.shape(1));
+  EXPECT_EQ(3, in_diff.shape(2));
+  EXPECT_EQ(2, in_diff.shape(3));
+  for (size_t i = 0; i < n; i++)
+    EXPECT_FLOAT_EQ(dy[i], xptr[i]);
+}
+#endif // USE_CUDA

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5afd81b7/test/singa/test_prelu.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_prelu.cc b/test/singa/test_prelu.cc
new file mode 100644
index 0000000..2dde9e9
--- /dev/null
+++ b/test/singa/test_prelu.cc
@@ -0,0 +1,149 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "../src/model/layer/prelu.h"
+#include "gtest/gtest.h"
+#include "singa_config.h"
+
+using singa::PReLU;
+TEST(PReLU, Setup) {
+  PReLU prelu;
+  EXPECT_EQ("PReLU", prelu.layer_type());
+
+  singa::LayerConf conf;
+  singa::PReLUConf *preluconf = conf.mutable_prelu_conf();
+  preluconf->set_channel_shared(true);
+  preluconf->set_format("NHWC");
+
+  prelu.Setup(conf);
+  EXPECT_EQ(true, prelu.Channel_shared());
+  EXPECT_EQ("NHWC", prelu.Format());
+}
+
+TEST(PReLU, ForwardCPU) {
+  const float x[] = { 1.f, 2.f, 3.f, -2.f, -3.f, -1.f, -1.f, 2.f, -1.f, -2.f,
+                      -2.f, -1.f };
+  size_t n = sizeof(x) / sizeof(float);
+  size_t batchsize = 2, c = 3, h = 2, w = 1;
+  singa::Tensor in(singa::Shape {
+    batchsize, h, w, c
+  });
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  PReLU prelu;
+  singa::LayerConf conf;
+  singa::PReLUConf *preluconf = conf.mutable_prelu_conf();
+  preluconf->set_channel_shared(false);
+  preluconf->set_format("NHWC");
+  prelu.Setup(conf);
+
+  const float neg_slope[] = { 0.25f, 0.5f, 0.75f };
+  singa::Tensor a(singa::Shape {
+    c
+  });
+  a.CopyDataFromHostPtr<float>(neg_slope, c);
+  prelu.Set_a(a);
+
+  singa::Tensor out = prelu.Forward(singa::kTrain, in);
+  const float *yptr = out.data<const float *>();
+  EXPECT_EQ(n, out.Size());
+
+  float *y = new float[n];
+  size_t div_factor = prelu.Channel_shared() ? c : 1;
+  if (prelu.Format() == "NCHW") {
+    for (size_t i = 0; i < n; i++) {
+      size_t pos = i / (h * w) % c / div_factor;
+      y[i] = std::max(x[i], 0.f) + neg_slope[pos] * std::min(x[i], 0.f);
+    }
+  } else if (prelu.Format() == "NHWC") {
+    for (size_t i = 0; i < n; i++) {
+      size_t pos = i % c / div_factor;
+      y[i] = std::max(x[i], 0.f) + neg_slope[pos] * std::min(x[i], 0.f);
+    }
+  }
+  for (size_t i = 0; i < n; i++)
+    EXPECT_FLOAT_EQ(y[i], yptr[i]);
+}
+
+TEST(PReLU, BackwardCPU) {
+  const float x[] = {1.f, 2.f, 3.f, -2.f, -3.f, -1.f, -1.f, 2.f, -1.f, -2.f, -2.f, -1.f};
+  size_t n = sizeof(x) / sizeof(float);
+  size_t batchsize = 2, c = 3, h = 2, w = 1;
+  singa::Tensor in(singa::Shape {
+    batchsize, c, h, w
+  });
+  in.CopyDataFromHostPtr<float>(x, n);
+
+  PReLU prelu;
+  singa::LayerConf conf;
+  singa::PReLUConf *preluconf = conf.mutable_prelu_conf();
+  preluconf->set_channel_shared(false);
+  preluconf->set_format("NCHW");
+  prelu.Setup(conf);
+
+  const float neg_slope[] = { 0.25f, 0.5f, 0.75f };
+  singa::Tensor a(singa::Shape {
+    c
+  });
+  a.CopyDataFromHostPtr<float>(neg_slope, c);
+  prelu.Set_a(a);
+
+  singa::Tensor out = prelu.Forward(singa::kTrain, in);
+
+  const float grad[] = { 1.f, 2.f, -2.f, -1.f, -1.f, -3.f, 2.f, -2.f, 1.f, 1.f,
+                         -2.f, 0.f };
+  singa::Tensor out_diff(singa::Shape {
+    batchsize, c, h, w
+  });
+  out_diff.CopyDataFromHostPtr<float>(grad, n);
+  const auto ret = prelu.Backward(singa::kTrain, out_diff);
+  const float *xptr = ret.first.data<const float *>();
+  const float *aptr = ret.second.at(0).data<const float *>();
+  float *dx = new float[n];
+  size_t div_factor = prelu.Channel_shared() ? c : 1;
+  size_t params = prelu.Channel_shared() ? 1 : c;
+  float da[] = { 0.f, 0.f, 0.f };
+  if (prelu.Format() == "NCHW") {
+    for (size_t i = 0; i < n; i++) {
+      size_t pos = i / (h * w) % c / div_factor;
+      dx[i] = grad[i] *
+              (std::max(x[i], 0.f) + neg_slope[pos] * std::min(x[i], 0.f));
+    }
+    for (size_t i = 0; i < n; i++) {
+      size_t pos = i / (h * w) % c / div_factor;
+      da[pos] += grad[i] * std::min(x[i], 0.f);
+    }
+  } else if (prelu.Format() == "NHWC") {
+    for (size_t i = 0; i < n; i++) {
+      size_t pos = i % c / div_factor;
+      dx[i] = grad[i] *
+              (std::max(x[i], 0.f) + neg_slope[pos] * std::min(x[i], 0.f));
+    }
+    for (size_t i = 0; i < n; i++) {
+      size_t pos = i % c / div_factor;
+      da[pos] += grad[i] * std::min(x[i], 0.f);
+    }
+  }
+  for (size_t i = 0; i < n; i++)
+    EXPECT_FLOAT_EQ(dx[i], xptr[i]);
+  for (size_t i = 0; i < params; i++)
+    EXPECT_FLOAT_EQ(da[i], aptr[i]);
+}


[2/3] incubator-singa git commit: SINGA-190 - Add prelu layer and flatten layer

Posted by wa...@apache.org.
SINGA-190 - Add prelu layer and flatten layer

Format code. Fix warning info from compilation.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/58be3f80
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/58be3f80
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/58be3f80

Branch: refs/heads/dev
Commit: 58be3f8079e8d00c9fee4e1ce319786cc4e9f225
Parents: 5afd81b
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Sun Jun 12 22:31:46 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Sun Jun 12 22:31:46 2016 +0800

----------------------------------------------------------------------
 src/model/layer/flatten.cc |  19 +--
 src/model/layer/flatten.h  |  13 +-
 src/model/layer/prelu.cc   |  62 ++------
 src/model/layer/prelu.h    |  11 +-
 src/proto/model.proto      | 326 ++++++++++++++++++++++------------------
 test/singa/test_flatten.cc |  68 ++++-----
 test/singa/test_prelu.cc   |  46 +++---
 7 files changed, 261 insertions(+), 284 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58be3f80/src/model/layer/flatten.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/flatten.cc b/src/model/layer/flatten.cc
index 3ed37fe..7341394 100644
--- a/src/model/layer/flatten.cc
+++ b/src/model/layer/flatten.cc
@@ -31,22 +31,15 @@ const Tensor Flatten::Forward(int flag, const Tensor &input) {
   if (!Axis()) {
     // reshape to 1D
     size_t dim = output.Size();
-    output.Reshape(Shape {
-      dim
-    });
-    output_shape_ = Shape { dim }
-    ;
+    output.Reshape(Shape{dim});
+    output_shape_ = Shape{dim};
   } else {
     // reshape to 2D
     size_t dim1 = 1, dim2;
-    for (int i = 0; i < Axis(); i++)
-      dim1 *= output.shape(i);
+    for (int i = 0; i < Axis(); i++) dim1 *= output.shape(i);
     dim2 = output.Size() / dim1;
-    output.Reshape(Shape {
-      dim1, dim2
-    });
-    output_shape_ = Shape { dim1, dim2 }
-    ;
+    output.Reshape(Shape{dim1, dim2});
+    output_shape_ = Shape{dim1, dim2};
   }
   return output;
 }
@@ -55,7 +48,7 @@ const std::pair<Tensor, vector<Tensor> > Flatten::Backward(int flag,
                                                            const Tensor &grad) {
   vector<Tensor> param_grad;
   Tensor input_grad = grad;
-  input_grad.Reshape(Input_shape());
+  input_grad.Reshape(input_shape_);
   return std::make_pair(input_grad, param_grad);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58be3f80/src/model/layer/flatten.h
----------------------------------------------------------------------
diff --git a/src/model/layer/flatten.h b/src/model/layer/flatten.h
index cb36542..580b2ba 100644
--- a/src/model/layer/flatten.h
+++ b/src/model/layer/flatten.h
@@ -24,7 +24,7 @@
 
 namespace singa {
 class Flatten : public Layer {
-public:
+ public:
   /// \copydoc Layer::layer_type();
   const std::string layer_type() const override { return "Flatten"; }
 
@@ -35,15 +35,14 @@ public:
   const Tensor Forward(int flag, const Tensor &input) override;
 
   /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&);
-  const std::pair<Tensor, vector<Tensor> > Backward(int flag,
-                                                    const Tensor &grad)
-      override;
+  const std::pair<Tensor, vector<Tensor> > Backward(
+      int flag, const Tensor &grad) override;
 
   const int Axis() const { return axis_; }
-  const Shape Input_shape() const { return input_shape_; }
-  const Shape Output_shape() const { return output_shape_; }
+  const Shape input_shape() const { return input_shape_; }
+  const Shape output_shape() const { return output_shape_; }
 
-protected:
+ protected:
   /// flatten layer reshape the input to 2D, one from 0 to axis_-1, one from
   /// axis_ to end.
   /// if axis_ is 0, reshape the input to 1D.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58be3f80/src/model/layer/prelu.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/prelu.cc b/src/model/layer/prelu.cc
index 1d6a2e7..b916bed 100644
--- a/src/model/layer/prelu.cc
+++ b/src/model/layer/prelu.cc
@@ -25,8 +25,7 @@ void PReLU::Setup(const LayerConf &conf) {
   channel_shared_ = conf.prelu_conf().channel_shared();
   format_ = conf.prelu_conf().format();
   // Push back params into param_values_
-  for (const auto &spec : conf.param())
-    param_specs_.push_back(spec);
+  for (const auto &spec : conf.param()) param_specs_.push_back(spec);
   param_values_.push_back(&a_);
 }
 
@@ -41,26 +40,18 @@ const Tensor PReLU::Forward(int flag, const Tensor &input) {
         c = temp.shape(1);
         h = temp.shape(2);
         w = temp.shape(3);
-        temp.Reshape(Shape {
-          n *c, h *w
-        });
-        Tensor temp_a(Shape {
-          n, c
-        });
+        temp.Reshape(Shape{n * c, h * w});
+        Tensor temp_a(Shape{n, c});
         Uniform(1.f, 1.f, &temp_a);
         MultRow(a_, &temp_a);
-        temp_a.Reshape(Shape {
-          n *c
-        });
+        temp_a.Reshape(Shape{n * c});
         MultColumn(temp_a, &temp);
       } else if (format_ == "NHWC") {
         n = temp.shape(0);
         h = temp.shape(1);
         w = temp.shape(2);
         c = temp.shape(3);
-        temp.Reshape(Shape {
-          n *h *w, c
-        });
+        temp.Reshape(Shape{n * h * w, c});
         MultRow(a_, &temp);
       } else {
         LOG(FATAL) << "Incorrect input format for prelu layer.";
@@ -74,8 +65,7 @@ const Tensor PReLU::Forward(int flag, const Tensor &input) {
     const float a = a_.data<const float *>()[0];
     output = input * ((input > 0.f) + (input <= 0.f) * a);
   }
-  if (flag & kTrain)
-    buf_.push(input);
+  if (flag & kTrain) buf_.push(input);
   return output;
 }
 
@@ -96,33 +86,21 @@ const std::pair<Tensor, vector<Tensor> > PReLU::Backward(int flag,
         c = temp1.shape(1);
         h = temp1.shape(2);
         w = temp1.shape(3);
-        temp1.Reshape(Shape {
-          n *c, h *w
-        });
-        Tensor temp_a(Shape {
-          n, c
-        });
+        temp1.Reshape(Shape{n * c, h * w});
+        Tensor temp_a(Shape{n, c});
         Uniform(1.f, 1.f, &temp_a);
         MultRow(a_, &temp_a);
-        temp_a.Reshape(Shape {
-          n *c
-        });
+        temp_a.Reshape(Shape{n * c});
         MultColumn(temp_a, &temp1);
-        temp1.Reshape(Shape {
-          n, c, h, w
-        });
+        temp1.Reshape(Shape{n, c, h, w});
       } else if (format_ == "NHWC") {
         n = temp1.shape(0);
         h = temp1.shape(1);
         w = temp1.shape(2);
         c = temp1.shape(3);
-        temp1.Reshape(Shape {
-          n *h *w, c
-        });
+        temp1.Reshape(Shape{n * h * w, c});
         MultRow(a_, &temp1);
-        temp1.Reshape(Shape {
-          n, h, w, c
-        });
+        temp1.Reshape(Shape{n, h, w, c});
       } else {
         LOG(FATAL) << "Incorrect input format for prelu layer.";
       }
@@ -130,22 +108,14 @@ const std::pair<Tensor, vector<Tensor> > PReLU::Backward(int flag,
       LOG(FATAL) << "Incorrect input format for prelu layer.";
     }
     input_grad = grad * input * ((input > 0.f) + temp1);
-    Tensor temp2 = grad * input * (input <= 0.f), temp3(Shape {
-      n *c
-    });
+    Tensor temp2 = grad * input * (input <= 0.f), temp3(Shape{n * c});
     if (format_ == "NCHW") {
-      temp2.Reshape(Shape {
-        n *c, h *w
-      });
+      temp2.Reshape(Shape{n * c, h * w});
       SumColumns(temp2, &temp3);
-      temp3.Reshape(Shape {
-        n, c
-      });
+      temp3.Reshape(Shape{n, c});
       SumRows(temp3, &da);
     } else if (format_ == "NHWC") {
-      temp2.Reshape(Shape {
-        n *h *w, c
-      });
+      temp2.Reshape(Shape{n * h * w, c});
       SumRows(temp2, &da);
     }
   } else {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58be3f80/src/model/layer/prelu.h
----------------------------------------------------------------------
diff --git a/src/model/layer/prelu.h b/src/model/layer/prelu.h
index 1a01d98..d165fe2 100644
--- a/src/model/layer/prelu.h
+++ b/src/model/layer/prelu.h
@@ -26,7 +26,7 @@ namespace singa {
 class PReLU : public Layer {
  public:
   /// \copydoc Layer::layer_type()
-   const std::string layer_type() const override { return "PReLU"; }
+  const std::string layer_type() const override { return "PReLU"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const LayerConf &conf) override;
@@ -35,9 +35,8 @@ class PReLU : public Layer {
   const Tensor Forward(int flag, const Tensor &input) override;
 
   /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&);
-  const std::pair<Tensor, vector<Tensor> > Backward(int flag,
-                                                    const Tensor &grad)
-      override;
+  const std::pair<Tensor, vector<Tensor> > Backward(
+      int flag, const Tensor &grad) override;
 
   void ToDevice(Device *device);
 
@@ -52,8 +51,8 @@ class PReLU : public Layer {
 
  protected:
   bool channel_shared_;
-  std::string format_; // format_ has two valid value, i.e. NCHW, NHWC
-  Tensor a_; // shape of a_ is 2D, i.e. (channels, 1)
+  std::string format_;  // format_ has two valid value, i.e. NCHW, NHWC
+  Tensor a_;            // shape of a_ is 2D, i.e. (channels, 1)
   std::stack<Tensor> buf_;
 };
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58be3f80/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index 1d1f3cf..590fdd6 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -33,59 +33,67 @@ package singa;
 /// using Python (or C++/Java).
 
 // Specifies the shape (dimensions) of a Blob.
-message BlobShape { repeated int64 dim = 1[packed = true]; }
+message BlobShape {
+  repeated int64 dim = 1 [packed = true];
+}
 
 message BlobProto {
   optional BlobShape shape = 7;
-  repeated float data = 5[packed = true];
-  repeated float diff = 6[packed = true];
-  repeated double double_data = 8[packed = true];
-  repeated double double_diff = 9[packed = true];
+  repeated float data = 5 [packed = true];
+  repeated float diff = 6 [packed = true];
+  repeated double double_data = 8 [packed = true];
+  repeated double double_diff = 9 [packed = true];
 
   // 4D dimensions -- deprecated.  Use "shape" instead.
-  optional int32 num = 1[default = 0];
-  optional int32 channels = 2[default = 0];
-  optional int32 height = 3[default = 0];
-  optional int32 width = 4[default = 0];
+  optional int32 num = 1 [default = 0];
+  optional int32 channels = 2 [default = 0];
+  optional int32 height = 3 [default = 0];
+  optional int32 width = 4 [default = 0];
 }
 
 message FillerConf {
   // The filler type, case insensitive
-  optional string type = 1[default = 'constant'];
-  optional float value = 2[default = 0]; // the value in constant filler
-  optional float min = 3[default = 0];   // the min value in uniform filler
-  optional float max = 4[default = 1];   // the max value in uniform filler
-  optional float mean = 5[default = 0];  // the mean value in Gaussian filler
-  optional float std = 6[default = 1];   // the std value in Gaussian filler
+  optional string type = 1 [default = 'constant'];
+  optional float value = 2 [default = 0]; // the value in constant filler
+  optional float min = 3 [default = 0]; // the min value in uniform filler
+  optional float max = 4 [default = 1]; // the max value in uniform filler
+  optional float mean = 5 [default = 0]; // the mean value in Gaussian filler
+  optional float std = 6 [default = 1]; // the std value in Gaussian filler
   // The expected number of non-zero output weights for a given input in
   // Gaussian filler -- the default -1 means don't perform sparsification.
   /* optional int32 sparse = 7 [default = -1]; */
   // Normalize the filler variance by fan_in, fan_out, or their average.
   // Applies to 'xavier' and 'msra' fillers.
   enum VarianceNorm {
-    FAN_IN = 0; FAN_OUT = 1; AVERAGE = 2;
-  } optional VarianceNorm variance_norm = 8[default = FAN_IN];
+    FAN_IN = 0;
+    FAN_OUT = 1;
+    AVERAGE = 2;
+  }
+  optional VarianceNorm variance_norm = 8 [default = FAN_IN];
 }
 
 /// SINGA message
 message OptimizerConf {
   // case insensitive
-  optional string type = 1[default = "sgd"];
+  optional string type = 1 [default = "sgd"];
 
   // used by RMSprop and Adadelta
-  optional float rho = 2[default = 0.001];
+  optional float rho = 2 [default = 0.001];
 
   // used by Adam and AdamMax
-  optional float beta_1 = 3[default = 0.9];
-  optional float beta_2 = 4[default = 0.999];
+  optional float beta_1 = 3 [default = 0.9];
+  optional float beta_2 = 4 [default = 0.999];
 
   // used by vanilla sgd and nesterov
-  optional float momentum = 5[default = 0.9];
+  optional float momentum = 5 [default = 0.9];
+
+  // delta is used to avoid dividing zero
+  optional float delta = 6 [default = 1e-8];
 }
 
 message ConstraintConf {
   // case insensitive to limit the parameter value/gradient scale
-  optional string type = 1[default = "l2"];
+  optional string type = 1 [default = "l2"];
   // e.g., the threshold for limiting the parameter scale.
   optional float threshold = 2;
 }
@@ -93,7 +101,7 @@ message ConstraintConf {
 /// SINGA message
 message RegularizerConf {
   // case insensitive to regularize the parameters, e.g., L2.
-  optional string type = 1[default = "l2"];
+  optional string type = 1 [default = "l2"];
   // e.g., the weight decay for L2 regularizer
   optional float coefficient = 2;
 }
@@ -119,10 +127,10 @@ message ParamSpec {
   */
 
   // The multiplier on the global learning rate for this parameter.
-  optional float lr_mult = 3[default = 1.0];
+  optional float lr_mult = 3 [default = 1.0];
 
   // The multiplier on the global weight decay for this parameter.
-  optional float decay_mult = 4[default = 1.0];
+  optional float decay_mult = 4 [default = 1.0];
 
   // SINGA uses this filed internally. Users just configure the fillers in
   // Layer specific conf message as caffe (style).
@@ -132,13 +140,14 @@ message ParamSpec {
 }
 
 enum Phase {
-  kTrain = 4; kEval = 8;
-}
-    // NOTE
-    // Update the next available ID when you add a new LayerConf field.
-    //
-    // LayerConf next available layer-specific ID: 139 (last added: tile_param)
-    message LayerConf {
+  kTrain = 4;
+  kEval = 8;
+}
+// NOTE
+// Update the next available ID when you add a new LayerConf field.
+//
+// LayerConf next available layer-specific ID: 139 (last added: tile_param)
+message LayerConf {
   optional string name = 1; // the layer name
   optional string type = 2; // the layer type
   /* repeated string bottom = 3; // the name of each bottom blob */
@@ -242,8 +251,7 @@ message TransformationConf {
   optional uint32 crop_size = 3 [default = 0];
   // mean_file and mean_value cannot be specified at the same time
   optional string mean_file = 4;
-  // if specified can be repeated once (would substract it from all the
-channels)
+  // if specified can be repeated once (would substract it from all the channels)
   // or can be repeated the same number of times as channels
   // (would subtract them from the corresponding channel)
   repeated float mean_value = 5;
@@ -260,33 +268,34 @@ message LossConf {
   optional int32 ignore_label = 1;
   // If true, normalize each batch across all instances (including spatial
   // dimesions, but not ignored instances); else, divide by batch size only.
-  optional bool normalize = 2[default = true];
+  optional bool normalize = 2 [default = true];
 }
 
 message MetricConf {
   // When computing accuracy, count as correct by comparing the true label to
   // the top k scoring classes.  By default, only compare to the top scoring
   // class (i.e. argmax).
-  optional uint32 top_k = 1[default = 1];
+  optional uint32 top_k = 1 [default = 1];
 
   // The "label" axis of the prediction blob, whose argmax corresponds to the
   // predicted label -- may be negative to index from the end (e.g., -1 for the
   // last axis).  For example, if axis == 1 and the predictions are
   // (N x C x H x W), the label blob is expected to contain N*H*W ground truth
   // labels with integer values in {0, 1, ..., C-1}.
-  optional int32 axis = 2[default = 1];
+  optional int32 axis = 2 [default = 1];
 
   // If specified, ignore instances with the given label.
   optional int32 ignore_label = 3;
 }
-// Messages that store hyper-parameters used by individual layer types follow,
-// in
+// Messages that store hyper-parameters used by individual layer types follow, in
 // alphabetical order.
 
+
+
 message ArgMaxConf {
   // If true produce pairs (argmax, maxval)
-  optional bool out_max_val = 1[default = false];
-  optional uint32 top_k = 2[default = 1];
+  optional bool out_max_val = 1 [default = false];
+  optional uint32 top_k = 2 [default = 1];
   // The axis along which to maximise -- may be negative to index from the
   // end (e.g., -1 for the last axis).
   // By default ArgMaxLayer maximizes over the flattened trailing dimensions
@@ -299,51 +308,54 @@ message ConcatConf {
   // end (e.g., -1 for the last axis).  Other axes must have the
   // same dimension for all the bottom blobs.
   // By default, ConcatLayer concatenates blobs along the "channels" axis (1).
-  optional int32 axis = 2[default = 1];
+  optional int32 axis = 2 [default = 1];
 
   // DEPRECATED: alias for "axis" -- does not support negative indexing.
-  optional uint32 concat_dim = 1[default = 1];
+  optional uint32 concat_dim = 1 [default = 1];
 }
 
 message ContrastiveLossConf {
   // margin for dissimilar pair
-  optional float margin = 1[default = 1.0];
+  optional float margin = 1 [default = 1.0];
   // The first implementation of this cost did not exactly match the cost of
   // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2.
   // legacy_version = false (the default) uses (margin - d)^2 as proposed in the
   // Hadsell paper. New models should probably use this version.
   // legacy_version = true uses (margin - d^2). This is kept to support /
   // reproduce existing models and results
-  optional bool legacy_version = 2[default = false];
+  optional bool legacy_version = 2 [default = false];
 }
 
 message ConvolutionConf {
   optional uint32 num_output = 1; // The number of outputs for the layer
-  optional bool bias_term = 2[default = true]; // whether to have bias terms
+  optional bool bias_term = 2 [default = true]; // whether to have bias terms
 
   // Pad, kernel size, and stride are all given as a single value for equal
   // dimensions in all spatial dimensions, or once per spatial dimension.
-  repeated uint32 pad = 3;         // The padding size; defaults to 0
+  repeated uint32 pad = 3; // The padding size; defaults to 0
   repeated uint32 kernel_size = 4; // The kernel size
-  repeated uint32 stride = 6;      // The stride; defaults to 1
+  repeated uint32 stride = 6; // The stride; defaults to 1
 
   // For 2D convolution only, the *_h and *_w versions may also be used to
   // specify both spatial dimensions.
-  optional uint32 pad_h = 9[default = 0];  // The padding height (2D only)
-  optional uint32 pad_w = 10[default = 0]; // The padding width (2D only)
-  optional uint32 kernel_h = 11;           // The kernel height (2D only)
-  optional uint32 kernel_w = 12;           // The kernel width (2D only)
-  optional uint32 stride_h = 13;           // The stride height (2D only)
-  optional uint32 stride_w = 14;           // The stride width (2D only)
+  optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only)
+  optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only)
+  optional uint32 kernel_h = 11; // The kernel height (2D only)
+  optional uint32 kernel_w = 12; // The kernel width (2D only)
+  optional uint32 stride_h = 13; // The stride height (2D only)
+  optional uint32 stride_w = 14; // The stride width (2D only)
 
   // SINGA: not supported.
   // optional uint32 group = 5 [default = 1]; // The group size for group conv
 
   optional FillerConf weight_filler = 7; // The filler for the weight
-  optional FillerConf bias_filler = 8;   // The filler for the bias
+  optional FillerConf bias_filler = 8; // The filler for the bias
   enum Engine {
-    DEFAULT = 0; CAFFE = 1; CUDNN = 2;
-  } optional Engine engine = 15[default = DEFAULT];
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 15 [default = DEFAULT];
 
   // The axis to interpret as "channels" when performing convolution.
   // Preceding dimensions are treated as independent inputs;
@@ -365,12 +377,13 @@ message ConvolutionConf {
   // SINGA: not supported;
   // optional bool force_nd_im2col = 17 [default = false];
 
+
   // SINGA: add by xiangrui
   // cudnn workspace size in MB
-  optional int32 workspace_byte_limit = 50[default = 512];
+  optional int32 workspace_byte_limit = 50 [default = 512];
   // cudnn algorithm preference
   // options: "fastest", "limited_workspace", "no_workspace"
-  optional string prefer = 51[default = "fastest"];
+  optional string prefer = 51 [default = "fastest"];
   // input shape
   optional int32 channels = 52;
   optional int32 height = 53;
@@ -414,7 +427,7 @@ message DataConf {
 */
 
 message DropoutConf {
-  optional float dropout_ratio = 1[default = 0.5]; // dropout ratio
+  optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio
 }
 
 // DummyDataLayer fills any number of arbitrarily shaped blobs with random
@@ -438,13 +451,16 @@ message DummyDataConf {
 
 message EltwiseConf {
   enum EltwiseOp {
-    PROD = 0; SUM = 1; MAX = 2;
-  } optional EltwiseOp operation = 1[default = SUM]; // element-wise operation
+    PROD = 0;
+    SUM = 1;
+    MAX = 2;
+  }
+  optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation
   repeated float coeff = 2; // blob-wise coefficient for SUM operation
 
   // Whether to use an asymptotically slower (for >2 inputs) but stabler method
   // of computing the gradient for the PROD operation. (No effect for SUM op.)
-  optional bool stable_prod_grad = 3[default = true];
+  optional bool stable_prod_grad = 3 [default = true];
 }
 
 // Message that stores hyper-parameters used by EmbedLayer
@@ -455,9 +471,9 @@ message EmbedConf {
   // 1 greater than the maximum possible input value.
   optional uint32 input_dim = 2;
 
-  optional bool bias_term = 3[default = true]; // Whether to use a bias term
-  optional FillerConf weight_filler = 4;       // The filler for the weight
-  optional FillerConf bias_filler = 5;         // The filler for the bias
+  optional bool bias_term = 3 [default = true]; // Whether to use a bias term
+  optional FillerConf weight_filler = 4; // The filler for the weight
+  optional FillerConf bias_filler = 5; // The filler for the bias
 
 }
 
@@ -466,21 +482,21 @@ message ExpConf {
   // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0.
   // Or if base is set to the default (-1), base is set to e,
   // so y = exp(shift + scale * x).
-  optional float base = 1[default = -1.0];
-  optional float scale = 2[default = 1.0];
-  optional float shift = 3[default = 0.0];
+  optional float base = 1 [default = -1.0];
+  optional float scale = 2 [default = 1.0];
+  optional float shift = 3 [default = 0.0];
 }
 
 /// Message that stores hyper-parameters used by FlattenLayer
 message FlattenConf {
   // The first axis to flatten: all preceding axes are retained in the output.
   // May be negative to index from the end (e.g., -1 for the last axis).
-  optional int32 axis = 1[default = 1];
+  optional int32 axis = 1 [default = 1];
 
   // The last axis to flatten: all following axes are retained in the output.
   // May be negative to index from the end (e.g., the default -1 for the last
   // axis).
-  optional int32 end_axis = 2[default = -1];
+  optional int32 end_axis = 2 [default = -1];
 }
 
 /*
@@ -506,10 +522,11 @@ message HDF5OutputConf {
 
 message HingeLossConf {
   enum Norm {
-    L1 = 1; L2 = 2;
+    L1 = 1;
+    L2 = 2;
   }
-      // Specify the Norm to use L1 or L2
-      optional Norm norm = 1[default = L1];
+  // Specify the Norm to use L1 or L2
+  optional Norm norm = 1 [default = L1];
 }
 
 /*
@@ -552,29 +569,29 @@ message InfogainLossConf {
 
 message InnerProductConf {
   optional uint32 num_output = 1; // The number of outputs for the layer
-  optional bool bias_term = 2[default = true]; // whether to have bias terms
-  optional FillerConf weight_filler = 3;       // The filler for the weight
-  optional FillerConf bias_filler = 4;         // The filler for the bias
+  optional bool bias_term = 2 [default = true]; // whether to have bias terms
+  optional FillerConf weight_filler = 3; // The filler for the weight
+  optional FillerConf bias_filler = 4; // The filler for the bias
 
   // The first axis to be lumped into a single inner product computation;
   // all preceding axes are retained in the output.
   // May be negative to index from the end (e.g., -1 for the last axis).
-  optional int32 axis = 5[default = 1];
+  optional int32 axis = 5 [default = 1];
 }
 
 message DenseConf {
   optional uint32 num_output = 1; // The number of outputs for the layer
-  optional bool bias_term = 2[default = true]; // whether to have bias terms
-  optional FillerConf weight_filler = 3;       // The filler for the weight
-  optional FillerConf bias_filler = 4;         // The filler for the bias
+  optional bool bias_term = 2 [default = true]; // whether to have bias terms
+  optional FillerConf weight_filler = 3; // The filler for the weight
+  optional FillerConf bias_filler = 4; // The filler for the bias
 
   // The first axis to be lumped into a single inner product computation;
   // all preceding axes are retained in the output.
   // May be negative to index from the end (e.g., -1 for the last axis).
-  optional int32 axis = 5[default = 1];
+  optional int32 axis = 5 [default = 1];
 
   optional uint32 num_input = 20; // The number of inputs for the layer
-  optional bool transpose = 21[default = false]; // whether transpose or not
+  optional bool transpose = 21 [default = false]; // whether transpose or not
 }
 
 // Message that stores hyper-parameters used by LogLayer
@@ -582,20 +599,22 @@ message LogConf {
   // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0.
   // Or if base is set to the default (-1), base is set to e,
   // so y = ln(shift + scale * x) = log_e(shift + scale * x)
-  optional float base = 1[default = -1.0];
-  optional float scale = 2[default = 1.0];
-  optional float shift = 3[default = 0.0];
+  optional float base = 1 [default = -1.0];
+  optional float scale = 2 [default = 1.0];
+  optional float shift = 3 [default = 0.0];
 }
 
 // Message that stores hyper-parameters used by LRNLayer
 message LRNConf {
-  optional uint32 local_size = 1[default = 5];
-  optional float alpha = 2[default = 1.];
-  optional float beta = 3[default = 0.75];
+  optional uint32 local_size = 1 [default = 5];
+  optional float alpha = 2 [default = 1.];
+  optional float beta = 3 [default = 0.75];
   enum NormRegion {
-    ACROSS_CHANNELS = 0; WITHIN_CHANNEL = 1;
-  } optional NormRegion norm_region = 4[default = ACROSS_CHANNELS];
-  optional float k = 5[default = 1.];
+    ACROSS_CHANNELS = 0;
+    WITHIN_CHANNEL = 1;
+  }
+  optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS];
+  optional float k = 5 [default = 1.];
 }
 
 message MemoryDataConf {
@@ -607,30 +626,33 @@ message MemoryDataConf {
 
 message MVNConf {
   // This parameter can be set to false to normalize mean only
-  optional bool normalize_variance = 1[default = true];
+  optional bool normalize_variance = 1 [default = true];
 
   // This parameter can be set to true to perform DNN-like MVN
-  optional bool across_channels = 2[default = false];
+  optional bool across_channels = 2 [default = false];
 
   // Epsilon for not dividing by zero while normalizing variance
-  optional float eps = 3[default = 1e-9];
+  optional float eps = 3 [default = 1e-9];
 }
 
 message PoolingConf {
   enum PoolMethod {
-    MAX = 0; AVE = 1; STOCHASTIC = 2;
-  } optional PoolMethod pool = 1[default = MAX]; // The pooling method
+    MAX = 0;
+    AVE = 1;
+    STOCHASTIC = 2;
+  }
+  optional PoolMethod pool = 1 [default = MAX]; // The pooling method
   // Pad, kernel size, and stride are all given as a single value for equal
   // dimensions in height and width or as Y, X pairs.
-  optional uint32 pad = 4[default = 0];    // The padding size (equal in Y, X)
-  optional uint32 pad_h = 9[default = 0];  // The padding height
-  optional uint32 pad_w = 10[default = 0]; // The padding width
-  optional uint32 kernel_size = 2;         // The kernel size (square)
-  optional uint32 kernel_h = 5;            // The kernel height
-  optional uint32 kernel_w = 6;            // The kernel width
-  optional uint32 stride = 3[default = 1]; // The stride (equal in Y, X)
-  optional uint32 stride_h = 7;            // The stride height
-  optional uint32 stride_w = 8;            // The stride width
+  optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X)
+  optional uint32 pad_h = 9 [default = 0]; // The padding height
+  optional uint32 pad_w = 10 [default = 0]; // The padding width
+  optional uint32 kernel_size = 2; // The kernel size (square)
+  optional uint32 kernel_h = 5; // The kernel height
+  optional uint32 kernel_w = 6; // The kernel width
+  optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X)
+  optional uint32 stride_h = 7; // The stride height
+  optional uint32 stride_w = 8; // The stride width
   /*
   enum Engine {
     DEFAULT = 0;
@@ -641,20 +663,20 @@ message PoolingConf {
   */
   // If global_pooling then it will pool over the size of the bottom by doing
   // kernel_h = bottom->height and kernel_w = bottom->width
-  optional bool global_pooling = 12[default = false];
+  optional bool global_pooling = 12 [default = false];
   // Shape of source
   optional int32 channels = 50;
   optional int32 height = 51;
   optional int32 width = 52;
   // whether to propagate nan
-  optional bool nan_prop = 53[default = false];
+  optional bool nan_prop = 53 [default = false];
 }
 
 message PowerConf {
   // PowerLayer computes outputs y = (shift + scale * x) ^ power.
-  optional float power = 1[default = 1.0];
-  optional float scale = 2[default = 1.0];
-  optional float shift = 3[default = 0.0];
+  optional float power = 1 [default = 1.0];
+  optional float scale = 2 [default = 1.0];
+  optional float shift = 3 [default = 0.0];
 }
 /*
 message PythonConf {
@@ -665,8 +687,7 @@ message PythonConf {
   // string, dictionary in Python dict format, JSON, etc. You may parse this
   // string in `setup` method and use it in `forward` and `backward`.
   optional string param_str = 3 [default = ''];
-  // Whether this PythonLayer is shared among worker solvers during data
-parallelism.
+  // Whether this PythonLayer is shared among worker solvers during data parallelism.
   // If true, each worker solver sequentially run forward from this layer.
   // This value should be set true if you are using it as a data layer.
   optional bool share_in_parallel = 4 [default = false];
@@ -676,8 +697,13 @@ parallelism.
 // Message that stores hyper-parameters used by ReductionLayer
 message ReductionConf {
   enum ReductionOp {
-    SUM = 1; ASUM = 2; SUMSQ = 3; MEAN = 4;
-  } optional ReductionOp operation = 1[default = SUM]; // reduction operation
+    SUM = 1;
+    ASUM = 2;
+    SUMSQ = 3;
+    MEAN = 4;
+  }
+
+  optional ReductionOp operation = 1 [default = SUM]; // reduction operation
 
   // The first axis to reduce to a scalar -- may be negative to index from the
   // end (e.g., -1 for the last axis).
@@ -692,9 +718,9 @@ message ReductionConf {
   // If axis == 0 (the default), the output Blob always has the empty shape
   // (count 1), performing reduction across the entire input --
   // often useful for creating new loss functions.
-  optional int32 axis = 2[default = 0];
+  optional int32 axis = 2 [default = 0];
 
-  optional float coeff = 3[default = 1.0]; // coefficient for output
+  optional float coeff = 3 [default = 1.0]; // coefficient for output
 }
 
 // Message that stores hyper-parameters used by ReLULayer
@@ -704,7 +730,7 @@ message ReLUConf {
   // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities
   // improve neural network acoustic models. In ICML Workshop on Deep Learning
   // for Audio, Speech, and Language Processing.
-  optional float negative_slope = 1[default = 0];
+  optional float negative_slope = 1 [default = 0];
   /*
   enum Engine {
     DEFAULT = 0;
@@ -775,50 +801,58 @@ message ReshapeConf {
   //   reshape_param { shape { dim: 2  dim: 1  dim: 8  }  }
   //   reshape_param { shape { dim: 1 }  axis: 1  num_axes: 0 }
   //
-  optional int32 axis = 2[default = 0];
-  optional int32 num_axes = 3[default = -1];
+  optional int32 axis = 2 [default = 0];
+  optional int32 num_axes = 3 [default = -1];
 }
 
 message SigmoidConf {
   enum Engine {
-    DEFAULT = 0; CAFFE = 1; CUDNN = 2;
-  } optional Engine engine = 1[default = DEFAULT];
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 1 [default = DEFAULT];
 }
 
 message SliceConf {
   // The axis along which to slice -- may be negative to index from the end
   // (e.g., -1 for the last axis).
   // By default, SliceLayer concatenates blobs along the "channels" axis (1).
-  optional int32 axis = 3[default = 1];
+  optional int32 axis = 3 [default = 1];
   repeated uint32 slice_point = 2;
 
   // DEPRECATED: alias for "axis" -- does not support negative indexing.
-  optional uint32 slice_dim = 1[default = 1];
+  optional uint32 slice_dim = 1 [default = 1];
 }
 
-// Message that stores hyper-parameters used by SoftmaxLayer,
-// SoftmaxWithLossLayer
+// Message that stores hyper-parameters used by SoftmaxLayer, SoftmaxWithLossLayer
 message SoftmaxConf {
   enum Engine {
-    DEFAULT = 0; CAFFE = 1; CUDNN = 2;
-  } optional Engine engine = 1[default = DEFAULT];
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 1 [default = DEFAULT];
 
   // The axis along which to perform the softmax -- may be negative to index
   // from the end (e.g., -1 for the last axis).
   // Any other axes will be evaluated as independent softmaxes.
-  optional int32 axis = 2[default = 1];
+  optional int32 axis = 2 [default = 1];
 }
 
 message TanHConf {
   enum Engine {
-    DEFAULT = 0; CAFFE = 1; CUDNN = 2;
-  } optional Engine engine = 1[default = DEFAULT];
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 1 [default = DEFAULT];
 }
 
 // Message that stores hyper-parameters used by TileLayer
 message TileConf {
   // The index of the axis to tile.
-  optional int32 axis = 1[default = 1];
+  optional int32 axis = 1 [default = 1];
 
   // The number of copies (tiles) of the blob to output.
   optional int32 tiles = 2;
@@ -826,7 +860,7 @@ message TileConf {
 
 // Message that stores hyper-parameters used by ThresholdLayer
 message ThresholdConf {
-  optional float threshold = 1[default = 0]; // Strictly positive values
+  optional float threshold = 1 [default = 0]; // Strictly positive values
 }
 
 /*
@@ -866,12 +900,18 @@ message WindowDataConf {
 
 message SPPConf {
   enum PoolMethod {
-    MAX = 0; AVE = 1; STOCHASTIC = 2;
-  } optional uint32 pyramid_height = 1;
-  optional PoolMethod pool = 2[default = MAX]; // The pooling method
+    MAX = 0;
+    AVE = 1;
+    STOCHASTIC = 2;
+  }
+  optional uint32 pyramid_height = 1;
+  optional PoolMethod pool = 2 [default = MAX]; // The pooling method
   enum Engine {
-    DEFAULT = 0; CAFFE = 1; CUDNN = 2;
-  } optional Engine engine = 6[default = DEFAULT];
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 6 [default = DEFAULT];
 }
 
 message PReLUConf {
@@ -881,15 +921,15 @@ message PReLUConf {
   // Initial value of a_i. Default is a_i=0.25 for all i.
   optional FillerConf filler = 1;
   // Whether or not slope paramters are shared across channels.
-  optional bool channel_shared = 2[default = false];
-  // format of the input. Default is NCHW.
-  optional string format = 50[default = "NCHW"];
+  optional bool channel_shared = 2 [default = false];
+
+  optional string format = 20 [default = "NCHW"];
 }
 
 message BatchNormConf {
   // Used in the moving average computation runningMean =
   // newMean*factor + runningMean*(1-factor).
-  optional double factor = 1[default = 0.9];
+  optional double factor = 1 [default = 0.9];
   // input shape
   optional int32 channels = 2;
   optional int32 height = 3;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58be3f80/test/singa/test_flatten.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_flatten.cc b/test/singa/test_flatten.cc
index 906e4b8..0ba8d3c 100644
--- a/test/singa/test_flatten.cc
+++ b/test/singa/test_flatten.cc
@@ -36,10 +36,10 @@ TEST(Flatten, Setup) {
 }
 
 TEST(Flatten, ForwardCPU) {
-  const float x[] = { 1.f, 2.f, 3.f, -2.f, -3.f, -4.f, 1.5f, -1.5f, 0.f, -0.5f,
-                      -2.f, -1.f };
+  const float x[] = {1.f,  2.f,   3.f, -2.f,  -3.f, -4.f,
+                     1.5f, -1.5f, 0.f, -0.5f, -2.f, -1.f};
   size_t n = sizeof(x) / sizeof(float);
-  singa::Shape s = { 2, 1, 3, 2 };
+  singa::Shape s = {2, 1, 3, 2};
   singa::Tensor in(s);
   in.CopyDataFromHostPtr<float>(x, n);
 
@@ -52,22 +52,19 @@ TEST(Flatten, ForwardCPU) {
 
   singa::Tensor out = flt.Forward(singa::kTrain, in);
   EXPECT_EQ(n, out.Size());
-  EXPECT_EQ(6, out.shape(0));
-  EXPECT_EQ(2, out.shape(1));
+  EXPECT_EQ(6u, out.shape(0));
+  EXPECT_EQ(2u, out.shape(1));
   const float *yptr = out.data<const float *>();
-  for (size_t i = 0; i < n; i++)
-    EXPECT_FLOAT_EQ(x[i], yptr[i]);
+  for (size_t i = 0; i < n; i++) EXPECT_FLOAT_EQ(x[i], yptr[i]);
 }
 
 TEST(Flatten, BackwardCPU) {
   // directly use input as the output_grad for backward
   // note that only the shape of input really matters
-  const float dy[] = { 1.f, 2.f, 3.f, -2.f, -3.f, -4.f, 1.5f, -1.5f, 0.f, -0.5f,
-                       -2.f, -1.f };
+  const float dy[] = {1.f,  2.f,   3.f, -2.f,  -3.f, -4.f,
+                      1.5f, -1.5f, 0.f, -0.5f, -2.f, -1.f};
   size_t n = sizeof(dy) / sizeof(float);
-  singa::Tensor in(singa::Shape {
-    2, 1, 3, 2
-  });
+  singa::Tensor in(singa::Shape{2, 1, 3, 2});
   in.CopyDataFromHostPtr<float>(dy, n);
 
   int axis = 2;
@@ -81,24 +78,20 @@ TEST(Flatten, BackwardCPU) {
   const auto out = flt.Backward(singa::kTrain, temp);
   const float *xptr = out.first.data<const float *>();
   EXPECT_EQ(n, out.first.Size());
-  EXPECT_EQ(2, out.first.shape(0));
-  EXPECT_EQ(1, out.first.shape(1));
-  EXPECT_EQ(3, out.first.shape(2));
-  EXPECT_EQ(2, out.first.shape(3));
-  for (size_t i = 0; i < n; i++)
-    EXPECT_FLOAT_EQ(dy[i], xptr[i]);
+  EXPECT_EQ(2u, out.first.shape(0));
+  EXPECT_EQ(1u, out.first.shape(1));
+  EXPECT_EQ(3u, out.first.shape(2));
+  EXPECT_EQ(2u, out.first.shape(3));
+  for (size_t i = 0; i < n; i++) EXPECT_FLOAT_EQ(dy[i], xptr[i]);
 }
 
 #ifdef USE_CUDA
 TEST(Flatten, ForwardGPU) {
-  const float x[] = { 1.f, 2.f, 3.f, -2.f, -3.f, -4.f, 1.5f, -1.5f, 0.f, -0.5f,
-                      -2.f, -1.f };
+  const float x[] = {1.f,  2.f,   3.f, -2.f,  -3.f, -4.f,
+                     1.5f, -1.5f, 0.f, -0.5f, -2.f, -1.f};
   size_t n = sizeof(x) / sizeof(float);
   singa::CudaGPU cuda(0, 1);
-  singa::Tensor in(singa::Shape {
-    2, 1, 3, 2
-  },
-                   &cuda);
+  singa::Tensor in(singa::Shape{2, 1, 3, 2}, &cuda);
   in.CopyDataFromHostPtr<float>(x, n);
 
   int axis = 3;
@@ -112,24 +105,20 @@ TEST(Flatten, ForwardGPU) {
   singa::CppCPU host(0, 1);
   out.ToDevice(&host);
   EXPECT_EQ(n, out.Size());
-  EXPECT_EQ(6, out.shape(0));
-  EXPECT_EQ(2, out.shape(1));
+  EXPECT_EQ(6u, out.shape(0));
+  EXPECT_EQ(2u, out.shape(1));
   const float *yptr = out.data<const float *>();
-  for (size_t i = 0; i < n; i++)
-    EXPECT_FLOAT_EQ(x[i], yptr[i]);
+  for (size_t i = 0; i < n; i++) EXPECT_FLOAT_EQ(x[i], yptr[i]);
 }
 
 TEST(Flatten, BackwardGPU) {
   // directly use input as the output_grad for backward
   // note that only the shape of input really matters
-  const float dy[] = { 1.f, 2.f, 3.f, -2.f, -3.f, -4.f, 1.5f, -1.5f, 0.f, -0.5f,
-                       -2.f, -1.f };
+  const float dy[] = {1.f,  2.f,   3.f, -2.f,  -3.f, -4.f,
+                      1.5f, -1.5f, 0.f, -0.5f, -2.f, -1.f};
   size_t n = sizeof(dy) / sizeof(float);
   singa::CudaGPU cuda(0, 1);
-  singa::Tensor in(singa::Shape {
-    2, 1, 3, 2
-  },
-                   &cuda);
+  singa::Tensor in(singa::Shape{2, 1, 3, 2}, &cuda);
   in.CopyDataFromHostPtr<float>(dy, n);
 
   int axis = 2;
@@ -146,11 +135,10 @@ TEST(Flatten, BackwardGPU) {
   in_diff.ToDevice(&host);
   const float *xptr = in_diff.data<const float *>();
   EXPECT_EQ(n, in_diff.Size());
-  EXPECT_EQ(2, in_diff.shape(0));
-  EXPECT_EQ(1, in_diff.shape(1));
-  EXPECT_EQ(3, in_diff.shape(2));
-  EXPECT_EQ(2, in_diff.shape(3));
-  for (size_t i = 0; i < n; i++)
-    EXPECT_FLOAT_EQ(dy[i], xptr[i]);
+  EXPECT_EQ(2u, in_diff.shape(0));
+  EXPECT_EQ(1u, in_diff.shape(1));
+  EXPECT_EQ(3u, in_diff.shape(2));
+  EXPECT_EQ(2u, in_diff.shape(3));
+  for (size_t i = 0; i < n; i++) EXPECT_FLOAT_EQ(dy[i], xptr[i]);
 }
 #endif // USE_CUDA

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58be3f80/test/singa/test_prelu.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_prelu.cc b/test/singa/test_prelu.cc
index 2dde9e9..6fc372b 100644
--- a/test/singa/test_prelu.cc
+++ b/test/singa/test_prelu.cc
@@ -39,13 +39,11 @@ TEST(PReLU, Setup) {
 }
 
 TEST(PReLU, ForwardCPU) {
-  const float x[] = { 1.f, 2.f, 3.f, -2.f, -3.f, -1.f, -1.f, 2.f, -1.f, -2.f,
-                      -2.f, -1.f };
+  const float x[] = {1.f,  2.f, 3.f,  -2.f, -3.f, -1.f,
+                     -1.f, 2.f, -1.f, -2.f, -2.f, -1.f};
   size_t n = sizeof(x) / sizeof(float);
   size_t batchsize = 2, c = 3, h = 2, w = 1;
-  singa::Tensor in(singa::Shape {
-    batchsize, h, w, c
-  });
+  singa::Tensor in(singa::Shape{batchsize, h, w, c});
   in.CopyDataFromHostPtr<float>(x, n);
 
   PReLU prelu;
@@ -55,10 +53,8 @@ TEST(PReLU, ForwardCPU) {
   preluconf->set_format("NHWC");
   prelu.Setup(conf);
 
-  const float neg_slope[] = { 0.25f, 0.5f, 0.75f };
-  singa::Tensor a(singa::Shape {
-    c
-  });
+  const float neg_slope[] = {0.25f, 0.5f, 0.75f};
+  singa::Tensor a(singa::Shape{c});
   a.CopyDataFromHostPtr<float>(neg_slope, c);
   prelu.Set_a(a);
 
@@ -79,17 +75,15 @@ TEST(PReLU, ForwardCPU) {
       y[i] = std::max(x[i], 0.f) + neg_slope[pos] * std::min(x[i], 0.f);
     }
   }
-  for (size_t i = 0; i < n; i++)
-    EXPECT_FLOAT_EQ(y[i], yptr[i]);
+  for (size_t i = 0; i < n; i++) EXPECT_FLOAT_EQ(y[i], yptr[i]);
 }
 
 TEST(PReLU, BackwardCPU) {
-  const float x[] = {1.f, 2.f, 3.f, -2.f, -3.f, -1.f, -1.f, 2.f, -1.f, -2.f, -2.f, -1.f};
+  const float x[] = {1.f,  2.f, 3.f,  -2.f, -3.f, -1.f,
+                     -1.f, 2.f, -1.f, -2.f, -2.f, -1.f};
   size_t n = sizeof(x) / sizeof(float);
   size_t batchsize = 2, c = 3, h = 2, w = 1;
-  singa::Tensor in(singa::Shape {
-    batchsize, c, h, w
-  });
+  singa::Tensor in(singa::Shape{batchsize, c, h, w});
   in.CopyDataFromHostPtr<float>(x, n);
 
   PReLU prelu;
@@ -99,20 +93,16 @@ TEST(PReLU, BackwardCPU) {
   preluconf->set_format("NCHW");
   prelu.Setup(conf);
 
-  const float neg_slope[] = { 0.25f, 0.5f, 0.75f };
-  singa::Tensor a(singa::Shape {
-    c
-  });
+  const float neg_slope[] = {0.25f, 0.5f, 0.75f};
+  singa::Tensor a(singa::Shape{c});
   a.CopyDataFromHostPtr<float>(neg_slope, c);
   prelu.Set_a(a);
 
   singa::Tensor out = prelu.Forward(singa::kTrain, in);
 
-  const float grad[] = { 1.f, 2.f, -2.f, -1.f, -1.f, -3.f, 2.f, -2.f, 1.f, 1.f,
-                         -2.f, 0.f };
-  singa::Tensor out_diff(singa::Shape {
-    batchsize, c, h, w
-  });
+  const float grad[] = {1.f, 2.f,  -2.f, -1.f, -1.f, -3.f,
+                        2.f, -2.f, 1.f,  1.f,  -2.f, 0.f};
+  singa::Tensor out_diff(singa::Shape{batchsize, c, h, w});
   out_diff.CopyDataFromHostPtr<float>(grad, n);
   const auto ret = prelu.Backward(singa::kTrain, out_diff);
   const float *xptr = ret.first.data<const float *>();
@@ -120,7 +110,7 @@ TEST(PReLU, BackwardCPU) {
   float *dx = new float[n];
   size_t div_factor = prelu.Channel_shared() ? c : 1;
   size_t params = prelu.Channel_shared() ? 1 : c;
-  float da[] = { 0.f, 0.f, 0.f };
+  float da[] = {0.f, 0.f, 0.f};
   if (prelu.Format() == "NCHW") {
     for (size_t i = 0; i < n; i++) {
       size_t pos = i / (h * w) % c / div_factor;
@@ -142,8 +132,6 @@ TEST(PReLU, BackwardCPU) {
       da[pos] += grad[i] * std::min(x[i], 0.f);
     }
   }
-  for (size_t i = 0; i < n; i++)
-    EXPECT_FLOAT_EQ(dx[i], xptr[i]);
-  for (size_t i = 0; i < params; i++)
-    EXPECT_FLOAT_EQ(da[i], aptr[i]);
+  for (size_t i = 0; i < n; i++) EXPECT_FLOAT_EQ(dx[i], xptr[i]);
+  for (size_t i = 0; i < params; i++) EXPECT_FLOAT_EQ(da[i], aptr[i]);
 }



[3/3] incubator-singa git commit: SINGA-190 - Add prelu layer and flatten layer

Posted by wa...@apache.org.
SINGA-190 - Add prelu layer and flatten layer

Merge PR#162 into dev


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/26df5ac0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/26df5ac0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/26df5ac0

Branch: refs/heads/dev
Commit: 26df5ac03326576cebcca516da3b27ba1fd0dbd8
Parents: 58be3f8 6d69047
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Sun Jun 12 22:33:02 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Sun Jun 12 22:33:02 2016 +0800

----------------------------------------------------------------------
 include/singa/core/tensor.h          | 396 ++++++++---------
 src/core/tensor/math_kernel.cu       | 656 +++++++++++++++-------------
 src/core/tensor/math_kernel.h        |  93 ++--
 src/core/tensor/tensor.cc            | 702 ++++++++++++++++--------------
 src/core/tensor/tensor_math.h        | 393 +++++++++--------
 src/core/tensor/tensor_math_cpp.h    | 585 +++++++++++++++++++------
 src/core/tensor/tensor_math_cuda.h   | 412 ++++++++++++++----
 src/model/layer/cudnn_convolution.cc | 180 ++++----
 test/singa/test_cudnn_convolution.cc | 181 ++++++++
 test/singa/test_tensor_math.cc       | 295 ++++++++++++-
 10 files changed, 2470 insertions(+), 1423 deletions(-)
----------------------------------------------------------------------