You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/07/05 03:10:12 UTC
[17/18] incubator-singa git commit: SINGA-371 Implement functional
operations in c++ for autograd
SINGA-371 Implement functional operations in c++ for autograd
- tidy codes and rename some variables
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/ac5f4eb2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/ac5f4eb2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/ac5f4eb2
Branch: refs/heads/master
Commit: ac5f4eb2a245f7515f01321b4fe259fc4f58146c
Parents: 4a45ee6
Author: xuewanqi <xu...@outlook.com>
Authored: Thu Jul 5 03:03:03 2018 +0000
Committer: xuewanqi <xu...@outlook.com>
Committed: Thu Jul 5 03:03:03 2018 +0000
----------------------------------------------------------------------
python/singa/autograd.py | 6 +-
src/api/model_operation.i | 28 ++--
src/model/operation/convolution.cc | 280 ++++++++++++++++----------------
src/model/operation/convolution.h | 62 +++----
4 files changed, 187 insertions(+), 189 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ac5f4eb2/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index b05f701..80209ff 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -463,7 +463,7 @@ class _Conv2D(Operation):
#assert 0 == 0, 'invalid padding'
if training:
- if self.handle.bias_term_:
+ if self.handle.bias_term:
self.inputs = (x, W, b)
else:
self.inputs = (x, W)
@@ -486,7 +486,7 @@ class _Conv2D(Operation):
dy, self.inputs[1], self.inputs[0], self.handle)
dW = singa.CpuConvBackwardW(
dy, self.inputs[0], self.inputs[1], self.handle)
- if self.handle.bias_term_:
+ if self.handle.bias_term:
db = singa.CpuConvBackwardb(dy, self.inputs[2], self.handle)
return dx, dW, db
else:
@@ -496,7 +496,7 @@ class _Conv2D(Operation):
dy, self.inputs[1], self.inputs[0], self.handle)
dW = singa.GpuConvBackwardW(
dy, self.inputs[0], self.inputs[1], self.handle)
- if self.handle.bias_term_:
+ if self.handle.bias_term:
db = singa.GpuConvBackwardb(dy, self.inputs[2], self.handle)
return dx, dW, db
else:
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ac5f4eb2/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index 26f5c69..2c13a3b 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -5,28 +5,26 @@
%}
namespace singa{
-struct ConvHandle{
-
- size_t batchsize;
- const bool bias_term_;
-
- ConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+class ConvHandle {
+ public:
+ ConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
const std::vector<size_t>& stride, const std::vector<size_t>& padding,
const size_t in_channels, const size_t out_channels,
const bool bias);
- };
+ bool bias_term;
+ size_t batchsize;
+};
struct CudnnConvHandle{
-
- size_t batchsize;
- const bool bias_term_;
-
- CudnnConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+ public:
+ CudnnConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
const std::vector<size_t>& stride, const std::vector<size_t>& padding,
const size_t in_channels, const size_t out_channels,
- const bool bias, const size_t workspace_byte_limit_ = 1024 * 1024 * 1024,
- const std::string& prefer_ = "fastest");
- };
+ const bool bias, const size_t workspace_byte_limit = 1024 * 1024 * 1024,
+ const std::string& prefer = "fastest");
+ bool bias_term;
+ size_t batchsize;
+};
Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const CudnnConvHandle &cch);
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ac5f4eb2/src/model/operation/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution.cc b/src/model/operation/convolution.cc
index 9a702fa..e36df43 100755
--- a/src/model/operation/convolution.cc
+++ b/src/model/operation/convolution.cc
@@ -8,32 +8,32 @@ ConvHandle::ConvHandle(const Tensor &input, const std::vector<size_t>& kernel_si
const std::vector<size_t>& stride, const std::vector<size_t>& padding,
const size_t in_channels, const size_t out_channels,
const bool bias) {
- kernel_h_ = kernel_size[0];
- kernel_w_ = kernel_size[1];
+ kernel_h = kernel_size[0];
+ kernel_w = kernel_size[1];
- pad_h_ = padding[0];
- pad_w_ = padding[1];
+ pad_h = padding[0];
+ pad_w = padding[1];
- stride_h_ = stride[0];
- stride_w_ = stride[1];
+ stride_h = stride[0];
+ stride_w = stride[1];
- channels_ = in_channels;
- num_filters_ = out_channels;
+ channels = in_channels;
+ num_filters = out_channels;
- bias_term_ = bias;
+ bias_term = bias;
batchsize = input.shape(0);
CHECK(input.shape(1) == in_channels) << "the number of input channels mismatched.";
- height_ = input.shape(2);
- width_ = input.shape(3);
+ height = input.shape(2);
+ width = input.shape(3);
- conv_height_ = 1;
- if (stride_h_ > 0)
- conv_height_ = (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1;
- conv_width_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
+ conv_height = 1;
+ if (stride_h > 0)
+ conv_height = (height + 2 * pad_h - kernel_h) / stride_h + 1;
+ conv_width = (width + 2 * pad_w - kernel_w) / stride_w + 1;
- col_height_ = in_channels * kernel_w_ * kernel_h_;
- col_width_ = conv_height_ * conv_width_;
+ col_height = in_channels * kernel_w * kernel_h;
+ col_width = conv_height * conv_width;
imagesize = input.Size() / batchsize;
}
@@ -42,43 +42,43 @@ ConvHandle::ConvHandle(const Tensor &input, const std::vector<size_t>& kernel_si
Tensor CpuConvForward(const Tensor &x, Tensor &W, Tensor &b, const ConvHandle &ch) {
CHECK_EQ(x.device()->lang(), kCpp);
- CHECK(x.shape(1) == ch.channels_ && x.shape(2) == ch.height_ &&
- x.shape(3) == ch.width_) << "input sample shape should not change";
+ CHECK(x.shape(1) == ch.channels && x.shape(2) == ch.height &&
+ x.shape(3) == ch.width) << "input sample shape should not change";
- CHECK(W.shape(0) == ch.num_filters_ && W.shape(1) == ch.channels_ &&
- W.shape(2) == ch.kernel_h_ && W.shape(3) == ch.kernel_w_) << "weights shape should not change";
+ CHECK(W.shape(0) == ch.num_filters && W.shape(1) == ch.channels &&
+ W.shape(2) == ch.kernel_h && W.shape(3) == ch.kernel_w) << "weights shape should not change";
Shape w_shape = W.shape();
Shape b_shape;
- if (ch.bias_term_)
+ if (ch.bias_term)
b_shape = b.shape();
- W.Reshape(Shape{ch.num_filters_, ch.col_height_});
- if (ch.bias_term_)
- b.Reshape(Shape{ch.num_filters_});
+ W.Reshape(Shape{ch.num_filters, ch.col_height});
+ if (ch.bias_term)
+ b.Reshape(Shape{ch.num_filters});
DataType dtype = x.data_type();
auto dev = x.device();
- Shape shape{ch.batchsize, ch.num_filters_, ch.conv_height_, ch.conv_width_};
+ Shape shape{ch.batchsize, ch.num_filters, ch.conv_height, ch.conv_width};
Tensor output(shape, dev, dtype);
- Tensor col_data(Shape{ch.col_height_, ch.col_width_});//broadcasted image
+ Tensor col_data(Shape{ch.col_height, ch.col_width});//broadcasted image
- float *data_col = new float[ch.col_height_ * ch.col_width_];
+ float *data_col = new float[ch.col_height * ch.col_width];
auto in_data = x.data<float>();
for (size_t num = 0; num < ch.batchsize; num++) {
- Im2col(in_data + num * ch.imagesize, ch.channels_, ch.height_, ch.width_, ch.kernel_h_,
- ch.kernel_w_, ch.pad_h_, ch.pad_w_, ch.stride_h_, ch.stride_w_, data_col);
+ Im2col(in_data + num * ch.imagesize, ch.channels, ch.height, ch.width, ch.kernel_h,
+ ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h, ch.stride_w, data_col);
- col_data.CopyDataFromHostPtr(data_col, ch.col_height_ * ch.col_width_);
+ col_data.CopyDataFromHostPtr(data_col, ch.col_height * ch.col_width);
Tensor each = Mult(W, col_data);
- if (ch.bias_term_) {
+ if (ch.bias_term) {
AddColumn(b, &each);
}
CopyDataToFrom(&output, each, each.Size(), num * each.Size());
};
W.Reshape(w_shape);
- if (ch.bias_term_)
+ if (ch.bias_term)
b.Reshape(b_shape);
return output;
}
@@ -86,14 +86,14 @@ Tensor CpuConvForward(const Tensor &x, Tensor &W, Tensor &b, const ConvHandle &
Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const ConvHandle &ch) {
CHECK_EQ(dy.device()->lang(), kCpp);
- CHECK(dy.shape(1) == ch.num_filters_ && dy.shape(2) == ch.conv_height_ &&
- dy.shape(3) == ch.conv_width_) << "input gradients shape should not change";
+ CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
+ dy.shape(3) == ch.conv_width) << "input gradients shape should not change";
- CHECK(W.shape(0) == ch.num_filters_ && W.shape(1) == ch.channels_ &&
- W.shape(2) == ch.kernel_h_ && W.shape(3) == ch.kernel_w_) << "weights shape should not change";
+ CHECK(W.shape(0) == ch.num_filters && W.shape(1) == ch.channels &&
+ W.shape(2) == ch.kernel_h && W.shape(3) == ch.kernel_w) << "weights shape should not change";
Shape w_shape = W.shape();
- W.Reshape(Shape{ch.num_filters_, ch.col_height_});
+ W.Reshape(Shape{ch.num_filters, ch.col_height});
Tensor dx;
dx.ResetLike(x);
@@ -101,12 +101,12 @@ Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const Conv
float *dx_b = new float[ch.imagesize];
for (size_t num = 0; num < ch.batchsize; num++) {
- Tensor grad_b(Shape{ch.num_filters_, ch.conv_height_ * ch.conv_width_});
+ Tensor grad_b(Shape{ch.num_filters, ch.conv_height * ch.conv_width});
CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
Tensor dcol_b = Mult(W.T(), grad_b);
auto dcol_data = dcol_b.data<float>();
- Col2im(dcol_data, ch.channels_, ch.height_, ch.width_, ch.kernel_h_, ch.kernel_w_, ch.pad_h_,
- ch.pad_w_, ch.stride_h_, ch.stride_w_, dx_b);
+ Col2im(dcol_data, ch.channels, ch.height, ch.width, ch.kernel_h, ch.kernel_w, ch.pad_h,
+ ch.pad_w, ch.stride_h, ch.stride_w, dx_b);
dx.CopyDataFromHostPtr(dx_b, ch.imagesize, num * ch.imagesize);
}
W.Reshape(w_shape);
@@ -116,28 +116,28 @@ Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const Conv
Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const ConvHandle &ch) {
CHECK_EQ(dy.device()->lang(), kCpp);
- CHECK(dy.shape(1) == ch.num_filters_ && dy.shape(2) == ch.conv_height_ &&
- dy.shape(3) == ch.conv_width_) << "input gradients shape should not change";
+ CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
+ dy.shape(3) == ch.conv_width) << "input gradients shape should not change";
- CHECK(x.shape(1) == ch.channels_ && x.shape(2) == ch.height_ &&
- x.shape(3) == ch.width_) << "input sample shape should not change";
+ CHECK(x.shape(1) == ch.channels && x.shape(2) == ch.height &&
+ x.shape(3) == ch.width) << "input sample shape should not change";
Tensor dW;
dW.ResetLike(W);
dW.SetValue(0.0f);
Shape w_shape = W.shape();
- dW.Reshape(Shape{ch.num_filters_, ch.col_height_});
+ dW.Reshape(Shape{ch.num_filters, ch.col_height});
- Tensor col_data(Shape{ch.col_height_, ch.col_width_});//broadcasted image
+ Tensor col_data(Shape{ch.col_height, ch.col_width});//broadcasted image
- float *data_col = new float[ch.col_height_ * ch.col_width_];
+ float *data_col = new float[ch.col_height * ch.col_width];
auto in_data = dy.data<float>();
for (size_t num = 0; num < ch.batchsize; num++) {
- Im2col(in_data + num * ch.imagesize, ch.channels_, ch.height_, ch.width_, ch.kernel_h_,
- ch.kernel_w_, ch.pad_h_, ch.pad_w_, ch.stride_h_, ch.stride_w_, data_col);
- col_data.CopyDataFromHostPtr(data_col, ch.col_height_ * ch.col_width_);
- Tensor grad_b(Shape{ch.num_filters_, ch.conv_height_ * ch.conv_width_});
+ Im2col(in_data + num * ch.imagesize, ch.channels, ch.height, ch.width, ch.kernel_h,
+ ch.kernel_w, ch.pad_h, ch.pad_w, ch.stride_h, ch.stride_w, data_col);
+ col_data.CopyDataFromHostPtr(data_col, ch.col_height * ch.col_width);
+ Tensor grad_b(Shape{ch.num_filters, ch.conv_height * ch.conv_width});
CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size());
dW += Mult(grad_b, col_data.T());
}
@@ -148,20 +148,20 @@ Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch) {
CHECK_EQ(dy.device()->lang(), kCpp);
- CHECK(dy.shape(1) == ch.num_filters_ && dy.shape(2) == ch.conv_height_ &&
- dy.shape(3) == ch.conv_width_) << "input gradients shape should not change";
+ CHECK(dy.shape(1) == ch.num_filters && dy.shape(2) == ch.conv_height &&
+ dy.shape(3) == ch.conv_width) << "input gradients shape should not change";
- CHECK(b.shape(0) == ch.num_filters_) << "bias shape should not change";
+ CHECK(b.shape(0) == ch.num_filters) << "bias shape should not change";
Tensor db;
db.ResetLike(b);
- auto tmpshp = Shape{ch.batchsize * ch.num_filters_, dy.Size() / (ch.batchsize * ch.num_filters_)};
+ auto tmpshp = Shape{ch.batchsize * ch.num_filters, dy.Size() / (ch.batchsize * ch.num_filters)};
Tensor tmp1 = Reshape(dy, tmpshp);
- Tensor tmp2(Shape{ch.batchsize * ch.num_filters_});
+ Tensor tmp2(Shape{ch.batchsize * ch.num_filters});
SumColumns(tmp1, &tmp2);
- Tensor tmp3 = Reshape(tmp2, Shape{ch.batchsize, ch.num_filters_});
+ Tensor tmp3 = Reshape(tmp2, Shape{ch.batchsize, ch.num_filters});
SumRows(tmp3, &db);
@@ -172,48 +172,48 @@ Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch)
CudnnConvHandle::CudnnConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
const std::vector<size_t>& stride, const std::vector<size_t>& padding,
const size_t in_channels, const size_t out_channels, const bool bias,
- const size_t workspace_byte_limit_, const std::string& prefer_)
+ const size_t workspace_byte_limit, const std::string& prefer)
: ConvHandle(input, kernel_size, stride, padding, in_channels, out_channels, bias) {
DataType dtype = input.data_type();
auto dev = input.device();
Context *ctx = dev->context(0);
- CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
- CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
- if (bias_term_)
- CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
- CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
- CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_));
+ CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc));
+ CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc));
+ if (bias_term)
+ CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
+ CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc));
+ CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
- CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW,
+ CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW,
GetCudnnDataType(dtype), batchsize,
- channels_, height_, width_));
+ channels, height, width));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
- y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
- num_filters_, conv_height_, conv_width_));
- if (bias_term_)
- CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW,
+ y_desc, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
+ num_filters, conv_height, conv_width));
+ if (bias_term)
+ CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc, CUDNN_TENSOR_NCHW,
GetCudnnDataType(dtype), 1,
- num_filters_, 1, 1));
- CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, pad_h_, pad_w_,
- stride_h_, stride_w_, 1, 1,
+ num_filters, 1, 1));
+ CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc, pad_h, pad_w,
+ stride_h, stride_w, 1, 1,
CUDNN_CROSS_CORRELATION,
GetCudnnDataType(dtype)));
- CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc_, GetCudnnDataType(dtype),
- CUDNN_TENSOR_NCHW, num_filters_,
- channels_, kernel_h_, kernel_w_));
- if (prefer_ == "fastest" || prefer_ == "limited_workspace" ||
- prefer_ == "no_workspace") {
+ CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc, GetCudnnDataType(dtype),
+ CUDNN_TENSOR_NCHW, num_filters,
+ channels, kernel_h, kernel_w));
+ if (prefer == "fastest" || prefer == "limited_workspace" ||
+ prefer == "no_workspace") {
cudnnConvolutionFwdPreference_t fwd_pref;
cudnnConvolutionBwdFilterPreference_t bwd_filt_pref;
cudnnConvolutionBwdDataPreference_t bwd_data_pref;
- if (prefer_ == "fastest") {
+ if (prefer == "fastest") {
fwd_pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST;
bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
- } else if (prefer_ == "limited_workspace") {
+ } else if (prefer == "limited_workspace") {
fwd_pref = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT;
bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT;
bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
@@ -223,67 +223,67 @@ CudnnConvHandle::CudnnConvHandle(const Tensor &input, const std::vector<size_t>&
bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
}
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
- ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
- workspace_byte_limit_, &fp_alg_));
+ ctx->cudnn_handle, x_desc, filter_desc, conv_desc, y_desc, fwd_pref,
+ workspace_byte_limit, &fp_alg));
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
- ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
- bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_));
+ ctx->cudnn_handle, x_desc, y_desc, conv_desc, filter_desc,
+ bwd_filt_pref, workspace_byte_limit, &bp_filter_alg));
// deprecated in cudnn v7
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
- ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
- bwd_data_pref, workspace_byte_limit_, &bp_data_alg_));
- } else if (prefer_ == "autotune") {
+ ctx->cudnn_handle, filter_desc, y_desc, conv_desc, x_desc,
+ bwd_data_pref, workspace_byte_limit, &bp_data_alg));
+ } else if (prefer == "autotune") {
const int topk = 1;
int num_fp_alg, num_bp_filt_alg, num_bp_data_alg;
- cudnnConvolutionFwdAlgoPerf_t fp_alg_perf[topk];
+ cudnnConvolutionFwdAlgoPerf_t fp_algperf[topk];
cudnnConvolutionBwdFilterAlgoPerf_t bp_filt_perf[topk];
cudnnConvolutionBwdDataAlgoPerf_t bp_data_perf[topk];
CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm(
- ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk,
- &num_fp_alg, fp_alg_perf));
- fp_alg_ = fp_alg_perf[0].algo;
+ ctx->cudnn_handle, x_desc, filter_desc, conv_desc, y_desc, topk,
+ &num_fp_alg, fp_algperf));
+ fp_alg = fp_algperf[0].algo;
CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(
- ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk,
+ ctx->cudnn_handle, x_desc, y_desc, conv_desc, filter_desc, topk,
&num_bp_filt_alg, bp_filt_perf));
- bp_filter_alg_ = bp_filt_perf[0].algo;
+ bp_filter_alg = bp_filt_perf[0].algo;
CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm(
- ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk,
+ ctx->cudnn_handle, filter_desc, y_desc, conv_desc, x_desc, topk,
&num_bp_data_alg, bp_data_perf));
- bp_data_alg_ = bp_data_perf[0].algo;
+ bp_data_alg = bp_data_perf[0].algo;
} else {
LOG(FATAL) << "Preferred algorithm is not available!";
}
size_t fp_byte, bp_data_byte, bp_filter_byte;
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
- ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_,
+ ctx->cudnn_handle, x_desc, filter_desc, conv_desc, y_desc, fp_alg,
&fp_byte));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
- ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
- bp_data_alg_, &bp_data_byte));
+ ctx->cudnn_handle, filter_desc, y_desc, conv_desc, x_desc,
+ bp_data_alg, &bp_data_byte));
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
- ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
- bp_filter_alg_, &bp_filter_byte));
- workspace_count_ = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) /
+ ctx->cudnn_handle, x_desc, y_desc, conv_desc, filter_desc,
+ bp_filter_alg, &bp_filter_byte));
+ workspace_count = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) /
sizeof(float) +
1;
- if (workspace_count_ * sizeof(float) > workspace_byte_limit_)
+ if (workspace_count * sizeof(float) > workspace_byte_limit)
LOG(WARNING) << "The required memory for workspace ("
- << workspace_count_ * sizeof(float)
+ << workspace_count * sizeof(float)
<< ") is larger than the expected Bytes ("
- << workspace_byte_limit_ << ")";
- workspace_ = Tensor(Shape{workspace_count_}, dev, dtype);
+ << workspace_byte_limit << ")";
+ workspace = Tensor(Shape{workspace_count}, dev, dtype);
}
CudnnConvHandle::~CudnnConvHandle() {
- if (bias_desc_ != nullptr)
- CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc_));
- if (filter_desc_ != nullptr)
- CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc_));
- if (conv_desc_ != nullptr)
- CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc_));
- if (x_desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_desc_));
- if (y_desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc_));
+ if (bias_desc != nullptr)
+ CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc));
+ if (filter_desc != nullptr)
+ CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc));
+ if (conv_desc != nullptr)
+ CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc));
+ if (x_desc != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_desc));
+ if (y_desc != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc));
}
Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const CudnnConvHandle &cch) {
@@ -292,27 +292,27 @@ Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const C
DataType dtype = x.data_type();
auto dev = x.device();
- Shape shape{cch.batchsize, cch.num_filters_, cch.conv_height_, cch.conv_width_};
+ Shape shape{cch.batchsize, cch.num_filters, cch.conv_height, cch.conv_width};
Tensor output(shape, dev, dtype);
output.device()->Exec([&output, &x, &W, &cch](Context * ctx) {
Block *inblock = x.block(), *outblock = output.block(),
*wblock = W.block();
float alpha = 1.f, beta = 0.f;
- cudnnConvolutionForward(ctx->cudnn_handle, &alpha, cch.x_desc_,
- inblock->data(), cch.filter_desc_, wblock->data(),
- cch.conv_desc_, cch.fp_alg_,
- cch.workspace_.block()->mutable_data(),
- cch.workspace_count_ * sizeof(float), &beta,
- cch.y_desc_, outblock->mutable_data());
- }, {x.block(), W.block()}, {output.block()}, cch.workspace_.block());
-
- if (cch.bias_term_) {
+ cudnnConvolutionForward(ctx->cudnn_handle, &alpha, cch.x_desc,
+ inblock->data(), cch.filter_desc, wblock->data(),
+ cch.conv_desc, cch.fp_alg,
+ cch.workspace.block()->mutable_data(),
+ cch.workspace_count * sizeof(float), &beta,
+ cch.y_desc, outblock->mutable_data());
+ }, {x.block(), W.block()}, {output.block()}, cch.workspace.block());
+
+ if (cch.bias_term) {
output.device()->Exec([&output, &b, &cch](Context * ctx) {
float beta = 1.f, alpha = 1.0f;
Block *outblock = output.block(), *bblock = b.block();
- cudnnAddTensor(ctx->cudnn_handle, &alpha, cch.bias_desc_,
- bblock->data(), &beta, cch.y_desc_,
+ cudnnAddTensor(ctx->cudnn_handle, &alpha, cch.bias_desc,
+ bblock->data(), &beta, cch.y_desc,
outblock->mutable_data());
}, {output.block(), b.block()}, {output.block()});
}
@@ -330,13 +330,13 @@ Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, cons
Block *wblock = W.block(), *dyblock = dy.block(),
*dxblock = dx.block();
float alpha = 1.f, beta = 0.f;
- cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, cch.filter_desc_,
- wblock->data(), cch.y_desc_, dyblock->data(),
- cch.conv_desc_, cch.bp_data_alg_,
- cch.workspace_.block()->mutable_data(),
- cch.workspace_count_ * sizeof(float), &beta,
- cch.x_desc_, dxblock->mutable_data());
- }, {dy.block(), W.block()}, {dx.block(), cch.workspace_.block()});
+ cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, cch.filter_desc,
+ wblock->data(), cch.y_desc, dyblock->data(),
+ cch.conv_desc, cch.bp_data_alg,
+ cch.workspace.block()->mutable_data(),
+ cch.workspace_count * sizeof(float), &beta,
+ cch.x_desc, dxblock->mutable_data());
+ }, {dy.block(), W.block()}, {dx.block(), cch.workspace.block()});
return dx;
}
@@ -352,12 +352,12 @@ Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
*dwblock = dW.block();
float alpha = 1.f, beta = 0.f;
cudnnConvolutionBackwardFilter(
- ctx->cudnn_handle, &alpha, cch.x_desc_, inblock->data(),
- cch.y_desc_, dyblock->data(), cch.conv_desc_, cch.bp_filter_alg_,
- cch.workspace_.block()->mutable_data(),
- cch.workspace_count_ * sizeof(float), &beta, cch.filter_desc_,
+ ctx->cudnn_handle, &alpha, cch.x_desc, inblock->data(),
+ cch.y_desc, dyblock->data(), cch.conv_desc, cch.bp_filter_alg,
+ cch.workspace.block()->mutable_data(),
+ cch.workspace_count * sizeof(float), &beta, cch.filter_desc,
dwblock->mutable_data());
- }, {dy.block(), x.block()}, {dW.block(), cch.workspace_.block()});
+ }, {dy.block(), x.block()}, {dW.block(), cch.workspace.block()});
return dW;
}
@@ -372,8 +372,8 @@ Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle
dy.device()->Exec([&db, &dy, &cch](Context * ctx) {
Block *dyblock = dy.block(), *dbblock = db.block();
float alpha = 1.f, beta = 0.f;
- cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, cch.y_desc_,
- dyblock->data(), &beta, cch.bias_desc_,
+ cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, cch.y_desc,
+ dyblock->data(), &beta, cch.bias_desc,
dbblock->mutable_data());
}, {dy.block()}, {db.block()});
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ac5f4eb2/src/model/operation/convolution.h
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution.h b/src/model/operation/convolution.h
index 93f7775..62ff254 100755
--- a/src/model/operation/convolution.h
+++ b/src/model/operation/convolution.h
@@ -22,26 +22,26 @@ class ConvHandle {
const size_t in_channels, const size_t out_channels,
const bool bias);
- size_t kernel_w_;
- size_t pad_w_;
- size_t stride_w_;
- size_t kernel_h_;
- size_t pad_h_;
- size_t stride_h_;
-
- size_t channels_;
- size_t num_filters_;
-
- bool bias_term_;
-
- size_t height_;
- size_t width_;
- size_t conv_height_;
- size_t conv_width_;
+ size_t kernel_w;
+ size_t pad_w;
+ size_t stride_w;
+ size_t kernel_h;
+ size_t pad_h;
+ size_t stride_h;
+
+ size_t channels;
+ size_t num_filters;
+
+ bool bias_term;
+
+ size_t height;
+ size_t width;
+ size_t conv_height;
+ size_t conv_width;
size_t batchsize;
- size_t col_height_;
- size_t col_width_;
+ size_t col_height;
+ size_t col_width;
size_t imagesize;
};
@@ -62,22 +62,22 @@ class CudnnConvHandle: public ConvHandle {
CudnnConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
const std::vector<size_t>& stride, const std::vector<size_t>& padding,
const size_t in_channels, const size_t out_channels,
- const bool bias, const size_t workspace_byte_limit_ = 1024 * 1024 * 1024,
- const std::string& prefer_ = "fastest");
+ const bool bias, const size_t workspace_byte_limit = 1024 * 1024 * 1024,
+ const std::string& prefer = "fastest");
~CudnnConvHandle();
// TODO(wangwei) add the destructor
- cudnnTensorDescriptor_t x_desc_ = nullptr;
- cudnnTensorDescriptor_t y_desc_ = nullptr;
- cudnnTensorDescriptor_t bias_desc_ = nullptr;
- cudnnFilterDescriptor_t filter_desc_ = nullptr;
- cudnnConvolutionDescriptor_t conv_desc_ = nullptr;
- cudnnConvolutionFwdAlgo_t fp_alg_;
- cudnnConvolutionBwdFilterAlgo_t bp_filter_alg_;
- cudnnConvolutionBwdDataAlgo_t bp_data_alg_;
-
- size_t workspace_count_;
- Tensor workspace_;
+ cudnnTensorDescriptor_t x_desc = nullptr;
+ cudnnTensorDescriptor_t y_desc = nullptr;
+ cudnnTensorDescriptor_t bias_desc = nullptr;
+ cudnnFilterDescriptor_t filter_desc = nullptr;
+ cudnnConvolutionDescriptor_t conv_desc = nullptr;
+ cudnnConvolutionFwdAlgo_t fp_alg;
+ cudnnConvolutionBwdFilterAlgo_t bp_filter_alg;
+ cudnnConvolutionBwdDataAlgo_t bp_data_alg;
+
+ size_t workspace_count;
+ Tensor workspace;
};
Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const CudnnConvHandle &cch);