You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2016/08/09 17:01:32 UTC

[1/3] incubator-singa git commit: SINGA-218 Implementation for RNN CUDNN version

Repository: incubator-singa
Updated Branches:
  refs/heads/dev 28678ae83 -> dfc422e5b


SINGA-218 Implementation for RNN CUDNN version

Finish the CudnnRNN layer.
Pass test for tanh rnn.

RNN forward accepts a vector of input tensors: <x0, x1, ... x(n-1), hx, cx>
x(i) is the i-th input tensor, hx is the init hidden tensor which could
be a dummy tensor. A dummy tensor is a tensor created without shape/device/data_type,
during compuation, cudnnRNN would use 0s for this tensor. cx is not necessary
for relu/tanh/gru rnn. For lstm, it could also be a dummy tensor like hx.
The output is: <y0, y1, ... y(n-1), hy, cy>.
relu/tanh/gru rnns does not have cy. lstm have both hy and cy.

RNN backward accepts a vector of input gradient tensors: <dy0, dy1, ...  dy(n-1), dhy, dcy>.
dhy is necessry for all rnns, but could be a dummy tensor, in which case
a tensor with 0s would be used for dhy during computation. dcy is used
only for lstm, which could also be a dummy tensor.
The output is: <dw, <dx0, dx1, ... dx(n-1), dhx, dcx>>,
where dhx is a tensor for the gradient of hx. dcx is only used for lstm.

The CudnnRNN must be moved onto cuda, otherwise memory error would happen (the weight is on cpu).


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/8e0b1083
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/8e0b1083
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/8e0b1083

Branch: refs/heads/dev
Commit: 8e0b1083992f471849bb80b0a8e869767ee9edc0
Parents: c51f944
Author: Wei Wang <wa...@gmail.com>
Authored: Fri Aug 5 16:43:40 2016 +0800
Committer: Wei Wang <wa...@gmail.com>
Committed: Wed Aug 10 00:43:11 2016 +0800

----------------------------------------------------------------------
 CMakeLists.txt                     |   2 +-
 cmake/Thirdparty/FindCUDNN.cmake   |   2 +-
 include/singa/core/common.h        |  15 +-
 include/singa/core/device.h        |   9 +-
 include/singa/model/layer.h        |   9 -
 include/singa/utils/context.h      | 291 ---------------
 src/CMakeLists.txt                 |   1 +
 src/core/tensor/tensor.cc          |  22 +-
 src/core/tensor/tensor_math.h      |   2 +-
 src/core/tensor/tensor_math_cpp.h  |   2 +-
 src/core/tensor/tensor_math_cuda.h |   2 +-
 src/model/layer/cudnn_rnn.cc       | 610 +++++++++++++++++++-------------
 src/model/layer/cudnn_rnn.h        |  44 +--
 src/model/layer/rnn.cc             |  59 ++-
 src/model/layer/rnn.h              |  31 +-
 src/model/optimizer/adagrad.cc     |   4 +-
 src/model/optimizer/nesterov.cc    |   4 +-
 src/model/optimizer/rmsprop.cc     |   1 +
 src/model/optimizer/sgd.cc         |   4 +-
 src/proto/model.proto              |  18 +-
 test/singa/test_cudnn_rnn.cc       | 273 +++++++-------
 21 files changed, 625 insertions(+), 780 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 23f8ef6..38014ce 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,7 +1,7 @@
 CMAKE_MINIMUM_REQUIRED(VERSION 2.6)
 
 PROJECT(singa)
-SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
+SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -g")
 
 LIST(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Thirdparty)
 #message(STATUS "module path: ${CMAKE_MODULE_PATH}")

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/cmake/Thirdparty/FindCUDNN.cmake
----------------------------------------------------------------------
diff --git a/cmake/Thirdparty/FindCUDNN.cmake b/cmake/Thirdparty/FindCUDNN.cmake
index eefab9d..cefc4fe 100644
--- a/cmake/Thirdparty/FindCUDNN.cmake
+++ b/cmake/Thirdparty/FindCUDNN.cmake
@@ -27,7 +27,7 @@ IF(CUDNN_FOUND)
     ELSE()
         SET(CUDNN_VERSION "${CUDNN_VERSION_MAJOR}.${CUDNN_VERSION_MINOR}.${CUDNN_VERSION_PATCH}")
     ENDIF()
-    MESSAGE(STATUS "Found Cudnn_v${CUDNN_VERSION} at ${CUDNN_INCLUDE_DIR}")
+    MESSAGE(STATUS "Found Cudnn_v${CUDNN_VERSION} at ${CUDNN_INCLUDE_DIR} ${CUDNN_LIBRARIES}")
     MARK_AS_ADVANCED(CUDNN_INCLUDE_DIR CUDNN_LIBRARIES)
 
 ENDIF()

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/include/singa/core/common.h
----------------------------------------------------------------------
diff --git a/include/singa/core/common.h b/include/singa/core/common.h
index caa7c67..53a9726 100644
--- a/include/singa/core/common.h
+++ b/include/singa/core/common.h
@@ -65,8 +65,14 @@ class Block {
   // Disabled as it is not used currently.
   // Block(void* ptr, size_t size, size_t offset, std::shared_ptr<atomic<int>>
   //  ref) : data_(ptr), size_(size), offset_(offset), ref_count_(ref) {}
-  void* mutable_data() const { return static_cast<char*>(data_) + offset_; }
-  const void* data() const { return static_cast<char*>(data_) + offset_; }
+  void* mutable_data() {
+    initialized_ = true;
+    return static_cast<char*>(data_) + offset_;
+  }
+  const void* data() const {
+    CHECK(initialized_) << "Must initialize data before reading it";
+    return static_cast<char*>(data_) + offset_;
+  }
   size_t size() const { return size_; }
   size_t offset() const { return offset_; }
   int IncRefCount() {
@@ -77,11 +83,16 @@ class Block {
   }
   int ref_count() const { return ref_count_.load(); }
 
+  bool initialized() const {
+    return initialized_;
+  }
+
  private:
   Block() {}
   void* data_ = nullptr;
   size_t size_ = 0;
   size_t offset_ = 0;
+  bool initialized_ = false;
   // Disabled as it is not used currently.
   // std::shared_ptr<std::atomic<int>> ref_count_ = nullptr;
   std::atomic<int> ref_count_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/include/singa/core/device.h
----------------------------------------------------------------------
diff --git a/include/singa/core/device.h b/include/singa/core/device.h
index cd9a811..778a130 100644
--- a/include/singa/core/device.h
+++ b/include/singa/core/device.h
@@ -100,7 +100,7 @@ class Device {
     return lang_;
   }
 
-  std::shared_ptr<Device> host() const { return host_;}
+  virtual std::shared_ptr<Device> host() const { return host_;}
 
   Context* context(int k) {
     return &ctx_;
@@ -140,6 +140,9 @@ class Device {
   Context ctx_;
 };
 
+/// a singleton CppDevice as the host for all devices.
+extern std::shared_ptr<Device> defaultDevice;
+
 /// Represent a CPU device which may have multiple threads/executors.
 /// It runs cpp code.
 class CppCPU : public Device {
@@ -147,6 +150,7 @@ class CppCPU : public Device {
   ~CppCPU() {};
   CppCPU();
 
+  std::shared_ptr<Device> host() const override { return defaultDevice;}
   void SetRandSeed(unsigned seed) override;
  protected:
   void DoExec(function<void(Context*)>&& fn, int executor) override;
@@ -161,9 +165,6 @@ class CppCPU : public Device {
   void Free(void* ptr) override;
 };
 
-/// a singleton CppDevice as the host for all devices.
-extern std::shared_ptr<Device> defaultDevice;
-
 
 // Implement Device using OpenCL libs.
 // class OpenclDevice : public Device { };

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/include/singa/model/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h
index c35f9b8..d31bd95 100644
--- a/include/singa/model/layer.h
+++ b/include/singa/model/layer.h
@@ -158,12 +158,10 @@ class Layer {
   /// Move the layer (including its parameters and other internal Tensor) onto
   /// the given device
   virtual void ToDevice(std::shared_ptr<Device> device) {
-    //for (auto p : param_values_) p->ToDevice(device);
   }
 
   /// Set the data type of Tensor in this layer.
   virtual void AsType(DataType dtype) {
-    //for (auto p : param_values_) p->AsType(dtype);
   }
 
   /// Serialize the layer info (including params) into a LayerConf proto message
@@ -202,12 +200,6 @@ class Layer {
     return vector<Tensor>{};
   }
 
-  /// Return a pointer to the 'i'-th parameter Tensor.
-  Tensor param_value(size_t i) {
-    CHECK_LT(i, param_values_.size());
-    return param_values().at(i);
-  }
-
   /// Return names of all parmaeters.
   const vector<string> param_names() {
     vector<string> pname;
@@ -227,7 +219,6 @@ class Layer {
 
  protected:
   std::string name_;
-  vector<Tensor*> param_values_;
   vector<ParamSpec> param_specs_;
 };
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/include/singa/utils/context.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/context.h b/include/singa/utils/context.h
deleted file mode 100644
index 6e897e8..0000000
--- a/include/singa/utils/context.h
+++ /dev/null
@@ -1,291 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_UTILS_CONTEXT_H_
-#define SINGA_UTILS_CONTEXT_H_
-
-#include <chrono>
-#include <random>
-#include <thread>
-#include <unordered_map>
-#include <vector>
-
-#include "singa/utils/logging.h"
-
-#ifdef USE_GPU
-#include <cublas_v2.h>
-#include <cuda.h>
-#include <cuda_runtime.h>
-#include <curand.h>
-// CUDA: various checks for different function calls.
-#define CUDA_CHECK(condition) \
-/* Code block avoids redefinition of cudaError_t error */ \
-do { \
-cudaError_t error = condition; \
-CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
-} while (0)
-
-#ifdef USE_CUDNN
-#include <cudnn.h>
-#endif
-
-#endif // USE_GPU
-
-namespace singa {
-
-/**
- * Context is used as a global singleton, which stores the mapping from CPU
- * thread id to GPU device id. If a thread has no GPU, then its associated
- * device id is -1. It manages (e.g., creating) the handlers for GPU
- * devices. It also manages the GPU and CPU random generators, which are created
- * when accessed. One CPU thread has a CPU random generator. A GPU device
- * has a GPU random generator, which is accessible after assigning the GPU
- * device with a CPU thread via SetupDevice.
- */
-class Context {
- public:
-   /**
-    * Destructor, release random generators and handlers.
-    */
-  ~Context() {
-#ifdef USE_GPU
-    for (auto& entry : device_id_) {
-      if (entry.second != -1) {
-        cudaSetDevice(entry.second);
-        if (cublas_handle_[entry.second] != nullptr) {
-          cublasDestroy(cublas_handle_[entry.second]);
-          cublas_handle_[entry.second] = nullptr;
-        }
-        if (curand_generator_[entry.second] != nullptr) {
-          curandDestroyGenerator(curand_generator_[entry.second]);
-          curand_generator_[entry.second] = nullptr;
-        }
-      }
-    }
-#ifdef USE_CUDNN
-    for (auto& handle : cudnn_handle_) {
-      if (handle != nullptr)
-        CHECK_EQ(cudnnDestroy(handle), CUDNN_STATUS_SUCCESS);
-      handle = nullptr;
-    }
-#endif
-#endif
-    for (auto& entry : rand_generator_) {
-      if (entry.second != nullptr) {
-        delete entry.second;
-        entry.second = nullptr;
-      }
-    }
-  }
-  /**
-   * Constructor, init handlers and GPU rand generators to nullptr.
-   */
-  Context() {
-    for (int i = 0; i < kMaxNumGPU; i++) {
-#ifdef USE_GPU
-      cublas_handle_.push_back(nullptr);
-      curand_generator_.push_back(nullptr);
-#ifdef USE_CUDNN
-      cudnn_handle_.push_back(nullptr);
-#endif
-#endif
-    }
-  }
-
-  /**
-   * @return the device ID of the current thread.
-   */
-  int device_id() {
-    return device_id(std::this_thread::get_id());
-  }
-  /**
-   * @return the ID of the device attached to a given CPU thread, or -1 if this
-   * thread has not been attached GPU device.
-   */
-  int device_id(const std::thread::id& tid) {
-    if (device_id_.find(tid) != device_id_.end())
-      return device_id_[tid];
-    else
-      return -2;
-  }
-  /**
-   * Setup the CPU thread, which may be assigned a GPU device.
-   * If there is no GPU device, then set did to -1.
-   * Set the random seed to -1.
-   * @param[in] thread::id CPU thread ID
-   * @param[in] device_id GPU device ID
-   */
-  void SetupDevice(const std::thread::id& tid, const int did) {
-    SetupDevice(tid, did, -1);
-  }
-  /**
-   * @copy SetupDevice(const int, const int);
-   * @param[in] seed random seed
-   */
-  void SetupDevice(const std::thread::id& tid, const int did, const int seed) {
-    device_id_[tid] = did;
-    seed_[tid] = seed;
-  }
-
-  /**
-   * Activate the GPU device by calling cudaSetDevice.
-   */
-  void ActivateDevice(const int device_id) {
-    CHECK_GE(device_id, 0);
-#ifdef USE_GPU
-    cudaSetDevice(device_id);
-#endif
-  }
-
-  /**
-   * \copybreif rand_generator(const std::thread::id&);
-   * @return the CPU random generator for the calling thread.
-   */
-  std::mt19937* rand_generator() {
-    return rand_generator(std::this_thread::get_id());
-  }
-  /**
-   * Get the CPU random generator.
-   * If the generator does not exist, then create it now.
-   * If the seed is not set, i.e., seed=-1, then get a seed from system time.
-   * @param[in] thread::id CPU thread ID
-   * @return the CPU random generator
-   */
-  std::mt19937* rand_generator(const std::thread::id& tid) {
-    if (rand_generator_.find(tid) == rand_generator_.end()) {
-      // CHECK(seed_.find(tid) != seed_.end());
-      auto seed = static_cast<unsigned>(seed_[tid]);
-      if (seed_.find(tid) == seed_.end() || seed_.at(tid) == -1)
-        seed = std::chrono::system_clock::now().time_since_epoch().count();
-      rand_generator_[tid] = new std::mt19937(seed);
-    }
-    return rand_generator_[tid];
-  }
-#ifdef USE_GPU
-  /**
-   * \copybreif cublas_handle_(const std::thread::id&);
-   * @return cublas handle for the calling thread.
-   */
-  cublasHandle_t cublas_handle() {
-    return cublas_handle(std::this_thread::get_id());
-  }
-  /**
-   * Get the handler of the GPU which is assigned to the given thread.
-   * Calls cublas_handle(const int);
-   */
-  cublasHandle_t cublas_handle(const std::thread::id thread_id) {
-    return cublas_handle(device_id(thread_id));
-  }
-  /**
-   * Get the handler of the GPU device given its device ID. The device
-   * must be set up via SetupDevice(const std::thread::id, const int) before
-   * calling this function.
-   * @param[in] device_id GPU device ID
-   * @return the GPU handler
-   */
-  cublasHandle_t cublas_handle(const int device_id) {
-    CHECK_GE(device_id, 0);
-    if (cublas_handle_.at(device_id) == nullptr) {
-      cudaSetDevice(device_id);
-      cublasCreate(&cublas_handle_[device_id]);
-    }
-    return cublas_handle_[device_id];
-  }
-  /**
-   * Get the rand generator of the GPU device assigned to the given thread.
-   */
-  curandGenerator_t curand_generator(const std::thread::id thread_id) {
-    return curand_generator(device_id(thread_id));
-  }
-  /**
-   * Get the random generator of the GPU device given the device id.
-   * @param[in] device_id GPU device ID
-   * @return random generator. If it does not exist, then create one.
-   * The random seed will be set to CURAND_RNG_PSEUDO_DEFAULT if it is not set.
-   */
-  curandGenerator_t curand_generator(const int device_id) {
-    CHECK_GE(device_id, 0);
-    CHECK_LT(device_id, cudnn_handle_.size());
-    if (curand_generator_.at(device_id) == nullptr) {
-      // TODO(wangwei) handle user set seed
-      /*
-      CHECK(seed_.find(tid) != seed_.end());
-      auto seed = seed_[tid];
-      */
-      ActivateDevice(device_id);
-      curandCreateGenerator(&curand_generator_[device_id],
-          CURAND_RNG_PSEUDO_DEFAULT);
-    }
-    return curand_generator_[device_id];
-  }
-
-#ifdef USE_CUDNN
-  cudnnHandle_t cudnn_handle() {
-    return cudnn_handle(std::this_thread::get_id());
-  }
-
-  cudnnHandle_t cudnn_handle(const std::thread::id thread_id) {
-    return cudnn_handle(device_id(thread_id));
-  }
-
-  cudnnHandle_t cudnn_handle(const int device_id) {
-    CHECK_GE(device_id, 0);
-    CHECK_LT(device_id, cudnn_handle_.size());
-  }
-#endif // USE_CUDNN
-
- protected:
-  //!< max num of GPUs per process
-  const int kMaxNumGPU = 64;
-  //!< map from thread id to device id
-  std::unordered_map<std::thread::id, int> device_id_;
-  //!< map from thread id to cpu rand generator
-  std::unordered_map<std::thread::id, std::mt19937 *> rand_generator_;
-  //!< map from thread id to cpu rand generator seed
-  std::unordered_map<std::thread::id, int> seed_;
-#ifdef USE_GPU
-  //!< cublas handler indexed by GPU device ID
-  std::vector<cublasHandle_t> cublas_handle_;
-  //!< cublas rand generator indexed by GPU device ID
-  std::vector<curandGenerator_t> curand_generator_;
-
-#ifdef USE_CUDNN
-  std::vector<cudnnHandle_t> cudnn_handle_;
-#endif
-#endif // USE_GPU
-};
-
-}  // namespace singa
-
-#endif  // SINGA_UTILS_CONTEXT_H_
-    if (cudnn_handle_.at(device_id) == nullptr) {
-      ActivateDevice(device_id);
-      // LOG(ERROR) << "create cudnn handle for device " << device_id;
-      CHECK_EQ(cudnnCreate(&cudnn_handle_[device_id]), CUDNN_STATUS_SUCCESS);
-    }
-    // LOG(ERROR) << "use cudnn handle from device " << device_id;
-    return cudnn_handle_[device_id];
-  }
-#endif
-
-#endif // USE_GPU
-
-#ifdef USE_OPENCL

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 65a81fc..38e6aa3 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -64,6 +64,7 @@ AUX_SOURCE_DIRECTORY(model/metric model_source)
 AUX_SOURCE_DIRECTORY(model/updater model_source)
 #MESSAGE(STATUS "MODEL ${model_source}")
 ADD_LIBRARY(singa_model SHARED ${model_source})
+MESSAGE(STATUS "model linker libs ${SINGA_LINKER_LIBS}")
 TARGET_LINK_LIBRARIES(singa_model ${SINGA_LINKER_LIBS})
 LIST(APPEND SINGA_LINKER_LIBS singa_model)
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index c16bd29..bd3bc70 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -35,21 +35,29 @@ Tensor::Tensor() { device_ = defaultDevice; }
 Tensor::Tensor(const Shape &shape, DataType dtype)
     : data_type_(dtype), device_(defaultDevice), shape_(shape) {
   device_ = defaultDevice;
-  block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_));
+  size_t size = Product(shape_) * SizeOf(data_type_);
+  if (size)
+    block_ = device_->NewBlock(size);
 }
 Tensor::Tensor(Shape &&shape, DataType dtype)
     : data_type_(dtype), device_(defaultDevice), shape_(shape) {
   device_ = defaultDevice;
-  block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_));
+  size_t size = Product(shape_) * SizeOf(data_type_);
+  if (size)
+    block_ = device_->NewBlock(size);
 }
 Tensor::Tensor(const Shape &shape, std::shared_ptr<Device> device,
                DataType dtype)
     : data_type_(dtype), device_(device), shape_(shape) {
-  block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_));
+  size_t size = Product(shape_) * SizeOf(data_type_);
+  if (size)
+    block_ = device_->NewBlock(size);
 }
 Tensor::Tensor(Shape &&shape, std::shared_ptr<Device> device, DataType dtype)
     : data_type_(dtype), device_(device), shape_(shape) {
-  block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_));
+  size_t size = Product(shape_) * SizeOf(data_type_);
+  if (size)
+    block_ = device_->NewBlock(size);
 }
 Tensor::Tensor(const Tensor &in)
     : transpose_(in.transpose_),
@@ -57,7 +65,8 @@ Tensor::Tensor(const Tensor &in)
       device_(in.device_),
       block_(in.block()),
       shape_(in.shape_) {
-  block_->IncRefCount();
+  if (block_ != nullptr)
+    block_->IncRefCount();
 }
 
 Tensor::Tensor(Tensor &&in)
@@ -118,7 +127,8 @@ void Tensor::ToDevice(std::shared_ptr<Device> dst) {
   // TODO(wangwei) the comparison is very strict. May compare against device ID?
   if (device_ != dst) {
     Tensor tmp(shape_, dst, data_type_);
-    if (block_ != nullptr && Size()) tmp.CopyData(*this);
+    if (block_ != nullptr && Size() && block_->initialized())
+      tmp.CopyData(*this);
     if (block_ != nullptr && block_->DecRefCount() == 0)
       device_->FreeBlock(block_);
     block_ = tmp.block_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/core/tensor/tensor_math.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index 7732dd2..1914ca6 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -341,7 +341,7 @@ void SoftmaxCrossEntropyBwd(const size_t batchsize, const size_t dim,
 
 template <typename DType, typename Lang>
 void RowMax(const size_t nrow, const size_t ncol, const Block *in,
-    const Block *ret, Context* ctx) {
+    Block *ret, Context* ctx) {
   LOG(FATAL) << "Not Implemented";
 }
 // **************************************

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/core/tensor/tensor_math_cpp.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index 3e0c8ad..941931d 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -551,7 +551,7 @@ void SoftmaxCrossEntropyBwd<float, lang::Cpp>(const size_t batchsize,
 
 template <>
 void RowMax<float, lang::Cpp>(const size_t nrow, const size_t ncol,
-                              const Block *in, const Block *out, Context *ctx) {
+                              const Block *in, Block *out, Context *ctx) {
   const float *inPtr = static_cast<const float *>(in->data());
   float *outPtr = static_cast<float *>(out->mutable_data());
   for (size_t r = 0; r < nrow; r++) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index 43bfa1b..8b6e939 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -424,7 +424,7 @@ void SoftmaxCrossEntropyBwd<float, lang::Cuda>(const size_t batchsize,
 
 template <>
 void RowMax<float, lang::Cuda>(const size_t nrow, const size_t ncol,
-                               const Block* in, const Block* out,
+                               const Block* in, Block* out,
                                Context* ctx) {
   const float* inPtr = static_cast<const float*>(in->data());
   float* outPtr = static_cast<float*>(out->mutable_data());

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/model/layer/cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc
index 6f04e5c..242a342 100644
--- a/src/model/layer/cudnn_rnn.cc
+++ b/src/model/layer/cudnn_rnn.cc
@@ -30,12 +30,6 @@ CudnnRNN::~CudnnRNN() {
     CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc_));
   if (rnn_desc_ != nullptr)
     CUDNN_CHECK(cudnnDestroyRNNDescriptor(rnn_desc_));
-  if (x_descs_ != nullptr)
-    for (size_t i = 0; i < seqLength_; i++) 
-      CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_descs_[i]));
-  if (y_descs_ != nullptr)
-    for (size_t i = 0; i < seqLength_; i++) 
-      CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_descs_[i]));
   if (hx_desc_ != nullptr)
     CUDNN_CHECK(cudnnDestroyTensorDescriptor(hx_desc_));
   if (hy_desc_ != nullptr)
@@ -44,284 +38,392 @@ CudnnRNN::~CudnnRNN() {
     CUDNN_CHECK(cudnnDestroyTensorDescriptor(cx_desc_));
   if (cy_desc_ != nullptr)
     CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc_));
-}
-
-void CudnnRNN::Setup(const Shape& in_sample, const LayerConf &conf) {
-  RNN::Setup(in_sample, conf);
-  RNNConf rnn_conf = conf.rnn_conf();
-  // convert MB to bytes
-  workspace_byte_limit_ = rnn_conf.workspace_byte_limit() << 20;
-  inputMode_ = ToLowerCase(rnn_conf.inputmode());
-  direction_ = ToLowerCase(rnn_conf.direction());
-  mode_ = ToLowerCase(rnn_conf.mode());
-  CHECK(inputMode_ == "cudnn_linear_input" || inputMode_ == "cudnn_skip_input")
-      << "CudnnRNN only supports two inputmodes: cudnn_linear_input, "
-         "cudnn_skip_input";
-  CHECK(direction_ == "cudnn_undirectional" || direction_ == "cudnn_bidirectional")
-      << "CudnnRNN only supports two directions: cudnn_undirectional, "
-         "cudnn_bidirectional";
-  CHECK(mode_ == "cudnn_rnn_relu" || mode_ == "cudnn_rnn_tanh" ||
-        mode_ == "cudnn_lstm" || mode_ == "cudnn_gru")
-      << "CudnnRNN only supports four modes: cudnn_rnn_relu, "
-         "cudnn_rnn_tanh, cudnn_lstm and cudnn_gru";
-  // the first constant (4) is the size of float
-  // the second constant (2, 8, 6) is the number of sets of params
-  if (mode_ == "cudnn_rnn_relu" || mode_ == "cudnn_rnn_tanh")
-    weightSize_ = 4 * 2 * (hiddenSize_ * in_sample[2] + hiddenSize_);
-  else if (mode_ == "cudnn_lstm")
-    weightSize_ = 4 * 8 * (hiddenSize_ * in_sample[2] + hiddenSize_);
-  else if (mode_ == "cudnn_gru")
-    weightSize_ = 4 * 6 * (hiddenSize_ * in_sample[2] + hiddenSize_);
-  if (direction_ == "cudnn_bidirectional")
-    weightSize_ = weightSize_ * 2;
+  if (dhx_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(dhx_desc_));
+  if (dhy_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(dhy_desc_));
+  if (dcx_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(dcx_desc_));
+  if (dcy_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(dcy_desc_));
+  DestroyIODescriptors();
 }
 
 void CudnnRNN::ToDevice(std::shared_ptr<Device> device) {
-  weight_.ToDevice(device);
+  RNN::ToDevice(device);
   workspace_.ToDevice(device);
+  reserve_space_.ToDevice(device);
 }
 
-void CudnnRNN::InitCudnn(const Tensor &input) {
-  CHECK(!has_init_cudnn_);
-  DataType dtype = input.data_type();
-  auto dev = input.device();
-  Context *ctx = dev->context(0);
-  seqLength_ = input.shape(0);
-  size_t batchsize = input.shape(1); /*(seqLength, minibatch, inputSize) !!! */
-  size_t inputSize = input.shape(2);
-  size_t numDirections;
-  if (direction_ == "cudnn_undirectional")
-    numDirections = 1;
-  else 
-    numDirections = 2;
-  x_descs_ = new cudnnTensorDescriptor_t[seqLength_];
-  y_descs_ = new cudnnTensorDescriptor_t[seqLength_];
-  for (size_t i = 0; i < seqLength_; i++)
-    CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_descs_[i]));
-  for (size_t i = 0; i < seqLength_; i++)
-    CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_descs_[i]));
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&hx_desc_));
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc_));
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&hy_desc_));
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&cy_desc_));
-  CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc_));
+void CudnnRNN::DestroyIODescriptors() {
+  if (x_descs_ != nullptr) {
+    for (size_t i = 0; i < seq_length_; i++) {
+      CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_descs_[i]));
+      CUDNN_CHECK(cudnnDestroyTensorDescriptor(dx_descs_[i]));
+    }
+    delete [] x_descs_;
+    delete [] dx_descs_;
+  }
+  if (y_descs_ != nullptr) {
+    for (size_t i = 0; i < seq_length_; i++) {
+      CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_descs_[i]));
+      CUDNN_CHECK(cudnnDestroyTensorDescriptor(dy_descs_[i]));
+    }
+    delete [] y_descs_;
+    delete [] dy_descs_;
+  }
+}
+
+void CudnnRNN::UpdateIODescriptors(size_t len, const vector<Tensor> &inputs) {
+  bool reset = false;
+  if (seq_length_ < len) {
+    DestroyIODescriptors();
+    x_descs_ = new cudnnTensorDescriptor_t[len];
+    dx_descs_ = new cudnnTensorDescriptor_t[len];
+    y_descs_ = new cudnnTensorDescriptor_t[len];
+    dy_descs_ = new cudnnTensorDescriptor_t[len];
+    for (size_t i = 0; i < len; i++) {
+      CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_descs_[i]));
+      CUDNN_CHECK(cudnnCreateTensorDescriptor(&dx_descs_[i]));
+      CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_descs_[i]));
+      CUDNN_CHECK(cudnnCreateTensorDescriptor(&dy_descs_[i]));
+    }
+    reset = true;
+  }
+
+  for (size_t i = 0; i < len; i++) {
+    CHECK_EQ(inputs[i].shape(1), input_dim_);
+    if (inputs[i].shape(0) != batch_size_ || reset) {
+      int d[3] = {1, 1, 1}, s[3] = {1, 1, 1};
+      d[0] = static_cast<int>(inputs[i].shape(0));
+      CHECK_GT(d[0], 0);
+      d[1] = static_cast<int>(inputs[i].shape(1));
+      s[0] = d[1] * d[2];
+      s[1] = d[2];
+      CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_descs_[i], dtype_, 3, d, s));
+      CUDNN_CHECK(cudnnSetTensorNdDescriptor(dx_descs_[i], dtype_, 3, d, s));
+
+      d[0] = static_cast<int>(inputs[i].shape(0));
+      d[1] = static_cast<int>(hidden_dim_ * num_directions_);
+      s[0] = d[1] * d[2];
+      s[1] = d[2];
+      CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_descs_[i], dtype_, 3, d, s));
+      CUDNN_CHECK(cudnnSetTensorNdDescriptor(dy_descs_[i], dtype_, 3, d, s));
+    }
+  }
+}
+
+// must be called after setting IO descriptors
+void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) {
+  auto ctx = dev->context(0);
   CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc_));
+  size_t state_size;
+  CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size));
+  dropout_state_ = Tensor(Shape{state_size}, dev, kChar);
+  CUDNN_CHECK(cudnnSetDropoutDescriptor(
+      dropout_desc_, ctx->cudnn_handle, dropout_,
+      dropout_state_.block()->mutable_data(), state_size, seed_));
+
   CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_));
+  cudnnRNNInputMode_t input_mode;
+  if (input_mode_ == "linear")
+    input_mode = CUDNN_LINEAR_INPUT;
+  else if (input_mode_ == "skip")
+    input_mode = CUDNN_SKIP_INPUT;
 
+  cudnnDirectionMode_t direction;
+  if (direction_ == "unidirectional")
+    direction = CUDNN_UNIDIRECTIONAL;
+  else if (direction_ == "bidirectional")
+    direction = CUDNN_BIDIRECTIONAL;
+
+  cudnnRNNMode_t rnn_mode;
+  if (rnn_mode_ == "relu")
+    rnn_mode = CUDNN_RNN_RELU;
+  else if (rnn_mode_ == "tanh")
+    rnn_mode = CUDNN_RNN_TANH;
+  else if (rnn_mode_ == "lstm")
+    rnn_mode = CUDNN_LSTM;
+  else if (rnn_mode_ == "gru")
+    rnn_mode = CUDNN_GRU;
+  CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_dim_, num_stacks_,
+                                    dropout_desc_, input_mode, direction,
+                                    rnn_mode, dtype_));
+
+  size_t weight_size;
+  CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0],
+                                    &weight_size, dtype_));
+  // check the size manually calculated
+  CHECK_EQ(weight_size, weight_.Size() * sizeof(float));
+  int filter_dim[3] = {static_cast<int>(weight_size), 1, 1};
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc_));
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(weight_desc_, dtype_,
+                                         CUDNN_TENSOR_NCHW, 3, filter_dim));
+}
 
-  int dimA[3] = {batchsize, inputSize, 1};
-  int strideA[3] = {dimA[2] * dimA[1], dimA[2], 1};
-  for (size_t i = 0; i < seqLength_; i++){
-    dimA[0] = batchsize;
-    dimA[1] = inputSize;
-    dimA[2] = 1;
-    strideA[0] = dimA[2] * dimA[1];
-    strideA[1] = dimA[2];
-    strideA[2] = 1;
-    CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_descs_[i], GetCudnnDataType(dtype), 3,
-                                         dimA, strideA));
-    dimA[0] = batchsize;
-    dimA[1] = hiddenSize_ * numDirections;
-    dimA[2] = 1;
-    strideA[0] = dimA[2] * dimA[1];
-    strideA[1] = dimA[2];
-    strideA[2] = 1;
-    CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_descs_[i], GetCudnnDataType(dtype), 3,
-                                         dimA, strideA));
+void CudnnRNN::ResetHiddenAndCellDescriptors(size_t batch_size) {
+  if (batch_size_ == 0) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc_));
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcx_desc_));
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&cy_desc_));
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcy_desc_));
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&hx_desc_));
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhx_desc_));
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&hy_desc_));
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhy_desc_));
   }
-  
-  dimA[0] = numLayers_;
-  dimA[1] = batchsize;
-  dimA[2] = hiddenSize_ * numDirections;
-  strideA[0] = dimA[2] * dimA[1];
-  strideA[1] = dimA[2];
-  strideA[2] = 1;
-  CUDNN_CHECK(cudnnSetTensorNdDescriptor(hx_desc_, GetCudnnDataType(dtype), 3,
-                                         dimA, strideA));
-  CUDNN_CHECK(cudnnSetTensorNdDescriptor(cx_desc_, GetCudnnDataType(dtype), 3,
-                                         dimA, strideA));
-  CUDNN_CHECK(cudnnSetTensorNdDescriptor(hy_desc_, GetCudnnDataType(dtype), 3,
-                                         dimA, strideA));
-  CUDNN_CHECK(cudnnSetTensorNdDescriptor(cy_desc_, GetCudnnDataType(dtype), 3,
-                                         dimA, strideA));
 
-  size_t dropoutStatesSize;
-  CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &dropoutStatesSize));
-  dropoutStates_ = Tensor(Shape{dropoutStatesSize}, dev, dtype);
-  CUDNN_CHECK(cudnnSetDropoutDescriptor(dropout_desc_, ctx->cudnn_handle, dropout_, this->dropoutStates_.block()->mutable_data(), dropoutStatesSize, 0x01234567));
-  
-  cudnnRNNInputMode_t inputMode;
-  cudnnDirectionMode_t direction;
-  cudnnRNNMode_t mode;
-  
-  if (inputMode_ == "cudnn_linear_input" || inputMode_ == "cudnn_skip_input"){
-    if (inputMode_ == "cudnn_linear_input")
-      inputMode = CUDNN_LINEAR_INPUT;
-    else if (inputMode_ == "cudnn_skip_input")
-      inputMode = CUDNN_SKIP_INPUT;
+  int dim[3] = {1, 1, 1};
+  dim[0] = static_cast<int>(num_stacks_ * num_directions_);
+  dim[1] = static_cast<int>(batch_size);
+  dim[2] = static_cast<int>(hidden_dim_);
+  int stride[3] = {1, 1, 1};
+  stride[0] = dim[1] * dim[2];
+  stride[1] = dim[2];
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(hx_desc_, dtype_, 3, dim, stride));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(dhx_desc_, dtype_, 3, dim, stride));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(hy_desc_, dtype_, 3, dim, stride));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(dhy_desc_, dtype_, 3, dim, stride));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(cx_desc_, dtype_, 3, dim, stride));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(dcx_desc_, dtype_, 3, dim, stride));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(cy_desc_, dtype_, 3, dim, stride));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(dcy_desc_, dtype_, 3, dim, stride));
+}
+
+void CudnnRNN::UpdateSpaces(size_t seq_length, shared_ptr<Device> dev) {
+  size_t count;
+  auto ctx = dev->context(0);
+  CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnn_desc_,
+                                       seq_length, x_descs_, &count));
+  if (workspace_.Size() != count) {
+    workspace_ = Tensor(Shape{count}, dev, kChar);
+    // workspace_.SetValue(0);
   }
-  if (direction_ == "cudnn_undirectional" || direction_ == "cudnn_bidirectional"){
-    if (direction_ == "cudnn_undirectional")
-      direction = CUDNN_UNIDIRECTIONAL;
-    else if (direction_ == "cudnn_bidirectional")
-      direction = CUDNN_BIDIRECTIONAL;
+
+  CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnn_desc_,
+                                             seq_length, x_descs_, &count));
+  if (reserve_space_.Size() != count) {
+    reserve_space_ = Tensor(Shape{count}, dev, kChar);
+    // reserve_space_.SetValue(0);
   }
-  if (mode_ == "cudnn_rnn_relu" || mode_ == "cudnn_rnn_tanh" ||
-        mode_ == "cudnn_lstm" || mode_ == "cudnn_gru"){
-    if (mode_ == "cudnn_rnn_relu")
-      mode = CUDNN_RNN_RELU;
-    else if (mode_ == "cudnn_rnn_tanh")
-      mode = CUDNN_RNN_TANH;
-    else if (mode_ == "cudnn_lstm")
-      mode = CUDNN_LSTM;
-    else if (mode_ == "cudnn_gru")
-      mode = CUDNN_GRU;
+}
+
+void CudnnRNN::UpdateStates(size_t num_x, const vector<Tensor> &inputs) {
+  UpdateIODescriptors(num_x, inputs);
+  size_t new_batch_size = inputs.at(0).shape(0);
+  if (batch_size_ != new_batch_size)
+    ResetHiddenAndCellDescriptors(new_batch_size);
+  if (rnn_desc_ == nullptr)
+    SetRNNDescriptor(inputs.at(0).device());
+  UpdateSpaces(num_x, inputs.at(0).device());
+  batch_size_ = new_batch_size;
+  seq_length_ = num_x;
+}
+
+Tensor CudnnRNN::MergeInputs(size_t num, const vector<Tensor> &in) {
+  if (num == 1)
+    return in.at(0);
+  size_t size = 0;
+  for (size_t i = 0; i < num; i++) size += in.at(i).Size();
+  Tensor out(Shape{size}, in.at(0).device(), in.at(0).data_type());
+  for (size_t i = 0, offset = 0; i < num; i++) {
+    CopyDataToFrom(&out, in.at(i), in.at(i).Size(), offset);
+    offset += in.at(i).Size();
+  }
+  return out;
+}
+
+vector<Tensor> CudnnRNN::SplitOutput(size_t num, size_t dim,
+                                     const vector<Tensor> &in,
+                                     const Tensor output) {
+  vector<Tensor> outputs;
+  if (num == 1) {
+    outputs.push_back(output);
+  } else {
+    for (size_t i = 0, offset = 0; offset < output.Size(); i++) {
+      Shape s{in.at(i).shape(0), dim};
+      Tensor out(s, output.device(), output.data_type());
+      CopyDataToFrom(&out, output, out.Size(), 0, offset);
+      outputs.push_back(out);
+      offset += out.Size();
+    }
+    CHECK_EQ(num, outputs.size());
   }
-  CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hiddenSize_, numLayers_, dropout_desc_, inputMode, direction, mode, GetCudnnDataType(dtype)));
+  return outputs;
+}
 
-  size_t weightSize;
-  CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0], &weightSize, GetCudnnDataType(dtype)));
-  CHECK_EQ(weightSize, weightSize_);
+const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
+  DataType dtype = inputs.at(0).data_type();
+  auto dev = inputs.at(0).device();
 
-  int filterDimA[3] = {weightSize_, 1, 1};
-  CUDNN_CHECK(cudnnSetFilterNdDescriptor(weight_desc_, GetCudnnDataType(dtype), CUDNN_TENSOR_NCHW, 3, filterDimA));
+  // copy input data into a block of contiguous memory
+  // hx (and cx) is at the end of inputs
+  CHECK_GT(inputs.size(), 1u + has_cell_);
+  size_t num_x = inputs.size() - has_cell_ - 1;
+  Tensor input = MergeInputs(num_x, inputs);
+  LOG(INFO) << "input size " << input.Size() << " value " << input.L1();
 
-  
-  CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnn_desc_, seqLength_, x_descs_, &workspace_count_));
-  workspace_ = Tensor(Shape{workspace_count_}, dev, dtype);
+  if (rnn_desc_ != nullptr)
+    CHECK_EQ(dtype_, GetCudnnDataType(dtype))
+      << "Cannot change cudnn data type during training from " << dtype_
+      << " to " << GetCudnnDataType(dtype);
+  else
+    dtype_ = GetCudnnDataType(dtype);
 
-  CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnn_desc_, seqLength_, x_descs_, &ReserveSize_));
-  reserve_ = Tensor(Shape{ReserveSize_}, dev, dtype);
-  has_init_cudnn_ = true;
-}
+  UpdateStates(num_x, inputs);
+  // CheckFowardShapes();
 
-const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor>& inputs) {
-  /*(seqLength, minibatch, inputSize) !!! */
-  singa::Tensor input = inputs[0];
-  singa::Tensor hx = inputs[1];
-  singa:: Tensor cx = inputs[2];
-  CHECK_EQ(input.device()->lang(), kCuda);
-  CHECK_EQ(input.device()->lang(), this->weight_.device()->lang());
-  CHECK_EQ(input.nDim(), 3u);
-  vector<Tensor> data_output;
-  if (flag & kTrain) buf_.push(input);  // buffer the input for backward
-  size_t batchsize = input.shape(1); /*(seqLength, minibatch, inputSize) !!! */
-  DataType dtype = input.data_type();
-  auto dev = input.device();
- 
-  if (!has_init_cudnn_) InitCudnn(input);
- 
-    
-  size_t numDirections;
-  if (direction_ == "cudnn_undirectional")
-    numDirections = 1;
-  else 
-    numDirections = 2;
-  
-  Shape shape{seqLength_, batchsize, hiddenSize_ * numDirections};
-  Tensor output(shape, dev, dtype);
-  Shape shape1{numLayers_, batchsize, hiddenSize_ * numDirections};
-  Tensor hy(shape1, dev, dtype);
-  Tensor cy(shape1, dev, dtype);
-  
-  output.device()->Exec([input, output, hx, hy, cx, cy, this](Context *ctx) {
-    Block *inblock = input.block(), *outblock = output.block(),
-          *wblock = this->weight_.block(), *hxblock = hx.block(), 
-          *hyblock = hy.block(), *cxblock = cx.block(), 
-          *cyblock = cy.block();
-    cudnnRNNForwardTraining(
-        ctx->cudnn_handle, this->rnn_desc_, seqLength_, this->x_descs_, 
-        inblock->data(), this->hx_desc_, hxblock->data(), this->cx_desc_, 
-        cxblock->data(), this->weight_desc_, wblock->data(), this->y_descs_, 
-        outblock->mutable_data(), this->hy_desc_, hyblock->mutable_data(), 
-        cy_desc_, cyblock->mutable_data(), this->workspace_.block()->mutable_data(), 
-        this->workspace_count_ * sizeof(float), this->reserve_.block()->mutable_data(), 
-        this->ReserveSize_ * sizeof(float));
-}, {input.block(), weight_.block(), hx.block(), cx.block()}, 
-   {output.block(), hy.block(), cy.block()}, workspace_.block());
-  buf_.push(output);
-  buf_.push(hx);
-  buf_.push(hy);  // in order to assign shape to dhy
-  buf_.push(cx);
-  buf_.push(cy);  // in order to assign shape to dcy
-  data_output.push_back(output);
-  data_output.push_back(hy);
-  data_output.push_back(cy);
-  return data_output;
+  Shape outshape{input.Size() * hidden_dim_ / input_dim_ * num_directions_};
+  Tensor output(outshape, dev, dtype);
+  LOG(INFO) << "output size " << output.Size();
+  Tensor hx = inputs.at(num_x);
+  Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_dim_};
+  Tensor hy(state_shape, dev, dtype);
+  Tensor cy, cx;
+  if (has_cell_) {
+    cx = inputs.at(num_x + 1);
+    cy.ResetLike(hy);
+  }
+
+  LOG(INFO) << "hidden size " << hy.Size();
+  LOG(INFO) << "weight size " << weight_.Size() << " value " << weight_.L1();
+  Block *inb = input.block(), *outb = output.block(),
+        *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(),
+        *hyb = hy.block(), *cyb = cy.block(),
+        *wspace = this->workspace_.block(),
+        *rspace = this->reserve_space_.block();
+  if (flag & kTrain) {
+    dev->Exec(
+        [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, this](Context *ctx) {
+        // clang-format off
+        cudnnRNNForwardTraining(
+            ctx->cudnn_handle,
+            this->rnn_desc_,
+            this->seq_length_,
+            this->x_descs_, inb->data(),
+            this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+            this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
+            this->weight_desc_, wb->data(),
+            this->y_descs_, outb->mutable_data(),
+            this->hy_desc_, hyb->mutable_data(),
+            this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
+            wspace->mutable_data(),
+            this->workspace_.Size(), rspace->mutable_data(),
+            this->reserve_space_.Size());
+        // clang-format on
+        },
+        {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace});
+    buf_.push(input);
+    buf_.push(output);
+    buf_.push(hx);
+    buf_.push(cx);
+  } else {
+    dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, this](Context *ctx) {
+      // clang-format off
+      cudnnRNNForwardInference(
+          ctx->cudnn_handle,
+          this->rnn_desc_,
+          this->seq_length_,
+          this->x_descs_, inb->data(),
+          this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+          this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
+          this->weight_desc_, wb->data(),
+          this->y_descs_, outb->mutable_data(),
+          this->hy_desc_, hyb->mutable_data(),
+          this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
+          wspace->mutable_data(), this->workspace_.Size());
+      // clang-format on
+    }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace});
+  }
+  auto outputs =
+      SplitOutput(num_x, hidden_dim_ * num_directions_, inputs, output);
+  outputs.push_back(hy);
+  if (has_cell_) outputs.push_back(cy);
+  return outputs;
 }
 
+// TODO(wangwei) check Tensor device to be on cuda?
 const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward(
-    int flag, const vector<Tensor>& grads) {
-  CHECK(has_init_cudnn_);
-  singa::Tensor grad = grads[0];
-  singa::Tensor dhy = grads[1];
-  singa::Tensor dcy = grads[2];
-  CHECK_EQ(grad.device()->lang(), kCuda);
-  CHECK_EQ(grad.nDim(), 3u);
-  CHECK(!buf_.empty());
-  Tensor cy = buf_.top();
+    int flag, const vector<Tensor> &grads) {
+  // dhy (and dcy) is at last
+  const Tensor cx = buf_.top();  // cannot use const Tensor& due to pop()
   buf_.pop();
-  CHECK(!buf_.empty());
-  Tensor cx = buf_.top();
+  const Tensor hx = buf_.top();
   buf_.pop();
-  CHECK(!buf_.empty());
-  Tensor hy = buf_.top();
+  const Tensor y = buf_.top();
   buf_.pop();
-  CHECK(!buf_.empty());
-  Tensor hx = buf_.top();
+  const Tensor x = buf_.top();
   buf_.pop();
-  CHECK(!buf_.empty());
-  Tensor src_output = buf_.top();
-  buf_.pop();
-  CHECK(!buf_.empty());
-  Tensor src_data = buf_.top();
-  buf_.pop();
-  vector<Tensor> param_grad;
-  vector<Tensor> data_grad;
-  Tensor dx;
-  dx.ResetLike(src_data);
-  Tensor dw;
-  dw.ResetLike(weight_);
-  Tensor dhx;
-  dhx.ResetLike(hx);
+
+  auto dev = y.device();
+  auto dtype = y.data_type();
+
+  CHECK_GT(grads.size(), 1u + has_cell_);
+  size_t num_dy = grads.size() - has_cell_ - 1;
+  CHECK_EQ(num_dy, seq_length_);
+  const Tensor dy = MergeInputs(num_dy, grads);
+  CHECK_EQ(dy.Size(), y.Size());
+  const Tensor dhy = grads.at(num_dy);
+  Tensor dcy;
+  if (has_cell_)
+    dcy = grads.at(num_dy + 1);
+
+  Shape xshape{y.Size() * input_dim_ / hidden_dim_ / num_directions_};
+  Tensor dx(xshape, dev, dtype);
+  Tensor dw(weight_.shape(), dev, dtype);
+  Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_dim_};
+  Tensor dhx(state_shape, dev, dtype);
   Tensor dcx;
-  dcx.ResetLike(cx);
+  if (has_cell_)
+    dcx.ResetLike(dhx);
+  dw.SetValue(0.0f);
+  Block *yb = y.block(), *dyb = dy.block(), *dhyb = dhy.block(),
+        *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(),
+        *wb = weight_.block(), *dwb = dw.block(), *hxb = hx.block(),
+        *dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(),
+        *wspace = workspace_.block(), *rspace = reserve_space_.block();
 
+  y.device()->Exec(
+      [yb, dyb, dhyb, dcyb, xb, cxb, wb, dwb, hxb, dxb, dhxb, dcxb, wspace,
+       rspace, this](Context *ctx) {
+        // clang-format off
+        cudnnRNNBackwardData(
+            ctx->cudnn_handle,
+            this->rnn_desc_,
+            this->seq_length_,
+            this->y_descs_, yb->data(),
+            this->dy_descs_, dyb->data(),
+            this->dhy_desc_, dhyb == nullptr ? nullptr : dhyb->data(),
+            this->dcy_desc_, dcyb == nullptr ? nullptr : dcyb->data(),
+            this->weight_desc_, wb->data(),
+            this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+            this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
+            this->dx_descs_, dxb->mutable_data(),
+            this->dhx_desc_, dhxb->mutable_data(),
+            this->dcx_desc_, dcxb == nullptr ? nullptr : dcxb->mutable_data(),
+            wspace->mutable_data(), this->workspace_.Size(),
+            rspace->mutable_data(), this->reserve_space_.Size());
+        cudnnRNNBackwardWeights(
+            ctx->cudnn_handle,
+            this->rnn_desc_,
+            this->seq_length_,
+            this->x_descs_, xb->data(),
+            this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+            this->y_descs_, yb->data(),
+            wspace->data(), this->workspace_.Size(),
+            this->dweight_desc_, dwb->mutable_data(),
+            rspace->data(), this->reserve_space_.Size());
+        // clang-format on
+      },
+      {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace},
+      {dxb, dwb, dhxb, dcxb, wspace, rspace});
 
-  dx.device()->Exec([grad, dw, src_data, src_output, hx, this](Context *ctx) {
-    Block *inblock = src_data.block(), *srcoutblock = src_output.block(), 
-          *dwblock = dw.block(), *hxblock = hx.block();
-    cudnnRNNBackwardWeights(
-        ctx->cudnn_handle, this->rnn_desc_, seqLength_, this->x_descs_, 
-        inblock->data(), this->hx_desc_, hxblock->data(), this->y_descs_, 
-        srcoutblock->data(), this->workspace_.block()->mutable_data(), 
-        this->workspace_count_ * sizeof(float), this->weight_desc_, 
-        dwblock->mutable_data(), this->reserve_.block()->mutable_data(), 
-        this->ReserveSize_ * sizeof(float));
-  }, {src_data.block(), hx.block(), src_output.block()}, {dw.block(), workspace_.block()}); 
-  
-  // LOG(ERROR) << "backward src";
-  dx.device()->Exec([grad, dw, src_output, dx, cy, cx, hy, hx, dhy, dcy, dhx, dcx, this](Context *ctx) {
-    Block *srcoutblock = src_output.block(), *wblock = this->weight_.block(), *dxblock = dx.block(),
-          *dyblock = grad.block(), *cxblock = cx.block(), *hxblock = hx.block(), *dhyblock = dhy.block(),
-          *dcyblock = dcy.block(), *dhxblock = dhx.block(), *dcxblock = dcx.block();
-    cudnnRNNBackwardData(
-        ctx->cudnn_handle, this->rnn_desc_, seqLength_, this->y_descs_, srcoutblock->data(), 
-        this->y_descs_, dyblock->data(), this->hy_desc_, dhyblock->data(), 
-        this->cy_desc_, dcyblock->data(), this->weight_desc_, wblock->data(), 
-        this->hx_desc_, hxblock->data(), this->cx_desc_, cxblock->data(), 
-        this->x_descs_, dxblock->mutable_data(), this->hx_desc_, dhxblock->mutable_data(), 
-        this->cx_desc_, dcxblock->mutable_data(), this->workspace_.block()->mutable_data(), 
-        this->workspace_count_ * sizeof(float), this->reserve_.block()->mutable_data(), 
-        this->ReserveSize_ * sizeof(float));
-  }, {hx.block(), src_output.block(), grad.block(), grad.block(), dhy.block(), dcy.block(), 
-      this->weight_.block(), hx.block(), cx.block()}, 
-     {dx.block(), dhx.block(), dcx.block(), reserve_.block(), workspace_.block()}); 
-  param_grad.push_back(dw);
-  data_grad.push_back(dx);
-  data_grad.push_back(dhx);
-  data_grad.push_back(dcx);
-  return std::make_pair(data_grad, param_grad);
+  vector <Tensor> param_grad{dw};
+  auto data_grads = SplitOutput(num_dy, input_dim_, grads, dx);
+  data_grads.push_back(dhx);
+  if (has_cell_)
+    data_grads.push_back(dcx);
+  return std::make_pair(data_grads, param_grad);
 }
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/model/layer/cudnn_rnn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.h b/src/model/layer/cudnn_rnn.h
index b1e9f43..d2f8db5 100644
--- a/src/model/layer/cudnn_rnn.h
+++ b/src/model/layer/cudnn_rnn.h
@@ -20,6 +20,7 @@
 #define SRC_MODEL_LAYER_CUDNN_RNN_H_
 #include "singa/singa_config.h"
 #ifdef USE_CUDNN
+#if CUDNN_VERSION_MAJOR >= 5
 #include <string>
 #include <utility>
 #include <vector>
@@ -41,45 +42,46 @@ class CudnnRNN : public RNN {
   const std::string layer_type() const override { return "CudnnRNN"; }
 
   const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) override;
-  const std::pair<vector<Tensor>, vector<Tensor>> Backward(int flag, const vector<Tensor>& grads) override;
-
-  /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const Shape& in_sample, const LayerConf &conf) override;
+  const std::pair<vector<Tensor>, vector<Tensor>> Backward(
+      int flag, const vector<Tensor>& grads) override;
 
   void ToDevice(std::shared_ptr<Device> device) override;
 
-  size_t workspace_byte_limit() { return workspace_byte_limit_; }
-  // string prefer() { return prefer_; }
-  string inputMode() const { return inputMode_; }
-  string direction() const { return direction_; }
-  string mode() const { return mode_; }
-
- protected:
-  /// Init cudnn related data structures.
-  void InitCudnn(const Tensor& input);
+  void SetRNNDescriptor(shared_ptr<Device> dev);
+  void ResetHiddenAndCellDescriptors(size_t batch_size);
+  void DestroyIODescriptors();
+  void UpdateIODescriptors(size_t num, const vector<Tensor>& inputs);
+  void UpdateSpaces(size_t num, shared_ptr<Device> dev);
+  void UpdateStates(size_t num, const vector<Tensor>& inputs);
+  Tensor MergeInputs(size_t num, const vector<Tensor>& in);
+  vector<Tensor> SplitOutput(size_t num, size_t dim, const vector<Tensor>& in,
+                             const Tensor output);
 
  protected:
-  bool has_init_cudnn_ = false;
   cudnnTensorDescriptor_t* x_descs_ = nullptr;
+  cudnnTensorDescriptor_t* dx_descs_ = nullptr;
   cudnnTensorDescriptor_t* y_descs_ = nullptr;
+  cudnnTensorDescriptor_t* dy_descs_ = nullptr;
   cudnnTensorDescriptor_t hx_desc_ = nullptr;
+  cudnnTensorDescriptor_t dhx_desc_ = nullptr;
   cudnnTensorDescriptor_t cx_desc_ = nullptr;
+  cudnnTensorDescriptor_t dcx_desc_ = nullptr;
   cudnnTensorDescriptor_t hy_desc_ = nullptr;
+  cudnnTensorDescriptor_t dhy_desc_ = nullptr;
   cudnnTensorDescriptor_t cy_desc_ = nullptr;
+  cudnnTensorDescriptor_t dcy_desc_ = nullptr;
   cudnnFilterDescriptor_t weight_desc_ = nullptr;
+  cudnnFilterDescriptor_t dweight_desc_ = nullptr;
   cudnnRNNDescriptor_t rnn_desc_ = nullptr;
   cudnnDropoutDescriptor_t dropout_desc_ = nullptr;
-  size_t workspace_byte_limit_, workspace_count_;
-  size_t ReserveSize_;
+  cudnnDataType_t dtype_ = CUDNN_DATA_FLOAT;
   Tensor workspace_;
-  string inputMode_;
-  string direction_;
-  string mode_;
-  Tensor reserve_;
-  Tensor dropoutStates_;
+  Tensor reserve_space_;
+  Tensor dropout_state_;
 };
 
 }  // namespace singa
 
+#endif  // CUDNN_VERSION_MAJOR >= 5
 #endif  // USE_CUDNN
 #endif  // SRC_MODEL_LAYER_CUDNN_RNN_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/model/layer/rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/rnn.cc b/src/model/layer/rnn.cc
index 493a5e4..6b831a7 100644
--- a/src/model/layer/rnn.cc
+++ b/src/model/layer/rnn.cc
@@ -19,20 +19,64 @@
 #include "./rnn.h"
 #include <vector>
 #include "singa/model/layer.h"
+#include "singa/utils/string.h"
 
 namespace singa {
 
 void RNN::Setup(const Shape& in_sample, const LayerConf &conf) {
   Layer::Setup(in_sample, conf);
+
   RNNConf rnn_conf = conf.rnn_conf();
-  hiddenSize_ = rnn_conf.hiddensize();
-  CHECK_GT(hiddenSize_, 0u);
+  hidden_dim_ = rnn_conf.hidden_dim();
+  CHECK_GT(hidden_dim_, 0u);
+  num_stacks_ = rnn_conf.num_stacks();
+  CHECK_GT(num_stacks_, 0u);
+  input_dim_ = Product(in_sample);
+  CHECK_GT(input_dim_, 0u);
+  dropout_ = rnn_conf.dropout();
+  CHECK_GE(dropout_, 0);
 
-  numLayers_ = rnn_conf.numlayers();
-  CHECK_GT(numLayers_, 0u);
+  input_mode_ = ToLowerCase(rnn_conf.input_mode());
+  CHECK(input_mode_ == "linear" || input_mode_ == "skip")
+      << "Input mode of " << input_mode_ << " is not supported; Please use "
+      << "'linear' and 'skip'";
 
-  dropout_ = rnn_conf.dropout();
-  CHECK_GE(dropout_, 0u);
+  direction_ = ToLowerCase(rnn_conf.direction());
+  if (direction_ == "unidirectional")
+    num_directions_ = 1;
+  else if (direction_ == "bidirectional")
+    num_directions_ = 2;
+  else
+    LOG(FATAL) << "Direction of " << direction_
+      << " is not supported; Please use unidirectional or bidirectional";
+
+  rnn_mode_ = ToLowerCase(rnn_conf.rnn_mode());
+  if (rnn_mode_ == "lstm") {
+    has_cell_ = true;
+  } else if (rnn_mode_ !="relu" && rnn_mode_ != "tanh" && rnn_mode_ != "gru") {
+    LOG(FATAL) << "RNN memory unit (mode) of " << rnn_mode_
+      << " is not supported Please use 'relu', 'tanh', 'lstm' and 'gru'";
+  }
+  // the first constant (4) is the size of float
+  // the second constant (2, 8, 6) is the number of sets of params
+  int mult = 1;
+  if (rnn_mode_ == "relu" || rnn_mode_ == "tanh")
+    mult *= 1;
+  else if (rnn_mode_ == "lstm")
+    mult *= 4;
+  else if (rnn_mode_ == "gru")
+    mult *= 3;
+  if (direction_ == "bidirectional")
+    mult *= 2;
+
+  size_t weight_size = 0;
+  for (size_t i = 0; i < num_stacks_; i++) {
+    size_t dim = hidden_dim_ * (in_sample[0] +  hidden_dim_ + 2);
+    if (i > 0)
+      dim = hidden_dim_ * (hidden_dim_ +  hidden_dim_ + 2);
+    weight_size += mult * dim;
+  }
+  weight_.Reshape(Shape{weight_size});
 }
 
 const vector<Tensor> RNN::Forward(int flag, const vector<Tensor>& inputs) {
@@ -40,7 +84,8 @@ const vector<Tensor> RNN::Forward(int flag, const vector<Tensor>& inputs) {
   return data_output;
 }
 
-const std::pair<vector<Tensor>, vector<Tensor>> RNN::Backward(int flag, const vector<Tensor>& grads) {
+const std::pair<vector<Tensor>, vector<Tensor>> RNN::Backward(int flag,
+    const vector<Tensor>& grads) {
   vector<Tensor> param_grad;
   vector<Tensor> data_grad;
   return std::make_pair(data_grad, param_grad);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/model/layer/rnn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/rnn.h b/src/model/layer/rnn.h
index ec5a35d..3750021 100644
--- a/src/model/layer/rnn.h
+++ b/src/model/layer/rnn.h
@@ -47,38 +47,37 @@ class RNN : public Layer {
   const std::pair<vector<Tensor>, vector<Tensor>> Backward(
       int flag, const vector<Tensor>& grads) override;
 
-
-  size_t hiddenSize() const { return hiddenSize_; }
-  size_t numLayers() const { return numLayers_; }
-  size_t weightSize() const { return weightSize_; }
-  float dropout() const { return dropout_; }
-  
   void set_weight(Tensor w) {
     weight_.ResetLike(w);
     weight_.CopyData(w);
   }
-
+  const vector<Tensor> param_values() override {
+    return vector<Tensor>{weight_};
+  }
 
   void ToDevice(std::shared_ptr<Device> device) override;
   /// Return the internal state stack, which should be empty at the beginning
-  /// of
-  /// one iteration.
+  /// of one iteration.
   // std::stack<Tensor> states() const { return states_; }
 
+  string input_mode() const { return input_mode_; }
+  string direction() const { return direction_; }
+  string rnn_mode() const { return rnn_mode_; }
+
  protected:
   /// Storing input or output from Forward(), which are used in Backward().
   /// Rules:
   /// 1. push the 'input' or 'output' into states_ if the flag of Forward() is
   ///    for kTrain and 'input' or 'output' is necessary for Backward().
   /// 2. pop data out in Backward().
-  // std::stack<Tensor*> states_;
   std::stack<Tensor> buf_;
-  size_t hiddenSize_;
-  size_t numLayers_;
-  size_t numLinearLayer_;
-  size_t seqLength_;
-  size_t weightSize_; /*all the weights and biases*/
-  float dropout_;
+  bool has_cell_ = false;
+  size_t num_directions_ = 1;
+  size_t input_dim_ = 0, hidden_dim_ = 0, num_stacks_ = 0, seq_length_ = 0;
+  size_t batch_size_ = 0;
+  size_t seed_ = 0x1234567;
+  float dropout_ = 0.0f;
+  string input_mode_, direction_, rnn_mode_;
   Tensor weight_;
 };
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/model/optimizer/adagrad.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/adagrad.cc b/src/model/optimizer/adagrad.cc
index 3ed1855..cdb3fac 100644
--- a/src/model/optimizer/adagrad.cc
+++ b/src/model/optimizer/adagrad.cc
@@ -27,8 +27,10 @@ void AdaGrad::Setup(const OptimizerConf& conf) { delta_ = conf.delta(); }
 // value = value - lr*grad/sqrt(history+delta)
 void AdaGrad::Apply(int step, float lr, const string& name, const Tensor& grad,
                     Tensor& value) {
-  if (history_gradient_.find(name) == history_gradient_.end())
+  if (history_gradient_.find(name) == history_gradient_.end()) {
     history_gradient_[name].ResetLike(value);
+    history_gradient_[name].SetValue(0.0f);
+  }
   Tensor& history = history_gradient_[name];
   Tensor tmp = Square(grad);
   history += tmp;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/model/optimizer/nesterov.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/nesterov.cc b/src/model/optimizer/nesterov.cc
index e5354b1..051499b 100644
--- a/src/model/optimizer/nesterov.cc
+++ b/src/model/optimizer/nesterov.cc
@@ -34,8 +34,10 @@ void Nesterov::Apply(int step, float lr, const string& name, const Tensor& grad,
                      Tensor& value) {
   if (momentum_generator_) {
     float mom = momentum_generator_(step);
-    if (history_gradient_.find(name) == history_gradient_.end())
+    if (history_gradient_.find(name) == history_gradient_.end()) {
       history_gradient_[name].ResetLike(value);
+      history_gradient_[name].SetValue(0.0f);
+    }
     Tensor& history = history_gradient_[name];
     Tensor tmp = history.Clone();
     history *= mom;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/model/optimizer/rmsprop.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/rmsprop.cc b/src/model/optimizer/rmsprop.cc
index 6d77ccd..13e2a75 100644
--- a/src/model/optimizer/rmsprop.cc
+++ b/src/model/optimizer/rmsprop.cc
@@ -32,6 +32,7 @@ void RMSProp::Apply(int step, float lr, const string& name, const Tensor& grad,
                     Tensor& value) {
   if (history_gradient_.find(name) == history_gradient_.end()) {
     history_gradient_[name].ResetLike(value);
+    history_gradient_[name].SetValue(0.0f);
   }
   Tensor& history = history_gradient_[name];
   history *= rho_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/model/optimizer/sgd.cc
----------------------------------------------------------------------
diff --git a/src/model/optimizer/sgd.cc b/src/model/optimizer/sgd.cc
index 2797fc6..d78d5b8 100644
--- a/src/model/optimizer/sgd.cc
+++ b/src/model/optimizer/sgd.cc
@@ -36,8 +36,10 @@ void SGD::Apply(int step, float lr, const string& name, const Tensor& grad,
   if (momentum_generator_) {
     float mom = momentum_generator_(step);
     if (mom != 0) {
-      if (history_gradient_.find(name) == history_gradient_.end())
+      if (history_gradient_.find(name) == history_gradient_.end()) {
         history_gradient_[name].ResetLike(value);
+        history_gradient_[name].SetValue(0.0f);
+      }
       Tensor& history = history_gradient_[name];
       history *= mom;
       Axpy(lr, grad, &history);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index d8193f1..31ebfc3 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -393,19 +393,19 @@ message ConvolutionConf {
 }
 
 message RNNConf {
-  optional uint32 hiddensize = 1; // The number of hiddensize
-  optional uint32 numlayers = 2; // The number of stacked RNN layers
+  optional uint32 hidden_dim = 1; // The number of hiddensize
+  optional uint32 num_stacks = 2; // The number of stacked RNN layers
   optional float dropout = 3 [default = 0];
-  optional int32 workspace_byte_limit = 50 [default = 512];
+  optional bool remember_state = 4 [default = false];
   // cudnn inputmode
-  // options: "cudnn_linear_input", "cudnn_skip_input"
-  optional string inputmode = 51 [default = "cudnn_linear_input"];
+  // options: "linear", "skip"
+  optional string input_mode = 7 [default = "linear"];
   // cudnn direction
-  // options: "cudnn_undirectional", "cudnn_bidirectional"
-  optional string direction = 52 [default = "cudnn_undirectional"];
+  // options: "unidirectional", "bidirectional"
+  optional string direction = 8 [default = "unidirectional"];
   // cudnn RNN mode
-  // options: "cudnn_rnn_relu", "cudnn_rnn_tanh", "cudnn_lstm", "cudnn_gru"
-  optional string mode = 53 [default = "cudnn_rnn_relu"];
+  // options: "relu", "tanh", "lstm", "gru"
+  optional string rnn_mode = 9 [default = "relu"];
 }
 
 /*

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8e0b1083/test/singa/test_cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_rnn.cc b/test/singa/test_cudnn_rnn.cc
index 1a79d7b..ebbf0aa 100644
--- a/test/singa/test_cudnn_rnn.cc
+++ b/test/singa/test_cudnn_rnn.cc
@@ -26,187 +26,154 @@
 
 using singa::CudnnRNN;
 using singa::Shape;
-TEST(CudnnRNN, Setup) {
+using singa::Tensor;
+class TestCudnnRNN : public ::testing::Test {
+  protected:
+    virtual void SetUp() {
+      singa::RNNConf *rnnconf = conf.mutable_rnn_conf();
+      rnnconf->set_hidden_dim(hidden_dim);
+      rnnconf->set_num_stacks(1);
+      rnnconf->set_dropout(1);
+      rnnconf->set_input_mode("linear");
+      rnnconf->set_direction("unidirectional");
+      rnnconf->set_rnn_mode("tanh");
+    }
+    singa::LayerConf conf;
+    size_t hidden_dim = 4;
+};
+
+TEST_F(TestCudnnRNN, Setup) {
   CudnnRNN rnn;
   EXPECT_EQ("CudnnRNN", rnn.layer_type());
-
-  singa::LayerConf conf;
-  singa::RNNConf *rnnconf = conf.mutable_rnn_conf();
-  rnnconf->set_hiddensize(2);
-  rnnconf->set_numlayers(1);
-  rnnconf->set_dropout(0); 
-  rnnconf->set_inputmode("cudnn_linear_input");
-  rnnconf->set_direction("cudnn_undirectional");
-  rnnconf->set_mode("cudnn_rnn_tanh");
-  // MB
-  rnnconf->set_workspace_byte_limit(256);
-  rnn.Setup(Shape{4, 1, 2}, conf);
-
-  EXPECT_EQ(2u, rnn.hiddenSize());
-  EXPECT_EQ(1u, rnn.numLayers());
-  EXPECT_EQ(0u, rnn.dropout());
-  EXPECT_EQ("cudnn_linear_input", rnn.inputMode());
-  EXPECT_EQ("cudnn_undirectional", rnn.direction());
-  EXPECT_EQ("cudnn_rnn_tanh", rnn.mode());
-  EXPECT_EQ(256u << 20, rnn.workspace_byte_limit());
+  rnn.Setup(Shape{2}, conf);
+  auto weight = rnn.param_values().at(0);
+  EXPECT_EQ(weight.Size(), hidden_dim * (2 + hidden_dim + 2));
 }
 
-TEST(CudnnRNN, Forward) {
+TEST_F(TestCudnnRNN, Forward) {
   auto cuda = std::make_shared<singa::CudaGPU>();
   const size_t seqLength = 4, batchsize = 1, dim = 2;
-  const size_t numLayers = 1, hiddensize = 2, numDirections = 1;
   const float x[seqLength * batchsize * dim] = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                           1.0f, 1.0f, 1.0f};
-  singa::Tensor in(singa::Shape{seqLength, batchsize, dim}, cuda);
-  in.CopyDataFromHostPtr(x, seqLength * batchsize * dim);
 
+  vector<Tensor> inputs;
+  for (size_t i = 0; i < seqLength; i++) {
+    Tensor t(Shape{batchsize, dim}, cuda);
+    t.CopyDataFromHostPtr(x + i * t.Size(), t.Size());
+    inputs.push_back(t);
+  }
 
-  
-  const float hx_data[numLayers * batchsize * hiddensize * numDirections] = {1.0f, 1.0f};
-  singa::Tensor hx(singa::Shape{numLayers, batchsize, hiddensize * numDirections}, cuda);
-  hx.CopyDataFromHostPtr(hx_data, numLayers * batchsize * hiddensize * numDirections);
+  singa::Tensor hx;
+  inputs.push_back(hx);
 
-  const float cx_data[numLayers * batchsize * hiddensize * numDirections] = {1.0f, 1.0f};
-  singa::Tensor cx(singa::Shape{numLayers, batchsize, hiddensize * numDirections}, cuda);
-  cx.CopyDataFromHostPtr(cx_data, numLayers * batchsize * hiddensize * numDirections);
-  
   CudnnRNN rnn;
-  
-  singa::LayerConf conf;
-  singa::RNNConf *rnnconf = conf.mutable_rnn_conf();
-  rnnconf->set_hiddensize(2);
-  rnnconf->set_numlayers(1);
-  rnnconf->set_dropout(0);
-  rnnconf->set_inputmode("cudnn_linear_input");
-  rnnconf->set_direction("cudnn_undirectional");
-  rnnconf->set_mode("cudnn_rnn_tanh");
-  // MB
-  rnnconf->set_workspace_byte_limit(256);
-  rnn.Setup(Shape{4, 1, 2}, conf);
- 
-  
-  size_t weightSize = rnn.weightSize();
+  rnn.Setup(Shape{dim}, conf);
+  rnn.ToDevice(cuda);
+
+  auto weight = rnn.param_values().at(0);
+  size_t weightSize = weight.Size();
   float we[weightSize];
+  float wvalue = 0.1f;
   for (size_t i = 0; i < weightSize; i++)
-    we[i] = 1.0f;
-  singa::Tensor weight(singa::Shape{weightSize, 1, 1}, cuda);
+    we[i] = wvalue;
   weight.CopyDataFromHostPtr(we, weightSize);
-  rnn.set_weight(weight);
- 
-  vector<singa::Tensor> input_array;
-  input_array.push_back(in);
-  input_array.push_back(hx);
-  input_array.push_back(cx);
-  const auto ret = rnn.Forward(singa::kTrain, input_array);
-  // singa::CppCPU host(0, 1);
-  singa::Tensor out1 = ret[0];
-  out1.ToHost();
-  const float *outptr1 = out1.data<float>();
-  EXPECT_EQ(8u, out1.Size());
-  EXPECT_NEAR(1.0f, outptr1[0], 0.0001); // tanh 6
-  EXPECT_NEAR(1.0f, outptr1[1], 0.0001);
-  EXPECT_NEAR(1.0f, outptr1[2], 0.0001);
-  EXPECT_NEAR(1.0f, outptr1[3], 0.0001);
-  EXPECT_NEAR(1.0f, outptr1[4], 0.0001);
-  EXPECT_NEAR(1.0f, outptr1[5], 0.0001);
-  EXPECT_NEAR(1.0f, outptr1[6], 0.0001);
-  EXPECT_NEAR(1.0f, outptr1[7], 0.0001);
-
-  singa::Tensor hy1 = ret[1];
-  hy1.ToHost();
-  const float *hyptr1 = hy1.data<float>();
-  EXPECT_EQ(2u, hy1.Size());
-  EXPECT_NEAR(1.0f, hyptr1[0], 0.0001);
-  EXPECT_NEAR(1.0f, hyptr1[1], 0.0001);
+
+  const auto ret = rnn.Forward(singa::kEval, inputs);
+  EXPECT_EQ(ret.size(), seqLength + 1);
+  vector<float> hxptr(hidden_dim, 0.0f);
+  for (size_t i = 0; i < seqLength; i++) {
+    auto y = ret[i];
+    y.ToHost();
+    auto yptr = y.data<float>();
+    vector<float> tmp;
+    for (size_t j = 0; j < hidden_dim; j++) {
+      float ty = 0;
+      for (size_t k = 0; k < dim; k++) {
+        ty += x[i * dim + k] * wvalue;
+      }
+      ty += wvalue;
+      for (size_t k = 0; k < hidden_dim; k++) {
+        ty += hxptr[k] * wvalue;
+      }
+      ty += wvalue;
+      ty = tanh(ty);
+      EXPECT_NEAR(ty, yptr[j], 1e-4);
+      tmp.push_back(ty);
+    }
+    std::copy(tmp.begin(), tmp.end(), hxptr.begin());
+  }
 }
 
-TEST(CudnnRNN, Backward) {
-  // src_data
+TEST_F(TestCudnnRNN, Backward) {
   auto cuda = std::make_shared<singa::CudaGPU>();
   const size_t seqLength = 4, batchsize = 1, dim = 2;
-  const size_t numLayers = 1, hiddensize = 2, numDirections = 1;
   const float x[seqLength * batchsize * dim] = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                                           1.0f, 1.0f, 1.0f};
-  singa::Tensor in(singa::Shape{seqLength, batchsize, dim}, cuda);
-  in.CopyDataFromHostPtr(x, seqLength * batchsize * dim);
 
-  const float hx_data[numLayers * batchsize * hiddensize * numDirections] = {1.0f, 1.0f};
-  singa::Tensor hx(singa::Shape{numLayers, batchsize, hiddensize * numDirections}, cuda);
-  hx.CopyDataFromHostPtr(hx_data, numLayers * batchsize * hiddensize * numDirections);
+  vector<Tensor> inputs;
+  for (size_t i = 0; i < seqLength; i++) {
+    Tensor t(Shape{batchsize, dim}, cuda);
+    t.CopyDataFromHostPtr(x + i * t.Size(), t.Size());
+    inputs.push_back(t);
+  }
 
-  const float cx_data[numLayers * batchsize * hiddensize * numDirections] = {1.0f, 1.0f};
-  singa::Tensor cx(singa::Shape{numLayers, batchsize, hiddensize * numDirections}, cuda);
-  cx.CopyDataFromHostPtr(cx_data, numLayers * batchsize * hiddensize * numDirections);
+  singa::Tensor hx;
+  inputs.push_back(hx);
 
   CudnnRNN rnn;
+  rnn.Setup(Shape{dim}, conf);
+  rnn.ToDevice(cuda);
 
-  singa::LayerConf conf;
-  singa::RNNConf *rnnconf = conf.mutable_rnn_conf();
-  rnnconf->set_hiddensize(2);
-  rnnconf->set_numlayers(1);
-  rnnconf->set_dropout(0);
-  rnnconf->set_inputmode("cudnn_linear_input");
-  rnnconf->set_direction("cudnn_undirectional");
-  rnnconf->set_mode("cudnn_rnn_tanh");
-  // MB
-  rnnconf->set_workspace_byte_limit(256);
-  rnn.Setup(Shape{4, 1, 2}, conf);
-
-  size_t weightSize = rnn.weightSize();
+  auto weight = rnn.param_values().at(0);
+  size_t weightSize = weight.Size();
   float we[weightSize];
+  float wvalue = 0.1f;
   for (size_t i = 0; i < weightSize; i++)
-    we[i] = 1.0f;
-  singa::Tensor weight(singa::Shape{weightSize, 1, 1}, cuda);
+    we[i] = wvalue;
   weight.CopyDataFromHostPtr(we, weightSize);
-  rnn.set_weight(weight);
-
-
-  vector<singa::Tensor> input_array;
-  input_array.push_back(in);
-  input_array.push_back(hx);
-  input_array.push_back(cx);
-  const auto ret = rnn.Forward(singa::kTrain, input_array);
-
-  // grad
-  const float dy[seqLength * batchsize * hiddensize * numDirections] = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
-  singa::Tensor grad(singa::Shape{seqLength, batchsize, hiddensize * numDirections},
-                     cuda);
-  grad.CopyDataFromHostPtr(dy, seqLength * batchsize * hiddensize * numDirections);
-
-  const float dhy_data[numLayers * batchsize * hiddensize * numDirections] = {1.0f, 1.0f};
-  singa::Tensor dhy(singa::Shape{numLayers, batchsize, hiddensize * numDirections},
-                     cuda);
-  dhy.CopyDataFromHostPtr(dhy_data, numLayers * batchsize * hiddensize * numDirections);
-
-  const float dcy_data[numLayers * batchsize * hiddensize * numDirections] = {1.0f, 1.0f};
-  singa::Tensor dcy(singa::Shape{numLayers, batchsize, hiddensize * numDirections},
-                     cuda);
-  dcy.CopyDataFromHostPtr(dcy_data, numLayers * batchsize * hiddensize * numDirections);
-
-  vector<singa::Tensor> grad_array;
-  grad_array.push_back(grad);
-  grad_array.push_back(dhy);
-  grad_array.push_back(dcy);
-  const auto ret_back = rnn.Backward(singa::kTrain, grad_array);
-  // singa::CppCPU host(0, 1);
-  singa::Tensor in_grad = ret_back.first[0];
-  in_grad.ToHost();
-  const float *dx = in_grad.data<float>();
-  EXPECT_EQ(8u, in_grad.Size());
-  EXPECT_NEAR(0.14, dx[0], 0.0001);
-  EXPECT_NEAR(0.14, dx[1], 0.0001);
-  EXPECT_NEAR(0.1596, dx[2], 0.0001);
-  EXPECT_NEAR(0.1596, dx[3], 0.0001);
-  EXPECT_NEAR(0.1623, dx[4], 0.0001);
-  EXPECT_NEAR(0.1623, dx[5], 0.0001);
-  EXPECT_NEAR(0.1627, dx[6], 0.0001);
-  EXPECT_NEAR(0.1627, dx[7], 0.0001);
-
-  singa::Tensor dhx_grad = ret_back.first[1];
-  dhx_grad.ToHost();
-  const float *dhx = dhx_grad.data<float>();
-  EXPECT_EQ(2u, dhx_grad.Size());
-  EXPECT_NEAR(0.1627, dhx[0], 0.0001);
-  EXPECT_NEAR(0.1627, dhx[1], 0.0001);
+
+  const auto outs = rnn.Forward(singa::kTrain, inputs);
+
+  float dyptr[seqLength * batchsize * hidden_dim];
+  for (size_t i = 0; i < seqLength * batchsize * hidden_dim; i++)
+    dyptr[i] = i * 0.1f;
+  vector<Tensor> grads;
+  for (size_t i = 0; i < seqLength; i++) {
+    Tensor dy(Shape{batchsize, hidden_dim}, cuda);
+    dy.CopyDataFromHostPtr(dyptr + i * dy.Size(), dy.Size());
+    grads.push_back(dy);
+  }
+  Tensor dhy;
+  grads.push_back(dhy);
+  vector<float> dhyptr(hidden_dim, 0.0f);
+  const auto ret = rnn.Backward(singa::kTrain, grads);
+  for (size_t i = seqLength - 1; i > 0 ; i --) {
+    auto dx = ret.first[i];
+    auto y = outs[i].Clone();
+    y.ToHost();
+    dx.ToHost();
+    auto dxptr = dx.data<float>();
+    auto yptr = y.data<float>();
+    for (size_t j = 0; j < hidden_dim; j++) {
+      dhyptr[j] += dyptr[i * hidden_dim + j];
+      dhyptr[j] *= 1 - yptr[j] * yptr[j];
+    }
+    for (size_t k = 0; k < dim; k++) {
+      float tdx = 0;
+      for (size_t j = 0; j < hidden_dim; j++) {
+        tdx += dhyptr[j] * wvalue;
+      }
+      EXPECT_NEAR(tdx, dxptr[k], 1e-4);
+    }
+    vector<float> tmp;
+    for (size_t k = 0; k < hidden_dim; k++) {
+      float tdhy = 0;
+      for (size_t j = 0; j < hidden_dim; j++) {
+        tdhy += dhyptr[j] * wvalue;
+      }
+      tmp.push_back(tdhy);
+    }
+    std::copy(tmp.begin(), tmp.end(), dhyptr.begin());
+  }
 }
 #endif  // USE_CUDNN


[2/3] incubator-singa git commit: SINGA-218 Implementation for RNN CUDNN version

Posted by zh...@apache.org.
SINGA-218 Implementation for RNN CUDNN version

- cudnn rnn implementation (cudnn_rnn,h, cudnn_rnn.cc, rnn.cc, rnn.h, test_cudnn_rnn.cc).
- The weight shape now are manually calculated instead of using API provided by CUDNN.
- Test for RNN_cudnn_Tanh (unidirectional, 1 hidden layer).


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/c51f9448
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/c51f9448
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/c51f9448

Branch: refs/heads/dev
Commit: c51f9448284ea905db592fb0c09d2bb0e8801828
Parents: 28678ae
Author: zhaojing <zh...@comp.nus.edu.sg>
Authored: Sat Jun 25 00:15:06 2016 +0800
Committer: Wei Wang <wa...@gmail.com>
Committed: Wed Aug 10 00:43:11 2016 +0800

----------------------------------------------------------------------
 src/model/layer/cudnn_rnn.cc | 328 ++++++++++++++++++++++++++++++++++++++
 src/model/layer/cudnn_rnn.h  |  85 ++++++++++
 src/model/layer/rnn.cc       |  53 ++++++
 src/model/layer/rnn.h        |  31 +++-
 src/proto/model.proto        |  17 ++
 test/singa/test_cudnn_rnn.cc | 212 ++++++++++++++++++++++++
 6 files changed, 720 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c51f9448/src/model/layer/cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc
new file mode 100644
index 0000000..6f04e5c
--- /dev/null
+++ b/src/model/layer/cudnn_rnn.cc
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "./cudnn_rnn.h"
+#ifdef USE_CUDNN
+#include <cudnn.h>
+#include <chrono>
+#include "./cudnn_utils.h"
+#include "singa/utils/logging.h"
+
+namespace singa {
+CudnnRNN::~CudnnRNN() {
+  if (weight_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyFilterDescriptor(weight_desc_));
+  if (dropout_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc_));
+  if (rnn_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyRNNDescriptor(rnn_desc_));
+  if (x_descs_ != nullptr)
+    for (size_t i = 0; i < seqLength_; i++) 
+      CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_descs_[i]));
+  if (y_descs_ != nullptr)
+    for (size_t i = 0; i < seqLength_; i++) 
+      CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_descs_[i]));
+  if (hx_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(hx_desc_));
+  if (hy_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(hy_desc_));
+  if (cx_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(cx_desc_));
+  if (cy_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc_));
+}
+
+void CudnnRNN::Setup(const Shape& in_sample, const LayerConf &conf) {
+  RNN::Setup(in_sample, conf);
+  RNNConf rnn_conf = conf.rnn_conf();
+  // convert MB to bytes
+  workspace_byte_limit_ = rnn_conf.workspace_byte_limit() << 20;
+  inputMode_ = ToLowerCase(rnn_conf.inputmode());
+  direction_ = ToLowerCase(rnn_conf.direction());
+  mode_ = ToLowerCase(rnn_conf.mode());
+  CHECK(inputMode_ == "cudnn_linear_input" || inputMode_ == "cudnn_skip_input")
+      << "CudnnRNN only supports two inputmodes: cudnn_linear_input, "
+         "cudnn_skip_input";
+  CHECK(direction_ == "cudnn_undirectional" || direction_ == "cudnn_bidirectional")
+      << "CudnnRNN only supports two directions: cudnn_undirectional, "
+         "cudnn_bidirectional";
+  CHECK(mode_ == "cudnn_rnn_relu" || mode_ == "cudnn_rnn_tanh" ||
+        mode_ == "cudnn_lstm" || mode_ == "cudnn_gru")
+      << "CudnnRNN only supports four modes: cudnn_rnn_relu, "
+         "cudnn_rnn_tanh, cudnn_lstm and cudnn_gru";
+  // the first constant (4) is the size of float
+  // the second constant (2, 8, 6) is the number of sets of params
+  if (mode_ == "cudnn_rnn_relu" || mode_ == "cudnn_rnn_tanh")
+    weightSize_ = 4 * 2 * (hiddenSize_ * in_sample[2] + hiddenSize_);
+  else if (mode_ == "cudnn_lstm")
+    weightSize_ = 4 * 8 * (hiddenSize_ * in_sample[2] + hiddenSize_);
+  else if (mode_ == "cudnn_gru")
+    weightSize_ = 4 * 6 * (hiddenSize_ * in_sample[2] + hiddenSize_);
+  if (direction_ == "cudnn_bidirectional")
+    weightSize_ = weightSize_ * 2;
+}
+
+void CudnnRNN::ToDevice(std::shared_ptr<Device> device) {
+  weight_.ToDevice(device);
+  workspace_.ToDevice(device);
+}
+
+void CudnnRNN::InitCudnn(const Tensor &input) {
+  CHECK(!has_init_cudnn_);
+  DataType dtype = input.data_type();
+  auto dev = input.device();
+  Context *ctx = dev->context(0);
+  seqLength_ = input.shape(0);
+  size_t batchsize = input.shape(1); /*(seqLength, minibatch, inputSize) !!! */
+  size_t inputSize = input.shape(2);
+  size_t numDirections;
+  if (direction_ == "cudnn_undirectional")
+    numDirections = 1;
+  else 
+    numDirections = 2;
+  x_descs_ = new cudnnTensorDescriptor_t[seqLength_];
+  y_descs_ = new cudnnTensorDescriptor_t[seqLength_];
+  for (size_t i = 0; i < seqLength_; i++)
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_descs_[i]));
+  for (size_t i = 0; i < seqLength_; i++)
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_descs_[i]));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&hx_desc_));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc_));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&hy_desc_));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&cy_desc_));
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc_));
+  CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc_));
+  CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_));
+
+
+  int dimA[3] = {batchsize, inputSize, 1};
+  int strideA[3] = {dimA[2] * dimA[1], dimA[2], 1};
+  for (size_t i = 0; i < seqLength_; i++){
+    dimA[0] = batchsize;
+    dimA[1] = inputSize;
+    dimA[2] = 1;
+    strideA[0] = dimA[2] * dimA[1];
+    strideA[1] = dimA[2];
+    strideA[2] = 1;
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_descs_[i], GetCudnnDataType(dtype), 3,
+                                         dimA, strideA));
+    dimA[0] = batchsize;
+    dimA[1] = hiddenSize_ * numDirections;
+    dimA[2] = 1;
+    strideA[0] = dimA[2] * dimA[1];
+    strideA[1] = dimA[2];
+    strideA[2] = 1;
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_descs_[i], GetCudnnDataType(dtype), 3,
+                                         dimA, strideA));
+  }
+  
+  dimA[0] = numLayers_;
+  dimA[1] = batchsize;
+  dimA[2] = hiddenSize_ * numDirections;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(hx_desc_, GetCudnnDataType(dtype), 3,
+                                         dimA, strideA));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(cx_desc_, GetCudnnDataType(dtype), 3,
+                                         dimA, strideA));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(hy_desc_, GetCudnnDataType(dtype), 3,
+                                         dimA, strideA));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(cy_desc_, GetCudnnDataType(dtype), 3,
+                                         dimA, strideA));
+
+  size_t dropoutStatesSize;
+  CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &dropoutStatesSize));
+  dropoutStates_ = Tensor(Shape{dropoutStatesSize}, dev, dtype);
+  CUDNN_CHECK(cudnnSetDropoutDescriptor(dropout_desc_, ctx->cudnn_handle, dropout_, this->dropoutStates_.block()->mutable_data(), dropoutStatesSize, 0x01234567));
+  
+  cudnnRNNInputMode_t inputMode;
+  cudnnDirectionMode_t direction;
+  cudnnRNNMode_t mode;
+  
+  if (inputMode_ == "cudnn_linear_input" || inputMode_ == "cudnn_skip_input"){
+    if (inputMode_ == "cudnn_linear_input")
+      inputMode = CUDNN_LINEAR_INPUT;
+    else if (inputMode_ == "cudnn_skip_input")
+      inputMode = CUDNN_SKIP_INPUT;
+  }
+  if (direction_ == "cudnn_undirectional" || direction_ == "cudnn_bidirectional"){
+    if (direction_ == "cudnn_undirectional")
+      direction = CUDNN_UNIDIRECTIONAL;
+    else if (direction_ == "cudnn_bidirectional")
+      direction = CUDNN_BIDIRECTIONAL;
+  }
+  if (mode_ == "cudnn_rnn_relu" || mode_ == "cudnn_rnn_tanh" ||
+        mode_ == "cudnn_lstm" || mode_ == "cudnn_gru"){
+    if (mode_ == "cudnn_rnn_relu")
+      mode = CUDNN_RNN_RELU;
+    else if (mode_ == "cudnn_rnn_tanh")
+      mode = CUDNN_RNN_TANH;
+    else if (mode_ == "cudnn_lstm")
+      mode = CUDNN_LSTM;
+    else if (mode_ == "cudnn_gru")
+      mode = CUDNN_GRU;
+  }
+  CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hiddenSize_, numLayers_, dropout_desc_, inputMode, direction, mode, GetCudnnDataType(dtype)));
+
+  size_t weightSize;
+  CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0], &weightSize, GetCudnnDataType(dtype)));
+  CHECK_EQ(weightSize, weightSize_);
+
+  int filterDimA[3] = {weightSize_, 1, 1};
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(weight_desc_, GetCudnnDataType(dtype), CUDNN_TENSOR_NCHW, 3, filterDimA));
+
+  
+  CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnn_desc_, seqLength_, x_descs_, &workspace_count_));
+  workspace_ = Tensor(Shape{workspace_count_}, dev, dtype);
+
+  CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnn_desc_, seqLength_, x_descs_, &ReserveSize_));
+  reserve_ = Tensor(Shape{ReserveSize_}, dev, dtype);
+  has_init_cudnn_ = true;
+}
+
+const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor>& inputs) {
+  /*(seqLength, minibatch, inputSize) !!! */
+  singa::Tensor input = inputs[0];
+  singa::Tensor hx = inputs[1];
+  singa:: Tensor cx = inputs[2];
+  CHECK_EQ(input.device()->lang(), kCuda);
+  CHECK_EQ(input.device()->lang(), this->weight_.device()->lang());
+  CHECK_EQ(input.nDim(), 3u);
+  vector<Tensor> data_output;
+  if (flag & kTrain) buf_.push(input);  // buffer the input for backward
+  size_t batchsize = input.shape(1); /*(seqLength, minibatch, inputSize) !!! */
+  DataType dtype = input.data_type();
+  auto dev = input.device();
+ 
+  if (!has_init_cudnn_) InitCudnn(input);
+ 
+    
+  size_t numDirections;
+  if (direction_ == "cudnn_undirectional")
+    numDirections = 1;
+  else 
+    numDirections = 2;
+  
+  Shape shape{seqLength_, batchsize, hiddenSize_ * numDirections};
+  Tensor output(shape, dev, dtype);
+  Shape shape1{numLayers_, batchsize, hiddenSize_ * numDirections};
+  Tensor hy(shape1, dev, dtype);
+  Tensor cy(shape1, dev, dtype);
+  
+  output.device()->Exec([input, output, hx, hy, cx, cy, this](Context *ctx) {
+    Block *inblock = input.block(), *outblock = output.block(),
+          *wblock = this->weight_.block(), *hxblock = hx.block(), 
+          *hyblock = hy.block(), *cxblock = cx.block(), 
+          *cyblock = cy.block();
+    cudnnRNNForwardTraining(
+        ctx->cudnn_handle, this->rnn_desc_, seqLength_, this->x_descs_, 
+        inblock->data(), this->hx_desc_, hxblock->data(), this->cx_desc_, 
+        cxblock->data(), this->weight_desc_, wblock->data(), this->y_descs_, 
+        outblock->mutable_data(), this->hy_desc_, hyblock->mutable_data(), 
+        cy_desc_, cyblock->mutable_data(), this->workspace_.block()->mutable_data(), 
+        this->workspace_count_ * sizeof(float), this->reserve_.block()->mutable_data(), 
+        this->ReserveSize_ * sizeof(float));
+}, {input.block(), weight_.block(), hx.block(), cx.block()}, 
+   {output.block(), hy.block(), cy.block()}, workspace_.block());
+  buf_.push(output);
+  buf_.push(hx);
+  buf_.push(hy);  // in order to assign shape to dhy
+  buf_.push(cx);
+  buf_.push(cy);  // in order to assign shape to dcy
+  data_output.push_back(output);
+  data_output.push_back(hy);
+  data_output.push_back(cy);
+  return data_output;
+}
+
+const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward(
+    int flag, const vector<Tensor>& grads) {
+  CHECK(has_init_cudnn_);
+  singa::Tensor grad = grads[0];
+  singa::Tensor dhy = grads[1];
+  singa::Tensor dcy = grads[2];
+  CHECK_EQ(grad.device()->lang(), kCuda);
+  CHECK_EQ(grad.nDim(), 3u);
+  CHECK(!buf_.empty());
+  Tensor cy = buf_.top();
+  buf_.pop();
+  CHECK(!buf_.empty());
+  Tensor cx = buf_.top();
+  buf_.pop();
+  CHECK(!buf_.empty());
+  Tensor hy = buf_.top();
+  buf_.pop();
+  CHECK(!buf_.empty());
+  Tensor hx = buf_.top();
+  buf_.pop();
+  CHECK(!buf_.empty());
+  Tensor src_output = buf_.top();
+  buf_.pop();
+  CHECK(!buf_.empty());
+  Tensor src_data = buf_.top();
+  buf_.pop();
+  vector<Tensor> param_grad;
+  vector<Tensor> data_grad;
+  Tensor dx;
+  dx.ResetLike(src_data);
+  Tensor dw;
+  dw.ResetLike(weight_);
+  Tensor dhx;
+  dhx.ResetLike(hx);
+  Tensor dcx;
+  dcx.ResetLike(cx);
+
+
+  dx.device()->Exec([grad, dw, src_data, src_output, hx, this](Context *ctx) {
+    Block *inblock = src_data.block(), *srcoutblock = src_output.block(), 
+          *dwblock = dw.block(), *hxblock = hx.block();
+    cudnnRNNBackwardWeights(
+        ctx->cudnn_handle, this->rnn_desc_, seqLength_, this->x_descs_, 
+        inblock->data(), this->hx_desc_, hxblock->data(), this->y_descs_, 
+        srcoutblock->data(), this->workspace_.block()->mutable_data(), 
+        this->workspace_count_ * sizeof(float), this->weight_desc_, 
+        dwblock->mutable_data(), this->reserve_.block()->mutable_data(), 
+        this->ReserveSize_ * sizeof(float));
+  }, {src_data.block(), hx.block(), src_output.block()}, {dw.block(), workspace_.block()}); 
+  
+  // LOG(ERROR) << "backward src";
+  dx.device()->Exec([grad, dw, src_output, dx, cy, cx, hy, hx, dhy, dcy, dhx, dcx, this](Context *ctx) {
+    Block *srcoutblock = src_output.block(), *wblock = this->weight_.block(), *dxblock = dx.block(),
+          *dyblock = grad.block(), *cxblock = cx.block(), *hxblock = hx.block(), *dhyblock = dhy.block(),
+          *dcyblock = dcy.block(), *dhxblock = dhx.block(), *dcxblock = dcx.block();
+    cudnnRNNBackwardData(
+        ctx->cudnn_handle, this->rnn_desc_, seqLength_, this->y_descs_, srcoutblock->data(), 
+        this->y_descs_, dyblock->data(), this->hy_desc_, dhyblock->data(), 
+        this->cy_desc_, dcyblock->data(), this->weight_desc_, wblock->data(), 
+        this->hx_desc_, hxblock->data(), this->cx_desc_, cxblock->data(), 
+        this->x_descs_, dxblock->mutable_data(), this->hx_desc_, dhxblock->mutable_data(), 
+        this->cx_desc_, dcxblock->mutable_data(), this->workspace_.block()->mutable_data(), 
+        this->workspace_count_ * sizeof(float), this->reserve_.block()->mutable_data(), 
+        this->ReserveSize_ * sizeof(float));
+  }, {hx.block(), src_output.block(), grad.block(), grad.block(), dhy.block(), dcy.block(), 
+      this->weight_.block(), hx.block(), cx.block()}, 
+     {dx.block(), dhx.block(), dcx.block(), reserve_.block(), workspace_.block()}); 
+  param_grad.push_back(dw);
+  data_grad.push_back(dx);
+  data_grad.push_back(dhx);
+  data_grad.push_back(dcx);
+  return std::make_pair(data_grad, param_grad);
+}
+
+}  // namespace singa
+#endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c51f9448/src/model/layer/cudnn_rnn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.h b/src/model/layer/cudnn_rnn.h
new file mode 100644
index 0000000..b1e9f43
--- /dev/null
+++ b/src/model/layer/cudnn_rnn.h
@@ -0,0 +1,85 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_MODEL_LAYER_CUDNN_RNN_H_
+#define SRC_MODEL_LAYER_CUDNN_RNN_H_
+#include "singa/singa_config.h"
+#ifdef USE_CUDNN
+#include <string>
+#include <utility>
+#include <vector>
+#include "./rnn.h"
+#include "singa/core/common.h"
+#include "singa/model/layer.h"
+#include "singa/proto/core.pb.h"
+#include "singa/utils/string.h"
+#include <cudnn.h>
+#include <chrono>
+#include "./cudnn_utils.h"
+#include "singa/utils/logging.h"
+
+namespace singa {
+class CudnnRNN : public RNN {
+ public:
+  ~CudnnRNN();
+  /// \copydoc Layer::layer_type()
+  const std::string layer_type() const override { return "CudnnRNN"; }
+
+  const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) override;
+  const std::pair<vector<Tensor>, vector<Tensor>> Backward(int flag, const vector<Tensor>& grads) override;
+
+  /// \copydoc Layer::Setup(const LayerConf&);
+  void Setup(const Shape& in_sample, const LayerConf &conf) override;
+
+  void ToDevice(std::shared_ptr<Device> device) override;
+
+  size_t workspace_byte_limit() { return workspace_byte_limit_; }
+  // string prefer() { return prefer_; }
+  string inputMode() const { return inputMode_; }
+  string direction() const { return direction_; }
+  string mode() const { return mode_; }
+
+ protected:
+  /// Init cudnn related data structures.
+  void InitCudnn(const Tensor& input);
+
+ protected:
+  bool has_init_cudnn_ = false;
+  cudnnTensorDescriptor_t* x_descs_ = nullptr;
+  cudnnTensorDescriptor_t* y_descs_ = nullptr;
+  cudnnTensorDescriptor_t hx_desc_ = nullptr;
+  cudnnTensorDescriptor_t cx_desc_ = nullptr;
+  cudnnTensorDescriptor_t hy_desc_ = nullptr;
+  cudnnTensorDescriptor_t cy_desc_ = nullptr;
+  cudnnFilterDescriptor_t weight_desc_ = nullptr;
+  cudnnRNNDescriptor_t rnn_desc_ = nullptr;
+  cudnnDropoutDescriptor_t dropout_desc_ = nullptr;
+  size_t workspace_byte_limit_, workspace_count_;
+  size_t ReserveSize_;
+  Tensor workspace_;
+  string inputMode_;
+  string direction_;
+  string mode_;
+  Tensor reserve_;
+  Tensor dropoutStates_;
+};
+
+}  // namespace singa
+
+#endif  // USE_CUDNN
+#endif  // SRC_MODEL_LAYER_CUDNN_RNN_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c51f9448/src/model/layer/rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/rnn.cc b/src/model/layer/rnn.cc
new file mode 100644
index 0000000..493a5e4
--- /dev/null
+++ b/src/model/layer/rnn.cc
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "./rnn.h"
+#include <vector>
+#include "singa/model/layer.h"
+
+namespace singa {
+
+void RNN::Setup(const Shape& in_sample, const LayerConf &conf) {
+  Layer::Setup(in_sample, conf);
+  RNNConf rnn_conf = conf.rnn_conf();
+  hiddenSize_ = rnn_conf.hiddensize();
+  CHECK_GT(hiddenSize_, 0u);
+
+  numLayers_ = rnn_conf.numlayers();
+  CHECK_GT(numLayers_, 0u);
+
+  dropout_ = rnn_conf.dropout();
+  CHECK_GE(dropout_, 0u);
+}
+
+const vector<Tensor> RNN::Forward(int flag, const vector<Tensor>& inputs) {
+  vector<Tensor> data_output;
+  return data_output;
+}
+
+const std::pair<vector<Tensor>, vector<Tensor>> RNN::Backward(int flag, const vector<Tensor>& grads) {
+  vector<Tensor> param_grad;
+  vector<Tensor> data_grad;
+  return std::make_pair(data_grad, param_grad);
+}
+
+void RNN::ToDevice(std::shared_ptr<Device> device) {
+  Layer::ToDevice(device);
+  weight_.ToDevice(device);
+}
+}  /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c51f9448/src/model/layer/rnn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/rnn.h b/src/model/layer/rnn.h
index 35c86bd..ec5a35d 100644
--- a/src/model/layer/rnn.h
+++ b/src/model/layer/rnn.h
@@ -38,21 +38,32 @@ class RNN : public Layer {
   const std::string layer_type() const override { return "RNN"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const LayerConf& conf) override;
+  void Setup(const vector<size_t>& in_shape, const LayerConf& conf) override;
 
   /// \copydoc Layer::Forward(int flag, const vector<Tensor>&)
-  const vector<Tensor> Forward(int flag, const vector<Tensor>& input) override;
+  const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) override;
 
   /// \copydoc Layer::Backward(int, const vector<Tensor>&);
   const std::pair<vector<Tensor>, vector<Tensor>> Backward(
-      int flag, const vector<Tensor>& grad) override;
+      int flag, const vector<Tensor>& grads) override;
 
-  void ToDevice(Device* device) override;
 
+  size_t hiddenSize() const { return hiddenSize_; }
+  size_t numLayers() const { return numLayers_; }
+  size_t weightSize() const { return weightSize_; }
+  float dropout() const { return dropout_; }
+  
+  void set_weight(Tensor w) {
+    weight_.ResetLike(w);
+    weight_.CopyData(w);
+  }
+
+
+  void ToDevice(std::shared_ptr<Device> device) override;
   /// Return the internal state stack, which should be empty at the beginning
   /// of
   /// one iteration.
-  std::stack<Tensor> states() const { return states_; }
+  // std::stack<Tensor> states() const { return states_; }
 
  protected:
   /// Storing input or output from Forward(), which are used in Backward().
@@ -60,7 +71,15 @@ class RNN : public Layer {
   /// 1. push the 'input' or 'output' into states_ if the flag of Forward() is
   ///    for kTrain and 'input' or 'output' is necessary for Backward().
   /// 2. pop data out in Backward().
-  std::stack<Tensor*> states_;
+  // std::stack<Tensor*> states_;
+  std::stack<Tensor> buf_;
+  size_t hiddenSize_;
+  size_t numLayers_;
+  size_t numLinearLayer_;
+  size_t seqLength_;
+  size_t weightSize_; /*all the weights and biases*/
+  float dropout_;
+  Tensor weight_;
 };
 }  // namespace singa
 #endif  // SRC_MODEL_LAYER_RNN_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c51f9448/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index b1318d9..d8193f1 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -203,6 +203,7 @@ message LayerConf {
   optional ConcatConf concat_conf = 104;
   optional ContrastiveLossConf contrastive_loss_conf = 105;
   optional ConvolutionConf convolution_conf = 106;
+  optional RNNConf rnn_conf = 140;
   // optional DataConf data_conf = 107;
   optional DropoutConf dropout_conf = 108;
   // optional DummyDataConf dummy_data_conf = 109;
@@ -391,6 +392,22 @@ message ConvolutionConf {
   optional string prefer = 51 [default = "fastest"];
 }
 
+message RNNConf {
+  optional uint32 hiddensize = 1; // The number of hiddensize
+  optional uint32 numlayers = 2; // The number of stacked RNN layers
+  optional float dropout = 3 [default = 0];
+  optional int32 workspace_byte_limit = 50 [default = 512];
+  // cudnn inputmode
+  // options: "cudnn_linear_input", "cudnn_skip_input"
+  optional string inputmode = 51 [default = "cudnn_linear_input"];
+  // cudnn direction
+  // options: "cudnn_undirectional", "cudnn_bidirectional"
+  optional string direction = 52 [default = "cudnn_undirectional"];
+  // cudnn RNN mode
+  // options: "cudnn_rnn_relu", "cudnn_rnn_tanh", "cudnn_lstm", "cudnn_gru"
+  optional string mode = 53 [default = "cudnn_rnn_relu"];
+}
+
 /*
 message DataConf {
   enum DB {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c51f9448/test/singa/test_cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_rnn.cc b/test/singa/test_cudnn_rnn.cc
new file mode 100644
index 0000000..1a79d7b
--- /dev/null
+++ b/test/singa/test_cudnn_rnn.cc
@@ -0,0 +1,212 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "../src/model/layer/cudnn_rnn.h"
+#ifdef USE_CUDNN
+
+#include "gtest/gtest.h"
+
+using singa::CudnnRNN;
+using singa::Shape;
+TEST(CudnnRNN, Setup) {
+  CudnnRNN rnn;
+  EXPECT_EQ("CudnnRNN", rnn.layer_type());
+
+  singa::LayerConf conf;
+  singa::RNNConf *rnnconf = conf.mutable_rnn_conf();
+  rnnconf->set_hiddensize(2);
+  rnnconf->set_numlayers(1);
+  rnnconf->set_dropout(0); 
+  rnnconf->set_inputmode("cudnn_linear_input");
+  rnnconf->set_direction("cudnn_undirectional");
+  rnnconf->set_mode("cudnn_rnn_tanh");
+  // MB
+  rnnconf->set_workspace_byte_limit(256);
+  rnn.Setup(Shape{4, 1, 2}, conf);
+
+  EXPECT_EQ(2u, rnn.hiddenSize());
+  EXPECT_EQ(1u, rnn.numLayers());
+  EXPECT_EQ(0u, rnn.dropout());
+  EXPECT_EQ("cudnn_linear_input", rnn.inputMode());
+  EXPECT_EQ("cudnn_undirectional", rnn.direction());
+  EXPECT_EQ("cudnn_rnn_tanh", rnn.mode());
+  EXPECT_EQ(256u << 20, rnn.workspace_byte_limit());
+}
+
+TEST(CudnnRNN, Forward) {
+  auto cuda = std::make_shared<singa::CudaGPU>();
+  const size_t seqLength = 4, batchsize = 1, dim = 2;
+  const size_t numLayers = 1, hiddensize = 2, numDirections = 1;
+  const float x[seqLength * batchsize * dim] = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
+                                          1.0f, 1.0f, 1.0f};
+  singa::Tensor in(singa::Shape{seqLength, batchsize, dim}, cuda);
+  in.CopyDataFromHostPtr(x, seqLength * batchsize * dim);
+
+
+  
+  const float hx_data[numLayers * batchsize * hiddensize * numDirections] = {1.0f, 1.0f};
+  singa::Tensor hx(singa::Shape{numLayers, batchsize, hiddensize * numDirections}, cuda);
+  hx.CopyDataFromHostPtr(hx_data, numLayers * batchsize * hiddensize * numDirections);
+
+  const float cx_data[numLayers * batchsize * hiddensize * numDirections] = {1.0f, 1.0f};
+  singa::Tensor cx(singa::Shape{numLayers, batchsize, hiddensize * numDirections}, cuda);
+  cx.CopyDataFromHostPtr(cx_data, numLayers * batchsize * hiddensize * numDirections);
+  
+  CudnnRNN rnn;
+  
+  singa::LayerConf conf;
+  singa::RNNConf *rnnconf = conf.mutable_rnn_conf();
+  rnnconf->set_hiddensize(2);
+  rnnconf->set_numlayers(1);
+  rnnconf->set_dropout(0);
+  rnnconf->set_inputmode("cudnn_linear_input");
+  rnnconf->set_direction("cudnn_undirectional");
+  rnnconf->set_mode("cudnn_rnn_tanh");
+  // MB
+  rnnconf->set_workspace_byte_limit(256);
+  rnn.Setup(Shape{4, 1, 2}, conf);
+ 
+  
+  size_t weightSize = rnn.weightSize();
+  float we[weightSize];
+  for (size_t i = 0; i < weightSize; i++)
+    we[i] = 1.0f;
+  singa::Tensor weight(singa::Shape{weightSize, 1, 1}, cuda);
+  weight.CopyDataFromHostPtr(we, weightSize);
+  rnn.set_weight(weight);
+ 
+  vector<singa::Tensor> input_array;
+  input_array.push_back(in);
+  input_array.push_back(hx);
+  input_array.push_back(cx);
+  const auto ret = rnn.Forward(singa::kTrain, input_array);
+  // singa::CppCPU host(0, 1);
+  singa::Tensor out1 = ret[0];
+  out1.ToHost();
+  const float *outptr1 = out1.data<float>();
+  EXPECT_EQ(8u, out1.Size());
+  EXPECT_NEAR(1.0f, outptr1[0], 0.0001); // tanh 6
+  EXPECT_NEAR(1.0f, outptr1[1], 0.0001);
+  EXPECT_NEAR(1.0f, outptr1[2], 0.0001);
+  EXPECT_NEAR(1.0f, outptr1[3], 0.0001);
+  EXPECT_NEAR(1.0f, outptr1[4], 0.0001);
+  EXPECT_NEAR(1.0f, outptr1[5], 0.0001);
+  EXPECT_NEAR(1.0f, outptr1[6], 0.0001);
+  EXPECT_NEAR(1.0f, outptr1[7], 0.0001);
+
+  singa::Tensor hy1 = ret[1];
+  hy1.ToHost();
+  const float *hyptr1 = hy1.data<float>();
+  EXPECT_EQ(2u, hy1.Size());
+  EXPECT_NEAR(1.0f, hyptr1[0], 0.0001);
+  EXPECT_NEAR(1.0f, hyptr1[1], 0.0001);
+}
+
+TEST(CudnnRNN, Backward) {
+  // src_data
+  auto cuda = std::make_shared<singa::CudaGPU>();
+  const size_t seqLength = 4, batchsize = 1, dim = 2;
+  const size_t numLayers = 1, hiddensize = 2, numDirections = 1;
+  const float x[seqLength * batchsize * dim] = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
+                                          1.0f, 1.0f, 1.0f};
+  singa::Tensor in(singa::Shape{seqLength, batchsize, dim}, cuda);
+  in.CopyDataFromHostPtr(x, seqLength * batchsize * dim);
+
+  const float hx_data[numLayers * batchsize * hiddensize * numDirections] = {1.0f, 1.0f};
+  singa::Tensor hx(singa::Shape{numLayers, batchsize, hiddensize * numDirections}, cuda);
+  hx.CopyDataFromHostPtr(hx_data, numLayers * batchsize * hiddensize * numDirections);
+
+  const float cx_data[numLayers * batchsize * hiddensize * numDirections] = {1.0f, 1.0f};
+  singa::Tensor cx(singa::Shape{numLayers, batchsize, hiddensize * numDirections}, cuda);
+  cx.CopyDataFromHostPtr(cx_data, numLayers * batchsize * hiddensize * numDirections);
+
+  CudnnRNN rnn;
+
+  singa::LayerConf conf;
+  singa::RNNConf *rnnconf = conf.mutable_rnn_conf();
+  rnnconf->set_hiddensize(2);
+  rnnconf->set_numlayers(1);
+  rnnconf->set_dropout(0);
+  rnnconf->set_inputmode("cudnn_linear_input");
+  rnnconf->set_direction("cudnn_undirectional");
+  rnnconf->set_mode("cudnn_rnn_tanh");
+  // MB
+  rnnconf->set_workspace_byte_limit(256);
+  rnn.Setup(Shape{4, 1, 2}, conf);
+
+  size_t weightSize = rnn.weightSize();
+  float we[weightSize];
+  for (size_t i = 0; i < weightSize; i++)
+    we[i] = 1.0f;
+  singa::Tensor weight(singa::Shape{weightSize, 1, 1}, cuda);
+  weight.CopyDataFromHostPtr(we, weightSize);
+  rnn.set_weight(weight);
+
+
+  vector<singa::Tensor> input_array;
+  input_array.push_back(in);
+  input_array.push_back(hx);
+  input_array.push_back(cx);
+  const auto ret = rnn.Forward(singa::kTrain, input_array);
+
+  // grad
+  const float dy[seqLength * batchsize * hiddensize * numDirections] = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
+  singa::Tensor grad(singa::Shape{seqLength, batchsize, hiddensize * numDirections},
+                     cuda);
+  grad.CopyDataFromHostPtr(dy, seqLength * batchsize * hiddensize * numDirections);
+
+  const float dhy_data[numLayers * batchsize * hiddensize * numDirections] = {1.0f, 1.0f};
+  singa::Tensor dhy(singa::Shape{numLayers, batchsize, hiddensize * numDirections},
+                     cuda);
+  dhy.CopyDataFromHostPtr(dhy_data, numLayers * batchsize * hiddensize * numDirections);
+
+  const float dcy_data[numLayers * batchsize * hiddensize * numDirections] = {1.0f, 1.0f};
+  singa::Tensor dcy(singa::Shape{numLayers, batchsize, hiddensize * numDirections},
+                     cuda);
+  dcy.CopyDataFromHostPtr(dcy_data, numLayers * batchsize * hiddensize * numDirections);
+
+  vector<singa::Tensor> grad_array;
+  grad_array.push_back(grad);
+  grad_array.push_back(dhy);
+  grad_array.push_back(dcy);
+  const auto ret_back = rnn.Backward(singa::kTrain, grad_array);
+  // singa::CppCPU host(0, 1);
+  singa::Tensor in_grad = ret_back.first[0];
+  in_grad.ToHost();
+  const float *dx = in_grad.data<float>();
+  EXPECT_EQ(8u, in_grad.Size());
+  EXPECT_NEAR(0.14, dx[0], 0.0001);
+  EXPECT_NEAR(0.14, dx[1], 0.0001);
+  EXPECT_NEAR(0.1596, dx[2], 0.0001);
+  EXPECT_NEAR(0.1596, dx[3], 0.0001);
+  EXPECT_NEAR(0.1623, dx[4], 0.0001);
+  EXPECT_NEAR(0.1623, dx[5], 0.0001);
+  EXPECT_NEAR(0.1627, dx[6], 0.0001);
+  EXPECT_NEAR(0.1627, dx[7], 0.0001);
+
+  singa::Tensor dhx_grad = ret_back.first[1];
+  dhx_grad.ToHost();
+  const float *dhx = dhx_grad.data<float>();
+  EXPECT_EQ(2u, dhx_grad.Size());
+  EXPECT_NEAR(0.1627, dhx[0], 0.0001);
+  EXPECT_NEAR(0.1627, dhx[1], 0.0001);
+}
+#endif  // USE_CUDNN


[3/3] incubator-singa git commit: SINGA-218 Implementation for RNN CUDNN version

Posted by zh...@apache.org.
SINGA-218 Implementation for RNN CUDNN version

Add an example using the char-rnn model.
The trained model (with 2 stacks of lstm) over linux kernel source code
could generate source code with some meaning full patterns, e.g.,
indention, comments, variable definition, assignments.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/dfc422e5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/dfc422e5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/dfc422e5

Branch: refs/heads/dev
Commit: dfc422e5b6229de5f598ee3f0226f1a0d082eb16
Parents: 8e0b108
Author: Wei Wang <wa...@gmail.com>
Authored: Mon Aug 8 17:41:26 2016 +0800
Committer: Wei Wang <wa...@gmail.com>
Committed: Wed Aug 10 00:47:10 2016 +0800

----------------------------------------------------------------------
 examples/char-rnn/README.md   |  31 ++++++
 examples/char-rnn/sample.py   | 102 ++++++++++++++++++
 examples/char-rnn/train.py    | 207 +++++++++++++++++++++++++++++++++++++
 src/core/tensor/tensor.cc     |   3 +-
 src/io/csv_encoder.cc         |   2 +-
 src/model/layer/cudnn_rnn.cc  |  34 +++---
 src/model/layer/rnn.cc        |  16 +--
 src/model/layer/rnn.h         |  30 ++++--
 src/proto/model.proto         |   2 +-
 src/python/singa/layer.py     |  71 ++++++++++++-
 src/python/swig/model_layer.i |  60 +++++++----
 test/singa/test_cudnn_rnn.cc  |  34 +++---
 12 files changed, 519 insertions(+), 73 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/examples/char-rnn/README.md
----------------------------------------------------------------------
diff --git a/examples/char-rnn/README.md b/examples/char-rnn/README.md
new file mode 100644
index 0000000..c5cdbb8
--- /dev/null
+++ b/examples/char-rnn/README.md
@@ -0,0 +1,31 @@
+# Train Char-RNN using SINGA
+
+Recurrent neural networks (RNN) are widely used for modelling sequential data,
+e.g., natural language sentences. This example describe how to implement a RNN
+application (or model) using SINGA's RNN layers.
+We will use the [char-rnn](https://github.com/karpathy/char-rnn) modle as an
+example, which trains over setences or
+source code, with each character as an input unit. Particularly, we will train
+a RNN using GRU over Linux kernel source code. After training, we expect to
+generate meaningful code from the model.
+
+
+## Instructions
+
+* Compile and install SINGA. Currently the RNN implmentation depends on Cudnn V5.
+
+* Prepare the dataset. Download the [kernel source code](http://cs.stanford.edu/people/karpathy/char-rnn/).
+Other plain text files can also be used.
+
+* Start the training,
+
+    python train.py input_linux.txt
+
+  Some hyper-parameters could be set through command line,
+
+    python train.py -h
+
+
+* Sample characters from the model by providing num of characters and the seed string.
+
+    python sample.py 100 --seed '#include <std'

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/examples/char-rnn/sample.py
----------------------------------------------------------------------
diff --git a/examples/char-rnn/sample.py b/examples/char-rnn/sample.py
new file mode 100644
index 0000000..a8fcb73
--- /dev/null
+++ b/examples/char-rnn/sample.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+'''Sample characters from the pre-trained model'''
+import sys
+import os
+import cPickle as pickle
+import numpy as np
+import argparse
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
+from singa import layer
+from singa import tensor
+from singa import device
+from singa.proto import model_pb2
+
+
+def sample(model_path, nsamples=100, seed_text='', do_sample=True):
+    with open(model_path, 'rb') as fd:
+        d=pickle.load(fd)
+        rnn_w = tensor.from_numpy(d['rnn_w'])
+        idx_to_char=d['idx_to_char']
+        char_to_idx=d['char_to_idx']
+        vocab_size = len(idx_to_char)
+        dense_w = tensor.from_numpy(d['dense_w'])
+        dense_b = tensor.from_numpy(d['dense_b'])
+        hidden_size = d['hidden_size']
+        num_stacks = d['num_stacks']
+        dropout = d['dropout']
+
+    cuda = device.create_cuda_gpu()
+    rnn = layer.LSTM(name='lstm', hidden_size=hidden_size,
+            num_stacks=num_stacks, dropout=dropout,
+            input_sample_shape=(len(idx_to_char),))
+    rnn.to_device(cuda)
+    rnn.param_values()[0].copy_data(rnn_w)
+    dense = layer.Dense('dense', vocab_size, input_sample_shape=(hidden_size,))
+    dense.to_device(cuda)
+    dense.param_values()[0].copy_data(dense_w)
+    dense.param_values()[1].copy_data(dense_b)
+    hx = tensor.Tensor((num_stacks, 1, hidden_size), cuda)
+    cx = tensor.Tensor((num_stacks, 1, hidden_size), cuda)
+    hx.set_value(0.0)
+    cx.set_value(0.0)
+    if len(seed_text) > 0:
+        for c in seed_text:
+            x = np.zeros((1, vocab_size), dtype=np.float32)
+            x[0, char_to_idx[c]] = 1
+            tx=tensor.from_numpy(x)
+            tx.to_device(cuda)
+            inputs=[tx, hx, cx]
+            outputs=rnn.forward(model_pb2.kEval, inputs)
+            y = dense.forward(model_pb2.kEval, outputs[0])
+            y = tensor.softmax(y)
+            hx = outputs[1]
+            cx = outputs[2]
+        sys.stdout.write(seed_text)
+    else:
+        y = tensor.Tensor((1, vocab_size), cuda)
+        y.set_value(1.0 / vocab_size)
+
+    for i in range(nsamples):
+        y.to_host()
+        prob = tensor.to_numpy(y)[0]
+        if do_sample:
+            cur=np.random.choice(vocab_size, 1, p=prob)[0]
+        else:
+            cur = np.argmax(prob)
+        sys.stdout.write(idx_to_char[cur])
+        x = np.zeros((1, vocab_size), dtype=np.float32)
+        x[0, cur] = 1
+        tx=tensor.from_numpy(x)
+        tx.to_device(cuda)
+        inputs=[tx, hx, cx]
+        outputs=rnn.forward(model_pb2.kEval, inputs)
+        y = dense.forward(model_pb2.kEval, outputs[0])
+        y = tensor.softmax(y)
+        hx = outputs[1]
+        cx = outputs[2]
+    print ''
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='sample chars from char-rnn')
+    parser.add_argument('--seed', help='seed text string which warms up the rnn'\
+            ' states for sampling', default='')
+    parser.add_argument('n', type=int, help='num of characters to sample')
+    args = parser.parse_args()
+    assert args.n > 0, 'n must > 0'
+    sample('model.bin', args.n, seed_text=args.seed)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/examples/char-rnn/train.py
----------------------------------------------------------------------
diff --git a/examples/char-rnn/train.py b/examples/char-rnn/train.py
new file mode 100644
index 0000000..22fdc82
--- /dev/null
+++ b/examples/char-rnn/train.py
@@ -0,0 +1,207 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+'''Train a Char-RNN model using plain text files.
+The model is created following https://github.com/karpathy/char-rnn
+The train file could be any text file,
+e.g., http://cs.stanford.edu/people/karpathy/char-rnn/
+'''
+import sys
+import os
+import cPickle as pickle
+import numpy as np
+import argparse
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
+from singa import layer
+from singa import loss
+from singa import device
+from singa import tensor
+from singa import optimizer
+from singa import initializer
+from singa.proto import core_pb2
+from singa.proto import model_pb2
+from singa import utils
+
+
+class Data(object):
+    def __init__(self, fpath, batch_size=32, seq_length=100, train_ratio=0.8):
+        '''Data object for loading a plain text file.
+
+        Args:
+            fpath, path to the text file.
+            train_ratio, split the text file into train and test sets, where
+                train_ratio of the characters are in the train set.
+        '''
+        self.raw_data = open(fpath, 'r').read()  # read text file
+        chars = list(set(self.raw_data))
+        self.vocab_size = len(chars)
+        self.char_to_idx = {ch:i for i, ch in enumerate(chars)}
+        self.idx_to_char = {i:ch for i, ch in enumerate(chars)}
+        data = [self.char_to_idx[c] for c in self.raw_data]
+        # seq_length + 1 for the data + label
+        nsamples = len(data) / (1 + seq_length)
+        data = data[0:nsamples * (1 + seq_length)]
+        data = np.asarray(data, dtype=np.int32)
+        data = np.reshape(data, (-1, seq_length + 1))
+        # shuffle all sequences
+        np.random.shuffle(data)
+        self.train_dat = data[0:int(data.shape[0]*train_ratio)]
+        self.num_train_batch = self.train_dat.shape[0] / batch_size
+        self.val_dat = data[self.train_dat.shape[0]:]
+        self.num_test_batch = self.val_dat.shape[0] / batch_size
+        print 'train dat', self.train_dat.shape
+        print 'val dat', self.val_dat.shape
+
+
+def numpy2tensors(npx, npy, dev):
+    '''batch, seq, dim -- > seq, batch, dim'''
+    tmpx=np.swapaxes(npx, 0, 1)
+    tmpy=np.swapaxes(npy, 0, 1)
+    inputs=[]
+    labels=[]
+    for t in range(tmpx.shape[0]):
+        x = tensor.from_numpy(tmpx[t])
+        y = tensor.from_numpy(tmpy[t])
+        x.to_device(dev)
+        y.to_device(dev)
+        inputs.append(x)
+        labels.append(y)
+    return inputs, labels
+
+
+def convert(batch, batch_size, seq_length, vocab_size, dev):
+    '''convert a batch of data into a sequence of input tensors'''
+    y = batch[:, 1:]
+    x1 = batch[:, :seq_length]
+    x = np.zeros((batch_size, seq_length, vocab_size), dtype=np.float32)
+    for b in range(batch_size):
+        for t in range(seq_length):
+            c = x1[b, t]
+            x[b, t, c] = 1
+    return numpy2tensors(x, y, dev)
+
+
+def get_lr(epoch):
+    return 0.001 / float(1 << (epoch / 50))
+
+
+def train(data, max_epoch, hidden_size =100, seq_length=100, batch_size=16,
+        num_stacks=1, lr=0.001, dropout = 0.5, model_path='model.bin'):
+    # SGD with L2 gradient normalization
+    opt = optimizer.SGD(constraint=optimizer.L2Constraint(5))
+    cuda = device.create_cuda_gpu()
+    rnn = layer.LSTM(name='lstm', hidden_size=hidden_size, num_stacks=num_stacks,
+            dropout=dropout, input_sample_shape=(data.vocab_size,))
+    rnn.to_device(cuda)
+    print 'created rnn'
+    rnn_w = rnn.param_values()[0]
+    initializer.uniform(rnn_w, -0.08, 0.08)  # init all rnn parameters
+    print 'rnn weight l1 = %f' % (rnn_w.l1())
+    dense = layer.Dense('dense', data.vocab_size, input_sample_shape=(hidden_size,))
+    dense.to_device(cuda)
+    dense_w = dense.param_values()[0]
+    dense_b = dense.param_values()[1]
+    print 'dense w ', dense_w.shape
+    print 'dense b ', dense_b.shape
+    initializer.xavier(dense_w) # init weight matrix using Xavier
+    print 'dense weight l1 = %f' % (dense_w.l1())
+    dense_b.set_value(0.0)
+    print 'dense b l1 = %f' % (dense_b.l1())
+
+    g_dense_w = tensor.Tensor(dense_w.shape, cuda)
+    g_dense_b = tensor.Tensor(dense_b.shape, cuda)
+
+    lossfun = loss.SoftmaxCrossEntropy();
+    for epoch in range(max_epoch):
+        train_loss = 0
+        for b in range(data.num_train_batch):
+            batch = data.train_dat[b * batch_size: (b + 1) * batch_size]
+            inputs, labels = convert(batch, batch_size, seq_length,
+                    data.vocab_size, cuda)
+            inputs.append(tensor.Tensor())
+            inputs.append(tensor.Tensor())
+
+            outputs = rnn.forward(model_pb2.kTrain, inputs)[0:-2]
+            grads=[]
+            batch_loss = 0
+            g_dense_w.set_value(0.0)
+            g_dense_b.set_value(0.0)
+            for output, label in zip(outputs, labels):
+                act = dense.forward(model_pb2.kTrain, output)
+                lvalue = lossfun.forward(model_pb2.kTrain, act, label)
+                batch_loss += lvalue.l1()
+                grad = lossfun.backward()
+                grad, gwb = dense.backward(model_pb2.kTrain, grad)
+                grads.append(grad)
+                g_dense_w += gwb[0]
+                g_dense_b += gwb[1]
+                #print output.l1(), act.l1()
+            utils.update_progress(b * 1.0 / data.num_train_batch,
+                    'training loss = %f' % (batch_loss / seq_length))
+            train_loss += batch_loss
+
+            grads.append(tensor.Tensor())
+            grads.append(tensor.Tensor())
+            g_rnn_w=rnn.backward(model_pb2.kTrain, grads)[1][0]
+            dense_w, dense_b = dense.param_values()
+            opt.apply_with_lr(epoch, get_lr(epoch), g_rnn_w, rnn_w, 'rnnw')
+            opt.apply_with_lr(epoch, get_lr(epoch), g_dense_w, dense_w, 'dense_w')
+            opt.apply_with_lr(epoch, get_lr(epoch), g_dense_b, dense_b, 'dense_b')
+        print '\nEpoch %d, train loss is %f' % (epoch,
+                train_loss / data.num_train_batch / seq_length)
+        eval_loss = 0
+        for b in range(data.num_test_batch):
+            batch = data.val_dat[b * batch_size: (b + 1) * batch_size]
+            inputs, labels = convert(batch, batch_size, seq_length,
+                    data.vocab_size, cuda)
+            inputs.append(tensor.Tensor())
+            inputs.append(tensor.Tensor())
+            outputs = rnn.forward(model_pb2.kEval, inputs)[0:-2]
+            for output, label in zip(outputs, labels):
+                output = dense.forward(model_pb2.kEval, output)
+                eval_loss += lossfun.forward(model_pb2.kEval, output, label).l1()
+        print 'Epoch %d, evaluation loss is %f' % (epoch,
+                eval_loss / data.num_test_batch / seq_length)
+
+    # checkpoint the file model
+    with open(model_path, 'wb') as fd:
+        print 'saving model to %s' % model_path
+        d={}
+        for name, w in zip(['rnn_w', 'dense_w', 'dense_b'], [rnn_w, dense_w, dense_b]):
+            w.to_host()
+            d[name]=tensor.to_numpy(w)
+        d['idx_to_char']=data.idx_to_char
+        d['char_to_idx']=data.char_to_idx
+        d['hidden_size']=hidden_size
+        d['num_stacks']=num_stacks
+        d['dropout']=dropout
+
+        pickle.dump(d, fd)
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Train multi-stack LSTM for '\
+            'modeling  character sequence from plain text files')
+    parser.add_argument('data', type=string, help='training file')
+    parser.add_argument('-b', type=int, default=32, help='batch_size')
+    parser.add_argument('-l', type=int, default=64, help='sequence length')
+    parser.add_argument('-d', type=int, default=128, help='hidden size')
+    parser.add_argument('-s', type=int, default=2, help='num of stacks')
+    parser.add_argument('-m', type=int, default=50, help='max num of epoch')
+    args = parser.parse_args()
+    data = Data(args.data, batch_size=args.b, seq_length=args.l)
+    train(data, args.m,  hidden_size=args.d, num_stacks=args.s,
+            seq_length=args.l, batch_size=args.b)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index bd3bc70..d2fec53 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -299,7 +299,8 @@ Tensor &Tensor::operator=(const Tensor &in) {
   shape_ = in.shape_;
   device_ = in.device_;
   block_ = in.block();
-  block_->IncRefCount();
+  if (block_ != nullptr)
+    block_->IncRefCount();
   return *this;
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/io/csv_encoder.cc
----------------------------------------------------------------------
diff --git a/src/io/csv_encoder.cc b/src/io/csv_encoder.cc
index 1b797a9..6089ab5 100644
--- a/src/io/csv_encoder.cc
+++ b/src/io/csv_encoder.cc
@@ -22,7 +22,7 @@
 namespace singa {
 
 std::string CSVEncoder::Encode(vector<Tensor>& data) {
-  CHECK_GE(data.size(), 1);
+  CHECK_GE(data.size(), 1u);
   size_t size = data[0].Size();
   const float* value = data[0].data<float>();
   std::string des = "";

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/model/layer/cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc
index 242a342..896c1e9 100644
--- a/src/model/layer/cudnn_rnn.cc
+++ b/src/model/layer/cudnn_rnn.cc
@@ -17,6 +17,7 @@
  */
 #include "./cudnn_rnn.h"
 #ifdef USE_CUDNN
+#if CUDNN_VERSION_MAJOR >= 5
 #include <cudnn.h>
 #include <chrono>
 #include "./cudnn_utils.h"
@@ -92,7 +93,7 @@ void CudnnRNN::UpdateIODescriptors(size_t len, const vector<Tensor> &inputs) {
   }
 
   for (size_t i = 0; i < len; i++) {
-    CHECK_EQ(inputs[i].shape(1), input_dim_);
+    CHECK_EQ(inputs[i].shape(1), input_size_);
     if (inputs[i].shape(0) != batch_size_ || reset) {
       int d[3] = {1, 1, 1}, s[3] = {1, 1, 1};
       d[0] = static_cast<int>(inputs[i].shape(0));
@@ -104,7 +105,7 @@ void CudnnRNN::UpdateIODescriptors(size_t len, const vector<Tensor> &inputs) {
       CUDNN_CHECK(cudnnSetTensorNdDescriptor(dx_descs_[i], dtype_, 3, d, s));
 
       d[0] = static_cast<int>(inputs[i].shape(0));
-      d[1] = static_cast<int>(hidden_dim_ * num_directions_);
+      d[1] = static_cast<int>(hidden_size_ * num_directions_);
       s[0] = d[1] * d[2];
       s[1] = d[2];
       CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_descs_[i], dtype_, 3, d, s));
@@ -121,7 +122,7 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) {
   CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size));
   dropout_state_ = Tensor(Shape{state_size}, dev, kChar);
   CUDNN_CHECK(cudnnSetDropoutDescriptor(
-      dropout_desc_, ctx->cudnn_handle, dropout_,
+      dropout_desc_, ctx->cudnn_handle, 1 - dropout_,  // keep probability
       dropout_state_.block()->mutable_data(), state_size, seed_));
 
   CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_));
@@ -146,7 +147,7 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) {
     rnn_mode = CUDNN_LSTM;
   else if (rnn_mode_ == "gru")
     rnn_mode = CUDNN_GRU;
-  CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_dim_, num_stacks_,
+  CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_size_, num_stacks_,
                                     dropout_desc_, input_mode, direction,
                                     rnn_mode, dtype_));
 
@@ -176,7 +177,7 @@ void CudnnRNN::ResetHiddenAndCellDescriptors(size_t batch_size) {
   int dim[3] = {1, 1, 1};
   dim[0] = static_cast<int>(num_stacks_ * num_directions_);
   dim[1] = static_cast<int>(batch_size);
-  dim[2] = static_cast<int>(hidden_dim_);
+  dim[2] = static_cast<int>(hidden_size_);
   int stride[3] = {1, 1, 1};
   stride[0] = dim[1] * dim[2];
   stride[1] = dim[2];
@@ -238,7 +239,7 @@ vector<Tensor> CudnnRNN::SplitOutput(size_t num, size_t dim,
                                      const Tensor output) {
   vector<Tensor> outputs;
   if (num == 1) {
-    outputs.push_back(output);
+    outputs.push_back(Reshape(output, Shape{in.at(0).shape(0), dim}));
   } else {
     for (size_t i = 0, offset = 0; offset < output.Size(); i++) {
       Shape s{in.at(i).shape(0), dim};
@@ -261,7 +262,7 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
   CHECK_GT(inputs.size(), 1u + has_cell_);
   size_t num_x = inputs.size() - has_cell_ - 1;
   Tensor input = MergeInputs(num_x, inputs);
-  LOG(INFO) << "input size " << input.Size() << " value " << input.L1();
+  // LOG(INFO) << "input size " << input.Size() << " value " << input.L1();
 
   if (rnn_desc_ != nullptr)
     CHECK_EQ(dtype_, GetCudnnDataType(dtype))
@@ -273,11 +274,11 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
   UpdateStates(num_x, inputs);
   // CheckFowardShapes();
 
-  Shape outshape{input.Size() * hidden_dim_ / input_dim_ * num_directions_};
+  Shape outshape{input.Size() * hidden_size_ / input_size_ * num_directions_};
   Tensor output(outshape, dev, dtype);
-  LOG(INFO) << "output size " << output.Size();
+  // LOG(INFO) << "output size " << output.Size();
   Tensor hx = inputs.at(num_x);
-  Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_dim_};
+  Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_size_};
   Tensor hy(state_shape, dev, dtype);
   Tensor cy, cx;
   if (has_cell_) {
@@ -285,8 +286,8 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
     cy.ResetLike(hy);
   }
 
-  LOG(INFO) << "hidden size " << hy.Size();
-  LOG(INFO) << "weight size " << weight_.Size() << " value " << weight_.L1();
+  // LOG(INFO) << "hidden size " << hy.Size();
+  // LOG(INFO) << "weight size " << weight_.Size() << " value " << weight_.L1();
   Block *inb = input.block(), *outb = output.block(),
         *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(),
         *hyb = hy.block(), *cyb = cy.block(),
@@ -336,7 +337,7 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
     }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace});
   }
   auto outputs =
-      SplitOutput(num_x, hidden_dim_ * num_directions_, inputs, output);
+      SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output);
   outputs.push_back(hy);
   if (has_cell_) outputs.push_back(cy);
   return outputs;
@@ -368,10 +369,10 @@ const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward(
   if (has_cell_)
     dcy = grads.at(num_dy + 1);
 
-  Shape xshape{y.Size() * input_dim_ / hidden_dim_ / num_directions_};
+  Shape xshape{y.Size() * input_size_ / hidden_size_ / num_directions_};
   Tensor dx(xshape, dev, dtype);
   Tensor dw(weight_.shape(), dev, dtype);
-  Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_dim_};
+  Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_size_};
   Tensor dhx(state_shape, dev, dtype);
   Tensor dcx;
   if (has_cell_)
@@ -419,7 +420,7 @@ const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward(
       {dxb, dwb, dhxb, dcxb, wspace, rspace});
 
   vector <Tensor> param_grad{dw};
-  auto data_grads = SplitOutput(num_dy, input_dim_, grads, dx);
+  auto data_grads = SplitOutput(num_dy, input_size_, grads, dx);
   data_grads.push_back(dhx);
   if (has_cell_)
     data_grads.push_back(dcx);
@@ -427,4 +428,5 @@ const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward(
 }
 
 }  // namespace singa
+#endif  // CUDNN_VERSION_MAJOR >= 5
 #endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/model/layer/rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/rnn.cc b/src/model/layer/rnn.cc
index 6b831a7..424c20b 100644
--- a/src/model/layer/rnn.cc
+++ b/src/model/layer/rnn.cc
@@ -27,13 +27,13 @@ void RNN::Setup(const Shape& in_sample, const LayerConf &conf) {
   Layer::Setup(in_sample, conf);
 
   RNNConf rnn_conf = conf.rnn_conf();
-  hidden_dim_ = rnn_conf.hidden_dim();
-  CHECK_GT(hidden_dim_, 0u);
+  hidden_size_ = rnn_conf.hidden_size();
+  CHECK_GT(hidden_size_, 0u);
   num_stacks_ = rnn_conf.num_stacks();
   CHECK_GT(num_stacks_, 0u);
-  input_dim_ = Product(in_sample);
-  CHECK_GT(input_dim_, 0u);
-  dropout_ = rnn_conf.dropout();
+  input_size_ = Product(in_sample);
+  CHECK_GT(input_size_, 0u);
+  dropout_ = rnn_conf.dropout();  // drop probability
   CHECK_GE(dropout_, 0);
 
   input_mode_ = ToLowerCase(rnn_conf.input_mode());
@@ -71,9 +71,9 @@ void RNN::Setup(const Shape& in_sample, const LayerConf &conf) {
 
   size_t weight_size = 0;
   for (size_t i = 0; i < num_stacks_; i++) {
-    size_t dim = hidden_dim_ * (in_sample[0] +  hidden_dim_ + 2);
+    size_t dim = hidden_size_ * (in_sample[0] +  hidden_size_ + 2);
     if (i > 0)
-      dim = hidden_dim_ * (hidden_dim_ +  hidden_dim_ + 2);
+      dim = hidden_size_ * (hidden_size_ +  hidden_size_ + 2);
     weight_size += mult * dim;
   }
   weight_.Reshape(Shape{weight_size});
@@ -81,6 +81,7 @@ void RNN::Setup(const Shape& in_sample, const LayerConf &conf) {
 
 const vector<Tensor> RNN::Forward(int flag, const vector<Tensor>& inputs) {
   vector<Tensor> data_output;
+  LOG(FATAL) << "CPU RNN is not implemented!";
   return data_output;
 }
 
@@ -88,6 +89,7 @@ const std::pair<vector<Tensor>, vector<Tensor>> RNN::Backward(int flag,
     const vector<Tensor>& grads) {
   vector<Tensor> param_grad;
   vector<Tensor> data_grad;
+  LOG(FATAL) << "CPU RNN is not implemented!";
   return std::make_pair(data_grad, param_grad);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/model/layer/rnn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/rnn.h b/src/model/layer/rnn.h
index 3750021..1b5dad7 100644
--- a/src/model/layer/rnn.h
+++ b/src/model/layer/rnn.h
@@ -37,20 +37,32 @@ class RNN : public Layer {
   /// \copydoc Layer::layer_type()
   const std::string layer_type() const override { return "RNN"; }
 
-  /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const vector<size_t>& in_shape, const LayerConf& conf) override;
+  /// Setup the RNN layer.
+  /// in_shape is the shape of a single training instance from one timestep,
+  void Setup(const Shape& in_shape, const LayerConf& conf) override;
 
-  /// \copydoc Layer::Forward(int flag, const vector<Tensor>&)
+  /// The inputs vector includes <x1, ... xn, hx, cx> where xi is the input
+  /// tensor at the i-th time step. hx is used to initialize the hidden tensor,
+  /// which could be a dummy tensor (like Tensor hx;). cx is used to initialize
+  /// the cell tensor, which could be a dummy tensor( like Tensor cx;). For
+  /// dummy tensors, 0's would be used during computation.
+  /// cx is missing for gru/relu/tanh RNNs, and is valid for lstm.
+  /// The dim order of xi is <batch, feature>, and the batchsize of xi must be
+  /// >= that of x(i+1).
+  /// The output vector includes <y1, ... yn, hy, cy> where yi is the output
+  /// tensor at the i-th time step. hy is the final hidden tensor, cy is the
+  /// final cell tensor. cy is missing for gru/relu/tanh RNNs and is valid for
+  /// lstm.
   const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) override;
 
-  /// \copydoc Layer::Backward(int, const vector<Tensor>&);
+  /// The grads vector includes <dy1, dy2, ... dyn, dhy, dcy>, the symbols are
+  /// similar to those for Forward. dcy is missing for gru/relu/tanh RNNs and is
+  /// valid for lstm.
+  /// The first vector of the output includes <dx1, dx2, ... dxn, dhx, dcx>.
+  /// The second vector of the output includes the gradients of all parameters.
   const std::pair<vector<Tensor>, vector<Tensor>> Backward(
       int flag, const vector<Tensor>& grads) override;
 
-  void set_weight(Tensor w) {
-    weight_.ResetLike(w);
-    weight_.CopyData(w);
-  }
   const vector<Tensor> param_values() override {
     return vector<Tensor>{weight_};
   }
@@ -73,7 +85,7 @@ class RNN : public Layer {
   std::stack<Tensor> buf_;
   bool has_cell_ = false;
   size_t num_directions_ = 1;
-  size_t input_dim_ = 0, hidden_dim_ = 0, num_stacks_ = 0, seq_length_ = 0;
+  size_t input_size_ = 0, hidden_size_ = 0, num_stacks_ = 0, seq_length_ = 0;
   size_t batch_size_ = 0;
   size_t seed_ = 0x1234567;
   float dropout_ = 0.0f;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index 31ebfc3..6923820 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -393,7 +393,7 @@ message ConvolutionConf {
 }
 
 message RNNConf {
-  optional uint32 hidden_dim = 1; // The number of hiddensize
+  optional uint32 hidden_size = 1; // The hidden feature size
   optional uint32 num_stacks = 2; // The number of stacked RNN layers
   optional float dropout = 3 [default = 0];
   optional bool remember_state = 4 [default = false];

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/src/python/singa/layer.py b/src/python/singa/layer.py
index a443e1a..a87eb10 100644
--- a/src/python/singa/layer.py
+++ b/src/python/singa/layer.py
@@ -403,7 +403,7 @@ class Dense(Layer):
         if W_specs is None:
             W_specs = {'init': 'xavier'}
         if b_specs is None:
-            b_specs = {'init': 'constant'}
+            b_specs = {'init': 'constant', 'value': 0}
         if 'name' not in W_specs:
             W_specs['name'] = name + '_weight'
         if 'name' not in b_specs:
@@ -502,6 +502,71 @@ class Flatten(Layer):
             self.setup(input_sample_shape)
 
 
+class RNN(Layer):
+    def __init__(self, name, hidden_size, rnn_mode='lstm', engine='cudnn',
+            dropout=0.0, num_stacks=1, input_mode='linear', bidirectional=False,
+            param_specs=None, input_sample_shape=None):
+        super(RNN, self).__init__(name)
+        conf = self.conf.rnn_conf
+        assert hidden_size > 0, 'Hidden feature size must > 0'
+        conf.hidden_size = hidden_size
+        assert rnn_mode in Set(['lstm', 'gru', 'tanh', 'relu']), \
+                'rnn mode %s is not available' %s (rnn_mode)
+        conf.rnn_mode = rnn_mode
+        conf.num_stacks = num_stacks
+        conf.dropout = dropout
+        conf.input_mode = input_mode
+        conf.direction = 'unidirectional'
+        if bidirectional:
+            conf.direction = 'bidirectional'
+        _check_engine(engine, ['cudnn'])
+        if param_specs is None:
+            param_specs = {'name': name + '-weight',
+                    'init': 'uniform', 'low':0, 'high':1};
+        self.conf.param.extend([_construct_param_specs_from_dict(param_specs)])
+        self.param_specs.append(_construct_param_specs_from_dict(param_specs))
+
+        self.layer = singa_wrap.CudnnRNN()
+        if input_sample_shape is not None:
+            self.setup(input_sample_shape)
+
+    def forward(self, flag, inputs):
+        assert self.has_setup, 'Must call setup() before forward()'
+        assert len(inputs) > 1, 'The input to RNN must include at '\
+                'least one input tensor '\
+                'and one hidden state tensor (could be a dummy tensor)'
+        tensors = []
+        for t in inputs:
+            assert isinstance(t, tensor.Tensor), 'input must be py Tensor %s' % (type(t))
+            tensors.append(t.singa_tensor)
+        y = self.layer.Forward(flag, tensors)
+        return tensor.from_raw_tensors(y)
+
+    def backward(self, flag, grad):
+        tensors = []
+        for t in grad:
+            assert isinstance(t, tensor.Tensor), 'grad must be py Tensor'
+            tensors.append(t.singa_tensor)
+        ret = self.layer.Backward(flag, tensors)
+        return tensor.from_raw_tensors(ret[0]), tensor.from_raw_tensors(ret[1])
+
+class LSTM(RNN):
+    def __init__(self, name, hidden_size, engine='cudnn',
+            dropout=0.0, num_stacks=1, input_mode='linear', bidirectional=False,
+            param_specs=None, input_sample_shape=None):
+        super(LSTM, self).__init__(name, hidden_size,  'lstm', engine, dropout,
+                num_stacks, input_mode, bidirectional, param_specs,
+                input_sample_shape)
+
+class GRU(RNN):
+    def __init__(self, name, hidden_size, engine='cudnn',
+            dropout=0.0, num_stacks=1, input_mode='linear', bidirectional=False,
+            param_specs=None, input_sample_shape=None):
+        super(GRU, self).__init__(name,  hidden_size, 'gru', engine, dropout,
+                num_stacks, input_mode, bidirectional, param_specs,
+                input_sample_shape)
+
+
 def _check_engine(engine, allowed_engines):
     assert engine.lower() in Set(allowed_engines), \
            '%s is not a supported engine. Pls use one of %s' % \
@@ -585,8 +650,8 @@ def _construct_param_specs_from_dict(specs):
         if specs['init'].lower() == 'uniform':
             assert 'low' in specs and 'high' in specs, \
                 'low and high are required for "uniform" init method'
-            filler.low = specs['low']
-            filler.high = specs['high']
+            filler.min = specs['low']
+            filler.max = specs['high']
         elif specs['init'].lower() == 'gaussian':
             assert 'mean' in specs and 'std' in specs, \
                 'std and mean are required for "gaussian" init method'

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/python/swig/model_layer.i
----------------------------------------------------------------------
diff --git a/src/python/swig/model_layer.i b/src/python/swig/model_layer.i
index 873ebc9..9d39301 100644
--- a/src/python/swig/model_layer.i
+++ b/src/python/swig/model_layer.i
@@ -30,6 +30,8 @@
 
 %{
 #include "singa/model/layer.h"
+#include "../src/model/layer/rnn.h"
+#include "../src/model/layer/cudnn_rnn.h"
 #include "singa/core/tensor.h"
 #include "singa/proto/model.pb.h"
 using singa::Tensor;
@@ -40,6 +42,8 @@ using singa::LayerConf;
 %}
 
 %shared_ptr(singa::Layer)
+%shared_ptr(singa::RNN)
+%shared_ptr(singa::CudnnRNN)
 
 namespace std {
   %template(strVector) vector<string>;
@@ -52,26 +56,44 @@ namespace std {
 
 namespace singa {
 
-  class Layer {
-    public:
-      Layer();
+class Layer {
+  public:
+    Layer();
 //      virtual void Setup(const std::vector<vector<size_t>>&, const string&);
-      virtual void Setup(const std::vector<size_t>& in_sample_shape,
-                         const std::string& proto_str);
-      const std::vector<Tensor> param_values();
-      virtual const std::vector<size_t> GetOutputSampleShape() const;
-      virtual void ToDevice(std::shared_ptr<Device> device);
-      virtual void AsType(DataType dtype);
-      virtual const Tensor Forward(int flag, const Tensor& input);
-      virtual const std::vector<Tensor> Forward(
-          int flag, const std::vector<Tensor>& inputs);
-      virtual const std::pair<Tensor, std::vector<Tensor>> Backward(
-          int flag, const Tensor& grad);
-      virtual const std::pair<std::vector<Tensor>, std::vector<Tensor>>
-      Backward(int flag, const vector<Tensor>& grads);
+    void Setup(const std::vector<size_t>& in_sample_shape,
+                        const std::string& proto_str);
+    virtual const std::vector<Tensor> param_values();
+    virtual const std::vector<size_t> GetOutputSampleShape() const;
+    virtual void ToDevice(std::shared_ptr<Device> device);
+    virtual void AsType(DataType dtype);
+    virtual const Tensor Forward(int flag, const Tensor& input);
+    virtual const std::vector<Tensor> Forward(
+        int flag, const std::vector<Tensor>& inputs);
+    virtual const std::pair<Tensor, std::vector<Tensor>> Backward(
+        int flag, const Tensor& grad);
+    virtual const std::pair<std::vector<Tensor>, std::vector<Tensor>>
+    Backward(int flag, const vector<Tensor>& grads);
+};
+
+std::shared_ptr<Layer> CreateLayer(const std::string& type);
+const std::vector<std::string> GetRegisteredLayers();
+class RNN : public Layer {
+  /*
+ public:
+  void Setup(const std::vector<size_t>& in_sample_shape,
+                        const std::string& proto_str) override;
+                        */
+};
+class CudnnRNN : public RNN {
+ public:
+ // note: Must use std::vector instead of vector.
+  const std::vector<Tensor> Forward(int flag, const std::vector<Tensor>& inputs) override;
+  const std::pair<std::vector<Tensor>, std::vector<Tensor>> Backward(
+      int flag, const std::vector<Tensor>& grads) override;
+  void ToDevice(std::shared_ptr<Device> device) override;
+    const std::vector<Tensor> param_values() override;
+    const std::vector<size_t> GetOutputSampleShape() const override;
+};
 
-  };
-  std::shared_ptr<Layer> CreateLayer(const std::string& type);
-  const std::vector<std::string> GetRegisteredLayers();
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/test/singa/test_cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_rnn.cc b/test/singa/test_cudnn_rnn.cc
index ebbf0aa..e0de02e 100644
--- a/test/singa/test_cudnn_rnn.cc
+++ b/test/singa/test_cudnn_rnn.cc
@@ -21,6 +21,7 @@
 
 #include "../src/model/layer/cudnn_rnn.h"
 #ifdef USE_CUDNN
+#if CUDNN_VERSION_MAJOR >= 5
 
 #include "gtest/gtest.h"
 
@@ -31,15 +32,15 @@ class TestCudnnRNN : public ::testing::Test {
   protected:
     virtual void SetUp() {
       singa::RNNConf *rnnconf = conf.mutable_rnn_conf();
-      rnnconf->set_hidden_dim(hidden_dim);
+      rnnconf->set_hidden_size(hidden_size);
       rnnconf->set_num_stacks(1);
-      rnnconf->set_dropout(1);
+      rnnconf->set_dropout(0);
       rnnconf->set_input_mode("linear");
       rnnconf->set_direction("unidirectional");
       rnnconf->set_rnn_mode("tanh");
     }
     singa::LayerConf conf;
-    size_t hidden_dim = 4;
+    size_t hidden_size = 4;
 };
 
 TEST_F(TestCudnnRNN, Setup) {
@@ -47,7 +48,7 @@ TEST_F(TestCudnnRNN, Setup) {
   EXPECT_EQ("CudnnRNN", rnn.layer_type());
   rnn.Setup(Shape{2}, conf);
   auto weight = rnn.param_values().at(0);
-  EXPECT_EQ(weight.Size(), hidden_dim * (2 + hidden_dim + 2));
+  EXPECT_EQ(weight.Size(), hidden_size * (2 + hidden_size + 2));
 }
 
 TEST_F(TestCudnnRNN, Forward) {
@@ -80,19 +81,19 @@ TEST_F(TestCudnnRNN, Forward) {
 
   const auto ret = rnn.Forward(singa::kEval, inputs);
   EXPECT_EQ(ret.size(), seqLength + 1);
-  vector<float> hxptr(hidden_dim, 0.0f);
+  vector<float> hxptr(hidden_size, 0.0f);
   for (size_t i = 0; i < seqLength; i++) {
     auto y = ret[i];
     y.ToHost();
     auto yptr = y.data<float>();
     vector<float> tmp;
-    for (size_t j = 0; j < hidden_dim; j++) {
+    for (size_t j = 0; j < hidden_size; j++) {
       float ty = 0;
       for (size_t k = 0; k < dim; k++) {
         ty += x[i * dim + k] * wvalue;
       }
       ty += wvalue;
-      for (size_t k = 0; k < hidden_dim; k++) {
+      for (size_t k = 0; k < hidden_size; k++) {
         ty += hxptr[k] * wvalue;
       }
       ty += wvalue;
@@ -134,18 +135,18 @@ TEST_F(TestCudnnRNN, Backward) {
 
   const auto outs = rnn.Forward(singa::kTrain, inputs);
 
-  float dyptr[seqLength * batchsize * hidden_dim];
-  for (size_t i = 0; i < seqLength * batchsize * hidden_dim; i++)
+  float dyptr[seqLength * batchsize * hidden_size];
+  for (size_t i = 0; i < seqLength * batchsize * hidden_size; i++)
     dyptr[i] = i * 0.1f;
   vector<Tensor> grads;
   for (size_t i = 0; i < seqLength; i++) {
-    Tensor dy(Shape{batchsize, hidden_dim}, cuda);
+    Tensor dy(Shape{batchsize, hidden_size}, cuda);
     dy.CopyDataFromHostPtr(dyptr + i * dy.Size(), dy.Size());
     grads.push_back(dy);
   }
   Tensor dhy;
   grads.push_back(dhy);
-  vector<float> dhyptr(hidden_dim, 0.0f);
+  vector<float> dhyptr(hidden_size, 0.0f);
   const auto ret = rnn.Backward(singa::kTrain, grads);
   for (size_t i = seqLength - 1; i > 0 ; i --) {
     auto dx = ret.first[i];
@@ -154,21 +155,21 @@ TEST_F(TestCudnnRNN, Backward) {
     dx.ToHost();
     auto dxptr = dx.data<float>();
     auto yptr = y.data<float>();
-    for (size_t j = 0; j < hidden_dim; j++) {
-      dhyptr[j] += dyptr[i * hidden_dim + j];
+    for (size_t j = 0; j < hidden_size; j++) {
+      dhyptr[j] += dyptr[i * hidden_size + j];
       dhyptr[j] *= 1 - yptr[j] * yptr[j];
     }
     for (size_t k = 0; k < dim; k++) {
       float tdx = 0;
-      for (size_t j = 0; j < hidden_dim; j++) {
+      for (size_t j = 0; j < hidden_size; j++) {
         tdx += dhyptr[j] * wvalue;
       }
       EXPECT_NEAR(tdx, dxptr[k], 1e-4);
     }
     vector<float> tmp;
-    for (size_t k = 0; k < hidden_dim; k++) {
+    for (size_t k = 0; k < hidden_size; k++) {
       float tdhy = 0;
-      for (size_t j = 0; j < hidden_dim; j++) {
+      for (size_t j = 0; j < hidden_size; j++) {
         tdhy += dhyptr[j] * wvalue;
       }
       tmp.push_back(tdhy);
@@ -176,4 +177,5 @@ TEST_F(TestCudnnRNN, Backward) {
     std::copy(tmp.begin(), tmp.end(), dhyptr.begin());
   }
 }
+#endif  // CUDNN_VERSION_MAJOR >= 5
 #endif  // USE_CUDNN