You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/07/05 03:10:05 UTC

[10/18] incubator-singa git commit: SINGA-371 Implement functional operations in c++ for autograd

SINGA-371 Implement functional operations in c++ for autograd

- merge cpu and gpu parts for con2d operation in python part(not complete)

- redesign handles(recorder)

- redesign api

- parts of codes have passed unit tests


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/189958ab
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/189958ab
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/189958ab

Branch: refs/heads/master
Commit: 189958ab53bff686f39e599a36f21867cd4a265f
Parents: dfe4478
Author: xuewanqi <xu...@outlook.com>
Authored: Wed Jun 27 14:57:21 2018 +0000
Committer: xuewanqi <xu...@outlook.com>
Committed: Wed Jun 27 14:57:21 2018 +0000

----------------------------------------------------------------------
 python/singa/autograd.py                   |  79 ++--
 src/CMakeLists.txt                         |   1 +
 src/api/model_operation.i                  |  43 +--
 src/model/convolution_functions.cc         | 457 ------------------------
 src/model/convolution_functions.h          |  95 -----
 src/model/operation/convolution_related.cc | 417 +++++++++++++++++++++
 src/model/operation/convolution_related.h  |  75 ++++
 7 files changed, 561 insertions(+), 606 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/189958ab/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 4f45bf1..e898312 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -583,7 +583,7 @@ class Flatten(Operation):
 def flatten(x):
     return Flatten()(x)[0]
 
-class Conv2d_GPU(Operation):
+class Conv2D(Operation):
     def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                  padding=0, dilation=1, groups=1, bias=True, **kwargs):
 
@@ -616,20 +616,14 @@ class Conv2d_GPU(Operation):
 
         self.bias = bias
 
-        inner_params = {'cudnn_prefer': 'fastest', 'workspace_MB_limit': 1024}
+        self.inner_params = {'cudnn_prefer': 'fastest', 'workspace_MB_limit': 1024}
         # TODO valid value of inner_params check
 
         for kwarg in kwargs:
-            if kwarg not in inner_params:
+            if kwarg not in self.inner_params:
                 raise TypeError('Keyword argument not understood:', kwarg)
             else:
-                inner_params[kwarg] = kwargs[kwarg]
-
-        self.convhandle = singa.SetupConv(self.kernel_size[0], self.kernel_size[1],
-        			self.padding[0], self.padding[1], self.stride[0], self.stride[1],
-        			self.in_channels, self.out_channels, self.bias, 
-                                inner_params['workspace_MB_limit']*1024*1024,
-        			inner_params['cudnn_prefer'])
+                self.inner_params[kwarg] = kwargs[kwarg]
         
         w_shape = (self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
         self.W = Tensor(shape=w_shape, requires_grad=True, stores_grad=True)
@@ -639,11 +633,13 @@ class Conv2d_GPU(Operation):
 
         if self.bias:
             b_shape = (self.out_channels,)
+            self.b = Tensor(shape=b_shape, requires_grad=True, stores_grad=True)
+            self.b.set_value(0.0)
         else:
-            b_shape = (1,) #to keep consistency when to do forward.
-        self.b = Tensor(shape=b_shape, requires_grad=True, stores_grad=True)
-        self.b.set_value(0.0)
-
+            #to keep consistency when to do forward.
+            self.b = Tensor(data=CTensor([]), requires_grad=False, stores_grad=False)
+        
+        self.reset = False
 
     def __call__(self, x):
         assert x.ndim() == 4, 'The dimensions of input should be 4D.'
@@ -651,10 +647,13 @@ class Conv2d_GPU(Operation):
         assert 0 == 0, 'invalid padding.'
     	# TODO valid padding check.
 
-    	if not hasattr (self, 'cudnnconvhandle'):
-    	    self.cudnnconvhandle = singa.InitCudnn(x.data, self.convhandle)
-    	elif x.shape[0] != self.cudnnconvhandle.batchsize:
-    	    self.cudnnconvhandle = singa.InitCudnn(x.data, self.convhandle)
+    	if not hasattr (self, 'recorder'):
+    	    self.recorder = singa.SetupRecorder(x.data, self.kernel_size, self.stride,
+                                self.padding, self.in_channels, self.out_channels, self.bias)
+    	elif x.shape[0] != self.recorder.batchsize:
+    	    self.recorder = singa.SetupRecorder(x.data, self.kernel_size, self.stride,
+                                self.padding, self.in_channels, self.out_channels, self.bias)
+            self.reset = True
         
         if training:
             self.x = x
@@ -664,26 +663,50 @@ class Conv2d_GPU(Operation):
     	self.W.to_device(self.dev)
     	xs = [x, self.W]
     	
-    	self.b.to_device(self.dev)
+        if self.bias:
+    	   self.b.to_device(self.dev)
     	xs.append(self.b)
     	return self._do_forward(*xs)[0]
 
     def forward(self, *xs):
-        return singa.CudnnConvForward(xs[0], xs[1], xs[2], self.convhandle, self.cudnnconvhandle)
+        if gpu:
+            
+            if not hasattr(self, 'cudnnconvhandles'):
+                self.cudnnconvhandles=InitCudnnConvHandles(xs[0], self.recorder, 
+                    self.inner_params['workspace_MB_limit']*1024*1024, self.inner_params['cudnn_prefer'])
+            elif self.reset:
+                self.cudnnconvhandles=InitCudnnConvHandles(xs[0], self.recorder, 
+                    self.inner_params['workspace_MB_limit']*1024*1024, self.inner_params['cudnn_prefer'])
+
+            return singa.GpuConvForward(xs[0], xs[1], xs[2], self.recorder, self.cudnnconvhandles)
+
+        if cpu:
+
+            return singa.CpuConvForward(xs[0], xs[1], xs[2], self.recorder)
 
     def backward(self, dy):
-        assert training is True and hasattr(self, 'x'), 'Please set \'training\' as True before do BP. '
+        assert training is True and hasattr(self, 'x'), 'Please set training as True before do BP. '
 
         # todo check device?
         dy.ToDevice(self.dev)
 
-        dx = singa.CudnnConvBackwardx(dy, self.W.data, self.x.data, self.cudnnconvhandle)
-        dW = singa.CudnnConvBackwardW(dy, self.x.data, self.W.data, self.cudnnconvhandle)
-        if self.bias:
-    	    db = singa.CudnnConvBackwardb(dy, self.b.data, self.cudnnconvhandle)
-    	    return dx, dW, db
-        else:
-    	    return dx, dW
+        if gpu:
+            dx = singa.GpuConvBackwardx(dy, self.W.data, self.x.data, self.cudnnconvhandles)
+            dW = singa.GpuConvBackwardW(dy, self.x.data, self.W.data, self.cudnnconvhandles)
+            if self.bias:
+        	    db = singa.GpuConvBackwardb(dy, self.b.data, self.cudnnconvhandles)
+        	    return dx, dW, db
+            else:
+        	    return dx, dW
+
+        if cpu:
+            dx = singa.CpuConvBackwardx(dy, self.W.data, self.x.data, self.recorder)
+            dW = singa.CpuConvBackwardW(dy, self.x.data, self.W.data, self.recorder)
+            if self.bias:
+                db = singa.CpuConvBackwardb(dy, self.b.data, self.recorder)
+                return dx, dW, db
+            else:
+                return dx, dW
 
 def infer_dependency(op):
     '''

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/189958ab/src/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 709894b..7dd9bf7 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -58,6 +58,7 @@ AUX_SOURCE_DIRECTORY(model/optimizer model_source)
 AUX_SOURCE_DIRECTORY(model/loss model_source)
 AUX_SOURCE_DIRECTORY(model/metric model_source)
 AUX_SOURCE_DIRECTORY(model/updater model_source)
+AUX_SOURCE_DIRECTORY(model/operation model_source)
 LIST(APPEND singa_sources ${model_source})
 
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/189958ab/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index 20f112e..1d31b9d 100644
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -1,47 +1,38 @@
 %module model_operation
 
 %{
-#include "../src/model/convolution_functions.h"
+#include "../src/model/operation/convolution_related.h"
 %}
 namespace singa{
 
-struct ConvHandle{};
+struct Recorder{size_t batchsize;};
 
-struct CudnnConvHandle{size_t batchsize;};
+struct CudnnConvHandles{};
 
-struct CpuConvHandle{};
 
-ConvHandle SetupConv(
-    const size_t kernel_h_, const size_t kernel_w_,
-    const size_t pad_h_, const size_t pad_w_,
-    const size_t stride_h_,const size_t stride_w_,
-    const size_t channels_, const size_t num_filters_,
-    const bool bias_term_ = true, const size_t workspace_byte_limit_ =1024*1024*1024,
-    const std::string prefer_="fastest");
+Recorder SetupRecorder(const Tensor &input, const std::vector<size_t> kernel_size, 
+	                const std::vector<size_t> stride, const std::vector<size_t> padding,
+	                const size_t in_channels, const size_t out_channels,
+	                const bool bias_term_);
 
-CudnnConvHandle InitCudnn(const Tensor &input, const ConvHandle ch);
+CudnnConvHandles InitCudnnConvHandles(const Tensor &input, const Recorder r, 
+     const size_t workspace_byte_limit_=1024*1024*1024, const std::string prefer_="fastest");
 
-Tensor CudnnConvForward(const Tensor &x, const Tensor &W, const Tensor &b,
-                        const ConvHandle ch, const CudnnConvHandle cch);
+Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const Recorder r, const CudnnConvHandles cch);
 
-Tensor CudnnConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const CudnnConvHandle cch);
+Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandles cch);
 
-Tensor CudnnConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle cch);
+Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const CudnnConvHandles cch);
 
-Tensor CudnnConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandle cch);
+Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandles cch);
 
 
-CpuConvHandle InitCpuHandle(const Tensor &input, const ConvHandle ch);
+Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b, const Recorder r);
 
-Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b,
-                        const ConvHandle ch, const CpuConvHandle cch);
+Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const Recorder r);
 
-Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, 
-    const ConvHandle ch, const CpuConvHandle cch);
+Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const Recorder r);
 
-Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, 
-    const ConvHandle ch, const CpuConvHandle cch);
-
-Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle ch, const CpuConvHandle cch);
+Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const Recorder r);
 
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/189958ab/src/model/convolution_functions.cc
----------------------------------------------------------------------
diff --git a/src/model/convolution_functions.cc b/src/model/convolution_functions.cc
deleted file mode 100644
index 6e4b195..0000000
--- a/src/model/convolution_functions.cc
+++ /dev/null
@@ -1,457 +0,0 @@
-//#include <string>
-//#include <cudnn.h>
-//#include "./layer/cudnn_convolution.h"
-//#include "./layer/cudnn_utils.h"
-//#include "singa/utils/logging.h"
-#include "./convolution_functions.h"
-#include "./layer/convolution.h"
-#include<iostream>
-namespace singa{
-
-// Done in conv2d.__init__()
-ConvHandle SetupConv(
-    const size_t kernel_h_, const size_t kernel_w_,
-    const size_t pad_h_, const size_t pad_w_,
-    const size_t stride_h_,const size_t stride_w_,
-    const size_t channels_, const size_t num_filters_,
-    const bool bias_term_ , const size_t workspace_byte_limit_,
-    const std::string prefer_){
-	 return ConvHandle{
-            kernel_w_,
-            pad_w_,
-            stride_w_,
-            kernel_h_,
-            pad_h_,
-            stride_h_,
-
-            channels_,
-            num_filters_,
-
-            bias_term_,
-
-            workspace_byte_limit_,
-            prefer_,
-    };
-};
-
-
-// Done in conv2d.__call__():
-// if self.cudnnconvhandle is None:
-//     self.cudnnconvhandle= InitCudnn(...)
-// elif x.shape(0) != self.cudnnconvhandle.batchsize:
-//     self.cudnnconvhandle= InitCudnn(...)
-CudnnConvHandle InitCudnn(const Tensor &input, const ConvHandle ch){
-
-    cudnnTensorDescriptor_t x_desc_ = nullptr;
-    cudnnTensorDescriptor_t y_desc_ = nullptr;
-    cudnnTensorDescriptor_t bias_desc_ = nullptr;
-    cudnnFilterDescriptor_t filter_desc_ = nullptr;
-    cudnnConvolutionDescriptor_t conv_desc_ = nullptr;
-    cudnnConvolutionFwdAlgo_t fp_alg_;
-    cudnnConvolutionBwdFilterAlgo_t bp_filter_alg_;
-    cudnnConvolutionBwdDataAlgo_t bp_data_alg_;
-    size_t workspace_count_;
-    Tensor workspace_;
-
-    size_t height_;
-    size_t width_;
-    size_t conv_height_;
-    size_t conv_width_;
-    
-    DataType dtype = input.data_type();
-    auto dev = input.device();
-    Context *ctx = dev->context(0);
-    
-    size_t batchsize, channels_;
-    batchsize = input.shape(0);
-    channels_ = input.shape(1);
-    height_ = input.shape(2);
-    width_ = input.shape(3);
-
-    CHECK(channels_ == ch.channels_)<<"the number of input channels mismatched.";
-
-    conv_height_ = 1;
-    if (ch.stride_h_ > 0)
-        conv_height_ = (height_ + 2 * ch.pad_h_ - ch.kernel_h_) / ch.stride_h_ + 1;
-    conv_width_ = (width_ + 2 * ch.pad_w_ - ch.kernel_w_) / ch.stride_w_ + 1;
-    
-    CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
-    CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
-    if (ch.bias_term_)
-        CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
-    CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
-    CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_));
-
-
-    CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
-                                           GetCudnnDataType(dtype), batchsize,
-                                           ch.channels_, height_, width_));
-    CUDNN_CHECK(cudnnSetTensor4dDescriptor(
-            y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
-            ch.num_filters_, conv_height_, conv_width_));
-    if (ch.bias_term_)
-        CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW,
-                                               GetCudnnDataType(dtype), 1,
-                                               ch.num_filters_, 1, 1));
-    CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, ch.pad_h_, ch.pad_w_,
-                                                ch.stride_h_, ch.stride_w_, 1, 1,
-                                                CUDNN_CROSS_CORRELATION,
-                                                GetCudnnDataType(dtype)));
-    CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc_, GetCudnnDataType(dtype),
-                                           CUDNN_TENSOR_NCHW, ch.num_filters_,
-                                           channels_, ch.kernel_h_, ch.kernel_w_));
-    if (ch.prefer_ == "fastest" || ch.prefer_ == "limited_workspace" ||
-        ch.prefer_ == "no_workspace") {
-        cudnnConvolutionFwdPreference_t fwd_pref;
-        cudnnConvolutionBwdFilterPreference_t bwd_filt_pref;
-        cudnnConvolutionBwdDataPreference_t bwd_data_pref;
-        if (ch.prefer_ == "fastest") {
-            fwd_pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
-            bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST;
-            bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
-        } else if (ch.prefer_ == "limited_workspace") {
-            fwd_pref = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT;
-            bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT;
-            bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
-        } else {
-            fwd_pref = CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
-            bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
-            bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
-        }
-        CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
-                ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
-                ch.workspace_byte_limit_, &fp_alg_));
-        CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
-                ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
-                bwd_filt_pref, ch.workspace_byte_limit_, &bp_filter_alg_));
-        // deprecated in cudnn v7
-        CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
-                ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
-                bwd_data_pref, ch.workspace_byte_limit_, &bp_data_alg_));
-        } else if (ch.prefer_ == "autotune") {
-        const int topk = 1;
-        int num_fp_alg, num_bp_filt_alg, num_bp_data_alg;
-        cudnnConvolutionFwdAlgoPerf_t fp_alg_perf[topk];
-        cudnnConvolutionBwdFilterAlgoPerf_t bp_filt_perf[topk];
-        cudnnConvolutionBwdDataAlgoPerf_t bp_data_perf[topk];
-        CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm(
-                ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk,
-                &num_fp_alg, fp_alg_perf));
-        fp_alg_ = fp_alg_perf[0].algo;
-        CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(
-                ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk,
-                &num_bp_filt_alg, bp_filt_perf));
-        bp_filter_alg_ = bp_filt_perf[0].algo;
-        CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm(
-                ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk,
-                &num_bp_data_alg, bp_data_perf));
-        bp_data_alg_ = bp_data_perf[0].algo;
-    } else {
-        LOG(FATAL) << "Preferred algorithm is not available!";
-    }
-
-    size_t fp_byte, bp_data_byte, bp_filter_byte;
-    CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
-            ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_,
-            &fp_byte));
-    CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
-            ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
-            bp_data_alg_, &bp_data_byte));
-    CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
-            ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
-            bp_filter_alg_, &bp_filter_byte));
-    workspace_count_ = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) /
-                       sizeof(float) +
-                       1;
-    if (workspace_count_ * sizeof(float) > ch.workspace_byte_limit_)
-        LOG(WARNING) << "The required memory for workspace ("
-                     << workspace_count_ * sizeof(float)
-                     << ") is larger than the expected Bytes ("
-                     << ch.workspace_byte_limit_ << ")";
-    workspace_ = Tensor(Shape{workspace_count_}, dev, dtype);
-
-    return CudnnConvHandle{
-            x_desc_,
-            y_desc_,
-            bias_desc_,
-            filter_desc_,
-            conv_desc_,
-            fp_alg_,
-            bp_filter_alg_,
-            bp_data_alg_,
-
-            workspace_count_,
-            workspace_,
-
-            height_,
-            width_,
-            conv_height_,
-            conv_width_,
-            batchsize,
-    };
-};
-
-Tensor CudnnConvForward(const Tensor &x, const Tensor &W, const Tensor &b,
-                        const ConvHandle ch, const CudnnConvHandle cch){
-    CHECK_EQ(x.device()->lang(), kCuda);
-    CHECK_EQ(x.nDim(), 4u);
-    CHECK_EQ(x.shape()[0],cch.batchsize);
-    CHECK_EQ(x.shape()[1],ch.channels_);
-    CHECK_EQ(x.shape()[2],cch.height_);
-    CHECK_EQ(x.shape()[3],cch.width_);
-
-    DataType dtype = x.data_type();
-    auto dev = x.device();
-
-    Shape shape{cch.batchsize, ch.num_filters_, cch.conv_height_, cch.conv_width_};
-    Tensor output(shape, dev, dtype);
-
-    output.device()->Exec([output, x, W, cch](Context *ctx) {
-        Block *inblock = x.block(), *outblock = output.block(),
-                *wblock = W.block();
-        float alpha = 1.f, beta = 0.f;
-        cudnnConvolutionForward(ctx->cudnn_handle, &alpha, cch.x_desc_,
-                                inblock->data(), cch.filter_desc_, wblock->data(),
-                                cch.conv_desc_, cch.fp_alg_,
-                                cch.workspace_.block()->mutable_data(),
-                                cch.workspace_count_ * sizeof(float), &beta,
-                                cch.y_desc_, outblock->mutable_data());
-    }, {x.block(), W.block()}, {output.block()}, cch.workspace_.block());
-
-    if (ch.bias_term_) {
-        output.device()->Exec([output, b, cch](Context *ctx) {
-            float beta = 1.f, alpha = 1.0f;
-            Block *outblock = output.block(), *bblock = b.block();
-            cudnnAddTensor(ctx->cudnn_handle, &alpha, cch.bias_desc_,
-                           bblock->data(), &beta, cch.y_desc_,
-                           outblock->mutable_data());
-        }, {output.block(), b.block()}, {output.block()});
-    }
-    return output;
-};
-
-// input Tensor W for Reset dW purpose, can avoid this later.
-Tensor CudnnConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const CudnnConvHandle cch){
-    CHECK_EQ(dy.device()->lang(), kCuda);
-    CHECK_EQ(dy.nDim(), 4u);
-
-    Tensor dW;
-    dW.ResetLike(W);
-
-    dy.device()->Exec([dW, dy, x, W, cch](Context *ctx) {
-    Block *inblock = x.block(), *dyblock = dy.block(),
-            *dwblock = dW.block();
-    float alpha = 1.f, beta = 0.f;
-    cudnnConvolutionBackwardFilter(
-            ctx->cudnn_handle, &alpha, cch.x_desc_, inblock->data(),
-            cch.y_desc_, dyblock->data(), cch.conv_desc_, cch.bp_filter_alg_,
-            cch.workspace_.block()->mutable_data(),
-            cch.workspace_count_ * sizeof(float), &beta, cch.filter_desc_,
-            dwblock->mutable_data());
-    }, {dy.block(), x.block()}, {dW.block(), cch.workspace_.block()});
-
-    return dW;
-};
-
-// input Tensor b for Reset db purpose, can avoid this later.
-Tensor CudnnConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle cch){
-    CHECK_EQ(dy.device()->lang(), kCuda);
-    CHECK_EQ(dy.nDim(), 4u);
-
-    Tensor db;
-    db.ResetLike(b);
-
-    dy.device()->Exec([db, dy, b, cch](Context *ctx) {
-        Block *dyblock = dy.block(), *dbblock = db.block();
-        float alpha = 1.f, beta = 0.f;
-        cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, cch.y_desc_,
-                                     dyblock->data(), &beta, cch.bias_desc_,
-                                     dbblock->mutable_data());
-    }, {dy.block()}, {db.block()});
-    return db;
-};
-
-Tensor CudnnConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandle cch){
-    CHECK_EQ(dy.device()->lang(), kCuda);
-    CHECK_EQ(dy.nDim(), 4u);
-
-    Tensor dx;
-    dx.ResetLike(x);
-
-    dy.device()->Exec([dx, dy, W, cch](Context *ctx) {
-        Block *wblock = W.block(), *dyblock = dy.block(),
-                *dxblock = dx.block();
-        float alpha = 1.f, beta = 0.f;
-        cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, cch.filter_desc_,
-                                     wblock->data(), cch.y_desc_, dyblock->data(),
-                                     cch.conv_desc_, cch.bp_data_alg_,
-                                     cch.workspace_.block()->mutable_data(),
-                                     cch.workspace_count_ * sizeof(float), &beta,
-                                     cch.x_desc_, dxblock->mutable_data());
-    }, {dy.block(), W.block()}, {dx.block(), cch.workspace_.block()});
-
-    return dx;
-};
-
-CpuConvHandle InitCpuHandle(const Tensor &input, const ConvHandle ch){
-    size_t height_;
-    size_t width_;
-    size_t conv_height_;
-    size_t conv_width_;    
-    size_t batchsize;
-    size_t channels_;
-
-    size_t col_height_;
-    size_t col_width_;
-
-    batchsize = input.shape(0);
-    channels_ = input.shape(1);
-    height_ = input.shape(2);
-    width_ = input.shape(3);
-
-    CHECK(channels_ == ch.channels_)<<"the number of input channels mismatched.";
-
-    conv_height_ = 1;
-    if (ch.stride_h_ > 0)
-        conv_height_ = (height_ + 2 * ch.pad_h_ - ch.kernel_h_) / ch.stride_h_ + 1;
-    conv_width_ = (width_ + 2 * ch.pad_w_ - ch.kernel_w_) / ch.stride_w_ + 1;
-
-    col_height_ = ch.channels_ * ch.kernel_w_ * ch.kernel_h_;
-    col_width_ = conv_height_ * conv_width_;
-
-    return CpuConvHandle{
-        height_,
-        width_,
-        conv_height_,
-        conv_width_,
-        batchsize,
-
-        col_height_,
-        col_width_
-    };
-};
-
-Convolution C;
-
-Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b,
-                        const ConvHandle ch, const CpuConvHandle cch){
-    CHECK_EQ(x.device()->lang(), kCpp);
-    CHECK_EQ(x.nDim(), 4u);
-    CHECK_EQ(x.shape()[0],cch.batchsize);
-    CHECK_EQ(x.shape()[1],ch.channels_);
-    CHECK_EQ(x.shape()[2],cch.height_);
-    CHECK_EQ(x.shape()[3],cch.width_);
-
-    size_t imagesize = x.Size() / cch.batchsize;
-
-    Shape w_shape= W.shape();
-    Shape b_shape= b.shape();
-
-    W.Reshape(Shape{ch.num_filters_, cch.col_height_});
-    if (ch.bias_term_)
-      b.Reshape(Shape{ch.num_filters_});
-
-    DataType dtype = x.data_type();
-    auto dev = x.device();
-    Shape shape{cch.batchsize, ch.num_filters_, cch.conv_height_, cch.conv_width_};
-    Tensor output(shape, dev, dtype);
-
-    Tensor col_data(Shape{cch.col_height_, cch.col_width_});//broadcasted image
-
-    float *data_col = new float[cch.col_height_ * cch.col_width_];
-    auto in_data = x.data<float>();
-    for (size_t num = 0; num < cch.batchsize; num++) {
-      C.Im2col(in_data + num * imagesize, ch.channels_, cch.height_, cch.width_, ch.kernel_h_,
-            ch.kernel_w_, ch.pad_h_, ch.pad_w_, ch.stride_h_, ch.stride_w_, data_col);
-      col_data.CopyDataFromHostPtr(data_col, cch.col_height_ * cch.col_width_);
-      Tensor each = Mult(W, col_data);
-      if (ch.bias_term_) {
-          AddColumn(b, &each);
-        }
-      CopyDataToFrom(&output, each, each.Size(), num * each.Size());
-  }
-  W.Reshape(w_shape);
-  b.Reshape(b_shape);
-  return output;
-}; 
-
-Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, 
-    const ConvHandle ch, const CpuConvHandle cch){
-    CHECK_EQ(dy.device()->lang(), kCpp);
-    CHECK_EQ(dy.nDim(), 4u);
-
-    Shape w_shape= W.shape();
-    W.Reshape(Shape{ch.num_filters_, cch.col_height_});
-
-    Tensor dx;
-    dx.ResetLike(x);
-    
-    size_t imagesize = x.Size() / cch.batchsize;
-    float *dx_b = new float[imagesize];
-
-    for (size_t num = 0; num < cch.batchsize; num++) {
-      Tensor grad_b(Shape{ch.num_filters_, cch.conv_height_ * cch.conv_width_});
-      CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
-      Tensor dcol_b = Mult(W.T(), grad_b);
-      auto dcol_data = dcol_b.data<float>();
-      C.Col2im(dcol_data, ch.channels_, cch.height_, cch.width_, ch.kernel_h_, ch.kernel_w_, ch.pad_h_,
-           ch.pad_w_, ch.stride_h_, ch.stride_w_, dx_b);
-      dx.CopyDataFromHostPtr(dx_b, imagesize, num * imagesize);
-    }
-  W.Reshape(w_shape); 
-  return dx;
-};
-
-Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, 
-    const ConvHandle ch, const CpuConvHandle cch){
-    CHECK_EQ(dy.device()->lang(), kCpp);
-    CHECK_EQ(dy.nDim(), 4u);
-
-    size_t imagesize = x.Size() / cch.batchsize;
-
-    Tensor dW;
-    dW.ResetLike(W);
-    dW.SetValue(0.0f);
-    
-    Shape w_shape= W.shape();
-    dW.Reshape(Shape{ch.num_filters_, cch.col_height_});
-
-    Tensor col_data(Shape{cch.col_height_, cch.col_width_});//broadcasted image
-
-    float *data_col = new float[cch.col_height_ * cch.col_width_];
-    auto in_data = dy.data<float>();
-    for (size_t num = 0; num < cch.batchsize; num++) {
-      C.Im2col(in_data + num * imagesize, ch.channels_, cch.height_, cch.width_, ch.kernel_h_,
-            ch.kernel_w_, ch.pad_h_, ch.pad_w_, ch.stride_h_, ch.stride_w_, data_col);
-      col_data.CopyDataFromHostPtr(data_col, cch.col_height_ * cch.col_width_);
-      Tensor grad_b(Shape{ch.num_filters_, cch.conv_height_ * cch.conv_width_});
-      CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
-      dW += Mult(grad_b, col_data.T());
-    }
-   dW.Reshape(w_shape);
-    //dW.Reshape(Shape{ch.num_filters_,ch.channels_ , ch.kernel_w_ , ch.kernel_h_});
-   return dW;
-};
-
-Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle ch, const CpuConvHandle cch){
-    CHECK_EQ(dy.device()->lang(), kCpp);
-    CHECK_EQ(dy.nDim(), 4u);
-
-    Tensor db;
-    db.ResetLike(b);
-
-    auto tmpshp = Shape{cch.batchsize * ch.num_filters_, dy.Size() / (cch.batchsize * ch.num_filters_)};
-    Tensor tmp1 = Reshape(dy, tmpshp);
-
-    Tensor tmp2(Shape{cch.batchsize * ch.num_filters_});
-    SumColumns(tmp1, &tmp2);
-    Tensor tmp3 = Reshape(tmp2, Shape{cch.batchsize, ch.num_filters_});
-
-    SumRows(tmp3, &db);
-
-    return db;
-};
-
-} //namespace_singa
-
-

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/189958ab/src/model/convolution_functions.h
----------------------------------------------------------------------
diff --git a/src/model/convolution_functions.h b/src/model/convolution_functions.h
deleted file mode 100644
index 1b90941..0000000
--- a/src/model/convolution_functions.h
+++ /dev/null
@@ -1,95 +0,0 @@
-#include <string>
-#include <cudnn.h>
-#include "./layer/cudnn_convolution.h"
-#include "./layer/cudnn_utils.h"
-#include "singa/utils/logging.h"
-
-namespace singa{
-
-struct ConvHandle{
-    size_t kernel_w_;
-    size_t pad_w_;
-    size_t stride_w_;
-    size_t kernel_h_;
-    size_t pad_h_;
-    size_t stride_h_;
-
-    size_t channels_;
-    size_t num_filters_;
-
-    bool bias_term_;
-
-    size_t workspace_byte_limit_;
-    string prefer_;
-};
-
-struct CudnnConvHandle{
-    cudnnTensorDescriptor_t x_desc_ ;
-    cudnnTensorDescriptor_t y_desc_ ;
-    cudnnTensorDescriptor_t bias_desc_ ;
-    cudnnFilterDescriptor_t filter_desc_ ;
-    cudnnConvolutionDescriptor_t conv_desc_ ;
-    cudnnConvolutionFwdAlgo_t fp_alg_;
-    cudnnConvolutionBwdFilterAlgo_t bp_filter_alg_;
-    cudnnConvolutionBwdDataAlgo_t bp_data_alg_;
-
-    size_t workspace_count_;
-    Tensor workspace_;
-
-    size_t height_;
-    size_t width_;
-    size_t conv_height_;
-    size_t conv_width_;
-    size_t batchsize;
-};
-
-struct CpuConvHandle{
-    size_t height_;
-    size_t width_;
-    size_t conv_height_;
-    size_t conv_width_;
-    size_t batchsize;
-
-    size_t col_height_;
-    size_t col_width_;
-
-};
-
-    
-
-ConvHandle SetupConv(
-    const size_t kernel_h_, const size_t kernel_w_,
-    const size_t pad_h_, const size_t pad_w_,
-    const size_t stride_h_,const size_t stride_w_,
-    const size_t channels_, const size_t num_filters_,
-    const bool bias_term_ = true ,const size_t workspace_byte_limit_=1024*1024*1024,
-    const std::string prefer_="fastest");
-
-void testInitCudnn(const Tensor &input, const ConvHandle ch);
-
-CudnnConvHandle InitCudnn(const Tensor &input, const ConvHandle ch);
-
-Tensor CudnnConvForward(const Tensor &x, const Tensor &W, const Tensor &b,
-                        const ConvHandle ch, const CudnnConvHandle cch);
-
-Tensor CudnnConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const CudnnConvHandle cch);
-
-Tensor CudnnConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle cch);
-
-Tensor CudnnConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandle cch);
-
-
-CpuConvHandle InitCpuHandle(const Tensor &input, const ConvHandle ch);
-
-Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b,
-                        const ConvHandle ch, const CpuConvHandle cch);
-
-Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, 
-    const ConvHandle ch, const CpuConvHandle cch);
-
-Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, 
-    const ConvHandle ch, const CpuConvHandle cch);
-
-Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle ch, const CpuConvHandle cch);
-
-}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/189958ab/src/model/operation/convolution_related.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution_related.cc b/src/model/operation/convolution_related.cc
new file mode 100644
index 0000000..1004074
--- /dev/null
+++ b/src/model/operation/convolution_related.cc
@@ -0,0 +1,417 @@
+#include "./convolution_related.h"
+#include "../layer/convolution.h"
+#include<iostream>
+
+namespace singa{
+
+Recorder SetupRecorder(const Tensor &input, const std::vector<size_t> kernel_size, 
+	                const std::vector<size_t> stride, const std::vector<size_t> padding,
+	                const size_t in_channels, const size_t out_channels,
+	                const bool bias_term_){
+	size_t kernel_w_;
+    size_t pad_w_;
+    size_t stride_w_;
+    size_t kernel_h_;
+    size_t pad_h_;
+    size_t stride_h_;
+
+    size_t height_;
+    size_t width_;
+    size_t conv_height_;
+    size_t conv_width_;
+    size_t batchsize;
+
+    size_t col_height_;
+    size_t col_width_;
+    size_t imagesize;
+
+    kernel_h_=kernel_size[0];
+    kernel_w_=kernel_size[1];
+
+    pad_h_=padding[0];
+    pad_w_=padding[1];
+
+    stride_h_=stride[0];
+    stride_w_=stride[1];
+
+	batchsize = input.shape(0);
+	CHECK(input.shape(1) == in_channels)<<"the number of input channels mismatched.";
+    height_ = input.shape(2);
+    width_ = input.shape(3);
+
+    conv_height_ = 1;
+    if (stride_h_ > 0)
+        conv_height_ = (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1;
+    conv_width_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
+
+    col_height_ = in_channels * kernel_w_ * kernel_h_;
+    col_width_ = conv_height_ * conv_width_;
+    imagesize = input.Size() / batchsize;
+
+    return Recorder{
+    	kernel_w_,
+        pad_w_,
+        stride_w_,
+        kernel_h_,
+        pad_h_,
+        stride_h_,
+
+        in_channels,
+        out_channels,
+
+        bias_term_,
+
+        height_,
+        width_,
+        conv_height_,
+        conv_width_,
+        batchsize,
+
+        col_height_,
+        col_width_,
+        imagesize
+    };
+};	
+
+Convolution C;
+
+Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b, const Recorder r){
+	CHECK_EQ(x.device()->lang(), kCpp);
+
+	CHECK(x.shape(1) == r.channels_ && x.shape(2) == r.height_ &&
+    x.shape(3) == r.width_) << "input sample shape should not change";
+
+    CHECK(W.shape(0) == r.num_filters_ && W.shape(1) == r.channels_ && 
+    W.shape(2) == r.kernel_h_ && W.shape(3) == r.kernel_w_) << "weights shape should not change";
+
+    Shape w_shape= W.shape();
+    Shape b_shape= b.shape();
+
+    W.Reshape(Shape{r.num_filters_, r.col_height_});
+    if (r.bias_term_)
+      b.Reshape(Shape{r.num_filters_});
+
+    DataType dtype = x.data_type();
+    auto dev = x.device();
+    Shape shape{r.batchsize, r.num_filters_, r.conv_height_, r.conv_width_};
+    Tensor output(shape, dev, dtype);
+
+    Tensor col_data(Shape{r.col_height_, r.col_width_});//broadcasted image
+
+    float *data_col = new float[r.col_height_ * r.col_width_];
+    auto in_data = x.data<float>();
+    for (size_t num = 0; num < r.batchsize; num++) {
+      C.Im2col(in_data + num * r.imagesize, r.channels_, r.height_, r.width_, r.kernel_h_,
+            r.kernel_w_, r.pad_h_, r.pad_w_, r.stride_h_, r.stride_w_, data_col);    
+
+      col_data.CopyDataFromHostPtr(data_col, r.col_height_ * r.col_width_);
+      Tensor each = Mult(W, col_data);
+      if (r.bias_term_) {
+          AddColumn(b, &each);
+        }
+      CopyDataToFrom(&output, each, each.Size(), num * each.Size());
+    };
+  W.Reshape(w_shape);
+  b.Reshape(b_shape);
+  return output;
+}; 
+
+Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const Recorder r){
+    CHECK_EQ(dy.device()->lang(), kCpp);
+    
+    CHECK(dy.shape(1) == r.num_filters_ && dy.shape(2) == r.conv_height_ &&
+    dy.shape(3) == r.conv_width_) << "input gradients shape should not change";
+
+    CHECK(W.shape(0) == r.num_filters_ && W.shape(1) == r.channels_ && 
+    W.shape(2) == r.kernel_h_ && W.shape(3) == r.kernel_w_) << "weights shape should not change";
+
+    Shape w_shape= W.shape();
+    W.Reshape(Shape{r.num_filters_, r.col_height_});
+
+    Tensor dx;
+    dx.ResetLike(x);
+    
+    float *dx_b = new float[r.imagesize];
+
+    for (size_t num = 0; num < r.batchsize; num++) {
+      Tensor grad_b(Shape{r.num_filters_, r.conv_height_ * r.conv_width_});
+      CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
+      Tensor dcol_b = Mult(W.T(), grad_b);
+      auto dcol_data = dcol_b.data<float>();
+      C.Col2im(dcol_data, r.channels_, r.height_, r.width_, r.kernel_h_, r.kernel_w_, r.pad_h_,
+           r.pad_w_, r.stride_h_, r.stride_w_, dx_b);
+      dx.CopyDataFromHostPtr(dx_b, r.imagesize, num * r.imagesize);
+    }
+  W.Reshape(w_shape); 
+  return dx;
+};
+
+Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const Recorder r){
+    CHECK_EQ(dy.device()->lang(), kCpp);
+    
+    CHECK(dy.shape(1) == r.num_filters_ && dy.shape(2) == r.conv_height_ &&
+    dy.shape(3) == r.conv_width_) << "input gradients shape should not change";
+
+    CHECK(x.shape(1) == r.channels_ && x.shape(2) == r.height_ &&
+    x.shape(3) == r.width_) << "input sample shape should not change";
+
+    Tensor dW;
+    dW.ResetLike(W);
+    dW.SetValue(0.0f);
+    
+    Shape w_shape= W.shape();
+    dW.Reshape(Shape{r.num_filters_, r.col_height_});
+
+    Tensor col_data(Shape{r.col_height_, r.col_width_});//broadcasted image
+
+    float *data_col = new float[r.col_height_ * r.col_width_];
+    auto in_data = dy.data<float>();
+    for (size_t num = 0; num < r.batchsize; num++) {
+      C.Im2col(in_data + num * r.imagesize, r.channels_, r.height_, r.width_, r.kernel_h_,
+            r.kernel_w_, r.pad_h_, r.pad_w_, r.stride_h_, r.stride_w_, data_col);
+      col_data.CopyDataFromHostPtr(data_col, r.col_height_ * r.col_width_);
+      Tensor grad_b(Shape{r.num_filters_, r.conv_height_ * r.conv_width_});
+      CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
+      dW += Mult(grad_b, col_data.T());
+    }
+   dW.Reshape(w_shape);
+   return dW;
+};
+
+Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const Recorder r){
+    CHECK_EQ(dy.device()->lang(), kCpp);
+    
+    CHECK(dy.shape(1) == r.num_filters_ && dy.shape(2) == r.conv_height_ &&
+    dy.shape(3) == r.conv_width_) << "input gradients shape should not change";
+	
+	CHECK(b.shape(0) == r.num_filters_)<< "bias shape should not change";
+
+    Tensor db;
+    db.ResetLike(b);
+
+    auto tmpshp = Shape{r.batchsize * r.num_filters_, dy.Size() / (r.batchsize * r.num_filters_)};
+    Tensor tmp1 = Reshape(dy, tmpshp);
+
+    Tensor tmp2(Shape{r.batchsize * r.num_filters_});
+    SumColumns(tmp1, &tmp2);
+    Tensor tmp3 = Reshape(tmp2, Shape{r.batchsize, r.num_filters_});
+
+    SumRows(tmp3, &db);
+
+    return db;
+};
+
+CudnnConvHandles InitCudnnConvHandles(const Tensor &input, const Recorder r, const size_t workspace_byte_limit_,
+    				const std::string prefer_){
+
+	CHECK(input.shape(0) == r.batchsize && input.shape(1) == r.channels_ && input.shape(2) == r.height_ &&
+    input.shape(3) == r.width_) << "input sample shape dismatched";
+
+	cudnnTensorDescriptor_t x_desc_ ;
+    cudnnTensorDescriptor_t y_desc_ ;
+    cudnnTensorDescriptor_t bias_desc_ ;
+    cudnnFilterDescriptor_t filter_desc_ ;
+    cudnnConvolutionDescriptor_t conv_desc_ ;
+    cudnnConvolutionFwdAlgo_t fp_alg_;
+    cudnnConvolutionBwdFilterAlgo_t bp_filter_alg_;
+    cudnnConvolutionBwdDataAlgo_t bp_data_alg_;
+
+    size_t workspace_count_;
+    Tensor workspace_; 
+
+    DataType dtype = input.data_type();
+    auto dev = input.device();
+    Context *ctx = dev->context(0);
+
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
+    if (r.bias_term_)
+        CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
+    CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
+    CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_));
+
+
+    CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
+                                           GetCudnnDataType(dtype), r.batchsize,
+                                           r.channels_, r.height_, r.width_));
+    CUDNN_CHECK(cudnnSetTensor4dDescriptor(
+            y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), r.batchsize,
+            r.num_filters_, r.conv_height_, r.conv_width_));
+    if (r.bias_term_)
+        CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW,
+                                               GetCudnnDataType(dtype), 1,
+                                               r.num_filters_, 1, 1));
+    CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, r.pad_h_, r.pad_w_,
+                                                r.stride_h_, r.stride_w_, 1, 1,
+                                                CUDNN_CROSS_CORRELATION,
+                                                GetCudnnDataType(dtype)));
+    CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc_, GetCudnnDataType(dtype),
+                                           CUDNN_TENSOR_NCHW, r.num_filters_,
+                                           r.channels_, r.kernel_h_, r.kernel_w_));
+    if (prefer_ == "fastest" || prefer_ == "limited_workspace" ||
+        prefer_ == "no_workspace") {
+        cudnnConvolutionFwdPreference_t fwd_pref;
+        cudnnConvolutionBwdFilterPreference_t bwd_filt_pref;
+        cudnnConvolutionBwdDataPreference_t bwd_data_pref;
+        if (prefer_ == "fastest") {
+            fwd_pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
+            bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST;
+            bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
+        } else if (prefer_ == "limited_workspace") {
+            fwd_pref = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT;
+            bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT;
+            bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
+        } else {
+            fwd_pref = CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
+            bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
+            bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
+        }
+        CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
+                ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
+                workspace_byte_limit_, &fp_alg_));
+        CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
+                ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
+                bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_));
+        // deprecated in cudnn v7
+        CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
+                ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
+                bwd_data_pref, workspace_byte_limit_, &bp_data_alg_));
+        } else if (prefer_ == "autotune") {
+        const int topk = 1;
+        int num_fp_alg, num_bp_filt_alg, num_bp_data_alg;
+        cudnnConvolutionFwdAlgoPerf_t fp_alg_perf[topk];
+        cudnnConvolutionBwdFilterAlgoPerf_t bp_filt_perf[topk];
+        cudnnConvolutionBwdDataAlgoPerf_t bp_data_perf[topk];
+        CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm(
+                ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk,
+                &num_fp_alg, fp_alg_perf));
+        fp_alg_ = fp_alg_perf[0].algo;
+        CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(
+                ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk,
+                &num_bp_filt_alg, bp_filt_perf));
+        bp_filter_alg_ = bp_filt_perf[0].algo;
+        CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm(
+                ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk,
+                &num_bp_data_alg, bp_data_perf));
+        bp_data_alg_ = bp_data_perf[0].algo;
+    } else {
+        LOG(FATAL) << "Preferred algorithm is not available!";
+    }
+
+    size_t fp_byte, bp_data_byte, bp_filter_byte;
+    CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
+            ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_,
+            &fp_byte));
+    CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
+            ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
+            bp_data_alg_, &bp_data_byte));
+    CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
+            ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
+            bp_filter_alg_, &bp_filter_byte));
+    workspace_count_ = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) /
+                       sizeof(float) +
+                       1;
+    if (workspace_count_ * sizeof(float) > workspace_byte_limit_)
+        LOG(WARNING) << "The required memory for workspace ("
+                     << workspace_count_ * sizeof(float)
+                     << ") is larger than the expected Bytes ("
+                     << workspace_byte_limit_ << ")";
+    workspace_ = Tensor(Shape{workspace_count_}, dev, dtype);
+
+};
+
+Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const Recorder r, const CudnnConvHandles cch){
+	CHECK_EQ(x.device()->lang(), kCuda);
+
+    DataType dtype = x.data_type();
+    auto dev = x.device();
+
+    Shape shape{r.batchsize, r.num_filters_, r.conv_height_, r.conv_width_};
+    Tensor output(shape, dev, dtype);
+
+    output.device()->Exec([output, x, W, cch](Context *ctx) {
+        Block *inblock = x.block(), *outblock = output.block(),
+                *wblock = W.block();
+        float alpha = 1.f, beta = 0.f;
+        cudnnConvolutionForward(ctx->cudnn_handle, &alpha, cch.x_desc_,
+                                inblock->data(), cch.filter_desc_, wblock->data(),
+                                cch.conv_desc_, cch.fp_alg_,
+                                cch.workspace_.block()->mutable_data(),
+                                cch.workspace_count_ * sizeof(float), &beta,
+                                cch.y_desc_, outblock->mutable_data());
+    }, {x.block(), W.block()}, {output.block()}, cch.workspace_.block());
+
+    if (r.bias_term_) {
+        output.device()->Exec([output, b, cch](Context *ctx) {
+            float beta = 1.f, alpha = 1.0f;
+            Block *outblock = output.block(), *bblock = b.block();
+            cudnnAddTensor(ctx->cudnn_handle, &alpha, cch.bias_desc_,
+                           bblock->data(), &beta, cch.y_desc_,
+                           outblock->mutable_data());
+        }, {output.block(), b.block()}, {output.block()});
+    }
+
+    return output;
+};
+
+Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandles cch){
+    CHECK_EQ(dy.device()->lang(), kCuda);
+
+    Tensor dx;
+    dx.ResetLike(x);
+
+    dy.device()->Exec([dx, dy, W, cch](Context *ctx) {
+        Block *wblock = W.block(), *dyblock = dy.block(),
+                *dxblock = dx.block();
+        float alpha = 1.f, beta = 0.f;
+        cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, cch.filter_desc_,
+                                     wblock->data(), cch.y_desc_, dyblock->data(),
+                                     cch.conv_desc_, cch.bp_data_alg_,
+                                     cch.workspace_.block()->mutable_data(),
+                                     cch.workspace_count_ * sizeof(float), &beta,
+                                     cch.x_desc_, dxblock->mutable_data());
+    }, {dy.block(), W.block()}, {dx.block(), cch.workspace_.block()});
+
+    return dx;
+};
+
+Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const CudnnConvHandles cch){
+    CHECK_EQ(dy.device()->lang(), kCuda);
+
+    Tensor dW;
+    dW.ResetLike(W);
+
+    dy.device()->Exec([dW, dy, x, W, cch](Context *ctx) {
+    Block *inblock = x.block(), *dyblock = dy.block(),
+            *dwblock = dW.block();
+    float alpha = 1.f, beta = 0.f;
+    cudnnConvolutionBackwardFilter(
+            ctx->cudnn_handle, &alpha, cch.x_desc_, inblock->data(),
+            cch.y_desc_, dyblock->data(), cch.conv_desc_, cch.bp_filter_alg_,
+            cch.workspace_.block()->mutable_data(),
+            cch.workspace_count_ * sizeof(float), &beta, cch.filter_desc_,
+            dwblock->mutable_data());
+    }, {dy.block(), x.block()}, {dW.block(), cch.workspace_.block()});
+
+    return dW;
+};
+
+// input Tensor b for Reset db purpose, can avoid this later.
+Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandles cch){
+    CHECK_EQ(dy.device()->lang(), kCuda);
+
+    Tensor db;
+    db.ResetLike(b);
+
+    dy.device()->Exec([db, dy, b, cch](Context *ctx) {
+        Block *dyblock = dy.block(), *dbblock = db.block();
+        float alpha = 1.f, beta = 0.f;
+        cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, cch.y_desc_,
+                                     dyblock->data(), &beta, cch.bias_desc_,
+                                     dbblock->mutable_data());
+    }, {dy.block()}, {db.block()});
+
+    return db;
+};
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/189958ab/src/model/operation/convolution_related.h
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution_related.h b/src/model/operation/convolution_related.h
new file mode 100644
index 0000000..49aab5b
--- /dev/null
+++ b/src/model/operation/convolution_related.h
@@ -0,0 +1,75 @@
+#include <string>
+#include <vector>
+#include <cudnn.h>
+#include "../layer/cudnn_convolution.h"
+#include "../layer/cudnn_utils.h"
+#include "singa/utils/logging.h"
+
+namespace singa{
+
+struct Recorder{
+    size_t kernel_w_;
+    size_t pad_w_;
+    size_t stride_w_;
+    size_t kernel_h_;
+    size_t pad_h_;
+    size_t stride_h_;
+
+    size_t channels_;
+    size_t num_filters_;
+
+    bool bias_term_;
+
+    size_t height_;
+    size_t width_;
+    size_t conv_height_;
+    size_t conv_width_;
+    size_t batchsize;
+
+    size_t col_height_;
+    size_t col_width_;
+    size_t imagesize;
+};
+
+struct CudnnConvHandles{
+	cudnnTensorDescriptor_t x_desc_ ;
+    cudnnTensorDescriptor_t y_desc_ ;
+    cudnnTensorDescriptor_t bias_desc_ ;
+    cudnnFilterDescriptor_t filter_desc_ ;
+    cudnnConvolutionDescriptor_t conv_desc_ ;
+    cudnnConvolutionFwdAlgo_t fp_alg_;
+    cudnnConvolutionBwdFilterAlgo_t bp_filter_alg_;
+    cudnnConvolutionBwdDataAlgo_t bp_data_alg_;
+
+    size_t workspace_count_;
+    Tensor workspace_;  
+};
+
+
+Recorder SetupRecorder(const Tensor &input, const std::vector<size_t> kernel_size, 
+	                const std::vector<size_t> stride, const std::vector<size_t> padding,
+	                const size_t in_channels, const size_t out_channels,
+	                const bool bias_term_);
+
+CudnnConvHandles InitCudnnConvHandles(const Tensor &input, const Recorder r, const size_t workspace_byte_limit_=1024*1024*1024,
+    				const std::string prefer_="fastest");
+
+Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const Recorder r, const CudnnConvHandles cch);
+
+Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandles cch);
+
+Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const CudnnConvHandles cch);
+
+Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandles cch);
+
+
+
+Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b, const Recorder r);
+
+Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const Recorder r);
+
+Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const Recorder r);
+
+Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const Recorder r);
+
+}
\ No newline at end of file