You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@singa.apache.org by GitBox <gi...@apache.org> on 2020/06/09 15:55:37 UTC

[GitHub] [singa] XJDKC commented on a change in pull request #722: cudnn lstm

XJDKC commented on a change in pull request #722:
URL: https://github.com/apache/singa/pull/722#discussion_r436745362



##########
File path: examples/qabot/model.py
##########
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import autograd
+from singa import layer
+from singa import model
+from singa import tensor
+from singa import device
+
+class QAModel(model.Model):
+    def __init__(self, hidden_size, num_layers=1, rnn_mode="lstm", batch_first=True):
+        super(QAModel, self).__init__()
+        self.lstm_q = layer.CudnnRNN(hidden_size=hidden_size,
+                                   num_layers=num_layers,
+                                   bidirectional=True,
+                                   return_sequences=False,
+                                   rnn_mode=rnn_mode,
+                                   batch_first=batch_first)
+        self.lstm_a = layer.CudnnRNN(hidden_size=hidden_size,
+                                   num_layers=num_layers,
+                                   bidirectional=True,
+                                   return_sequences=False,
+                                   rnn_mode=rnn_mode,
+                                   batch_first=batch_first)
+
+    def forward(self, q, a_batch):
+        q = self.lstm_q(q) # BS, Hidden*2
+        a_batch = self.lstm_a(a_batch) # {2, hidden*2}
+        bs_a = int(a_batch.shape[0]/2) # cut concated a-a+ to half and half

Review comment:
       Does bs_a keep changing during training?

##########
File path: examples/qabot/model.py
##########
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import autograd
+from singa import layer
+from singa import model
+from singa import tensor
+from singa import device
+
+class QAModel(model.Model):
+    def __init__(self, hidden_size, num_layers=1, rnn_mode="lstm", batch_first=True):
+        super(QAModel, self).__init__()
+        self.lstm_q = layer.CudnnRNN(hidden_size=hidden_size,
+                                   num_layers=num_layers,
+                                   bidirectional=True,
+                                   return_sequences=False,
+                                   rnn_mode=rnn_mode,
+                                   batch_first=batch_first)
+        self.lstm_a = layer.CudnnRNN(hidden_size=hidden_size,
+                                   num_layers=num_layers,
+                                   bidirectional=True,
+                                   return_sequences=False,
+                                   rnn_mode=rnn_mode,
+                                   batch_first=batch_first)
+
+    def forward(self, q, a_batch):
+        q = self.lstm_q(q) # BS, Hidden*2
+        a_batch = self.lstm_a(a_batch) # {2, hidden*2}
+        bs_a = int(a_batch.shape[0]/2) # cut concated a-a+ to half and half
+        a_pos, a_neg = autograd.split(a_batch, 0, [bs_a,bs_a])
+        sim_pos = autograd.cossim(q, a_pos)
+        sim_neg = autograd.cossim(q, a_neg)
+        return sim_pos, sim_neg
+
+    def train_one_batch(self, q, a):
+        out = self.forward(q, a)
+        loss = autograd.qa_lstm_loss(out[0], out[1])
+        self.optimizer.backward_and_update(loss)
+        return out, loss
+
+
+
+class MLP(model.Model):
+    def __init__(self):
+        super(MLP, self).__init__()
+        self.linear1 = layer.Linear(500)
+        self.relu = layer.ReLU()
+        self.linear2 = layer.Linear(2)
+
+    def forward(self, q, a):
+        q=autograd.reshape(q, (q.shape[0], -1))

Review comment:
       Can use the layers in layer.py.(Reshape, Cat and SoftmaxCrossEntropy)

##########
File path: src/model/operation/rnn.cc
##########
@@ -0,0 +1,407 @@
+/*********************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ ************************************************************/
+
+#include "rnn.h"
+namespace singa {
+#ifdef USE_CUDNN
+CudnnRNNHandle::CudnnRNNHandle(const Tensor &x, const int hidden_size,
+                               const int mode, const int num_layers,
+                               const int bias, const float dropout,
+                               const int bidirectional)
+    : bias(bias),
+      dropout(dropout),
+      bidirectional(bidirectional),
+      hidden_size(hidden_size),
+      mode(mode),
+      num_layers(num_layers) {
+  CHECK_EQ(bias, 1) << "Current implementation always include bias";
+  CHECK(bidirectional == 0 || bidirectional == 1)
+      << "bidirectional should be 0 or 1 not " << bidirectional;
+
+  dev = x.device();
+  ctx = x.device()->context(0);
+
+  seq_length = x.shape(0);
+  batch_size = x.shape(1);
+  feature_size = x.shape(2);
+
+  cudnnRNNAlgo = CUDNN_RNN_ALGO_STANDARD;
+  cudnnDataType = CUDNN_DATA_FLOAT;
+
+  cudnnTensorDescriptor_t *xDesc = new cudnnTensorDescriptor_t[seq_length];
+  init_xDesc(xDesc, *this);
+
+  init_dropout_desc();
+  init_rnn_desc();
+  init_parameters_desc(xDesc);
+  init_workspace(xDesc);
+}
+
+void CudnnRNNHandle::init_workspace(cudnnTensorDescriptor_t *xDesc) {
+  /* workspace data */
+  // Need for every pass
+  CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnnDesc, seq_length,
+                                       xDesc, &workspace_size));
+  // Only needed in training, shouldn't be touched between passes.
+  CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnnDesc,
+                                             seq_length, xDesc, &reserve_size));
+
+  workspace = Tensor(Shape{workspace_size}, dev);
+  reserve_space = Tensor(Shape{reserve_size}, dev);
+}
+
+void CudnnRNNHandle::init_parameters_desc(cudnnTensorDescriptor_t *xDesc) {
+  /* weights size
+   *   depends on rnn desc */
+  CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnnDesc, xDesc[0],
+                                    &weights_size, cudnnDataType));
+  /* weights desc
+   *   depends on weights size */
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&wDesc));
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&dwDesc));
+
+  int dimW[3];
+  dimW[0] = weights_size / sizeof(float);  // TODO different types
+  dimW[1] = 1;
+  dimW[2] = 1;
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(wDesc, cudnnDataType,
+                                         CUDNN_TENSOR_NCHW, 3, dimW));
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(dwDesc, cudnnDataType,
+                                         CUDNN_TENSOR_NCHW, 3, dimW));
+}
+
+void CudnnRNNHandle::init_rnn_desc() {
+  /* rnn desc */
+  CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnnDesc));
+  if (mode == 0)
+    RNNMode = CUDNN_RNN_RELU;
+  else if (mode == 1)
+    RNNMode = CUDNN_RNN_TANH;
+  else if (mode == 2)
+    RNNMode = CUDNN_LSTM;
+  else if (mode == 3)
+    RNNMode = CUDNN_GRU;
+  CUDNN_CHECK(cudnnSetRNNDescriptor(
+      ctx->cudnn_handle, rnnDesc, hidden_size, num_layers, dropoutDesc,
+      CUDNN_LINEAR_INPUT,
+      bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, RNNMode,
+      cudnnRNNAlgo,  // CUDNN_RNN_ALGO_STANDARD,
+      cudnnDataType));
+}
+void CudnnRNNHandle::init_dropout_desc() {
+  /* drop out */
+  size_t seed = 0x1234567;
+  CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropoutDesc));
+  size_t stateSize;
+  CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &stateSize));
+  CUDA_CHECK(cudaMalloc(&states, stateSize));
+  CUDNN_CHECK(cudnnSetDropoutDescriptor(dropoutDesc, ctx->cudnn_handle, dropout,
+                                        states, stateSize, seed));
+}
+
+// reserve for masking
+Tensor CudnnRNNHandle::merge_inputs(size_t num, const vector<Tensor> &in) {
+  if (num == 1) return in.at(0);
+  size_t size = 0;
+  for (size_t i = 0; i < num; i++) size += in.at(i).Size();
+  Tensor out(Shape{size}, in.at(0).device(), in.at(0).data_type());
+  for (size_t i = 0, offset = 0; i < num; i++) {
+    CopyDataToFrom(&out, in.at(i), in.at(i).Size(), offset);
+    offset += in.at(i).Size();
+  }
+  return out;
+}
+vector<Tensor> CudnnRNNHandle::split_output(size_t num, size_t dim,
+                                            const vector<Tensor> &in,
+                                            const Tensor output) {
+  vector<Tensor> outputs;
+  if (num == 1) {
+    outputs.push_back(Reshape(output, Shape{in.at(0).shape(0), dim}));
+  } else {
+    for (size_t i = 0, offset = 0; offset < output.Size(); i++) {
+      Shape s{in.at(i).shape(0), dim};
+      Tensor out(s, output.device(), output.data_type());
+      CopyDataToFrom(&out, output, out.Size(), 0, offset);
+      outputs.push_back(out);
+      offset += out.Size();
+    }
+    CHECK_EQ(num, outputs.size());
+  }
+  return outputs;
+}
+
+void init_yDesc(cudnnTensorDescriptor_t *yDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.batch_size;
+  dimA[1] = h.bidirectional ? h.hidden_size * 2 : h.hidden_size;
+  dimA[2] = 1;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+
+  for (int i = 0; i < h.seq_length; i++) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&yDesc[i]));
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(yDesc[i], h.cudnnDataType, 3, dimA,
+                                           strideA));
+  }
+}
+
+void init_xDesc(cudnnTensorDescriptor_t *xDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.batch_size;
+  dimA[1] = h.feature_size;
+  dimA[2] = 1;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+
+  for (int i = 0; i < h.seq_length; i++) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&xDesc[i]));
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(xDesc[i], h.cudnnDataType, 3, dimA,
+                                           strideA));
+  }
+}
+
+void init_hc_Desc(cudnnTensorDescriptor_t &hxDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.num_layers * (h.bidirectional ? 2 : 1);
+  dimA[1] = h.batch_size;
+  dimA[2] = h.hidden_size;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&hxDesc));
+  CUDNN_CHECK(
+      cudnnSetTensorNdDescriptor(hxDesc, h.cudnnDataType, 3, dimA, strideA));
+}
+
+vector<Tensor> GpuRNNForwardInference(const Tensor &x, const Tensor &hx,
+                                      const Tensor &cx, const Tensor &W,
+                                      CudnnRNNHandle &h) {
+  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
+  h.seq_length = x.shape(0);
+  h.batch_size = x.shape(1);  // update batch size to accomodate bs change
+  Tensor y(Shape{h.seq_length, h.batch_size,
+                 h.hidden_size * (h.bidirectional ? 2 : 1)},
+           x.device());
+  Tensor hy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  Tensor cy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+
+  y.device()->Exec(
+      [&y, &hy, &cy, &x, &hx, &cx, &W, &h](Context *ctx) {
+        // require desc, [x], hx, cx, w, y, hy, cy
+        cudnnTensorDescriptor_t *xDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        cudnnTensorDescriptor_t *yDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        init_xDesc(xDesc, h);
+        init_yDesc(yDesc, h);
+        cudnnTensorDescriptor_t hxDesc;
+        cudnnTensorDescriptor_t cxDesc;
+        cudnnTensorDescriptor_t hyDesc;
+        cudnnTensorDescriptor_t cyDesc;
+        init_hc_Desc(hxDesc, h);
+        init_hc_Desc(cxDesc, h);
+        init_hc_Desc(hyDesc, h);
+        init_hc_Desc(cyDesc, h);
+
+        auto xptr = x.block()->data();
+        auto hxptr = hx.block()->data();
+        auto cxptr = cx.block()->data();
+        auto Wptr = W.block()->data();
+        auto yptr = y.block()->mutable_data();
+        auto hyptr = hy.block()->mutable_data();
+        auto cyptr = cy.block()->mutable_data();
+        auto wsptr = h.workspace.block()->mutable_data();
+
+        CUDNN_CHECK(cudnnRNNForwardInference(
+            ctx->cudnn_handle, h.rnnDesc, h.seq_length, xDesc, xptr, hxDesc,
+            hxptr, cxDesc, cxptr, h.wDesc, Wptr, yDesc, yptr, hyDesc, hyptr,
+            cyDesc, cyptr, wsptr, h.workspace_size));
+
+        delete[] xDesc;
+        delete[] yDesc;
+      },
+      {x.block(), hx.block(), cx.block(), W.block()},
+      {y.block(), hy.block(), cy.block()});
+  return {y, hy, cy};
+}
+
+vector<Tensor> GpuRNNForwardTraining(const Tensor &x, const Tensor &hx,
+                                     const Tensor &cx, const Tensor &W,
+                                     CudnnRNNHandle &h) {
+  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
+  h.seq_length = x.shape(0);
+  h.batch_size = x.shape(1);  // update batch size to accomodate bs change
+  Tensor y(Shape{h.seq_length, h.batch_size,
+                 h.hidden_size * (h.bidirectional ? 2 : 1)},
+           x.device());
+  Tensor hy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  Tensor cy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  y.device()->Exec(
+      [&y, &hy, &cy, &x, &hx, &cx, &W, &h](Context *ctx) {
+        // require desc, [x], hx, cx, w, y, hy, cy
+        cudnnTensorDescriptor_t *xDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        cudnnTensorDescriptor_t *yDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        init_xDesc(xDesc, h);
+        init_yDesc(yDesc, h);
+        cudnnTensorDescriptor_t hxDesc;
+        cudnnTensorDescriptor_t cxDesc;
+        cudnnTensorDescriptor_t hyDesc;
+        cudnnTensorDescriptor_t cyDesc;
+        init_hc_Desc(hxDesc, h);
+        init_hc_Desc(cxDesc, h);
+        init_hc_Desc(hyDesc, h);
+        init_hc_Desc(cyDesc, h);
+
+        auto xptr = x.block()->data();
+        auto hxptr = hx.block()->data();
+        auto cxptr = cx.block()->data();
+        auto Wptr = W.block()->data();
+        auto yptr = y.block()->mutable_data();
+        auto hyptr = hy.block()->mutable_data();
+        auto cyptr = cy.block()->mutable_data();
+        auto wsptr = h.workspace.block()->mutable_data();
+        auto rsptr = h.reserve_space.block()->mutable_data();
+        CUDNN_CHECK(cudnnRNNForwardTraining(
+            ctx->cudnn_handle, h.rnnDesc, h.seq_length, xDesc, xptr, hxDesc,
+            hxptr, cxDesc, cxptr, h.wDesc, Wptr, yDesc, yptr, hyDesc, hyptr,
+            cyDesc, cyptr, wsptr, h.workspace_size, rsptr, h.reserve_size));
+        delete[] xDesc;
+        delete[] yDesc;
+      },
+      {x.block(), hx.block(), cx.block(), W.block()},
+      {y.block(), hy.block(), cy.block()});
+
+  return {y, hy, cy};
+}
+
+vector<Tensor> GpuRNNBackwardx(const Tensor &y, const Tensor &dy,
+                               const Tensor &dhy, const Tensor &dcy,
+                               const Tensor &W, const Tensor &hx,
+                               const Tensor &cx, CudnnRNNHandle &h) {
+  Tensor dx(Shape{h.seq_length, h.batch_size, h.feature_size}, y.device());
+  Tensor dhx(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                   h.hidden_size},
+             y.device());
+  Tensor dcx(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                   h.hidden_size},
+             y.device());
+  dx.device()->Exec(
+      [&dx, &dhx, &dcx, &y, &dy, &dhy, &dcy, &W, &hx, &cx, &h](Context *ctx) {
+        // require desc:
+        //      [dx], hx, dhx, cx, dcx, w,
+        // [y], [dy],     dhy,     dcy
+        cudnnTensorDescriptor_t *dxDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        cudnnTensorDescriptor_t *yDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        cudnnTensorDescriptor_t *dyDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        init_yDesc(yDesc, h);
+        init_xDesc(dxDesc, h);
+        init_yDesc(dyDesc, h);
+        cudnnTensorDescriptor_t hxDesc;
+        cudnnTensorDescriptor_t cxDesc;
+        cudnnTensorDescriptor_t dhxDesc;
+        cudnnTensorDescriptor_t dcxDesc;
+        cudnnTensorDescriptor_t dhyDesc;
+        cudnnTensorDescriptor_t dcyDesc;
+        init_hc_Desc(hxDesc, h);
+        init_hc_Desc(cxDesc, h);
+        init_hc_Desc(dhxDesc, h);
+        init_hc_Desc(dcxDesc, h);
+        init_hc_Desc(dhyDesc, h);
+        init_hc_Desc(dcyDesc, h);
+
+        auto dxptr = dx.block()->mutable_data();
+        auto hxptr = hx.block()->data();
+        auto dhxptr = dhx.block()->mutable_data();
+        auto cxptr = cx.block()->data();
+        auto dcxptr = dcx.block()->mutable_data();
+        auto Wptr = W.block()->data();
+        auto yptr = y.block()->data();
+        auto dyptr = dy.block()->data();
+        auto dhyptr = dhy.block()->data();
+        auto dcyptr = dcy.block()->data();
+        auto wsptr = h.workspace.block()->mutable_data();
+        auto rsptr = h.reserve_space.block()->mutable_data();
+
+        CUDNN_CHECK(cudnnRNNBackwardData(
+            ctx->cudnn_handle, h.rnnDesc, h.seq_length, yDesc, yptr, dyDesc,
+            dyptr, dhyDesc, dhyptr, dcyDesc, dcyptr, h.wDesc, Wptr, hxDesc,
+            hxptr, cxDesc, cxptr, dxDesc, dxptr, dhxDesc, dhxptr, dcxDesc,
+            dcxptr, wsptr, h.workspace_size, rsptr, h.reserve_size));
+        delete[] dxDesc;
+        delete[] yDesc;
+        delete[] dyDesc;
+      },
+      {y.block(), dy.block(), dhy.block(), dcy.block(), hx.block(), cx.block(),
+       W.block()},
+      {dx.block(), dhx.block(), dcx.block()});
+  return {dx, dhx, dcx};
+}
+
+Tensor GpuRNNBackwardW(const Tensor &x, const Tensor &hx, const Tensor &y,
+                       CudnnRNNHandle &h) {
+  Tensor dW(Shape{h.weights_size}, x.device());
+  dW.device()->Exec(
+      [&dW, &x, &hx, &y, &h](Context *ctx) {

Review comment:
       Same as the problem mentioned above. And it seems that we have not yet implemented the CPU version.

##########
File path: python/singa/autograd.py
##########
@@ -1190,6 +1190,41 @@ def cross_entropy(y, t):
     return CrossEntropy()(y, t)[0]
 
 
+class QALSTMLoss(Operator):
+
+    def __init__(self, M=0.2):
+        super(QALSTMLoss, self).__init__()
+        self.M = M
+
+    def forward(self, pos, neg):
+        # L = max{0, M - cosine(q, a+) + cosine(q, a-)}
+        zero = singa.Tensor(list(pos.shape()), pos.device())
+        zero.SetFloatValue(0.0)
+        val = singa.AddFloat(singa.__sub__(neg, pos), self.M)
+        gt_zero = singa.__gt__(val, zero)
+        self.inputs = (gt_zero, ) # (BS,)
+        all_loss = singa.__mul__(gt_zero, val)
+        loss = singa.SumAll(all_loss)
+        loss /= (pos.shape()[0])
+        # assert loss.shape(0) == 1
+        return loss

Review comment:
       1. What's the type of neg and pos?
   2. Just use neg-pos.
   3. What's the type of gt_zero? If its type is boolean, it may not be buffered in the graph.
   

##########
File path: examples/qabot/train.py
##########
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+import sys
+new_path = r'/root/singa/build/python'
+sys.path.append(new_path)
+
+from data import *
+from model import QAModel, MLP #, QALoss
+
+from singa import autograd
+from singa import layer
+from singa import model
+from singa import tensor
+from singa import device
+from singa import opt
+from tqdm import tqdm
+
+import singa.singa_wrap as singa
+
+import numpy as np
+
+def train(m, tq, ta, id_to_word, label_to_ans_text, wv):
+    print("training")
+    train_data = parse_file('./V2/InsuranceQA.question.anslabel.token.100.pool.solr.train.encoded', id_to_word, label_to_ans_text)
+    train_data = train_data[:100]
+    train_triplets = generate_qa_triplets(train_data) # (q, a+, a-)
+
+
+    for epoch in range(max_epoch):
+        for (q,apos,aneg) in tqdm(train_triplets):
+            q = words_text_to_fixed_seqlen_vec(wv,q,q_seq_length)
+            a = np.array([words_text_to_fixed_seqlen_vec(wv,apos,a_seq_length),words_text_to_fixed_seqlen_vec(wv,aneg,a_seq_length)])
+            q = q.astype(np.float32)
+            a = a.astype(np.float32)
+
+            tq.copy_from_numpy(q)
+            ta.copy_from_numpy(a)
+
+            # train
+            _, l = m(tq,ta)
+
+        print("loss", l)
+
+        # training top1 accuracy
+        top1hit = 0
+        trials = len(train_data)
+
+        for (q, a_pos, a_negs) in train_data:
+            scores = []
+            q_vec = words_text_to_fixed_seqlen_vec(wv,q,q_seq_length)
+            tq.copy_from_numpy(q_vec)
+
+            a_pos_vec = words_text_to_fixed_seqlen_vec(wv,a_pos,a_seq_length)
+            # prepare for <q, a+, a+> input
+            ta.copy_from_numpy(np.array([a_pos_vec]*2))
+            true_score, l = m(tq,ta)
+
+            a_neg_vecs = [words_text_to_fixed_seqlen_vec(wv,a_neg,a_seq_length) for a_neg in a_negs]
+
+            # prepare for triplets <q, a-, a-> input
+            while len(a_neg_vecs) > 1:
+                a_vec=[]
+                a_vec.append(a_neg_vecs.pop(0))
+                a_vec.append(a_neg_vecs.pop(0))
+                ta.copy_from_numpy(np.array(a_vec))
+                score, l = m(tq,ta)
+                scores.extend(score)
+
+            max_neg = np.max(np.array([tensor.to_numpy(s) for s in scores]).flatten())
+            if max_neg < tensor.to_numpy(true_score[0])[0]:
+                top1hit+=1
+
+        print("training top 1 hit accuracy: ", top1hit/trials)
+
+
+def test(m, tq, ta, id_to_word, label_to_ans_text, wv):
+    print("testing")
+    test_data = parse_test_file('./V2/InsuranceQA.question.anslabel.token.100.pool.solr.test.encoded', id_to_word, label_to_ans_text)
+    test_data = test_data[:10]  # run on n samples
+
+    m.eval()
+    top1hit=0
+    trials = len(test_data)
+    for (q, labels, cands) in test_data:
+
+        q_vec = words_text_to_fixed_seqlen_vec(wv, q, q_seq_length)
+        tq.copy_from_numpy(np.array(q_vec))
+
+        cands_vec = [words_text_to_fixed_seqlen_vec(wv, label_to_ans_text[candidate_label], a_seq_length) for candidate_label in cands]
+
+        scores = []
+        # inference all candidates
+        # import pdb; pdb.set_trace()
+        while len(cands_vec) > 1:
+            a_vec=[]
+            a_vec.append(cands_vec.pop(0))
+            a_vec.append(cands_vec.pop(0))
+            ta.copy_from_numpy(np.array(a_vec))
+            score = m(tq,ta) # inference mode only return forward result
+            scores.extend(score)
+
+        # check correct from predict
+        true_idxs = [cands.index(l) for l in labels if l in cands]
+        pred_idx = np.argmax(np.array([tensor.to_numpy(s) for s in scores]).flatten())
+        if pred_idx in true_idxs:
+            top1hit += 1
+
+    print("testing top 1 hit accuracy: ", top1hit/trials)
+
+if __name__ == "__main__":
+    dev = device.create_cuda_gpu(set_default=False)
+
+    q_seq_length = 10
+    a_seq_length = 100
+    embed_size = 300
+    batch_size = 128
+    max_epoch = 30
+    hidden_size = 100
+
+    # build model
+    m = QAModel(hidden_size)
+    print("created qa model")
+    # m = MLP()
+    m.optimizer = opt.SGD()
+
+    tq = tensor.Tensor((1, q_seq_length, embed_size), dev, tensor.float32)
+    ta = tensor.Tensor((2, a_seq_length, embed_size), dev, tensor.float32)
+
+    tq.set_value(0.0)
+    ta.set_value(0.0)
+
+    m.compile([tq, ta], is_train=True, use_graph=False, sequential=False)

Review comment:
       If the graph is enabled, can we train the model correctly?

##########
File path: src/model/operation/rnn.cc
##########
@@ -0,0 +1,407 @@
+/*********************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ ************************************************************/
+
+#include "rnn.h"
+namespace singa {
+#ifdef USE_CUDNN
+CudnnRNNHandle::CudnnRNNHandle(const Tensor &x, const int hidden_size,
+                               const int mode, const int num_layers,
+                               const int bias, const float dropout,
+                               const int bidirectional)
+    : bias(bias),
+      dropout(dropout),
+      bidirectional(bidirectional),
+      hidden_size(hidden_size),
+      mode(mode),
+      num_layers(num_layers) {
+  CHECK_EQ(bias, 1) << "Current implementation always include bias";
+  CHECK(bidirectional == 0 || bidirectional == 1)
+      << "bidirectional should be 0 or 1 not " << bidirectional;
+
+  dev = x.device();
+  ctx = x.device()->context(0);
+
+  seq_length = x.shape(0);
+  batch_size = x.shape(1);
+  feature_size = x.shape(2);
+
+  cudnnRNNAlgo = CUDNN_RNN_ALGO_STANDARD;
+  cudnnDataType = CUDNN_DATA_FLOAT;
+
+  cudnnTensorDescriptor_t *xDesc = new cudnnTensorDescriptor_t[seq_length];
+  init_xDesc(xDesc, *this);
+
+  init_dropout_desc();
+  init_rnn_desc();
+  init_parameters_desc(xDesc);
+  init_workspace(xDesc);
+}
+
+void CudnnRNNHandle::init_workspace(cudnnTensorDescriptor_t *xDesc) {
+  /* workspace data */
+  // Need for every pass
+  CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnnDesc, seq_length,
+                                       xDesc, &workspace_size));
+  // Only needed in training, shouldn't be touched between passes.
+  CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnnDesc,
+                                             seq_length, xDesc, &reserve_size));
+
+  workspace = Tensor(Shape{workspace_size}, dev);
+  reserve_space = Tensor(Shape{reserve_size}, dev);
+}
+
+void CudnnRNNHandle::init_parameters_desc(cudnnTensorDescriptor_t *xDesc) {
+  /* weights size
+   *   depends on rnn desc */
+  CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnnDesc, xDesc[0],
+                                    &weights_size, cudnnDataType));
+  /* weights desc
+   *   depends on weights size */
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&wDesc));
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&dwDesc));
+
+  int dimW[3];
+  dimW[0] = weights_size / sizeof(float);  // TODO different types
+  dimW[1] = 1;
+  dimW[2] = 1;
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(wDesc, cudnnDataType,
+                                         CUDNN_TENSOR_NCHW, 3, dimW));
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(dwDesc, cudnnDataType,
+                                         CUDNN_TENSOR_NCHW, 3, dimW));
+}
+
+void CudnnRNNHandle::init_rnn_desc() {
+  /* rnn desc */
+  CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnnDesc));
+  if (mode == 0)
+    RNNMode = CUDNN_RNN_RELU;
+  else if (mode == 1)
+    RNNMode = CUDNN_RNN_TANH;
+  else if (mode == 2)
+    RNNMode = CUDNN_LSTM;
+  else if (mode == 3)
+    RNNMode = CUDNN_GRU;
+  CUDNN_CHECK(cudnnSetRNNDescriptor(
+      ctx->cudnn_handle, rnnDesc, hidden_size, num_layers, dropoutDesc,
+      CUDNN_LINEAR_INPUT,
+      bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, RNNMode,
+      cudnnRNNAlgo,  // CUDNN_RNN_ALGO_STANDARD,
+      cudnnDataType));
+}
+void CudnnRNNHandle::init_dropout_desc() {
+  /* drop out */
+  size_t seed = 0x1234567;
+  CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropoutDesc));
+  size_t stateSize;
+  CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &stateSize));
+  CUDA_CHECK(cudaMalloc(&states, stateSize));
+  CUDNN_CHECK(cudnnSetDropoutDescriptor(dropoutDesc, ctx->cudnn_handle, dropout,
+                                        states, stateSize, seed));
+}
+
+// reserve for masking
+Tensor CudnnRNNHandle::merge_inputs(size_t num, const vector<Tensor> &in) {
+  if (num == 1) return in.at(0);
+  size_t size = 0;
+  for (size_t i = 0; i < num; i++) size += in.at(i).Size();
+  Tensor out(Shape{size}, in.at(0).device(), in.at(0).data_type());
+  for (size_t i = 0, offset = 0; i < num; i++) {
+    CopyDataToFrom(&out, in.at(i), in.at(i).Size(), offset);
+    offset += in.at(i).Size();
+  }
+  return out;
+}
+vector<Tensor> CudnnRNNHandle::split_output(size_t num, size_t dim,
+                                            const vector<Tensor> &in,
+                                            const Tensor output) {
+  vector<Tensor> outputs;
+  if (num == 1) {
+    outputs.push_back(Reshape(output, Shape{in.at(0).shape(0), dim}));
+  } else {
+    for (size_t i = 0, offset = 0; offset < output.Size(); i++) {
+      Shape s{in.at(i).shape(0), dim};
+      Tensor out(s, output.device(), output.data_type());
+      CopyDataToFrom(&out, output, out.Size(), 0, offset);
+      outputs.push_back(out);
+      offset += out.Size();
+    }
+    CHECK_EQ(num, outputs.size());
+  }
+  return outputs;
+}
+
+void init_yDesc(cudnnTensorDescriptor_t *yDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.batch_size;
+  dimA[1] = h.bidirectional ? h.hidden_size * 2 : h.hidden_size;
+  dimA[2] = 1;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+
+  for (int i = 0; i < h.seq_length; i++) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&yDesc[i]));
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(yDesc[i], h.cudnnDataType, 3, dimA,
+                                           strideA));
+  }
+}
+
+void init_xDesc(cudnnTensorDescriptor_t *xDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.batch_size;
+  dimA[1] = h.feature_size;
+  dimA[2] = 1;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+
+  for (int i = 0; i < h.seq_length; i++) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&xDesc[i]));
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(xDesc[i], h.cudnnDataType, 3, dimA,
+                                           strideA));
+  }
+}
+
+void init_hc_Desc(cudnnTensorDescriptor_t &hxDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.num_layers * (h.bidirectional ? 2 : 1);
+  dimA[1] = h.batch_size;
+  dimA[2] = h.hidden_size;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&hxDesc));
+  CUDNN_CHECK(
+      cudnnSetTensorNdDescriptor(hxDesc, h.cudnnDataType, 3, dimA, strideA));
+}
+
+vector<Tensor> GpuRNNForwardInference(const Tensor &x, const Tensor &hx,
+                                      const Tensor &cx, const Tensor &W,
+                                      CudnnRNNHandle &h) {
+  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
+  h.seq_length = x.shape(0);
+  h.batch_size = x.shape(1);  // update batch size to accomodate bs change
+  Tensor y(Shape{h.seq_length, h.batch_size,
+                 h.hidden_size * (h.bidirectional ? 2 : 1)},
+           x.device());
+  Tensor hy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  Tensor cy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+
+  y.device()->Exec(
+      [&y, &hy, &cy, &x, &hx, &cx, &W, &h](Context *ctx) {

Review comment:
       Can't submit the operator like this. All tensors that are not parameters should be passed by value(These tensors will be released, there will some segmentation faults if they are not passed by value). Make sure the correctness of read_blocks and write_blocks.
   Can refer to [convolution.cc](https://github.com/apache/singa/blob/master/src/model/operation/convolution.cc#L590).

##########
File path: src/model/operation/rnn.cc
##########
@@ -0,0 +1,407 @@
+/*********************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ ************************************************************/
+
+#include "rnn.h"
+namespace singa {
+#ifdef USE_CUDNN
+CudnnRNNHandle::CudnnRNNHandle(const Tensor &x, const int hidden_size,
+                               const int mode, const int num_layers,
+                               const int bias, const float dropout,
+                               const int bidirectional)
+    : bias(bias),
+      dropout(dropout),
+      bidirectional(bidirectional),
+      hidden_size(hidden_size),
+      mode(mode),
+      num_layers(num_layers) {
+  CHECK_EQ(bias, 1) << "Current implementation always include bias";
+  CHECK(bidirectional == 0 || bidirectional == 1)
+      << "bidirectional should be 0 or 1 not " << bidirectional;
+
+  dev = x.device();
+  ctx = x.device()->context(0);
+
+  seq_length = x.shape(0);
+  batch_size = x.shape(1);
+  feature_size = x.shape(2);
+
+  cudnnRNNAlgo = CUDNN_RNN_ALGO_STANDARD;
+  cudnnDataType = CUDNN_DATA_FLOAT;
+
+  cudnnTensorDescriptor_t *xDesc = new cudnnTensorDescriptor_t[seq_length];
+  init_xDesc(xDesc, *this);
+
+  init_dropout_desc();
+  init_rnn_desc();
+  init_parameters_desc(xDesc);
+  init_workspace(xDesc);
+}
+
+void CudnnRNNHandle::init_workspace(cudnnTensorDescriptor_t *xDesc) {
+  /* workspace data */
+  // Need for every pass
+  CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnnDesc, seq_length,
+                                       xDesc, &workspace_size));
+  // Only needed in training, shouldn't be touched between passes.
+  CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnnDesc,
+                                             seq_length, xDesc, &reserve_size));
+
+  workspace = Tensor(Shape{workspace_size}, dev);
+  reserve_space = Tensor(Shape{reserve_size}, dev);
+}
+
+void CudnnRNNHandle::init_parameters_desc(cudnnTensorDescriptor_t *xDesc) {
+  /* weights size
+   *   depends on rnn desc */
+  CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnnDesc, xDesc[0],
+                                    &weights_size, cudnnDataType));
+  /* weights desc
+   *   depends on weights size */
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&wDesc));
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&dwDesc));
+
+  int dimW[3];
+  dimW[0] = weights_size / sizeof(float);  // TODO different types
+  dimW[1] = 1;
+  dimW[2] = 1;
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(wDesc, cudnnDataType,
+                                         CUDNN_TENSOR_NCHW, 3, dimW));
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(dwDesc, cudnnDataType,
+                                         CUDNN_TENSOR_NCHW, 3, dimW));
+}
+
+void CudnnRNNHandle::init_rnn_desc() {
+  /* rnn desc */
+  CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnnDesc));
+  if (mode == 0)
+    RNNMode = CUDNN_RNN_RELU;
+  else if (mode == 1)
+    RNNMode = CUDNN_RNN_TANH;
+  else if (mode == 2)
+    RNNMode = CUDNN_LSTM;
+  else if (mode == 3)
+    RNNMode = CUDNN_GRU;
+  CUDNN_CHECK(cudnnSetRNNDescriptor(
+      ctx->cudnn_handle, rnnDesc, hidden_size, num_layers, dropoutDesc,
+      CUDNN_LINEAR_INPUT,
+      bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, RNNMode,
+      cudnnRNNAlgo,  // CUDNN_RNN_ALGO_STANDARD,
+      cudnnDataType));
+}
+void CudnnRNNHandle::init_dropout_desc() {
+  /* drop out */
+  size_t seed = 0x1234567;
+  CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropoutDesc));
+  size_t stateSize;
+  CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &stateSize));
+  CUDA_CHECK(cudaMalloc(&states, stateSize));
+  CUDNN_CHECK(cudnnSetDropoutDescriptor(dropoutDesc, ctx->cudnn_handle, dropout,
+                                        states, stateSize, seed));
+}
+
+// reserve for masking
+Tensor CudnnRNNHandle::merge_inputs(size_t num, const vector<Tensor> &in) {
+  if (num == 1) return in.at(0);
+  size_t size = 0;
+  for (size_t i = 0; i < num; i++) size += in.at(i).Size();
+  Tensor out(Shape{size}, in.at(0).device(), in.at(0).data_type());
+  for (size_t i = 0, offset = 0; i < num; i++) {
+    CopyDataToFrom(&out, in.at(i), in.at(i).Size(), offset);
+    offset += in.at(i).Size();
+  }
+  return out;
+}
+vector<Tensor> CudnnRNNHandle::split_output(size_t num, size_t dim,
+                                            const vector<Tensor> &in,
+                                            const Tensor output) {
+  vector<Tensor> outputs;
+  if (num == 1) {
+    outputs.push_back(Reshape(output, Shape{in.at(0).shape(0), dim}));
+  } else {
+    for (size_t i = 0, offset = 0; offset < output.Size(); i++) {
+      Shape s{in.at(i).shape(0), dim};
+      Tensor out(s, output.device(), output.data_type());
+      CopyDataToFrom(&out, output, out.Size(), 0, offset);
+      outputs.push_back(out);
+      offset += out.Size();
+    }
+    CHECK_EQ(num, outputs.size());
+  }
+  return outputs;
+}
+
+void init_yDesc(cudnnTensorDescriptor_t *yDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.batch_size;
+  dimA[1] = h.bidirectional ? h.hidden_size * 2 : h.hidden_size;
+  dimA[2] = 1;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+
+  for (int i = 0; i < h.seq_length; i++) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&yDesc[i]));
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(yDesc[i], h.cudnnDataType, 3, dimA,
+                                           strideA));
+  }
+}
+
+void init_xDesc(cudnnTensorDescriptor_t *xDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.batch_size;
+  dimA[1] = h.feature_size;
+  dimA[2] = 1;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+
+  for (int i = 0; i < h.seq_length; i++) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&xDesc[i]));
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(xDesc[i], h.cudnnDataType, 3, dimA,
+                                           strideA));
+  }
+}
+
+void init_hc_Desc(cudnnTensorDescriptor_t &hxDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.num_layers * (h.bidirectional ? 2 : 1);
+  dimA[1] = h.batch_size;
+  dimA[2] = h.hidden_size;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&hxDesc));
+  CUDNN_CHECK(
+      cudnnSetTensorNdDescriptor(hxDesc, h.cudnnDataType, 3, dimA, strideA));
+}
+
+vector<Tensor> GpuRNNForwardInference(const Tensor &x, const Tensor &hx,
+                                      const Tensor &cx, const Tensor &W,
+                                      CudnnRNNHandle &h) {
+  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
+  h.seq_length = x.shape(0);
+  h.batch_size = x.shape(1);  // update batch size to accomodate bs change
+  Tensor y(Shape{h.seq_length, h.batch_size,
+                 h.hidden_size * (h.bidirectional ? 2 : 1)},
+           x.device());
+  Tensor hy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  Tensor cy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+
+  y.device()->Exec(
+      [&y, &hy, &cy, &x, &hx, &cx, &W, &h](Context *ctx) {
+        // require desc, [x], hx, cx, w, y, hy, cy
+        cudnnTensorDescriptor_t *xDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        cudnnTensorDescriptor_t *yDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        init_xDesc(xDesc, h);
+        init_yDesc(yDesc, h);
+        cudnnTensorDescriptor_t hxDesc;
+        cudnnTensorDescriptor_t cxDesc;
+        cudnnTensorDescriptor_t hyDesc;
+        cudnnTensorDescriptor_t cyDesc;
+        init_hc_Desc(hxDesc, h);
+        init_hc_Desc(cxDesc, h);
+        init_hc_Desc(hyDesc, h);
+        init_hc_Desc(cyDesc, h);
+
+        auto xptr = x.block()->data();
+        auto hxptr = hx.block()->data();
+        auto cxptr = cx.block()->data();
+        auto Wptr = W.block()->data();
+        auto yptr = y.block()->mutable_data();
+        auto hyptr = hy.block()->mutable_data();
+        auto cyptr = cy.block()->mutable_data();
+        auto wsptr = h.workspace.block()->mutable_data();
+
+        CUDNN_CHECK(cudnnRNNForwardInference(
+            ctx->cudnn_handle, h.rnnDesc, h.seq_length, xDesc, xptr, hxDesc,
+            hxptr, cxDesc, cxptr, h.wDesc, Wptr, yDesc, yptr, hyDesc, hyptr,
+            cyDesc, cyptr, wsptr, h.workspace_size));
+
+        delete[] xDesc;
+        delete[] yDesc;
+      },
+      {x.block(), hx.block(), cx.block(), W.block()},
+      {y.block(), hy.block(), cy.block()});
+  return {y, hy, cy};
+}
+
+vector<Tensor> GpuRNNForwardTraining(const Tensor &x, const Tensor &hx,
+                                     const Tensor &cx, const Tensor &W,
+                                     CudnnRNNHandle &h) {
+  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
+  h.seq_length = x.shape(0);
+  h.batch_size = x.shape(1);  // update batch size to accomodate bs change
+  Tensor y(Shape{h.seq_length, h.batch_size,
+                 h.hidden_size * (h.bidirectional ? 2 : 1)},
+           x.device());
+  Tensor hy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  Tensor cy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  y.device()->Exec(
+      [&y, &hy, &cy, &x, &hx, &cx, &W, &h](Context *ctx) {

Review comment:
       Same as the problem mentioned above.

##########
File path: src/model/operation/rnn.cc
##########
@@ -0,0 +1,407 @@
+/*********************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ ************************************************************/
+
+#include "rnn.h"
+namespace singa {
+#ifdef USE_CUDNN
+CudnnRNNHandle::CudnnRNNHandle(const Tensor &x, const int hidden_size,
+                               const int mode, const int num_layers,
+                               const int bias, const float dropout,
+                               const int bidirectional)
+    : bias(bias),
+      dropout(dropout),
+      bidirectional(bidirectional),
+      hidden_size(hidden_size),
+      mode(mode),
+      num_layers(num_layers) {
+  CHECK_EQ(bias, 1) << "Current implementation always include bias";
+  CHECK(bidirectional == 0 || bidirectional == 1)
+      << "bidirectional should be 0 or 1 not " << bidirectional;
+
+  dev = x.device();
+  ctx = x.device()->context(0);
+
+  seq_length = x.shape(0);
+  batch_size = x.shape(1);
+  feature_size = x.shape(2);
+
+  cudnnRNNAlgo = CUDNN_RNN_ALGO_STANDARD;
+  cudnnDataType = CUDNN_DATA_FLOAT;
+
+  cudnnTensorDescriptor_t *xDesc = new cudnnTensorDescriptor_t[seq_length];
+  init_xDesc(xDesc, *this);
+
+  init_dropout_desc();
+  init_rnn_desc();
+  init_parameters_desc(xDesc);
+  init_workspace(xDesc);
+}
+
+void CudnnRNNHandle::init_workspace(cudnnTensorDescriptor_t *xDesc) {
+  /* workspace data */
+  // Need for every pass
+  CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnnDesc, seq_length,
+                                       xDesc, &workspace_size));
+  // Only needed in training, shouldn't be touched between passes.
+  CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnnDesc,
+                                             seq_length, xDesc, &reserve_size));
+
+  workspace = Tensor(Shape{workspace_size}, dev);
+  reserve_space = Tensor(Shape{reserve_size}, dev);
+}
+
+void CudnnRNNHandle::init_parameters_desc(cudnnTensorDescriptor_t *xDesc) {
+  /* weights size
+   *   depends on rnn desc */
+  CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnnDesc, xDesc[0],
+                                    &weights_size, cudnnDataType));
+  /* weights desc
+   *   depends on weights size */
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&wDesc));
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&dwDesc));
+
+  int dimW[3];
+  dimW[0] = weights_size / sizeof(float);  // TODO different types
+  dimW[1] = 1;
+  dimW[2] = 1;
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(wDesc, cudnnDataType,
+                                         CUDNN_TENSOR_NCHW, 3, dimW));
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(dwDesc, cudnnDataType,
+                                         CUDNN_TENSOR_NCHW, 3, dimW));
+}
+
+void CudnnRNNHandle::init_rnn_desc() {
+  /* rnn desc */
+  CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnnDesc));
+  if (mode == 0)
+    RNNMode = CUDNN_RNN_RELU;
+  else if (mode == 1)
+    RNNMode = CUDNN_RNN_TANH;
+  else if (mode == 2)
+    RNNMode = CUDNN_LSTM;
+  else if (mode == 3)
+    RNNMode = CUDNN_GRU;
+  CUDNN_CHECK(cudnnSetRNNDescriptor(
+      ctx->cudnn_handle, rnnDesc, hidden_size, num_layers, dropoutDesc,
+      CUDNN_LINEAR_INPUT,
+      bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, RNNMode,
+      cudnnRNNAlgo,  // CUDNN_RNN_ALGO_STANDARD,
+      cudnnDataType));
+}
+void CudnnRNNHandle::init_dropout_desc() {
+  /* drop out */
+  size_t seed = 0x1234567;
+  CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropoutDesc));
+  size_t stateSize;
+  CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &stateSize));
+  CUDA_CHECK(cudaMalloc(&states, stateSize));
+  CUDNN_CHECK(cudnnSetDropoutDescriptor(dropoutDesc, ctx->cudnn_handle, dropout,
+                                        states, stateSize, seed));
+}
+
+// reserve for masking
+Tensor CudnnRNNHandle::merge_inputs(size_t num, const vector<Tensor> &in) {
+  if (num == 1) return in.at(0);
+  size_t size = 0;
+  for (size_t i = 0; i < num; i++) size += in.at(i).Size();
+  Tensor out(Shape{size}, in.at(0).device(), in.at(0).data_type());
+  for (size_t i = 0, offset = 0; i < num; i++) {
+    CopyDataToFrom(&out, in.at(i), in.at(i).Size(), offset);
+    offset += in.at(i).Size();
+  }
+  return out;
+}
+vector<Tensor> CudnnRNNHandle::split_output(size_t num, size_t dim,
+                                            const vector<Tensor> &in,
+                                            const Tensor output) {
+  vector<Tensor> outputs;
+  if (num == 1) {
+    outputs.push_back(Reshape(output, Shape{in.at(0).shape(0), dim}));
+  } else {
+    for (size_t i = 0, offset = 0; offset < output.Size(); i++) {
+      Shape s{in.at(i).shape(0), dim};
+      Tensor out(s, output.device(), output.data_type());
+      CopyDataToFrom(&out, output, out.Size(), 0, offset);
+      outputs.push_back(out);
+      offset += out.Size();
+    }
+    CHECK_EQ(num, outputs.size());
+  }
+  return outputs;
+}
+
+void init_yDesc(cudnnTensorDescriptor_t *yDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.batch_size;
+  dimA[1] = h.bidirectional ? h.hidden_size * 2 : h.hidden_size;
+  dimA[2] = 1;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+
+  for (int i = 0; i < h.seq_length; i++) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&yDesc[i]));
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(yDesc[i], h.cudnnDataType, 3, dimA,
+                                           strideA));
+  }
+}
+
+void init_xDesc(cudnnTensorDescriptor_t *xDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.batch_size;
+  dimA[1] = h.feature_size;
+  dimA[2] = 1;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+
+  for (int i = 0; i < h.seq_length; i++) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&xDesc[i]));
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(xDesc[i], h.cudnnDataType, 3, dimA,
+                                           strideA));
+  }
+}
+
+void init_hc_Desc(cudnnTensorDescriptor_t &hxDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.num_layers * (h.bidirectional ? 2 : 1);
+  dimA[1] = h.batch_size;
+  dimA[2] = h.hidden_size;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&hxDesc));
+  CUDNN_CHECK(
+      cudnnSetTensorNdDescriptor(hxDesc, h.cudnnDataType, 3, dimA, strideA));
+}
+
+vector<Tensor> GpuRNNForwardInference(const Tensor &x, const Tensor &hx,
+                                      const Tensor &cx, const Tensor &W,
+                                      CudnnRNNHandle &h) {
+  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
+  h.seq_length = x.shape(0);
+  h.batch_size = x.shape(1);  // update batch size to accomodate bs change
+  Tensor y(Shape{h.seq_length, h.batch_size,
+                 h.hidden_size * (h.bidirectional ? 2 : 1)},
+           x.device());
+  Tensor hy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  Tensor cy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+
+  y.device()->Exec(
+      [&y, &hy, &cy, &x, &hx, &cx, &W, &h](Context *ctx) {

Review comment:
       Can't submit the operator like this. All tensors that are not parameters should be passed by value(These tensors will be released, there will be some segmentation faults if they are not passed by value). Make sure the correctness of read_blocks and write_blocks.
   Can refer to [convolution.cc](https://github.com/apache/singa/blob/master/src/model/operation/convolution.cc#L590).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org