You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/07/08 08:06:12 UTC
[1/2] incubator-singa git commit: SINGA-346 Update cudnn from V5 to V7
Repository: incubator-singa
Updated Branches:
refs/heads/master 56292f1fb -> e16cea129
SINGA-346 Update cudnn from V5 to V7
support cudnn5 (conv and rnn has API changes from v5 to v7)
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e2092030
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e2092030
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e2092030
Branch: refs/heads/master
Commit: e20920309bfeb6ed7e0adf3d529c2fba1d44ad2f
Parents: 56292f1
Author: Wang Wei <wa...@gmail.com>
Authored: Thu Jul 5 22:57:33 2018 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Sun Jul 8 16:00:38 2018 +0800
----------------------------------------------------------------------
src/model/layer/cudnn_convolution.cc | 101 +++++++++---------
src/model/layer/cudnn_rnn.cc | 165 +++++++++++++++---------------
2 files changed, 137 insertions(+), 129 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e2092030/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index 8846746..1b12f93 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -44,7 +44,7 @@ void CudnnConvolution::Setup(const Shape& in_sample, const LayerConf &conf) {
CHECK(prefer_ == "fastest" || prefer_ == "limited_workspace" ||
prefer_ == "no_workspace" || prefer_ == "autotune")
<< "CudnnConvolution only supports four algorithm preferences: fastest, "
- "limited_workspace, no_workspace and autotune";
+ "limited_workspace, no_workspace and autotune";
}
void CudnnConvolution::ToDevice(std::shared_ptr<Device> device) {
@@ -70,16 +70,19 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
GetCudnnDataType(dtype), batchsize,
channels_, height_, width_));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
- y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
- num_filters_, conv_height_, conv_width_));
+ y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
+ num_filters_, conv_height_, conv_width_));
if (bias_term_)
CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW,
GetCudnnDataType(dtype), 1,
num_filters_, 1, 1));
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, pad_h_, pad_w_,
- stride_h_, stride_w_, 1, 1,
- CUDNN_CROSS_CORRELATION,
- GetCudnnDataType(dtype)));
+ stride_h_, stride_w_, 1, 1, // dilation x and y
+ CUDNN_CROSS_CORRELATION
+#if CUDNN_MAJOR == 5
+ , GetCudnnDataType(dtype)
+#endif // CUDNN_MAJOR
+ ));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc_, GetCudnnDataType(dtype),
CUDNN_TENSOR_NCHW, num_filters_,
channels_, kernel_h_, kernel_w_));
@@ -102,15 +105,15 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
}
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
- ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
- workspace_byte_limit_, &fp_alg_));
+ ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
+ workspace_byte_limit_, &fp_alg_));
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
- ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
- bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_));
+ ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
+ bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_));
// deprecated in cudnn v7
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
- ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
- bwd_data_pref, workspace_byte_limit_, &bp_data_alg_));
+ ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
+ bwd_data_pref, workspace_byte_limit_, &bp_data_alg_));
} else if (prefer_ == "autotune") {
const int topk = 1;
int num_fp_alg, num_bp_filt_alg, num_bp_data_alg;
@@ -118,16 +121,16 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
cudnnConvolutionBwdFilterAlgoPerf_t bp_filt_perf[topk];
cudnnConvolutionBwdDataAlgoPerf_t bp_data_perf[topk];
CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm(
- ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk,
- &num_fp_alg, fp_alg_perf));
+ ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk,
+ &num_fp_alg, fp_alg_perf));
fp_alg_ = fp_alg_perf[0].algo;
CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(
- ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk,
- &num_bp_filt_alg, bp_filt_perf));
+ ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk,
+ &num_bp_filt_alg, bp_filt_perf));
bp_filter_alg_ = bp_filt_perf[0].algo;
CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm(
- ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk,
- &num_bp_data_alg, bp_data_perf));
+ ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk,
+ &num_bp_data_alg, bp_data_perf));
bp_data_alg_ = bp_data_perf[0].algo;
} else {
LOG(FATAL) << "Preferred algorithm is not available!";
@@ -135,22 +138,22 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
size_t fp_byte, bp_data_byte, bp_filter_byte;
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
- ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_,
- &fp_byte));
+ ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_,
+ &fp_byte));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
- ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
- bp_data_alg_, &bp_data_byte));
+ ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
+ bp_data_alg_, &bp_data_byte));
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
- ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
- bp_filter_alg_, &bp_filter_byte));
+ ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
+ bp_filter_alg_, &bp_filter_byte));
workspace_count_ = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) /
- sizeof(float) +
+ sizeof(float) +
1;
if (workspace_count_ * sizeof(float) > workspace_byte_limit_)
LOG(WARNING) << "The required memory for workspace ("
- << workspace_count_ * sizeof(float)
- << ") is larger than the expected Bytes ("
- << workspace_byte_limit_ << ")";
+ << workspace_count_ * sizeof(float)
+ << ") is larger than the expected Bytes ("
+ << workspace_byte_limit_ << ")";
workspace_ = Tensor(Shape{workspace_count_}, dev, dtype);
has_init_cudnn_ = true;
}
@@ -170,23 +173,23 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) {
int n, c, h, w, s;
cudnnDataType_t type;
CUDNN_CHECK(cudnnGetTensor4dDescriptor(x_desc_, &type, &n, &c, &h, &w,
- &s, &s, &s, &s));
+ &s, &s, &s, &s));
if (batchsize != static_cast<size_t>(n))
InitCudnn(input);
CHECK(input.shape(1) == static_cast<size_t>(c)
- && input.shape(2) == static_cast<size_t>(h)
- && input.shape(3) == static_cast<size_t>(w))
- << "input sample shape should not change"
- << "previous shape " << c << ", " << h << ", " << w
- << "current shape " << input.shape(1) << ", " << input.shape(2) << ", "
- << input.shape(3);
+ && input.shape(2) == static_cast<size_t>(h)
+ && input.shape(3) == static_cast<size_t>(w))
+ << "input sample shape should not change"
+ << "previous shape " << c << ", " << h << ", " << w
+ << "current shape " << input.shape(1) << ", " << input.shape(2) << ", "
+ << input.shape(3);
}
Shape shape{batchsize, num_filters_, conv_height_, conv_width_};
Tensor output(shape, dev, dtype);
- output.device()->Exec([input, output, this](Context *ctx) {
+ output.device()->Exec([input, output, this](Context * ctx) {
Block *inblock = input.block(), *outblock = output.block(),
- *wblock = this->weight_.block();
+ *wblock = this->weight_.block();
float alpha = 1.f, beta = 0.f;
cudnnConvolutionForward(ctx->cudnn_handle, &alpha, this->x_desc_,
inblock->data(), this->filter_desc_, wblock->data(),
@@ -197,7 +200,7 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) {
}, {input.block(), weight_.block()}, {output.block()}, workspace_.block());
if (bias_term_) {
- output.device()->Exec([output, this](Context *ctx) {
+ output.device()->Exec([output, this](Context * ctx) {
float beta = 1.f, alpha = 1.0f;
Block *outblock = output.block(), *bblock = this->bias_.block();
cudnnAddTensor(ctx->cudnn_handle, &alpha, this->bias_desc_,
@@ -209,7 +212,7 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) {
}
const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
- int flag, const Tensor &grad) {
+int flag, const Tensor &grad) {
CHECK(has_init_cudnn_);
CHECK_EQ(grad.device()->lang(), kCuda);
CHECK_EQ(grad.nDim(), 4u);
@@ -225,7 +228,7 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
// LOG(ERROR) << "backward bias";
if (bias_term_) {
db.ResetLike(bias_);
- dx.device()->Exec([grad, db, this](Context *ctx) {
+ dx.device()->Exec([grad, db, this](Context * ctx) {
Block *dyblock = grad.block(), *dbblock = db.block();
float alpha = 1.f, beta = 0.f;
cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, this->y_desc_,
@@ -234,22 +237,22 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
}, {grad.block()}, {db.block()});
}
// LOG(ERROR) << "backward w";
- dx.device()->Exec([grad, dw, src_data, this](Context *ctx) {
+ dx.device()->Exec([grad, dw, src_data, this](Context * ctx) {
Block *inblock = src_data.block(), *dyblock = grad.block(),
- *dwblock = dw.block();
+ *dwblock = dw.block();
float alpha = 1.f, beta = 0.f;
cudnnConvolutionBackwardFilter(
- ctx->cudnn_handle, &alpha, this->x_desc_, inblock->data(),
- this->y_desc_, dyblock->data(), this->conv_desc_, this->bp_filter_alg_,
- this->workspace_.block()->mutable_data(),
- this->workspace_count_ * sizeof(float), &beta, this->filter_desc_,
- dwblock->mutable_data());
+ ctx->cudnn_handle, &alpha, this->x_desc_, inblock->data(),
+ this->y_desc_, dyblock->data(), this->conv_desc_, this->bp_filter_alg_,
+ this->workspace_.block()->mutable_data(),
+ this->workspace_count_ * sizeof(float), &beta, this->filter_desc_,
+ dwblock->mutable_data());
}, {grad.block(), src_data.block()}, {dw.block(), workspace_.block()});
// LOG(ERROR) << "backward src";
- dx.device()->Exec([dx, grad, this](Context *ctx) {
+ dx.device()->Exec([dx, grad, this](Context * ctx) {
Block *wblock = this->weight_.block(), *dyblock = grad.block(),
- *dxblock = dx.block();
+ *dxblock = dx.block();
float alpha = 1.f, beta = 0.f;
cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, this->filter_desc_,
wblock->data(), this->y_desc_, dyblock->data(),
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e2092030/src/model/layer/cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc
index fb5fee0..28a52c5 100644
--- a/src/model/layer/cudnn_rnn.cc
+++ b/src/model/layer/cudnn_rnn.cc
@@ -125,8 +125,8 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) {
CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size));
dropout_state_ = Tensor(Shape{state_size}, dev, kChar);
CUDNN_CHECK(cudnnSetDropoutDescriptor(
- dropout_desc_, ctx->cudnn_handle, 1 - dropout_, // keep probability
- dropout_state_.block()->mutable_data(), state_size, seed_));
+ dropout_desc_, ctx->cudnn_handle, 1 - dropout_, // keep probability
+ dropout_state_.block()->mutable_data(), state_size, seed_));
CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_));
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
@@ -144,10 +144,15 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) {
rnn_mode = CUDNN_RNN_TANH;
else if (rnn_mode_ == "gru")
rnn_mode = CUDNN_GRU;
+#ifdef CUDNN_MAJOR == 5
+ CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_size_, num_stacks_,
+ dropout_desc_, input_mode, direction,
+ rnn_mode, dtype_));
+#else
CUDNN_CHECK(cudnnSetRNNDescriptor(ctx->cudnn_handle, rnn_desc_, hidden_size_, num_stacks_,
dropout_desc_, input_mode, direction,
rnn_mode, CUDNN_RNN_ALGO_STANDARD, dtype_));
-
+#endif
size_t weight_size;
CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0],
&weight_size, dtype_));
@@ -199,7 +204,7 @@ void CudnnRNN::UpdateSpaces(size_t seq_length, shared_ptr<Device> dev) {
}
CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnn_desc_,
- seq_length, x_descs_, &count));
+ seq_length, x_descs_, &count));
if (reserve_space_.Size() != count) {
reserve_space_ = Tensor(Shape{count}, dev, kChar);
// reserve_space_.SetValue(0);
@@ -263,8 +268,8 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
if (rnn_desc_ != nullptr)
CHECK_EQ(dtype_, GetCudnnDataType(dtype))
- << "Cannot change cudnn data type during training from " << dtype_
- << " to " << GetCudnnDataType(dtype);
+ << "Cannot change cudnn data type during training from " << dtype_
+ << " to " << GetCudnnDataType(dtype);
else
dtype_ = GetCudnnDataType(dtype);
@@ -303,57 +308,57 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
// LOG(INFO) << "hidden size " << hy.Size();
// LOG(INFO) << "weight size " << weight_.Size() << " value " << weight_.L1();
Block *inb = input.block(), *outb = output.block(),
- *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(),
- *hyb = hy.block(), *cyb = cy.block(),
- *wspace = this->workspace_.block(),
- *rspace = this->reserve_space_.block();
+ *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(),
+ *hyb = hy.block(), *cyb = cy.block(),
+ *wspace = this->workspace_.block(),
+ *rspace = this->reserve_space_.block();
if (flag & kTrain) {
CHECK_EQ(reserve_space_.device()->lang(), kCuda);
CHECK_EQ(did, reserve_space_.device()->id());
dev->Exec(
- [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, this](Context *ctx) {
- // clang-format off
- cudnnRNNForwardTraining(
- ctx->cudnn_handle,
- this->rnn_desc_,
- this->seq_length_,
- this->x_descs_, inb->data(),
- this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
- this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
- this->weight_desc_, wb->data(),
- this->y_descs_, outb->mutable_data(),
- this->hy_desc_, hyb->mutable_data(),
- this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
- wspace->mutable_data(),
- this->workspace_.Size(), rspace->mutable_data(),
- this->reserve_space_.Size());
- // clang-format on
- },
- {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace});
+ [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, this](Context * ctx) {
+ // clang-format off
+ cudnnRNNForwardTraining(
+ ctx->cudnn_handle,
+ this->rnn_desc_,
+ this->seq_length_,
+ this->x_descs_, inb->data(),
+ this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+ this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
+ this->weight_desc_, wb->data(),
+ this->y_descs_, outb->mutable_data(),
+ this->hy_desc_, hyb->mutable_data(),
+ this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
+ wspace->mutable_data(),
+ this->workspace_.Size(), rspace->mutable_data(),
+ this->reserve_space_.Size());
+ // clang-format on
+ },
+ {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace});
buf_.push(input);
buf_.push(output);
buf_.push(hx);
buf_.push(cx);
} else {
- dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, this](Context *ctx) {
+ dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, this](Context * ctx) {
// clang-format off
cudnnRNNForwardInference(
- ctx->cudnn_handle,
- this->rnn_desc_,
- this->seq_length_,
- this->x_descs_, inb->data(),
- this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
- this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
- this->weight_desc_, wb->data(),
- this->y_descs_, outb->mutable_data(),
- this->hy_desc_, hyb->mutable_data(),
- this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
- wspace->mutable_data(), this->workspace_.Size());
+ ctx->cudnn_handle,
+ this->rnn_desc_,
+ this->seq_length_,
+ this->x_descs_, inb->data(),
+ this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+ this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
+ this->weight_desc_, wb->data(),
+ this->y_descs_, outb->mutable_data(),
+ this->hy_desc_, hyb->mutable_data(),
+ this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
+ wspace->mutable_data(), this->workspace_.Size());
// clang-format on
}, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace});
}
auto outputs =
- SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output);
+ SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output);
outputs.push_back(hy);
if (has_cell_) outputs.push_back(cy);
return outputs;
@@ -361,7 +366,7 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
// TODO(wangwei) check Tensor device to be on cuda?
const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward(
- int flag, const vector<Tensor> &grads) {
+int flag, const vector<Tensor> &grads) {
// dhy (and dcy) is at last
const Tensor cx = buf_.top(); // cannot use const Tensor& due to pop()
buf_.pop();
@@ -395,45 +400,45 @@ const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward(
dcx.ResetLike(dhx);
dw.SetValue(0.0f);
Block *yb = y.block(), *dyb = dy.block(), *dhyb = dhy.block(),
- *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(),
- *wb = weight_.block(), *dwb = dw.block(), *hxb = hx.block(),
- *dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(),
- *wspace = workspace_.block(), *rspace = reserve_space_.block();
+ *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(),
+ *wb = weight_.block(), *dwb = dw.block(), *hxb = hx.block(),
+ *dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(),
+ *wspace = workspace_.block(), *rspace = reserve_space_.block();
y.device()->Exec(
- [yb, dyb, dhyb, dcyb, xb, cxb, wb, dwb, hxb, dxb, dhxb, dcxb, wspace,
- rspace, this](Context *ctx) {
- // clang-format off
- cudnnRNNBackwardData(
- ctx->cudnn_handle,
- this->rnn_desc_,
- this->seq_length_,
- this->y_descs_, yb->data(),
- this->dy_descs_, dyb->data(),
- this->dhy_desc_, dhyb == nullptr ? nullptr : dhyb->data(),
- this->dcy_desc_, dcyb == nullptr ? nullptr : dcyb->data(),
- this->weight_desc_, wb->data(),
- this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
- this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
- this->dx_descs_, dxb->mutable_data(),
- this->dhx_desc_, dhxb->mutable_data(),
- this->dcx_desc_, dcxb == nullptr ? nullptr : dcxb->mutable_data(),
- wspace->mutable_data(), this->workspace_.Size(),
- rspace->mutable_data(), this->reserve_space_.Size());
- cudnnRNNBackwardWeights(
- ctx->cudnn_handle,
- this->rnn_desc_,
- this->seq_length_,
- this->x_descs_, xb->data(),
- this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
- this->y_descs_, yb->data(),
- wspace->data(), this->workspace_.Size(),
- this->dweight_desc_, dwb->mutable_data(),
- rspace->data(), this->reserve_space_.Size());
- // clang-format on
- },
- {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace},
- {dxb, dwb, dhxb, dcxb, wspace, rspace});
+ [yb, dyb, dhyb, dcyb, xb, cxb, wb, dwb, hxb, dxb, dhxb, dcxb, wspace,
+ rspace, this](Context * ctx) {
+ // clang-format off
+ cudnnRNNBackwardData(
+ ctx->cudnn_handle,
+ this->rnn_desc_,
+ this->seq_length_,
+ this->y_descs_, yb->data(),
+ this->dy_descs_, dyb->data(),
+ this->dhy_desc_, dhyb == nullptr ? nullptr : dhyb->data(),
+ this->dcy_desc_, dcyb == nullptr ? nullptr : dcyb->data(),
+ this->weight_desc_, wb->data(),
+ this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+ this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
+ this->dx_descs_, dxb->mutable_data(),
+ this->dhx_desc_, dhxb->mutable_data(),
+ this->dcx_desc_, dcxb == nullptr ? nullptr : dcxb->mutable_data(),
+ wspace->mutable_data(), this->workspace_.Size(),
+ rspace->mutable_data(), this->reserve_space_.Size());
+ cudnnRNNBackwardWeights(
+ ctx->cudnn_handle,
+ this->rnn_desc_,
+ this->seq_length_,
+ this->x_descs_, xb->data(),
+ this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+ this->y_descs_, yb->data(),
+ wspace->data(), this->workspace_.Size(),
+ this->dweight_desc_, dwb->mutable_data(),
+ rspace->data(), this->reserve_space_.Size());
+ // clang-format on
+ },
+ {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace},
+ {dxb, dwb, dhxb, dcxb, wspace, rspace});
vector <Tensor> param_grad{dw};
auto data_grads = SplitOutput(num_dy, input_size_, grads, dx);
[2/2] incubator-singa git commit: SINGA-371 Implement functional
operations in c++ for autograd
Posted by wa...@apache.org.
SINGA-371 Implement functional operations in c++ for autograd
fix some bugs and update the example for autograd
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e16cea12
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e16cea12
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e16cea12
Branch: refs/heads/master
Commit: e16cea129b688c804afe87b3bd1b6a82e5f5ca5f
Parents: e209203
Author: wang wei <wa...@comp.nus.edu.sg>
Authored: Sat Jul 7 22:00:07 2018 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Sun Jul 8 16:01:45 2018 +0800
----------------------------------------------------------------------
examples/autograd/mnist_cnn.py | 25 ++++---
python/singa/autograd.py | 6 +-
src/api/model_operation.i | 29 ++++----
src/core/tensor/tensor.cc | 4 +-
src/core/tensor/tensor_math_cuda.h | 114 +++++++++++++++++-------------
src/model/layer/cudnn_convolution.cc | 2 +-
src/model/operation/convolution.cc | 9 ++-
src/model/operation/convolution.h | 1 +
tool/conda/singa/build.sh | 3 +-
tool/conda/singa/meta.yaml | 2 +-
10 files changed, 112 insertions(+), 83 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index a82f64c..5b4e608 100755
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -21,13 +21,11 @@ import numpy as np
import argparse
import os
-import singa
+from singa import device
from singa import tensor
from singa import autograd
from singa import optimizer
-singa.layer.engine = 'singacpp'
-
def load_data(path):
f = np.load(path)
@@ -75,11 +73,18 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train CNN over MNIST')
parser.add_argument('file_path', type=str, help='the dataset path')
+ parser.add_argument('--use_cpu', action='store_true')
args = parser.parse_args()
assert os.path.exists(args.file_path), \
- 'Pls download the MNIST dataset from ' \
- 'https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz'
+ 'Pls download the MNIST dataset from '
+
+ if args.use_cpu:
+ print('Using CPU')
+ dev = device.get_default_device()
+ else:
+ print('Using GPU')
+ dev = device.create_cuda_gpu()
train, test = load_data(args.file_path)
@@ -119,16 +124,16 @@ if __name__ == '__main__':
autograd.training = True
for epoch in range(epochs):
for i in range(batch_number):
- inputs = tensor.Tensor(data=x_train[i * 100:(1 + i) * 100, :])
- targets = tensor.Tensor(data=y_train[i * 100:(1 + i) * 100, :])
+ inputs = tensor.Tensor(device=dev, data=x_train[i * 100:(1 + i) * 100])
+ targets = tensor.Tensor(device=dev, data=y_train[i * 100:(1 + i) * 100])
loss, y = forward(inputs, targets)
- accuracy_rate = accuracy(autograd.ctensor2numpy(y.data),
- autograd.ctensor2numpy(targets.data))
+ accuracy_rate = accuracy(tensor.to_numpy(y),
+ tensor.to_numpy(targets))
if (i % 5 == 0):
print('accuracy is:', accuracy_rate, 'loss is:',
- autograd.ctensor2numpy(loss.data)[0])
+ tensor.to_numpy(loss)[0])
in_grads = autograd.backward(loss)
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 80209ff..9fd8b4d 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -741,12 +741,10 @@ class Conv2D(NewLayer):
else:
if not hasattr(self, 'handle'):
self.handle = singa.CudnnConvHandle(x.data, self.kernel_size, self.stride,
- self.padding, self.in_channels, self.out_channels, self.bias,
- self.inner_params['workspace_MB_limit'] * 1024 * 1024, self.inner_params['cudnn_prefer'])
+ self.padding, self.in_channels, self.out_channels, self.bias)
elif x.shape[0] != self.handle.batchsize:
self.handle = singa.CudnnConvHandle(x.data, self.kernel_size, self.stride,
- self.padding, self.in_channels, self.out_channels, self.bias,
- self.inner_params['workspace_MB_limit'] * 1024 * 1024, self.inner_params['cudnn_prefer'])
+ self.padding, self.in_channels, self.out_channels, self.bias)
self.handle.device_id = x.device.id()
y = conv2d(x, self.W, self.b, self.handle)
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index 2c13a3b..3858a2b 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -1,9 +1,12 @@
%module model_operation
+%include "config.i"
+%include "std_vector.i"
+%include "std_string.i"
%{
#include "../src/model/operation/convolution.h"
%}
-namespace singa{
+namespace singa {
class ConvHandle {
public:
@@ -15,15 +18,24 @@ class ConvHandle {
size_t batchsize;
};
-struct CudnnConvHandle{
+Tensor CpuConvForward(const Tensor &x, Tensor &W, Tensor &b, const ConvHandle &ch);
+
+Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const ConvHandle &ch);
+
+Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const ConvHandle &ch);
+
+Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch);
+
+#if USE_CUDNN
+class CudnnConvHandle: public ConvHandle {
public:
- CudnnConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+ CudnnConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
const std::vector<size_t>& stride, const std::vector<size_t>& padding,
const size_t in_channels, const size_t out_channels,
const bool bias, const size_t workspace_byte_limit = 1024 * 1024 * 1024,
const std::string& prefer = "fastest");
bool bias_term;
- size_t batchsize;
+ size_t batchsize;
};
Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const CudnnConvHandle &cch);
@@ -34,13 +46,6 @@ Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle &cch);
-
-Tensor CpuConvForward(const Tensor &x, Tensor &W, Tensor &b, const ConvHandle &ch);
-
-Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const ConvHandle &ch);
-
-Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const ConvHandle &ch);
-
-Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch);
+#endif // USE_CUDNN
}
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index e0a9ecb..05db7cf 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -346,7 +346,7 @@ Tensor Tensor::Repeat(vector<size_t> repeats, int axis, std::shared_ptr<Device>
} else {
if (repeats.size() == 1){
total_repeats = repeats[0];
- for (int i = 0; i < shape_.size(); i++) {
+ for (size_t i = 0; i < shape_.size(); i++) {
if (i == axis) {
tshape.push_back(shape_[i] * total_repeats);
} else {
@@ -363,7 +363,7 @@ Tensor Tensor::Repeat(vector<size_t> repeats, int axis, std::shared_ptr<Device>
}
total_repeats += repeats[i];
}
- for (int i = 0; i < shape_.size(); i++){
+ for (size_t i = 0; i < shape_.size(); i++){
if (i == axis) {
tshape.push_back(total_repeats);
} else{
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index a1b9381..2a43468 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -791,6 +791,12 @@ void Sqrt<float, lang::Cuda>(const Tensor& in, Tensor* out,
const float* inPtr = static_cast<const float*>(in.block()->data());
float* outPtr = static_cast<float*>(out->block()->mutable_data());
+#if CUDNN_MAJOR < 7
+ size_t num = in.Size();
+ cuda::sqrt(num, inPtr, outPtr, ctx->stream);
+
+#else
+
float alpha1 = 1.0;
float alpha2 = 0.0;
float beta = 0.0;
@@ -800,6 +806,7 @@ void Sqrt<float, lang::Cuda>(const Tensor& in, Tensor* out,
(void*)(&alpha2), in_desc, inPtr,
(void*)(&beta), generate_tensor_nd_desc(*out), outPtr
));
+#endif // CUDNN_MAJOR < 7
}
/// Element-wise operation, out[i]=in[i]^2
@@ -833,54 +840,6 @@ void Square<float, lang::Cuda>(const Tensor& in, Tensor* out,
// // cuda::sum(num, inPtr, out, ctx->stream);
// }
-template <>
-void Sum<float, lang::Cuda>(const Tensor& in, float* out,
- Context* ctx) {
- const float* inPtr = static_cast<const float*>(in.block()->data());
-
- //reduce all axes to 1 for cudnnReduce, e.g. Tensor A with shape (2,4) will be reduced to (1)
- Shape reduced_shape = {1};
- Tensor t(reduced_shape, in.device(), in.data_type());
- float* tPtr = static_cast<float*>(t.block()->mutable_data());
- vector<int> reduce_all_axes = generate_shape_cuda(in);
- for (size_t n = 0; n < reduce_all_axes.size(); ++n) {
- reduce_all_axes[n] = 1;
- }
-
- //reduce_desc
- cudnnReduceTensorDescriptor_t reduce_desc;
- cudnnReduceTensorOp_t reduce_op = CUDNN_REDUCE_TENSOR_ADD;
- cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT;
- cudnnNanPropagation_t cudnn_propagation = CUDNN_PROPAGATE_NAN;
- cudnnReduceTensorIndices_t cudnn_indices = CUDNN_REDUCE_TENSOR_NO_INDICES;
- cudnnIndicesType_t cudnn_indices_type = CUDNN_32BIT_INDICES;
- check_cudnn(cudnnCreateReduceTensorDescriptor(&reduce_desc));
- check_cudnn(cudnnSetReduceTensorDescriptor(reduce_desc, reduce_op, cudnn_dtype,
- cudnn_propagation, cudnn_indices, cudnn_indices_type));
-
- //instantiate 2 new tensors to use new blocks as memory instead of cudaMalloc
- size_t reduction_size_int = Product(in.shape());
- Shape reduction_size = {reduction_size_int * 100};
- Tensor indices(reduction_size, in.device(), in.data_type());
- Tensor workspace(reduction_size, in.device(), in.data_type());
- size_t indices_bytes = indices.block()->size() * 100;
- size_t workspace_bytes = workspace.block()->size() * 100;
- size_t* indicesPtr = static_cast<size_t*>(indices.block()->mutable_data());
- float* workspacePtr = static_cast<float*>(workspace.block()->mutable_data());
- //void* indicesPtr{nullptr}; void* workspacePtr{nullptr};
- //cudaMalloc(&indicesPtr, indices_bytes); cudaMalloc(&workspacePtr, workspace_bytes);
-
- float alpha = 1.0;
- float beta = 0.0;
- check_cudnn(cudnnReduceTensor(ctx->cudnn_handle, reduce_desc,
- indicesPtr, indices_bytes, workspacePtr, workspace_bytes,
- (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
- (void*)(&beta), generate_tensor_nd_desc(t), tPtr
- ));
-
- *out = tPtr[0];
-}
-
/// Element-wise operation, out[i]=tanh([in[i])
// template <>
@@ -949,7 +908,7 @@ void Transform<float, lang::Cuda>(const Tensor& in, Tensor* out,
(void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
(void*)(&beta), generate_tensor_nd_desc(*out), outPtr
));
-
+
}
// ================Random functions===========================================
@@ -1233,6 +1192,63 @@ void RowMax<float, lang::Cuda>(const Tensor& in, Tensor* out,
}
}
+
+// must put this function after Set and Dot functions due to the error from
+// instantiation before specialization
+template <>
+void Sum<float, lang::Cuda>(const Tensor& in, float* out,
+ Context* ctx) {
+#if CUDNN_MAJOR < 7
+ Tensor one(in.shape(), in.device(), in.data_type());
+ Set<float, lang::Cuda>(float(1), &one, ctx);
+ Dot<float, lang::Cuda>(in, one, out, ctx);
+#else
+ const float* inPtr = static_cast<const float*>(in.block()->data());
+ //reduce all axes to 1 for cudnnReduce, e.g. Tensor A with shape (2,4) will be reduced to (1)
+ Shape reduced_shape = {1};
+ Tensor t(reduced_shape, in.device(), in.data_type());
+ float* tPtr = static_cast<float*>(t.block()->mutable_data());
+ vector<int> reduce_all_axes = generate_shape_cuda(in);
+ for (size_t n = 0; n < reduce_all_axes.size(); ++n) {
+ reduce_all_axes[n] = 1;
+ }
+
+ //reduce_desc
+ cudnnReduceTensorDescriptor_t reduce_desc;
+ cudnnReduceTensorOp_t reduce_op = CUDNN_REDUCE_TENSOR_ADD;
+ cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT;
+ cudnnNanPropagation_t cudnn_propagation = CUDNN_PROPAGATE_NAN;
+ cudnnReduceTensorIndices_t cudnn_indices = CUDNN_REDUCE_TENSOR_NO_INDICES;
+ cudnnIndicesType_t cudnn_indices_type = CUDNN_32BIT_INDICES;
+ check_cudnn(cudnnCreateReduceTensorDescriptor(&reduce_desc));
+ check_cudnn(cudnnSetReduceTensorDescriptor(reduce_desc, reduce_op, cudnn_dtype,
+ cudnn_propagation, cudnn_indices, cudnn_indices_type));
+
+ //instantiate 2 new tensors to use new blocks as memory instead of cudaMalloc
+ size_t reduction_size_int = Product(in.shape());
+ Shape reduction_size = {reduction_size_int * 100};
+ Tensor indices(reduction_size, in.device(), in.data_type());
+ Tensor workspace(reduction_size, in.device(), in.data_type());
+ size_t indices_bytes = indices.block()->size() * 100;
+ size_t workspace_bytes = workspace.block()->size() * 100;
+ size_t* indicesPtr = static_cast<size_t*>(indices.block()->mutable_data());
+ float* workspacePtr = static_cast<float*>(workspace.block()->mutable_data());
+ //void* indicesPtr{nullptr}; void* workspacePtr{nullptr};
+ //cudaMalloc(&indicesPtr, indices_bytes); cudaMalloc(&workspacePtr, workspace_bytes);
+
+ float alpha = 1.0;
+ float beta = 0.0;
+ check_cudnn(cudnnReduceTensor(ctx->cudnn_handle, reduce_desc,
+ indicesPtr, indices_bytes, workspacePtr, workspace_bytes,
+ (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+ (void*)(&beta), generate_tensor_nd_desc(t), tPtr
+ ));
+
+ *out = tPtr[0];
+#endif // CUDNN_MAJOR < 7
+}
+
+
} // namespace singa
#endif // USE_CUDA
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index 1b12f93..0aed832 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -79,7 +79,7 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, pad_h_, pad_w_,
stride_h_, stride_w_, 1, 1, // dilation x and y
CUDNN_CROSS_CORRELATION
-#if CUDNN_MAJOR == 5
+#if CUDNN_MAJOR >= 7
, GetCudnnDataType(dtype)
#endif // CUDNN_MAJOR
));
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/src/model/operation/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution.cc b/src/model/operation/convolution.cc
index e36df43..f700203 100755
--- a/src/model/operation/convolution.cc
+++ b/src/model/operation/convolution.cc
@@ -199,8 +199,11 @@ CudnnConvHandle::CudnnConvHandle(const Tensor &input, const std::vector<size_t>&
num_filters, 1, 1));
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc, pad_h, pad_w,
stride_h, stride_w, 1, 1,
- CUDNN_CROSS_CORRELATION,
- GetCudnnDataType(dtype)));
+ CUDNN_CROSS_CORRELATION
+#if CUDNN_MAJOR >= 7
+ , GetCudnnDataType(dtype)
+#endif
+ ));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc, GetCudnnDataType(dtype),
CUDNN_TENSOR_NCHW, num_filters,
channels, kernel_h, kernel_w));
@@ -381,4 +384,4 @@ Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle
}
#endif // USE_CUDNN
-} // namespace singa
\ No newline at end of file
+} // namespace singa
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/src/model/operation/convolution.h
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution.h b/src/model/operation/convolution.h
index 62ff254..9da881f 100755
--- a/src/model/operation/convolution.h
+++ b/src/model/operation/convolution.h
@@ -5,6 +5,7 @@
#include <vector>
#include "singa/core/tensor.h"
#include "singa/utils/logging.h"
+#include "singa/singa_config.h"
#ifdef USE_CUDNN
#include <cudnn.h>
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/tool/conda/singa/build.sh
----------------------------------------------------------------------
diff --git a/tool/conda/singa/build.sh b/tool/conda/singa/build.sh
index 91a2f3b..b54e451 100644
--- a/tool/conda/singa/build.sh
+++ b/tool/conda/singa/build.sh
@@ -23,12 +23,13 @@ export CMAKE_PREFIX_PATH=$PREFIX:$CMAKE_PREFIX_PATH
export CMAKE_INCLUDE_PATH=$PREFIX/include:$CMAKE_INCLUDE_PATH
export CMAKE_LIBRARY_PATH=$PREFIX/lib:$CMAKE_LIBRARY_PATH
+echo "----------------------$CUDNN_PATH---------------"
if [ -z ${CUDNN_PATH+x} ]; then
USE_CUDA=OFF
else
USE_CUDA=ON
- cp -r $CUDNN_PATH/include $PREFIX/include
+ cp $CUDNN_PATH/include/* $PREFIX/include/
cp -P $CUDNN_PATH/lib64/libcudnn.so* $PREFIX/lib/
fi
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/tool/conda/singa/meta.yaml
----------------------------------------------------------------------
diff --git a/tool/conda/singa/meta.yaml b/tool/conda/singa/meta.yaml
index 997341c..ee76636 100644
--- a/tool/conda/singa/meta.yaml
+++ b/tool/conda/singa/meta.yaml
@@ -22,7 +22,7 @@ package:
version: "{{ GIT_DESCRIBE_TAG }}"
source:
- git_url: https://github.com/apache/incubator-singa.git
+ path: /home/wangwei/incubator-singa/
build:
number: 0