You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2016/06/13 13:20:14 UTC

[21/50] [abbrv] incubator-singa git commit: SINGA-178 Add Convolution layer and Pooling layer

SINGA-178 Add Convolution layer and Pooling layer

Minor update on variable names and InitCudnn arguments.
Fix compiling warnings about signed and unsigned number comparison.
Format code. Pass all tests.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/7d149ecf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/7d149ecf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/7d149ecf

Branch: refs/heads/master
Commit: 7d149ecf786f816cf2da47ea9e5bb86f8fecdd6b
Parents: 152056d
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Mon May 30 16:53:40 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Mon May 30 18:17:57 2016 +0800

----------------------------------------------------------------------
 include/singa/core/device.h          |  12 +--
 include/singa/core/tensor.h          |   2 +-
 include/singa/model/layer.h          |   7 +-
 include/singa/utils/string.h         |  81 ++++++++++++++++++
 include/singa/utils/tokenizer.h      |  65 --------------
 src/core/tensor/tensor.cc            |   8 +-
 src/model/layer/convolution.cc       |  33 ++++++--
 src/model/layer/convolution.h        |   3 +-
 src/model/layer/cudnn_convolution.cc | 135 ++++++++++++++----------------
 src/model/layer/cudnn_convolution.h  |  11 ++-
 src/model/layer/cudnn_pooling.cc     |  40 ++++-----
 src/model/layer/cudnn_pooling.h      |   2 +-
 src/model/layer/pooling.cc           |   8 +-
 src/model/layer/pooling.h            |   3 +-
 src/proto/model.proto                |  15 ++--
 test/singa/test_cudnn_convolution.cc |  50 +++++------
 test/singa/test_cudnn_pooling.cc     |  26 +++---
 17 files changed, 274 insertions(+), 227 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/include/singa/core/device.h
----------------------------------------------------------------------
diff --git a/include/singa/core/device.h b/include/singa/core/device.h
index a4b3f6d..56eda70 100644
--- a/include/singa/core/device.h
+++ b/include/singa/core/device.h
@@ -77,6 +77,10 @@ class Device {
 
   Device* host() const { return host_;}
 
+  Context* context(int k) {
+    return &ctx_;
+  }
+
   int id() const { return id_; }
 
  protected:
@@ -104,6 +108,8 @@ class Device {
   // SafeQueue<Operation> op_log_;
   /// The host device
   Device* host_;
+  // TODO(wangwei) define multiple contexts, one per executor
+  Context ctx_;
 };
 
 /// Represent a CPU device which may have multiple threads/executors.
@@ -125,9 +131,6 @@ class CppCPU : public Device {
 
   /// Free cpu memory.
   void Free(void* ptr) override;
-
- protected:
-  Context ctx_;
 };
 
 /// a singleton CppDevice as the host for all devices.
@@ -177,9 +180,6 @@ class CudaGPU : public Device {
 
   /// Free cpu memory.
   void Free(void* ptr) override;
-
- protected:
-  Context ctx_;
 };
 
 /// CudaCPU which uses cudaMallocHost to allocate pinned memory for host.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index f51c899..8682bca 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -97,7 +97,7 @@ public:
     return shape_.at(idx);
   }
 
-  int nDim() const { return shape_.size(); }
+  size_t nDim() const { return shape_.size(); }
 
   bool transpose() const { return transpose_; }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/include/singa/model/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h
index c6a3bd1..82c8edc 100644
--- a/include/singa/model/layer.h
+++ b/include/singa/model/layer.h
@@ -44,7 +44,7 @@ class Layer {
 
   // ============= Following Functions could be override =====================
   /// Destruct objects created by this layer.
-  virtual ~Layer() {}; 
+  virtual ~Layer() {};
 
   /// Each layer sub-class would optionaly have a type name.
   /// Used for debugging and logging.
@@ -160,7 +160,10 @@ class Layer {
   const vector<ParamSpec> param_specs() { return param_specs_; }
 
   /// Return the i-th ParamSpec.
-  const ParamSpec& param_specs(int i) { return param_specs_.at(i); }
+  const ParamSpec& param_specs(size_t i) {
+    CHECK_LT(i, param_specs_.size());
+    return param_specs_.at(i);
+  }
 
   /// Return pointers to parameter Tensor s.
   const vector<Tensor*> param_values() { return param_values_; }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/include/singa/utils/string.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/string.h b/include/singa/utils/string.h
new file mode 100644
index 0000000..b739afc
--- /dev/null
+++ b/include/singa/utils/string.h
@@ -0,0 +1,81 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#ifndef SINGA_UTILS_TOKENIZER_H_
+#define SINGA_UTILS_TOKENIZER_H_
+
+#include <string>
+#include <algorithm>
+#include "singa/utils/logging.h"
+
+namespace singa {
+inline bool icasecmp(const string& l, const string& r) {
+  return l.size() == r.size() &&
+         equal(l.cbegin(), l.cend(), r.cbegin(),
+               [](string::value_type l1, string::value_type r1) {
+                 return toupper(l1) == toupper(r1);
+               });
+}
+
+inline string ToLowerCase(const string& input) {
+  string out;
+  out.resize(input.size());
+  std::transform(input.begin(), input.end(), out.begin(), ::tolower);
+  return out;
+}
+
+/**
+ * Tokenize a string.
+ *
+ * example:
+ * Tokenizer t("assa,asf;wes", ",;");
+ * string x;
+ * t >> x; // x is assa
+ * t >> x; // x is asf
+ * t >> x; // x is wes
+ * cout << (t >> x); // print 0.
+ */
+
+class Tokenizer {
+ public:
+  Tokenizer(const std::string& str, const std::string& sep): start_(0),
+  sep_(sep), buf_(str) {}
+  Tokenizer & operator>>(std::string& out) {
+    CHECK_LT(start_, buf_.length());
+    int start = start_;
+    auto pos = buf_.find_first_of(sep_, start);
+    if (pos == std::string::npos)
+      pos = buf_.length();
+    start_ = pos + 1;
+    out = buf_.substr(start, pos);
+    return *this;
+  }
+  bool Valid() { return start_ < buf_.length(); }
+
+ private:
+  unsigned start_;
+  std::string sep_;
+  const std::string& buf_;
+};
+
+}  // namespace singa
+
+#endif  // SINGA_UTILS_TOKENIZER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/include/singa/utils/tokenizer.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/tokenizer.h b/include/singa/utils/tokenizer.h
deleted file mode 100644
index 92c24b6..0000000
--- a/include/singa/utils/tokenizer.h
+++ /dev/null
@@ -1,65 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_UTILS_TOKENIZER_H_
-#define SINGA_UTILS_TOKENIZER_H_
-
-#include <string>
-#include "singa/utils/logging.h"
-
-namespace singa {
-/**
- * Tokenize a string.
- *
- * example:
- * Tokenizer t("assa,asf;wes", ",;");
- * string x;
- * t >> x; // x is assa
- * t >> x; // x is asf
- * t >> x; // x is wes
- * cout << (t >> x); // print 0.
- */
-
-class Tokenizer {
- public:
-  Tokenizer(const std::string& str, const std::string& sep): start_(0),
-  sep_(sep), buf_(str) {}
-  Tokenizer & operator>>(std::string& out) {
-    CHECK_LT(start_, buf_.length());
-    int start = start_;
-    auto pos = buf_.find_first_of(sep_, start);
-    if (pos == std::string::npos)
-      pos = buf_.length();
-    start_ = pos + 1;
-    out = buf_.substr(start, pos);
-    return *this;
-  }
-  bool Valid() { return start_ < buf_.length(); }
-
- private:
-  unsigned start_;
-  std::string sep_;
-  const std::string& buf_;
-};
-
-}  // namespace singa
-
-#endif  // SINGA_UTILS_TOKENIZER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 0e47a4f..fcf42c2 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -562,8 +562,8 @@ void AddColumn(const float alpha, const float beta, const Tensor &v,
     Tensor X = M->T();
     AddRow(v, &X);
   } else {
-    CHECK_EQ(M->nDim(), 2);
-    CHECK_EQ(v.nDim(), 1);
+    CHECK_EQ(M->nDim(), 2u);
+    CHECK_EQ(v.nDim(), 1u);
     size_t nb_row = M->shape(0), nb_col = M->shape(1);
     CHECK_EQ(nb_row, v.Size());
 
@@ -581,8 +581,8 @@ void AddRow(const float alpha, const float beta, const Tensor &v, Tensor *M) {
     Tensor X = M->T();
     AddColumn(v, &X);
   } else {
-    CHECK_EQ(M->nDim(), 2);
-    CHECK_EQ(v.nDim(), 1);
+    CHECK_EQ(M->nDim(), 2u);
+    CHECK_EQ(v.nDim(), 1u);
     size_t nb_row = M->shape(0), nb_col = M->shape(1);
     CHECK_EQ(nb_col, v.Size());
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/convolution.cc b/src/model/layer/convolution.cc
index 6406a31..50ee3c8 100644
--- a/src/model/layer/convolution.cc
+++ b/src/model/layer/convolution.cc
@@ -28,32 +28,51 @@ void Convolution::Setup(const LayerConf &conf) {
   ConvolutionConf conv_conf = conf.convolution_conf();
   // kernel_size, pad, and stride are repeated fields.
   if (conv_conf.kernel_size_size() > 0) {
-    kernel_w_ = kernel_h_ = conv_conf.kernel_size(0);
+    if (conv_conf.kernel_size_size() == 1) {
+      kernel_w_ = kernel_h_ = conv_conf.kernel_size(0);
+    } else {
+      kernel_w_ = conv_conf.kernel_size(0);
+      kernel_h_ = conv_conf.kernel_size(1);
+    }
   } else {
     kernel_w_ = conv_conf.kernel_w();
     kernel_h_ = conv_conf.kernel_h();
   }
-  CHECK_NE(kernel_w_, 0);
-  CHECK_NE(kernel_h_, 0);
+  CHECK_GT(kernel_w_, 0u);
+  CHECK_GT(kernel_h_, 0u);
 
   if (conv_conf.pad_size() > 0) {
-    pad_w_ = pad_h_ = conv_conf.pad(0);
+    if (conv_conf.pad_size() == 1) {
+      pad_w_ = pad_h_ = conv_conf.pad(0);
+    } else {
+      pad_w_ = conv_conf.pad(0);
+      pad_h_ = conv_conf.pad(1);
+    }
   } else {
     pad_w_ = conv_conf.pad_w();
     pad_h_ = conv_conf.pad_h();
   }
+  CHECK_GE(pad_w_, 0u);
+  CHECK_GE(pad_h_, 0u);
 
   if (conv_conf.stride_size() > 0) {
-    stride_w_ = stride_h_ = conv_conf.stride(0);
+    if (conv_conf.stride_size() == 1) {
+      stride_w_ = stride_h_ = conv_conf.stride(0);
+    } else {
+      stride_w_ = conv_conf.stride(0);
+      stride_h_ = conv_conf.stride(1);
+    }
   } else {
     stride_w_ = conv_conf.stride_w();
     stride_h_ = conv_conf.stride_h();
   }
+  CHECK_GT(stride_w_, 0u);
+  CHECK_GT(stride_h_, 0u);
 
   num_filters_ = conv_conf.num_output();
   bias_term_ = conv_conf.bias_term();
 
-  // Shape of src
+  // Shape of input image
   channels_ = conv_conf.channels();
   height_ = conv_conf.height();
   width_ = conv_conf.width();
@@ -68,7 +87,7 @@ void Convolution::Setup(const LayerConf &conf) {
   bias_.Reshape(Shape{num_filters_});
   // Push back params into param_values_
   // Assume the order of param is: weight, bias
-  for (const auto& spec : conf.param()) param_specs_.push_back(spec);
+  for (const auto &spec : conf.param()) param_specs_.push_back(spec);
   param_values_.push_back(&weight_);
   param_values_.push_back(&bias_);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/convolution.h
----------------------------------------------------------------------
diff --git a/src/model/layer/convolution.h b/src/model/layer/convolution.h
index a9bf833..477efb3 100644
--- a/src/model/layer/convolution.h
+++ b/src/model/layer/convolution.h
@@ -47,7 +47,6 @@ class Convolution : public Layer {
   size_t stride_w() const { return stride_w_; }
   size_t stride_h() const { return stride_h_; }
   size_t num_filters() const { return num_filters_; }
-  size_t batchsize() const { return batchsize_; }
   size_t channels() const { return channels_; }
   size_t height() const { return height_; }
   size_t width() const { return width_; }
@@ -67,7 +66,7 @@ class Convolution : public Layer {
  protected:
   size_t kernel_w_, pad_w_, stride_w_;
   size_t kernel_h_, pad_h_, stride_h_;
-  size_t batchsize_, channels_, height_, width_;
+  size_t channels_, height_, width_;
   size_t col_height_, col_width_, conv_height_, conv_width_, num_filters_;
   Tensor weight_, bias_;
   // store intermediate data, i.e., input tensor

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index ec7cd6a..922b7e0 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -39,9 +39,9 @@ void CudnnConvolution::Setup(const LayerConf &conf) {
   ConvolutionConf conv_conf = conf.convolution_conf();
   // convert MB to bytes
   workspace_byte_limit_ = conv_conf.workspace_byte_limit() << 20;
-  pref_ = conv_conf.algo_pref();
-  CHECK(pref_ == "fastest" || pref_ == "limited_workspace" ||
-        pref_ == "no_workspace")
+  prefer_ = ToLowerCase(conv_conf.prefer());
+  CHECK(prefer_ == "fastest" || prefer_ == "limited_workspace" ||
+        prefer_ == "no_workspace")
       << "CudnnConvolution only supports three algorithm preferences: fastest, "
          "limited_workspace and no_workspace";
 }
@@ -52,8 +52,12 @@ void CudnnConvolution::ToDevice(Device *device) {
   workspace_.ToDevice(device);
 }
 
-void CudnnConvolution::InitCudnn(DataType dtype, Device *dev, Context *ctx) {
+void CudnnConvolution::InitCudnn(const Tensor& input) {
   CHECK(!has_init_cudnn_);
+  DataType dtype = input.data_type();
+  Device *dev = input.device();
+  Context *ctx = dev->context(0);
+  size_t batchsize = input.shape(0);
   CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
   CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
   CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
@@ -61,10 +65,10 @@ void CudnnConvolution::InitCudnn(DataType dtype, Device *dev, Context *ctx) {
   CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_));
 
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
-                                         GetCudnnDataType(dtype), batchsize_,
+                                         GetCudnnDataType(dtype), batchsize,
                                          channels_, height_, width_));
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(
-      y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize_,
+      y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
       num_filters_, conv_height_, conv_width_));
   if (bias_term_)
     CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW,
@@ -88,20 +92,20 @@ void CudnnConvolution::InitCudnn(DataType dtype, Device *dev, Context *ctx) {
   cudnnConvolutionFwdPreference_t fwd_pref;
   cudnnConvolutionBwdFilterPreference_t bwd_filt_pref;
   cudnnConvolutionBwdDataPreference_t bwd_data_pref;
-  if (pref_ == "fastest") {
+  if (prefer_ == "fastest") {
     fwd_pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
     bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST;
     bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
-  } else if (pref_ == "limited_workspace") {
+  } else if (prefer_ == "limited_workspace") {
     fwd_pref = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT;
     bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT;
     bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
-  } else if (pref_ == "no_workspace") {
+  } else if (prefer_ == "no_workspace") {
     fwd_pref = CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
     bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
     bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
   } else {
-    LOG(FATAL) << "Algorithm preference is not implemented!";
+    LOG(FATAL) << "Preferred algorithm is not available!";
   }
   CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
       ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
@@ -133,51 +137,46 @@ void CudnnConvolution::InitCudnn(DataType dtype, Device *dev, Context *ctx) {
 
 const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) {
   CHECK_EQ(input.device()->lang(), kCuda);
-  CHECK_EQ(input.shape().size(), 4);
+  CHECK_EQ(input.nDim(), 4u);
   buf_.push(input);
-  batchsize_ = input.shape()[0];
+  size_t batchsize = input.shape()[0];
   DataType dtype = input.data_type();
   Device *dev = input.device();
 
-  if (!has_init_cudnn_) InitCudnn(dtype, dev, dev->context(0));
+  if (!has_init_cudnn_) InitCudnn(input);
 
-  Shape shape{batchsize_, num_filters_, conv_height_, conv_width_};
+  Shape shape{batchsize, num_filters_, conv_height_, conv_width_};
   Tensor output(shape, dev, dtype);
-  float alpha = 1.f, beta = 0.f;
-  output.device()->Exec(
-      [input, output, alpha, beta, this](Context *ctx) {
-        Blob *inblob = input.blob(), *outblob = output.blob(),
-             *wblob = this->weight_.blob();
-        cudnnConvolutionForward(ctx->cudnn_handle, &alpha, this->x_desc_,
-                                inblob->data(), this->filter_desc_,
-                                wblob->data(), this->conv_desc_, this->fp_alg_,
-                                this->workspace_.blob()->mutable_data(),
-                                this->workspace_count_ * sizeof(float), &beta,
-                                this->y_desc_, outblob->mutable_data());
-      },
-      {input.blob(), weight_.blob()}, {output.blob()}, workspace_.blob());
+  output.device()->Exec([input, output, this](Context *ctx) {
+    Blob *inblob = input.blob(), *outblob = output.blob(),
+         *wblob = this->weight_.blob();
+    float alpha = 1.f, beta = 0.f;
+    cudnnConvolutionForward(ctx->cudnn_handle, &alpha, this->x_desc_,
+                            inblob->data(), this->filter_desc_, wblob->data(),
+                            this->conv_desc_, this->fp_alg_,
+                            this->workspace_.blob()->mutable_data(),
+                            this->workspace_count_ * sizeof(float), &beta,
+                            this->y_desc_, outblob->mutable_data());
+  }, {input.blob(), weight_.blob()}, {output.blob()}, workspace_.blob());
 
   if (bias_term_) {
-    beta = 1.f;
-    output.device()->Exec(
-        [output, alpha, beta, this](Context *ctx) {
-          Blob *outblob = output.blob(), *bblob = this->bias_.blob();
-          cudnnAddTensor(ctx->cudnn_handle, &alpha, this->bias_desc_,
-                         bblob->data(), &beta, this->y_desc_,
-                         outblob->mutable_data());
-        },
-        {output.blob(), bias_.blob()}, {output.blob()});
+    output.device()->Exec([output, this](Context *ctx) {
+      float beta = 1.f, alpha = 1.0f;
+      Blob *outblob = output.blob(), *bblob = this->bias_.blob();
+      cudnnAddTensor(ctx->cudnn_handle, &alpha, this->bias_desc_, bblob->data(),
+                     &beta, this->y_desc_, outblob->mutable_data());
+    }, {output.blob(), bias_.blob()}, {output.blob()});
   }
   return output;
 }
 
 const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
     int flag, const Tensor &grad) {
+  CHECK(has_init_cudnn_);
   CHECK_EQ(grad.device()->lang(), kCuda);
-  CHECK_EQ(grad.shape().size(), 4);
+  CHECK_EQ(grad.nDim(), 4u);
   Tensor src_data = buf_.top();
   buf_.pop();
-  float alpha = 1.f, beta = 0.f;
   vector<Tensor> param_grad;
   Tensor dx;
   dx.ResetLike(src_data);
@@ -187,42 +186,38 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
 
   // LOG(ERROR) << "backward bias";
   if (bias_term_) {
-    dx.device()->Exec(
-        [grad, db, alpha, beta, this](Context *ctx) {
-          Blob *dyblob = grad.blob(), *dbblob = db.blob();
-          cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, this->y_desc_,
-                                       dyblob->data(), &beta, this->bias_desc_,
-                                       dbblob->mutable_data());
-        },
-        {grad.blob()}, {db.blob()});
+    dx.device()->Exec([grad, db, this](Context *ctx) {
+      Blob *dyblob = grad.blob(), *dbblob = db.blob();
+      float alpha = 1.f, beta = 0.f;
+      cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, this->y_desc_,
+                                   dyblob->data(), &beta, this->bias_desc_,
+                                   dbblob->mutable_data());
+    }, {grad.blob()}, {db.blob()});
   }
   // LOG(ERROR) << "backward w";
-  dx.device()->Exec(
-      [grad, dw, src_data, alpha, beta, this](Context *ctx) {
-        Blob *inblob = src_data.blob(), *dyblob = grad.blob(),
-             *dwblob = dw.blob();
-        cudnnConvolutionBackwardFilter(
-            ctx->cudnn_handle, &alpha, this->x_desc_, inblob->data(),
-            this->y_desc_, dyblob->data(), this->conv_desc_,
-            this->bp_filter_alg_, this->workspace_.blob()->mutable_data(),
-            this->workspace_count_ * sizeof(float), &beta, this->filter_desc_,
-            dwblob->mutable_data());
-      },
-      {grad.blob(), src_data.blob()}, {dw.blob(), workspace_.blob()});
+  dx.device()->Exec([grad, dw, src_data, this](Context *ctx) {
+    Blob *inblob = src_data.blob(), *dyblob = grad.blob(), *dwblob = dw.blob();
+    float alpha = 1.f, beta = 0.f;
+    cudnnConvolutionBackwardFilter(
+        ctx->cudnn_handle, &alpha, this->x_desc_, inblob->data(), this->y_desc_,
+        dyblob->data(), this->conv_desc_, this->bp_filter_alg_,
+        this->workspace_.blob()->mutable_data(),
+        this->workspace_count_ * sizeof(float), &beta, this->filter_desc_,
+        dwblob->mutable_data());
+  }, {grad.blob(), src_data.blob()}, {dw.blob(), workspace_.blob()});
 
   // LOG(ERROR) << "backward src";
-  dx.device()->Exec(
-      [dx, grad, alpha, beta, this](Context *ctx) {
-        Blob *wblob = this->weight_.blob(), *dyblob = grad.blob(),
-             *dxblob = dx.blob();
-        cudnnConvolutionBackwardData(
-            ctx->cudnn_handle, &alpha, this->filter_desc_, wblob->data(),
-            this->y_desc_, dyblob->data(), this->conv_desc_, this->bp_data_alg_,
-            this->workspace_.blob()->mutable_data(),
-            this->workspace_count_ * sizeof(float), &beta, this->x_desc_,
-            dxblob->mutable_data());
-      },
-      {grad.blob(), weight_.blob()}, {dx.blob(), workspace_.blob()});
+  dx.device()->Exec([dx, grad, this](Context *ctx) {
+    Blob *wblob = this->weight_.blob(), *dyblob = grad.blob(),
+         *dxblob = dx.blob();
+    float alpha = 1.f, beta = 0.f;
+    cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, this->filter_desc_,
+                                 wblob->data(), this->y_desc_, dyblob->data(),
+                                 this->conv_desc_, this->bp_data_alg_,
+                                 this->workspace_.blob()->mutable_data(),
+                                 this->workspace_count_ * sizeof(float), &beta,
+                                 this->x_desc_, dxblob->mutable_data());
+  }, {grad.blob(), weight_.blob()}, {dx.blob(), workspace_.blob()});
   param_grad.push_back(dw);
   param_grad.push_back(db);
   return std::make_pair(dx, param_grad);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/cudnn_convolution.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.h b/src/model/layer/cudnn_convolution.h
index cf04be0..b86c576 100644
--- a/src/model/layer/cudnn_convolution.h
+++ b/src/model/layer/cudnn_convolution.h
@@ -27,6 +27,7 @@
 #include "singa/core/common.h"
 #include "singa/model/layer.h"
 #include "singa/proto/core.pb.h"
+#include "singa/utils/string.h"
 
 namespace singa {
 class CudnnConvolution : public Convolution {
@@ -41,13 +42,15 @@ class CudnnConvolution : public Convolution {
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const LayerConf &conf) override;
-  /// Init cudnn related data structures.
-  void InitCudnn(DataType dtype, Device *dev, Context *ctx);
 
   void ToDevice(Device *device) override;
 
   size_t workspace_byte_limit() { return workspace_byte_limit_; }
-  string pref() { return pref_; }
+  string prefer() { return prefer_; }
+
+ protected:
+  /// Init cudnn related data structures.
+  void InitCudnn(const Tensor& input);
 
  protected:
   bool has_init_cudnn_ = false;
@@ -61,7 +64,7 @@ class CudnnConvolution : public Convolution {
   cudnnConvolutionBwdDataAlgo_t bp_data_alg_;
   size_t workspace_byte_limit_, workspace_count_;
   Tensor workspace_;
-  string pref_;
+  string prefer_;
 };
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/cudnn_pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_pooling.cc b/src/model/layer/cudnn_pooling.cc
index d68bcd2..afbc490 100644
--- a/src/model/layer/cudnn_pooling.cc
+++ b/src/model/layer/cudnn_pooling.cc
@@ -41,17 +41,19 @@ void CudnnPooling::Setup(const LayerConf &conf) {
     nan_prop_ = CUDNN_NOT_PROPAGATE_NAN;
 }
 
-void CudnnPooling::InitCudnn(DataType dtype) {
+void CudnnPooling::InitCudnn(const Tensor& input) {
   CHECK(!has_init_cudnn_);
+  DataType dtype = input.data_type();
+  size_t batchsize = input.shape(0);
   CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
   CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
   CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc_));
 
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
-                                         GetCudnnDataType(dtype), batchsize_,
+                                         GetCudnnDataType(dtype), batchsize,
                                          channels_, height_, width_));
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(
-      y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize_,
+      y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
       channels_, pooled_height_, pooled_width_));
   auto pool_method = CUDNN_POOLING_MAX;
   if (pool_ == PoolingConf_PoolMethod_MAX)
@@ -77,19 +79,19 @@ void CudnnPooling::InitCudnn(DataType dtype) {
 
 const Tensor CudnnPooling::Forward(int flag, const Tensor &input) {
   CHECK_EQ(input.device()->lang(), kCuda);
-  CHECK_EQ(input.shape().size(), 4);
+  CHECK_EQ(input.nDim(), 4u);
   buf_.push(input);
-  batchsize_ = input.shape()[0];
+  size_t batchsize = input.shape(0);
   DataType dtype = input.data_type();
   Device *dev = input.device();
-  float alpha = 1.0f, beta = 0.0f;
-  if (!has_init_cudnn_) InitCudnn(dtype);
+  if (!has_init_cudnn_) InitCudnn(input);
 
-  Shape shape{batchsize_, channels_, pooled_height_, pooled_width_};
+  Shape shape{batchsize, channels_, pooled_height_, pooled_width_};
   Tensor output = Tensor(shape, dev, dtype);
   output.device()->Exec(
-      [input, output, alpha, beta, this](Context *ctx) {
+      [input, output, this](Context *ctx) {
         Blob *inblob = input.blob(), *outblob = output.blob();
+        float alpha = 1.0f, beta = 0.0f;
         cudnnPoolingForward(ctx->cudnn_handle, this->pool_desc_, &alpha,
                             this->x_desc_, inblob->data(), &beta, this->y_desc_,
                             outblob->mutable_data());
@@ -102,26 +104,26 @@ const Tensor CudnnPooling::Forward(int flag, const Tensor &input) {
 const std::pair<Tensor, vector<Tensor>> CudnnPooling::Backward(
     int flag, const Tensor &grad) {
   CHECK_EQ(grad.device()->lang(), kCuda);
-  CHECK_EQ(grad.shape().size(), 4);
+  CHECK_EQ(grad.nDim(), 4u);
   vector<Tensor> param_grad;
-  Tensor dx;
-  Tensor data = buf_.top();
+  Tensor y = buf_.top();
   buf_.pop();
-  Tensor src_data = buf_.top();
+  Tensor x = buf_.top();
   buf_.pop();
-  dx.ResetLike(src_data);
+  Tensor dx;
+  dx.ResetLike(x);
 
-  float alpha = 1.0f, beta = 0.0f;
   dx.device()->Exec(
-      [dx, grad, src_data, data, alpha, beta, this](Context *ctx) {
-        Blob *dyblob = grad.blob(), *dxblob = dx.blob(),
-             *yblob = data.blob(), *xblob = src_data.blob();
+      [dx, grad, x, y, this](Context *ctx) {
+        Blob *dyblob = grad.blob(), *dxblob = dx.blob(), *yblob = y.blob(),
+             *xblob = x.blob();
+        float alpha = 1.0f, beta = 0.0f;
         cudnnPoolingBackward(ctx->cudnn_handle, this->pool_desc_, &alpha,
                              this->y_desc_, yblob->data(), this->y_desc_,
                              dyblob->data(), this->x_desc_, xblob->data(),
                              &beta, this->x_desc_, dxblob->mutable_data());
       },
-      {grad.blob(), data.blob(), src_data.blob()}, {dx.blob()});
+      {grad.blob(), y.blob(), x.blob()}, {dx.blob()});
 
   return std::make_pair(dx, param_grad);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/cudnn_pooling.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_pooling.h b/src/model/layer/cudnn_pooling.h
index 14bdf40..1a38cd5 100644
--- a/src/model/layer/cudnn_pooling.h
+++ b/src/model/layer/cudnn_pooling.h
@@ -43,7 +43,7 @@ class CudnnPooling : public Pooling {
                                                    const Tensor &grad) override;
 
   /// Init cudnn related data structures.
-  void InitCudnn(DataType dtype);
+  void InitCudnn(const Tensor& input);
 
  private:
   bool has_init_cudnn_ = false;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/pooling.cc b/src/model/layer/pooling.cc
index 05c6bc9..2655369 100644
--- a/src/model/layer/pooling.cc
+++ b/src/model/layer/pooling.cc
@@ -30,8 +30,8 @@ void Pooling::Setup(const LayerConf& conf) {
     kernel_w_ = pool_conf.kernel_w();
     kernel_h_ = pool_conf.kernel_h();
   }
-  CHECK_NE(kernel_w_, 0);
-  CHECK_NE(kernel_h_, 0);
+  CHECK_GT(kernel_w_, 0u);
+  CHECK_GT(kernel_h_, 0u);
 
   if (pool_conf.has_pad()) {
     pad_w_ = pad_h_ = pool_conf.pad();
@@ -39,6 +39,8 @@ void Pooling::Setup(const LayerConf& conf) {
     pad_w_ = pool_conf.pad_w();
     pad_h_ = pool_conf.pad_h();
   }
+  CHECK_GE(pad_w_, 0u);
+  CHECK_GE(pad_h_, 0u);
 
   if (pool_conf.has_stride()) {
     stride_w_ = stride_h_ = pool_conf.stride();
@@ -46,6 +48,8 @@ void Pooling::Setup(const LayerConf& conf) {
     stride_w_ = pool_conf.stride_w();
     stride_h_ = pool_conf.stride_h();
   }
+  CHECK_GT(stride_w_, 0u);
+  CHECK_GT(stride_h_, 0u);
 
   pool_ = pool_conf.pool();
   CHECK(pool_ == PoolingConf_PoolMethod_AVE ||

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/model/layer/pooling.h
----------------------------------------------------------------------
diff --git a/src/model/layer/pooling.h b/src/model/layer/pooling.h
index ce6670d..522b603 100644
--- a/src/model/layer/pooling.h
+++ b/src/model/layer/pooling.h
@@ -46,7 +46,6 @@ class Pooling : public Layer {
   size_t stride_w() const { return stride_w_; }
   size_t stride_h() const { return stride_h_; }
   PoolingConf_PoolMethod pool_method() const { return pool_; }
-  size_t batchsize() const { return batchsize_; }
   size_t channels() const { return channels_; }
   size_t height() const { return height_; }
   size_t width() const { return width_; }
@@ -54,7 +53,7 @@ class Pooling : public Layer {
  protected:
   size_t kernel_w_, pad_w_, stride_w_;
   size_t kernel_h_, pad_h_, stride_h_;
-  size_t batchsize_, channels_, height_, width_, pooled_height_, pooled_width_;
+  size_t channels_, height_, width_, pooled_height_, pooled_width_;
   PoolingConf_PoolMethod pool_;
   // To store the input and output(of forward) tensors
   std::stack<Tensor> buf_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index 03ad6ad..66296d5 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -306,7 +306,8 @@ message ConvolutionConf {
   optional uint32 stride_h = 13; // The stride height (2D only)
   optional uint32 stride_w = 14; // The stride width (2D only)
 
-  optional uint32 group = 5 [default = 1]; // The group size for group conv
+  // SINGA: not supported.
+  // optional uint32 group = 5 [default = 1]; // The group size for group conv
 
   optional FillerConf weight_filler = 7; // The filler for the weight
   optional FillerConf bias_filler = 8; // The filler for the bias
@@ -326,20 +327,24 @@ message ConvolutionConf {
   // With (N, C, D, H, W) inputs, and axis == 1, we perform
   // N independent 3D convolutions, sliding (C/g)-channels
   // filters across the spatial axes (D, H, W) of the input.
-  optional int32 axis = 16 [default = 1];
+  // SINGA: not supported;
+  // optional int32 axis = 16 [default = 1];
 
   // Whether to force use of the general ND convolution, even if a specific
   // implementation for blobs of the appropriate number of spatial dimensions
   // is available. (Currently, there is only a 2D-specific convolution
   // implementation; for input blobs with num_axes != 2, this option is
   // ignored and the ND implementation will be used.)
-  optional bool force_nd_im2col = 17 [default = false];
-  // add by xiangrui
+  // SINGA: not supported;
+  // optional bool force_nd_im2col = 17 [default = false];
+
+
+  // SINGA: add by xiangrui
   // cudnn workspace size in MB
   optional int32 workspace_byte_limit = 50 [default = 512];
   // cudnn algorithm preference
   // options: "fastest", "limited_workspace", "no_workspace"
-  optional string algo_pref = 51 [default = "fastest"];
+  optional string prefer = 51 [default = "fastest"];
   // input shape
   optional int32 channels = 52;
   optional int32 height = 53;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/test/singa/test_cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_convolution.cc b/test/singa/test_cudnn_convolution.cc
index 0955c82..73359b4 100644
--- a/test/singa/test_cudnn_convolution.cc
+++ b/test/singa/test_cudnn_convolution.cc
@@ -40,31 +40,31 @@ TEST(CudnnConvolution, Setup) {
   convconf->set_bias_term(true);
   // MB
   convconf->set_workspace_byte_limit(256);
-  convconf->set_algo_pref("fastest");
+  convconf->set_prefer("fastest");
   convconf->set_channels(1);
   convconf->set_height(3);
   convconf->set_width(3);
   conv.Setup(conf);
 
-  EXPECT_EQ(2, conv.kernel_h());
-  EXPECT_EQ(2, conv.kernel_w());
-  EXPECT_EQ(1, conv.pad_h());
-  EXPECT_EQ(1, conv.pad_w());
-  EXPECT_EQ(1, conv.stride_h());
-  EXPECT_EQ(1, conv.stride_w());
-  EXPECT_EQ(2, conv.num_filters());
+  EXPECT_EQ(2u, conv.kernel_h());
+  EXPECT_EQ(2u, conv.kernel_w());
+  EXPECT_EQ(1u, conv.pad_h());
+  EXPECT_EQ(1u, conv.pad_w());
+  EXPECT_EQ(1u, conv.stride_h());
+  EXPECT_EQ(1u, conv.stride_w());
+  EXPECT_EQ(2u, conv.num_filters());
   EXPECT_EQ(true, conv.bias_term());
-  EXPECT_EQ(256 << 20, conv.workspace_byte_limit());
-  EXPECT_STREQ("fastest", conv.pref().c_str());
-  EXPECT_EQ(1, conv.channels());
-  EXPECT_EQ(3, conv.height());
-  EXPECT_EQ(3, conv.width());
+  EXPECT_EQ(256u << 20, conv.workspace_byte_limit());
+  EXPECT_STREQ("fastest", conv.prefer().c_str());
+  EXPECT_EQ(1u, conv.channels());
+  EXPECT_EQ(3u, conv.height());
+  EXPECT_EQ(3u, conv.width());
 }
 
 TEST(CudnnConvolution, Forward) {
   const size_t batchsize = 1, c = 1, h = 3, w = 3;
   const float x[batchsize * c * h * w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
-                                      6.0f, 7.0f, 8.0f, 9.0f};
+                                          6.0f, 7.0f, 8.0f, 9.0f};
   singa::CudaGPU cuda(0, 1);
   singa::Tensor in(singa::Shape{batchsize, c, h, w}, &cuda);
   in.CopyDataFromHostPtr(x, batchsize * c * h * w);
@@ -94,7 +94,7 @@ TEST(CudnnConvolution, Forward) {
   convconf->set_bias_term(true);
   // MB
   convconf->set_workspace_byte_limit(256);
-  convconf->set_algo_pref("fastest");
+  convconf->set_prefer("fastest");
   convconf->set_channels(1);
   convconf->set_height(3);
   convconf->set_width(3);
@@ -106,7 +106,7 @@ TEST(CudnnConvolution, Forward) {
   out1.ToDevice(&host);
   const float *outptr1 = out1.data<const float *>();
   // Input: 3*3; kernel: 3*3; stride: 2*2; padding: 1*1.
-  EXPECT_EQ(4, out1.Size());
+  EXPECT_EQ(4u, out1.Size());
 
   EXPECT_EQ(3.0f, outptr1[0]);
   EXPECT_EQ(7.0f, outptr1[1]);
@@ -118,7 +118,7 @@ TEST(CudnnConvolution, Backward) {
   // src_data
   const size_t batchsize = 1, c = 1, src_h = 3, src_w = 3;
   const float x[batchsize * c * src_h * src_w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
-                                              6.0f, 7.0f, 8.0f, 9.0f};
+                                                  6.0f, 7.0f, 8.0f, 9.0f};
   singa::CudaGPU cuda(0, 1);
   singa::Tensor in(singa::Shape{batchsize, c, src_h, src_w}, &cuda);
   in.CopyDataFromHostPtr(x, batchsize * c * src_h * src_w);
@@ -148,7 +148,7 @@ TEST(CudnnConvolution, Backward) {
   convconf->set_num_output(1);
   convconf->set_bias_term(true);
   convconf->set_workspace_byte_limit(256);
-  convconf->set_algo_pref("fastest");
+  convconf->set_prefer("fastest");
   convconf->set_channels(1);
   convconf->set_height(3);
   convconf->set_width(3);
@@ -159,8 +159,10 @@ TEST(CudnnConvolution, Backward) {
 
   // grad
   const size_t grad_h = 2, grad_w = 2;
-  const float dy[batchsize * num_filters * grad_h * grad_w] = {0.1f, 0.2f, 0.3f, 0.4f};
-  singa::Tensor grad(singa::Shape{batchsize, num_filters, grad_h, grad_w}, &cuda);
+  const float dy[batchsize * num_filters * grad_h * grad_w] = {0.1f, 0.2f, 0.3f,
+                                                               0.4f};
+  singa::Tensor grad(singa::Shape{batchsize, num_filters, grad_h, grad_w},
+                     &cuda);
   grad.CopyDataFromHostPtr(dy, batchsize * num_filters * grad_h * grad_w);
 
   const auto ret = conv.Backward(singa::kTrain, grad);
@@ -169,7 +171,7 @@ TEST(CudnnConvolution, Backward) {
   in_grad.ToDevice(&host);
   const float *dx = in_grad.data<const float *>();
   const float *wptr = we;
-  EXPECT_EQ(9, in_grad.Size());
+  EXPECT_EQ(9u, in_grad.Size());
   EXPECT_EQ(dy[0] * wptr[4], dx[0]);
   EXPECT_EQ(dy[0] * wptr[5] + dy[1] * wptr[3], dx[1]);
   EXPECT_EQ(dy[1] * wptr[4], dx[2]);
@@ -190,7 +192,7 @@ TEST(CudnnConvolution, Backward) {
   EXPECT_EQ(dy[0] + dy[1] + dy[2] + dy[3], dbptr[0]);
 
   const float *dwptr = dw.data<const float *>();
-  EXPECT_EQ(9, dw.Size());
+  EXPECT_EQ(9u, dw.Size());
   EXPECT_EQ(dy[3] * x[4], dwptr[0]);
   EXPECT_EQ(dy[3] * x[5] + dy[2] * x[3], dwptr[1]);
   EXPECT_EQ(dy[2] * x[4], dwptr[2]);
@@ -201,5 +203,5 @@ TEST(CudnnConvolution, Backward) {
   EXPECT_EQ(dy[1] * x[4], dwptr[6]);
   EXPECT_EQ(dy[0] * x[3] + dy[1] * x[5], dwptr[7]);
   EXPECT_EQ(dy[0] * x[4], dwptr[8]);
-}  // USE_CUDNN
-#endif
+}
+#endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7d149ecf/test/singa/test_cudnn_pooling.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_pooling.cc b/test/singa/test_cudnn_pooling.cc
index 0bfd620..e66f212 100644
--- a/test/singa/test_cudnn_pooling.cc
+++ b/test/singa/test_cudnn_pooling.cc
@@ -43,23 +43,23 @@ TEST(CudnnPooling, Setup) {
   pool.Setup(conf);
 
   EXPECT_EQ(singa::PoolingConf_PoolMethod_MAX, pool.pool_method());
-  EXPECT_EQ(1, pool.kernel_h());
-  EXPECT_EQ(2, pool.kernel_w());
-  EXPECT_EQ(1, pool.pad_h());
-  EXPECT_EQ(0, pool.pad_w());
-  EXPECT_EQ(2, pool.stride_h());
-  EXPECT_EQ(1, pool.stride_w());
-  EXPECT_EQ(1, pool.channels());
-  EXPECT_EQ(3, pool.height());
-  EXPECT_EQ(3, pool.width());
+  EXPECT_EQ(1u, pool.kernel_h());
+  EXPECT_EQ(2u, pool.kernel_w());
+  EXPECT_EQ(1u, pool.pad_h());
+  EXPECT_EQ(0u, pool.pad_w());
+  EXPECT_EQ(2u, pool.stride_h());
+  EXPECT_EQ(1u, pool.stride_w());
+  EXPECT_EQ(1u, pool.channels());
+  EXPECT_EQ(3u, pool.height());
+  EXPECT_EQ(3u, pool.width());
 }
 
 TEST(CudnnPooling, Forward) {
   const size_t batchsize = 1, c = 1, h = 3, w = 3;
   const float x[batchsize * c * h * w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
-                                      6.0f, 7.0f, 8.0f, 9.0f};
+                                          6.0f, 7.0f, 8.0f, 9.0f};
   singa::CudaGPU cuda(0, 1);
-  singa::Tensor in(singa::Shape{batchsize, c,  h, w}, &cuda);
+  singa::Tensor in(singa::Shape{batchsize, c, h, w}, &cuda);
   in.CopyDataFromHostPtr(x, batchsize * c * h * w);
 
   CudnnPooling pool;
@@ -83,7 +83,7 @@ TEST(CudnnPooling, Forward) {
   out1.ToDevice(&host);
   const float *outptr1 = out1.data<const float *>();
   // Input: 3*3; kernel: 2*2; stride: 1*1; no padding.
-  EXPECT_EQ(4, out1.Size());
+  EXPECT_EQ(4u, out1.Size());
   EXPECT_EQ(5.0f, outptr1[0]);
   EXPECT_EQ(6.0f, outptr1[1]);
   EXPECT_EQ(8.0f, outptr1[2]);
@@ -127,7 +127,7 @@ TEST(CudnnPooling, Backward) {
   singa::Tensor in_grad = ret.first;
   in_grad.ToDevice(&host);
   const float *dx = in_grad.data<const float *>();
-  EXPECT_EQ(9, in_grad.Size());
+  EXPECT_EQ(9u, in_grad.Size());
   EXPECT_EQ(0.0f, dx[0]);
   EXPECT_EQ(0.0f, dx[1]);
   EXPECT_EQ(0.0f, dx[2]);