You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2016/06/13 13:20:25 UTC

[32/50] [abbrv] incubator-singa git commit: SINGA-191 Add "autotune" for CudnnConvolution Layer

SINGA-191 Add "autotune" for CudnnConvolution Layer

If users choose "autotune", the layer will choose algorithm preference
and algorithm automatically. The following CUDNN functions are used:
  cudnnFindConvolutionForwardAlgorithm,
  cudnnFindConvolutionBackwardFilterAlgorithm,
  cudnnFindConvolutionBackwardDataAlgorithm


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/01aaf490
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/01aaf490
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/01aaf490

Branch: refs/heads/master
Commit: 01aaf49009d2b0c6c62a484b21a1d06bce575e81
Parents: 04e23d1
Author: XiangruiCAI <ca...@gmail.com>
Authored: Sat Jun 4 16:18:47 2016 +0800
Committer: XiangruiCAI <ca...@gmail.com>
Committed: Wed Jun 8 10:35:51 2016 +0800

----------------------------------------------------------------------
 src/model/layer/cudnn_convolution.cc | 180 +++++++++++++++++------------
 test/singa/test_cudnn_convolution.cc | 181 ++++++++++++++++++++++++++++++
 2 files changed, 287 insertions(+), 74 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/01aaf490/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index 97aa256..b80c3bd 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -41,9 +41,9 @@ void CudnnConvolution::Setup(const LayerConf &conf) {
   workspace_byte_limit_ = conv_conf.workspace_byte_limit() << 20;
   prefer_ = ToLowerCase(conv_conf.prefer());
   CHECK(prefer_ == "fastest" || prefer_ == "limited_workspace" ||
-        prefer_ == "no_workspace")
-      << "CudnnConvolution only supports three algorithm preferences: fastest, "
-         "limited_workspace and no_workspace";
+        prefer_ == "no_workspace" || prefer_ == "autotune")
+      << "CudnnConvolution only supports four algorithm preferences: fastest, "
+         "limited_workspace, no_workspace and autotune";
 }
 
 void CudnnConvolution::ToDevice(Device *device) {
@@ -52,7 +52,7 @@ void CudnnConvolution::ToDevice(Device *device) {
   workspace_.ToDevice(device);
 }
 
-void CudnnConvolution::InitCudnn(const Tensor& input) {
+void CudnnConvolution::InitCudnn(const Tensor &input) {
   CHECK(!has_init_cudnn_);
   DataType dtype = input.data_type();
   Device *dev = input.device();
@@ -89,34 +89,54 @@ void CudnnConvolution::InitCudnn(const Tensor& input) {
   LOG(FATAL) << "Not supported CUDNN version = " << CUDNN_VERSION_MAJOR;
 #endif
 
-  cudnnConvolutionFwdPreference_t fwd_pref;
-  cudnnConvolutionBwdFilterPreference_t bwd_filt_pref;
-  cudnnConvolutionBwdDataPreference_t bwd_data_pref;
-  if (prefer_ == "fastest") {
-    fwd_pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
-    bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST;
-    bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
-  } else if (prefer_ == "limited_workspace") {
-    fwd_pref = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT;
-    bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT;
-    bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
-  } else if (prefer_ == "no_workspace") {
-    fwd_pref = CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
-    bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
-    bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
+  if (prefer_ == "fastest" || prefer_ == "limited_workspace" ||
+      prefer_ == "no_workspace") {
+    cudnnConvolutionFwdPreference_t fwd_pref;
+    cudnnConvolutionBwdFilterPreference_t bwd_filt_pref;
+    cudnnConvolutionBwdDataPreference_t bwd_data_pref;
+    if (prefer_ == "fastest") {
+      fwd_pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
+      bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST;
+      bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
+    } else if (prefer_ == "limited_workspace") {
+      fwd_pref = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT;
+      bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT;
+      bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
+    } else {
+      fwd_pref = CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
+      bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
+      bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
+    }
+    CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
+        ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
+        workspace_byte_limit_, &fp_alg_));
+    CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
+        ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
+        bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_));
+    CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
+        ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
+        bwd_data_pref, workspace_byte_limit_, &bp_data_alg_));
+  } else if (prefer_ == "autotune") {
+    const int topk = 1;
+    int num_fp_alg, num_bp_filt_alg, num_bp_data_alg;
+    cudnnConvolutionFwdAlgoPerf_t fp_alg_perf[topk];
+    cudnnConvolutionBwdFilterAlgoPerf_t bp_filt_perf[topk];
+    cudnnConvolutionBwdDataAlgoPerf_t bp_data_perf[topk];
+    CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm(
+        ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk,
+        &num_fp_alg, fp_alg_perf));
+    fp_alg_ = fp_alg_perf[0].algo;
+    CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(
+        ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk,
+        &num_bp_filt_alg, bp_filt_perf));
+    bp_filter_alg_ = bp_filt_perf[0].algo;
+    CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm(
+        ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk,
+        &num_bp_data_alg, bp_data_perf));
+    bp_data_alg_ = bp_data_perf[0].algo;
   } else {
     LOG(FATAL) << "Preferred algorithm is not available!";
   }
-  CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
-      ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
-      workspace_byte_limit_, &fp_alg_));
-
-  CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
-      ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
-      bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_));
-  CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
-      ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
-      bwd_data_pref, workspace_byte_limit_, &bp_data_alg_));
 
   size_t fp_byte, bp_data_byte, bp_filter_byte;
   CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
@@ -147,25 +167,30 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) {
 
   Shape shape{batchsize, num_filters_, conv_height_, conv_width_};
   Tensor output(shape, dev, dtype);
-  output.device()->Exec([input, output, this](Context *ctx) {
-    Blob *inblob = input.blob(), *outblob = output.blob(),
-         *wblob = this->weight_.blob();
-    float alpha = 1.f, beta = 0.f;
-    cudnnConvolutionForward(ctx->cudnn_handle, &alpha, this->x_desc_,
-                            inblob->data(), this->filter_desc_, wblob->data(),
-                            this->conv_desc_, this->fp_alg_,
-                            this->workspace_.blob()->mutable_data(),
-                            this->workspace_count_ * sizeof(float), &beta,
-                            this->y_desc_, outblob->mutable_data());
-  }, {input.blob(), weight_.blob()}, {output.blob()}, workspace_.blob());
+  output.device()->Exec(
+      [input, output, this](Context *ctx) {
+        Blob *inblob = input.blob(), *outblob = output.blob(),
+             *wblob = this->weight_.blob();
+        float alpha = 1.f, beta = 0.f;
+        cudnnConvolutionForward(ctx->cudnn_handle, &alpha, this->x_desc_,
+                                inblob->data(), this->filter_desc_,
+                                wblob->data(), this->conv_desc_, this->fp_alg_,
+                                this->workspace_.blob()->mutable_data(),
+                                this->workspace_count_ * sizeof(float), &beta,
+                                this->y_desc_, outblob->mutable_data());
+      },
+      {input.blob(), weight_.blob()}, {output.blob()}, workspace_.blob());
 
   if (bias_term_) {
-    output.device()->Exec([output, this](Context *ctx) {
-      float beta = 1.f, alpha = 1.0f;
-      Blob *outblob = output.blob(), *bblob = this->bias_.blob();
-      cudnnAddTensor(ctx->cudnn_handle, &alpha, this->bias_desc_, bblob->data(),
-                     &beta, this->y_desc_, outblob->mutable_data());
-    }, {output.blob(), bias_.blob()}, {output.blob()});
+    output.device()->Exec(
+        [output, this](Context *ctx) {
+          float beta = 1.f, alpha = 1.0f;
+          Blob *outblob = output.blob(), *bblob = this->bias_.blob();
+          cudnnAddTensor(ctx->cudnn_handle, &alpha, this->bias_desc_,
+                         bblob->data(), &beta, this->y_desc_,
+                         outblob->mutable_data());
+        },
+        {output.blob(), bias_.blob()}, {output.blob()});
   }
   return output;
 }
@@ -187,38 +212,45 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
 
   // LOG(ERROR) << "backward bias";
   if (bias_term_) {
-    dx.device()->Exec([grad, db, this](Context *ctx) {
-      Blob *dyblob = grad.blob(), *dbblob = db.blob();
-      float alpha = 1.f, beta = 0.f;
-      cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, this->y_desc_,
-                                   dyblob->data(), &beta, this->bias_desc_,
-                                   dbblob->mutable_data());
-    }, {grad.blob()}, {db.blob()});
+    dx.device()->Exec(
+        [grad, db, this](Context *ctx) {
+          Blob *dyblob = grad.blob(), *dbblob = db.blob();
+          float alpha = 1.f, beta = 0.f;
+          cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, this->y_desc_,
+                                       dyblob->data(), &beta, this->bias_desc_,
+                                       dbblob->mutable_data());
+        },
+        {grad.blob()}, {db.blob()});
   }
   // LOG(ERROR) << "backward w";
-  dx.device()->Exec([grad, dw, src_data, this](Context *ctx) {
-    Blob *inblob = src_data.blob(), *dyblob = grad.blob(), *dwblob = dw.blob();
-    float alpha = 1.f, beta = 0.f;
-    cudnnConvolutionBackwardFilter(
-        ctx->cudnn_handle, &alpha, this->x_desc_, inblob->data(), this->y_desc_,
-        dyblob->data(), this->conv_desc_, this->bp_filter_alg_,
-        this->workspace_.blob()->mutable_data(),
-        this->workspace_count_ * sizeof(float), &beta, this->filter_desc_,
-        dwblob->mutable_data());
-  }, {grad.blob(), src_data.blob()}, {dw.blob(), workspace_.blob()});
+  dx.device()->Exec(
+      [grad, dw, src_data, this](Context *ctx) {
+        Blob *inblob = src_data.blob(), *dyblob = grad.blob(),
+             *dwblob = dw.blob();
+        float alpha = 1.f, beta = 0.f;
+        cudnnConvolutionBackwardFilter(
+            ctx->cudnn_handle, &alpha, this->x_desc_, inblob->data(),
+            this->y_desc_, dyblob->data(), this->conv_desc_,
+            this->bp_filter_alg_, this->workspace_.blob()->mutable_data(),
+            this->workspace_count_ * sizeof(float), &beta, this->filter_desc_,
+            dwblob->mutable_data());
+      },
+      {grad.blob(), src_data.blob()}, {dw.blob(), workspace_.blob()});
 
   // LOG(ERROR) << "backward src";
-  dx.device()->Exec([dx, grad, this](Context *ctx) {
-    Blob *wblob = this->weight_.blob(), *dyblob = grad.blob(),
-         *dxblob = dx.blob();
-    float alpha = 1.f, beta = 0.f;
-    cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, this->filter_desc_,
-                                 wblob->data(), this->y_desc_, dyblob->data(),
-                                 this->conv_desc_, this->bp_data_alg_,
-                                 this->workspace_.blob()->mutable_data(),
-                                 this->workspace_count_ * sizeof(float), &beta,
-                                 this->x_desc_, dxblob->mutable_data());
-  }, {grad.blob(), weight_.blob()}, {dx.blob(), workspace_.blob()});
+  dx.device()->Exec(
+      [dx, grad, this](Context *ctx) {
+        Blob *wblob = this->weight_.blob(), *dyblob = grad.blob(),
+             *dxblob = dx.blob();
+        float alpha = 1.f, beta = 0.f;
+        cudnnConvolutionBackwardData(
+            ctx->cudnn_handle, &alpha, this->filter_desc_, wblob->data(),
+            this->y_desc_, dyblob->data(), this->conv_desc_, this->bp_data_alg_,
+            this->workspace_.blob()->mutable_data(),
+            this->workspace_count_ * sizeof(float), &beta, this->x_desc_,
+            dxblob->mutable_data());
+      },
+      {grad.blob(), weight_.blob()}, {dx.blob(), workspace_.blob()});
   param_grad.push_back(dw);
   param_grad.push_back(db);
   return std::make_pair(dx, param_grad);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/01aaf490/test/singa/test_cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_convolution.cc b/test/singa/test_cudnn_convolution.cc
index 73359b4..2a17da2 100644
--- a/test/singa/test_cudnn_convolution.cc
+++ b/test/singa/test_cudnn_convolution.cc
@@ -204,4 +204,185 @@ TEST(CudnnConvolution, Backward) {
   EXPECT_EQ(dy[0] * x[3] + dy[1] * x[5], dwptr[7]);
   EXPECT_EQ(dy[0] * x[4], dwptr[8]);
 }
+// Tests for prefer=autotune
+TEST(CudnnConvolution_AT, Setup) {
+  CudnnConvolution conv;
+  EXPECT_EQ("CudnnConvolution", conv.layer_type());
+
+  singa::LayerConf conf;
+  singa::ConvolutionConf *convconf = conf.mutable_convolution_conf();
+  convconf->set_kernel_h(2);
+  convconf->set_kernel_w(2);
+  convconf->set_pad_h(1);
+  convconf->set_pad_w(1);
+  convconf->set_stride_h(1);
+  convconf->set_stride_w(1);
+  convconf->set_num_output(2);
+  convconf->set_bias_term(true);
+  // MB
+  convconf->set_workspace_byte_limit(256);
+  convconf->set_prefer("autotune");
+  convconf->set_channels(1);
+  convconf->set_height(3);
+  convconf->set_width(3);
+  conv.Setup(conf);
+
+  EXPECT_EQ(2u, conv.kernel_h());
+  EXPECT_EQ(2u, conv.kernel_w());
+  EXPECT_EQ(1u, conv.pad_h());
+  EXPECT_EQ(1u, conv.pad_w());
+  EXPECT_EQ(1u, conv.stride_h());
+  EXPECT_EQ(1u, conv.stride_w());
+  EXPECT_EQ(2u, conv.num_filters());
+  EXPECT_EQ(true, conv.bias_term());
+  EXPECT_EQ(256u << 20, conv.workspace_byte_limit());
+  EXPECT_STREQ("autotune", conv.prefer().c_str());
+  EXPECT_EQ(1u, conv.channels());
+  EXPECT_EQ(3u, conv.height());
+  EXPECT_EQ(3u, conv.width());
+}
+
+TEST(CudnnConvolution_AT, Forward) {
+  const size_t batchsize = 1, c = 1, h = 3, w = 3;
+  const float x[batchsize * c * h * w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                                          6.0f, 7.0f, 8.0f, 9.0f};
+  singa::CudaGPU cuda(0, 1);
+  singa::Tensor in(singa::Shape{batchsize, c, h, w}, &cuda);
+  in.CopyDataFromHostPtr(x, batchsize * c * h * w);
+
+  // Set weight and bias manually
+  const size_t num_filters = 1;
+  const float we[num_filters * batchsize * h * w] = {
+      1.0f, 1.0f, 0.0f, 0.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f};
+  singa::Tensor weight(singa::Shape{num_filters, batchsize * h * w}, &cuda);
+  weight.CopyDataFromHostPtr(we, batchsize * h * w);
+  const float b[num_filters] = {1.0f};
+  singa::Tensor bias(singa::Shape{num_filters}, &cuda);
+  bias.CopyDataFromHostPtr(b, num_filters);
+  CudnnConvolution conv;
+  conv.set_weight(weight);
+  conv.set_bias(bias);
+
+  singa::LayerConf conf;
+  singa::ConvolutionConf *convconf = conf.mutable_convolution_conf();
+  convconf->set_kernel_h(3);
+  convconf->set_kernel_w(3);
+  convconf->set_pad_h(1);
+  convconf->set_pad_w(1);
+  convconf->set_stride_h(2);
+  convconf->set_stride_w(2);
+  convconf->set_num_output(1);
+  convconf->set_bias_term(true);
+  // MB
+  convconf->set_workspace_byte_limit(256);
+  convconf->set_prefer("autotune");
+  convconf->set_channels(1);
+  convconf->set_height(3);
+  convconf->set_width(3);
+  conv.Setup(conf);
+
+  // Parameter "flag" does not influence convolution
+  singa::Tensor out1 = conv.Forward(singa::kTrain, in);
+  singa::CppCPU host(0, 1);
+  out1.ToDevice(&host);
+  const float *outptr1 = out1.data<const float *>();
+  // Input: 3*3; kernel: 3*3; stride: 2*2; padding: 1*1.
+  EXPECT_EQ(4u, out1.Size());
+
+  EXPECT_EQ(3.0f, outptr1[0]);
+  EXPECT_EQ(7.0f, outptr1[1]);
+  EXPECT_EQ(-3.0f, outptr1[2]);
+  EXPECT_EQ(12.0f, outptr1[3]);
+}
+
+TEST(CudnnConvolution_AT, Backward) {
+  // src_data
+  const size_t batchsize = 1, c = 1, src_h = 3, src_w = 3;
+  const float x[batchsize * c * src_h * src_w] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                                                  6.0f, 7.0f, 8.0f, 9.0f};
+  singa::CudaGPU cuda(0, 1);
+  singa::Tensor in(singa::Shape{batchsize, c, src_h, src_w}, &cuda);
+  in.CopyDataFromHostPtr(x, batchsize * c * src_h * src_w);
+
+  // Set weight_ and bias_ manually
+  const size_t num_filters = 1;
+  const float we[num_filters * batchsize * src_h * src_w] = {
+      1.0f, 1.0f, 0.0f, 0.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f};
+  singa::Tensor weight(singa::Shape{num_filters, batchsize * src_h * src_w},
+                       &cuda);
+  weight.CopyDataFromHostPtr(we, batchsize * src_h * src_w);
+  const float b[num_filters] = {1.0f};
+  singa::Tensor bias(singa::Shape{num_filters}, &cuda);
+  bias.CopyDataFromHostPtr(b, num_filters);
+  CudnnConvolution conv;
+  conv.set_weight(weight);
+  conv.set_bias(bias);
+
+  singa::LayerConf conf;
+  singa::ConvolutionConf *convconf = conf.mutable_convolution_conf();
+  convconf->set_kernel_h(3);
+  convconf->set_kernel_w(3);
+  convconf->set_pad_h(1);
+  convconf->set_pad_w(1);
+  convconf->set_stride_h(2);
+  convconf->set_stride_w(2);
+  convconf->set_num_output(1);
+  convconf->set_bias_term(true);
+  convconf->set_workspace_byte_limit(256);
+  convconf->set_prefer("autotune");
+  convconf->set_channels(1);
+  convconf->set_height(3);
+  convconf->set_width(3);
+  conv.Setup(conf);
+
+  // Parameter "flag" does not influence convolution
+  singa::Tensor out1 = conv.Forward(singa::kTrain, in);
+
+  // grad
+  const size_t grad_h = 2, grad_w = 2;
+  const float dy[batchsize * num_filters * grad_h * grad_w] = {0.1f, 0.2f, 0.3f,
+                                                               0.4f};
+  singa::Tensor grad(singa::Shape{batchsize, num_filters, grad_h, grad_w},
+                     &cuda);
+  grad.CopyDataFromHostPtr(dy, batchsize * num_filters * grad_h * grad_w);
+
+  const auto ret = conv.Backward(singa::kTrain, grad);
+  singa::CppCPU host(0, 1);
+  singa::Tensor in_grad = ret.first;
+  in_grad.ToDevice(&host);
+  const float *dx = in_grad.data<const float *>();
+  const float *wptr = we;
+  EXPECT_EQ(9u, in_grad.Size());
+  EXPECT_EQ(dy[0] * wptr[4], dx[0]);
+  EXPECT_EQ(dy[0] * wptr[5] + dy[1] * wptr[3], dx[1]);
+  EXPECT_EQ(dy[1] * wptr[4], dx[2]);
+  EXPECT_EQ(dy[0] * wptr[7] + dy[2] * wptr[1], dx[3]);
+  EXPECT_EQ(
+      dy[0] * wptr[8] + dy[1] * wptr[6] + dy[2] * wptr[2] + dy[3] * wptr[0],
+      dx[4]);
+  EXPECT_EQ(dy[1] * wptr[7] + dy[3] * wptr[1], dx[5]);
+  EXPECT_EQ(dy[2] * wptr[4], dx[6]);
+  EXPECT_EQ(dy[2] * wptr[5] + dy[3] * wptr[3], dx[7]);
+  EXPECT_EQ(dy[3] * wptr[4], dx[8]);
+
+  singa::Tensor dw = ret.second[0];
+  singa::Tensor db = ret.second[1];
+  dw.ToDevice(&host);
+  db.ToDevice(&host);
+  const float *dbptr = db.data<const float *>();
+  EXPECT_EQ(dy[0] + dy[1] + dy[2] + dy[3], dbptr[0]);
+
+  const float *dwptr = dw.data<const float *>();
+  EXPECT_EQ(9u, dw.Size());
+  EXPECT_EQ(dy[3] * x[4], dwptr[0]);
+  EXPECT_EQ(dy[3] * x[5] + dy[2] * x[3], dwptr[1]);
+  EXPECT_EQ(dy[2] * x[4], dwptr[2]);
+  EXPECT_EQ(dy[1] * x[1] + dy[3] * x[7], dwptr[3]);
+  EXPECT_FLOAT_EQ(dy[0] * x[0] + dy[1] * x[2] + dy[2] * x[6] + dy[3] * x[8],
+                  dwptr[4]);
+  EXPECT_EQ(dy[0] * x[1] + dy[2] * x[7], dwptr[5]);
+  EXPECT_EQ(dy[1] * x[4], dwptr[6]);
+  EXPECT_EQ(dy[0] * x[3] + dy[1] * x[5], dwptr[7]);
+  EXPECT_EQ(dy[0] * x[4], dwptr[8]);
+}
 #endif  // USE_CUDNN