You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/06/24 06:51:39 UTC

[6/6] incubator-singa git commit: Merge PR #165 for CnMeM

Merge PR #165  for CnMeM

Fixbugs from device type (Device* -> std::shared_ptr<Device>).


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/dd08f413
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/dd08f413
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/dd08f413

Branch: refs/heads/dev
Commit: dd08f413015878365fed32e579c1b7f4ecc81270
Parents: 5651383 9abd791
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Fri Jun 24 13:41:02 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Fri Jun 24 14:49:56 2016 +0800

----------------------------------------------------------------------
 .gitignore                              |   1 +
 CMakeLists.txt                          |   9 +-
 cmake/Dependencies.cmake                |  12 +
 cmake/Templates/singa_config.h.in       |   2 +
 cmake/Utils.cmake                       |  15 +
 include/singa/core/common.h             |  32 +-
 include/singa/core/device.h             |  18 +-
 include/singa/core/memory.h             |  62 +-
 include/singa/core/tensor.h             | 464 +++++++-------
 include/singa/io/decoder.h              |  56 ++
 include/singa/io/encoder.h              |  61 ++
 include/singa/io/reader.h               |  99 +++
 include/singa/io/writer.h               | 112 ++++
 include/singa/model/initializer.h       | 105 ++++
 include/singa/model/layer.h             |  48 +-
 include/singa/model/loss.h              |  47 ++
 include/singa/model/optimizer.h         |  59 +-
 include/singa/utils/channel.h           |  85 +++
 include/singa/utils/timer.h             |   2 +-
 src/CMakeLists.txt                      |  18 +
 src/core/device/cpp_cpu.cc              |   2 +-
 src/core/device/cuda_gpu.cc             |  88 ++-
 src/core/device/device.cc               |  24 +-
 src/core/memory/memory.cc               |  83 +--
 src/core/tensor/math_kernel.cu          | 682 +++++++++++---------
 src/core/tensor/math_kernel.h           |  98 +--
 src/core/tensor/tensor.cc               | 896 ++++++++++++++++-----------
 src/core/tensor/tensor_math.h           | 418 +++++++------
 src/core/tensor/tensor_math_cpp.h       | 629 ++++++++++++++-----
 src/core/tensor/tensor_math_cuda.h      | 429 ++++++++++---
 src/io/binfile_reader.cc                | 113 ++++
 src/io/binfile_writer.cc                | 136 ++++
 src/io/jpg2proto_encoder.cc             |  83 +++
 src/io/proto2jpg_decoder.cc             |  75 +++
 src/model/layer/activation.cc           |  27 +-
 src/model/layer/activation.h            |   7 +-
 src/model/layer/batchnorm.cc            |  11 +-
 src/model/layer/batchnorm.h             |  12 +-
 src/model/layer/convolution.cc          |  13 +-
 src/model/layer/convolution.h           |   7 +-
 src/model/layer/cudnn_activation.cc     |  33 +-
 src/model/layer/cudnn_activation.h      |  11 +-
 src/model/layer/cudnn_batchnorm.cc      | 132 ++--
 src/model/layer/cudnn_batchnorm.h       |  40 +-
 src/model/layer/cudnn_convolution.cc    | 114 ++--
 src/model/layer/cudnn_convolution.h     |   4 +-
 src/model/layer/cudnn_dropout.cc        |  52 +-
 src/model/layer/cudnn_dropout.h         |   4 +-
 src/model/layer/cudnn_lrn.cc            |  78 +--
 src/model/layer/cudnn_lrn.h             |  32 +-
 src/model/layer/cudnn_pooling.cc        |  48 +-
 src/model/layer/cudnn_pooling.h         |   4 +-
 src/model/layer/cudnn_softmax.cc        |  62 +-
 src/model/layer/cudnn_softmax.h         |  11 +-
 src/model/layer/dense.cc                |   7 +-
 src/model/layer/dense.h                 |   6 +-
 src/model/layer/dropout.cc              |   5 +-
 src/model/layer/dropout.h               |   7 +-
 src/model/layer/flatten.cc              |  53 ++
 src/model/layer/flatten.h               |  56 ++
 src/model/layer/lrn.cc                  |   5 +-
 src/model/layer/lrn.h                   |  13 +-
 src/model/layer/pooling.cc              |  13 +-
 src/model/layer/pooling.h               |   8 +-
 src/model/layer/prelu.cc                | 145 +++++
 src/model/layer/prelu.h                 |  66 ++
 src/model/layer/softmax.cc              |  34 +-
 src/model/layer/softmax.h               |  11 +-
 src/model/loss/mse.cc                   |  41 ++
 src/model/loss/mse.h                    |  66 --
 src/model/loss/softmax_cross_entropy.cc |  53 ++
 src/model/metric/accuracy.h             |   5 +-
 src/model/optimizer/adagrad.cc          |  41 ++
 src/model/optimizer/nesterov.cc         |  49 ++
 src/model/optimizer/optimizer.cc        |   2 +-
 src/model/optimizer/rmsprop.cc          |  45 ++
 src/model/optimizer/sgd.cc              |  10 +-
 src/proto/core.proto                    |   7 +-
 src/proto/io.proto                      |  37 ++
 src/proto/model.proto                   |  26 +-
 src/python/device.py                    |  82 +++
 src/python/example_layer.py             |  25 +
 src/python/layer.py                     |  78 ++-
 src/python/swig/core_device.i           |  60 ++
 src/python/swig/core_tensor.i           | 263 ++++++++
 src/python/swig/model_layer.i           |  83 +++
 src/python/swig/singa.i                 |  27 +
 src/python/tensor.py                    | 370 +++++++++++
 src/utils/channel.cc                    | 104 ++++
 test/CMakeLists.txt                     |   3 +-
 test/python/example_test_device.py      |  36 ++
 test/python/example_test_tensor.py      | 179 ++++++
 test/python/unittest_python.py          | 139 +++++
 test/singa/test_activation.cc           |  13 +-
 test/singa/test_adagrad.cc              |  96 +++
 test/singa/test_binfile_rw.cc           |  95 +++
 test/singa/test_channel.cc              |  39 ++
 test/singa/test_cpp_cpu.cc              |  16 +-
 test/singa/test_cross_entropy.cc        | 116 ++++
 test/singa/test_cudnn_activation.cc     |  36 +-
 test/singa/test_cudnn_batchnorm.cc      |  59 +-
 test/singa/test_cudnn_convolution.cc    | 105 ++--
 test/singa/test_cudnn_dropout.cc        |  35 +-
 test/singa/test_cudnn_lrn.cc            |  28 +-
 test/singa/test_cudnn_pooling.cc        |  36 +-
 test/singa/test_cudnn_softmax.cc        | 130 ++--
 test/singa/test_decoder.cc              |  84 +++
 test/singa/test_dense.cc                | 480 +++++++-------
 test/singa/test_dropout.cc              |  17 +-
 test/singa/test_flatten.cc              | 143 +++++
 test/singa/test_initializer.cc          | 148 +++++
 test/singa/test_memory.cc               | 129 ++--
 test/singa/test_mse.cc                  |  12 +-
 test/singa/test_nesterov.cc             | 101 +++
 test/singa/test_prelu.cc                | 245 ++++++++
 test/singa/test_rmsprop.cc              | 105 ++++
 test/singa/test_sgd.cc                  |  32 +-
 test/singa/test_softmax.cc              |  36 +-
 test/singa/test_tensor.cc               |  14 +-
 test/singa/test_tensor_math.cc          | 505 +++++++++++++--
 120 files changed, 8172 insertions(+), 2708 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/CMakeLists.txt
----------------------------------------------------------------------
diff --cc CMakeLists.txt
index c34b6ce,87b3a5d..7a5caf3
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@@ -10,22 -10,23 +10,23 @@@ LIST(APPEND CMAKE_MODULE_PATH ${PROJECT
  IF(UNIX OR APPLE)
    SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -Wall")
  ENDIF()
- 
+ IF(CMAKE_BUILD_TYPE=Debug)
+   SET(NVCC_FLAG "${NVCC_FLAG} -g -G ")
+ ENDIF()
  #message(STATUS "${CMAKE_CXX_FLAGS}")
 -SET(SINGA_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/include;${PROJECT_BINARY_DIR}")
 -#message(STATUS "include path: ${SINGA_INCLUDE_DIR}")
 +SET(SINGA_INCLUDE_DIR
-     #"${CMAKE_SOURCE_DIR}/include;${CMAKE_SOURCE_DIR}/lib/cnmem/lib;${CMAKE_SOURCE_DIR}/lib/cnmen/include;${PROJECT_BINARY_DIR}")
 +    "${CMAKE_SOURCE_DIR}/include;${CMAKE_SOURCE_DIR}/lib/cnmem/include;${PROJECT_BINARY_DIR}")
- #message(STATUS "include path: ${SINGA_INCLUDE_DIR}")
  INCLUDE_DIRECTORIES(${SINGA_INCLUDE_DIR})
  
- #OPTION(CPU_ONLY "use GPU libs" OFF)
  OPTION(USE_CBLAS "Use CBlas libs" ON)
  OPTION(USE_CUDA "Use Cuda libs" ON)
 -OPTION(USE_CUDNN "Use Cudnn libs" ON)
 +OPTION(USE_CUDNN "Use Cudnn libs" OFF)
  OPTION(USE_OPENCV "Use opencv" OFF)
  OPTION(USE_LMDB "Use LMDB libs" OFF)
+ OPTION(USE_PYTHON "Generate py wrappers" OFF)
  
  INCLUDE("cmake/Dependencies.cmake")
+ INCLUDE("cmake/Utils.cmake")
  ADD_DEFINITIONS(-DUSE_CMAKE)
  
  CONFIGURE_FILE (

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/cmake/Dependencies.cmake
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/include/singa/core/common.h
----------------------------------------------------------------------
diff --cc include/singa/core/common.h
index e19022e,22a2b49..cb1bdca
--- a/include/singa/core/common.h
+++ b/include/singa/core/common.h
@@@ -20,7 -20,9 +20,9 @@@
  #define SINGA_CORE_COMMON_H_
  #include <random>
  #include <chrono>
 +#include "./singa/singa_config.h"
+ #include <atomic>
+ #include <memory>
 -#include "./singa_config.h"
  #include "singa/utils/logging.h"
  
  #ifdef USE_CUDA

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/include/singa/core/device.h
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/include/singa/core/memory.h
----------------------------------------------------------------------
diff --cc include/singa/core/memory.h
index e4e1e63,db09043..c35f5d0
--- a/include/singa/core/memory.h
+++ b/include/singa/core/memory.h
@@@ -19,56 -19,10 +19,58 @@@
  #ifndef SINGA_CORE_MEMORY_H_
  #define SINGA_CORE_MEMORY_H_
  
 +#include "cnmem.h"
++#include "singa/singa_config.h"
 +#include <mutex>
 +
  namespace singa {
  
  /// Manage device memory pool including garbage collection, memory opt.
  class VirtualMemory {};
  
 +class DeviceMemPool {
- 	public:
- 	virtual void InitPool() = 0;
- 	virtual void Malloc(void** ptr, const size_t size) = 0;
- 	virtual void Free(void* ptr) = 0;
- 	virtual ~DeviceMemPool(){};
++ public:
++  virtual void InitPool() = 0;
++  virtual void Malloc(void** ptr, const size_t size) = 0;
++  virtual void Free(void* ptr) = 0;
++  virtual ~DeviceMemPool(){};
 +};
 +
++#ifdef USE_CUDA
 +class CnMemPool : public DeviceMemPool {
- 	public:
- 	int status = 1;
++ public:
++  int status = 1;
 +
- 	void InitPool();
++  void InitPool();
 +
- 	/// numDevices: total number of available GPU cards.
- 	/// initSize: all devices will be allocated with this size 
- 	/// manager_flags: pool manager flag (one for all devices)
- 	/// flag = 0; default flag
- 	/// flag = 1: Prevent the manager from growing its memory consumption
- 	/// flag = 2; Prevent the manager from stealing memory.
- 	void InitPool(int numDevices, size_t initSize, unsigned flag);
++  /// numDevices: total number of available GPU cards.
++  /// initSize: all devices will be allocated with this size
++  /// manager_flags: pool manager flag (one for all devices)
++  /// flag = 0; default flag
++  /// flag = 1: Prevent the manager from growing its memory consumption
++  /// flag = 2; Prevent the manager from stealing memory.
++  void InitPool(int numDevices, size_t initSize, unsigned flag);
 +
- 	void Malloc(void** ptr, const size_t size);
- 	void Free(void* ptr);
++  void Malloc(void** ptr, const size_t size);
++  void Free(void* ptr);
 +
- 	// release all memory and set cnmem manager to unintialized 
- 	~CnMemPool();
++  // release all memory and set cnmem manager to unintialized
++  ~CnMemPool();
 +
- 	private:
- 	// whether the (global) memory pool has been initialized
- 	static bool initialized;
- 	// lock on the initialized variable
- 	static std::mutex mtx;
++ private:
++  // whether the (global) memory pool has been initialized
++  static bool initialized;
++  // lock on the initialized variable
++  static std::mutex mtx;
 +};
 +
 +class CudaMemPool : public DeviceMemPool {
- 	public:
- 	void InitPool(){};
- 	void Malloc(void** ptr, const size_t size);
- 	void Free(void* ptr);
- 	~CudaMemPool(){};
++ public:
++  void InitPool(){};
++  void Malloc(void** ptr, const size_t size);
++  void Free(void* ptr);
++  ~CudaMemPool(){};
 +};
- 
++#endif
  }  // namespace singa
  #endif  // SINGA_CORE_MEMORY_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --cc include/singa/core/tensor.h
index 8f73047,eb72bd3..a4f42db
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@@ -65,27 -54,28 +54,29 @@@ class Tensor 
   public:
    ~Tensor();
    Tensor();
 -  explicit Tensor(Shape &&shape, const DataType dtype = kFloat32);
 -  explicit Tensor(const Shape &shape, const DataType dtype = kFloat32);
 -  Tensor(Shape &&shape, Device *dev, const DataType dtype = kFloat32);
 -  Tensor(const Shape &shape, Device *dev, const DataType dtype = kFloat32);
 +  explicit Tensor(Shape &&shape, DataType dtype = kFloat32);
 +  explicit Tensor(const Shape &shape, DataType dtype = kFloat32);
 +  Tensor(Shape &&shape, std::shared_ptr<Device> dev, DataType dtype = kFloat32);
-   Tensor(const Shape &shape, std::shared_ptr<Device> dev, DataType dtype = kFloat32);
++  Tensor(const Shape &shape, std::shared_ptr<Device> dev,
++         DataType dtype = kFloat32);
  
    /// Copy Tensor to share the internal data.  No deep copy.
    Tensor(const Tensor &from);
    /// Copy Tensor to share the internal data.  No deep copy.
    Tensor(Tensor &&from);
  
-   /// For functions in xx_math.cc to access the blob.
-   /// Users should not operate against Blob directly.
-   /// blob_ is allocated in constructors.
-   Blob *blob() const { return blob_; }
+   /// For functions in xx_math.cc to access the block.
+   /// Users should not operate against Block directly.
+   /// block_ is allocated in constructors.
+   Block *block() const { return block_; }
+   void SetBlock(Block* block);
  
 -  Device *device() const { return device_; }
 +  std::shared_ptr<Device> device() const { return device_; }
  
-   /// Return immutable Tensor values with given type.
-   template <typename DType>
-   DType data() const {
-     return static_cast<DType>(blob()->data());
+   /// return immutable Tensor values with given type.
+   template <typename SType>
+   const SType* data() const {
+     return static_cast<const SType*>(block()->data());
    }
  
    /// data type, including kFloat16, kFloat32, kInt
@@@ -192,13 -179,22 +180,22 @@@
   protected:
    bool transpose_ = false;
    DataType data_type_ = kFloat32;
 -  Device *device_ = nullptr;
 +  std::shared_ptr<Device> device_ = nullptr;
-   /// Note: blob_ is allocated in lazy manner to avoid frequent malloc/free.
-   /// If you want to get an allocated Blob, use blob() instead of blob_.
-   Blob *blob_ = nullptr;
-   Shape shape_;
+   /// Note: block_ is allocated in lazy manner to avoid frequent malloc/free.
+   /// If you want to get an allocated Block, use block() instead of block_.
+   Block *block_ = nullptr;
+   Shape shape_ = {};
  };
  
+ typedef Shape::iterator ShapeIter;
+ inline size_t Product(const Shape &shape, int start = 0, size_t len = 0) {
+   if (len == 0) len = shape.size();
+   CHECK_LE(len, shape.size());
+   size_t v = 1;
+   for (unsigned int i = start; i < len; i++) v *= shape[i];
+   return v;
+ }
+ 
  inline void CheckDataTypeAndLang(const Tensor &in1, const Tensor &in2) {
    CHECK_EQ(in1.data_type(), in2.data_type());
    CHECK_EQ(in1.device()->lang(), in2.device()->lang());

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/include/singa/model/layer.h
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/include/singa/model/loss.h
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/include/singa/utils/timer.h
----------------------------------------------------------------------
diff --cc include/singa/utils/timer.h
index a54829d,a54829d..bdd6c5c
--- a/include/singa/utils/timer.h
+++ b/include/singa/utils/timer.h
@@@ -19,7 -19,7 +19,7 @@@ class Timer 
    /// Return the duration since last call to Tick() or since the creation of
    /// Timer. The template arg must be from Second or Millisecond or Hour.
    /// The returned value is the count of the time metric.
--  template <typename T>
++  template <typename T = Milliseconds>
    int Elapsed() const {
      static_assert(std::is_same<T, Seconds>::value ||
                        std::is_same<T, Milliseconds>::value ||

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/CMakeLists.txt
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/core/device/cpp_cpu.cc
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/core/device/cuda_gpu.cc
----------------------------------------------------------------------
diff --cc src/core/device/cuda_gpu.cc
index 4da292f,5d4e1ed..5879c58
--- a/src/core/device/cuda_gpu.cc
+++ b/src/core/device/cuda_gpu.cc
@@@ -32,8 -32,8 +32,7 @@@ const cudaMemcpyKind copyKind[] = {cuda
                                     cudaMemcpyDeviceToDevice};
  
  CudaGPU::~CudaGPU() {
--  if (ctx_.cublas_handle)
--    CUBLAS_CHECK(cublasDestroy(ctx_.cublas_handle));
++  if (ctx_.cublas_handle) CUBLAS_CHECK(cublasDestroy(ctx_.cublas_handle));
    if (ctx_.curand_generator)
      CURAND_CHECK(curandDestroyGenerator(ctx_.curand_generator));
  #ifdef USE_CUDNN
@@@ -42,14 -42,13 +41,12 @@@
      CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status);
    }
  #endif
- 	delete pool;
++  delete pool;
  }
  
--CudaGPU::CudaGPU(int id, int num_executors,
--                       string scheduler, string vm)
++CudaGPU::CudaGPU(int id, int num_executors, string scheduler, string vm)
      : Device(id, num_executors, scheduler, vm) {
--  if (id == -1)
--    id = FindDevice(0);
++  if (id == -1) id = FindDevice(0);
    lang_ = kCuda;
    ctx_.stream = NULL;  // use the default sync stream
    // TODO(wangwei) create one handle for each steam?
@@@ -68,62 -67,20 +65,57 @@@
    auto status = cudnnCreate(&ctx_.cudnn_handle);
    CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status);
  #endif  // USE_CUDNN
- 	
- 	// initialize cnmem memory management as default
- 	pool = new CnMemPool();
- 	((CnMemPool*)pool)->InitPool();
++
++  // initialize cnmem memory management as default
++  pool = new CnMemPool();
++  ((CnMemPool*)pool)->InitPool();
  }
  
- CudaGPU::CudaGPU(const MemPoolConf& mem_conf,int id, int num_executors,
-                        string scheduler)
 -void CudaGPU::SetRandSeed(unsigned seed) {
 -  CHECK(ctx_.curand_generator);
++CudaGPU::CudaGPU(const MemPoolConf& mem_conf, int id, int num_executors,
++                 string scheduler)
 +    : Device(id, num_executors, scheduler, "gc-only") {
-   if (id == -1)
-     id = FindDevice(0);
++  if (id == -1) id = FindDevice(0);
 +  lang_ = kCuda;
 +  ctx_.stream = NULL;  // use the default sync stream
 +  // TODO(wangwei) create one handle for each steam?
 +  CUDA_CHECK(cudaSetDevice(FindDevice(0)));
 +  // use curandCreateGeneratorHost for CudaHost device
    CURAND_CHECK(
 -      curandSetPseudoRandomGeneratorSeed(ctx_.curand_generator, seed));
 +      curandCreateGenerator(&ctx_.curand_generator, CURAND_RNG_PSEUDO_DEFAULT));
 +  auto seed = std::chrono::system_clock::now().time_since_epoch().count();
 +  SetRandSeed(seed);
 +  // TODO(wangwei) if one generator per stream, then need diff offset per gen?
 +  CURAND_CHECK(curandSetGeneratorOffset(ctx_.curand_generator, 0));
 +  CUBLAS_CHECK(cublasCreate(&(ctx_.cublas_handle)));
 +
 +#ifdef USE_CUDNN
 +  // TODO(wangwei) create one handle for each stream?
 +  auto status = cudnnCreate(&ctx_.cudnn_handle);
 +  CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status);
 +#endif  // USE_CUDNN
 +
- 	// initialize memory management for cuda devices
- 	string memoryPoolType = mem_conf.type();
- 	if(memoryPoolType.compare("cnmem") == 0) {
- 		pool = new CnMemPool();
- 		int num_devices = mem_conf.num_devices();
- 		size_t alloc_size = mem_conf.alloc_size();
- 		unsigned flag = mem_conf.cnmemflag();
- 		((CnMemPool*)pool)->InitPool(num_devices, alloc_size, flag);
- 	}
- 	else {
- 		pool = new CudaMemPool();
- 	}
++  // initialize memory management for cuda devices
++  string memoryPoolType = mem_conf.type();
++  if (memoryPoolType.compare("cnmem") == 0) {
++    pool = new CnMemPool();
++    int num_devices = mem_conf.num_devices();
++    size_t alloc_size = mem_conf.alloc_size();
++    unsigned flag = mem_conf.cnmemflag();
++    ((CnMemPool*)pool)->InitPool(num_devices, alloc_size, flag);
++  } else {
++    pool = new CudaMemPool();
++  }
  }
  
 -void CudaGPU::DoExec(function<void(Context*)>&& fn, int executor) {
 -  fn(&ctx_);
 +void CudaGPU::SetRandSeed(unsigned seed) {
 +  CHECK(ctx_.curand_generator);
-   CURAND_CHECK(
-       curandSetPseudoRandomGeneratorSeed(ctx_.curand_generator, seed));
++  CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(ctx_.curand_generator, seed));
  }
  
- void CudaGPU::DoExec(function<void(Context*)>&& fn, int executor) {
-   fn(&ctx_);
- }
++void CudaGPU::DoExec(function<void(Context*)>&& fn, int executor) { fn(&ctx_); }
 +
  void CudaGPU::CopyToFrom(void* dst, const void* src, size_t nBytes,
--                            CopyDirection direction, Context* ctx) {
++                         CopyDirection direction, Context* ctx) {
    cudaMemcpy(dst, src, nBytes, copyKind[direction]);
    // TODO(wangwei) use async copy
    // cudaMemcpyAsync(dst, src, nBytes,cudaMemcpyDefault, ctx_.stream);
@@@ -133,22 -90,19 +125,21 @@@
  void* CudaGPU::Malloc(int size) {
    void* ptr = nullptr;
    if (size > 0) {
- 		//CUDA_CHECK(cudaMalloc((void**)&ptr,size));
- 		pool->Malloc((void**)&ptr,size);
 -    CUDA_CHECK(cudaMalloc(&ptr, size));
++    // CUDA_CHECK(cudaMalloc((void**)&ptr,size));
++    pool->Malloc((void**)&ptr, size);
      CUDA_CHECK(cudaMemset(ptr, 0, size));
    }
    return ptr;
  }
  
--  /// Free cpu memory.
++/// Free cpu memory.
  void CudaGPU::Free(void* ptr) {
 -  if (ptr != nullptr)
 -    CUDA_CHECK(cudaFree(ptr));
 +  if (ptr != nullptr) {
- 		//CUDA_CHECK(cudaFree(ptr));
- 		pool->Free(ptr);
- 	}
++    // CUDA_CHECK(cudaFree(ptr));
++    pool->Free(ptr);
++  }
  }
  
--
  // ==========Following code is from Caffe src/caffe/common.cpp=================
  
  void CudaGPU::DeviceQuery() {
@@@ -169,20 -123,20 +160,18 @@@
    LOG(INFO) << "Warp size:                     " << prop.warpSize;
    LOG(INFO) << "Maximum memory pitch:          " << prop.memPitch;
    LOG(INFO) << "Maximum threads per block:     " << prop.maxThreadsPerBlock;
--  LOG(INFO) << "Maximum dimension of block:    "
--      << prop.maxThreadsDim[0] << ", " << prop.maxThreadsDim[1] << ", "
--      << prop.maxThreadsDim[2];
--  LOG(INFO) << "Maximum dimension of grid:     "
--      << prop.maxGridSize[0] << ", " << prop.maxGridSize[1] << ", "
--      << prop.maxGridSize[2];
++  LOG(INFO) << "Maximum dimension of block:    " << prop.maxThreadsDim[0]
++            << ", " << prop.maxThreadsDim[1] << ", " << prop.maxThreadsDim[2];
++  LOG(INFO) << "Maximum dimension of grid:     " << prop.maxGridSize[0] << ", "
++            << prop.maxGridSize[1] << ", " << prop.maxGridSize[2];
    LOG(INFO) << "Clock rate:                    " << prop.clockRate;
    LOG(INFO) << "Total constant memory:         " << prop.totalConstMem;
    LOG(INFO) << "Texture alignment:             " << prop.textureAlignment;
--  LOG(INFO) << "Concurrent copy and execution: "
--      << (prop.deviceOverlap ? "Yes" : "No");
++  LOG(INFO) << "Concurrent copy and execution: " << (prop.deviceOverlap ? "Yes"
++                                                                        : "No");
    LOG(INFO) << "Number of multiprocessors:     " << prop.multiProcessorCount;
    LOG(INFO) << "Kernel execution timeout:      "
--      << (prop.kernelExecTimeoutEnabled ? "Yes" : "No");
++            << (prop.kernelExecTimeoutEnabled ? "Yes" : "No");
    return;
  }
  
@@@ -203,6 -157,6 +192,5 @@@ int CudaGPU::FindDevice(const int start
    return -1;
  }
  
--
  }  // namespace singa
  #endif  // USE_CUDA

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/core/device/device.cc
----------------------------------------------------------------------
diff --cc src/core/device/device.cc
index 1889339,36381e4..6775e40
--- a/src/core/device/device.cc
+++ b/src/core/device/device.cc
@@@ -22,11 -22,11 +22,11 @@@ namespace singa 
  Device::Device(int id, int num_executors, string scheduler, string vm)
      : id_(id), num_executors_(num_executors) {
        // TODO(wangwei) create scheduler and vm.
 -  host_ = &defaultDevice;
 +  host_ = defaultDevice;
  }
  
- void Device::Exec(function<void(Context*)>&& fn, const vector<Blob*> read_blobs,
-                     const vector<Blob*> write_blobs, bool use_rand_generator) {
+ void Device::Exec(function<void(Context*)>&& fn, const vector<Block*> read_blocks,
+                     const vector<Block*> write_blocks, bool use_rand_generator) {
    // TODO(wangwei) execute operations scheduled by the scheduler.
    DoExec(std::move(fn), 0);
  }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/core/memory/memory.cc
----------------------------------------------------------------------
diff --cc src/core/memory/memory.cc
index 304c101,a1cf5db..7ac6792
--- a/src/core/memory/memory.cc
+++ b/src/core/memory/memory.cc
@@@ -16,71 -16,5 +16,74 @@@
   * limitations under the License.
   */
  
--
  #include "singa/core/memory.h"
 +#include "singa/utils/logging.h"
++#include "singa/proto/core.pb.h"
 +#include <iostream>
 +
++#ifdef USE_CUDA
 +namespace singa {
- 
 +bool singa::CnMemPool::initialized = false;
 +std::mutex singa::CnMemPool::mtx;
- 
 +void CnMemPool::InitPool(int numDevices, size_t initSize, unsigned flag) {
- 	mtx.lock();
- 	if(!initialized) {
- 		CHECK_GE(numDevices, 1);
- 		cnmemDevice_t* settingPtr = new cnmemDevice_t[numDevices];
- 		for(int i = 0; i < numDevices; i++) {
- 			settingPtr[i].device = i;
- 			settingPtr[i].size = initSize;
- 			settingPtr[i].numStreams = 0;
- 			settingPtr[i].streams = NULL;
- 			settingPtr[i].streamSizes = 0;
- 		}
- 		cnmemStatus_t status = cnmemInit(numDevices, settingPtr, flag);
- 		CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS) << " " << cnmemGetErrorString(status);
- 		delete[] settingPtr;
- 		initialized = true;
- 	}
- 	mtx.unlock();
++  mtx.lock();
++  const size_t kNBytesPerMB = (1u << 20);
++  if (!initialized) {
++    CHECK_GE(numDevices, 1);
++    cnmemDevice_t* settingPtr = new cnmemDevice_t[numDevices];
++    for (int i = 0; i < numDevices; i++) {
++      settingPtr[i].device = i;
++      settingPtr[i].size = initSize * kNBytesPerMB;
++      settingPtr[i].numStreams = 0;
++      settingPtr[i].streams = NULL;
++      settingPtr[i].streamSizes = 0;
++    }
++    cnmemStatus_t status = cnmemInit(numDevices, settingPtr, flag);
++    CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
++        << " " << cnmemGetErrorString(status);
++    delete[] settingPtr;
++    initialized = true;
++  }
++  mtx.unlock();
 +}
 +
 +void CnMemPool::InitPool() {
- 		int defaultNumDevices = 1;
- 		size_t defaultSize = 1000000U;
- 		InitPool(defaultNumDevices,defaultSize,cnmemManagerFlags_t::CNMEM_FLAGS_DEFAULT);
++  MemPoolConf conf;
++  InitPool(conf.num_devices(), conf.alloc_size(),
++           cnmemManagerFlags_t::CNMEM_FLAGS_DEFAULT);
 +}
 +
 +CnMemPool::~CnMemPool() {
- 	mtx.lock();
- 	if(initialized) {
- 		cnmemStatus_t status = cnmemFinalize();
- 		CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS) << " " << cnmemGetErrorString(status);
- 		initialized = false;
- 	}
- 	mtx.unlock();
++  mtx.lock();
++  if (initialized) {
++    cnmemStatus_t status = cnmemFinalize();
++    CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
++        << " " << cnmemGetErrorString(status);
++    initialized = false;
++  }
++  mtx.unlock();
 +}
 +
- 
 +void CnMemPool::Malloc(void** ptr, const size_t size) {
- 	cnmemStatus_t status = cnmemMalloc(ptr,size,NULL);
- 	CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS) << " " << cnmemGetErrorString(status);
++  cnmemStatus_t status = cnmemMalloc(ptr, size, NULL);
++  CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
++      << " " << cnmemGetErrorString(status);
 +}
 +
 +void CnMemPool::Free(void* ptr) {
- 	cnmemStatus_t status = cnmemFree(ptr,NULL);
- 	CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS) << " " << cnmemGetErrorString(status);
++  cnmemStatus_t status = cnmemFree(ptr, NULL);
++  CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
++      << " " << cnmemGetErrorString(status);
 +}
 +
 +void CudaMemPool::Malloc(void** ptr, const size_t size) {
- 	cudaError_t status = cudaMalloc(ptr,size);
- 	CHECK_EQ(status, cudaError_t::cudaSuccess);
++  cudaError_t status = cudaMalloc(ptr, size);
++  CHECK_EQ(status, cudaError_t::cudaSuccess);
 +}
 +
 +void CudaMemPool::Free(void* ptr) {
- 	cudaError_t status = cudaFree(ptr);
- 	CHECK_EQ(status, cudaError_t::cudaSuccess);
++  cudaError_t status = cudaFree(ptr);
++  CHECK_EQ(status, cudaError_t::cudaSuccess);
 +}
- 
 +}
++#endif

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/core/tensor/math_kernel.cu
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/core/tensor/math_kernel.h
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --cc src/core/tensor/tensor.cc
index a5b43d8,9b3eeff..b852a54
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@@ -25,54 -25,66 +25,65 @@@
  namespace singa {
  
  Tensor::~Tensor() {
-   if (blob_ != nullptr && blob_->DecRefCount() == 0)
-     device_->FreeBlob(blob_);
-   blob_ = nullptr;
 -  // LOG(ERROR) << "~";
+   if (block_ != nullptr && block_->DecRefCount() == 0)
+     device_->FreeBlock(block_);
+   block_ = nullptr;
  }
  
 -Tensor::Tensor() { device_ = &defaultDevice; }
 +Tensor::Tensor() { device_ = defaultDevice; }
  
 -Tensor::Tensor(const Shape &shape, const DataType dtype)
 -    : data_type_(dtype), device_(&defaultDevice), shape_(shape) {
 -  device_ = &defaultDevice;
 +Tensor::Tensor(const Shape &shape, DataType dtype)
 +    : data_type_(dtype), device_(defaultDevice), shape_(shape) {
 +  device_ = defaultDevice;
-   blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_));
+   block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_));
  }
 -Tensor::Tensor(Shape &&shape, const DataType dtype)
 -    : data_type_(dtype), device_(&defaultDevice), shape_(shape) {
 -  device_ = &defaultDevice;
 +Tensor::Tensor(Shape &&shape, DataType dtype)
 +    : data_type_(dtype), device_(defaultDevice), shape_(shape) {
 +  device_ = defaultDevice;
-   blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_));
+   block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_));
  }
 -Tensor::Tensor(const Shape &shape, Device *device, const DataType dtype)
 +Tensor::Tensor(const Shape &shape, std::shared_ptr<Device> device, DataType dtype)
      : data_type_(dtype), device_(device), shape_(shape) {
-   blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_));
+   block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_));
  }
 -Tensor::Tensor(Shape &&shape, Device *device, const DataType dtype)
 +Tensor::Tensor(Shape &&shape, std::shared_ptr<Device> device, DataType dtype)
      : data_type_(dtype), device_(device), shape_(shape) {
-   blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_));
- }
- Tensor::Tensor(const Tensor &t)
-     : transpose_(t.transpose_), data_type_(t.data_type_), device_(t.device_),
-       blob_(t.blob()), shape_(t.shape_) {
-   blob_->IncRefCount();
-   // LOG(ERROR) << "const&";
- }
- 
- Tensor::Tensor(Tensor &&t)
-     : transpose_(t.transpose_), data_type_(t.data_type_), device_(t.device_),
-       shape_(std::move(t.shape_)) {
-   blob_ = t.blob_;
-   t.blob_ = nullptr;
-   // LOG(ERROR) << "&&";
- }
- 
- void Tensor::ResetLike(const Tensor &t) {
-   if (blob_ == nullptr || device_ != t.device_ || MemSize() != t.MemSize()) {
-     if (blob_ != nullptr && blob_->DecRefCount() == 0)
-       device_->FreeBlob(blob_);
-     shape_ = t.shape_;
-     device_ = t.device_;
-     data_type_ = t.data_type_;
-     blob_ = device_->NewBlob(t.MemSize());
+   block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_));
+ }
+ Tensor::Tensor(const Tensor &in)
+     : transpose_(in.transpose_),
+       data_type_(in.data_type_),
+       device_(in.device_),
+       block_(in.block()),
+       shape_(in.shape_) {
+   block_->IncRefCount();
+ }
+ 
+ Tensor::Tensor(Tensor &&in)
+     : transpose_(in.transpose_),
+       data_type_(in.data_type_),
+       device_(in.device_),
+       shape_(std::move(in.shape_)) {
+   block_ = in.block_;
+   in.block_ = nullptr;
+ }
+ 
+ void Tensor::SetBlock(Block* block) {
+   LOG(WARNING) << "Pls avoid using this function, which may have side-effect.";
+   if (block_ != nullptr)
+     if (block_->DecRefCount())
+       device_->FreeBlock(block_);
+   block_ = block;
+ }
+ 
+ void Tensor::ResetLike(const Tensor &in) {
+   if (block_ == nullptr || device_ != in.device_ || MemSize() != in.MemSize()) {
+     if (block_ != nullptr && block_->DecRefCount() == 0)
+       device_->FreeBlock(block_);
+     shape_ = in.shape_;
+     device_ = in.device_;
+     data_type_ = in.data_type_;
+     block_ = device_->NewBlock(in.MemSize());
    }
  }
  
@@@ -228,13 -245,13 +244,13 @@@ void CopyDataToFrom(Tensor *dst, const 
    auto width = SizeOf(src.data_type());
    CHECK_EQ(width, SizeOf(dst->data_type()));
    size_t nBytes = num * width;
-   dst_offset *= width;
-   src_offset *= width;
-   CHECK_GE(src.MemSize(), src_offset + nBytes);
-   CHECK_GE(dst->MemSize(), dst_offset + nBytes);
+   auto d_offset = dst_offset * width;
+   auto s_offset = src_offset * width;
+   CHECK_GE(src.MemSize(), s_offset + nBytes);
+   CHECK_GE(dst->MemSize(), d_offset + nBytes);
  
 -  Device *src_dev = src.device(), *dst_dev = dst->device();
 +  std::shared_ptr<Device> src_dev = src.device(), dst_dev = dst->device();
-   Blob *from = src.blob(), *to = dst->blob();
+   Block *from = src.block(), *to = dst->block();
    if (dst_dev->lang() != src_dev->lang()) {
      // let the none cpp device conduct copy op
      if (dst_dev->lang() == kCpp) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/batchnorm.cc
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/batchnorm.h
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/cudnn_activation.cc
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/cudnn_activation.h
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --cc src/model/layer/cudnn_batchnorm.cc
index 8288a41,0e597fe..a1e9e50
--- a/src/model/layer/cudnn_batchnorm.cc
+++ b/src/model/layer/cudnn_batchnorm.cc
@@@ -30,7 -30,7 +30,7 @@@ CudnnBatchNorm::~CudnnBatchNorm() 
    }
  }
  
--void CudnnBatchNorm::ToDevice(Device* device) {
++void CudnnBatchNorm::ToDevice(std::shared_ptr<Device> device) {
    BatchNorm::ToDevice(device);
    resultSaveMean_.ToDevice(device);
    resultSaveVariance_.ToDevice(device);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/cudnn_batchnorm.h
----------------------------------------------------------------------
diff --cc src/model/layer/cudnn_batchnorm.h
index 8598b65,36dbbce..4f46452
--- a/src/model/layer/cudnn_batchnorm.h
+++ b/src/model/layer/cudnn_batchnorm.h
@@@ -29,31 -29,29 +29,29 @@@
  namespace singa {
  class CudnnBatchNorm : public BatchNorm {
   public:
-    ~CudnnBatchNorm();
-    /// \copy doc Layer::layer_type()
-    const std::string layer_type() const override {
-      return "CudnnBatchNorm";
-    }
+   ~CudnnBatchNorm();
+   /// \copy doc Layer::layer_type()
+   const std::string layer_type() const override { return "CudnnBatchNorm"; }
  
-    void Setup(const LayerConf& conf) override;
+   void Setup(const Shape& in_sample, const LayerConf& conf) override;
  
-    const Tensor Forward(int flag, const Tensor& input)
-      override;
-    const std::pair<Tensor, vector<Tensor>> Backward(
-        int flag, const Tensor& grad) override;
+   const Tensor Forward(int flag, const Tensor& input) override;
+   const std::pair<Tensor, vector<Tensor>> Backward(int flag,
+                                                    const Tensor& grad) override;
 -  void ToDevice(Device* device) override;
++  void ToDevice(std::shared_ptr<Device> device) override;
  
-    /// Init cudnn related data structures.
-    void InitCudnn(const Shape& shape, DataType dtype);
-    void ToDevice(Device* device) override;
+  private:
+   /// Init cudnn related data structures.
+   void InitCudnn(const Shape& shape, DataType dtype);
  
   private:
-    bool has_init_cudnn_ = false;
-    cudnnBatchNormMode_t mode_;
-    cudnnLRNDescriptor_t lrn_desc_;
-    cudnnTensorDescriptor_t shape_desc_, param_desc_;
-    Tensor resultSaveMean_, resultSaveVariance_;
-    
- }; // class CudnnBatchNorm
+   bool has_init_cudnn_ = false;
+   cudnnBatchNormMode_t mode_;
+   cudnnLRNDescriptor_t lrn_desc_ = nullptr;
+   cudnnTensorDescriptor_t shape_desc_ = nullptr, param_desc_ = nullptr;
+   Tensor resultSaveMean_, resultSaveVariance_;
+ 
+ };  // class CudnnBatchNorm
  }  // namespace
  
  #endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --cc src/model/layer/cudnn_convolution.cc
index b80c3bd,8cdfc07..d5ac2a3
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@@ -46,7 -46,7 +46,7 @@@ void CudnnConvolution::Setup(const Shap
           "limited_workspace, no_workspace and autotune";
  }
  
--void CudnnConvolution::ToDevice(Device *device) {
++void CudnnConvolution::ToDevice(std::shared_ptr<Device> device) {
    weight_.ToDevice(device);
    bias_.ToDevice(device);
    workspace_.ToDevice(device);
@@@ -55,7 -55,7 +55,7 @@@
  void CudnnConvolution::InitCudnn(const Tensor &input) {
    CHECK(!has_init_cudnn_);
    DataType dtype = input.data_type();
--  Device *dev = input.device();
++  auto dev = input.device();
    Context *ctx = dev->context(0);
    size_t batchsize = input.shape(0);
    CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
@@@ -161,7 -161,7 +161,7 @@@ const Tensor CudnnConvolution::Forward(
    if (flag & kTrain) buf_.push(input);  // buffer the input for backward
    size_t batchsize = input.shape()[0];
    DataType dtype = input.data_type();
--  Device *dev = input.device();
++  auto dev = input.device();
  
    if (!has_init_cudnn_) InitCudnn(input);
  

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/cudnn_convolution.h
----------------------------------------------------------------------
diff --cc src/model/layer/cudnn_convolution.h
index 152d797,6c15839..cd0471f
--- a/src/model/layer/cudnn_convolution.h
+++ b/src/model/layer/cudnn_convolution.h
@@@ -41,9 -41,9 +41,9 @@@ class CudnnConvolution : public Convolu
                                                     const Tensor &grad) override;
  
    /// \copydoc Layer::Setup(const LayerConf&);
-   void Setup(const LayerConf &conf) override;
+   void Setup(const Shape& in_sample, const LayerConf &conf) override;
  
--  void ToDevice(Device *device) override;
++  void ToDevice(std::shared_ptr<Device> device) override;
  
    size_t workspace_byte_limit() { return workspace_byte_limit_; }
    string prefer() { return prefer_; }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/cudnn_dropout.cc
----------------------------------------------------------------------
diff --cc src/model/layer/cudnn_dropout.cc
index 64a581b,877dd12..2e2e12b
--- a/src/model/layer/cudnn_dropout.cc
+++ b/src/model/layer/cudnn_dropout.cc
@@@ -34,8 -34,8 +34,8 @@@ CudnnDropout::~CudnnDropout() 
    if (y_desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc_));
  }
  
--void CudnnDropout::InitCudnn(int size, DataType dtype, Device* dev,
--                             Context* ctx) {
++void CudnnDropout::InitCudnn(int size, DataType dtype,
++                             std::shared_ptr<Device> dev, Context* ctx) {
    CHECK(!has_init_cudnn_);
    CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
    CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
@@@ -65,13 -65,11 +65,11 @@@ const Tensor CudnnDropout::Forward(int 
    if (flag & kTrain) {
      auto size = input.Size();
      DataType dtype = input.data_type();
--    Device* dev = input.device();
++    auto dev = input.device();
      if (!has_init_cudnn_) {
-       input.device()->Exec(
-           [size, dtype, this, dev](Context* ctx) {
-             this->InitCudnn(size, dtype, dev, ctx);
-           },
-           {}, {this->state_.blob()});
+       input.device()->Exec([size, dtype, this, dev](Context* ctx) {
+         this->InitCudnn(size, dtype, dev, ctx);
+       }, {}, {this->state_.block()});
      }
      Tensor output;
      output.ResetLike(input);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/cudnn_dropout.h
----------------------------------------------------------------------
diff --cc src/model/layer/cudnn_dropout.h
index da3d1d2,83572cf..6809653
--- a/src/model/layer/cudnn_dropout.h
+++ b/src/model/layer/cudnn_dropout.h
@@@ -42,8 -42,9 +42,10 @@@ class CudnnDropout : public Dropout 
    const std::pair<Tensor, vector<Tensor>> Backward(int flag,
                                                     const Tensor& grad) override;
  
+  private:
    /// Init cudnn related data structures.
--  void InitCudnn(int size, DataType dtype, Device* dev, Context* ctx);
++  void InitCudnn(int size, DataType dtype, std::shared_ptr<Device> dev,
++                 Context* ctx);
  
   private:
    bool has_init_cudnn_ = false;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/cudnn_lrn.h
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/cudnn_pooling.cc
----------------------------------------------------------------------
diff --cc src/model/layer/cudnn_pooling.cc
index 842685d,9d288c0..6d7a5b1
--- a/src/model/layer/cudnn_pooling.cc
+++ b/src/model/layer/cudnn_pooling.cc
@@@ -82,7 -82,7 +82,7 @@@ const Tensor CudnnPooling::Forward(int 
    CHECK_EQ(input.nDim(), 4u);
    size_t batchsize = input.shape(0);
    DataType dtype = input.data_type();
--  Device *dev = input.device();
++  auto dev = input.device();
    if (!has_init_cudnn_) InitCudnn(input);
  
    Shape shape{batchsize, channels_, pooled_height_, pooled_width_};

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/cudnn_pooling.h
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/cudnn_softmax.cc
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/dense.cc
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/dense.h
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/dropout.cc
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/dropout.h
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/prelu.cc
----------------------------------------------------------------------
diff --cc src/model/layer/prelu.cc
index 0000000,83a56fa..6eb09d9
mode 000000,100644..100644
--- a/src/model/layer/prelu.cc
+++ b/src/model/layer/prelu.cc
@@@ -1,0 -1,145 +1,145 @@@
+ /**
+  * Licensed to the Apache Software Foundation (ASF) under one
+  * or more contributor license agreements.  See the NOTICE file
+  * distributed with this work for additional information
+  * regarding copyright ownership.  The ASF licenses this file
+  * to you under the Apache License, Version 2.0 (the
+  * "License"); you may not use this file except in compliance
+  * with the License.  You may obtain a copy of the License at
+  *
+  *     http://www.apache.org/licenses/LICENSE-2.0
+  *
+  * Unless required by applicable law or agreed to in writing, software
+  * distributed under the License is distributed on an "AS IS" BASIS,
+  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  * See the License for the specific language governing permissions and
+  * limitations under the License.
+  */
+ 
+ #include "singa/model/layer.h"
+ #include "./prelu.h"
+ namespace singa {
+ 
+ void PReLU::Setup(const Shape& in_sample, const LayerConf &conf) {
+   Layer::Setup(in_sample, conf);
+   out_sample_shape_ = in_sample;
+   channel_shared_ = conf.prelu_conf().channel_shared();
+   format_ = conf.prelu_conf().format();
+   // Push back params into param_values_
+   for (const auto &spec : conf.param()) param_specs_.push_back(spec);
+   param_values_.push_back(&a_);
+ }
+ 
+ const Tensor PReLU::Forward(int flag, const Tensor &input) {
+   Tensor output;
+   if (!channel_shared_) {
+     size_t n, c, h, w;
+     Tensor temp = (input <= 0.f);
+     if (temp.nDim() == 4) {
+       if (format_ == "NCHW") {
+         n = temp.shape(0);
+         c = temp.shape(1);
+         h = temp.shape(2);
+         w = temp.shape(3);
+         temp.Reshape(Shape{n * c, h * w});
+         Tensor temp_a(Shape{n, c}, input.device(), input.data_type());
+         Uniform(1.f, 1.f, &temp_a);
+         MultRow(a_, &temp_a);
+         temp_a.Reshape(Shape{n * c});
+         MultColumn(temp_a, &temp);
+       } else if (format_ == "NHWC") {
+         n = temp.shape(0);
+         h = temp.shape(1);
+         w = temp.shape(2);
+         c = temp.shape(3);
+         temp.Reshape(Shape{n * h * w, c});
+         MultRow(a_, &temp);
+       } else {
+         LOG(FATAL) << "Incorrect input format for prelu layer.";
+       }
+     } else {
+       LOG(FATAL) << "Incorrect input format for prelu layer.";
+     }
+     output = input * ((input > 0.f) + temp);
+   } else {
+     // share the first param of Tensor A along all channels
+     LOG(FATAL) << "Not implemented";
+   // TODO(wangwei) cannot access the data in this way. The data could be on GPU.
+     auto a = a_.data<float>()[0];
+     output = input * ((input > 0.f) + (input <= 0.f) * a);
+   }
+   if (flag & kTrain) buf_.push(input);
+   return output;
+ }
+ 
+ const std::pair<Tensor, vector<Tensor> > PReLU::Backward(int flag,
+                                                          const Tensor &grad) {
+   vector<Tensor> param_grad;
+   CHECK(!buf_.empty());
+   Tensor input_grad, input = buf_.top();
+   buf_.pop();
+   Tensor da;
+   da.ResetLike(a_);
+   if (!channel_shared_) {
+     size_t n, c, h, w;
+     Tensor temp1 = (input <= 0.f);
+     if (temp1.nDim() == 4) {
+       if (format_ == "NCHW") {
+         n = temp1.shape(0);
+         c = temp1.shape(1);
+         h = temp1.shape(2);
+         w = temp1.shape(3);
+         temp1.Reshape(Shape{n * c, h * w});
+         Tensor temp_a(Shape{n, c}, grad.device(), grad.data_type());
+         Uniform(1.f, 1.f, &temp_a);
+         MultRow(a_, &temp_a);
+         temp_a.Reshape(Shape{n * c});
+         MultColumn(temp_a, &temp1);
+         temp1.Reshape(Shape{n, c, h, w});
+       } else if (format_ == "NHWC") {
+         n = temp1.shape(0);
+         h = temp1.shape(1);
+         w = temp1.shape(2);
+         c = temp1.shape(3);
+         temp1.Reshape(Shape{n * h * w, c});
+         MultRow(a_, &temp1);
+         temp1.Reshape(Shape{n, h, w, c});
+       } else {
+         LOG(FATAL) << "Incorrect input format for prelu layer.";
+       }
+     } else {
+       LOG(FATAL) << "Incorrect input format for prelu layer.";
+     }
+     input_grad = grad * input * ((input > 0.f) + temp1);
+     Tensor temp2 = grad * input * (input <= 0.f);
+     if (format_ == "NCHW") {
+       Tensor temp3(Shape{n * c}, grad.device(), grad.data_type());
+       temp2.Reshape(Shape{n * c, h * w});
+       SumColumns(temp2, &temp3);
+       temp3.Reshape(Shape{n, c});
+       SumRows(temp3, &da);
+     } else if (format_ == "NHWC") {
+       temp2.Reshape(Shape{n * h * w, c});
+       SumRows(temp2, &da);
+     }
+   } else {
+     // share the first param of Tensor A along all channels
+     LOG(FATAL) << "Not Implemented";
+     // TODO(wangwei) cannot access the data in this way. The data could be on GPU.
+     auto a = a_.data<float>()[0];
+     input_grad = grad * input * ((input > 0.f) + (input <= 0.f) * a);
+     Tensor temp = grad * input * (input <= 0.f);
+     float sum = Sum<float>(temp);
+     Uniform(1.f, 1.f, &da);
+     da *= sum;
+   }
+   param_grad.push_back(da);
+   return std::make_pair(input_grad, param_grad);
+ }
+ 
 -void PReLU::ToDevice(Device *device) {
++void PReLU::ToDevice(std::shared_ptr<Device> device) {
+   Layer::ToDevice(device);
+   a_.ToDevice(device);
+ }
+ 
+ } // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/model/layer/prelu.h
----------------------------------------------------------------------
diff --cc src/model/layer/prelu.h
index 0000000,ee571e1..70a9dcf
mode 000000,100644..100644
--- a/src/model/layer/prelu.h
+++ b/src/model/layer/prelu.h
@@@ -1,0 -1,66 +1,66 @@@
+ /**
+  * Licensed to the Apache Software Foundation (ASF) under one
+  * or more contributor license agreements.  See the NOTICE file
+  * distributed with this work for additional information
+  * regarding copyright ownership.  The ASF licenses this file
+  * to you under the Apache License, Version 2.0 (the
+  * "License"); you may not use this file except in compliance
+  * with the License.  You may obtain a copy of the License at
+  *
+  *     http://www.apache.org/licenses/LICENSE-2.0
+  *
+  * Unless required by applicable law or agreed to in writing, software
+  * distributed under the License is distributed on an "AS IS" BASIS,
+  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  * See the License for the specific language governing permissions and
+  * limitations under the License.
+  */
+ #ifndef SINGA_MODEL_LAYER_PRELU_H_
+ #define SINGA_MODEL_LAYER_PRELU_H_
+ #include <utility>
+ #include <string>
+ #include <vector>
+ #include "singa/model/layer.h"
 -#include "singa_config.h"
++#include "singa/singa_config.h"
+ 
+ namespace singa {
+ class PReLU : public Layer {
+  public:
+   /// \copydoc Layer::layer_type()
+   const std::string layer_type() const override { return "PReLU"; }
+ 
+ 
+   /// \copydoc Layer::Setup(const LayerConf&);
+   void Setup(const Shape& in_sample, const LayerConf& conf) override;
+   const Shape GetOutputSampleShape() const override {
+     CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
+     return out_sample_shape_;
+   }
+ 
+   /// \copydoc Layer::Forward(int flag, const Tensor&)
+   const Tensor Forward(int flag, const Tensor &input) override;
+ 
+   /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&);
+   const std::pair<Tensor, vector<Tensor> > Backward(
+       int flag, const Tensor &grad) override;
+ 
 -  void ToDevice(Device *device);
++  void ToDevice(std::shared_ptr<Device> device);
+ 
+   const bool Channel_shared() const { return channel_shared_; }
+   const Tensor A() const { return a_; }
+   const std::string Format() const { return format_; }
+ 
+   void Set_a(Tensor a) {
+     a_.ResetLike(a);
+     a_.CopyData(a);
+   }
+ 
+  protected:
+   bool channel_shared_;
+   std::string format_;  // format_ has two valid value, i.e. NCHW, NHWC
+   Tensor a_;            // shape of a_ is 2D, i.e. (channels, 1)
+   std::stack<Tensor> buf_;
+   Shape out_sample_shape_;
+ };
+ }  // namespace singa
+ #endif  // SINGA_MODEL_LAYER_PRELU_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/src/proto/core.proto
----------------------------------------------------------------------
diff --cc src/proto/core.proto
index cf6e193,3031359..b853b30
--- a/src/proto/core.proto
+++ b/src/proto/core.proto
@@@ -44,16 -45,3 +45,16 @@@ enum CopyDirection 
    kDeviceToDevice = 3;
    kNumDirection = 4;
  }
 +
 +// configuration for device memory pool
 +message MemPoolConf {
 +	optional string type = 1 [default = "cnmem"];
 +	optional uint32 num_devices = 2 [default = 1];
- 	// allocation size for each device
- 	optional uint32 alloc_size = 3 [default = 10000000];
++	// allocation size for each device, default is 256 MB
++	optional uint32 alloc_size = 3 [default = 256];
 +	// memory manager flag for cnmem
 +	// cnmemflag = 0: default flag
 +	// cnmemflag = 1: prevent the manager from growing its memory consumption
 +	// cnmemflag = 2: prevent the manager from stealing memory
 +	optional uint32 cnmemflag = 4 [default = 0];
 +}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/test/singa/test_adagrad.cc
----------------------------------------------------------------------
diff --cc test/singa/test_adagrad.cc
index 0000000,642e929..c45dcef
mode 000000,100644..100644
--- a/test/singa/test_adagrad.cc
+++ b/test/singa/test_adagrad.cc
@@@ -1,0 -1,96 +1,96 @@@
+ /************************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *************************************************************/
+ 
+ #include "gtest/gtest.h"
+ #include "singa/model/optimizer.h"
 -#include "singa_config.h"
++#include "singa/singa_config.h"
+ #include <cmath>
+ 
+ TEST(Adagrad, ApplyCPU) {
+   singa::Adagrad adagrad;
+   float lr = 0.1f;
+   const float v[4] = {0.1, 0.2, 0.3, 0.4};
+   const float g[4] = {0.01, 0.02, 0.03, 0.04};
+ 
+   singa::Tensor value(singa::Shape{4}), grad(singa::Shape{4});
+   value.CopyDataFromHostPtr(v, 4);
+   grad.CopyDataFromHostPtr(g, 4);
+ 
+   singa::OptimizerConf conf;
+   adagrad.Setup(conf);
+   adagrad.Apply(0, lr, "xx", grad, &value);
+ 
+   singa::Tensor v1 = value.Clone();
+   const float* newv1 = v1.data<float>();
+   float history[4];
+   for (int i = 0; i < 4; ++i) history[i] = g[i] * g[i];
+   for (int i = 0; i < 4; ++i)
+     EXPECT_NEAR(newv1[i], v[i] - lr * g[i] / sqrt(history[i] + conf.delta()),
+                 1e-5);
+ 
+   grad.CopyDataFromHostPtr(g, 4);
+   adagrad.Apply(1, lr, "xx", grad, &value);
+   singa::Tensor v2 = value.Clone();
+   const float* newv2 = v2.data<float>();
+   for (int i = 0; i < 4; ++i) history[i] += g[i] * g[i];
+ 
+   for (int i = 0; i < 4; ++i)
+     EXPECT_NEAR(newv2[i],
+                 newv1[i] - lr * g[i] / sqrt(history[i] + conf.delta()), 1e-5);
+ }
+ 
+ #ifdef USE_CUDA
+ TEST(Adagrad, ApplyCUDA) {
+   singa::Adagrad adagrad;
+   float lr = 0.1f;
+   const float v[4] = {0.1, 0.2, 0.3, 0.4};
+   const float g[4] = {0.01, 0.02, 0.03, 0.04};
+ 
 -  singa::CudaGPU dev;
 -  singa::Tensor value(singa::Shape{4}, &dev), grad(singa::Shape{4}, &dev);
++  auto dev = std::make_shared<singa::CudaGPU>();
++  singa::Tensor value(singa::Shape{4}, dev), grad(singa::Shape{4}, dev);
+   value.CopyDataFromHostPtr(v, 4);
+   grad.CopyDataFromHostPtr(g, 4);
+ 
+   singa::OptimizerConf conf;
+   adagrad.Setup(conf);
+   adagrad.Apply(0, lr, "xx", grad, &value);
+ 
+   singa::Tensor v1 = value.Clone();
+   v1.ToHost();
+   const float* newv1 = v1.data<float>();
+   float history[4];
+   for (int i = 0; i < 4; ++i) history[i] = g[i] * g[i];
+   for (int i = 0; i < 4; ++i)
+     EXPECT_NEAR(newv1[i], v[i] - lr * g[i] / sqrt(history[i] + conf.delta()),
+                 1e-5);
+ 
+   grad.CopyDataFromHostPtr(g, 4);
+   adagrad.Apply(1, lr, "xx", grad, &value);
+   singa::Tensor v2 = value.Clone();
+   v2.ToHost();
+   const float* newv2 = v2.data<float>();
+   for (int i = 0; i < 4; ++i) history[i] += g[i] * g[i];
+ 
+   for (int i = 0; i < 4; ++i)
+     EXPECT_FLOAT_EQ(newv2[i],
+                     newv1[i] - lr * g[i] / sqrt(history[i] + conf.delta()));
+ }
+ #endif

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/test/singa/test_cross_entropy.cc
----------------------------------------------------------------------
diff --cc test/singa/test_cross_entropy.cc
index 0000000,ce60f7c..d73591f
mode 000000,100644..100644
--- a/test/singa/test_cross_entropy.cc
+++ b/test/singa/test_cross_entropy.cc
@@@ -1,0 -1,116 +1,116 @@@
+ /************************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *************************************************************/
+ 
+ #include "gtest/gtest.h"
+ #include "singa/core/tensor.h"
+ #include "singa/core/device.h"
+ #include "singa/model/loss.h"
 -#include "singa_config.h"
++#include "singa/singa_config.h"
+ 
+ using singa::Tensor;
+ class TestSoftmaxCrossEntropy : public ::testing::Test {
+  protected:
+   virtual void SetUp() {
+     p.Reshape(singa::Shape{2, 4});
+     t.Reshape(singa::Shape{2, 1});
+   }
+   const float pdat[8] = {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+   const int tdat[2] = {0, 2};
+ 
+   singa::Tensor p, t;
+ };
+ 
+ TEST_F(TestSoftmaxCrossEntropy, CppForward) {
+   p.CopyDataFromHostPtr(pdat, 8);
+   t.AsType(singa::kInt);
+   t.CopyDataFromHostPtr(tdat, 2);
+ 
+   singa::SoftmaxCrossEntropy cross_entropy;
+   const Tensor& loss = cross_entropy.Forward(p, t);
+   auto ldat = loss.data<float>();
+ 
+   const float result_test = -log(0.25);
+   EXPECT_FLOAT_EQ(ldat[0], result_test);
+   EXPECT_FLOAT_EQ(ldat[1], result_test);
+ }
+ 
+ TEST_F(TestSoftmaxCrossEntropy, CppBackward) {
+   p.CopyDataFromHostPtr(pdat, 8);
+   t.AsType(singa::kInt);
+   t.CopyDataFromHostPtr(tdat, 2);
+ 
+   singa::SoftmaxCrossEntropy cross_entropy;
+   cross_entropy.Forward(p, t);
+   const Tensor& grad = cross_entropy.Backward();
+ 
+   auto gdat = grad.data<float>();
+   EXPECT_FLOAT_EQ(gdat[0], -0.75);
+   EXPECT_FLOAT_EQ(gdat[1], 0.25);
+   EXPECT_FLOAT_EQ(gdat[2], 0.25);
+   EXPECT_FLOAT_EQ(gdat[3], 0.25);
+   EXPECT_FLOAT_EQ(gdat[4], 0.25);
+   EXPECT_FLOAT_EQ(gdat[5], 0.25);
+   EXPECT_FLOAT_EQ(gdat[6], -0.75);
+   EXPECT_FLOAT_EQ(gdat[7], 0.25);
+ }
+ 
+ #ifdef USE_CUDA
+ 
+ TEST_F(TestSoftmaxCrossEntropy, CudaForward) {
+   singa::SoftmaxCrossEntropy cross_entropy;
 -  singa::CudaGPU dev;
 -  p.ToDevice(&dev);
 -  t.ToDevice(&dev);
++  auto dev = std::make_shared<singa::CudaGPU>();
++  p.ToDevice(dev);
++  t.ToDevice(dev);
+   p.CopyDataFromHostPtr(pdat, 8);
+   t.CopyDataFromHostPtr(tdat, 2);
+ 
+   Tensor loss = cross_entropy.Forward(p, t);
+   loss.ToHost();
+   auto ldat = loss.data<float>();
+ 
+   const float result_test = -log(0.25);
+   EXPECT_FLOAT_EQ(ldat[0], result_test);
+   EXPECT_FLOAT_EQ(ldat[1], result_test);
+ }
+ 
+ TEST_F(TestSoftmaxCrossEntropy, CudaBackward) {
+   singa::SoftmaxCrossEntropy cross_entropy;
 -  singa::CudaGPU dev;
 -  p.ToDevice(&dev);
 -  t.ToDevice(&dev);
++  auto dev = std::make_shared<singa::CudaGPU>();
++  p.ToDevice(dev);
++  t.ToDevice(dev);
+   p.CopyDataFromHostPtr(pdat, 8);
+   t.CopyDataFromHostPtr(tdat, 2);
+ 
+   cross_entropy.Forward(p, t);
+   Tensor grad = cross_entropy.Backward();
+ 
+   grad.ToHost();
+   auto gdat = grad.data<float>();
+   EXPECT_FLOAT_EQ(gdat[0], -0.75);
+   EXPECT_FLOAT_EQ(gdat[1], 0.25);
+   EXPECT_FLOAT_EQ(gdat[2], 0.25);
+   EXPECT_FLOAT_EQ(gdat[3], 0.25);
+   EXPECT_FLOAT_EQ(gdat[4], 0.25);
+   EXPECT_FLOAT_EQ(gdat[5], 0.25);
+   EXPECT_FLOAT_EQ(gdat[6], -0.75);
+   EXPECT_FLOAT_EQ(gdat[7], 0.25);
+ }
+ #endif  // USE_CUDA

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/test/singa/test_cudnn_activation.cc
----------------------------------------------------------------------
diff --cc test/singa/test_cudnn_activation.cc
index bed7715,940c6b9..1a619e7
--- a/test/singa/test_cudnn_activation.cc
+++ b/test/singa/test_cudnn_activation.cc
@@@ -46,8 -46,8 +46,8 @@@ TEST(TCudnnActivation, Setup) 
  TEST(TCudnnActivation, Forward) {
    const float x[] = {1.0f, 2.0f, 3.0f, -2.0f, -3.0f, -4.0};
    size_t n = sizeof(x) / sizeof(float);
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor in(singa::Shape{n}, &cuda);
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor in(singa::Shape{n}, cuda);
    in.CopyDataFromHostPtr<float>(x, n);
  
    float neg_slope = 0.5f;
@@@ -66,9 -65,9 +65,8 @@@
  
      singa::Tensor out = acti.Forward(singa::kTrain, in);
      EXPECT_EQ(n, out.Size());
--    singa::CppCPU host(0, 1);
--    out.ToDevice(&host);
-     const float* yptr = out.data<const float*>();
++    out.ToHost();
+     const float* yptr = out.data<float>();
      float* y = new float[n];
      if (acti.Mode() == "SIGMOID") {
        for (size_t i = 0; i < n; i++) y[i] = 1.f / (1.f + exp(-x[i]));
@@@ -87,8 -86,8 +85,8 @@@
  TEST(TCudnnActivation, Backward) {
    const float x[] = {2.0f, 3.0f, 3.0f, 7.f, 0.0f, 5.0, 1.5, 2.5, -2.5, 1.5};
    size_t n = sizeof(x) / sizeof(float);
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor in(singa::Shape{n}, &cuda);
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor in(singa::Shape{n}, cuda);
    in.CopyDataFromHostPtr<float>(x, n);
    float neg_slope = 0.5f;
    std::string types[] = {"SIGMOID", "TANH", "RELU"};
@@@ -101,22 -100,21 +99,20 @@@
        singa::ReLUConf* reluconf = conf.mutable_relu_conf();
        reluconf->set_negative_slope(neg_slope);
      }
-     acti.Setup(conf);
-     acti.InitCudnn(n, singa::kFloat32);
+     acti.Setup(Shape{n}, conf);
      singa::Tensor out = acti.Forward(singa::kTrain, in);
      EXPECT_EQ(n, out.Size());
--    singa::CppCPU host(0, 1);
--    out.ToDevice(&host);
-     const float* yptr = out.data<const float*>();
++    out.ToHost();
+     const float* yptr = out.data<float>();
  
      const float grad[] = {2.0f, 1.0f, 2.0f, 0.0f, -2.0f,
                            -1.0, 1.5,  2.5,  -1.5, -2.5};
--    singa::Tensor out_diff(singa::Shape{n}, &cuda);
++    singa::Tensor out_diff(singa::Shape{n}, cuda);
      out_diff.CopyDataFromHostPtr<float>(grad, n);
      const auto ret = acti.Backward(singa::kTrain, out_diff);
      singa::Tensor in_diff = ret.first;
--    in_diff.ToDevice(&host);
-     const float* xptr = in_diff.data<const float*>();
++    in_diff.ToHost();
+     const float* xptr = in_diff.data<float>();
      float* dx = new float[n];
      if (acti.Mode() == "SIGMOID") {
        for (size_t i = 0; i < n; i++) dx[i] = grad[i] * yptr[i] * (1. - yptr[i]);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/test/singa/test_cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --cc test/singa/test_cudnn_batchnorm.cc
index d38fdaa,b3b6477..7067b16
--- a/test/singa/test_cudnn_batchnorm.cc
+++ b/test/singa/test_cudnn_batchnorm.cc
@@@ -56,34 -53,31 +53,30 @@@ TEST(CudnnBatchNorm, Forward) 
      0.150676, 0.153442, -0.0929899, -0.148675,
      -0.112459, -0.106284, -0.103074, -0.0668811
    };
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor in(singa::Shape{1,2,4,4}, &cuda);
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor in(singa::Shape{1,2,4,4}, cuda);
    in.CopyDataFromHostPtr(x, 1*2*4*4);
    const float alpha_[] = {1, 1};
--  singa::Tensor alpha(singa::Shape{1,2,1,1}, &cuda);
++  singa::Tensor alpha(singa::Shape{1,2,1,1}, cuda);
    alpha.CopyDataFromHostPtr(alpha_, 1*2*1*1);
  
    const float beta_[] = {0, 0};
--  singa::Tensor beta(singa::Shape{1,2,1,1}, &cuda);
++  singa::Tensor beta(singa::Shape{1,2,1,1}, cuda);
    beta.CopyDataFromHostPtr(beta_, 1*2*1*1);
  
    singa::LayerConf conf;
    singa::BatchNormConf *batchnorm_conf = conf.mutable_batchnorm_conf();
    batchnorm_conf->set_factor(0.9);
-   batchnorm_conf->set_channels(2);
-   batchnorm_conf->set_height(4);
-   batchnorm_conf->set_width(4);
-   batchnorm.Setup(conf);
+   batchnorm.Setup(Shape{2, 4, 4}, conf);
  
--  batchnorm.ToDevice(&cuda);
++  batchnorm.ToDevice(cuda);
    batchnorm.set_bnScale(alpha);
    batchnorm.set_bnBias(beta);
    batchnorm.set_runningMean(beta);
    batchnorm.set_runningVariance(beta);
    singa::Tensor out = batchnorm.Forward(singa::kTrain, in);
--  singa::CppCPU host(0, 1);
    out.ToHost();
-   const float *outptr = out.data<const float *>();
+   const float *outptr = out.data<float>();
    const auto & shape = out.shape();
    EXPECT_EQ(4u, shape.size());
    EXPECT_EQ(1u, shape[0]);
@@@ -136,8 -130,8 +129,8 @@@ TEST(CudnnBatchNorm, Backward) 
      0.150676, 0.153442, -0.0929899, -0.148675,
      -0.112459, -0.106284, -0.103074, -0.0668811
    };
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor x_tensor(singa::Shape{1,2,4,4}, &cuda);
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor x_tensor(singa::Shape{1,2,4,4}, cuda);
    x_tensor.CopyDataFromHostPtr(x, 1*2*4*4);
  
    singa::LayerConf conf;
@@@ -159,35 -150,35 +149,34 @@@
      0.00468428, 0.00735506, -0.00682525, 0.00342023
    };
  
--  singa::Tensor dy_tensor(singa::Shape{1,2,4,4}, &cuda);
++  singa::Tensor dy_tensor(singa::Shape{1,2,4,4}, cuda);
    dy_tensor.CopyDataFromHostPtr(dy, 1*2*4*4);
    const float alpha_[] = {1, 1};
--  singa::Tensor alpha(singa::Shape{1,2,1,1}, &cuda);
++  singa::Tensor alpha(singa::Shape{1,2,1,1}, cuda);
    alpha.CopyDataFromHostPtr(alpha_, 1*2*1*1);
  
    const float beta_[] = {0, 0};
--  singa::Tensor beta(singa::Shape{1,2,1,1}, &cuda);
++  singa::Tensor beta(singa::Shape{1,2,1,1}, cuda);
    beta.CopyDataFromHostPtr(beta_, 1*2*1*1);
  
    const float mean_[] = {0.0123405, -0.0622333};
--  singa::Tensor mean(singa::Shape{1,2,1,1}, &cuda);
++  singa::Tensor mean(singa::Shape{1,2,1,1}, cuda);
    mean.CopyDataFromHostPtr(mean_, 1*2*1*1);
  
    const float var_[] = {15.9948, 8.68198};
--  singa::Tensor var(singa::Shape{1,2,1,1}, &cuda);
++  singa::Tensor var(singa::Shape{1,2,1,1}, cuda);
    var.CopyDataFromHostPtr(var_, 1*2*1*1);
  
--  batchnorm.ToDevice(&cuda);
++  batchnorm.ToDevice(cuda);
    batchnorm.set_bnScale(alpha);
    batchnorm.set_bnBias(beta);
    batchnorm.set_runningMean(beta);
    batchnorm.set_runningVariance(beta);
    batchnorm.Forward(singa::kTrain, x_tensor);
    const auto ret = batchnorm.Backward(singa::kTrain, dy_tensor);
--  singa::CppCPU host(0, 1);
    singa::Tensor dx = ret.first;
--  dx.ToDevice(&host);
-   const float *dxptr = dx.data<const float *>();
++  dx.ToHost();
+   const float *dxptr = dx.data<float>();
    const auto & shape = dx.shape();
    EXPECT_EQ(4u, shape.size());
    EXPECT_EQ(1u, shape[0]);
@@@ -228,8 -219,8 +217,8 @@@
    EXPECT_NEAR(0.0217477, dxptr[31], 1e-4f);
  
    singa::Tensor dbnScale = ret.second.at(0);
--  dbnScale.ToDevice(&host);
-   const float *dbnScaleptr = dbnScale.data<const float *>();
++  dbnScale.ToHost();
+   const float *dbnScaleptr = dbnScale.data<float>();
    const auto & dbnScaleShape = dbnScale.shape();
    EXPECT_EQ(4u, dbnScaleShape.size());
    EXPECT_EQ(1u, dbnScaleShape[0]);
@@@ -241,8 -232,8 +230,8 @@@
    EXPECT_NEAR(-0.00219431f, dbnScaleptr[1], 1e-4f);
  
    singa::Tensor dbnBias = ret.second.at(1);
--  dbnBias.ToDevice(&host);
-   const float *dbnBiasptr = dbnBias.data<const float *>();
++  dbnBias.ToHost();
+   const float *dbnBiasptr = dbnBias.data<float>();
    const auto & dbnBiasShape = dbnBias.shape();
    EXPECT_EQ(4u, dbnBiasShape.size());
    EXPECT_EQ(1u, dbnBiasShape[0]);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/test/singa/test_cudnn_convolution.cc
----------------------------------------------------------------------
diff --cc test/singa/test_cudnn_convolution.cc
index 2a17da2,44077b7..3b84645
--- a/test/singa/test_cudnn_convolution.cc
+++ b/test/singa/test_cudnn_convolution.cc
@@@ -65,18 -63,18 +63,18 @@@ TEST(CudnnConvolution, Forward) 
    const size_t batchsize = 1, c = 1, h = 3, w = 3;
    const float x[batchsize * c * h * w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
                                            6.0f, 7.0f, 8.0f, 9.0f};
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor in(singa::Shape{batchsize, c, h, w}, &cuda);
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor in(singa::Shape{batchsize, c, h, w}, cuda);
    in.CopyDataFromHostPtr(x, batchsize * c * h * w);
  
    // Set weight and bias manually
    const size_t num_filters = 1;
    const float we[num_filters * batchsize * h * w] = {
        1.0f, 1.0f, 0.0f, 0.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f};
--  singa::Tensor weight(singa::Shape{num_filters, batchsize * h * w}, &cuda);
++  singa::Tensor weight(singa::Shape{num_filters, batchsize * h * w}, cuda);
    weight.CopyDataFromHostPtr(we, batchsize * h * w);
    const float b[num_filters] = {1.0f};
--  singa::Tensor bias(singa::Shape{num_filters}, &cuda);
++  singa::Tensor bias(singa::Shape{num_filters}, cuda);
    bias.CopyDataFromHostPtr(b, num_filters);
    CudnnConvolution conv;
    conv.set_weight(weight);
@@@ -102,9 -97,9 +97,8 @@@
  
    // Parameter "flag" does not influence convolution
    singa::Tensor out1 = conv.Forward(singa::kTrain, in);
--  singa::CppCPU host(0, 1);
--  out1.ToDevice(&host);
-   const float *outptr1 = out1.data<const float *>();
++  out1.ToHost();
+   const float *outptr1 = out1.data<float>();
    // Input: 3*3; kernel: 3*3; stride: 2*2; padding: 1*1.
    EXPECT_EQ(4u, out1.Size());
  
@@@ -119,8 -114,8 +113,8 @@@ TEST(CudnnConvolution, Backward) 
    const size_t batchsize = 1, c = 1, src_h = 3, src_w = 3;
    const float x[batchsize * c * src_h * src_w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
                                                    6.0f, 7.0f, 8.0f, 9.0f};
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor in(singa::Shape{batchsize, c, src_h, src_w}, &cuda);
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor in(singa::Shape{batchsize, c, src_h, src_w}, cuda);
    in.CopyDataFromHostPtr(x, batchsize * c * src_h * src_w);
  
    // Set weight_ and bias_ manually
@@@ -128,10 -123,10 +122,10 @@@
    const float we[num_filters * batchsize * src_h * src_w] = {
        1.0f, 1.0f, 0.0f, 0.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f};
    singa::Tensor weight(singa::Shape{num_filters, batchsize * src_h * src_w},
--                       &cuda);
++                       cuda);
    weight.CopyDataFromHostPtr(we, batchsize * src_h * src_w);
    const float b[num_filters] = {1.0f};
--  singa::Tensor bias(singa::Shape{num_filters}, &cuda);
++  singa::Tensor bias(singa::Shape{num_filters}, cuda);
    bias.CopyDataFromHostPtr(b, num_filters);
    CudnnConvolution conv;
    conv.set_weight(weight);
@@@ -162,14 -154,14 +153,13 @@@
    const float dy[batchsize * num_filters * grad_h * grad_w] = {0.1f, 0.2f, 0.3f,
                                                                 0.4f};
    singa::Tensor grad(singa::Shape{batchsize, num_filters, grad_h, grad_w},
--                     &cuda);
++                     cuda);
    grad.CopyDataFromHostPtr(dy, batchsize * num_filters * grad_h * grad_w);
  
    const auto ret = conv.Backward(singa::kTrain, grad);
--  singa::CppCPU host(0, 1);
    singa::Tensor in_grad = ret.first;
--  in_grad.ToDevice(&host);
-   const float *dx = in_grad.data<const float *>();
++  in_grad.ToHost();
+   const float *dx = in_grad.data<float>();
    const float *wptr = we;
    EXPECT_EQ(9u, in_grad.Size());
    EXPECT_EQ(dy[0] * wptr[4], dx[0]);
@@@ -186,12 -178,12 +176,12 @@@
  
    singa::Tensor dw = ret.second[0];
    singa::Tensor db = ret.second[1];
--  dw.ToDevice(&host);
--  db.ToDevice(&host);
-   const float *dbptr = db.data<const float *>();
++  dw.ToHost();
++  db.ToHost();
+   const float *dbptr = db.data<float>();
    EXPECT_EQ(dy[0] + dy[1] + dy[2] + dy[3], dbptr[0]);
  
-   const float *dwptr = dw.data<const float *>();
+   const float *dwptr = dw.data<float>();
    EXPECT_EQ(9u, dw.Size());
    EXPECT_EQ(dy[3] * x[4], dwptr[0]);
    EXPECT_EQ(dy[3] * x[5] + dy[2] * x[3], dwptr[1]);
@@@ -246,18 -235,18 +233,19 @@@ TEST(CudnnConvolution_AT, Forward) 
    const size_t batchsize = 1, c = 1, h = 3, w = 3;
    const float x[batchsize * c * h * w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
                                            6.0f, 7.0f, 8.0f, 9.0f};
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor in(singa::Shape{batchsize, c, h, w}, &cuda);
++
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor in(singa::Shape{batchsize, c, h, w}, cuda);
    in.CopyDataFromHostPtr(x, batchsize * c * h * w);
  
    // Set weight and bias manually
    const size_t num_filters = 1;
    const float we[num_filters * batchsize * h * w] = {
        1.0f, 1.0f, 0.0f, 0.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f};
--  singa::Tensor weight(singa::Shape{num_filters, batchsize * h * w}, &cuda);
++  singa::Tensor weight(singa::Shape{num_filters, batchsize * h * w}, cuda);
    weight.CopyDataFromHostPtr(we, batchsize * h * w);
    const float b[num_filters] = {1.0f};
--  singa::Tensor bias(singa::Shape{num_filters}, &cuda);
++  singa::Tensor bias(singa::Shape{num_filters}, cuda);
    bias.CopyDataFromHostPtr(b, num_filters);
    CudnnConvolution conv;
    conv.set_weight(weight);
@@@ -283,9 -269,9 +268,8 @@@
  
    // Parameter "flag" does not influence convolution
    singa::Tensor out1 = conv.Forward(singa::kTrain, in);
--  singa::CppCPU host(0, 1);
--  out1.ToDevice(&host);
-   const float *outptr1 = out1.data<const float *>();
++  out1.ToHost();
+   const float *outptr1 = out1.data<float>();
    // Input: 3*3; kernel: 3*3; stride: 2*2; padding: 1*1.
    EXPECT_EQ(4u, out1.Size());
  
@@@ -300,8 -286,8 +284,9 @@@ TEST(CudnnConvolution_AT, Backward) 
    const size_t batchsize = 1, c = 1, src_h = 3, src_w = 3;
    const float x[batchsize * c * src_h * src_w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
                                                    6.0f, 7.0f, 8.0f, 9.0f};
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor in(singa::Shape{batchsize, c, src_h, src_w}, &cuda);
++
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor in(singa::Shape{batchsize, c, src_h, src_w}, cuda);
    in.CopyDataFromHostPtr(x, batchsize * c * src_h * src_w);
  
    // Set weight_ and bias_ manually
@@@ -309,10 -295,10 +294,10 @@@
    const float we[num_filters * batchsize * src_h * src_w] = {
        1.0f, 1.0f, 0.0f, 0.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f};
    singa::Tensor weight(singa::Shape{num_filters, batchsize * src_h * src_w},
--                       &cuda);
++                       cuda);
    weight.CopyDataFromHostPtr(we, batchsize * src_h * src_w);
    const float b[num_filters] = {1.0f};
--  singa::Tensor bias(singa::Shape{num_filters}, &cuda);
++  singa::Tensor bias(singa::Shape{num_filters}, cuda);
    bias.CopyDataFromHostPtr(b, num_filters);
    CudnnConvolution conv;
    conv.set_weight(weight);
@@@ -343,14 -326,14 +325,13 @@@
    const float dy[batchsize * num_filters * grad_h * grad_w] = {0.1f, 0.2f, 0.3f,
                                                                 0.4f};
    singa::Tensor grad(singa::Shape{batchsize, num_filters, grad_h, grad_w},
--                     &cuda);
++                     cuda);
    grad.CopyDataFromHostPtr(dy, batchsize * num_filters * grad_h * grad_w);
  
    const auto ret = conv.Backward(singa::kTrain, grad);
--  singa::CppCPU host(0, 1);
    singa::Tensor in_grad = ret.first;
--  in_grad.ToDevice(&host);
-   const float *dx = in_grad.data<const float *>();
++  in_grad.ToHost();
+   const float *dx = in_grad.data<float>();
    const float *wptr = we;
    EXPECT_EQ(9u, in_grad.Size());
    EXPECT_EQ(dy[0] * wptr[4], dx[0]);
@@@ -367,12 -350,12 +348,12 @@@
  
    singa::Tensor dw = ret.second[0];
    singa::Tensor db = ret.second[1];
--  dw.ToDevice(&host);
--  db.ToDevice(&host);
-   const float *dbptr = db.data<const float *>();
++  dw.ToHost();
++  db.ToHost();
+   const float *dbptr = db.data<float>();
    EXPECT_EQ(dy[0] + dy[1] + dy[2] + dy[3], dbptr[0]);
  
-   const float *dwptr = dw.data<const float *>();
+   const float *dwptr = dw.data<float>();
    EXPECT_EQ(9u, dw.Size());
    EXPECT_EQ(dy[3] * x[4], dwptr[0]);
    EXPECT_EQ(dy[3] * x[5] + dy[2] * x[3], dwptr[1]);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/test/singa/test_cudnn_dropout.cc
----------------------------------------------------------------------
diff --cc test/singa/test_cudnn_dropout.cc
index 32572d0,419dd0c..d06a254
--- a/test/singa/test_cudnn_dropout.cc
+++ b/test/singa/test_cudnn_dropout.cc
@@@ -48,8 -49,8 +49,8 @@@ TEST(CudnnDropout, Setup) 
  TEST(CudnnDropout, Forward) {
    const float x[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
    size_t n = sizeof(x) / sizeof(float);
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor in(singa::Shape{n}, &cuda);
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor in(singa::Shape{n}, cuda);
    in.CopyDataFromHostPtr(x, n);
  
    float pdrop = 0.5;
@@@ -67,9 -68,9 +68,8 @@@
    for (size_t i = 0; i < n; i++)
      EXPECT_FLOAT_EQ(0, GetBitValue(mptr, i) * (GetBitValue(mptr, i) - 1));
  
--  singa::CppCPU host(0, 1);
--  out1.ToDevice(&host);
-   const float* outptr1 = out1.data<const float*>();
++  out1.ToHost();
+   const float* outptr1 = out1.data<float>();
    EXPECT_EQ(n, out1.Size());
    float scale = 1.0f / (1.0f - pdrop);
    // the output value should be 0 or the same as the input
@@@ -78,9 -79,9 +78,9 @@@
    EXPECT_EQ(0.f, outptr1[7] * (outptr1[7] - scale * x[7]));
  
    singa::Tensor out2 = drop.Forward(singa::kEval, in);
--  out2.ToDevice(&host);
++  out2.ToHost();
    EXPECT_EQ(n, out2.Size());
-   const float* outptr2 = out2.data<const float*>();
+   const float* outptr2 = out2.data<float>();
    // the output value should be the same as the input
    EXPECT_EQ(x[0], outptr2[0]);
    EXPECT_EQ(x[1], outptr2[1]);
@@@ -90,8 -91,8 +90,8 @@@
  TEST(CudnnDropout, Backward) {
    const float x[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
    size_t n = sizeof(x) / sizeof(float);
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor in(singa::Shape{n}, &cuda);
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor in(singa::Shape{n}, cuda);
    in.CopyDataFromHostPtr(x, n);
  
    float pdrop = 0.5;
@@@ -105,14 -106,14 +105,13 @@@
    singa::Tensor out1 = drop.Forward(singa::kTrain, in);
  
    const float dy[] = {4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 2.0f, 3.0f};
--  singa::Tensor grad(singa::Shape{n}, &cuda);
++  singa::Tensor grad(singa::Shape{n}, cuda);
    grad.CopyDataFromHostPtr(dy, n);
  
    const auto ret = drop.Backward(singa::kTrain, grad);
--  singa::CppCPU host(0, 1);
    singa::Tensor in_grad = ret.first;
--  in_grad.ToDevice(&host);
-   const float* dx = in_grad.data<const float*>();
++  in_grad.ToHost();
+   const float* dx = in_grad.data<float>();
  
    singa::Tensor mask(drop.mask().shape(), drop.mask().data_type());
    mask.CopyData(drop.mask());

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/test/singa/test_cudnn_lrn.cc
----------------------------------------------------------------------
diff --cc test/singa/test_cudnn_lrn.cc
index 390c588,f7ec046..4ee0c54
--- a/test/singa/test_cudnn_lrn.cc
+++ b/test/singa/test_cudnn_lrn.cc
@@@ -58,8 -58,8 +58,8 @@@ TEST(CudnnLRN, Forward) 
      0.0597329, -0.0530868, 0.0124246, 0.108429,
      0.0451175, 0.0247055, 0.0304345, 0.0179575
    };
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor in(singa::Shape{1,2,4,4}, &cuda);
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor in(singa::Shape{1,2,4,4}, cuda);
    in.CopyDataFromHostPtr(x, 1*2*4*4);
  
    singa::LayerConf conf;
@@@ -68,12 -68,12 +68,11 @@@
    lrn_conf->set_local_size(3);
    lrn_conf->set_alpha(0.1);
    lrn_conf->set_beta(0.75);
-   lrn.Setup(conf);
+   lrn.Setup(Shape{2, 4, 4}, conf);
  
    singa::Tensor out = lrn.Forward(singa::kTrain, in);
--  singa::CppCPU host(0, 1);
--  out.ToDevice(&host);
-   const float *outptr = out.data<const float *>();
++  out.ToHost();
+   const float *outptr = out.data<float>();
    const auto & shape = out.shape();
    EXPECT_EQ(4u, shape.size());
    EXPECT_EQ(1u, shape[0]);
@@@ -128,8 -128,8 +127,8 @@@ TEST(CudnnLRN, Backward) 
      0.0597329, -0.0530868, 0.0124246, 0.108429,
      0.0451175, 0.0247055, 0.0304345, 0.0179575
    };
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor x_tensor(singa::Shape{1,2,4,4}, &cuda);
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor x_tensor(singa::Shape{1,2,4,4}, cuda);
    x_tensor.CopyDataFromHostPtr(x, 1*2*4*4);
  
    const float dy[] = {
@@@ -143,7 -143,7 +142,7 @@@
      0.177807, 0.000892812, -0.00113197, 0.00327798
    };
  
--  singa::Tensor dy_tensor(singa::Shape{1,2,4,4}, &cuda);
++  singa::Tensor dy_tensor(singa::Shape{1,2,4,4}, cuda);
    dy_tensor.CopyDataFromHostPtr(dy, 1*2*4*4);
  
    singa::LayerConf conf;
@@@ -156,10 -156,10 +155,9 @@@
  
    lrn.Forward(singa::kTrain, x_tensor);
    const auto ret = lrn.Backward(singa::kTrain, dy_tensor);
--  singa::CppCPU host(0, 1);
    singa::Tensor dx = ret.first;
--  dx.ToDevice(&host);
-   const float *dxptr = dx.data<const float *>();
++  dx.ToHost();
+   const float *dxptr = dx.data<float>();
    const auto & shape = dx.shape();
    EXPECT_EQ(4u, shape.size());
    EXPECT_EQ(1u, shape[0]);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dd08f413/test/singa/test_cudnn_pooling.cc
----------------------------------------------------------------------
diff --cc test/singa/test_cudnn_pooling.cc
index e66f212,2a98ab4..79051a3
--- a/test/singa/test_cudnn_pooling.cc
+++ b/test/singa/test_cudnn_pooling.cc
@@@ -58,8 -56,8 +56,8 @@@ TEST(CudnnPooling, Forward) 
    const size_t batchsize = 1, c = 1, h = 3, w = 3;
    const float x[batchsize * c * h * w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
                                            6.0f, 7.0f, 8.0f, 9.0f};
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor in(singa::Shape{batchsize, c, h, w}, &cuda);
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor in(singa::Shape{batchsize, c, h, w}, cuda);
    in.CopyDataFromHostPtr(x, batchsize * c * h * w);
  
    CudnnPooling pool;
@@@ -79,9 -74,9 +74,8 @@@
  
    // Parameter "flag" does not influence pooling
    singa::Tensor out1 = pool.Forward(singa::kTrain, in);
--  singa::CppCPU host(0, 1);
--  out1.ToDevice(&host);
-   const float *outptr1 = out1.data<const float *>();
++  out1.ToHost();
+   const float *outptr1 = out1.data<float>();
    // Input: 3*3; kernel: 2*2; stride: 1*1; no padding.
    EXPECT_EQ(4u, out1.Size());
    EXPECT_EQ(5.0f, outptr1[0]);
@@@ -95,8 -90,8 +89,8 @@@ TEST(CudnnPooling, Backward) 
    const size_t batchsize = 1, c = 1, src_h = 3, src_w = 3;
    const float x[batchsize * src_h * src_w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
                                                6.0f, 7.0f, 8.0f, 9.0f};
--  singa::CudaGPU cuda(0, 1);
--  singa::Tensor in(singa::Shape{batchsize, c, src_h, src_w}, &cuda);
++  auto cuda = std::make_shared<singa::CudaGPU>(0, 1);
++  singa::Tensor in(singa::Shape{batchsize, c, src_h, src_w}, cuda);
    in.CopyDataFromHostPtr(x, batchsize * c * src_h * src_w);
  
    CudnnPooling pool;
@@@ -119,14 -111,14 +110,13 @@@
    // grad
    const size_t grad_h = 2, grad_w = 2;
    const float dy[batchsize * c * grad_h * grad_w] = {0.1f, 0.2f, 0.3f, 0.4f};
--  singa::Tensor grad(singa::Shape{batchsize, c, grad_h, grad_w}, &cuda);
++  singa::Tensor grad(singa::Shape{batchsize, c, grad_h, grad_w}, cuda);
    grad.CopyDataFromHostPtr(dy, batchsize * c * grad_h * grad_w);
  
    const auto ret = pool.Backward(singa::kTrain, grad);
--  singa::CppCPU host(0, 1);
    singa::Tensor in_grad = ret.first;
--  in_grad.ToDevice(&host);
-   const float *dx = in_grad.data<const float *>();
++  in_grad.ToHost();
+   const float *dx = in_grad.data<float>();
    EXPECT_EQ(9u, in_grad.Size());
    EXPECT_EQ(0.0f, dx[0]);
    EXPECT_EQ(0.0f, dx[1]);