You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/06/03 07:48:37 UTC

[32/60] incubator-singa git commit: SINGA-170 Add Dropout layer and CudnnDropout layer

SINGA-170 Add Dropout layer and CudnnDropout layer

pass compilation.
there is link error for cudnn.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/99e0d24d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/99e0d24d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/99e0d24d

Branch: refs/heads/dev
Commit: 99e0d24d90fa1c588d73f87f402dfb0ac36ca8a7
Parents: 02851fa
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Mon May 16 21:40:24 2016 +0800
Committer: wangwei <wa...@gmail.com>
Committed: Tue May 17 00:40:24 2016 +0800

----------------------------------------------------------------------
 CMakeLists.txt                     |   7 +-
 include/singa/core/common.h        |  29 ++++-
 include/singa/core/device.h        |   4 +-
 include/singa/core/tensor.h        |  62 ++++++-----
 include/singa/model/layer.h        | 190 +++++++++++++++++++++++++-------
 include/singa/model/param.h        |  97 ----------------
 src/CMakeLists.txt                 |   7 +-
 src/core/device/device.cc          |   4 +-
 src/core/tensor/tensor.cc          | 107 ++++++++++--------
 src/core/tensor/tensor_math.h      |  11 +-
 src/core/tensor/tensor_math_cpp.h  |  29 +++++
 src/core/tensor/tensor_math_cuda.h |  24 ++--
 src/model/layer/conv.cc            |  27 -----
 src/model/layer/cudnn_dropout.cc   | 106 ++++++++++++++++++
 src/model/layer/cudnn_dropout.h    |  54 +++++++++
 src/model/layer/cudnn_utils.h      |  83 ++++++++++++++
 src/model/layer/dropout.cc         |  60 ++++++++++
 src/model/layer/dropout.h          |  49 ++++++++
 src/model/layer/layer.cc           |  30 -----
 src/proto/core.proto               |   3 +-
 src/proto/layer.proto              |  10 +-
 test/singa/test_dropout.cc         |  29 +++++
 test/singa/test_tensor.cc          |  10 +-
 23 files changed, 722 insertions(+), 310 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 67a82e5..dd92d03 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,6 +1,6 @@
 CMAKE_MINIMUM_REQUIRED(VERSION 2.6)
 PROJECT(singa)
-SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -std=c++11")
+SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -std=c++11 -DUSE_CUDA -DUSE_CUDNN")
 
 # Flags
 IF(UNIX OR APPLE)
@@ -10,12 +10,13 @@ ENDIF()
 # Includes
 SET(singa_include_dir ${PROJECT_SOURCE_DIR}/include)
 INCLUDE_DIRECTORIES(${singa_include_dir} ${PROJECT_BINARY_DIR})
+INCLUDE_DIRECTORIES("/home/wangwei/local/cudnn5/include" "/usr/local/cuda/include")
 
 
 SET(LIBRARY_OUTPUT_PATH ${PROJECT_BINARY_DIR}/lib)
 SET(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/bin)
-SET(singa_linker_lib)
-LINK_DIRECTORIES(${LIBRARY_OUTPUT_PATH})
+SET(singa_linker_lib cudnn)
+LINK_DIRECTORIES(${LIBRARY_OUTPUT_PATH} "/home/wangwei/local/cudnn5/lib64/")
 
 INCLUDE(cmake/ProtoBuf.cmake)
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/include/singa/core/common.h
----------------------------------------------------------------------
diff --git a/include/singa/core/common.h b/include/singa/core/common.h
index 1d73f67..4d783fb 100644
--- a/include/singa/core/common.h
+++ b/include/singa/core/common.h
@@ -18,9 +18,18 @@
 
 #ifndef SINGA_CORE_COMMON_H_
 #define SINGA_CORE_COMMON_H_
-
+#include <random>
+#include <chrono>
 #include "singa/utils/logging.h"
 
+#ifdef USE_CUDA
+#include <cuda_runtime.h>
+#include "cublas_v2.h"
+#ifdef USE_CUDNN
+#include <cudnn.h>
+#endif
+#endif
+
 namespace singa {
 namespace lib {
 /// To implemente functions using cpp libraries
@@ -37,10 +46,10 @@ typedef unsigned char Byte;
 /// Blob reprent a chunk of memory (on device or host) managed by VirtualMemory.
 class Blob {
  public:
-  Blob(void* ptr, int size) : data_(ptr), size_(size), ref_count_(1) {}
+  Blob(void* ptr, size_t size) : data_(ptr), size_(size), ref_count_(1) {}
   void* mutable_data() const { return data_; }
   const void* data() const { return data_; }
-  int size() const { return size_; }
+  size_t size() const { return size_; }
   int IncRefCount() {
     ref_count_++;
     return ref_count_;
@@ -54,11 +63,21 @@ class Blob {
 
  private:
   void* data_ = nullptr;
-  int size_ = 0;
+  size_t size_ = 0;
   int ref_count_ = 0;
 };
 
-class Context {};
+typedef struct _Context {
+  std::mt19937 random_generator;
+  unsigned long long seed;
+#ifdef USE_CUDA
+  cublasHandle_t cublas_handle;
+  cudaStream_t stream;
+#ifdef USE_CUDNN
+  cudnnHandle_t cudnn_handle;
+#endif
+#endif
+} Context;
 
 }  // namespace singa
 #endif  // SINGA_CORE_COMMON_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/include/singa/core/device.h
----------------------------------------------------------------------
diff --git a/include/singa/core/device.h b/include/singa/core/device.h
index fa30d6d..f3bb5a2 100644
--- a/include/singa/core/device.h
+++ b/include/singa/core/device.h
@@ -79,8 +79,8 @@ class Device {
   void CopyDataFromHostPtr(Blob* dst, const void* src, size_t size);
   /// Submit the operation to the device, which may execute it right now or
   /// delay it depending on the scheduler.
-  void Submit(function<void(Context*)> fn, const vector<Blob*> read_blobs,
-              const vector<Blob*> write_blobs);
+  void Exec(function<void(Context*)> fn, const vector<Blob*> read_blobs,
+              const vector<Blob*> write_blobs, bool use_rand_generator = false);
 
   // Wait for one event.
   // void WaitFor();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 4278078..4807123 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -31,25 +31,23 @@ using std::vector;
 using std::tuple;
 namespace singa {
 
-typedef vector<int> Shape;
-inline int Product(Shape shape) {
-  if (shape.size() == 0)
-    return 0;
-  return Product(shape.begin(), shape.end());
-}
-
-inline int Product(vector<int>::iterator begin, vector<int>::iterator end) {
-  CHECK(begin != end);
-  int v = 1;
-  for (auto it = being; it < end; it++)
-    v* = *it;
+typedef vector<size_t> Shape;
+typedef Shape::iterator ShapeIter;
+inline size_t Product(const Shape& shape, int start = 0, size_t len = 0) {
+  if (len == 0)
+    len = shape.size();
+  CHECK_LE(len, shape.size());
+  size_t v = 1;
+  for (unsigned int i = start; i < len; i ++)
+    v *= shape[i];
   return v;
 }
 
 /// hardcode the width of types defined in DataType
-const int kDataWidth[] = {4, 2, 4, 1};
-inline int SizeOf(DataType t) {
-  static_assert(kNumDataType == sizeof(kDataWidth) / sizeof(int),
+const size_t kDataWidth[] = {sizeof(float), sizeof(float) / 2, sizeof(int),
+                          sizeof(char), sizeof(double)};
+inline size_t SizeOf(DataType t) {
+  static_assert(kNumDataType == sizeof(kDataWidth) / sizeof(size_t),
       "Num of data types not match num of data width");
   CHECK_GT(kNumDataType, t);
   return kDataWidth[t];
@@ -112,18 +110,23 @@ class Tensor {
   }
 
   /// Return number of total elements
-  int Size() const {
+  size_t Size() const {
     return blob_->size() / SizeOf(data_type_);
   }
 
   /// Return memory size (i.e., Bytes)
-  int MemSize() const {
+  size_t MemSize() const {
     return blob_->size();
   }
 
   /// Reset the tensor shape, it may reallocate blob, if MemSize() changes.
   void ReShape(const Shape& shape);
 
+  /// Reset the shape, device, and data type as given tensor.
+  /// If blob size changes, then reallocate a new blob. The previous blob would
+  /// be deleted.
+  void ResetLike(const Tensor& t);
+
   /// Reset the data type, it would reallocate blob if type changes.
   void AsType(DataType type);
 
@@ -136,7 +139,7 @@ class Tensor {
 
   /// For init the tensor values, copy 'num' elements.
   template<typename DType>
-  void CopyDataFromHostPtr(const DType* src, int num);
+  void CopyDataFromHostPtr(const DType* src, size_t num);
 
   /// Copy data from another Tensor which may be on a diff device.
   /// Meta data would not be copied!
@@ -207,17 +210,17 @@ class Tensor {
 /// The first 'src_offset' ('dst_offset') elements will be skipped.
 void CopyData(Tensor* dst,
               const Tensor& src,
-              int num,
-              int src_offset = 0,
-              int dst_offset = 0);
+              size_t num,
+              size_t src_offset = 0,
+              size_t dst_offset = 0);
 
 /// Copy 'nBytes' bytes of src data to dst.
 /// The first 'src_offset' ('dst_offset') bytes will be skipped.
 void CopyRawData(Tensor* dst,
               const Tensor& src,
-              int nBytes,
-              int src_offset = 0,
-              int dst_offset = 0);
+              size_t nBytes,
+              size_t src_offset = 0,
+              size_t dst_offset = 0);
 
 // ==================Simple Linear Algebra Operations=========================
 Tensor Abs(const Tensor& t);
@@ -306,15 +309,15 @@ void Mult(DType alpha, const Tensor& lhs, DType beta, const Tensor& rhs,
 // tempalte<typename DType> T Dot(const Tensor& lhs, const Tensor& rhs);
 
 //================Random operations==========================================
-/// For each element x set x = 0 if random() < p; otherwise x = 1.
-Tensor Bernoulli(float p, Blob* t);
+/// For each element x set x = 1 if random() < p; otherwise x = 1.
+void Bernoulli(float p, Tensor* t);
 /// Fill in Tensor 't' following uniform distribution.
-Tensor Uniform(float low, DType high, Blob* t);
+void Uniform(float low, float high, Tensor* t);
 /// Fill in Tensor 't' following Gaussian distribution.
-Tensor Gaussian(float mean, DType std, Blob* t);
+void Gaussian(float mean, float std, Tensor* t);
 
 //================Neural Net operations======================================
-// following API of cudnn, e.g., conv, pool, lrn, batchnorm, softmax
+/* following API of cudnn, e.g., conv, pool, lrn, batchnorm, softmax
 void ConvFwd(const ConvConf& conf, const Tensor& x, const Tensor& w, Tensor* y);
 void ConvBwdBias(const ConvConf& conf, const Tensor& dy, Tensor* db);
 void ConvBwdFilter(const ConvConf& conf, const Tensor& dy, const Tensor& x,
@@ -325,6 +328,7 @@ void PoolFwd(const PoolConf& conf, const Tensor& x, Tensor* y,
              Tensor* mask = nullptr);
 void PoolBwd(const PoolConf& conf, const Tensor& y, const Tensor& dy,
              const Tensor& x, Tensor* dx);
+*/
 }  // namespace singa
 
 #endif  // SINGA_CORE_TENSOR_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/include/singa/model/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h
index 7b9b6d4..48fc58f 100644
--- a/include/singa/model/layer.h
+++ b/include/singa/model/layer.h
@@ -21,6 +21,7 @@
 
 #include <vector>
 #include <string>
+#include <stack>
 #include "singa/core/tensor.h"
 #include "singa/proto/layer.pb.h"
 
@@ -28,14 +29,10 @@ namespace singa {
 
 /// The base layer class.
 /// Generally, a layer conducts feature transformation against a set of Tensor
-/// to generate a set of Tensor. Each layer may have some parameters represented
-/// by Param instances.
+/// to generate a set of Tensor. Each layer may have some parameters.
 class Layer {
  public:
   Layer() = default;
-  /// Each layer sub-class would optionaly have a type name.
-  /// Used for debugging and logging.
-  virtual const std::string layer_type() const { return "Unknown"; }
 
   /// Set meta data fields from a string representing a proto message.
   void Setup(const string& proto_str) {
@@ -44,68 +41,183 @@ class Layer {
     this->Setup(conf);
   }
 
+  // ============= Following Functions could be override =====================
+  /// Destruct the objecst created by this layer.
+  virtual ~Layer() {
+    for (Tensor * t : param_values_) {
+      delete t;
+    }
+  }
+
+  /// Each layer sub-class would optionaly have a type name.
+  /// Used for debugging and logging.
+  virtual const std::string layer_type() const { return "Unknown"; }
+
   /// Set meta data fields configured in 'conf' (a proto message).
   virtual void Setup(const LayerConf& conf) {
     name_ = conf.name();
+    for (const auto& spec : conf.param())
+      param_specs_.push_back(spec);
+    // TODO(wangwei) load param values from checkpoint blobs.
   }
 
-  /// Do feature transformation for given 'input' Tensor.
-  /// It is the forward pass for feed-forward nets and rnn nets.
+  /// Do feature transformation for the given 'input' tensor (denoted as x).
   /// 'flag' is either kPhaseTrain or kPhaseTest for feed-forward nets, and
-  /// would be used for phases of training other nets.
-  /// It will return a set of Tensor.
-  virtual const vector<Tensor> ComputeFeature(int flag,
-                                              const vector<Tensor>& input) {
-    return vector<Tensor>{};
-  }
-  /// Compute gradients of parameters of this layer.
-  /// It would also compute the gradients for other layers, e.g., the
-  /// preceding layers in topology order. It would return an empty vector if
-  /// this layer does not need to compute gradients for other layers.
-  /// 'flag' is either kPhaseTrain or kPhaseTest for feed-forward nets, and
-  /// would be used for phases of training other nets.
-  /// 'input' is a vector of Tensor for gradients from other layers.
-  virtual const vector<Tensor> ComputeGradient(int flag,
-                                               const vector<Tensor>& input) {
-    return vector<Tensor>{};
+  /// would be used for other phases of training other nets. For example, when
+  /// training RBM, we may create an alias of this function as ComputeFeature
+  /// where flag could be kPositivePhase and kNegativePhase.
+  /// It will return a Tensor (denoted as y).
+  /// If the 'input' or 'output' is required for computing the gradients in
+  /// Backward(), then push them into the states_ stack.
+  virtual const Tensor Forward(int flag, const Tensor& input) {
+    LOG(FATAL) << "Not implemented";
+    Tensor t;
+    return t;
+  }
+
+  /// \copydoc Forward(int flag, const Tensor& input)
+  /// Accept multiple input tensors and generate multiple output tensors.
+  virtual const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) {
+    vector<Tensor> ret;
+    if (inputs.size() == 1)
+      ret.push_back(Forward(flag, inputs.at(0)));
+
+    LOG(FATAL) << "Not implemented";
+    return ret;
+  }
+
+  /// Compute gradients of this layer.
+  /// Specifically, there are two types of gradients:
+  /// 1. gradients of preceding layers, i.e., dx.
+  /// 2. gradients of parameters of this layer.
+  /// 1 and 2 are returned as a pair of vector<Tensor>
+  /// 1 is an empty tensor if there is no preceding layer or there is no need to
+  /// compute dx (e.g., x is from a data layer); 2 is empty if this layer has no
+  /// parameters.
+  /// 'flag' is either kTrainPhase or kTestPhase for feed-forward nets, and
+  /// would be used for other phases when training other nets.
+  /// 'grad' is a Tensor for gradient (dy) from the upper layer.
+  /// Some layer would use 'input' or 'output' from Forward to compute the
+  /// gradients of parameters. Backward() pop out the state data.
+  /// It is useful for RNN layers, where the same layer is used multiple
+  /// times just like unrolling the layer.
+  virtual const std::pair<Tensor, vector<Tensor>> Backward(int flag,
+                                                           const Tensor& grad) {
+    LOG(FATAL) << "Not implemented!";
+    Tensor t;
+    return std::make_pair(t, vector<Tensor>{});
+  }
+
+  /// \copydoc Backward(int, const vector<Tensor>&)
+  /// For Forward(int, const vector<Tensor>&)
+  virtual const std::pair<vector<Tensor>, vector<Tensor>> Backward(
+      int flag, const vector<Tensor>& grads) {
+    vector<Tensor> input_grad, param_grad;
+    if (grads.size() == 1u) {
+      auto ret = Backward(flag, grads.at(0));
+      input_grad.push_back(ret.first);
+      param_grad = ret.second;
+    } else  {
+      LOG(FATAL) << "Not implemented";
+    }
+    return std::make_pair(input_grad, param_grad);
   }
-  // return <dx>  <dw (ParamGrad)>
 
-  /// Move the layer (including its parameters and other Tensor) onto the given
-  /// device
+  /// Move the layer (including its parameters and other internal Tensor) onto
+  /// the given device
   virtual void ToDevice(Device* device) {
-    // for (auto p : params_)
-      // p->ToDevice(device);
+    for (auto p : param_values_) p->ToDevice(device);
   }
 
-  /// Set the data type of Tensor s in this layer.
+  /// Set the data type of Tensor in this layer.
   virtual void AsType(DataType dtype) {
-  //     for (auto p : params_)
-  //     p->AsType(dtype);
+    for (auto p : param_values_) p->AsType(dtype);
   }
 
-  /// Serialize the layer info, including params)_, into a LayerConf message.
-  virtual std::string ToProto(LayerConf* conf) const {
+  /// Serialize the layer info (including params) into a LayerConf proto message
+  virtual void ToProto(LayerConf* conf) const {
     conf->set_name(name_);
+    for (const auto& spec: param_specs_) {
+      ParamSpec* p = conf->add_param();
+      p->CopyFrom(spec);
+    }
+    // TODO(wangwei) add param values into conf;
   }
 
+  // ========================================================================
+
   /// Serialize the layer info, including params_, into a string representing
   /// a LayerParameter message.
-  std::string ToProtoStr() const;
+  std::string ToProtoStr() const {
+    LayerConf conf;
+    ToProto(&conf);
+    string str;
+    conf.SerializeToString(&str);
+    return str;
+  }
+  /// Return specs/configuration of all parameter instances of this layer.
+  /// \ref ParamSpec.
+  const vector<ParamSpec> param_specs() {
+    return param_specs_;
+  }
 
-  /// Return all Param instances of this layer.
-  /// Each layer could cache the Param objects.
-  /// To save memory of , it can also create it when this function
-  /// is called
-  const vector<Param*> GetParam();
+  /// Return the i-th ParamSpec.
+  const ParamSpec& param_specs(int i) {
+    return param_specs_.at(i);
+  }
+
+  /// Return pointers to parameter Tensor s.
+  const vector<Tensor*> param_values() {
+    return param_values_;
+  }
+
+  /// Return a pointer to the 'i'-th parameter Tensor.
+  Tensor* param_value(size_t i) {
+    CHECK_LT(i, param_values_.size());
+    return param_values_[i];
+  }
+
+  /// Return names of all parmaeters.
+  const vector<string> param_names() {
+    vector<string> pname;
+    for (const auto& spec: param_specs_)
+      pname.push_back(spec.name());
+    return pname;
+  }
+
+  /// Return the 'i'-th parameter name.
+  const string& param_name(size_t i) {
+    CHECK_LT(i, param_specs_.size());
+    return param_specs_.at(i).name();
+  }
 
   /// Each layer instance would optionally have a name.
   /// Used for debugging and logging.
   const std::string name() const { return name_; }
 
+  /*
+  std::stack<Tensor> states() const {
+    return states_;
+  }
+  */
+
  protected:
   std::string name_;
+  vector<Tensor*> param_values_;
+  vector<ParamSpec> param_specs_;
+  /// Used to store input or output of Forward(), which would be used in
+  /// Backward.  Rules:
+  /// 1. push the 'input' or 'output' into states_ if the flag of Forward() is
+  ///    for training.
+  /// 2. pop data out in Backward().
+  /// TODO(wangwei) enable this feature for rnn layers.
+  // std::stack<Tensor*> states_;
 };
 
+// ===========================================================================
+// Order layer sub-classes based on alphabetical order of the first letter.
+// ===========================================================================
+
+
 }  // namespace singa
 #endif  // SINGA_LAYER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/include/singa/model/param.h
----------------------------------------------------------------------
diff --git a/include/singa/model/param.h b/include/singa/model/param.h
deleted file mode 100644
index b859b1c..0000000
--- a/include/singa/model/param.h
+++ /dev/null
@@ -1,97 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_MODEL_PARAM_H_
-#define SINGA_MODEL_PARAM_H_
-#include "singa/core/tensor.h"
-#include <vector>
-#include <string>
-using std::vector;
-using std::string;
-namespace singa {
-/// Base Param class for storing set of parameters, e.g., a weight matrix or a
-/// bias vector.
-/// It includes multiple Tensor s for parameter values, gradients, etc.
-class Param {
- public:
-  ~Param();
-  Param(const ParamSpec& conf);
-  Param(Param&& p);
-  Param(const Param& p);
-  void operator=(Param&& p);
-  void operator=(const Param& p);
-
-  Tensor& value() {
-    return value_;
-  }
-
-  Tensor& grad() {
-    return grad_;
-  }
-
-  void set_value(const Tensor& t) {
-    value_ = t;
-  }
-
-  void set_value(Tensor&& t) {
-    value_ = std::move(t);
-  }
-
-  void set_grad(const Tensor& t) {
-    isGradValid_ = true;
-    grad_ = t;
-  }
-
-  void set_grad(Tensor&& t) {
-    grad_ = std::move(t);
-  }
-
-  // void Compress();
-  // string ToString();
-
- protected:
-  string name_;
-  Tensor value_;
-  float lr_mult_ = 1.0f, decay_mult_ = 1.0f;
-};
-
-class ParamGrad {
-// return grad tensor or data to recover the grad tensor, e.g., if W = U * V
-// then, ParamGrad could just store U and V. provide func for serailize and
-// deserialize.
-};
-
-// updater just copy the ParamGrad to a device and submit ops to that device, e.g.,
-// add grad; check update_condidtion; apply sgd; copy back.
-// consider rpc (no rmda).
-
-Param* CreateParam(string type) {
-  Param* p = nullptr;
-  if (type == "default")
-    p = new Param();
-  else
-    LOG(FATAL) << "Currently param type " << type << " is not implemented."
-               << "Pls use the 'default' type";
-  return p;
-}
-#endif  // SINGA_MODEL_PARAM_H_
-
-}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index d8bec8d..e2e923e 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -15,7 +15,12 @@ FILE(GLOB_RECURSE core_source ${CMAKE_CURRENT_SOURCE_DIR}/core/ "*.cc")
 ADD_LIBRARY(singa_core SHARED ${core_source})
 TARGET_LINK_LIBRARIES(singa_core ${singa_linker_libs})
 list(APPEND singa_linker_libs singa_core)
-MESSAGE(STATUS "link libs " ${singa_linker_libs})
+#MESSAGE(STATUS "link libs " ${singa_linker_libs})
+
+FILE(GLOB_RECURSE model_source ${CMAKE_CURRENT_SOURCE_DIR}/model/ "*.cc")
+ADD_LIBRARY(singa_model SHARED ${model_source})
+TARGET_LINK_LIBRARIES(singa_model ${singa_linker_libs})
+list(APPEND singa_linker_libs singa_model)
 
 #ADD_LIBRARY(singa_layer SHARED ${LAYER_SOURCE})
 #ADD_LIBRARY(singa_model SHARED ${MODEL_SOURCE})

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/core/device/device.cc
----------------------------------------------------------------------
diff --git a/src/core/device/device.cc b/src/core/device/device.cc
index 4976a32..b2a8705 100644
--- a/src/core/device/device.cc
+++ b/src/core/device/device.cc
@@ -25,8 +25,8 @@ Device::Device(int id, int num_executors, string scheduler, string vm)
   vm_ = nullptr;
 }
 
-void Device::Submit(function<void(Context*)> fn, const vector<Blob*> read_blobs,
-                    const vector<Blob*> write_blobs) {
+void Device::Exec(function<void(Context*)> fn, const vector<Blob*> read_blobs,
+                    const vector<Blob*> write_blobs, bool use_rand_generator) {
   fn(nullptr);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 51b785e..8352b48 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -20,6 +20,7 @@
 #include "./tensor_math_cpp.h"
 #include "./tensor_math_cuda.h"
 #include "./tensor_math_opencl.h"
+#include <utility>
 
 namespace singa {
 
@@ -69,6 +70,16 @@ Tensor::Tensor(Tensor&& t)
   t.blob_ = nullptr;
 }
 
+void Tensor::ResetLike(const Tensor& t) {
+  if (blob_->size() != t.MemSize()) {
+    if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_);
+    shape_ = t.shape_;
+    device_ = t.device_;
+    data_type_ = t.data_type_;
+    blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_));
+  }
+}
+
 void Tensor::ReShape(const Shape& shape) {
   if (shape_ != shape) {
     if (blob_ != nullptr && blob_->DecRefCount() == 0)
@@ -105,7 +116,7 @@ void Tensor::ToHost() {
 }
 
 template<typename DType>
-void Tensor::CopyDataFromHostPtr(const DType* src, int num) {
+void Tensor::CopyDataFromHostPtr(const DType* src, size_t num) {
   CHECK_EQ(sizeof(DType), SizeOf(data_type_)) << "data_type is "
                                               << DataType_Name(data_type_)
                                               << " user given type is of size "
@@ -115,7 +126,7 @@ void Tensor::CopyDataFromHostPtr(const DType* src, int num) {
   else
     LOG(WARNING) << "Copy data from null host ptr";
 }
-template void Tensor::CopyDataFromHostPtr(const float* src, int num);
+template void Tensor::CopyDataFromHostPtr(const float* src, size_t num);
 
 void Tensor::CopyData(const Tensor& src) {
   CHECK_EQ(Size(), src.Size());
@@ -134,10 +145,10 @@ Tensor Tensor::Clone() {
 }
 
 Tensor Tensor::T() const {
-  CHECK_EQ(shape_.size(), 2);
+  CHECK_EQ(shape_.size(), 2u);
   Tensor t(*this);
   t.transpose_ = ~transpose_;
-  std::swap(shape_[0], shape_[1]);
+  std::swap(t.shape_[0], t.shape_[1]);
   return t;
 }
 
@@ -185,21 +196,21 @@ GenUnaryScalarArgMemberFunction(operator/=, Div);
 // ====================Tensor Operations=======================================
 void CopyData(Tensor* dst,
               const Tensor& src,
-              int num,
-              int dst_offset,
-              int src_offset) {
+              size_t num,
+              size_t dst_offset,
+              size_t src_offset) {
   CHECK_GE(src.Size(), src_offset + num);
   CHECK_GE(dst->Size(), dst_offset + num);
-  int width = SizeOf(src.data_type());
+  auto width = SizeOf(src.data_type());
   CHECK_EQ(width, SizeOf(dst->data_type()));
   CopyRawData(dst, src, num * width, dst_offset * width, src_offset * width);
 }
 
 void CopyRawData(Tensor* dst,
               const Tensor& src,
-              int nBytes,
-              int dst_offset,
-              int src_offset) {
+              size_t nBytes,
+              size_t dst_offset,
+              size_t src_offset) {
   CHECK_GE(src.MemSize(), src_offset + nBytes);
   CHECK_GE(dst->MemSize(), dst_offset + nBytes);
   Device* src_dev = src.device(), *dst_dev = dst->device();
@@ -286,7 +297,7 @@ void CopyRawData(Tensor* dst,
 #define EltwiseUnaryTensorFn(fn, t, ret)                                   \
   do {                                                                     \
     TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { \
-      ret->device()->Submit(                                               \
+      ret->device()->Exec(                                               \
           [t, ret](Context* ctx) {                                         \
             fn<DType, Lib>(t.Size(), t.blob(), ret->blob(), ctx);          \
           },                                                               \
@@ -320,14 +331,14 @@ Tensor Softmax(const Tensor& t, int axis) {
 void Softmax(const Tensor& t, Tensor* ret, int axis) {
   int nrow = 1, ncol = t.Size(), size = ncol;
   CHECK_GE(axis, -1);
-  CHECK_GT(t.shape().size(), 0);
+  CHECK_GT(t.shape().size(), 0u);
   if (axis > -1) {
-    nrow = Product(t.shape().begin(), t.shape().begin() + axis + 1);
+    nrow = Product(t.shape(), 0, axis + 1);
     CHECK_EQ(size % nrow, 0) << "Size = " << size << " nrow = " << nrow;
     ncol = size / nrow;
   }
   TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, {
-    ret->device()->Submit(
+    ret->device()->Exec(
         [nrow, ncol, t, ret](Context* ctx) {
           Softmax<DType, Lib>(nrow, ncol, t.blob(), ret->blob(), ctx);
         },
@@ -338,8 +349,8 @@ void Softmax(const Tensor& t, Tensor* ret, int axis) {
 #define EltwiseBinaryTensorFn(fn, lhs, rhs, ret)                               \
   do {                                                                         \
     TYPE_LIB_SWITCH(lhs.data_type(), DType, lhs.device()->device_lib(), Lib, { \
-      ret->device()->Submit(                                                   \
-          CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type()));                    \
+      CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type()));                        \
+      ret->device()->Exec(                                                     \
           [lhs, rhs, ret](Context* ctx) {                                      \
             fn<DType, Lib>(lhs.Size(), lhs.blob(), rhs.blob(), ret->blob(),    \
                            ctx);                                               \
@@ -364,28 +375,28 @@ GenBinaryTensorFunction(operator*, EltwiseMult);
 GenBinaryTensorFunction(operator/, Div);
 GenBinaryTensorFunction(Pow, Pow);
 
-#define EltwiseTensorScalarFn(fn, t, x, ret)                                \
-  do {                                                                      \
-    TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, {  \
-      ret->device()->Submit(                                                \
-          static_assert(typeid(x) == typeid(DType),                         \
-                        "The Scalar type must match the Tensor data type"); \
-          [t, x, ret](Context* ctx) {                                       \
-            fn<DType, Lib>(t.Size(), t.blob(), x, ret->blob(), ctx);        \
-          },                                                                \
-          {t.blob()}, {ret->blob()});                                       \
-    });                                                                     \
+#define EltwiseTensorScalarFn(fn, t, x, ret)                               \
+  do {                                                                     \
+    TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { \
+      static_assert(std::is_same<SType, DType>::value,                             \
+                    "The Scalar type must match the Tensor data type");    \
+      ret->device()->Exec(                                                 \
+          [t, x, ret](Context* ctx) {                                      \
+            fn<DType, Lib>(t.Size(), t.blob(), x, ret->blob(), ctx);       \
+          },                                                               \
+          {t.blob()}, {ret->blob()});                                      \
+    });                                                                    \
   } while (0)
 
 #define GenTensorScalarFunction(op, fn)                \
-  template <typename DType>                                \
-  Tensor op(const Tensor& t, DType x) {                    \
+  template <typename SType>                            \
+  Tensor op(const Tensor& t, SType x) {                \
     Tensor ret(t.shape(), t.device(), t.data_type());  \
     fn(t, x, &ret);                                    \
     return ret;                                        \
   }                                                    \
-  template <typename DType>                                \
-  void fn(const Tensor& t, DType x, Tensor* ret) {   \
+  template <typename SType>                            \
+  void fn(const Tensor& t, SType x, Tensor* ret) {     \
     EltwiseTensorScalarFn(fn, t, x, ret);              \
   }                                                    \
   template Tensor op<float>(const Tensor& t, float x); \
@@ -424,15 +435,15 @@ template Tensor Mult<float>(float alpha, const Tensor& lhs, float beta,
 template <typename SType>
 void Mult(SType alpha, const Tensor& A, SType beta, const Tensor& B, Tensor* C)
 {
-  CHECK_EQ(A.shape().size(), 2);
+  CHECK_EQ(A.shape().size(), 2u);
   bool transA = A.transpose();
-  int m = transA ? A.shape()[1] : A.shape()[0], n = 0;
-  if (B.shape().size() == 1) {
+  size_t m = transA ? A.shape()[1] : A.shape()[0], n = 0;
+  if (B.shape().size() == 1u) {
     n = C->Size();
     TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->device_lib(), Lib, {
       static_assert(std::is_same<SType, DType>::value,
         "The scalar type must be the same as the tensor data type");
-      C->device()->Submit(
+      C->device()->Exec(
         [transA, m, n, alpha, A, beta, B, C](Context* ctx) {
         GEMV<DType, Lib>(transA, m, n, alpha, A.blob(),
           B.blob(), beta, C->blob(), ctx);
@@ -442,7 +453,7 @@ void Mult(SType alpha, const Tensor& A, SType beta, const Tensor& B, Tensor* C)
   } else {
     CHECK(!C->transpose());
     bool transB = B.transpose();
-    int k = transB ? B.shape()[1] : B.shape()[0];
+    size_t k = transB ? B.shape()[1] : B.shape()[0];
     n = C->shape()[1];
     CHECK_EQ(C->shape()[0], m);
     CHECK_EQ(A.Size(), m * k);
@@ -450,7 +461,7 @@ void Mult(SType alpha, const Tensor& A, SType beta, const Tensor& B, Tensor* C)
     TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->device_lib(), Lib, {
         static_assert(std::is_same<SType, DType>::value,
           "The scalar type must be the same as the tensor data type");
-        C->device()->Submit(
+        C->device()->Exec(
           [transA, transB, m, n, k, alpha, A, beta, B, C](Context* ctx) {
           GEMM<DType, Lib>(transA, transB, m, n, k, alpha, A.blob(),
             B.blob(), beta, C->blob(), ctx);
@@ -468,7 +479,7 @@ template void Mult<float>(float alpha, const Tensor& lhs, float beta,
 void Conv(const OpConf* conf, const Tensor& input, const Tensor& W,
           const Tensor& b, Tensor* ret) {
   TYPE_LIB_SWITCH(input.data_type(), DType, input.device()->nn_lib(), Lib, {
-    ret->device()->Submit(
+    ret->device()->Exec(
         [conf, input, W, b, ret](Context* ctx) {
           Conv<DType, Lib>(conf, input.blob(), W.blob(), b.blob(), ret->blob(),
                            ctx);
@@ -477,33 +488,33 @@ void Conv(const OpConf* conf, const Tensor& input, const Tensor& W,
   });
 }
 */
-void Bernoulli(float threshold, Tensor* t) {
+void Bernoulli(float p, Tensor* t) {
   TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, {
-    t->device()->Submit(
-        [threshold, t](Context* ctx) {
-          Bernoulli<DType, Lib>(t->Size(), threshold, t->blob(), ctx);
+    t->device()->Exec(
+        [p, t](Context* ctx) {
+          Bernoulli<DType, Lib>(t->Size(), p, t->blob(), ctx);
         },
-        {}, {t->blob()});
+        {}, {t->blob()}, true);
   });
 }
 
 void Uniform(float low, float high, Tensor* t) {
   TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, {
-    t->device()->Submit(
+    t->device()->Exec(
         [low, high, t](Context* ctx) {
           Uniform<DType, Lib>(t->Size(), low, high, t->blob(), ctx);
         },
-        {}, {t->blob()});
+        {}, {t->blob()}, true);
   });
 }
 
 void Gaussian(float mean, float std, Tensor* t) {
   TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, {
-    t->device()->Submit(
+    t->device()->Exec(
         [mean, std, t](Context* ctx) {
           Gaussian<DType, Lib>(t->Size(), mean, std, t->blob(), ctx);
         },
-        {}, {t->blob()});
+        {}, {t->blob()}, true);
   });
 }
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/core/tensor/tensor_math.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index a4f68e3..aa520c9 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -96,6 +96,12 @@ void Sigmoid(int count, const Blob* input, Blob* ret, Context* ctx) {
   LOG(FATAL) << "Not Implemented";
 }
 
+/// Do softmax for each row invidually
+template <typename DType, typename Lib>
+void Softmax(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) {
+  LOG(FATAL) << "Not Implemented";
+}
+
 /// Element-wise operation, do v^x for every v from the input tensor
 template <typename DType, typename Lib>
 void Pow(int count, const Blob* input, DType x, Blob* ret, Context* ctx) {
@@ -258,7 +264,7 @@ void GEMM(bool transA, bool transB, int m, int n, int k, DType alpha,
 // Get the random generator from 'ctx'
 // If DType is not float, then convert the threshold to DType
 template <typename DType, typename Lib>
-void Bernoulli(int count, float threshold, Blob* ret, Context* ctx) {
+void Bernoulli(int count, float p, Blob* ret, Context* ctx) {
   LOG(FATAL) << "Not Implemented";
 }
 // The random generator should be extracted from ctx.
@@ -274,7 +280,7 @@ void Gaussian(int count, float mean, float std, Blob* ret, Context* ctx) {
   LOG(FATAL) << "Not Implemented";
 }
 
-// ================Neural net functions=======================================
+/* ================Neural net functions=======================================
 template <typename DType, typename Lib>
 void ConvFwd(ConvConf* conf, const Blob* x, const Blob* w, Blob* y,
              Context* ctx) {
@@ -296,6 +302,7 @@ void PoolBwd(const PoolConf* conf, const Blob* y, const Blob* dy, const Blob* x,
              Blob* dx, Context* ctx) {
   LOG(FATAL) << "Not Implemented";
 }
+*/
 
 }  // namespace singa
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/core/tensor/tensor_math_cpp.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index a953085..9e7ed30 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -40,6 +40,35 @@ void Add<float, lib::Cpp>(int count,
   }
 }
 
+template <>
+void Bernoulli<float, lib::Cpp>(int count, float p, Blob* ret,
+                                 Context* ctx) {
+  std::bernoulli_distribution distribution(p);
+  float* ptr = static_cast<float*>(ret->mutable_data());
+  for (int i = 0; i < count; i ++) {
+    ptr[i] = static_cast<float>(distribution(ctx->random_generator));
+  }
+}
+
+template <>
+void Uniform<float, lib::Cpp>(int count, float low, float high, Blob* ret,
+                               Context* ctx) {
+  std::uniform_real_distribution<float> distribution(low, high);
+  float* ptr = static_cast<float*>(ret->mutable_data());
+  for (int i = 0; i < count; i ++) {
+    ptr[i] = static_cast<float>(distribution(ctx->random_generator));
+  }
+}
+
+template <>
+void Gaussian<float, lib::Cpp>(int count, float mean, float std, Blob* ret,
+                              Context* ctx) {
+  std::normal_distribution<float> distribution(mean, std);
+  float* ptr = static_cast<float*>(ret->mutable_data());
+  for (int i = 0; i < count; i++) {
+    ptr[i] = static_cast<float>(distribution(ctx->random_generator));
+  }
+}
 #ifdef USE_CBLAS
 template<>
 void Dot<float, lib::Cpp>(int count,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index e1c72d8..c5ea3c4 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -28,24 +28,16 @@ namespace singa {
 template<>
 void Add<float, lib::Cuda>(int count, const Blob* lhs, const Blob* rhs,
                         Blob* ret, Context* ctx) {
-  cublasSetStream(ctx->handle, ctx->stream);
-  cublasScopy(ctx->handle, count, lhs->data(), 1, ret->mutable_data(), 1);
-  cublasSaxpy(ctx->handle, 1.0f, rhs->data(), 1, ret->mutable_data(), 1);
+  /*
+  cublasSetStream(ctx->cublas_handle, ctx->stream);
+  const float* lptr = static_cast<const float*>(lhs->data());
+  const float* rptr = static_cast<const float*>(rhs->data());
+  float* ptr = static_cast<float*>(ret->mutable_data());
+  cublasScopy(ctx->cublas_handle, count, lptr, 1, ptr, 1);
+  cublasSaxpy(ctx->cublas_handle, 1.0f, rptr, 1, ptr, 1);
+  */
 }
 
-#ifdef USE_CUDNN
-template<>
-void Conv<float, lib::Cudnn>(const OpConf *conf,
-          const Blob* input,
-          const Blob* W,
-          const Blob* b,
-          Blob* ret,
-          Context* ctx) {
-  // auto conv_conf = conf->CastTo<ConvConf>();
-  // conv op
-}
-
-#endif
 #endif
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/conv.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/conv.cc b/src/model/layer/conv.cc
deleted file mode 100644
index d1a7d2c..0000000
--- a/src/model/layer/conv.cc
+++ /dev/null
@@ -1,27 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-namespace singa {
-
-
-
-
-
-
-}  /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/cudnn_dropout.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_dropout.cc b/src/model/layer/cudnn_dropout.cc
new file mode 100644
index 0000000..926ccb9
--- /dev/null
+++ b/src/model/layer/cudnn_dropout.cc
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifdef USE_CUDNN
+// cudnn dropout is added in cudnn 5
+//#if CUDNN_MAJOR_VERSION >= 5
+#include "./cudnn_utils.h"
+#include "./cudnn_dropout.h"
+#include "singa/utils/logging.h"
+namespace singa {
+CudnnDropout::~CudnnDropout() {
+  if (drop_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyDropoutDescriptor(drop_desc_));
+  if (x_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_desc_));
+  if (y_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc_));
+}
+
+void CudnnDropout::InitCudnn(int size, DataType dtype, Context* ctx) {
+  CHECK(!has_init_cudnn_);
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
+  CUDNN_CHECK(cudnnCreateDropoutDescriptor(&drop_desc_));
+
+  int dim[] = {size};
+  int stride[] = {1};
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_desc_, GetCudnnDataType(dtype), 1,
+      dim, stride));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_desc_, GetCudnnDataType(dtype), 1,
+      dim, stride));
+
+  cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size_);
+  cudnnDropoutGetReserveSpaceSize(x_desc_, &reserve_size_);
+  cudnnSetDropoutDescriptor(drop_desc_, ctx->cudnn_handle, dropout_ratio_,
+    state_.blob()->mutable_data(),
+    state_size_, ctx->seed);
+  has_init_cudnn_ = true;
+}
+
+const Tensor CudnnDropout::Forward(int flag, const Tensor& input) {
+  if (flag & kTrain) {
+    auto size = input.Size();
+    DataType dtype = input.data_type();
+    if (!has_init_cudnn_) {
+      input.device()->Exec(
+          [size, dtype, this](Context* ctx) {
+          this->InitCudnn(size, dtype, ctx);
+          },
+          {}, {state_.blob()});
+      mask_.ResetLike(input);
+      CHECK_EQ(reserve_size_, mask_.MemSize());
+    }
+    Tensor out;
+    out.ResetLike(input);
+    Blob *inblob = input.blob(), *outblob = out.blob(), *mblob = mask_.blob();
+    out.device()->Exec(
+        [inblob, outblob, mblob, this](Context* ctx) {
+        cudnnDropoutForward(
+            ctx->cudnn_handle, this->drop_desc_, this->x_desc_, inblob->data(),
+            this->y_desc_, outblob->mutable_data(), mblob, this->reserve_size_);
+        },
+        {inblob}, {mblob, outblob});
+    return out;
+  } else {
+    return input;
+  }
+}
+
+const std::pair<Tensor, vector<Tensor>> CudnnDropout::Backward(
+    int flag, const Tensor& grad) {
+  vector<Tensor> param_grad;
+  Tensor dx;
+  if (flag & kTrain) {
+    dx.ResetLike(grad);
+    Blob *dyblob = grad.blob(), *dxblob = dx.blob(), *mblob = mask_.blob();
+    dx.device()->Exec(
+        [dyblob, dxblob, mblob, this](Context* ctx) {
+        cudnnDropoutBackward(ctx->cudnn_handle, this->drop_desc_,
+            this->y_desc_, dyblob->data(), this->x_desc_,
+            dxblob->mutable_data(), mblob,
+            this->reserve_size_);
+        },
+        {dyblob, mblob}, {dxblob});
+  } else {
+    LOG(ERROR) << "Do not call backward for evaluation phase";
+  }
+  return std::make_pair(dx, param_grad);
+}
+}  // namespace singa
+//#endif  // CUDNN_VERSION_MAJOR>=5
+#endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/cudnn_dropout.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_dropout.h b/src/model/layer/cudnn_dropout.h
new file mode 100644
index 0000000..0a19214
--- /dev/null
+++ b/src/model/layer/cudnn_dropout.h
@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SINGA_MODEL_LAYER_CUDNN_DROPOUT_H_
+#define SINGA_MODEL_LAYER_CUDNN_DROPOUT_H_
+#ifdef USE_CUDNN
+// cudnn dropout is added in cudnn 5
+//#if CUDNN_MAJOR_VERSION >= 5
+
+#include "singa/model/layer.h"
+#include "singa/core/common.h"
+#include "singa/proto/core.pb.h"
+#include "./dropout.h"
+
+namespace singa {
+class CudnnDropout : public Dropout {
+ public:
+  ~CudnnDropout();
+  /// \copydoc Layer::layer_type()
+  const std::string layer_type() const override { return "CudnnDropout"; }
+
+  const Tensor Forward(int flag, const Tensor& input) override;
+  const std::pair<Tensor, vector<Tensor>> Backward(
+      int flag, const Tensor& grad) override;
+
+  /// Init cudnn related data structures.
+  void InitCudnn(int size, DataType dtype, Context* ctx);
+
+ private:
+  bool has_init_cudnn_ = false;
+  cudnnDropoutDescriptor_t drop_desc_;
+  cudnnTensorDescriptor_t x_desc_, y_desc_;
+  size_t state_size_, reserve_size_;
+  Tensor state_;
+};
+}  // namespace
+//#endif  // CUDNN_VERSION_MAJOR>=5
+#endif  // USE_CUDNN
+#endif // SINGA_MODEL_LAYER_CUDNN_DROPOUT_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/cudnn_utils.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_utils.h b/src/model/layer/cudnn_utils.h
new file mode 100644
index 0000000..735ec13
--- /dev/null
+++ b/src/model/layer/cudnn_utils.h
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SINGA_MODEL_LAYER_CUDNN_BASE_H_
+#define SINGA_MODEL_LAYER_CUDNN_BASE_H_
+#ifdef USE_CUDNN
+#include "singa/proto/core.pb.h"
+#include "singa/utils/logging.h"
+#include <cudnn.h>
+namespace singa {
+inline cudnnDataType_t GetCudnnDataType(DataType dtype) {
+  cudnnDataType_t ret;
+  switch (dtype) {
+    case kFloat32:
+      ret = CUDNN_DATA_FLOAT;
+      break;
+    case kDouble:
+      ret = CUDNN_DATA_DOUBLE;
+      break;
+    case kFloat16:
+      ret = CUDNN_DATA_HALF;
+      break;
+    default:
+      LOG(FATAL) << "The data type " << DataType_Name(dtype)
+                 << " is not support by cudnn";
+  }
+  return ret;
+}
+
+#define CUDNN_CHECK(condition) \
+  do { \
+    cudnnStatus_t status = condition; \
+    CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << " "\
+      << cudnnGetErrorString(status); \
+  } while (0)
+
+/*
+inline const char* cudnnGetErrorString(cudnnStatus_t status) {
+  switch (status) {
+    case CUDNN_STATUS_SUCCESS:
+      return "CUDNN_STATUS_SUCCESS";
+    case CUDNN_STATUS_NOT_INITIALIZED:
+      return "CUDNN_STATUS_NOT_INITIALIZED";
+    case CUDNN_STATUS_ALLOC_FAILED:
+      return "CUDNN_STATUS_ALLOC_FAILED";
+    case CUDNN_STATUS_BAD_PARAM:
+      return "CUDNN_STATUS_BAD_PARAM";
+    case CUDNN_STATUS_INTERNAL_ERROR:
+      return "CUDNN_STATUS_INTERNAL_ERROR";
+    case CUDNN_STATUS_INVALID_VALUE:
+      return "CUDNN_STATUS_INVALID_VALUE";
+    case CUDNN_STATUS_ARCH_MISMATCH:
+      return "CUDNN_STATUS_ARCH_MISMATCH";
+    case CUDNN_STATUS_MAPPING_ERROR:
+      return "CUDNN_STATUS_MAPPING_ERROR";
+    case CUDNN_STATUS_EXECUTION_FAILED:
+      return "CUDNN_STATUS_EXECUTION_FAILED";
+    case CUDNN_STATUS_NOT_SUPPORTED:
+      return "CUDNN_STATUS_NOT_SUPPORTED";
+    case CUDNN_STATUS_LICENSE_ERROR:
+      return "CUDNN_STATUS_LICENSE_ERROR";
+  }
+  return "Unknown cudnn status";
+}
+*/
+
+}  // namespace singa
+#endif  // USE_CUDNN
+#endif  // SINGA_MODEL_LAYER_CUDNN_BASE_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/dropout.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/dropout.cc b/src/model/layer/dropout.cc
new file mode 100644
index 0000000..f0fe25b
--- /dev/null
+++ b/src/model/layer/dropout.cc
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "singa/model/layer.h"
+#include "./dropout.h"
+namespace singa {
+
+void Dropout::Setup(const LayerConf& conf) {
+  Layer::Setup(conf);
+  dropout_ratio_ = conf.dropout_conf().dropout_ratio();
+}
+
+const Tensor Dropout::Forward(int flag, const Tensor& input) {
+  Tensor out;
+  if (flag & kTrain) {
+    mask_.ResetLike(input);
+    // set mask_[i] = 1 with prob 1-dropout_rato_
+    Bernoulli(1 - dropout_ratio_, &mask_);
+    mask_ *= 1.0f / (1.0f - dropout_ratio_);
+    out = input * mask_;
+  } else {
+    out = input;
+  }
+  return out;
+}
+
+const std::pair<Tensor, vector<Tensor>> Dropout::Backward(
+    int flag, const Tensor& grad) {
+  vector<Tensor> param_grad;
+  Tensor input_grad;
+  if (flag & kTrain) {
+    // note mask is already scaled by 1/(1-dropout_ratio_)
+    input_grad = grad * mask_;
+  } else {
+    LOG(ERROR) << "Do not call backward for evaluation phase";
+  }
+  return std::make_pair(input_grad, param_grad);
+}
+
+void Dropout::ToDevice(Device* device) {
+  Layer::ToDevice(device);
+  mask_.ToDevice(device);
+}
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/dropout.h
----------------------------------------------------------------------
diff --git a/src/model/layer/dropout.h b/src/model/layer/dropout.h
new file mode 100644
index 0000000..de349a5
--- /dev/null
+++ b/src/model/layer/dropout.h
@@ -0,0 +1,49 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SINGA_MODEL_LAYER_DROPOUT_H_
+#define SINGA_MODEL_LAYER_DROPOUT_H_
+#include "singa/model/layer.h"
+namespace singa {
+class Dropout : public Layer {
+ public:
+  /// \copydoc Layer::layer_type()
+  const std::string layer_type() const override { return "Dropout"; }
+
+  /// \copydoc Layer::Setup(const LayerConf&);
+  void Setup(const LayerConf& conf) override;
+
+  /// \copydoc Layer::Forward(int flag, const Tensor&)
+  /// if flag is kTrain, then do dropout with given dropout_ratio;
+  /// otherwise if it is kEval, copy input directly to the output
+  /// TODO(wangwei) There are diff implementations, Caffe vs
+  /// <a href="https://github.com/nitishsrivastava/deepnet/blob/master/deepnet/fastdropoutnet.py">
+  const Tensor Forward(int flag, const Tensor& input) override;
+
+  /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&);
+  const std::pair<Tensor, vector<Tensor>> Backward(int flag,
+                                                   const Tensor& grad) override;
+
+  void ToDevice(Device* device) override;
+
+ protected:
+  /// the proability to set each element to 0.
+  float dropout_ratio_;
+  Tensor mask_;
+};
+}  // namespace singa
+#endif  // SINGA_MODEL_LAYER_DROPOUT_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/model/layer/layer.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/layer.cc b/src/model/layer/layer.cc
deleted file mode 100644
index 0e83cde..0000000
--- a/src/model/layer/layer.cc
+++ /dev/null
@@ -1,30 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include "singa/model/layer.h"
-
-namespace singa {
-const vector<Tensor> ComputeFeature(int flag, const vector<Tensor>& input) {
-  const vector<Blob*> input_blobs;
-
-}
-
-void ComputeFeature(int flag, const vector<Tensor>& input) {
-  const vector<Blob*> input_blobs;
-
-}
-}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/proto/core.proto
----------------------------------------------------------------------
diff --git a/src/proto/core.proto b/src/proto/core.proto
index c137186..f366ed0 100644
--- a/src/proto/core.proto
+++ b/src/proto/core.proto
@@ -26,7 +26,8 @@ enum DataType {
   kFloat16 = 1;
   kInt = 2;
   kChar = 3;
-  kNumDataType = 4;
+  kDouble = 4;
+  kNumDataType = 5;
 }
 
 enum LibType {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/src/proto/layer.proto
----------------------------------------------------------------------
diff --git a/src/proto/layer.proto b/src/proto/layer.proto
index 0fbbb5d..3d130ea 100644
--- a/src/proto/layer.proto
+++ b/src/proto/layer.proto
@@ -98,11 +98,15 @@ message ParamSpec {
   // The multiplier on the global weight decay for this parameter.
   optional float decay_mult = 4 [default = 1.0];
 
-  // SINGA field for creating diff Param, e.g. SparseParam or CompressableParam
-  // Curently only have a default param implementation.
-  optional string type = 20 [default = "default"];
+  // SINGA uses this filed internally. Users just configure the fillers in
+  // Layer specific conf message as caffe (style).
+  optional FillerConf filler = 20;
 }
 
+enum Phase {
+  kTrain = 4;
+  kEval = 8;
+}
 // NOTE
 // Update the next available ID when you add a new LayerConf field.
 //

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/test/singa/test_dropout.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_dropout.cc b/test/singa/test_dropout.cc
new file mode 100644
index 0000000..cfe9d73
--- /dev/null
+++ b/test/singa/test_dropout.cc
@@ -0,0 +1,29 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "gtest/gtest.h"
+#include "../src/model/layer/dropout.h"
+
+
+TEST(TestDropoutLayer, Setup) {
+
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/99e0d24d/test/singa/test_tensor.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_tensor.cc b/test/singa/test_tensor.cc
index 86200a8..ae20823 100644
--- a/test/singa/test_tensor.cc
+++ b/test/singa/test_tensor.cc
@@ -6,19 +6,19 @@ using singa::Device;
 
 TEST(TensorTest, TestConstructor) {
   singa::Tensor float_t(singa::Shape{2,3});
-  EXPECT_EQ(6, float_t.Size());
+  EXPECT_EQ(6u, float_t.Size());
   EXPECT_EQ(sizeof(float) * 6, float_t.MemSize());
   EXPECT_EQ(singa::kFloat32, float_t.data_type());
   auto s = float_t.shape();
-  EXPECT_EQ(s[0], 2);
-  EXPECT_EQ(s[1], 3);
+  EXPECT_EQ(s[0], 2u);
+  EXPECT_EQ(s[1], 3u);
 
   EXPECT_NE(float_t.device(), nullptr);
 
   singa::Tensor float16_t(Shape{2,3}, singa::kFloat16);
   EXPECT_EQ(singa::kFloat16, float16_t.data_type());
-  EXPECT_EQ(6, float16_t.Size());
-  EXPECT_EQ(12, float16_t.blob()->size());
+  EXPECT_EQ(6u, float16_t.Size());
+  EXPECT_EQ(12u, float16_t.blob()->size());
 
   singa::Tensor x(float16_t);
   EXPECT_EQ(float16_t.Size(), x.Size());