You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by ka...@apache.org on 2017/04/11 09:01:50 UTC

incubator-singa git commit: SINGA-309 Update the layer setting/config dynamically

Repository: incubator-singa
Updated Branches:
  refs/heads/master 85dbad744 -> 38da78914


SINGA-309 Update the layer setting/config dynamically

Update the layer settings and internal variables (e.g cudnn descriptors) if the input data's batchsize is changed from previous iteration.

TODO(wangwei) update the layer settings if the input sample shape changes.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/38da7891
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/38da7891
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/38da7891

Branch: refs/heads/master
Commit: 38da7891440f912845a27b4e77985c6173e1edbe
Parents: 85dbad7
Author: wangwei <wa...@comp.nus.edu.sg>
Authored: Mon Apr 10 19:14:15 2017 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Tue Apr 11 13:59:45 2017 +0800

----------------------------------------------------------------------
 src/model/layer/batchnorm.cc          |   1 +
 src/model/layer/concat.cc             |   1 +
 src/model/layer/convolution.cc        |   3 +
 src/model/layer/cudnn_activation.cc   |  38 +++++++----
 src/model/layer/cudnn_batchnorm.cc    |  35 +++++++---
 src/model/layer/cudnn_convolution.cc  |  33 +++++++---
 src/model/layer/cudnn_dropout.cc      |  13 +++-
 src/model/layer/cudnn_lrn.cc          |  34 +++++++---
 src/model/layer/cudnn_pooling.cc      |  27 ++++++--
 src/model/layer/cudnn_softmax.cc      |  21 ++++--
 src/model/layer/opencl_convolution.cc | 100 +++++++++++++++--------------
 src/model/layer/opencl_pooling.cc     |  65 ++++++++++---------
 src/model/layer/pooling.cc            |   5 ++
 src/model/layer/slice.cc              |   1 +
 14 files changed, 248 insertions(+), 129 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.cc b/src/model/layer/batchnorm.cc
index e07dfd9..4e74a82 100644
--- a/src/model/layer/batchnorm.cc
+++ b/src/model/layer/batchnorm.cc
@@ -70,6 +70,7 @@ const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
   x.Reshape(Shape{input.shape(0), input.Size() / input.shape(0)});
   Tensor output, mean, var, xnorm;
   output.ResetLike(x);
+  // TODO(wangwei) input sample shape check
 
   if ((flag & kTrain) == kTrain) {  // forward for train
     if (is_2d_) {                   // batchnorm_per_activation mode

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/concat.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/concat.cc b/src/model/layer/concat.cc
index a94b68e..6e071ea 100644
--- a/src/model/layer/concat.cc
+++ b/src/model/layer/concat.cc
@@ -54,6 +54,7 @@ void Concat::Setup(const vector<Shape>& in_shapes, const LayerConf& conf) {
 }
 
 const vector<Tensor> Concat::Forward(int flag, const vector<Tensor>& inputs) {
+  // TODO(wangwei) check the inputs shape to be the same for all iterations
   vector<Tensor> outputs;
   slice_point_.clear();
   size_t offset = 0;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/convolution.cc b/src/model/layer/convolution.cc
index 8940fb2..3fc7afb 100644
--- a/src/model/layer/convolution.cc
+++ b/src/model/layer/convolution.cc
@@ -111,6 +111,9 @@ const Tensor Convolution::Forward(int flag, const Tensor &input) {
   if (flag & kTrain) buf_.push(input);
   size_t batchsize = input.shape(0);
   size_t imagesize = input.Size() / batchsize;
+  // TODO(wangwei) update the layer config if the input sample shape changes
+  CHECK(input.shape(1) == channels_ && input.shape(2) == height_ &&
+      input.shape(3) == width_) << "input sample shape should not change";
   DataType dtype = input.data_type();
   auto dev = input.device();
   Shape shape{batchsize, num_filters_, conv_height_, conv_width_};

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/cudnn_activation.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_activation.cc b/src/model/layer/cudnn_activation.cc
index 756b625..ff520b8 100644
--- a/src/model/layer/cudnn_activation.cc
+++ b/src/model/layer/cudnn_activation.cc
@@ -35,24 +35,26 @@ CudnnActivation::~CudnnActivation() {
 }
 
 void CudnnActivation::InitCudnn(size_t size, DataType dtype) {
-  CHECK(!has_init_cudnn_);
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&desc_));
-  CUDNN_CHECK(cudnnCreateActivationDescriptor(&acti_desc_));
+  if (!has_init_cudnn_) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&desc_));
+    CUDNN_CHECK(cudnnCreateActivationDescriptor(&acti_desc_));
+
+    if (mode_ == "sigmoid")
+      cudnn_mode_ = CUDNN_ACTIVATION_SIGMOID;
+    else if (mode_ == "tanh")
+      cudnn_mode_ = CUDNN_ACTIVATION_TANH;
+    else if (mode_ == "relu")
+      cudnn_mode_ = CUDNN_ACTIVATION_RELU;
+    else
+      LOG(FATAL) << "Unkown activation: " << mode_;
+
+    CUDNN_CHECK(cudnnSetActivationDescriptor(
+          acti_desc_, cudnn_mode_, CUDNN_PROPAGATE_NAN, 0.0f));
+  }
 
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(
       desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, 1, 1, size));
 
-  if (mode_ == "sigmoid")
-    cudnn_mode_ = CUDNN_ACTIVATION_SIGMOID;
-  else if (mode_ == "tanh")
-    cudnn_mode_ = CUDNN_ACTIVATION_TANH;
-  else if (mode_ == "relu")
-    cudnn_mode_ = CUDNN_ACTIVATION_RELU;
-  else
-    LOG(FATAL) << "Unkown activation: " << mode_;
-
-  CUDNN_CHECK(cudnnSetActivationDescriptor(
-        acti_desc_, cudnn_mode_, CUDNN_PROPAGATE_NAN, 0.0f));
   has_init_cudnn_ = true;
 }
 
@@ -62,7 +64,15 @@ const Tensor CudnnActivation::Forward(int flag, const Tensor& input) {
   DataType dtype = input.data_type();
   if (!has_init_cudnn_) {
     InitCudnn(size, dtype);
+  } else {
+    int n, c, h, w, s;
+    cudnnDataType_t type;
+    CUDNN_CHECK(cudnnGetTensor4dDescriptor(desc_,
+          &type, &n, &c, &h, &w, &s, &s, &s, &s));
+    if (size != static_cast<size_t>(w))
+      InitCudnn(size, dtype);
   }
+
   Tensor output;
   output.ResetLike(input);
   output.device()->Exec([input, output, this](Context* ctx) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.cc b/src/model/layer/cudnn_batchnorm.cc
index 19a2ccb..1dbb05b 100644
--- a/src/model/layer/cudnn_batchnorm.cc
+++ b/src/model/layer/cudnn_batchnorm.cc
@@ -44,13 +44,14 @@ void CudnnBatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
 }
 
 void CudnnBatchNorm::InitCudnn(const Shape& shape, DataType dtype) {
-  CHECK(!has_init_cudnn_);
-  if (is_2d_)
-    mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
-  else
-    mode_ = CUDNN_BATCHNORM_SPATIAL;
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_));
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc_));
+  if (!has_init_cudnn_) {
+    if (is_2d_)
+      mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
+    else
+      mode_ = CUDNN_BATCHNORM_SPATIAL;
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_));
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc_));
+  }
   CHECK_EQ(shape.size(), 4u);
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc_, CUDNN_TENSOR_NCHW,
                                          GetCudnnDataType(dtype), shape[0],
@@ -70,7 +71,25 @@ const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) {
   else
     x = input;
   shape = x.shape();
-  if (!has_init_cudnn_) InitCudnn(shape, dtype);
+  if (!has_init_cudnn_) {
+    InitCudnn(shape, dtype);
+  } else {
+    int n, c, h, w, s;
+    cudnnDataType_t type;
+    CUDNN_CHECK(cudnnGetTensor4dDescriptor(shape_desc_, &type,
+          &n, &c, &h, &w, &s, &s, &s, &s));
+    if (shape[0] != static_cast<size_t>(n))
+      InitCudnn(shape, dtype);
+    CHECK(input.shape(1) == static_cast<size_t>(c)
+        && input.shape(2) == static_cast<size_t>(h)
+        && input.shape(3) == static_cast<size_t>(w))
+      << "input sample shape should not change"
+      << "previous shape " << c << ", " << h << ", " << w
+      << "current shape " << input.shape(1) << ", " << input.shape(2) << ", "
+      << input.shape(3);
+  }
+
+
   // TODO(wangji): check device id of input and params
   output.ResetLike(x);
   if ((flag & kTrain) == kTrain) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index 03ad8b9..cc78499 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -53,17 +53,18 @@ void CudnnConvolution::ToDevice(std::shared_ptr<Device> device) {
 }
 
 void CudnnConvolution::InitCudnn(const Tensor &input) {
-  CHECK(!has_init_cudnn_);
   DataType dtype = input.data_type();
   auto dev = input.device();
   Context *ctx = dev->context(0);
   size_t batchsize = input.shape(0);
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
-  if (bias_term_)
-    CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
-  CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
-  CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_));
+  if (!has_init_cudnn_) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
+    if (bias_term_)
+      CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
+    CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
+    CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_));
+  }
 
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
                                          GetCudnnDataType(dtype), batchsize,
@@ -170,7 +171,23 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) {
   DataType dtype = input.data_type();
   auto dev = input.device();
 
-  if (!has_init_cudnn_) InitCudnn(input);
+  if (!has_init_cudnn_) {
+    InitCudnn(input);
+  } else {
+    int n, c, h, w, s;
+    cudnnDataType_t type;
+    CUDNN_CHECK(cudnnGetTensor4dDescriptor(x_desc_, &type, &n, &c, &h, &w,
+          &s, &s, &s, &s));
+    if (batchsize != static_cast<size_t>(n))
+      InitCudnn(input);
+    CHECK(input.shape(1) == static_cast<size_t>(c)
+        && input.shape(2) == static_cast<size_t>(h)
+        && input.shape(3) == static_cast<size_t>(w))
+      << "input sample shape should not change"
+      << "previous shape " << c << ", " << h << ", " << w
+      << "current shape " << input.shape(1) << ", " << input.shape(2) << ", "
+      << input.shape(3);
+  }
 
   Shape shape{batchsize, num_filters_, conv_height_, conv_width_};
   Tensor output(shape, dev, dtype);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/cudnn_dropout.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_dropout.cc b/src/model/layer/cudnn_dropout.cc
index e05a425..65d7b42 100644
--- a/src/model/layer/cudnn_dropout.cc
+++ b/src/model/layer/cudnn_dropout.cc
@@ -69,8 +69,17 @@ const Tensor CudnnDropout::Forward(int flag, const Tensor& input) {
     auto dev = input.device();
     if (!has_init_cudnn_) {
       input.device()->Exec([size, dtype, this, dev](Context* ctx) {
-        this->InitCudnn(size, dtype, dev, ctx);
-      }, {}, {this->state_.block()});
+          this->InitCudnn(size, dtype, dev, ctx);
+          }, {}, {this->state_.block()});
+    } else {
+      int n, c, h, w, s;
+      cudnnDataType_t type;
+      CUDNN_CHECK(cudnnGetTensor4dDescriptor(x_desc_, &type,
+            &n, &c, &h, &w, &s, &s, &s, &s));
+      if (size != static_cast<size_t>(w))
+        input.device()->Exec([size, dtype, this, dev](Context* ctx) {
+            this->InitCudnn(size, dtype, dev, ctx);
+            }, {}, {this->state_.block()});
     }
     Tensor output;
     output.ResetLike(input);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/cudnn_lrn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_lrn.cc b/src/model/layer/cudnn_lrn.cc
index ac7645e..1a08526 100644
--- a/src/model/layer/cudnn_lrn.cc
+++ b/src/model/layer/cudnn_lrn.cc
@@ -31,21 +31,39 @@ CudnnLRN::~CudnnLRN() {
   }
 }
 void CudnnLRN::InitCudnn(const Shape& shape, DataType dtype) {
-  CHECK(!has_init_cudnn_);
-  mode_ = CUDNN_LRN_CROSS_CHANNEL_DIM1;
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_));
   CHECK_EQ(shape.size(), 4u);
+  if (!has_init_cudnn_) {
+    mode_ = CUDNN_LRN_CROSS_CHANNEL_DIM1;
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_));
+    CUDNN_CHECK(cudnnCreateLRNDescriptor(&lrn_desc_));
+    CUDNN_CHECK(cudnnSetLRNDescriptor(lrn_desc_, local_size_, alpha_, beta_, k_));
+  }
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc_, CUDNN_TENSOR_NCHW,
-                                         GetCudnnDataType(dtype), shape[0],
-                                         shape[1], shape[2], shape[3]));
-  CUDNN_CHECK(cudnnCreateLRNDescriptor(&lrn_desc_));
-  CUDNN_CHECK(cudnnSetLRNDescriptor(lrn_desc_, local_size_, alpha_, beta_, k_));
+        GetCudnnDataType(dtype), shape[0],
+        shape[1], shape[2], shape[3]));
   has_init_cudnn_ = true;
 }
 const Tensor CudnnLRN::Forward(int flag, const Tensor& input) {
   auto shape = input.shape();
   auto dtype = input.data_type();
-  if (!has_init_cudnn_) InitCudnn(shape, dtype);
+  if (!has_init_cudnn_) {
+    InitCudnn(shape, dtype);
+  } else {
+    int n, c, h, w, s;
+    cudnnDataType_t type;
+    CUDNN_CHECK(cudnnGetTensor4dDescriptor(shape_desc_, &type,
+          &n, &c, &h, &w, &s, &s, &s, &s));
+    if (shape[0] != static_cast<size_t>(n))
+      InitCudnn(shape, dtype);
+    CHECK(input.shape(1) == static_cast<size_t>(c)
+        && input.shape(2) == static_cast<size_t>(h)
+        && input.shape(3) == static_cast<size_t>(w))
+      << "input sample shape should not change"
+      << "previous shape " << c << ", " << h << ", " << w
+      << "current shape " << input.shape(1) << ", " << input.shape(2) << ", "
+      << input.shape(3);
+  }
+
   Tensor output;
   output.ResetLike(input);
   output.device()->Exec([=](Context* ctx) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/cudnn_pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_pooling.cc b/src/model/layer/cudnn_pooling.cc
index d5b1aa3..364242e 100644
--- a/src/model/layer/cudnn_pooling.cc
+++ b/src/model/layer/cudnn_pooling.cc
@@ -43,12 +43,13 @@ void CudnnPooling::Setup(const Shape& in_sample, const LayerConf &conf) {
 }
 
 void CudnnPooling::InitCudnn(const Tensor &input) {
-  CHECK(!has_init_cudnn_);
   DataType dtype = input.data_type();
   size_t batchsize = input.shape(0);
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
-  CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc_));
+  if (!has_init_cudnn_) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
+    CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc_));
+  }
 
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
                                          GetCudnnDataType(dtype), batchsize,
@@ -85,7 +86,23 @@ const Tensor CudnnPooling::Forward(int flag, const Tensor &input) {
   size_t batchsize = input.shape(0);
   DataType dtype = input.data_type();
   auto dev = input.device();
-  if (!has_init_cudnn_) InitCudnn(input);
+  if (!has_init_cudnn_) {
+    InitCudnn(input);
+  } else {
+    int n, c, h, w, s;
+    cudnnDataType_t type;
+    CUDNN_CHECK(cudnnGetTensor4dDescriptor(x_desc_, &type, &n, &c, &h, &w,
+          &s, &s, &s, &s));
+    if (batchsize != static_cast<size_t>(n))
+      InitCudnn(input);
+    CHECK(input.shape(1) == static_cast<size_t>(c)
+        && input.shape(2) == static_cast<size_t>(h)
+        && input.shape(3) == static_cast<size_t>(w))
+      << "input sample shape should not change"
+      << "previous shape " << c << ", " << h << ", " << w
+      << "current shape " << input.shape(1) << ", " << input.shape(2) << ", "
+      << input.shape(3);
+  }
 
   Shape shape{batchsize, channels_, pooled_height_, pooled_width_};
   Tensor output = Tensor(shape, dev, dtype);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/cudnn_softmax.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_softmax.cc b/src/model/layer/cudnn_softmax.cc
index f1a4a5b..0824611 100644
--- a/src/model/layer/cudnn_softmax.cc
+++ b/src/model/layer/cudnn_softmax.cc
@@ -43,12 +43,9 @@ void CudnnSoftmax::Setup(const Shape& in_sample, const LayerConf &conf) {
 }
 
 void CudnnSoftmax::InitCudnn(Shape shape, DataType dtype) {
-  CHECK(!has_init_cudnn_);
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&desc_));
+  if (!has_init_cudnn_)
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&desc_));
 
-  CHECK_LE(shape.size(), 2u)
-    << "Tensor shape should range from 1 to 2D;"
-    << "otherwise, add flatten layer to transform";
   if (shape.size() == 1u)
     CUDNN_CHECK(cudnnSetTensor4dDescriptor( desc_,
       CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, shape[0], 1, 1));
@@ -61,10 +58,24 @@ void CudnnSoftmax::InitCudnn(Shape shape, DataType dtype) {
 const Tensor CudnnSoftmax::Forward(int flag, const Tensor& input) {
   CHECK(buf_.empty());
   auto shape = input.shape();
+  CHECK_LE(shape.size(), 2u)
+    << "Tensor shape should range from 1 to 2D;"
+    << "otherwise, add flatten layer to transform";
   DataType dtype = input.data_type();
   if (!has_init_cudnn_) {
     InitCudnn(shape, dtype);
+  } else {
+    int n, c, h, w, s;
+    cudnnDataType_t type;
+    CUDNN_CHECK(cudnnGetTensor4dDescriptor(desc_, &type, &n, &c, &h, &w,
+          &s, &s, &s, &s));
+    if ((shape.size() == 1u && shape[0] != static_cast<size_t>(c)) ||
+        (shape.size() == 2u &&
+         (shape[0] != static_cast<size_t>(n)
+          || shape[1] != static_cast<size_t>(c))))
+      InitCudnn(shape, dtype);
   }
+
   Tensor output;
   output.ResetLike(input);
   output.device()->Exec([input, output, this](Context* ctx) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/opencl_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/opencl_convolution.cc b/src/model/layer/opencl_convolution.cc
index eb37236..063c4c3 100644
--- a/src/model/layer/opencl_convolution.cc
+++ b/src/model/layer/opencl_convolution.cc
@@ -15,7 +15,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
- 
+
 #include "opencl_convolution.h"
 
 #ifdef USE_OPENCL
@@ -29,26 +29,30 @@ const Tensor OpenclConvolution::Forward(int flag, const Tensor &input) {
   CHECK(buf_.empty());
   CHECK_EQ(input.device()->lang(), kOpencl);
   CHECK_EQ(input.nDim(), 4u);
-  
+
   if (flag & kTrain) buf_.push(input);
-  
+
   auto batchsize = input.shape(0);
   auto imagesize = input.Size() / batchsize;
   auto data_type = input.data_type();
   auto device = input.device();
-  
+
+   // TODO(wangwei) update the layer config if the input sample shape changes
+  CHECK(input.shape(1) == channels_ && input.shape(2) == height_ &&
+      input.shape(3) == width_) << "input sample shape should not change";
+
   Shape shape{batchsize, num_filters_, conv_height_, conv_width_};
   Tensor output(shape, device, data_type);
   Tensor col_data(Shape{col_height_, col_width_}, device, data_type);
-  
+
   for (size_t b = 0; b < batchsize; b++) {
     int offset = b * imagesize;
-    
+
     col_data.device()->Exec([input, offset, col_data, this](Context* ctx) mutable {
 
-      this->Im2Col(input.block(), offset, 
+      this->Im2Col(input.block(), offset,
                    height_, width_,
-                   kernel_h_, kernel_w_, 
+                   kernel_h_, kernel_w_,
                    pad_h_, pad_w_,
                    stride_h_, stride_w_,
                    conv_height_, conv_width_,
@@ -57,16 +61,16 @@ const Tensor OpenclConvolution::Forward(int flag, const Tensor &input) {
     },
     {input.block()},
     {col_data.block()});
-    
+
     Tensor each = Mult(weight_, col_data);
 
     if (bias_term_) {
       AddColumn(bias_, &each);
     }
-    
+
     CopyDataToFrom(&output, each, each.Size(), b * each.Size());
   }
-  
+
   return output;
 }
 
@@ -77,46 +81,46 @@ OpenclConvolution::Backward(int flag, const Tensor &grad) {
   CHECK(!buf_.empty());
   CHECK_EQ(grad.device()->lang(), kOpencl);
   CHECK_EQ(grad.nDim(), 4u);
-  
+
   std::vector<Tensor> param_grad;
-  
+
   Tensor src_data = buf_.top();
   buf_.pop();
-  
+
   Tensor dx, db, dw;
   dx.ResetLike(src_data);
   db.ResetLike(bias_);
   dw.ResetLike(weight_);
   dw.SetValue(0.0f);
-  
+
   size_t batchsize = grad.shape(0);
   size_t imagesize = src_data.Size() / batchsize;
-  
+
   if (bias_term_) {
     auto tmpshp = Shape{batchsize * num_filters_, grad.Size() / (batchsize * num_filters_)};
     Tensor tmp1 = Reshape(grad, tmpshp);
 
-    Tensor tmp2(Shape{batchsize * num_filters_}, 
+    Tensor tmp2(Shape{batchsize * num_filters_},
                 grad.device(), grad.data_type());
     SumColumns(tmp1, &tmp2);
     Tensor tmp3 = Reshape(tmp2, Shape{batchsize, num_filters_});
 
     SumRows(tmp3, &db);
   }
-  
-  Tensor col_data(Shape{col_height_, col_width_}, 
+
+  Tensor col_data(Shape{col_height_, col_width_},
                   grad.device(), grad.data_type());
-  
+
   for (size_t b = 0; b < batchsize; b++) {
-  
+
     int im_offset = b * imagesize;
     int col_offset = 0; // Always keep this to zero.
-    
+
     col_data.device()->Exec([src_data, col_data, im_offset, col_offset, this](Context* ctx) mutable {
-      
-      this->Im2Col(src_data.block(), im_offset, 
+
+      this->Im2Col(src_data.block(), im_offset,
                    height_, width_,
-                   kernel_h_, kernel_w_, 
+                   kernel_h_, kernel_w_,
                    pad_h_, pad_w_,
                    stride_h_, stride_w_,
                    conv_height_, conv_width_,
@@ -125,19 +129,19 @@ OpenclConvolution::Backward(int flag, const Tensor &grad) {
     },
     {src_data.block()},
     {col_data.block()});
-    
-    Tensor grad_b(Shape{num_filters_, conv_height_ * conv_width_}, 
+
+    Tensor grad_b(Shape{num_filters_, conv_height_ * conv_width_},
                   grad.device(), grad.data_type());
     CopyDataToFrom(&grad_b, grad, grad_b.Size(), 0, b * grad_b.Size());
-    
+
     dw += Mult(grad_b, col_data.T());
     Tensor dcol_b = Mult(weight_.T(), grad_b);
-                         
+
     dx.device()->Exec([dcol_b, dx, im_offset, col_offset, this](Context* ctx) mutable {
-      
-      this->Col2Im(dcol_b.block(), col_offset, 
+
+      this->Col2Im(dcol_b.block(), col_offset,
                    height_, width_,
-                   kernel_h_, kernel_w_, 
+                   kernel_h_, kernel_w_,
                    pad_h_, pad_w_,
                    stride_h_, stride_w_,
                    conv_height_, conv_width_,
@@ -147,10 +151,10 @@ OpenclConvolution::Backward(int flag, const Tensor &grad) {
     {dcol_b.block()},
     {dx.block()});
   }
-  
+
   param_grad.push_back(dw);
   param_grad.push_back(db);
-  
+
   return std::make_pair(dx, param_grad);
 }
 
@@ -164,14 +168,14 @@ void OpenclConvolution::ToDevice(std::shared_ptr<Device> device) {
   Convolution::ToDevice(device);
 }
 
-                           
-void OpenclConvolution::Im2Col(Block* src, int data_im_off, 
+
+void OpenclConvolution::Im2Col(Block* src, int data_im_off,
                                const int height, const int width,
                                const int kernel_h, const int kernel_w,
                                const int pad_h, const int pad_w,
                                const int stride_h, const int stride_w,
                                const int conv_h, const int conv_w,
-                               const int col_data_off, const int channels, 
+                               const int col_data_off, const int channels,
                                Block* dst, Context* ctx) {
 
   auto ocl_ctx = viennacl::ocl::get_context(ctx->vcl_ctx_id);
@@ -179,36 +183,36 @@ void OpenclConvolution::Im2Col(Block* src, int data_im_off,
 
   auto src_buf = WrapHandle(static_cast<cl_mem>(src->mutable_data()), ocl_ctx);
   auto dst_buf = WrapHandle(static_cast<cl_mem>(dst->mutable_data()), ocl_ctx);
-  
+
   int num_kernels = channels * conv_h * conv_w;
-  
+
   viennacl::ocl::enqueue(kernel(num_kernels, src_buf, data_im_off,
-                                height, width, kernel_h, kernel_w, 
+                                height, width, kernel_h, kernel_w,
                                 pad_h, pad_w, stride_h, stride_w,
                                 1, 1, conv_h, conv_w,
                                 dst_buf, col_data_off));
 }
 
-  
-void OpenclConvolution::Col2Im(Block* src, const int col_data_off, 
+
+void OpenclConvolution::Col2Im(Block* src, const int col_data_off,
                                const int height, const int width,
                                const int kernel_h, const int kernel_w,
                                const int pad_h, const int pad_w,
                                const int stride_h, const int stride_w,
                                const int conv_h, const int conv_w,
-                               const int data_im_off, const int channels, 
+                               const int data_im_off, const int channels,
                                Block* dst, Context* ctx) {
-                               
+
   auto ocl_ctx = viennacl::ocl::get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("opencl_im2col", "col2im");
-  
+
   auto src_buf = WrapHandle(static_cast<cl_mem>(src->mutable_data()), ocl_ctx);
   auto dst_buf = WrapHandle(static_cast<cl_mem>(dst->mutable_data()), ocl_ctx);
-  
+
   int num_kernels = channels * height * width;
-  
+
   viennacl::ocl::enqueue(kernel(num_kernels, src_buf, col_data_off, channels,
-                                height, width, kernel_h, kernel_w, 
+                                height, width, kernel_h, kernel_w,
                                 pad_h, pad_w, stride_h, stride_w,
                                 1, 1, conv_h, conv_w,
                                 dst_buf, data_im_off));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/opencl_pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/opencl_pooling.cc b/src/model/layer/opencl_pooling.cc
index 155f2bb..4e31289 100644
--- a/src/model/layer/opencl_pooling.cc
+++ b/src/model/layer/opencl_pooling.cc
@@ -15,7 +15,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
- 
+
 #include "opencl_pooling.h"
 
 #ifdef USE_OPENCL
@@ -28,14 +28,17 @@ const Tensor OpenclPooling::Forward(int flag, const Tensor &input) {
   CHECK(buf_.empty());
   CHECK_EQ(input.device()->lang(), kOpencl);
   CHECK_EQ(input.nDim(), 4u);
-  
+
   auto batchsize = input.shape(0);
   auto data_type = input.data_type();
   auto device = input.device();
+  // TODO(wangwei) update the layer config if the input sample shape changes
+  CHECK(input.shape(1) == channels_ && input.shape(2) == height_ &&
+      input.shape(3) == width_) << "input sample shape should not change";
 
   Shape shape{batchsize, channels_, pooled_height_, pooled_width_};
   Tensor output = Tensor(shape, device, data_type);
-  
+
   output.device()->Exec([input, output, flag, this](Context *ctx) {
     Block* in_block = input.block();
     Block* outblock = output.block();
@@ -43,18 +46,18 @@ const Tensor OpenclPooling::Forward(int flag, const Tensor &input) {
     if (pool_ == PoolingConf_PoolMethod_MAX) {
       Tensor mask;
       mask.ResetLike(output);
-      
-      Pooling_Forward_Max((int)output.Size(), in_block, mask.block(), 
+
+      Pooling_Forward_Max((int)output.Size(), in_block, mask.block(),
                           height_, width_,
                           pooled_height_, pooled_width_,
                           kernel_h_, kernel_w_,
                           stride_h_, stride_w_,
                           pad_h_, pad_w_,
                           outblock, channels_, ctx);
-      
+
       if (flag & kTrain)
         buf_.push(mask);
-      
+
     } else if (pool_ == PoolingConf_PoolMethod_AVE) {
       Pooling_Forward_Ave((int)output.Size(), in_block, outblock,
                           height_, width_, pooled_height_, pooled_width_,
@@ -62,9 +65,9 @@ const Tensor OpenclPooling::Forward(int flag, const Tensor &input) {
                           pad_h_, pad_w_, channels_, ctx);
     } else
       LOG(FATAL) << "Unknown pooling method.";
-    
+
   }, {input.block()}, {output.block()});
-  
+
   return output;
 }
 
@@ -73,14 +76,14 @@ const std::pair<Tensor, std::vector<Tensor>>
 OpenclPooling::Backward(int flag, const Tensor &grad) {
   CHECK_EQ(grad.device()->lang(), kOpencl);
   CHECK_EQ(grad.nDim(), 4u);
-  
+
   std::vector<Tensor> param_grad;
-  
+
   auto batchsize = grad.shape(0);
   auto data_type = grad.data_type();
   auto device = grad.device();
   Shape shape{batchsize, channels_, height_, width_};
-  
+
   Tensor dx(shape, device, data_type);
 
   dx.device()->Exec([dx, grad, this](Context *ctx) {
@@ -97,19 +100,19 @@ OpenclPooling::Backward(int flag, const Tensor &grad) {
                            pad_h_, pad_w_,
                            stride_h_, stride_w_,
                            dx.block(), ctx);
-                           
+
     } else if (pool_ == PoolingConf_PoolMethod_AVE) {
-      Pooling_Backward_Ave(grad.block(), grad.shape(0), channels_, 
+      Pooling_Backward_Ave(grad.block(), grad.shape(0), channels_,
                            height_, width_,
                            pooled_height_, pooled_width_,
                            kernel_h_, kernel_w_,
                            pad_h_, pad_w_,
                            stride_h_, stride_w_,
                            dx.block(), ctx);
-                           
+
     } else
       LOG(FATAL) << "Unknown pooling method.";
-    
+
   }, {grad.block()}, {dx.block()});
 
   return std::make_pair(dx, param_grad);
@@ -122,7 +125,7 @@ void OpenclPooling::Setup(const Shape& in_sample, const LayerConf &conf) {
 }
 
 
-void OpenclPooling::Pooling_Forward_Max(const int num, Block* src, Block* mask, 
+void OpenclPooling::Pooling_Forward_Max(const int num, Block* src, Block* mask,
                                         const int height, const int width,
                                         const int pooled_h, const int pooled_w,
                                         const int kernel_h, const int kernel_w,
@@ -132,7 +135,7 @@ void OpenclPooling::Pooling_Forward_Max(const int num, Block* src, Block* mask,
                                         Context* ctx) {
   auto ocl_ctx = viennacl::ocl::get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("opencl_pooling", "max_pool_forward");
-  
+
   auto src_buf = WrapHandle(static_cast<cl_mem>(src->mutable_data()), ocl_ctx);
   auto dst_buf = WrapHandle(static_cast<cl_mem>(dst->mutable_data()), ocl_ctx);
   auto maskbuf = WrapHandle(static_cast<cl_mem>(mask->mutable_data()), ocl_ctx);
@@ -144,7 +147,7 @@ void OpenclPooling::Pooling_Forward_Max(const int num, Block* src, Block* mask,
 }
 
 
-void OpenclPooling::Pooling_Forward_Ave(const int num, Block* src, Block* dst, 
+void OpenclPooling::Pooling_Forward_Ave(const int num, Block* src, Block* dst,
                                         const int height, const int width,
                                         const int pooled_h, const int pooled_w,
                                         const int kernel_h, const int kernel_w,
@@ -153,10 +156,10 @@ void OpenclPooling::Pooling_Forward_Ave(const int num, Block* src, Block* dst,
                                         const int channels, Context* ctx) {
   auto ocl_ctx = viennacl::ocl::get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("opencl_pooling", "ave_pool_forward");
-  
+
   auto src_buf = WrapHandle(static_cast<cl_mem>(src->mutable_data()), ocl_ctx);
   auto dst_buf = WrapHandle(static_cast<cl_mem>(dst->mutable_data()), ocl_ctx);
-                                   
+
   viennacl::ocl::enqueue(kernel(num, src_buf, channels,
                                 height, width, pooled_h, pooled_w,
                                 kernel_h, kernel_w, stride_h, stride_w,
@@ -169,11 +172,11 @@ void OpenclPooling::Pooling_Forward_Sto_Train(Block* src, Block* rand,
                                               const int pooled_h, const int pooled_w,
                                               const int kernel_h, const int kernel_w,
                                               const int stride_h, const int stride_w,
-                                              const int channels, 
+                                              const int channels,
                                               Block* dst, Context* ctx) {
   auto ocl_ctx = viennacl::ocl::get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("opencl_pooling", "sto_pool_forward_train");
-  
+
   auto src_buf = WrapHandle(static_cast<cl_mem>(src->mutable_data()), ocl_ctx);
   auto dst_buf = WrapHandle(static_cast<cl_mem>(dst->mutable_data()), ocl_ctx);
   auto randbuf = WrapHandle(static_cast<cl_mem>(rand->mutable_data()), ocl_ctx);
@@ -185,7 +188,7 @@ void OpenclPooling::Pooling_Forward_Sto_Train(Block* src, Block* rand,
 }
 
 
-void OpenclPooling::Pooling_Forward_Sto_Test(Block* src, Block* dst, 
+void OpenclPooling::Pooling_Forward_Sto_Test(Block* src, Block* dst,
                                              const int height, const int width,
                                              const int pooled_h, const int pooled_w,
                                              const int kernel_h, const int kernel_w,
@@ -193,7 +196,7 @@ void OpenclPooling::Pooling_Forward_Sto_Test(Block* src, Block* dst,
                                              const int channels, Context* ctx) {
   auto ocl_ctx = viennacl::ocl::get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("opencl_pooling", "sto_pool_forward_test");
-  
+
   auto src_buf = WrapHandle(static_cast<cl_mem>(src->mutable_data()), ocl_ctx);
   auto dst_buf = WrapHandle(static_cast<cl_mem>(dst->mutable_data()), ocl_ctx);
 
@@ -214,7 +217,7 @@ void OpenclPooling::Pooling_Backward_Max(Block* top, Block* mask,
                                          Block* bottom, Context* ctx) {
   auto ocl_ctx = viennacl::ocl::get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("opencl_pooling", "max_pool_backward");
-  
+
   auto src_buf = WrapHandle(static_cast<cl_mem>(top->mutable_data()), ocl_ctx);
   auto dst_buf = WrapHandle(static_cast<cl_mem>(bottom->mutable_data()), ocl_ctx);
   auto mask_buf = WrapHandle(static_cast<cl_mem>(mask->mutable_data()), ocl_ctx);
@@ -227,7 +230,7 @@ void OpenclPooling::Pooling_Backward_Max(Block* top, Block* mask,
 
 
 void OpenclPooling::Pooling_Backward_Ave(Block* bottom,
-                                         const int num, const int channels, 
+                                         const int num, const int channels,
                                          const int height, const int width,
                                          const int pooled_h, const int pooled_w,
                                          const int kernel_h, const int kernel_w,
@@ -236,10 +239,10 @@ void OpenclPooling::Pooling_Backward_Ave(Block* bottom,
                                          Block* top, Context* ctx) {
   auto ocl_ctx = viennacl::ocl::get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("opencl_pooling", "ave_pool_backward");
-  
+
   auto src_buf = WrapHandle(static_cast<cl_mem>(bottom->mutable_data()), ocl_ctx);
   auto dst_buf = WrapHandle(static_cast<cl_mem>(top->mutable_data()), ocl_ctx);
-                                   
+
   viennacl::ocl::enqueue(kernel(num, src_buf, channels,
                                 height, width, pooled_h, pooled_w,
                                 kernel_h, kernel_w, stride_h, stride_w,
@@ -255,11 +258,11 @@ void OpenclPooling::Pooling_Backward_Sto(Block* src, Block* rand, Block* dst,
                                          const int channels, Context* ctx) {
   auto ocl_ctx = viennacl::ocl::get_context(ctx->vcl_ctx_id);
   auto kernel = ocl_ctx.get_kernel("opencl_pooling", "sto_pool_backward");
-  
+
   auto src_buf = WrapHandle(static_cast<cl_mem>(src->mutable_data()), ocl_ctx);
   auto dst_buf = WrapHandle(static_cast<cl_mem>(dst->mutable_data()), ocl_ctx);
   auto randbuf = WrapHandle(static_cast<cl_mem>(rand->mutable_data()), ocl_ctx);
-                                   
+
   viennacl::ocl::enqueue(kernel(height * width, randbuf, src_buf, channels,
                                 height, width, pooled_h, pooled_w,
                                 kernel_h, kernel_w, stride_h, stride_w,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/pooling.cc b/src/model/layer/pooling.cc
index a8f3d3d..60a58a9 100644
--- a/src/model/layer/pooling.cc
+++ b/src/model/layer/pooling.cc
@@ -89,6 +89,11 @@ const Tensor Pooling::Forward(int flag, const Tensor& input) {
   CHECK_EQ(input.nDim(), 4u);
   size_t batchsize = input.shape(0);
   DataType dtype = input.data_type();
+
+  // TODO(wangwei) update the layer config if the input sample shape changes
+  CHECK(input.shape(1) == channels_ && input.shape(2) == height_ &&
+      input.shape(3) == width_) << "input sample shape should not change";
+
   auto dev = input.device();
   Shape shape{batchsize, channels_, pooled_height_, pooled_width_};
   Tensor output(shape, dev, dtype);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/38da7891/src/model/layer/slice.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/slice.cc b/src/model/layer/slice.cc
index 8a3e4bf..c908a1b 100644
--- a/src/model/layer/slice.cc
+++ b/src/model/layer/slice.cc
@@ -54,6 +54,7 @@ void Slice::Setup(const Shape& in_sample, const LayerConf& conf) {
 }
 
 const vector<Tensor> Slice::Forward(int flag, const vector<Tensor>& inputs) {
+  // TODO(wangwei) check the inputs shape to be the same for all iterations
   vector<Tensor> outputs;
   CHECK_EQ(inputs.size(), 1u) << "Split layer only have one input tensor.";
   size_t offset = 0;