You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2015/09/04 12:05:11 UTC

[3/7] incubator-singa git commit: SINGA-21 Code review 4

SINGA-21 Code review 4

review base_layer.h/cc, layer.h/cc, optional_layer.h/cc
  - change basic functions in Layer from virtual to inline
  - make ComputeFeature and ComputeGradient have same paramters
    i.e. ComputeGradient(Phase phase, Metric* perf)
  - add a middle level of layer category for ConnectionLayer and NeuronLayer
  - re-arrange Layers in these file
    base_layer.h contains layer categories that are extended by other layers
    layer.h contains layer implementations


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/134c891a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/134c891a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/134c891a

Branch: refs/heads/master
Commit: 134c891abf07b28ed75c2ee403d4698164c12c3e
Parents: f99246e
Author: wang sheng <wa...@gmail.com>
Authored: Sat Aug 22 19:40:55 2015 +0800
Committer: wangwei <wa...@comp.nus.edu.sg>
Committed: Fri Aug 28 18:32:01 2015 +0800

----------------------------------------------------------------------
 include/neuralnet/base_layer.h     | 277 ++++------
 include/neuralnet/layer.h          | 372 +++++++-------
 include/neuralnet/optional_layer.h |  26 +-
 src/neuralnet/base_layer.cc        | 265 +++-------
 src/neuralnet/layer.cc             | 859 +++++++++++++++++++-------------
 src/neuralnet/optional_layer.cc    | 136 +++--
 src/proto/job.proto                |   5 +-
 src/trainer/worker.cc              |   4 +-
 8 files changed, 981 insertions(+), 963 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/134c891a/include/neuralnet/base_layer.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/base_layer.h b/include/neuralnet/base_layer.h
index 9aa207d..5498cd0 100644
--- a/include/neuralnet/base_layer.h
+++ b/include/neuralnet/base_layer.h
@@ -1,27 +1,19 @@
 #ifndef SINGA_NEURALNET_BASE_LAYER_H_
 #define SINGA_NEURALNET_BASE_LAYER_H_
 
-#include <vector>
-#include <string>
 #include <map>
-#include <utility>
-#include <memory>
+#include <string>
 #include <thread>
+#include <vector>
 
-#include "proto/job.pb.h"
 #include "proto/common.pb.h"
-#include "utils/param.h"
+#include "proto/job.pb.h"
 #include "utils/common.h"
 #include "utils/blob.h"
+#include "utils/param.h"
 
 namespace singa {
 
-using std::vector;
-using std::string;
-using std::map;
-
-class Layer;
-
 /**
  * Base layer class.
  *
@@ -31,8 +23,9 @@ class Layer;
  */
 class Layer {
  public:
-  static Layer *Create(const LayerProto& proto);
-  Layer() { }
+  static Layer* Create(const LayerProto& proto);
+
+  Layer() {}
   virtual ~Layer() {}
   /**
    * Setup layer properties.
@@ -44,8 +37,10 @@ class Layer {
    * @param npartitions num of total partitions of the original layer. This
    * layer should be setup as one partition.
    */
-  virtual void Setup(const LayerProto& proto, int npartitions = 1);
-
+  virtual void Setup(const LayerProto& proto, int npartitions = 1) {
+    CHECK_GE(npartitions, 1);
+    layer_proto_ = proto;
+  }
   /**
    * Compute features of this layer based on connected layers.
    *
@@ -57,8 +52,8 @@ class Layer {
    *
    * @param flag kTrain, kTest, kPositive, etc.
    */
-  virtual void ComputeLoss(Metric* perf) {}
-  virtual void ComputeGradient(int flag) = 0;
+  virtual void ComputeGradient(int flag, Metric* perf) = 0;
+  virtual void ComputeGradient(int flag, Metric* perf) = 0;
   /**
    * For print debug info about each layer, e.g., norm of feature vector,
    * norm of parameters.
@@ -73,8 +68,8 @@ class Layer {
    *
    * @return parameters associated with this layer
    */
-  virtual const vector<Param*> GetParams() const {
-    return vector<Param*> {};
+  virtual const std::vector<Param*> GetParams() const {
+    return std::vector<Param*> {};
   }
   /**
    * Return the connection type between one neuron of this layer and
@@ -82,7 +77,7 @@ class Layer {
    * Currently support two connection types: kOneToOne, and kOneToAll.
    * kOneToOne indicates the neuron depends on only one neuron from src layer.
    * kOneToAll indicates the neuron depends on all neurons from src layer.
-   * TODO support kOneToMany.
+   * TODO(wangwei) support kOneToMany.
    *
    * @param k index of source layer (current only support k = 0.
    * @param connection type.
@@ -91,7 +86,6 @@ class Layer {
     // CHECK_LT(k, srclayers_.size());
     return kOneToOne;
   }
-
   /**
    * Return the connection type of this layer and all dst layers.
    *
@@ -108,27 +102,30 @@ class Layer {
     return kOneToOne;
   }
   /**
+   * For print debug info about each layer, e.g., norm of feature vector,
+   * norm of parameters.
+   *
+   * @param step training/test/validation step
+   * @param phase forward/backward/positive/negative...
+   * @return debug info about this layer.
+   */
+  virtual const std::string DebugString(int step, Phase phase);
+  /**
    * @return partition dimension of this layer.
    * -1 for no partition;
    *  0 for partition the mini-batch into sub-mini-batch.
    *  1 for partition the layer feature vector into sub-vector.
    */
-  virtual int partition_dim() const {
+  inline int partition_dim() const {
+    CHECK_LE(layer_proto_.partition_dim(), 1);
     return layer_proto_.partition_dim();
   }
-
-  virtual int partition_id() const {
-    return layer_proto_.partition_id();
-  }
-  virtual int type() const {
-    return layer_proto_.type();
-  }
+  inline int partition_id() const { return layer_proto_.partition_id(); }
+  inline int type() const { return layer_proto_.type(); }
   /**
    * Return name of this layer
    */
-  const std::string &name() const {
-    return layer_proto_.name();
-  }
+  inline const std::string &name() const { return layer_proto_.name(); }
   /**
    * @return name of src data blob, used by prefetch layer to locate the data
    * blob in parser layers; The default value is "unknown"; If the
@@ -147,7 +144,6 @@ class Layer {
   virtual Blob<float>* mutable_data(const Layer* from) {
     return &data_;
   }
-
   virtual const Blob<float>& grad(const Layer* from) const {
     return grad_;
   }
@@ -160,36 +156,17 @@ class Layer {
   /**
    * return LayerS that connected to this layer
    */
-  virtual const vector<Layer*> srclayers() const {
-    return srclayers_;
-  }
+  inline const std::vector<Layer*> srclayers() const { return srclayers_; }
   /**
    * return LayerS that this layer connected to
    */
-  virtual const vector<Layer*> dstlayers() const {
-    return dstlayers_;
-  }
-
-  virtual int srclayers_size() const {
-    return srclayers_.size();
-  }
-  virtual int dstlayers_size() const {
-    return dstlayers_.size();
-  }
-  virtual void clear_dstlayers() {
-    dstlayers_.clear();
-  }
-  virtual void clear_srclayers() {
-    srclayers_.clear();
-  }
-
-  virtual void add_srclayer(Layer* src) {
-    srclayers_.push_back(src);
-  }
-  virtual void add_dstlayer(Layer* dst) {
-    dstlayers_.push_back(dst);
-  }
-
+  inline const std::vector<Layer*> dstlayers() const { return dstlayers_; }
+  inline int srclayers_size() const { return srclayers_.size(); }
+  inline int dstlayers_size() const { return dstlayers_.size(); }
+  inline void clear_dstlayers() { dstlayers_.clear(); }
+  inline void clear_srclayers() { srclayers_.clear(); }
+  inline void add_srclayer(Layer* src) { srclayers_.push_back(src); }
+  inline void add_dstlayer(Layer* dst) { dstlayers_.push_back(dst); }
   virtual bool is_datalayer() const {
     return false;
   }
@@ -217,6 +194,7 @@ class Layer {
 
  protected:
   LayerProto layer_proto_;
+<<<<<<< HEAD
   Blob<float> data_, grad_;
   vector<Layer*> srclayers_, dstlayers_;
 };
@@ -304,14 +282,9 @@ class ConcateLayer: public Layer {
 /**
  * Base layer for reading records from local Shard, HDFS, lmdb, etc.
  */
-class DataLayer: public Layer{
+class DataLayer: public Layer {
  public:
-  using Layer::ComputeGradient;
-  using Layer::mutable_data;
-  using Layer::mutable_grad;
-  using Layer::dst_layer_connection;
-
-  void ComputeGradient(int flag) override {}
+  void ComputeGradient(int flag, Metric* perf) override {}
   bool is_datalayer() const override {
     return true;
   }
@@ -324,123 +297,62 @@ class DataLayer: public Layer{
   ConnectionType dst_layer_connection() const override {
     return kOneToMany;
   }
-
-  int batchsize() const {
-    return batchsize_;
-  }
+  inline int batchsize() const { return batchsize_; }
   virtual const Record& sample() const {
     return sample_;
   }
   /**
    * @return the loaded records
    */
-  virtual const vector<Record>& records() const {
+  virtual const std::vector<Record>& records() const {
     return records_;
   }
 
  protected:
-  int random_skip_, batchsize_;
+  int random_skip_;
+  int batchsize_;
   Record sample_;
-  vector<Record> records_;
+  std::vector<Record> records_;
 };
 
 /**
- * Layer for prefetching data records and parsing them.
- *
- * The data loading and parsing work is done by internal DataLayer and
- * ParserLayer respectively. This layer controls the prefetching thread, i.e.,
- * creating and joining the prefetching thread.
+ * Base layer for parsing the input records into Blobs.
  */
-class PrefetchLayer : public Layer {
+class ParserLayer : public Layer {
  public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
-
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric* perf) override;
-  void ComputeGradient(int flag) override {};
-
-  const Blob<float>& data(const Layer* from) const override;
-  Blob<float>* mutable_data(const Layer* layer) override;
-
+  void ComputeFeature(Phase phase, Metric* perf) override;
+  void ComputeGradient(Phase phase, Metric* perf) override {}
+  /**
+   * Parse records from DataLayer into blob.
+   */
+  virtual void ParseRecords(Phase phase, const std::vector<Record>& records,
+      Blob<float>* blob) = 0;
+  bool is_parserlayer() const override {
+    return true;
+  }
   Blob<float>* mutable_grad(const Layer* layer) override {
     return nullptr;
   }
-  const Blob<float>& grad(const Layer* from) const override {
-    CHECK(false) << "Loss layer has not gradient blob";
+  const Blob<float>& grad(const Layer* from) const  override {
+    CHECK(false) << "Parser layer has not gradient blob";
     return grad_;
   }
-
-  void Prefetch(int flag);
-  virtual ~PrefetchLayer();
-
- protected:
-  vector<Layer*> sublayers_;
-  map<string, Blob<float>> datablobs_;
-  std::thread thread_;
-};
-
-/**
- * Slice the source layer into multiple dst layers on one dimension
- */
-class SliceLayer: public Layer {
- public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
-
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric* perf) override;
-  void ComputeGradient(int flag) override;
-  ConnectionType dst_layer_connection() const override {
-    return kOneToMany;
-  }
-  const Blob<float>& data(const Layer* layer) const override;
-  const Blob<float>& grad(const Layer* layer) const override;
-  Blob<float>* mutable_data(const Layer* layer) override;
-  Blob<float>* mutable_grad(const Layer* layer) override;
-
- protected:
-  int SliceID(const Layer* layer) const;
-
- private:
-  vector<Blob<float>> datavec_, gradvec_;
-  int slice_dim_, slice_num_;
 };
 
-/**
- * Connect the source layer with multiple dst layers.
- * Pass source layer's data blob directly to dst layers.
- * Aggregate dst layer's gradients into source layer's gradient.
- */
-class SplitLayer: public Layer {
- public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
-
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric* perf) override;
-  void ComputeGradient(int flag) override;
-  ConnectionType dst_layer_connection() const override {
-    return kOneToMany;
-  }
- protected:
-  Blob<float> grads_;
+class NeuronLayer : public Layer {
+  // defined as a layer category
 };
 
 /**
- * Loss layer to calculate loss and other metrics, e.g., precison.
+ * Base layer for calculating loss and other metrics, e.g., precison.
  */
-class LossLayer: public Layer{
+class LossLayer: public Layer {
  public:
-  using Layer::mutable_grad;
-  using Layer::grad;
-  using Layer::is_losslayer;
-
   Blob<float>* mutable_grad(const Layer* layer) override {
     return nullptr;
   }
   const Blob<float>& grad(const Layer* from) const override {
-    CHECK(false) << "Loss layer has not gradient blob";
+    LOG(FATAL) << "Loss layer has no gradient blob";
     return grad_;
   }
   bool is_losslayer() const override {
@@ -452,33 +364,54 @@ class LossLayer: public Layer{
 };
 
 /**
- * parse the input records into Blobs.
+ * Base layer for sending/waiting remote messages.
  */
-class ParserLayer: public Layer {
+class BridgeLayer : public Layer {
  public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
-  using Layer::is_parserlayer;
-  using Layer::mutable_grad;
-  using Layer::grad;
+  inline void set_ready(bool a) { ready_ = a; }
+  inline bool ready() const { return ready_; }
+  bool is_bridgelayer() const override { return true; }
 
-  void ComputeFeature(int flag, Metric* perf) override;
-  void ComputeGradient(int flag) override {};
-  /**
-   * Parse records from DataLayer into blob.
-   */
-  virtual void ParseRecords(int flag, const vector<Record>& records,
-      Blob<float>* blob) = 0;
-  bool is_parserlayer() const override {
-    return true;
-  }
+ protected:
+  //!< true if received grad from BridgeDstLayer
+  bool ready_;
+};
+
+/**
+ * Base layer for connecting layers when neural net is partitioned.
+ */
+class ConnectionLayer : public Layer {
+  // defined as a layer category
+};
+
+/**
+ * Layer for prefetching data records and parsing them.
+ *
+ * The data loading and parsing work is done by internal DataLayer and
+ * ParserLayer respectively. This layer controls the prefetching thread, i.e.,
+ * creating and joining the prefetching thread.
+ */
+class PrefetchLayer : public Layer {
+ public:
+  ~PrefetchLayer();
+  void Setup(const LayerProto& proto, int npartitions) override;
+  void ComputeFeature(Phase phase, Metric* perf) override;
+  void ComputeGradient(Phase phase, Metric* perf) override {}
+  const Blob<float>& data(const Layer* from, Phase phase) const override;
+  void Prefetch(Phase phase);
+  Blob<float>* mutable_data(const Layer* layer, Phase phase) override;
   Blob<float>* mutable_grad(const Layer* layer) override {
     return nullptr;
   }
-  const Blob<float>& grad(const Layer* from) const  override {
-    CHECK(false) << "Parser layer has not gradient blob";
+  const Blob<float>& grad(const Layer* from) const override {
+    CHECK(false) << "Loss layer has not gradient blob";
     return grad_;
   }
+
+ protected:
+  std::vector<Layer*> sublayers_;
+  std::map<std::string, Blob<float>> datablobs_;
+  std::thread thread_;
 };
 
 class RBMLayer: public Layer {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/134c891a/include/neuralnet/layer.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/layer.h b/include/neuralnet/layer.h
index 435d854..c64cee9 100644
--- a/include/neuralnet/layer.h
+++ b/include/neuralnet/layer.h
@@ -1,20 +1,10 @@
 #ifndef SINGA_NEURALNET_LAYER_H_
 #define SINGA_NEURALNET_LAYER_H_
 
-#include <lmdb.h>
-
 #include <vector>
-#include <string>
-#include <map>
-#include <functional>
-#include <utility>
-#include <memory>
-#include <chrono>
-#include <random>
-
+#include "neuralnet/base_layer.h"
 #include "proto/job.pb.h"
 #include "utils/data_shard.h"
-#include "neuralnet/base_layer.h"
 
 /**
  * \file this file includes the declarations neuron layer classes that conduct
@@ -22,26 +12,78 @@
  */
 namespace singa {
 
+/********** Derived from DataLayer **********/
+
+class ShardDataLayer : public DataLayer {
+ public:
+  ~ShardDataLayer();
+
+  void Setup(const LayerProto& proto, int npartitions) override;
+  void ComputeFeature(Phase phase, Metric *perf) override;
+
+ private:
+  DataShard* shard_;
+};
+
+/********** Derived from ParserLayer **********/
+
+class LabelLayer : public ParserLayer {
+ public:
+  void Setup(const LayerProto& proto, int npartitions) override;
+  void ParseRecords(Phase phase, const std::vector<Record>& records,
+                    Blob<float>* blob) override;
+};
+
+class MnistLayer : public ParserLayer {
+ public:
+  void Setup(const LayerProto& proto, int npartitions) override;
+  void ParseRecords(Phase phase, const std::vector<Record>& records,
+                    Blob<float>* blob) override;
+
+ protected:
+  // height and width of the image after deformation
+  // kernel size for elastic distortion
+  // n^2 images are processed as a batch for elastic distortion
+  // conv height and conv width
+  // gauss kernel values, displacements, column image and tmp buffer
+  // float* gauss_, *displacementx_, *displacementy_, *colimg_, *tmpimg_;
+  float  gamma_, beta_, sigma_, kernel_, alpha_, norm_a_, norm_b_;
+  int resize_, elastic_freq_;
+};
+
+class RGBImageLayer : public ParserLayer {
+ public:
+  void Setup(const LayerProto& proto, int npartitions) override;
+  void ParseRecords(Phase phase, const std::vector<Record>& records,
+                    Blob<float>* blob) override;
+
+ private:
+  float scale_;
+  int cropsize_;
+  bool mirror_;
+  Blob<float> mean_;
+};
+
+/********** Derived from NeuronLayer **********/
+
 /**
  * Convolution layer.
  */
-class ConvolutionLayer: public Layer {
+class ConvolutionLayer : public NeuronLayer {
  public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
+  ~ConvolutionLayer();
 
   void Setup(const LayerProto& proto, int npartitions) override;
   void ComputeFeature(int flag, Metric *perf) override;
   void ComputeGradient(int flag) override;
-  const vector<Param*> GetParams() const override {
-    vector<Param*> params{weight_, bias_};
+  const std::vector<Param*> GetParams() const override {
+    std::vector<Param*> params{weight_, bias_};
     return params;
   }
   ConnectionType src_neuron_connection(int k) const  override {
     // CHECK_LT(k, srclayers_.size());
     return kOneToAll;
   }
-  ~ConvolutionLayer();
 
  protected:
   int kernel_, pad_,  stride_;
@@ -51,15 +93,11 @@ class ConvolutionLayer: public Layer {
   Blob<float> col_data_, col_grad_;
 };
 
-class DropoutLayer: public Layer {
+class DropoutLayer : public NeuronLayer {
  public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
-
   void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag) override;
-
+  void ComputeFeature(int flag, Metric* perf) override;
+  void ComputeGradient(int flag, Metric* perf) override;
  protected:
   // drop probability
   float pdrop_;
@@ -87,66 +125,8 @@ class RBMVisLayer: public RBMLayer {
   RBMLayer* hid_layer_;
   Layer* input_layer_;
 };
-/**
- * RBM hidden layer
- */
-class RBMHidLayer: public RBMLayer {
- public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
-
-  ~RBMHidLayer();
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag) override;
-  Blob<float>* Sample(int flat) override;
- private:
-  // whether use gaussian sampling
-  bool gaussian_;
-  RBMLayer *vis_layer_;
-};
-/**
-  * fully connected layer
-  */
-class InnerProductLayer: public Layer {
- public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
-
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag) override;
-
-  ConnectionType src_neuron_connection(int k) const override {
-    // CHECK_LT(k, srclayers_.size());
-    return kOneToAll;
-  }
-  const vector<Param*> GetParams() const override {
-    vector<Param*> params{weight_, bias_};
-    return params;
-  }
-  ~InnerProductLayer();
-
- private:
-  //! dimension of the hidden layer
-  int hdim_;
-  //! dimension of the visible layer
-  int vdim_;
-  int batchsize_;
-  bool transpose_;
-  Param* weight_, *bias_;
-};
 
-class LabelLayer: public ParserLayer {
- public:
-  using ParserLayer::ParseRecords;
-
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ParseRecords(int flag, const vector<Record>& records,
-      Blob<float>* blob) override;
-};
-
-class LRNLayer: public Layer {
+class LRNLayer : public NeuronLayer {
 /**
  * Local Response Normalization edge
  * b_i=a_i/x_i^beta
@@ -155,12 +135,9 @@ class LRNLayer: public Layer {
  * a_i, the activation (after ReLU) of a neuron convolved with the i-th kernel.
  * b_i, the neuron after normalization, N is the total num of kernels
  */
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
-
   void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag) override;
+  void ComputeFeature(Phase phase, Metric *perf) override;
+  void ComputeGradient(Phase phase, Metric* perf) override;
 
  protected:
   //! shape of the bottom layer feature
@@ -172,35 +149,11 @@ class LRNLayer: public Layer {
   Blob<float> norm_;
 };
 
-class MnistLayer: public ParserLayer {
- public:
-  using ParserLayer::ParseRecords;
-
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ParseRecords(int flag, const vector<Record>& records,
-      Blob<float>* blob) override;
-  ConnectionType dst_layer_connection() const override {
-    return kOneToMany;
-  }
- protected:
-  // height and width of the image after deformation
-  // kernel size for elastic distortion
-  // n^2 images are processed as a batch for elastic distortion
-  // conv height and conv width
-  // gauss kernel values, displacements, column image and tmp buffer
-  // float* gauss_, *displacementx_, *displacementy_, *colimg_, *tmpimg_;
-  float  gamma_, beta_, sigma_, kernel_, alpha_, norm_a_, norm_b_;
-  int resize_, elastic_freq_;
-};
-
-class PoolingLayer: public Layer {
+class PoolingLayer : public NeuronLayer {
  public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
-
   void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag) override;
+  void ComputeFeature(Phase phase, Metric *perf) override;
+  void ComputeGradient(Phase phase, Metric* perf) override;
 
  protected:
   int kernel_, pad_, stride_;
@@ -208,60 +161,97 @@ class PoolingLayer: public Layer {
   PoolingProto_PoolMethod pool_;
 };
 
-class ReLULayer: public Layer {
+class ReLULayer : public NeuronLayer {
  public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
-
-  void Setup(const LayerProto& proto, int npartitions = 1) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag) override;
+  void Setup(const LayerProto& proto, int npartitions) override;
+  void ComputeFeature(Phase phase, Metric *perf) override;
+  void ComputeGradient(Phase phase, Metric* perf) override;
 };
 
-class EuclideanLossLayer: public LossLayer {
+/**
+ * RBM hidden layer
+ */
+class RBMHidLayer: public RBMLayer {
  public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
+  ~RBMHidLayer();
 
+  ~RBMHidLayer();
   void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag) override;
+  void ComputeFeature(int flag, Metric* perf) override;
+  void ComputeGradient(int flag, Metric* perf) override;
+  Blob<float>* Sample(int flat) override;
+ private:
+  // whether use gaussian sampling
+  bool gaussian_;
+  RBMLayer *vis_layer_;
+};
 
+/**
+  * RBM visible layer
+  */
+class RBMVisLayer : public NeuronLayer {
+ public:
+  ~RBMVisLayer();
+
+  void Setup(const LayerProto& proto, int npartitions) override;
+  void ComputeFeature(int flag, Metric* perf) override;
+  void ComputeGradient(int flag, Metric* perf) override;
 
-  int partition_dim() const override {
-    CHECK_LE(layer_proto_.partition_dim(), 1);
-    return layer_proto_.partition_dim();
-  }
   ConnectionType src_neuron_connection(int k) const override {
     // CHECK_LT(k, srclayers_.size());
     return kOneToAll;
   }
+  const Blob<float>& data(const Layer* from, Phase phase) const override {
+    return (phase == kPositive) ? data_ : vis_sample_;
+  }
+  const std::vector<Param*> GetParams() const override {
+    std::vector<Param*> params{weight_, bias_};
+    return params;
+  }
 
  private:
+  //! dimension of the hidden layer
+  int hdim_;
+  //! dimension of the visible layer
+  int vdim_;
   int batchsize_;
-  int dim_;
+  bool transpose_;
+  Param* weight_, *bias_;
+  // data to store sampling result
+  Blob<float> vis_sample_;
+  // in order to implement Persistent Contrastive Divergence,
+};
+
+/**
+ * This layer apply Tan function to neuron activations.
+ * f(x)=A tanh(Bx)
+ * f'(x)=B/A (A*A-f(x)*f(x))
+ */
+class TanhLayer : public NeuronLayer {
+ public:
+  void Setup(const LayerProto& proto, int npartitions) override;
+  void ComputeFeature(Phase phase, Metric *perf) override;
+  void ComputeGradient(Phase phase, Metric* perf) override;
+
+ private:
+  float outer_scale_, inner_scale_;
 };
 
-class SoftmaxLossLayer: public LossLayer {
+/********** Derived from LossLayer **********/
+
+class SoftmaxLossLayer : public LossLayer {
   /*
    * connected from the label layer and the last fc layer
    */
  public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
-
   void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag) override;
+  void ComputeFeature(int flag, Metric* perf) override;
+  void ComputeGradient(int flag, Metric* perf) override;
 
   /**
    * softmax is not recommendeded for partition because it requires the whole
    * src layer for normalization.
    */
-  int partition_dim() const override {
-    CHECK_LE(layer_proto_.partition_dim(), 1);
-    return layer_proto_.partition_dim();
-  }
   ConnectionType src_neuron_connection(int k) const override {
     // CHECK_LT(k, srclayers_.size());
     return kOneToAll;
@@ -274,31 +264,77 @@ class SoftmaxLossLayer: public LossLayer {
   int topk_;
 };
 
-class RGBImageLayer: public ParserLayer {
- public:
-  using ParserLayer::ParseRecords;
+/********** Derived from BridgeLayer **********/
 
+/**
+ * For recv data from layer on other threads which may resident on other nodes
+ * due to layer/data partiton
+ */
+class BridgeDstLayer : public BridgeLayer {
+ public:
   void Setup(const LayerProto& proto, int npartitions) override;
-  void ParseRecords(int flag, const vector<Record>& records,
-      Blob<float>* blob) override;
+  void ComputeFeature(int flag, Metric* perf) override {
+    // reset ready_ for next iteration.
+    ready_ = false;
+  }
+  void ComputeGradient(int flag, Metric* perf) override {}
+  bool is_bridgedstlayer() const {
+    return true;
+  }
+};
 
- private:
-  float scale_;
-  int cropsize_;
-  bool mirror_;
-  Blob<float> mean_;
+/**
+ * For sending data to layer on other threads which may resident on other nodes
+ * due to layer/data partition.
+ */
+class BridgeSrcLayer : public BridgeLayer {
+ public:
+  void ComputeFeature(Phase phase, Metric* perf) override {}
+  void ComputeGradient(Phase phase, Metric* perf) override {
+    ready_ = false;
+  }
+  const Blob<float>& data(const Layer* from, Phase phase) const override {
+    return srclayers_[0]->data(this);
+  }
+  Blob<float>* mutable_data(const Layer* from, Phase phase) override {
+    return srclayers_[0]->mutable_data(this);
+  }
+  const Blob<float>& grad(const Layer* from) const override {
+    return srclayers_[0]->grad(this);
+  }
+  Blob<float>* mutable_grad(const Layer* from) override {
+    return srclayers_[0]->mutable_grad(this);
+  }
+  bool is_bridgesrclayer() const override {
+    return true;
+  }
 };
 
-class ShardDataLayer: public DataLayer{
+/********** Derived from ConnectionLayer **********/
+
+/**
+ * Concate src layers on one dimension
+ */
+class ConcateLayer : public ConnectionLayer {
  public:
-  using Layer::ComputeFeature;
+  void Setup(const LayerProto& proto, int npartitions) override;
+  void ComputeFeature(Phase phase, Metric* perf) override;
+  void ComputeGradient(Phase phase, Metric* perf) override;
+};
 
-  ~ShardDataLayer();
+/**
+ * Slice the source layer into multiple dst layers on one dimension
+ */
+class SliceLayer : public ConnectionLayer {
+ public:
   void Setup(const LayerProto& proto, int npartitions) override;
   void ComputeFeature(int flag, Metric *perf) override;
 
  private:
-  DataShard* shard_;
+  std::vector<Blob<float>> datavec_;
+  std::vector<Blob<float>> gradvec_;
+  int slice_dim_;
+  int slice_num_;
 };
 
 /**
@@ -312,29 +348,25 @@ class SigmoidLayer: public Layer {
   using Layer::ComputeGradient;
 
   void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag) override;
+  void ComputeFeature(int flag, Metric* perf) override;
+  void ComputeGradient(int flag, Metric* perf) override;
 };
 
 /**
- * This layer apply Tan function to neuron activations.
- * f(x)=A tanh(Bx)
- * f'(x)=B/A (A*A-f(x)*f(x))
+ * Connect the source layer with multiple dst layers.
+ * Pass source layer's data blob directly to dst layers.
+ * Aggregate dst layer's gradients into source layer's gradient.
  */
-class TanhLayer: public Layer {
+class SplitLayer : public ConnectionLayer {
  public:
-  using Layer::ComputeFeature;
-  using Layer::ComputeGradient;
-
   void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag) override;
+  void ComputeFeature(int flag, Metric* perf) override;
+  void ComputeGradient(int flag, Metric* perf) override;
 
- private:
-  float outer_scale_, inner_scale_;
+ protected:
+  Blob<float> grads_;
 };
 
-
 }  // namespace singa
 
 #endif  // SINGA_NEURALNET_LAYER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/134c891a/include/neuralnet/optional_layer.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/optional_layer.h b/include/neuralnet/optional_layer.h
index 34dd43b..8f64ab4 100644
--- a/include/neuralnet/optional_layer.h
+++ b/include/neuralnet/optional_layer.h
@@ -1,20 +1,24 @@
+#ifndef SINGA_NEURALNET_OPTIONAL_LAYER_H_
+#define SINGA_NEURALNET_OPTIONAL_LAYER_H_
+
 #ifdef USE_LMDB
-#ifndef SINGA_NEURALNET_OPTIONAL_LAYER_
-#define SINGA_NEURALNET_OPTIONAL_LAYER_
-#include "neuralnet/layer.h"
+#include <lmdb.h>
+#endif
+#include <string>
+#include "neuralnet/base_layer.h"
 
 namespace singa {
 
-class LMDBDataLayer: public DataLayer{
+#ifdef USE_LMDB
+class LMDBDataLayer : public DataLayer {
  public:
-  using Layer::ComputeFeature;
-
   ~LMDBDataLayer();
-  void OpenLMDB(const std::string& path);
+
   void Setup(const LayerProto& proto, int npartitions) override;
+  void OpenLMDB(const std::string& path);
   void ComputeFeature(Phase phase, Metric *perf) override;
   void ConvertCaffeDatumToRecord(const CaffeDatum& datum,
-      SingleLabelImageRecord* record);
+                                 SingleLabelImageRecord* record);
 
  private:
   MDB_env* mdb_env_;
@@ -23,8 +27,8 @@ class LMDBDataLayer: public DataLayer{
   MDB_cursor* mdb_cursor_;
   MDB_val mdb_key_, mdb_value_;
 };
-} /* singa */
-
-#endif  // SINGA_NEURALNET_OPTIONAL_LAYER_
 #endif
 
+}  // namespace singa
+
+#endif  // SINGA_NEURALNET_OPTIONAL_LAYER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/134c891a/src/neuralnet/base_layer.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/base_layer.cc b/src/neuralnet/base_layer.cc
index 46f8b57..7d94a75 100644
--- a/src/neuralnet/base_layer.cc
+++ b/src/neuralnet/base_layer.cc
@@ -1,15 +1,22 @@
+#include "neuralnet/base_layer.h"
+
 #include <cblas.h>
+#include <glog/logging.h>
 #include <math.h>
+#include <opencv2/highgui/highgui.hpp>
+#include <opencv2/imgproc/imgproc.hpp>
 #include <cfloat>
-#include <glog/logging.h>
-#include "utils/singleton.h"
 #include "utils/factory.h"
-#include "neuralnet/base_layer.h"
+#include "utils/singleton.h"
 
 namespace singa {
-Layer *Layer::Create(const LayerProto& proto) {
+
+using std::string;
+using std::vector;
+
+Layer* Layer::Create(const LayerProto& proto) {
   auto* factory = Singleton<Factory<Layer>>::Instance();
-  Layer * layer = nullptr;
+  Layer* layer = nullptr;
   if (proto.has_user_type())
     layer = factory->Create(proto.user_type());
   else
@@ -17,233 +24,105 @@ Layer *Layer::Create(const LayerProto& proto) {
   return layer;
 }
 
-void Layer::Setup(const LayerProto& proto, int npartitions) {
-  CHECK_GE(npartitions, 1);
-  layer_proto_ = proto;
-}
-
-const string Layer::DebugString(int step, int flag) {
-  string ret =StringPrintf("Layer %10s ", name().c_str());
-  if ((flag & kForward) == kForward && data_.count() !=0) {
-    ret += StringPrintf("data norm1 %13.9f", data_.asum_data());
-  } else if ((flag & kBackward) == kBackward) {
-    if (grad_.count() != 0)
-      ret += StringPrintf("grad norm1 %13.9f\n", grad_.asum_data());
-    for(Param* p: GetParams())
-      ret += StringPrintf("param id %2d, name %10s,\
-          value norm1 %13.9f, grad norm1 %13.9f\n",
-          p->id(), p->name().c_str(),
-          p->data().asum_data(), p->grad().asum_data());
+const string Layer::DebugString(int step, Phase phase) {
+  string ret = StringPrintf("Layer %10s ", name().c_str());
+  if (data_.count() != 0)
+    return ret;
+  if (phase == kForward) {
+    ret += StringPrintf("data %10s data norm1 %13.9f", data_.asum_data());
+  } else if (phase == kBackward) {
+    ret += StringPrintf("grad norm1 %13.9f\n", grad_.asum_data());
+    for (Param* p : GetParams()) {
+      ret += StringPrintf(
+          "param id %2d, name %10s, value norm1 %13.9f, grad norm1 %13.9f\n",
+          p->id(), p->name().c_str(), p->data().asum_data(),
+          p->grad().asum_data());
+    }
   }
   return ret;
 }
-/********* Implementation for BridgeDstLayer **************/
-void BridgeDstLayer::Setup(const LayerProto& proto, int npartitions) {
-  Layer::Setup(proto, npartitions);
-  CHECK_EQ(srclayers_.size(),1);
-  data_.Reshape(srclayers_[0]->data(this).shape());
-  grad_.ReshapeLike(data_);
-}
-
-/************* Implementation for ConcateLayer ***********/
-void ConcateLayer::Setup(const LayerProto& proto, int npartitions) {
-  // CHECK_EQ(npartitions, 1);
-  Layer::Setup(proto, npartitions);
-  size_t concate_dim=proto.concate_conf().concate_dim();
-  CHECK_GE(concate_dim,0);
-  CHECK_GT(srclayers_.size(),1);
-  vector<int> shape=srclayers_[0]->data(this).shape();
-  for(size_t i=1;i<srclayers_.size();i++){
-    const vector<int>& srcshape=srclayers_[i]->data(this).shape();
-    for(size_t j=0;j<shape.size();j++)
-      if(j==concate_dim)
-        shape[j]+=srcshape[j];
-      else
-        CHECK_EQ(shape[j], srcshape[j]);
-  }
-  data_.Reshape(shape);
-  grad_.Reshape(shape);
-}
-
-void ConcateLayer::ComputeFeature(int flag, Metric *perf){
-  LOG(FATAL) << "Not implemented for Concate Layer";
-}
-
-void ConcateLayer::ComputeGradient(int flag){
-  LOG(FATAL) << "Not implemented for Concate Layer";
-}
 
 /************* Implementation for ParserLayer ***********/
-void ParserLayer::ComputeFeature(int flag, Metric *perf){
-  CHECK_EQ(srclayers_.size(),1);
-  auto datalayer=static_cast<DataLayer*>(*srclayers_.begin());
-  ParseRecords(flag, datalayer->records(), &data_);
+void ParserLayer::ComputeFeature(Phase phase, Metric *perf) {
+  CHECK_EQ(srclayers_.size(), 1);
+  auto datalayer = static_cast<DataLayer*>(*srclayers_.begin());
+  ParseRecords(phase, datalayer->records(), &data_);
 }
 
 /************* Implementation for PrefetchLayer ***********/
-void PrefetchLayer::Prefetch(int flag){
-  //clock_t s=clock();
-  for(auto layer: sublayers_)
-    layer->ComputeFeature(flag, nullptr);
-  //LOG(ERROR)<<(clock()-s)*1.0/CLOCKS_PER_SEC;
-}
-
-void PrefetchLayer::ComputeFeature(int flag, Metric* perf){
-  if(thread_.joinable())
+PrefetchLayer::~PrefetchLayer() {
+  if (thread_.joinable())
     thread_.join();
-  else{
-    Prefetch(flag);
-  }
-  for(auto layer: sublayers_){
-    if(layer->is_parserlayer())
-      // TODO replace CopyFrom with Swap?
-      datablobs_.at(layer->name()).CopyFrom(layer->data(this));
-  }
-  thread_=std::thread(&PrefetchLayer::Prefetch, this, flag);
+  for (auto layer : sublayers_)
+    delete layer;
 }
 
 void PrefetchLayer::Setup(const LayerProto& proto, int npartitions) {
   Layer::Setup(proto, npartitions);
   // CHECK_EQ(npartitions, 1);
-  Factory<Layer>* factory=Singleton<Factory<Layer>>::Instance();
-  const auto& sublayers=proto.prefetch_conf().sublayers();
+  Factory<Layer>* factory = Singleton<Factory<Layer>>::Instance();
+  const auto& sublayers = proto.prefetch_conf().sublayers();
   CHECK_GE(sublayers.size(), 1);
-  map<string, Layer*> layers;
-  for(auto const &p:sublayers){
-    auto layer=factory->Create(p.type());
+  std::map<string, Layer*> layers;
+  for (auto const &p : sublayers) {
+    auto layer = factory->Create(p.type());
     sublayers_.push_back(layer);
-    layers[p.name()]= layer;
+    layers[p.name()] = layer;
   }
-  // TODO topology sort layers
-  auto layer=sublayers_.begin();
-  for(auto const &p : sublayers){
+  // TODO(wangwei) topology sort layers
+  auto layer = sublayers_.begin();
+  for (auto const &p : sublayers) {
     std::vector<Layer*> src;
-    for(auto const &srcname: p.srclayers()){
+    for (auto const &srcname : p.srclayers()) {
       src.push_back(layers[srcname]);
       (*layer)->add_srclayer(layers[srcname]);
     }
     (*layer)->Setup(p);
     layer++;
   }
-  for(auto layer: sublayers_)
-    if(layer->is_parserlayer())
-      datablobs_[layer->name()]=Blob<float>(layer->data(this).shape());
+  for (auto layer : sublayers_)
+    if (layer->is_parserlayer())
+      datablobs_[layer->name()] = Blob<float>(layer->data(this).shape());
+}
+
+void PrefetchLayer::ComputeFeature(Phase phase, Metric* perf) {
+  if (thread_.joinable())
+    thread_.join();
+  else
+    Prefetch(phase);
+  for (auto layer : sublayers_) {
+    if (layer->is_parserlayer())
+      // TODO(wangwei) replace CopyFrom with Swap?
+      datablobs_.at(layer->name()).CopyFrom(layer->data(this));
+  }
+  thread_ = std::thread(&PrefetchLayer::Prefetch, this, phase);
+}
+
+void PrefetchLayer::Prefetch(Phase phase) {
+  // clock_t s=clock();
+  for (auto layer : sublayers_)
+    layer->ComputeFeature(phase, nullptr);
+  // LOG(ERROR)<<(clock()-s)*1.0/CLOCKS_PER_SEC;
 }
 
 const Blob<float>& PrefetchLayer::data(const Layer* from) const {
   LOG(FATAL) << " needs update";
-  if(from != nullptr) {
+  if (from != nullptr) {
     return datablobs_.at("");
   } else {
-    //CHECK_EQ(datablobs_.size(),1);
+    // CHECK_EQ(datablobs_.size(),1);
     return datablobs_.begin()->second;
   }
 }
 
 Blob<float>* PrefetchLayer::mutable_data(const Layer* from) {
   LOG(FATAL) << " needs update";
-  if(from!=nullptr){
+  if (from != nullptr) {
     return &(datablobs_.at(""));
-  }else{
-    //CHECK_EQ(datablobs_.size(),1);
+  } else {
+    // CHECK_EQ(datablobs_.size(),1);
     return &(datablobs_.begin()->second);
   }
 }
 
-PrefetchLayer::~PrefetchLayer(){
-  if(thread_.joinable())
-    thread_.join();
-  for(auto layer : sublayers_)
-    delete layer;
-}
-/************* Implementation for SliceLayer****************/
-void SliceLayer::Setup(const LayerProto& proto, int npartitions){
-  // CHECK_EQ(npartitions, 1);
-  Layer::Setup(proto, npartitions);
-  slice_dim_=proto.slice_conf().slice_dim();
-  slice_num_= npartitions;
-  CHECK_GE(slice_dim_,0);
-  CHECK_EQ(slice_num_, dstlayers_.size());
-  data_.Reshape(srclayers_[0]->data(this).shape());
-  grad_.ReshapeLike(data_);
-  datavec_.resize(slice_num_);
-  gradvec_.resize(slice_num_);
-  CHECK_EQ(data_.count()%slice_num_, 0); // restrict equal slicing
-  //LOG(ERROR)<<"slice dim "<<slice_dim<<" slice num "<<slice_num;
-  for(int i=0;i<slice_num_;i++){
-    vector<int> newshape(data_.shape());
-    newshape[slice_dim_]=newshape[slice_dim_]/slice_num_+
-      ((i==slice_num_-1)?newshape[slice_dim_]%slice_num_:0);
-    datavec_[i].Reshape(newshape);
-    gradvec_[i].Reshape(newshape);
-    //LOG(ERROR)<<"slice "<<IntVecToString(newshape);
-  }
-}
-
-int SliceLayer::SliceID(const Layer* layer) const {
-  CHECK(layer!= nullptr);
-  for(size_t i=0;i<datavec_.size();i++){
-    //LOG(ERROR)<<"get slice "<<IntVecToString(shapes_[i]);
-    if(dstlayers_[i] == layer)
-      return i;
-  }
-  CHECK(false);
-  return -1;
-}
-
-const Blob<float>& SliceLayer::data(const Layer* layer) const {
-  if(layer==nullptr)
-    return data_;
-  return datavec_[SliceID(layer)];
-}
-const Blob<float>& SliceLayer::grad(const Layer* layer) const {
-  if(layer==nullptr)
-    return grad_;
-  return gradvec_[SliceID(layer)];
-}
-Blob<float>* SliceLayer::mutable_data(const Layer* layer) {
-  if(layer==nullptr)
-    return &data_;
-  return &datavec_[SliceID(layer)];
-}
-Blob<float>* SliceLayer::mutable_grad(const Layer* layer){
-  if(layer==nullptr)
-    return &grad_;
-  return &gradvec_[SliceID(layer)];
-}
-void SliceLayer::ComputeFeature(int flag, Metric *perf) {
-  CHECK_EQ(srclayers_.size(),1);
-  if(slice_dim_==0){
-    const auto& blob=srclayers_.at(0)->data(this);
-    int size=blob.count()/slice_num_;
-    for(int i=0;i<slice_num_;i++){
-      float* dst=datavec_[i].mutable_cpu_data();
-      const float* src=blob.cpu_data()+i*size;
-      memcpy(dst, src, size*sizeof(float));
-    }
-  }
-}
-void SliceLayer::ComputeGradient(int flag) {
-  // LOG(FATAL) << "Not implemented";
-}
-
-/************* Implementation for SplitLayer****************/
-void SplitLayer::Setup(const LayerProto& proto, int npartitions) {
-  // CHECK_EQ(npartitions, 1);
-  Layer::Setup(proto, npartitions);
-
-  CHECK_EQ(srclayers_.size(),1);
-  data_.Reshape(srclayers_[0]->data(this).shape());
-  grad_.Reshape(srclayers_[0]->data(this).shape());
-}
-
-void SplitLayer::ComputeFeature(int flag, Metric *perf) {
-  LOG(FATAL) << "Not implemented";
-
-}
-void SplitLayer::ComputeGradient(int flag) {
-  LOG(FATAL) << "Not implemented";
-}
-
 }  // namespace singa
-

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/134c891a/src/neuralnet/layer.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/layer.cc b/src/neuralnet/layer.cc
index 29a2312..028682c 100644
--- a/src/neuralnet/layer.cc
+++ b/src/neuralnet/layer.cc
@@ -1,15 +1,40 @@
+#include "neuralnet/layer.h"
+
 #include <glog/logging.h>
-#include <memory>
 #include <algorithm>
 #include "mshadow/tensor.h"
 #include "mshadow/cxxnet_op.h"
-#include "neuralnet/layer.h"
 #include "utils/singleton.h"
 
-using namespace mshadow;
-using namespace mshadow::expr;
-
 namespace singa {
+
+using mshadow::cpu;
+using mshadow::expr::broadcast;
+using mshadow::expr::chpool;
+using mshadow::expr::F;
+using mshadow::expr::pool;
+using mshadow::expr::sumall_except_dim;
+using mshadow::expr::unpool;
+using mshadow::op::power;
+using mshadow::op::relu;
+using mshadow::op::relu_grad;
+using mshadow::op::sigmoid;
+using mshadow::op::square;
+using mshadow::op::stanh;
+using mshadow::op::stanh_grad;
+using mshadow::op::threshold;
+using mshadow::Random;
+using mshadow::red::maximum;
+using mshadow::red::sum;
+using mshadow::Shape;
+using mshadow::Shape1;
+using mshadow::Shape2;
+using mshadow::Shape3;
+using mshadow::Shape4;
+using mshadow::Tensor;
+using std::string;
+using std::vector;
+
 inline Tensor<cpu, 4> Tensor4(Blob<float>* blob) {
   const vector<int>& shape = blob->shape();
   Tensor<cpu, 4> tensor(blob->mutable_cpu_data(),
@@ -17,59 +42,279 @@ inline Tensor<cpu, 4> Tensor4(Blob<float>* blob) {
   return tensor;
 }
 
-inline Tensor<cpu, 3> Tensor3(Blob<float>* blob){
+inline Tensor<cpu, 3> Tensor3(Blob<float>* blob) {
   const vector<int>& shape = blob->shape();
   Tensor<cpu, 3> tensor(blob->mutable_cpu_data(),
       Shape3(shape[0], shape[1], blob->count() / shape[0] / shape[1]));
   return tensor;
 }
-inline Tensor<cpu, 2> Tensor2(Blob<float>* blob){
+
+inline Tensor<cpu, 2> Tensor2(Blob<float>* blob) {
   const vector<int>& shape = blob->shape();
   Tensor<cpu, 2> tensor(blob->mutable_cpu_data(),
       Shape2(shape[0], blob->count() / shape[0]));
   return tensor;
 }
-inline Tensor<cpu, 1> Tensor1(Blob<float>* blob){
+
+inline Tensor<cpu, 1> Tensor1(Blob<float>* blob) {
   Tensor<cpu, 1> tensor(blob->mutable_cpu_data(), Shape1(blob->count()));
   return tensor;
 }
 
-/************ Implementation for ConvProductLayer*************************/
+/***************Implementation for ShardDataLayer**************************/
+ShardDataLayer::~ShardDataLayer() {
+  if (shard_ != nullptr)
+    delete shard_;
+  shard_ = nullptr;
+}
+
+void ShardDataLayer::Setup(const LayerProto& proto, int npartitions) {
+  Layer::Setup(proto, npartitions);
+  shard_ = new DataShard(proto.sharddata_conf().path(), DataShard::kRead);
+  string key;
+  shard_->Next(&key, &sample_);
+  delete shard_;
+  shard_ = nullptr;
+  batchsize_ = proto.sharddata_conf().batchsize();
+  if (partition_dim() == 0)
+    batchsize_ /= npartitions;
+  records_.resize(batchsize_);
+  random_skip_ = proto.sharddata_conf().random_skip();
+}
+
+void ShardDataLayer::ComputeFeature(int flag, Metric* perf) {
+  if ((flag & kForward) == 0)
+    return;
+
+  if (shard_ == nullptr)
+    shard_ = new DataShard(layer_proto_.sharddata_conf().path(),
+                           DataShard::kRead);
+  if (random_skip_) {
+    int nskip = rand() % random_skip_;
+    LOG(INFO) << "Random Skip " << nskip << " records, there are "
+              << shard_->Count() << " records in total";
+    string key;
+    for (int i = 0; i < nskip; i++) {
+      shard_->Next(&key, &sample_);
+    }
+    random_skip_ = 0;
+  }
+  for (auto& record : records_) {
+    string key;
+    if (!shard_->Next(&key, &record)) {
+      shard_->SeekToFirst();
+      CHECK(shard_->Next(&key, &record));
+    }
+  }
+}
+
+/********* Implementation for LabelLayer **************/
+void LabelLayer::Setup(const LayerProto& proto, int npartitions) {
+  Layer::Setup(proto, npartitions);
+  CHECK_EQ(srclayers_.size(), 1);
+  int batchsize = static_cast<DataLayer*>(srclayers_[0])->batchsize();
+  data_.Reshape(vector<int>{batchsize});
+}
+
+void LabelLayer::ParseRecords(int flag, const vector<Record>& records,
+                              Blob<float>* blob) {
+  int rid = 0;
+  float *label = blob->mutable_cpu_data();
+  for (const Record& record : records) {
+    label[rid++] = record.image().label();
+    // CHECK_LT(record.image().label(),10);
+  }
+  CHECK_EQ(rid, blob->shape()[0]);
+}
+
+/**************** Implementation for MnistLayer ******************/
+void MnistLayer::ParseRecords(int flag,
+    const vector<Record>& records, Blob<float>* blob){
+  if ((flag & kForward) == 0)
+    return;
+  LOG_IF(ERROR, records.size()==0)<<"Empty records to parse";
+  int ndim=records.at(0).image().shape_size();
+  int inputsize =records.at(0).image().shape(ndim-1);
+  CHECK_EQ(inputsize, blob->shape()[2]);
+
+  float* dptr=blob->mutable_cpu_data();
+  for(const Record& record: records){
+    const SingleLabelImageRecord& imagerecord=record.image();
+    if(imagerecord.pixel().size()) {
+      string pixel=imagerecord.pixel();
+      for(int i = 0, k = 0; i < inputsize; i++) {
+        for(int j = 0; j < inputsize; j++) {
+          // NOTE!!! must cast pixel to uint8_t then to float!!! waste a lot of
+          // time to debug this
+          float x =  static_cast<float>(static_cast<uint8_t>(pixel[k++]));
+          x = x / norm_a_-norm_b_;
+          *dptr = x;
+          dptr++;
+        }
+      }
+    } else {
+      for(int i = 0, k = 0; i < inputsize; i++) {
+        for(int j = 0; j < inputsize; j++) {
+          *dptr = imagerecord.data(k++) / norm_a_ - norm_b_;
+          dptr++;
+        }
+      }
+    }
+  }
+  CHECK_EQ(dptr, blob->mutable_cpu_data()+blob->count());
+}
+void MnistLayer::Setup(const LayerProto& proto, int npartitions) {
+  Layer::Setup(proto, npartitions);
+  CHECK_EQ(srclayers_.size(), 1);
+  int batchsize = static_cast<DataLayer*>(srclayers_[0])->batchsize();
+  Record sample = static_cast<DataLayer*>(srclayers_[0])->sample();
+  kernel_ = proto.mnist_conf().kernel();
+  sigma_ = proto.mnist_conf().sigma();
+  alpha_ = proto.mnist_conf().alpha();
+  beta_ = proto.mnist_conf().beta();
+  gamma_ = proto.mnist_conf().gamma();
+  resize_ = proto.mnist_conf().resize();
+  norm_a_ = proto.mnist_conf().norm_a();
+  norm_b_ = proto.mnist_conf().norm_b();
+  elastic_freq_ = proto.mnist_conf().elastic_freq();
+  int ndim = sample.image().shape_size();
+  CHECK_GE(ndim, 2);
+  if (resize_) {
+    data_.Reshape(vector<int>{batchsize, 1, resize_, resize_});
+  } else {
+    int s = sample.image().shape(ndim - 1);
+    CHECK_EQ(s, sample.image().shape(ndim - 2));
+    data_.Reshape(vector<int>{batchsize, 1, s, s});
+  }
+}
+
+/*************** Implementation for RGBImageLayer *************************/
+void RGBImageLayer::ParseRecords(int flag,
+    const vector<Record>& records, Blob<float>* blob){
+  if ((flag & kForward) == 0)
+    return;
+
+  const vector<int>& s=blob->shape();
+  auto images = Tensor4(&data_);
+  const SingleLabelImageRecord& r=records.at(0).image();
+  Tensor<cpu, 3> raw_image(Shape3(r.shape(0),r.shape(1),r.shape(2)));
+  AllocSpace(raw_image);
+  Tensor<cpu, 3> croped_image(nullptr, Shape3(s[1],s[2],s[3]));
+  if(cropsize_)
+    AllocSpace(croped_image);
+    //CHECK(std::equal(croped_image.shape(), raw_image.shape());
+  int rid=0;
+  const float* meandptr=mean_.cpu_data();
+  for(const Record& record: records){
+    auto image=images[rid];
+    bool do_crop = cropsize_ > 0 && ((flag & kTrain) == kTrain);
+    bool do_mirror = mirror_ && rand() % 2 && ((flag & kTrain) == kTrain);
+    float* dptr=nullptr;
+    if(do_crop||do_mirror)
+      dptr=raw_image.dptr;
+    else
+      dptr=image.dptr;
+    if(record.image().pixel().size()){
+      string pixel=record.image().pixel();
+      for(size_t i=0;i<pixel.size();i++)
+        dptr[i]=static_cast<float>(static_cast<uint8_t>(pixel[i]));
+    }else {
+      memcpy(dptr, record.image().data().data(),
+          sizeof(float)*record.image().data_size());
+    }
+    for(int i=0;i<mean_.count();i++)
+      dptr[i]-=meandptr[i];
+
+    if(do_crop){
+      int hoff=rand()%(r.shape(1)-cropsize_);
+      int woff=rand()%(r.shape(2)-cropsize_);
+      Shape<2> cropshape=Shape2(cropsize_, cropsize_);
+      if(do_mirror){
+        croped_image=crop(raw_image, cropshape, hoff, woff);
+        image=mirror(croped_image);
+      }else{
+        image=crop(raw_image, cropshape, hoff, woff);
+      }
+    }else if(do_mirror){
+      image=mirror(raw_image);
+    }
+    rid++;
+  }
+}
+
+void RGBImageLayer::Setup(const LayerProto& proto, int npartitions) {
+  ParserLayer::Setup(proto, npartitions);
+  CHECK_EQ(srclayers_.size(), 1);
+  scale_ = proto.rgbimage_conf().scale();
+  cropsize_ = proto.rgbimage_conf().cropsize();
+  mirror_ = proto.rgbimage_conf().mirror();
+  int batchsize = static_cast<DataLayer*>(srclayers_[0])->batchsize();
+  Record sample = static_cast<DataLayer*>(srclayers_[0])->sample();
+  vector<int> shape;
+  shape.push_back(batchsize);
+  for (int x : sample.image().shape()) {
+    shape.push_back(x);
+  }
+  CHECK_EQ(shape.size(), 4);
+  if (cropsize_) {
+    shape[2] = cropsize_;
+    shape[3] = cropsize_;
+  }
+  data_.Reshape(shape);
+  mean_.Reshape({shape[1], shape[2], shape[3]});
+  if (proto.rgbimage_conf().has_meanfile()) {
+    if (proto.rgbimage_conf().meanfile().find("binaryproto") != string::npos) {
+      CaffeBlob mean;
+      ReadProtoFromBinaryFile(proto.rgbimage_conf().meanfile().c_str(), &mean);
+      CHECK_EQ(mean_.count(), mean.data_size());
+      memcpy(mean_.mutable_cpu_data(), mean.data().data(),
+             sizeof(float)*mean.data_size());
+    } else {
+      SingleLabelImageRecord mean;
+      ReadProtoFromBinaryFile(proto.rgbimage_conf().meanfile().c_str(), &mean);
+      CHECK_EQ(mean_.count(), mean.data_size());
+      memcpy(mean_.mutable_cpu_data(), mean.data().data(),
+             sizeof(float)*mean.data_size());
+    }
+  } else {
+    memset(mean_.mutable_cpu_data(), 0, sizeof(float) * mean_.count());
+  }
+}
+
+/************ Implementation for ConvolutionLayer*************************/
 ConvolutionLayer::~ConvolutionLayer() {
   delete weight_;
   delete bias_;
 }
 void ConvolutionLayer::Setup(const LayerProto& proto, int npartitions) {
   Layer::Setup(proto, npartitions);
-  ConvolutionProto conv_conf=proto.convolution_conf();
-  kernel_=conv_conf.kernel();
+  ConvolutionProto conv_conf = proto.convolution_conf();
+  kernel_ = conv_conf.kernel();
   CHECK_GT(kernel_, 0) << "Filter size cannot be zero.";
-  pad_=conv_conf.pad();
-  stride_=conv_conf.stride();
-  num_filters_=conv_conf.num_filters();
-  if(partition_dim() > 0)
+  pad_ = conv_conf.pad();
+  stride_ = conv_conf.stride();
+  num_filters_ = conv_conf.num_filters();
+  if (partition_dim() > 0)
     num_filters_ /= npartitions;
-
-  const vector<int>& srcshape=srclayers_[0]->data(this).shape();
-  int dim=srcshape.size();
+  const vector<int>& srcshape = srclayers_[0]->data(this).shape();
+  int dim = srcshape.size();
   CHECK_GT(dim, 2);
-  width_=srcshape[dim-1];
-  height_=srcshape[dim-2];
-  if(dim>3)
-    channels_=srcshape[dim-3];
-  else if(dim>2)
-    channels_=1;
-  batchsize_=srcshape[0];
-  conv_height_=(height_ + 2 * pad_ - kernel_) / stride_ + 1;
-  conv_width_= (width_ + 2 * pad_ - kernel_) / stride_ + 1;
-  col_height_=channels_*kernel_*kernel_;
-  col_width_=conv_height_*conv_width_;
+  width_ = srcshape[dim - 1];
+  height_ = srcshape[dim - 2];
+  if (dim > 3)
+    channels_ = srcshape[dim - 3];
+  else if (dim > 2)
+    channels_ = 1;
+  batchsize_ = srcshape[0];
+  conv_height_ = (height_ + 2 * pad_ - kernel_) / stride_ + 1;
+  conv_width_ = (width_ + 2 * pad_ - kernel_) / stride_ + 1;
+  col_height_ = channels_ * kernel_ * kernel_;
+  col_width_ = conv_height_ * conv_width_;
   vector<int> shape{batchsize_, num_filters_, conv_height_, conv_width_};
   data_.Reshape(shape);
   grad_.Reshape(shape);
   col_data_.Reshape(vector<int>{col_height_, col_width_});
   col_grad_.Reshape(vector<int>{col_height_, col_width_});
-
   weight_ = Param::Create(proto.param(0));
   bias_ = Param::Create(proto.param(1));
   weight_->Setup(vector<int>{num_filters_, col_height_});
@@ -82,46 +327,41 @@ void ConvolutionLayer::ComputeFeature(int flag, Metric* perf){
   auto col = Tensor2(&col_data_);
   auto weight = Tensor2(weight_->mutable_data());
   auto bias = Tensor1(bias_->mutable_data());
-
-  for(int n=0;n<batchsize_;n++){
-    if(pad_>0)
-      col=unpack_patch2col(pad(src[n], pad_), kernel_, stride_);
+  for (int n = 0; n < batchsize_; n++) {
+    if (pad_ > 0)
+      col = unpack_patch2col(pad(src[n], pad_), kernel_, stride_);
     else
-      col=unpack_patch2col(src[n], kernel_, stride_);
-    data[n]=dot(weight, col);
+      col = unpack_patch2col(src[n], kernel_, stride_);
+    data[n] = dot(weight, col);
   }
-  data+=broadcast<1>(bias, data.shape);
+  data += broadcast<1>(bias, data.shape);
 }
 
-void ConvolutionLayer::ComputeGradient(int flag) {
+void ConvolutionLayer::ComputeGradient(int flag, Metric* perf) {
   auto src = Tensor4(srclayers_[0]->mutable_data(this));
   auto col = Tensor2(&col_data_);
   auto weight = Tensor2(weight_->mutable_data());
-
   auto grad = Tensor3(&grad_);
   auto gcol = Tensor2(&col_grad_);
   auto gweight = Tensor2(weight_->mutable_grad());
   auto gbias = Tensor1(bias_->mutable_grad());
-
-  Blob<float>* gsrcblob=srclayers_[0]->mutable_grad(this);
+  Blob<float>* gsrcblob = srclayers_[0]->mutable_grad(this);
   Tensor<cpu, 4> gsrc(nullptr, Shape4(batchsize_, channels_, height_, width_));
-  if(gsrcblob!=nullptr)
-    gsrc.dptr=gsrcblob->mutable_cpu_data();
-  gbias=sumall_except_dim<1>(grad);
-
+  if (gsrcblob != nullptr)
+    gsrc.dptr = gsrcblob->mutable_cpu_data();
+  gbias = sumall_except_dim<1>(grad);
   gweight = 0.0f;
   Shape<3> padshp(gsrc.shape.SubShape());
   padshp[0] += 2 * pad_;
   padshp[1] += 2 * pad_;
   Shape<2> imgshp = Shape2(height_, width_);
-  for(int n=0;n<batchsize_;n++){
-    if(pad_>0)
-      col=unpack_patch2col(pad(src[n], pad_), kernel_, stride_);
+  for (int n = 0; n < batchsize_; n++) {
+    if (pad_ > 0)
+      col = unpack_patch2col(pad(src[n], pad_), kernel_, stride_);
     else
-      col=unpack_patch2col(src[n], kernel_, stride_);
+      col = unpack_patch2col(src[n], kernel_, stride_);
     gweight += dot(grad[n], col.T());
-
-    if(gsrcblob!=nullptr){
+    if (gsrcblob != nullptr) {
       gcol = dot(weight.T(), grad[n]);
       gsrc[n] = crop(pack_col2patch(gcol, padshp, kernel_, stride_), imgshp);
     }
@@ -143,16 +383,16 @@ void DropoutLayer::ComputeFeature(int flag, Metric* perf) {
     data_.CopyFrom(srclayers_[0]->data(this));
     return;
   }
-  float pkeep=1-pdrop_;
+  float pkeep = 1 - pdrop_;
   auto mask = Tensor1(&mask_);
-  mask = F<op::threshold>(TSingleton<Random<cpu>>::Instance()\
-      ->uniform(mask.shape), pkeep ) * (1.0f/pkeep);
+  mask = F<threshold>(TSingleton<Random<cpu>>::Instance() \
+                      ->uniform(mask.shape), pkeep) * (1.0f/pkeep);
   auto data = Tensor1(&data_);
   auto src = Tensor1(srclayers_[0]->mutable_data(this));
   data = src * mask;
 }
 
-void DropoutLayer::ComputeGradient(int flag)  {
+void DropoutLayer::ComputeGradient(int flag, Metric* perf)  {
   auto mask = Tensor1(&mask_);
   auto grad = Tensor1(&grad_);
   auto gsrc = Tensor1(srclayers_[0]->mutable_grad(this));
@@ -221,7 +461,7 @@ void RBMVisLayer::ComputeFeature(int flag, Metric* perf) {
   }
 }
 
-void RBMVisLayer::ComputeGradient(int flag) {
+void RBMVisLayer::ComputeGradient(int flag, Metric* perf) {
   auto vis_pos = Tensor2(&data_);
   auto vis_neg = Tensor2(&neg_data_);
     auto gbias = Tensor1(bias_->mutable_grad());
@@ -292,7 +532,7 @@ void RBMHidLayer::ComputeFeature(int flag, Metric* perf) {
     data = F<op::sigmoid>(data);
 }
 
-void RBMHidLayer::ComputeGradient(int flag) {
+void RBMHidLayer::ComputeGradient(int flag, Metric* perf) {
   auto hid_pos = Tensor2(&data_);
   auto hid_neg = Tensor2(&neg_data_);
   auto vis_pos = Tensor2(vis_layer_->mutable_data(this));
@@ -313,13 +553,12 @@ InnerProductLayer::~InnerProductLayer() {
   delete weight_;
   delete bias_;
 }
+
 void InnerProductLayer::Setup(const LayerProto& proto, int npartitions) {
   Layer::Setup(proto, npartitions);
   CHECK_EQ(srclayers_.size(), 1);
   const auto& src = srclayers_[0]->data(this);
   batchsize_ = src.shape()[0];
-  vdim_ = src.count()/batchsize_;
-  hdim_ = proto.innerproduct_conf().num_output();
   transpose_ = proto.innerproduct_conf().transpose();
   if (partition_dim() > 0)
     hdim_ /= npartitions;
@@ -344,10 +583,12 @@ void InnerProductLayer::ComputeFeature(int flag, Metric* perf) {
   else
     data = dot(src, weight.T());
   // repmat: repeat bias vector into batchsize rows
-  data+=repmat(bias, batchsize_);
+  data += repmat(bias, batchsize_);
 }
 
-void InnerProductLayer::ComputeGradient(int phas) {
+void InnerProductLayer::ComputeGradient(int flag, Metric* perf) {
+  if ((flag & kForward) != kForward)
+    return;
   auto src = Tensor2(srclayers_[0]->mutable_data(this));
   auto grad = Tensor2(&grad_);
   auto weight = Tensor2(weight_->mutable_data());
@@ -385,27 +626,25 @@ void LabelLayer::ParseRecords(int flag, const vector<Record>& records,
     label[rid++]=record.image().label();
     //  CHECK_LT(record.image().label(),10);
   }
-  CHECK_EQ(rid, blob->shape()[0]);
 }
 
 /***************** Implementation for LRNLayer *************************/
 void LRNLayer::Setup(const LayerProto& proto, int npartitions) {
   Layer::Setup(proto, npartitions);
-  CHECK_EQ(srclayers_.size(),1);
+  CHECK_EQ(srclayers_.size(), 1);
   lsize_ = proto.lrn_conf().local_size();
   CHECK_EQ(lsize_ % 2, 1) << "LRN only supports odd values for Localvol";
-  knorm_=proto.lrn_conf().knorm();
+  knorm_ = proto.lrn_conf().knorm();
   alpha_ = proto.lrn_conf().alpha();
   beta_ = proto.lrn_conf().beta();
-
-  const vector<int>& s=srclayers_[0]->data(this).shape();
+  const vector<int>& s = srclayers_[0]->data(this).shape();
   data_.Reshape(s);
   grad_.Reshape(s);
   norm_.Reshape(s);
-  batchsize_=s[0];
-  channels_=s[1];
-  height_=s[2];
-  width_=s[3];
+  batchsize_ = s[0];
+  channels_ = s[1];
+  height_ = s[2];
+  width_ = s[3];
 }
 
 void LRNLayer::ComputeFeature(int flag, Metric* perf) {
@@ -414,11 +653,11 @@ void LRNLayer::ComputeFeature(int flag, Metric* perf) {
   auto data = Tensor4(&data_);
   auto norm = Tensor4(&norm_);
   // stores normalizer without power
-  norm= chpool<red::sum>( F<op::square>(src) , lsize_ ) * salpha + knorm_;
-  data = src * F<op::power>(norm, -beta_ );
+  norm = chpool<sum>(F<square>(src), lsize_) * salpha + knorm_;
+  data = src * F<power>(norm, -beta_);
 }
 
-void LRNLayer::ComputeGradient(int flag) {
+void LRNLayer::ComputeGradient(int flag, Metric* perf) {
   const float salpha = alpha_ / lsize_;
   auto src = Tensor4(srclayers_[0]->mutable_data(this));
   auto norm = Tensor4(&norm_);
@@ -430,125 +669,61 @@ void LRNLayer::ComputeGradient(int flag) {
       grad * src * F<op::power>( norm, -beta_-1.0f ), lsize_ )  * src;
 }
 
-/**************** Implementation for MnistImageLayer******************/
-
-void MnistLayer::ParseRecords(int flag,
-    const vector<Record>& records, Blob<float>* blob){
-  if ((flag & kForward) == 0)
-    return;
-  LOG_IF(ERROR, records.size()==0)<<"Empty records to parse";
-  int ndim=records.at(0).image().shape_size();
-  int inputsize =records.at(0).image().shape(ndim-1);
-  CHECK_EQ(inputsize, blob->shape()[2]);
-
-  float* dptr=blob->mutable_cpu_data();
-  for(const Record& record: records){
-    const SingleLabelImageRecord& imagerecord=record.image();
-    if(imagerecord.pixel().size()) {
-      string pixel=imagerecord.pixel();
-      for(int i = 0, k = 0; i < inputsize; i++) {
-        for(int j = 0; j < inputsize; j++) {
-          // NOTE!!! must cast pixel to uint8_t then to float!!! waste a lot of
-          // time to debug this
-          float x =  static_cast<float>(static_cast<uint8_t>(pixel[k++]));
-          x = x / norm_a_-norm_b_;
-          *dptr = x;
-          dptr++;
-        }
-      }
-    } else {
-      for(int i = 0, k = 0; i < inputsize; i++) {
-        for(int j = 0; j < inputsize; j++) {
-          *dptr = imagerecord.data(k++) / norm_a_ - norm_b_;
-          dptr++;
-        }
-      }
-    }
-  }
-  CHECK_EQ(dptr, blob->mutable_cpu_data()+blob->count());
-}
-void MnistLayer::Setup(const LayerProto& proto, int npartitions) {
-  Layer::Setup(proto, npartitions);
-  CHECK_EQ(srclayers_.size(),1);
-  int batchsize=static_cast<DataLayer*>(srclayers_[0])->batchsize();
-  Record sample=static_cast<DataLayer*>(srclayers_[0])->sample();
-  kernel_=proto.mnist_conf().kernel();
-  sigma_=proto.mnist_conf().sigma();
-  alpha_=proto.mnist_conf().alpha();
-  beta_=proto.mnist_conf().beta();
-  gamma_=proto.mnist_conf().gamma();
-  resize_=proto.mnist_conf().resize();
-  norm_a_=proto.mnist_conf().norm_a();
-  norm_b_=proto.mnist_conf().norm_b();
-  elastic_freq_=proto.mnist_conf().elastic_freq();
-
-  int ndim=sample.image().shape_size();
-  CHECK_GE(ndim,2);
-  if(resize_)
-    data_.Reshape(vector<int>{batchsize, 1, resize_, resize_});
-  else{
-    int s=sample.image().shape(ndim-1);
-    CHECK_EQ(s,sample.image().shape(ndim-2));
-    data_.Reshape(vector<int>{batchsize, 1, s, s });
-  }
-}
-
 /******************** Implementation for PoolingLayer******************/
 void PoolingLayer::Setup(const LayerProto& proto, int npartitions) {
   Layer::Setup(proto, npartitions);
-  CHECK_EQ(srclayers_.size(),1);
+  CHECK_EQ(srclayers_.size(), 1);
   PoolingProto pool_conf = proto.pooling_conf();
-  kernel_=pool_conf.kernel();
-  stride_=pool_conf.stride();
+  kernel_ = pool_conf.kernel();
+  stride_ = pool_conf.stride();
   CHECK_LT(pad_, kernel_);
-  pool_=proto.pooling_conf().pool();
+  pool_ = proto.pooling_conf().pool();
   CHECK(pool_ == PoolingProto_PoolMethod_AVE
         || pool_ == PoolingProto_PoolMethod_MAX)
-      << "Padding implemented only for average and max pooling.";
-
-  const auto& srcshape=srclayers_[0]->data(this).shape();
-  int dim=srcshape.size();
-  CHECK_GT(dim,2);
-  width_ = srcshape[dim-1];
-  height_ = srcshape[dim-2];
-  if(dim>3)
+        << "Padding implemented only for average and max pooling.";
+  const auto& srcshape = srclayers_[0]->data(this).shape();
+  int dim = srcshape.size();
+  CHECK_GT(dim, 2);
+  width_ = srcshape[dim - 1];
+  height_ = srcshape[dim - 2];
+  if (dim > 3)
     channels_ = srcshape[dim-3];
   else
-    channels_=1;
-  batchsize_=srcshape[0];
+    channels_ = 1;
+  batchsize_ = srcshape[0];
   pooled_height_ = static_cast<int>((height_ - kernel_) / stride_) + 1;
-  pooled_width_ = static_cast<int>(( width_ - kernel_) / stride_) + 1;
-  data_.Reshape(vector<int>{batchsize_, channels_, pooled_height_, pooled_width_});
+  pooled_width_ = static_cast<int>((width_ - kernel_) / stride_) + 1;
+  data_.Reshape(vector<int>{batchsize_, channels_, pooled_height_,
+                            pooled_width_});
   grad_.ReshapeLike(data_);
 }
 
 void PoolingLayer::ComputeFeature(int flag, Metric* perf) {
   auto src = Tensor4(srclayers_[0]->mutable_data(this));
   auto data = Tensor4(&data_);
-  if(pool_ == PoolingProto_PoolMethod_MAX)
-    data=pool<red::maximum>(src, kernel_, stride_);
-  else if(pool_ == PoolingProto_PoolMethod_AVE)
-    data=pool<red::sum>(src, kernel_, stride_) *(1.0f/(kernel_*kernel_));
+  if (pool_ == PoolingProto_PoolMethod_MAX)
+    data = pool<maximum>(src, kernel_, stride_);
+  else if (pool_ == PoolingProto_PoolMethod_AVE)
+    data = pool<sum>(src, kernel_, stride_) * (1.0f / (kernel_ * kernel_));
 }
 
 /*
  * partition only on num/channel dim
  * assume grad and data have the same paritition
  */
-void PoolingLayer::ComputeGradient(int flag) {
+void PoolingLayer::ComputeGradient(int flag, Metric* perf) {
   auto src = Tensor4(srclayers_[0]->mutable_data(this));
   auto gsrc = Tensor4(srclayers_[0]->mutable_grad(this));
   auto data = Tensor4(&data_);
   auto grad = Tensor4(&grad_);
-  if(pool_ == PoolingProto_PoolMethod_MAX)
-    gsrc = unpool<red::maximum>(src, data, grad, kernel_, stride_);
-  else if(pool_ == PoolingProto_PoolMethod_AVE)
-    gsrc = unpool<red::sum>(src, data, grad, kernel_, stride_)
-      *(1.0f/(kernel_*kernel_));
+  if (pool_ == PoolingProto_PoolMethod_MAX)
+    gsrc = unpool<maximum>(src, data, grad, kernel_, stride_);
+  else if (pool_ == PoolingProto_PoolMethod_AVE)
+    gsrc = unpool<sum>(src, data, grad, kernel_, stride_)
+           * (1.0f / (kernel_ * kernel_));
 }
 
 /***************** Implementation for ReLULayer *****************************/
-
 void ReLULayer::Setup(const LayerProto& proto, int npartitions) {
   Layer::Setup(proto, npartitions);
   data_.ReshapeLike(srclayers_[0]->data(this));
@@ -558,162 +733,54 @@ void ReLULayer::Setup(const LayerProto& proto, int npartitions) {
 void ReLULayer::ComputeFeature(int flag, Metric* perf) {
   auto data = Tensor1(&data_);
   auto src = Tensor1(srclayers_[0]->mutable_data(this));
-  data=F<op::relu>(src);
+  data = F<relu>(src);
 }
 
-void ReLULayer::ComputeGradient(int flag) {
+void ReLULayer::ComputeGradient(int flag, Metric* perf) {
   auto data = Tensor1(&data_);
   auto grad = Tensor1(&grad_);
   auto gsrc = Tensor1(srclayers_[0]->mutable_grad(this));
-  gsrc=F<op::relu_grad>(data)*grad;
+  gsrc = F<relu_grad>(data)*grad;
 }
 
-/*************** Implementation for RGBImageLayer *************************/
-
-void RGBImageLayer::ParseRecords(int flag,
-    const vector<Record>& records, Blob<float>* blob){
-  if ((flag & kForward) == 0)
-    return;
-
-  const vector<int>& s=blob->shape();
-  auto images = Tensor4(&data_);
-  const SingleLabelImageRecord& r=records.at(0).image();
-  Tensor<cpu, 3> raw_image(Shape3(r.shape(0),r.shape(1),r.shape(2)));
-  AllocSpace(raw_image);
-  Tensor<cpu, 3> croped_image(nullptr, Shape3(s[1],s[2],s[3]));
-  if(cropsize_)
-    AllocSpace(croped_image);
-    //CHECK(std::equal(croped_image.shape(), raw_image.shape());
-  int rid=0;
-  const float* meandptr=mean_.cpu_data();
-  for(const Record& record: records){
-    auto image=images[rid];
-    bool do_crop = cropsize_ > 0 && ((flag & kTrain) == kTrain);
-    bool do_mirror = mirror_ && rand() % 2 && ((flag & kTrain) == kTrain);
-    float* dptr=nullptr;
-    if(do_crop||do_mirror)
-      dptr=raw_image.dptr;
-    else
-      dptr=image.dptr;
-    if(record.image().pixel().size()){
-      string pixel=record.image().pixel();
-      for(size_t i=0;i<pixel.size();i++)
-        dptr[i]=static_cast<float>(static_cast<uint8_t>(pixel[i]));
-    }else {
-      memcpy(dptr, record.image().data().data(),
-          sizeof(float)*record.image().data_size());
-    }
-    for(int i=0;i<mean_.count();i++)
-      dptr[i]-=meandptr[i];
-
-    if(do_crop){
-      int hoff=rand()%(r.shape(1)-cropsize_);
-      int woff=rand()%(r.shape(2)-cropsize_);
-      Shape<2> cropshape=Shape2(cropsize_, cropsize_);
-      if(do_mirror){
-        croped_image=crop(raw_image, cropshape, hoff, woff);
-        image=mirror(croped_image);
-      }else{
-        image=crop(raw_image, cropshape, hoff, woff);
-      }
-    }else if(do_mirror){
-      image=mirror(raw_image);
-    }
-    rid++;
-  }
-  if(scale_)
-    images=images*scale_;
-
-  FreeSpace(raw_image);
-  if(cropsize_)
-    FreeSpace(croped_image);
+/**************** Implementation for RBMHidLayer********************/
+RBMHidLayer::~RBMHidLayer() {
+  delete weight_;
+  delete bias_;
 }
-void RGBImageLayer::Setup(const LayerProto& proto, int npartitions) {
-  ParserLayer::Setup(proto, npartitions);
-  CHECK_EQ(srclayers_.size(),1);
-  scale_=proto.rgbimage_conf().scale();
-  cropsize_=proto.rgbimage_conf().cropsize();
-  mirror_=proto.rgbimage_conf().mirror();
-  int batchsize=static_cast<DataLayer*>(srclayers_[0])->batchsize();
-  Record sample=static_cast<DataLayer*>(srclayers_[0])->sample();
-  vector<int> shape;
-  shape.push_back(batchsize);
-  for(int x: sample.image().shape()){
-    shape.push_back(x);
-  }
-  CHECK_EQ(shape.size(),4);
-  if(cropsize_){
-    shape[2]=cropsize_;
-    shape[3]=cropsize_;
-  }
-  data_.Reshape(shape);
-  mean_.Reshape({shape[1],shape[2],shape[3]});
-  if(proto.rgbimage_conf().has_meanfile()){
-    if(proto.rgbimage_conf().meanfile().find("binaryproto") != string::npos) {
-      CaffeBlob mean;
-      ReadProtoFromBinaryFile(proto.rgbimage_conf().meanfile().c_str(), &mean);
-      CHECK_EQ(mean_.count(), mean.data_size());
-      memcpy(mean_.mutable_cpu_data(), mean.data().data(),
-          sizeof(float)*mean.data_size());
-    } else {
-      SingleLabelImageRecord mean;
-      ReadProtoFromBinaryFile(proto.rgbimage_conf().meanfile().c_str(), &mean);
-      CHECK_EQ(mean_.count(), mean.data_size());
-      memcpy(mean_.mutable_cpu_data(), mean.data().data(),
-          sizeof(float)*mean.data_size());
-    }
-  } else {
-    memset(mean_.mutable_cpu_data(),0,sizeof(float)*mean_.count());
-  }
+void RBMHidLayer::Setup(const LayerProto& proto, int npartitions) {
+  Layer::Setup(proto, npartitions);
+  CHECK_EQ(srclayers_.size(), 1);
+  const auto& src_data = srclayers_[0]->data(this, kPositive);
+  const auto& src_sample = srclayers_[0]->data(this, kNegative);
+  scale_ = static_cast<float> (1.0f);
+  batchsize_ = src_data.shape()[0];
+  neg_batchsize_ = src_sample.shape()[0];
+  vdim_ = src_data.count() / batchsize_;
+  hdim_ = proto.rbmhid_conf().hid_dim();
+  data_.Reshape(vector<int>{batchsize_, hdim_});
+  hid_sample_.Reshape(vector<int>{neg_batchsize_, hdim_});
+  weight_ = Param::Create(proto.param(0));
+  bias_ = Param::Create(proto.param(1));
+  weight_->Setup(proto.param(0), vector<int>{vdim_, hdim_});
+  bias_->Setup(proto.param(1), vector<int>{hdim_});
 }
 
-/***************Implementation for ShardDataLayer**************************/
-void ShardDataLayer::ComputeFeature(int flag, Metric* perf){
-  if ((flag & kForward) == 0)
-    return;
-
-  if (shard_ == nullptr)
-    shard_ = new DataShard(layer_proto_.sharddata_conf().path(),
-        DataShard::kRead);
-  if(random_skip_){
-    int nskip = rand() % random_skip_;
-    LOG(INFO)<<"Random Skip "<<nskip<<" records, there are "<<shard_->Count()
-      <<" records in total";
-    string key;
-    for(int i=0;i<nskip;i++){
-      shard_->Next(&key, &sample_);
-    }
-    random_skip_=0;
-  }
-  for(auto& record: records_){
-    string key;
-    if(!shard_->Next(&key, &record)){
-      shard_->SeekToFirst();
-      CHECK(shard_->Next(&key, &record));
-    }
-  }
+void RBMHidLayer::ComputeGradient(int flag, Metric* perf) {
+  auto data = Tensor2(&data_);
+  auto hid_sample = Tensor2(&hid_sample_);
+  auto gbias = Tensor1(bias_->mutable_grad());
+  gbias = sum_rows(hid_sample);
+  gbias -= sum_rows(data);
+  gbias *= scale_ / (1.0f * batchsize_);
 }
 
-void ShardDataLayer::Setup(const LayerProto& proto, int npartitions) {
-  Layer::Setup(proto, npartitions);
-  shard_= new DataShard(proto.sharddata_conf().path(), DataShard::kRead);
-  string key;
-  shard_->Next(&key, &sample_);
-  delete shard_;
-  shard_ = nullptr;
-  batchsize_=proto.sharddata_conf().batchsize();
-  if(partition_dim() == 0)
-    batchsize_ /= npartitions;
-
-  records_.resize(batchsize_);
-  random_skip_=proto.sharddata_conf().random_skip();
+/**************** Implementation for RBMVisLayer********************/
+RBMVisLayer::~RBMVisLayer() {
+  delete weight_;
+  delete bias_;
 }
 
-ShardDataLayer::~ShardDataLayer() {
-  if (shard_ != nullptr)
-    delete shard_;
-  shard_ = nullptr;
-}
 /*******************Implementation of SigmoidLayer***************************/
 void SigmoidLayer::Setup(const LayerProto& proto, int npartitions) {
   Layer::Setup(proto, npartitions);
@@ -727,14 +794,14 @@ void SigmoidLayer::ComputeFeature(int flag, Metric* perf) {
   data = F<op::sigmoid>(src);
 }
 
-void SigmoidLayer::ComputeGradient(int flag) {
+void SigmoidLayer::ComputeGradient(int flag, Metric* perf) {
   auto data = Tensor1(&data_);
   auto grad = Tensor1(&grad_);
   auto gsrc = Tensor1(srclayers_[0]->mutable_grad(this));
   gsrc = F<op::sigmoid_grad>(data)*grad;
 }
 /*******************Implementation of TanLayer***************************/
-void TanhLayer::Setup(const LayerProto& proto, int npartitions){
+void TanhLayer::Setup(const LayerProto& proto, int npartitions) {
   Layer::Setup(proto, npartitions);
   data_.ReshapeLike(srclayers_[0]->data(this));
   grad_.ReshapeLike(srclayers_[0]->grad(this));
@@ -743,14 +810,14 @@ void TanhLayer::Setup(const LayerProto& proto, int npartitions){
 void TanhLayer::ComputeFeature(int flag, Metric* perf) {
   auto data = Tensor1(&data_);
   auto src = Tensor1(srclayers_[0]->mutable_data(this));
-  data=F<op::stanh>(src);
+  data = F<stanh>(src);
 }
 
-void TanhLayer::ComputeGradient(int flag) {
+void TanhLayer::ComputeGradient(int flag, Metric* perf) {
   auto data = Tensor1(&data_);
   auto grad = Tensor1(&grad_);
   auto gsrc = Tensor1(srclayers_[0]->mutable_grad(this));
-  gsrc=F<op::stanh_grad>(data)*grad;
+  gsrc = F<stanh_grad>(data) * grad;
 }
 /********** * Implementation for EuclideanLossLayer*************************/
 void EuclideanLossLayer::Setup(const LayerProto& proto, int npartitions) {
@@ -768,61 +835,61 @@ void EuclideanLossLayer::ComputeFeature(int flag, Metric* perf) {
   for (int n = 0; n < batchsize_; n++) {
     for (int j = 0; j < dim_; ++j) {
       loss += (input_dptr[j] - reconstruct_dptr[j]) *
-                 (input_dptr[j] - reconstruct_dptr[j]);
+        (input_dptr[j] - reconstruct_dptr[j]);
     }
-    reconstruct_dptr+=dim_;
-    input_dptr+=dim_;
+    reconstruct_dptr +=dim_;
+    input_dptr +=dim_;
   }
   CHECK_EQ(reconstruct_dptr,
-            srclayers_[0]->data(this).cpu_data() + (batchsize_*dim_));
+      srclayers_[0]->data(this).cpu_data() + (batchsize_*dim_));
   CHECK_EQ(input_dptr,
       srclayers_[1]->data(this).cpu_data() + (batchsize_*dim_));
-  perf->Add("loss", loss/(1.0f*batchsize_));
+  perf->Add("loss", loss / batchsize_);
 }
-void EuclideanLossLayer::ComputeGradient(int flag) {
+void EuclideanLossLayer::ComputeGradient(int flag, Metric* perf) {
   const float* reconstruct_dptr = srclayers_[0]->data(this).cpu_data();
   const float* input_dptr = srclayers_[1]->data(this).cpu_data();
   Blob<float>* gsrcblob = srclayers_[0]->mutable_grad(this);
   float* gsrcptr = gsrcblob->mutable_cpu_data();
   for (int n = 0; n < batchsize_; n++) {
     for (int j = 0; j < dim_; j++)
-    gsrcptr[n*dim_+j]= 2 * (reconstruct_dptr[n*dim_+j]-input_dptr[n*dim_+j]);
+    gsrcptr[n*dim_+j] = 2 * (reconstruct_dptr[n*dim_+j]-input_dptr[n*dim_+j]);
   }
   Tensor<cpu, 1> gsrc(gsrcptr, Shape1(gsrcblob->count()));
-  gsrc*=1.0f/(1.0f*batchsize_);
+  gsrc /= batchsize_;
 }
+
 /********** * Implementation for SoftmaxLossLayer*************************/
 void SoftmaxLossLayer::Setup(const LayerProto& proto, int npartitions) {
   LossLayer::Setup(proto, npartitions);
-  CHECK_EQ(srclayers_.size(),2);
+  CHECK_EQ(srclayers_.size(), 2);
   data_.Reshape(srclayers_[0]->data(this).shape());
-  batchsize_=data_.shape()[0];
-  dim_=data_.count()/batchsize_;
-  topk_=proto.softmaxloss_conf().topk();
+  batchsize_ = data_.shape()[0];
+  dim_ = data_.count() / batchsize_;
+  topk_ = proto.softmaxloss_conf().topk();
   metric_.Reshape(vector<int>{2});
-  scale_=proto.softmaxloss_conf().scale();
+  scale_ = proto.softmaxloss_conf().scale();
 }
 void SoftmaxLossLayer::ComputeFeature(int flag, Metric* perf) {
   Shape<2> s=Shape2(batchsize_, dim_);
   Tensor<cpu, 2> prob(data_.mutable_cpu_data(), s);
   Tensor<cpu, 2> src(srclayers_[0]->mutable_data(this)->mutable_cpu_data(), s);
   Softmax(prob, src);
-  const float* label=srclayers_[1]->data(this).cpu_data();
-  const float* probptr=prob.dptr;
-  float loss=0, precision=0;
-  for(int n=0;n<batchsize_;n++){
-    int ilabel=static_cast<int>(label[n]);
+  const float* label = srclayers_[1]->data(this).cpu_data();
+  const float* probptr = prob.dptr;
+  float loss = 0, precision = 0;
+  for (int n = 0; n < batchsize_; n++) {
+    int ilabel = static_cast<int>(label[n]);
     //  CHECK_LT(ilabel,10);
-    CHECK_GE(ilabel,0);
-    float prob_of_truth=probptr[ilabel];
-    loss-=log(std::max(prob_of_truth, FLT_MIN));
+    CHECK_GE(ilabel, 0);
+    float prob_of_truth = probptr[ilabel];
+    loss -= log(std::max(prob_of_truth, FLT_MIN));
     vector<std::pair<float, int> > probvec;
     for (int j = 0; j < dim_; ++j) {
       probvec.push_back(std::make_pair(probptr[j], j));
     }
-    std::partial_sort(
-        probvec.begin(), probvec.begin() + topk_,
-        probvec.end(), std::greater<std::pair<float, int> >());
+    std::partial_sort(probvec.begin(), probvec.begin() + topk_, probvec.end(),
+                      std::greater<std::pair<float, int> >());
     // check if true label is in top k predictions
     for (int k = 0; k < topk_; k++) {
       if (probvec[k].second == static_cast<int>(label[n])) {
@@ -830,23 +897,127 @@ void SoftmaxLossLayer::ComputeFeature(int flag, Metric* perf) {
         break;
       }
     }
-    probptr+=dim_;
+    probptr += dim_;
   }
   CHECK_EQ(probptr, prob.dptr+prob.shape.Size());
-  perf->Add("loss", loss*scale_/(1.0f*batchsize_));
-  perf->Add("accuracy", precision*scale_/(1.0f*batchsize_));
+  perf->Add("loss", loss * scale_ / (1.0f * batchsize_));
+  perf->Add("accuracy", precision * scale_ / (1.0f * batchsize_));
 }
 
-void SoftmaxLossLayer::ComputeGradient(int flag) {
-  const float* label=srclayers_[1]->data(this).cpu_data();
-  Blob<float>* gsrcblob=srclayers_[0]->mutable_grad(this);
+void SoftmaxLossLayer::ComputeGradient(int flag, Metric* perf) {
+  const float* label = srclayers_[1]->data(this).cpu_data();
+  Blob<float>* gsrcblob = srclayers_[0]->mutable_grad(this);
   gsrcblob->CopyFrom(data_);
-  float* gsrcptr=gsrcblob->mutable_cpu_data();
-  for(int n=0;n<batchsize_;n++){
-    gsrcptr[n*dim_+static_cast<int>(label[n])]-=1.0f;
+  float* gsrcptr = gsrcblob->mutable_cpu_data();
+  for (int n = 0; n < batchsize_; n++) {
+    gsrcptr[n*dim_ + static_cast<int>(label[n])] -= 1.0f;
   }
   Tensor<cpu, 1> gsrc(gsrcptr, Shape1(gsrcblob->count()));
-  gsrc*=scale_/(1.0f*batchsize_);
+  gsrc *= scale_ / (1.0f * batchsize_);
+}
+
+/********* Implementation for BridgeDstLayer **************/
+void BridgeDstLayer::Setup(const LayerProto& proto, int npartitions) {
+  Layer::Setup(proto, npartitions);
+  CHECK_EQ(srclayers_.size(), 1);
+  data_.Reshape(srclayers_[0]->data(this).shape());
+  grad_.ReshapeLike(data_);
+}
+
+/************* Implementation for ConcateLayer ***********/
+void ConcateLayer::Setup(const LayerProto& proto, int npartitions) {
+  // CHECK_EQ(npartitions, 1);
+  Layer::Setup(proto, npartitions);
+  size_t concate_dim = proto.concate_conf().concate_dim();
+  CHECK_GE(concate_dim, 0);
+  CHECK_GT(srclayers_.size(), 1);
+  vector<int> shape = srclayers_[0]->data(this).shape();
+  for (size_t i = 1; i < srclayers_.size(); i++) {
+    const vector<int>& srcshape = srclayers_[i]->data(this).shape();
+    for (size_t j = 0; j < shape.size(); j++)
+      if (j == concate_dim)
+        shape[j] += srcshape[j];
+      else
+        CHECK_EQ(shape[j], srcshape[j]);
+  }
+  data_.Reshape(shape);
+  grad_.Reshape(shape);
+}
+
+void ConcateLayer::ComputeFeature(int flag, Metric *perf) {
+  LOG(FATAL) << "Not implemented for Concate Layer";
+}
+
+void ConcateLayer::ComputeGradient(int flag, Metric* perf) {
+  LOG(FATAL) << "Not implemented for Concate Layer";
+}
+
+/************* Implementation for SliceLayer****************/
+void SliceLayer::Setup(const LayerProto& proto, int npartitions) {
+  Layer::Setup(proto, npartitions);
+  slice_dim_ = proto.slice_conf().slice_dim();
+  slice_num_ = npartitions;
+  CHECK_GE(slice_dim_, 0);
+  CHECK_EQ(slice_num_, dstlayers_.size());
+  data_.Reshape(srclayers_[0]->data(this).shape());
+  grad_.ReshapeLike(data_);
+  datavec_.resize(slice_num_);
+  gradvec_.resize(slice_num_);
+  CHECK_EQ(data_.count() % slice_num_, 0);  // restrict equal slicing
+  // LOG(ERROR)<<"slice dim "<<slice_dim<<" slice num "<<slice_num;
+  for (int i = 0; i < slice_num_; i++) {
+    vector<int> newshape(data_.shape());
+    newshape[slice_dim_] = newshape[slice_dim_] / slice_num_ +
+      ((i == slice_num_ - 1) ? newshape[slice_dim_] % slice_num_ : 0);
+    datavec_[i].Reshape(newshape);
+    gradvec_[i].Reshape(newshape);
+    // LOG(ERROR)<<"slice "<<IntVecToString(newshape);
+  }
+}
+
+void SliceLayer::ComputeFeature(int flag, Metric *perf) {
+  CHECK_EQ(srclayers_.size(), 1);
+  if (slice_dim_ == 0) {
+    const auto& blob = srclayers_.at(0)->data(this);
+    int size = blob.count() / slice_num_;
+    for (int i = 0; i < slice_num_; i++) {
+      float* dst = datavec_[i].mutable_cpu_data();
+      const float* src = blob.cpu_data() + i * size;
+      memcpy(dst, src, size*sizeof(float));
+    }
+  }
+}
+
+void SliceLayer::ComputeGradient(int flag, Metric* perf) {
+  // LOG(FATAL) << "Not implemented";
+}
+
+int SliceLayer::SliceID(const Layer* layer) const {
+  CHECK(layer != nullptr);
+  for (size_t i = 0; i < datavec_.size(); i++) {
+    // LOG(ERROR)<<"get slice "<<IntVecToString(shapes_[i]);
+    if (dstlayers_[i] == layer)
+      return i;
+  }
+  CHECK(false);
+  return -1;
+}
+
+/************* Implementation for SplitLayer****************/
+void SplitLayer::Setup(const LayerProto& proto, int npartitions) {
+  // CHECK_EQ(npartitions, 1);
+  Layer::Setup(proto, npartitions);
+  CHECK_EQ(srclayers_.size(), 1);
+  data_.Reshape(srclayers_[0]->data(this).shape());
+  grad_.Reshape(srclayers_[0]->data(this).shape());
+}
+
+void SplitLayer::ComputeFeature(int flag, Metric *perf) {
+  LOG(FATAL) << "Not implemented";
+}
+
+void SplitLayer::ComputeGradient(int flag, Metric* perf) {
+  LOG(FATAL) << "Not implemented";
 }
 
 }  // namespace singa