You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/07/08 08:06:12 UTC

[1/2] incubator-singa git commit: SINGA-346 Update cudnn from V5 to V7

Repository: incubator-singa
Updated Branches:
  refs/heads/master 56292f1fb -> e16cea129


SINGA-346 Update cudnn from V5 to V7

support cudnn5 (conv and rnn has API changes from v5 to v7)


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e2092030
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e2092030
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e2092030

Branch: refs/heads/master
Commit: e20920309bfeb6ed7e0adf3d529c2fba1d44ad2f
Parents: 56292f1
Author: Wang Wei <wa...@gmail.com>
Authored: Thu Jul 5 22:57:33 2018 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Sun Jul 8 16:00:38 2018 +0800

----------------------------------------------------------------------
 src/model/layer/cudnn_convolution.cc | 101 +++++++++---------
 src/model/layer/cudnn_rnn.cc         | 165 +++++++++++++++---------------
 2 files changed, 137 insertions(+), 129 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e2092030/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index 8846746..1b12f93 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -44,7 +44,7 @@ void CudnnConvolution::Setup(const Shape& in_sample, const LayerConf &conf) {
   CHECK(prefer_ == "fastest" || prefer_ == "limited_workspace" ||
         prefer_ == "no_workspace" || prefer_ == "autotune")
       << "CudnnConvolution only supports four algorithm preferences: fastest, "
-         "limited_workspace, no_workspace and autotune";
+      "limited_workspace, no_workspace and autotune";
 }
 
 void CudnnConvolution::ToDevice(std::shared_ptr<Device> device) {
@@ -70,16 +70,19 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
                                          GetCudnnDataType(dtype), batchsize,
                                          channels_, height_, width_));
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(
-      y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
-      num_filters_, conv_height_, conv_width_));
+                y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
+                num_filters_, conv_height_, conv_width_));
   if (bias_term_)
     CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW,
                                            GetCudnnDataType(dtype), 1,
                                            num_filters_, 1, 1));
   CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, pad_h_, pad_w_,
-                                              stride_h_, stride_w_, 1, 1,
-                                              CUDNN_CROSS_CORRELATION, 
-                                              GetCudnnDataType(dtype)));
+              stride_h_, stride_w_, 1, 1,  // dilation x and y
+              CUDNN_CROSS_CORRELATION
+#if CUDNN_MAJOR == 5
+              , GetCudnnDataType(dtype)
+#endif  // CUDNN_MAJOR
+                                             ));
   CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc_, GetCudnnDataType(dtype),
                                          CUDNN_TENSOR_NCHW, num_filters_,
                                          channels_, kernel_h_, kernel_w_));
@@ -102,15 +105,15 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
       bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
     }
     CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
-        ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
-        workspace_byte_limit_, &fp_alg_));
+                  ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
+                  workspace_byte_limit_, &fp_alg_));
     CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
-        ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
-        bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_));
+                  ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
+                  bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_));
     // deprecated in cudnn v7
     CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
-        ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
-        bwd_data_pref, workspace_byte_limit_, &bp_data_alg_));
+                  ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
+                  bwd_data_pref, workspace_byte_limit_, &bp_data_alg_));
   } else if (prefer_ == "autotune") {
     const int topk = 1;
     int num_fp_alg, num_bp_filt_alg, num_bp_data_alg;
@@ -118,16 +121,16 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
     cudnnConvolutionBwdFilterAlgoPerf_t bp_filt_perf[topk];
     cudnnConvolutionBwdDataAlgoPerf_t bp_data_perf[topk];
     CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm(
-        ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk,
-        &num_fp_alg, fp_alg_perf));
+                  ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk,
+                  &num_fp_alg, fp_alg_perf));
     fp_alg_ = fp_alg_perf[0].algo;
     CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(
-        ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk,
-        &num_bp_filt_alg, bp_filt_perf));
+                  ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk,
+                  &num_bp_filt_alg, bp_filt_perf));
     bp_filter_alg_ = bp_filt_perf[0].algo;
     CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm(
-        ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk,
-        &num_bp_data_alg, bp_data_perf));
+                  ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk,
+                  &num_bp_data_alg, bp_data_perf));
     bp_data_alg_ = bp_data_perf[0].algo;
   } else {
     LOG(FATAL) << "Preferred algorithm is not available!";
@@ -135,22 +138,22 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
 
   size_t fp_byte, bp_data_byte, bp_filter_byte;
   CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
-      ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_,
-      &fp_byte));
+                ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_,
+                &fp_byte));
   CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
-      ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
-      bp_data_alg_, &bp_data_byte));
+                ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
+                bp_data_alg_, &bp_data_byte));
   CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
-      ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
-      bp_filter_alg_, &bp_filter_byte));
+                ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
+                bp_filter_alg_, &bp_filter_byte));
   workspace_count_ = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) /
-                         sizeof(float) +
+                     sizeof(float) +
                      1;
   if (workspace_count_ * sizeof(float) > workspace_byte_limit_)
     LOG(WARNING) << "The required memory for workspace ("
-      << workspace_count_ * sizeof(float)
-      << ") is larger than the expected Bytes ("
-      << workspace_byte_limit_ << ")";
+                 << workspace_count_ * sizeof(float)
+                 << ") is larger than the expected Bytes ("
+                 << workspace_byte_limit_ << ")";
   workspace_ = Tensor(Shape{workspace_count_}, dev, dtype);
   has_init_cudnn_ = true;
 }
@@ -170,23 +173,23 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) {
     int n, c, h, w, s;
     cudnnDataType_t type;
     CUDNN_CHECK(cudnnGetTensor4dDescriptor(x_desc_, &type, &n, &c, &h, &w,
-          &s, &s, &s, &s));
+                                           &s, &s, &s, &s));
     if (batchsize != static_cast<size_t>(n))
       InitCudnn(input);
     CHECK(input.shape(1) == static_cast<size_t>(c)
-        && input.shape(2) == static_cast<size_t>(h)
-        && input.shape(3) == static_cast<size_t>(w))
-      << "input sample shape should not change"
-      << "previous shape " << c << ", " << h << ", " << w
-      << "current shape " << input.shape(1) << ", " << input.shape(2) << ", "
-      << input.shape(3);
+          && input.shape(2) == static_cast<size_t>(h)
+          && input.shape(3) == static_cast<size_t>(w))
+        << "input sample shape should not change"
+        << "previous shape " << c << ", " << h << ", " << w
+        << "current shape " << input.shape(1) << ", " << input.shape(2) << ", "
+        << input.shape(3);
   }
 
   Shape shape{batchsize, num_filters_, conv_height_, conv_width_};
   Tensor output(shape, dev, dtype);
-  output.device()->Exec([input, output, this](Context *ctx) {
+  output.device()->Exec([input, output, this](Context * ctx) {
     Block *inblock = input.block(), *outblock = output.block(),
-          *wblock = this->weight_.block();
+           *wblock = this->weight_.block();
     float alpha = 1.f, beta = 0.f;
     cudnnConvolutionForward(ctx->cudnn_handle, &alpha, this->x_desc_,
                             inblock->data(), this->filter_desc_, wblock->data(),
@@ -197,7 +200,7 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) {
   }, {input.block(), weight_.block()}, {output.block()}, workspace_.block());
 
   if (bias_term_) {
-    output.device()->Exec([output, this](Context *ctx) {
+    output.device()->Exec([output, this](Context * ctx) {
       float beta = 1.f, alpha = 1.0f;
       Block *outblock = output.block(), *bblock = this->bias_.block();
       cudnnAddTensor(ctx->cudnn_handle, &alpha, this->bias_desc_,
@@ -209,7 +212,7 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) {
 }
 
 const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
-    int flag, const Tensor &grad) {
+int flag, const Tensor &grad) {
   CHECK(has_init_cudnn_);
   CHECK_EQ(grad.device()->lang(), kCuda);
   CHECK_EQ(grad.nDim(), 4u);
@@ -225,7 +228,7 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
   // LOG(ERROR) << "backward bias";
   if (bias_term_) {
     db.ResetLike(bias_);
-    dx.device()->Exec([grad, db, this](Context *ctx) {
+    dx.device()->Exec([grad, db, this](Context * ctx) {
       Block *dyblock = grad.block(), *dbblock = db.block();
       float alpha = 1.f, beta = 0.f;
       cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, this->y_desc_,
@@ -234,22 +237,22 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
     }, {grad.block()}, {db.block()});
   }
   // LOG(ERROR) << "backward w";
-  dx.device()->Exec([grad, dw, src_data, this](Context *ctx) {
+  dx.device()->Exec([grad, dw, src_data, this](Context * ctx) {
     Block *inblock = src_data.block(), *dyblock = grad.block(),
-          *dwblock = dw.block();
+           *dwblock = dw.block();
     float alpha = 1.f, beta = 0.f;
     cudnnConvolutionBackwardFilter(
-        ctx->cudnn_handle, &alpha, this->x_desc_, inblock->data(),
-        this->y_desc_, dyblock->data(), this->conv_desc_, this->bp_filter_alg_,
-        this->workspace_.block()->mutable_data(),
-        this->workspace_count_ * sizeof(float), &beta, this->filter_desc_,
-        dwblock->mutable_data());
+      ctx->cudnn_handle, &alpha, this->x_desc_, inblock->data(),
+      this->y_desc_, dyblock->data(), this->conv_desc_, this->bp_filter_alg_,
+      this->workspace_.block()->mutable_data(),
+      this->workspace_count_ * sizeof(float), &beta, this->filter_desc_,
+      dwblock->mutable_data());
   }, {grad.block(), src_data.block()}, {dw.block(), workspace_.block()});
 
   // LOG(ERROR) << "backward src";
-  dx.device()->Exec([dx, grad, this](Context *ctx) {
+  dx.device()->Exec([dx, grad, this](Context * ctx) {
     Block *wblock = this->weight_.block(), *dyblock = grad.block(),
-          *dxblock = dx.block();
+           *dxblock = dx.block();
     float alpha = 1.f, beta = 0.f;
     cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, this->filter_desc_,
                                  wblock->data(), this->y_desc_, dyblock->data(),

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e2092030/src/model/layer/cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc
index fb5fee0..28a52c5 100644
--- a/src/model/layer/cudnn_rnn.cc
+++ b/src/model/layer/cudnn_rnn.cc
@@ -125,8 +125,8 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) {
   CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size));
   dropout_state_ = Tensor(Shape{state_size}, dev, kChar);
   CUDNN_CHECK(cudnnSetDropoutDescriptor(
-      dropout_desc_, ctx->cudnn_handle, 1 - dropout_,  // keep probability
-      dropout_state_.block()->mutable_data(), state_size, seed_));
+                dropout_desc_, ctx->cudnn_handle, 1 - dropout_,  // keep probability
+                dropout_state_.block()->mutable_data(), state_size, seed_));
 
   CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_));
   cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
@@ -144,10 +144,15 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) {
     rnn_mode = CUDNN_RNN_TANH;
   else if (rnn_mode_ == "gru")
     rnn_mode = CUDNN_GRU;
+#ifdef CUDNN_MAJOR == 5
+  CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_size_, num_stacks_,
+                                    dropout_desc_, input_mode, direction,
+                                    rnn_mode, dtype_));
+#else
   CUDNN_CHECK(cudnnSetRNNDescriptor(ctx->cudnn_handle, rnn_desc_, hidden_size_, num_stacks_,
                                     dropout_desc_, input_mode, direction,
                                     rnn_mode, CUDNN_RNN_ALGO_STANDARD, dtype_));
-
+#endif
   size_t weight_size;
   CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0],
                                     &weight_size, dtype_));
@@ -199,7 +204,7 @@ void CudnnRNN::UpdateSpaces(size_t seq_length, shared_ptr<Device> dev) {
   }
 
   CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnn_desc_,
-                                             seq_length, x_descs_, &count));
+              seq_length, x_descs_, &count));
   if (reserve_space_.Size() != count) {
     reserve_space_ = Tensor(Shape{count}, dev, kChar);
     // reserve_space_.SetValue(0);
@@ -263,8 +268,8 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
 
   if (rnn_desc_ != nullptr)
     CHECK_EQ(dtype_, GetCudnnDataType(dtype))
-      << "Cannot change cudnn data type during training from " << dtype_
-      << " to " << GetCudnnDataType(dtype);
+        << "Cannot change cudnn data type during training from " << dtype_
+        << " to " << GetCudnnDataType(dtype);
   else
     dtype_ = GetCudnnDataType(dtype);
 
@@ -303,57 +308,57 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
   // LOG(INFO) << "hidden size " << hy.Size();
   // LOG(INFO) << "weight size " << weight_.Size() << " value " << weight_.L1();
   Block *inb = input.block(), *outb = output.block(),
-        *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(),
-        *hyb = hy.block(), *cyb = cy.block(),
-        *wspace = this->workspace_.block(),
-        *rspace = this->reserve_space_.block();
+         *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(),
+          *hyb = hy.block(), *cyb = cy.block(),
+           *wspace = this->workspace_.block(),
+            *rspace = this->reserve_space_.block();
   if (flag & kTrain) {
     CHECK_EQ(reserve_space_.device()->lang(), kCuda);
     CHECK_EQ(did, reserve_space_.device()->id());
     dev->Exec(
-        [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, this](Context *ctx) {
-        // clang-format off
-        cudnnRNNForwardTraining(
-            ctx->cudnn_handle,
-            this->rnn_desc_,
-            this->seq_length_,
-            this->x_descs_, inb->data(),
-            this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
-            this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
-            this->weight_desc_, wb->data(),
-            this->y_descs_, outb->mutable_data(),
-            this->hy_desc_, hyb->mutable_data(),
-            this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
-            wspace->mutable_data(),
-            this->workspace_.Size(), rspace->mutable_data(),
-            this->reserve_space_.Size());
-        // clang-format on
-        },
-        {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace});
+    [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, this](Context * ctx) {
+      // clang-format off
+      cudnnRNNForwardTraining(
+        ctx->cudnn_handle,
+        this->rnn_desc_,
+        this->seq_length_,
+        this->x_descs_, inb->data(),
+        this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+        this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
+        this->weight_desc_, wb->data(),
+        this->y_descs_, outb->mutable_data(),
+        this->hy_desc_, hyb->mutable_data(),
+        this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
+        wspace->mutable_data(),
+        this->workspace_.Size(), rspace->mutable_data(),
+        this->reserve_space_.Size());
+      // clang-format on
+    },
+    {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace});
     buf_.push(input);
     buf_.push(output);
     buf_.push(hx);
     buf_.push(cx);
   } else {
-    dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, this](Context *ctx) {
+    dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, this](Context * ctx) {
       // clang-format off
       cudnnRNNForwardInference(
-          ctx->cudnn_handle,
-          this->rnn_desc_,
-          this->seq_length_,
-          this->x_descs_, inb->data(),
-          this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
-          this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
-          this->weight_desc_, wb->data(),
-          this->y_descs_, outb->mutable_data(),
-          this->hy_desc_, hyb->mutable_data(),
-          this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
-          wspace->mutable_data(), this->workspace_.Size());
+        ctx->cudnn_handle,
+        this->rnn_desc_,
+        this->seq_length_,
+        this->x_descs_, inb->data(),
+        this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+        this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
+        this->weight_desc_, wb->data(),
+        this->y_descs_, outb->mutable_data(),
+        this->hy_desc_, hyb->mutable_data(),
+        this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
+        wspace->mutable_data(), this->workspace_.Size());
       // clang-format on
     }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace});
   }
   auto outputs =
-      SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output);
+    SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output);
   outputs.push_back(hy);
   if (has_cell_) outputs.push_back(cy);
   return outputs;
@@ -361,7 +366,7 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) {
 
 // TODO(wangwei) check Tensor device to be on cuda?
 const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward(
-    int flag, const vector<Tensor> &grads) {
+int flag, const vector<Tensor> &grads) {
   // dhy (and dcy) is at last
   const Tensor cx = buf_.top();  // cannot use const Tensor& due to pop()
   buf_.pop();
@@ -395,45 +400,45 @@ const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward(
     dcx.ResetLike(dhx);
   dw.SetValue(0.0f);
   Block *yb = y.block(), *dyb = dy.block(), *dhyb = dhy.block(),
-        *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(),
-        *wb = weight_.block(), *dwb = dw.block(), *hxb = hx.block(),
-        *dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(),
-        *wspace = workspace_.block(), *rspace = reserve_space_.block();
+         *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(),
+          *wb = weight_.block(), *dwb = dw.block(), *hxb = hx.block(),
+           *dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(),
+            *wspace = workspace_.block(), *rspace = reserve_space_.block();
 
   y.device()->Exec(
-      [yb, dyb, dhyb, dcyb, xb, cxb, wb, dwb, hxb, dxb, dhxb, dcxb, wspace,
-       rspace, this](Context *ctx) {
-        // clang-format off
-        cudnnRNNBackwardData(
-            ctx->cudnn_handle,
-            this->rnn_desc_,
-            this->seq_length_,
-            this->y_descs_, yb->data(),
-            this->dy_descs_, dyb->data(),
-            this->dhy_desc_, dhyb == nullptr ? nullptr : dhyb->data(),
-            this->dcy_desc_, dcyb == nullptr ? nullptr : dcyb->data(),
-            this->weight_desc_, wb->data(),
-            this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
-            this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
-            this->dx_descs_, dxb->mutable_data(),
-            this->dhx_desc_, dhxb->mutable_data(),
-            this->dcx_desc_, dcxb == nullptr ? nullptr : dcxb->mutable_data(),
-            wspace->mutable_data(), this->workspace_.Size(),
-            rspace->mutable_data(), this->reserve_space_.Size());
-        cudnnRNNBackwardWeights(
-            ctx->cudnn_handle,
-            this->rnn_desc_,
-            this->seq_length_,
-            this->x_descs_, xb->data(),
-            this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
-            this->y_descs_, yb->data(),
-            wspace->data(), this->workspace_.Size(),
-            this->dweight_desc_, dwb->mutable_data(),
-            rspace->data(), this->reserve_space_.Size());
-        // clang-format on
-      },
-      {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace},
-      {dxb, dwb, dhxb, dcxb, wspace, rspace});
+    [yb, dyb, dhyb, dcyb, xb, cxb, wb, dwb, hxb, dxb, dhxb, dcxb, wspace,
+  rspace, this](Context * ctx) {
+    // clang-format off
+    cudnnRNNBackwardData(
+      ctx->cudnn_handle,
+      this->rnn_desc_,
+      this->seq_length_,
+      this->y_descs_, yb->data(),
+      this->dy_descs_, dyb->data(),
+      this->dhy_desc_, dhyb == nullptr ? nullptr : dhyb->data(),
+      this->dcy_desc_, dcyb == nullptr ? nullptr : dcyb->data(),
+      this->weight_desc_, wb->data(),
+      this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+      this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
+      this->dx_descs_, dxb->mutable_data(),
+      this->dhx_desc_, dhxb->mutable_data(),
+      this->dcx_desc_, dcxb == nullptr ? nullptr : dcxb->mutable_data(),
+      wspace->mutable_data(), this->workspace_.Size(),
+      rspace->mutable_data(), this->reserve_space_.Size());
+    cudnnRNNBackwardWeights(
+      ctx->cudnn_handle,
+      this->rnn_desc_,
+      this->seq_length_,
+      this->x_descs_, xb->data(),
+      this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+      this->y_descs_, yb->data(),
+      wspace->data(), this->workspace_.Size(),
+      this->dweight_desc_, dwb->mutable_data(),
+      rspace->data(), this->reserve_space_.Size());
+    // clang-format on
+  },
+  {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace},
+  {dxb, dwb, dhxb, dcxb, wspace, rspace});
 
   vector <Tensor> param_grad{dw};
   auto data_grads = SplitOutput(num_dy, input_size_, grads, dx);


[2/2] incubator-singa git commit: SINGA-371 Implement functional operations in c++ for autograd

Posted by wa...@apache.org.
SINGA-371 Implement functional operations in c++ for autograd

fix some bugs and update the example for autograd


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e16cea12
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e16cea12
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e16cea12

Branch: refs/heads/master
Commit: e16cea129b688c804afe87b3bd1b6a82e5f5ca5f
Parents: e209203
Author: wang wei <wa...@comp.nus.edu.sg>
Authored: Sat Jul 7 22:00:07 2018 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Sun Jul 8 16:01:45 2018 +0800

----------------------------------------------------------------------
 examples/autograd/mnist_cnn.py       |  25 ++++---
 python/singa/autograd.py             |   6 +-
 src/api/model_operation.i            |  29 ++++----
 src/core/tensor/tensor.cc            |   4 +-
 src/core/tensor/tensor_math_cuda.h   | 114 +++++++++++++++++-------------
 src/model/layer/cudnn_convolution.cc |   2 +-
 src/model/operation/convolution.cc   |   9 ++-
 src/model/operation/convolution.h    |   1 +
 tool/conda/singa/build.sh            |   3 +-
 tool/conda/singa/meta.yaml           |   2 +-
 10 files changed, 112 insertions(+), 83 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index a82f64c..5b4e608 100755
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -21,13 +21,11 @@ import numpy as np
 import argparse
 import os
 
-import singa
+from singa import device
 from singa import tensor
 from singa import autograd
 from singa import optimizer
 
-singa.layer.engine = 'singacpp'
-
 
 def load_data(path):
     f = np.load(path)
@@ -75,11 +73,18 @@ if __name__ == '__main__':
 
     parser = argparse.ArgumentParser(description='Train CNN over MNIST')
     parser.add_argument('file_path', type=str, help='the dataset path')
+    parser.add_argument('--use_cpu', action='store_true')
     args = parser.parse_args()
 
     assert os.path.exists(args.file_path), \
-        'Pls download the MNIST dataset from ' \
-        'https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz'
+        'Pls download the MNIST dataset from '
+
+    if args.use_cpu:
+        print('Using CPU')
+        dev = device.get_default_device()
+    else:
+        print('Using GPU')
+        dev = device.create_cuda_gpu()
 
     train, test = load_data(args.file_path)
 
@@ -119,16 +124,16 @@ if __name__ == '__main__':
     autograd.training = True
     for epoch in range(epochs):
         for i in range(batch_number):
-            inputs = tensor.Tensor(data=x_train[i * 100:(1 + i) * 100, :])
-            targets = tensor.Tensor(data=y_train[i * 100:(1 + i) * 100, :])
+            inputs = tensor.Tensor(device=dev, data=x_train[i * 100:(1 + i) * 100])
+            targets = tensor.Tensor(device=dev, data=y_train[i * 100:(1 + i) * 100])
 
             loss, y = forward(inputs, targets)
 
-            accuracy_rate = accuracy(autograd.ctensor2numpy(y.data),
-                                     autograd.ctensor2numpy(targets.data))
+            accuracy_rate = accuracy(tensor.to_numpy(y),
+                                     tensor.to_numpy(targets))
             if (i % 5 == 0):
                 print('accuracy is:', accuracy_rate, 'loss is:',
-                      autograd.ctensor2numpy(loss.data)[0])
+                      tensor.to_numpy(loss)[0])
 
             in_grads = autograd.backward(loss)
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 80209ff..9fd8b4d 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -741,12 +741,10 @@ class Conv2D(NewLayer):
         else:
             if not hasattr(self, 'handle'):
                 self.handle = singa.CudnnConvHandle(x.data, self.kernel_size, self.stride,
-                                                    self.padding, self.in_channels, self.out_channels, self.bias,
-                                                    self.inner_params['workspace_MB_limit'] * 1024 * 1024, self.inner_params['cudnn_prefer'])
+                                                    self.padding, self.in_channels, self.out_channels, self.bias)
             elif x.shape[0] != self.handle.batchsize:
                 self.handle = singa.CudnnConvHandle(x.data, self.kernel_size, self.stride,
-                                                    self.padding, self.in_channels, self.out_channels, self.bias,
-                                                    self.inner_params['workspace_MB_limit'] * 1024 * 1024, self.inner_params['cudnn_prefer'])
+                                                    self.padding, self.in_channels, self.out_channels, self.bias)
         self.handle.device_id = x.device.id()
 
         y = conv2d(x, self.W, self.b, self.handle)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index 2c13a3b..3858a2b 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -1,9 +1,12 @@
 %module model_operation
 
+%include "config.i"
+%include "std_vector.i"
+%include "std_string.i"
 %{
 #include "../src/model/operation/convolution.h"
 %}
-namespace singa{
+namespace singa {
 
 class ConvHandle {
  public:
@@ -15,15 +18,24 @@ class ConvHandle {
   size_t batchsize;
 };
 
-struct CudnnConvHandle{
+Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b, const ConvHandle &ch);
+
+Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const ConvHandle &ch);
+
+Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const ConvHandle &ch);
+
+Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch);
+
+#if USE_CUDNN
+class CudnnConvHandle: public ConvHandle {
  public:
-	CudnnConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
+  CudnnConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
                   const std::vector<size_t>& stride, const std::vector<size_t>& padding,
                   const size_t in_channels, const size_t out_channels,
                   const bool bias, const size_t workspace_byte_limit = 1024 * 1024 * 1024,
                   const std::string& prefer = "fastest");
   bool bias_term;
-  size_t batchsize; 
+  size_t batchsize;
 };
 
 Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const CudnnConvHandle &cch);
@@ -34,13 +46,6 @@ Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons
 
 Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle &cch);
 
-
-Tensor CpuConvForward(const Tensor &x, Tensor &W,  Tensor &b, const ConvHandle &ch);
-
-Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const ConvHandle &ch);
-
-Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const ConvHandle &ch);
-
-Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch);
+#endif  // USE_CUDNN
 
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index e0a9ecb..05db7cf 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -346,7 +346,7 @@ Tensor Tensor::Repeat(vector<size_t> repeats, int axis, std::shared_ptr<Device>
   } else {
     if (repeats.size() == 1){
       total_repeats = repeats[0];
-      for (int i = 0; i < shape_.size(); i++) {
+      for (size_t i = 0; i < shape_.size(); i++) {
         if (i == axis) {
           tshape.push_back(shape_[i] * total_repeats);
         } else {
@@ -363,7 +363,7 @@ Tensor Tensor::Repeat(vector<size_t> repeats, int axis, std::shared_ptr<Device>
         }
         total_repeats += repeats[i];
       }
-      for (int i = 0; i < shape_.size(); i++){
+      for (size_t i = 0; i < shape_.size(); i++){
         if (i == axis) {
           tshape.push_back(total_repeats);
         } else{

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index a1b9381..2a43468 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -791,6 +791,12 @@ void Sqrt<float, lang::Cuda>(const Tensor& in, Tensor* out,
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
 
+#if CUDNN_MAJOR < 7
+  size_t num = in.Size();
+  cuda::sqrt(num, inPtr, outPtr, ctx->stream);
+
+#else
+
   float alpha1 = 1.0;
   float alpha2 = 0.0;
   float beta = 0.0;
@@ -800,6 +806,7 @@ void Sqrt<float, lang::Cuda>(const Tensor& in, Tensor* out,
                 (void*)(&alpha2), in_desc, inPtr,
                 (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
                ));
+#endif  // CUDNN_MAJOR < 7
 }
 
 /// Element-wise operation, out[i]=in[i]^2
@@ -833,54 +840,6 @@ void Square<float, lang::Cuda>(const Tensor& in, Tensor* out,
 //   // cuda::sum(num, inPtr, out, ctx->stream);
 // }
 
-template <>
-void Sum<float, lang::Cuda>(const Tensor& in, float* out,
-                            Context* ctx) {
-  const float* inPtr = static_cast<const float*>(in.block()->data());
-
-  //reduce all axes to 1 for cudnnReduce, e.g. Tensor A with shape (2,4) will be reduced to (1)
-  Shape reduced_shape = {1};
-  Tensor t(reduced_shape, in.device(), in.data_type());
-  float* tPtr = static_cast<float*>(t.block()->mutable_data());
-  vector<int> reduce_all_axes = generate_shape_cuda(in);
-  for (size_t n = 0; n < reduce_all_axes.size(); ++n) {
-    reduce_all_axes[n] = 1;
-  }
-
-  //reduce_desc
-  cudnnReduceTensorDescriptor_t reduce_desc;
-  cudnnReduceTensorOp_t reduce_op = CUDNN_REDUCE_TENSOR_ADD;
-  cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT;
-  cudnnNanPropagation_t cudnn_propagation = CUDNN_PROPAGATE_NAN;
-  cudnnReduceTensorIndices_t cudnn_indices = CUDNN_REDUCE_TENSOR_NO_INDICES;
-  cudnnIndicesType_t cudnn_indices_type = CUDNN_32BIT_INDICES;
-  check_cudnn(cudnnCreateReduceTensorDescriptor(&reduce_desc));
-  check_cudnn(cudnnSetReduceTensorDescriptor(reduce_desc, reduce_op, cudnn_dtype,
-                                 cudnn_propagation, cudnn_indices, cudnn_indices_type));
-
-  //instantiate 2 new tensors to use new blocks as memory instead of cudaMalloc
-  size_t reduction_size_int = Product(in.shape());
-  Shape reduction_size = {reduction_size_int * 100};
-  Tensor indices(reduction_size, in.device(), in.data_type());
-  Tensor workspace(reduction_size, in.device(), in.data_type());
-  size_t indices_bytes = indices.block()->size() * 100;
-  size_t workspace_bytes = workspace.block()->size() * 100;
-  size_t* indicesPtr = static_cast<size_t*>(indices.block()->mutable_data());
-  float* workspacePtr = static_cast<float*>(workspace.block()->mutable_data());
-  //void* indicesPtr{nullptr}; void* workspacePtr{nullptr};
-  //cudaMalloc(&indicesPtr, indices_bytes); cudaMalloc(&workspacePtr, workspace_bytes);
-
-  float alpha = 1.0;
-  float beta = 0.0;
-  check_cudnn(cudnnReduceTensor(ctx->cudnn_handle, reduce_desc,
-                    indicesPtr, indices_bytes, workspacePtr, workspace_bytes,
-                    (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
-                    (void*)(&beta), generate_tensor_nd_desc(t), tPtr
-                   ));
-
-  *out = tPtr[0];
-}
-
 
 /// Element-wise operation, out[i]=tanh([in[i])
 // template <>
@@ -949,7 +908,7 @@ void Transform<float, lang::Cuda>(const Tensor& in, Tensor* out,
                          (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
                          (void*)(&beta), generate_tensor_nd_desc(*out), outPtr
                         ));
-  
+
 }
 
 // ================Random functions===========================================
@@ -1233,6 +1192,63 @@ void RowMax<float, lang::Cuda>(const Tensor& in, Tensor* out,
   }
 }
 
+
+// must put this function after Set and Dot functions due to the error from
+// instantiation before specialization
+template <>
+void Sum<float, lang::Cuda>(const Tensor& in, float* out,
+                            Context* ctx) {
+#if CUDNN_MAJOR < 7
+  Tensor one(in.shape(), in.device(), in.data_type());
+  Set<float, lang::Cuda>(float(1), &one, ctx);
+  Dot<float, lang::Cuda>(in, one, out, ctx);
+#else
+  const float* inPtr = static_cast<const float*>(in.block()->data());
+  //reduce all axes to 1 for cudnnReduce, e.g. Tensor A with shape (2,4) will be reduced to (1)
+  Shape reduced_shape = {1};
+  Tensor t(reduced_shape, in.device(), in.data_type());
+  float* tPtr = static_cast<float*>(t.block()->mutable_data());
+  vector<int> reduce_all_axes = generate_shape_cuda(in);
+  for (size_t n = 0; n < reduce_all_axes.size(); ++n) {
+    reduce_all_axes[n] = 1;
+  }
+
+  //reduce_desc
+  cudnnReduceTensorDescriptor_t reduce_desc;
+  cudnnReduceTensorOp_t reduce_op = CUDNN_REDUCE_TENSOR_ADD;
+  cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT;
+  cudnnNanPropagation_t cudnn_propagation = CUDNN_PROPAGATE_NAN;
+  cudnnReduceTensorIndices_t cudnn_indices = CUDNN_REDUCE_TENSOR_NO_INDICES;
+  cudnnIndicesType_t cudnn_indices_type = CUDNN_32BIT_INDICES;
+  check_cudnn(cudnnCreateReduceTensorDescriptor(&reduce_desc));
+  check_cudnn(cudnnSetReduceTensorDescriptor(reduce_desc, reduce_op, cudnn_dtype,
+                                 cudnn_propagation, cudnn_indices, cudnn_indices_type));
+
+  //instantiate 2 new tensors to use new blocks as memory instead of cudaMalloc
+  size_t reduction_size_int = Product(in.shape());
+  Shape reduction_size = {reduction_size_int * 100};
+  Tensor indices(reduction_size, in.device(), in.data_type());
+  Tensor workspace(reduction_size, in.device(), in.data_type());
+  size_t indices_bytes = indices.block()->size() * 100;
+  size_t workspace_bytes = workspace.block()->size() * 100;
+  size_t* indicesPtr = static_cast<size_t*>(indices.block()->mutable_data());
+  float* workspacePtr = static_cast<float*>(workspace.block()->mutable_data());
+  //void* indicesPtr{nullptr}; void* workspacePtr{nullptr};
+  //cudaMalloc(&indicesPtr, indices_bytes); cudaMalloc(&workspacePtr, workspace_bytes);
+
+  float alpha = 1.0;
+  float beta = 0.0;
+  check_cudnn(cudnnReduceTensor(ctx->cudnn_handle, reduce_desc,
+                    indicesPtr, indices_bytes, workspacePtr, workspace_bytes,
+                    (void*)(&alpha), generate_tensor_nd_desc(in), inPtr,
+                    (void*)(&beta), generate_tensor_nd_desc(t), tPtr
+                   ));
+
+  *out = tPtr[0];
+#endif  // CUDNN_MAJOR < 7
+}
+
+
 }  // namespace singa
 
 #endif  // USE_CUDA

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index 1b12f93..0aed832 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -79,7 +79,7 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
   CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, pad_h_, pad_w_,
               stride_h_, stride_w_, 1, 1,  // dilation x and y
               CUDNN_CROSS_CORRELATION
-#if CUDNN_MAJOR == 5
+#if CUDNN_MAJOR >= 7
               , GetCudnnDataType(dtype)
 #endif  // CUDNN_MAJOR
                                              ));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/src/model/operation/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution.cc b/src/model/operation/convolution.cc
index e36df43..f700203 100755
--- a/src/model/operation/convolution.cc
+++ b/src/model/operation/convolution.cc
@@ -199,8 +199,11 @@ CudnnConvHandle::CudnnConvHandle(const Tensor &input, const std::vector<size_t>&
                                            num_filters, 1, 1));
   CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc, pad_h, pad_w,
               stride_h, stride_w, 1, 1,
-              CUDNN_CROSS_CORRELATION,
-              GetCudnnDataType(dtype)));
+              CUDNN_CROSS_CORRELATION
+#if CUDNN_MAJOR >= 7
+              , GetCudnnDataType(dtype)
+#endif
+              ));
   CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc, GetCudnnDataType(dtype),
                                          CUDNN_TENSOR_NCHW, num_filters,
                                          channels, kernel_h, kernel_w));
@@ -381,4 +384,4 @@ Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle
 }
 #endif  // USE_CUDNN
 
-}  // namespace singa
\ No newline at end of file
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/src/model/operation/convolution.h
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution.h b/src/model/operation/convolution.h
index 62ff254..9da881f 100755
--- a/src/model/operation/convolution.h
+++ b/src/model/operation/convolution.h
@@ -5,6 +5,7 @@
 #include <vector>
 #include "singa/core/tensor.h"
 #include "singa/utils/logging.h"
+#include "singa/singa_config.h"
 
 #ifdef USE_CUDNN
 #include <cudnn.h>

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/tool/conda/singa/build.sh
----------------------------------------------------------------------
diff --git a/tool/conda/singa/build.sh b/tool/conda/singa/build.sh
index 91a2f3b..b54e451 100644
--- a/tool/conda/singa/build.sh
+++ b/tool/conda/singa/build.sh
@@ -23,12 +23,13 @@ export CMAKE_PREFIX_PATH=$PREFIX:$CMAKE_PREFIX_PATH
 export CMAKE_INCLUDE_PATH=$PREFIX/include:$CMAKE_INCLUDE_PATH
 export CMAKE_LIBRARY_PATH=$PREFIX/lib:$CMAKE_LIBRARY_PATH
 
+echo "----------------------$CUDNN_PATH---------------"
 
 if [ -z ${CUDNN_PATH+x} ]; then
 	USE_CUDA=OFF
 else
 	USE_CUDA=ON
-	cp -r $CUDNN_PATH/include $PREFIX/include 
+	cp $CUDNN_PATH/include/* $PREFIX/include/ 
 	cp -P $CUDNN_PATH/lib64/libcudnn.so* $PREFIX/lib/
 fi
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e16cea12/tool/conda/singa/meta.yaml
----------------------------------------------------------------------
diff --git a/tool/conda/singa/meta.yaml b/tool/conda/singa/meta.yaml
index 997341c..ee76636 100644
--- a/tool/conda/singa/meta.yaml
+++ b/tool/conda/singa/meta.yaml
@@ -22,7 +22,7 @@ package:
   version: "{{ GIT_DESCRIBE_TAG }}"
 
 source:
-  git_url: https://github.com/apache/incubator-singa.git
+  path: /home/wangwei/incubator-singa/
 
 build:
   number: 0