You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@singa.apache.org by GitBox <gi...@apache.org> on 2020/06/09 15:55:24 UTC

[GitHub] [singa] dcslin commented on a change in pull request #722: cudnn lstm

dcslin commented on a change in pull request #722:
URL: https://github.com/apache/singa/pull/722#discussion_r436782678



##########
File path: python/singa/autograd.py
##########
@@ -1190,6 +1190,41 @@ def cross_entropy(y, t):
     return CrossEntropy()(y, t)[0]
 
 
+class QALSTMLoss(Operator):
+
+    def __init__(self, M=0.2):
+        super(QALSTMLoss, self).__init__()
+        self.M = M
+
+    def forward(self, pos, neg):
+        # L = max{0, M - cosine(q, a+) + cosine(q, a-)}
+        zero = singa.Tensor(list(pos.shape()), pos.device())
+        zero.SetFloatValue(0.0)
+        val = singa.AddFloat(singa.__sub__(neg, pos), self.M)
+        gt_zero = singa.__gt__(val, zero)
+        self.inputs = (gt_zero, ) # (BS,)
+        all_loss = singa.__mul__(gt_zero, val)
+        loss = singa.SumAll(all_loss)
+        loss /= (pos.shape()[0])
+        # assert loss.shape(0) == 1
+        return loss

Review comment:
       1. neg pos are tensors of shape (bs,)
   2. do you mean replace singa.__sub__ with "-"?
   3. gt_zero is tensor of (bs,) value is float

##########
File path: src/model/operation/rnn.cc
##########
@@ -0,0 +1,407 @@
+/*********************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ ************************************************************/
+
+#include "rnn.h"
+namespace singa {
+#ifdef USE_CUDNN
+CudnnRNNHandle::CudnnRNNHandle(const Tensor &x, const int hidden_size,
+                               const int mode, const int num_layers,
+                               const int bias, const float dropout,
+                               const int bidirectional)
+    : bias(bias),
+      dropout(dropout),
+      bidirectional(bidirectional),
+      hidden_size(hidden_size),
+      mode(mode),
+      num_layers(num_layers) {
+  CHECK_EQ(bias, 1) << "Current implementation always include bias";
+  CHECK(bidirectional == 0 || bidirectional == 1)
+      << "bidirectional should be 0 or 1 not " << bidirectional;
+
+  dev = x.device();
+  ctx = x.device()->context(0);
+
+  seq_length = x.shape(0);
+  batch_size = x.shape(1);
+  feature_size = x.shape(2);
+
+  cudnnRNNAlgo = CUDNN_RNN_ALGO_STANDARD;
+  cudnnDataType = CUDNN_DATA_FLOAT;
+
+  cudnnTensorDescriptor_t *xDesc = new cudnnTensorDescriptor_t[seq_length];
+  init_xDesc(xDesc, *this);
+
+  init_dropout_desc();
+  init_rnn_desc();
+  init_parameters_desc(xDesc);
+  init_workspace(xDesc);
+}
+
+void CudnnRNNHandle::init_workspace(cudnnTensorDescriptor_t *xDesc) {
+  /* workspace data */
+  // Need for every pass
+  CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnnDesc, seq_length,
+                                       xDesc, &workspace_size));
+  // Only needed in training, shouldn't be touched between passes.
+  CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnnDesc,
+                                             seq_length, xDesc, &reserve_size));
+
+  workspace = Tensor(Shape{workspace_size}, dev);
+  reserve_space = Tensor(Shape{reserve_size}, dev);
+}
+
+void CudnnRNNHandle::init_parameters_desc(cudnnTensorDescriptor_t *xDesc) {
+  /* weights size
+   *   depends on rnn desc */
+  CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnnDesc, xDesc[0],
+                                    &weights_size, cudnnDataType));
+  /* weights desc
+   *   depends on weights size */
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&wDesc));
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&dwDesc));
+
+  int dimW[3];
+  dimW[0] = weights_size / sizeof(float);  // TODO different types
+  dimW[1] = 1;
+  dimW[2] = 1;
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(wDesc, cudnnDataType,
+                                         CUDNN_TENSOR_NCHW, 3, dimW));
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(dwDesc, cudnnDataType,
+                                         CUDNN_TENSOR_NCHW, 3, dimW));
+}
+
+void CudnnRNNHandle::init_rnn_desc() {
+  /* rnn desc */
+  CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnnDesc));
+  if (mode == 0)
+    RNNMode = CUDNN_RNN_RELU;
+  else if (mode == 1)
+    RNNMode = CUDNN_RNN_TANH;
+  else if (mode == 2)
+    RNNMode = CUDNN_LSTM;
+  else if (mode == 3)
+    RNNMode = CUDNN_GRU;
+  CUDNN_CHECK(cudnnSetRNNDescriptor(
+      ctx->cudnn_handle, rnnDesc, hidden_size, num_layers, dropoutDesc,
+      CUDNN_LINEAR_INPUT,
+      bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, RNNMode,
+      cudnnRNNAlgo,  // CUDNN_RNN_ALGO_STANDARD,
+      cudnnDataType));
+}
+void CudnnRNNHandle::init_dropout_desc() {
+  /* drop out */
+  size_t seed = 0x1234567;
+  CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropoutDesc));
+  size_t stateSize;
+  CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &stateSize));
+  CUDA_CHECK(cudaMalloc(&states, stateSize));
+  CUDNN_CHECK(cudnnSetDropoutDescriptor(dropoutDesc, ctx->cudnn_handle, dropout,
+                                        states, stateSize, seed));
+}
+
+// reserve for masking
+Tensor CudnnRNNHandle::merge_inputs(size_t num, const vector<Tensor> &in) {
+  if (num == 1) return in.at(0);
+  size_t size = 0;
+  for (size_t i = 0; i < num; i++) size += in.at(i).Size();
+  Tensor out(Shape{size}, in.at(0).device(), in.at(0).data_type());
+  for (size_t i = 0, offset = 0; i < num; i++) {
+    CopyDataToFrom(&out, in.at(i), in.at(i).Size(), offset);
+    offset += in.at(i).Size();
+  }
+  return out;
+}
+vector<Tensor> CudnnRNNHandle::split_output(size_t num, size_t dim,
+                                            const vector<Tensor> &in,
+                                            const Tensor output) {
+  vector<Tensor> outputs;
+  if (num == 1) {
+    outputs.push_back(Reshape(output, Shape{in.at(0).shape(0), dim}));
+  } else {
+    for (size_t i = 0, offset = 0; offset < output.Size(); i++) {
+      Shape s{in.at(i).shape(0), dim};
+      Tensor out(s, output.device(), output.data_type());
+      CopyDataToFrom(&out, output, out.Size(), 0, offset);
+      outputs.push_back(out);
+      offset += out.Size();
+    }
+    CHECK_EQ(num, outputs.size());
+  }
+  return outputs;
+}
+
+void init_yDesc(cudnnTensorDescriptor_t *yDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.batch_size;
+  dimA[1] = h.bidirectional ? h.hidden_size * 2 : h.hidden_size;
+  dimA[2] = 1;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+
+  for (int i = 0; i < h.seq_length; i++) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&yDesc[i]));
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(yDesc[i], h.cudnnDataType, 3, dimA,
+                                           strideA));
+  }
+}
+
+void init_xDesc(cudnnTensorDescriptor_t *xDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.batch_size;
+  dimA[1] = h.feature_size;
+  dimA[2] = 1;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+
+  for (int i = 0; i < h.seq_length; i++) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&xDesc[i]));
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(xDesc[i], h.cudnnDataType, 3, dimA,
+                                           strideA));
+  }
+}
+
+void init_hc_Desc(cudnnTensorDescriptor_t &hxDesc, CudnnRNNHandle &h) {
+  int dimA[3];
+  int strideA[3];
+  dimA[0] = h.num_layers * (h.bidirectional ? 2 : 1);
+  dimA[1] = h.batch_size;
+  dimA[2] = h.hidden_size;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&hxDesc));
+  CUDNN_CHECK(
+      cudnnSetTensorNdDescriptor(hxDesc, h.cudnnDataType, 3, dimA, strideA));
+}
+
+vector<Tensor> GpuRNNForwardInference(const Tensor &x, const Tensor &hx,
+                                      const Tensor &cx, const Tensor &W,
+                                      CudnnRNNHandle &h) {
+  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
+  h.seq_length = x.shape(0);
+  h.batch_size = x.shape(1);  // update batch size to accomodate bs change
+  Tensor y(Shape{h.seq_length, h.batch_size,
+                 h.hidden_size * (h.bidirectional ? 2 : 1)},
+           x.device());
+  Tensor hy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  Tensor cy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+
+  y.device()->Exec(
+      [&y, &hy, &cy, &x, &hx, &cx, &W, &h](Context *ctx) {

Review comment:
       ok let me check again

##########
File path: examples/qabot/train.py
##########
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+import sys
+new_path = r'/root/singa/build/python'
+sys.path.append(new_path)
+
+from data import *
+from model import QAModel, MLP #, QALoss
+
+from singa import autograd
+from singa import layer
+from singa import model
+from singa import tensor
+from singa import device
+from singa import opt
+from tqdm import tqdm
+
+import singa.singa_wrap as singa
+
+import numpy as np
+
+def train(m, tq, ta, id_to_word, label_to_ans_text, wv):
+    print("training")
+    train_data = parse_file('./V2/InsuranceQA.question.anslabel.token.100.pool.solr.train.encoded', id_to_word, label_to_ans_text)
+    train_data = train_data[:100]
+    train_triplets = generate_qa_triplets(train_data) # (q, a+, a-)
+
+
+    for epoch in range(max_epoch):
+        for (q,apos,aneg) in tqdm(train_triplets):
+            q = words_text_to_fixed_seqlen_vec(wv,q,q_seq_length)
+            a = np.array([words_text_to_fixed_seqlen_vec(wv,apos,a_seq_length),words_text_to_fixed_seqlen_vec(wv,aneg,a_seq_length)])
+            q = q.astype(np.float32)
+            a = a.astype(np.float32)
+
+            tq.copy_from_numpy(q)
+            ta.copy_from_numpy(a)
+
+            # train
+            _, l = m(tq,ta)
+
+        print("loss", l)
+
+        # training top1 accuracy
+        top1hit = 0
+        trials = len(train_data)
+
+        for (q, a_pos, a_negs) in train_data:
+            scores = []
+            q_vec = words_text_to_fixed_seqlen_vec(wv,q,q_seq_length)
+            tq.copy_from_numpy(q_vec)
+
+            a_pos_vec = words_text_to_fixed_seqlen_vec(wv,a_pos,a_seq_length)
+            # prepare for <q, a+, a+> input
+            ta.copy_from_numpy(np.array([a_pos_vec]*2))
+            true_score, l = m(tq,ta)
+
+            a_neg_vecs = [words_text_to_fixed_seqlen_vec(wv,a_neg,a_seq_length) for a_neg in a_negs]
+
+            # prepare for triplets <q, a-, a-> input
+            while len(a_neg_vecs) > 1:
+                a_vec=[]
+                a_vec.append(a_neg_vecs.pop(0))
+                a_vec.append(a_neg_vecs.pop(0))
+                ta.copy_from_numpy(np.array(a_vec))
+                score, l = m(tq,ta)
+                scores.extend(score)
+
+            max_neg = np.max(np.array([tensor.to_numpy(s) for s in scores]).flatten())
+            if max_neg < tensor.to_numpy(true_score[0])[0]:
+                top1hit+=1
+
+        print("training top 1 hit accuracy: ", top1hit/trials)
+
+
+def test(m, tq, ta, id_to_word, label_to_ans_text, wv):
+    print("testing")
+    test_data = parse_test_file('./V2/InsuranceQA.question.anslabel.token.100.pool.solr.test.encoded', id_to_word, label_to_ans_text)
+    test_data = test_data[:10]  # run on n samples
+
+    m.eval()
+    top1hit=0
+    trials = len(test_data)
+    for (q, labels, cands) in test_data:
+
+        q_vec = words_text_to_fixed_seqlen_vec(wv, q, q_seq_length)
+        tq.copy_from_numpy(np.array(q_vec))
+
+        cands_vec = [words_text_to_fixed_seqlen_vec(wv, label_to_ans_text[candidate_label], a_seq_length) for candidate_label in cands]
+
+        scores = []
+        # inference all candidates
+        # import pdb; pdb.set_trace()
+        while len(cands_vec) > 1:
+            a_vec=[]
+            a_vec.append(cands_vec.pop(0))
+            a_vec.append(cands_vec.pop(0))
+            ta.copy_from_numpy(np.array(a_vec))
+            score = m(tq,ta) # inference mode only return forward result
+            scores.extend(score)
+
+        # check correct from predict
+        true_idxs = [cands.index(l) for l in labels if l in cands]
+        pred_idx = np.argmax(np.array([tensor.to_numpy(s) for s in scores]).flatten())
+        if pred_idx in true_idxs:
+            top1hit += 1
+
+    print("testing top 1 hit accuracy: ", top1hit/trials)
+
+if __name__ == "__main__":
+    dev = device.create_cuda_gpu(set_default=False)
+
+    q_seq_length = 10
+    a_seq_length = 100
+    embed_size = 300
+    batch_size = 128
+    max_epoch = 30
+    hidden_size = 100
+
+    # build model
+    m = QAModel(hidden_size)
+    print("created qa model")
+    # m = MLP()
+    m.optimizer = opt.SGD()
+
+    tq = tensor.Tensor((1, q_seq_length, embed_size), dev, tensor.float32)
+    ta = tensor.Tensor((2, a_seq_length, embed_size), dev, tensor.float32)
+
+    tq.set_value(0.0)
+    ta.set_value(0.0)
+
+    m.compile([tq, ta], is_train=True, use_graph=False, sequential=False)

Review comment:
       setting use_graph to True gives:
   ```
   Traceback (most recent call last):
     File "/root/singa/examples/qabot/train.py", line 184, in <module>
       training(m, train_data, max_epoch)
     File "/root/singa/examples/qabot/train.py", line 126, in training
       score, l = m(tq, ta)
     File "build/python/singa/model.py", line 210, in __call__
       return self.train_one_batch(*input, **kwargs)
     File "build/python/singa/model.py", line 63, in wrapper
       self._results = func(self, *args, **kwargs)
     File "/root/singa/examples/qabot/model.py", line 65, in train_one_batch
       self.optimizer.backward_and_update(loss)
     File "/usr/local/lib/python3.6/dist-packages/deprecated/classic.py", line 281, in wrapper_function
       return wrapped_(*args_, **kwargs_)
     File "build/python/singa/opt.py", line 320, in backward_and_update
       super(SGD, self).__call__(loss)
     File "build/python/singa/opt.py", line 103, in __call__
       self.call(loss)
     File "build/python/singa/opt.py", line 110, in call
       self.apply(p.name, p, g)
     File "build/python/singa/opt.py", line 270, in apply
       singa.Axpy(minus_lr.data, param_grad.data, param_value.data)
   TypeError: in method 'Axpy', argument 1 of type 'float'
   ```

##########
File path: examples/qabot/model.py
##########
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import autograd
+from singa import layer
+from singa import model
+from singa import tensor
+from singa import device
+
+class QAModel(model.Model):
+    def __init__(self, hidden_size, num_layers=1, rnn_mode="lstm", batch_first=True):
+        super(QAModel, self).__init__()
+        self.lstm_q = layer.CudnnRNN(hidden_size=hidden_size,
+                                   num_layers=num_layers,
+                                   bidirectional=True,
+                                   return_sequences=False,
+                                   rnn_mode=rnn_mode,
+                                   batch_first=batch_first)
+        self.lstm_a = layer.CudnnRNN(hidden_size=hidden_size,
+                                   num_layers=num_layers,
+                                   bidirectional=True,
+                                   return_sequences=False,
+                                   rnn_mode=rnn_mode,
+                                   batch_first=batch_first)
+
+    def forward(self, q, a_batch):
+        q = self.lstm_q(q) # BS, Hidden*2
+        a_batch = self.lstm_a(a_batch) # {2, hidden*2}
+        bs_a = int(a_batch.shape[0]/2) # cut concated a-a+ to half and half

Review comment:
       most of the time bs_a does not change, some times changes

##########
File path: examples/qabot/model.py
##########
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import autograd
+from singa import layer
+from singa import model
+from singa import tensor
+from singa import device
+
+class QAModel(model.Model):
+    def __init__(self, hidden_size, num_layers=1, rnn_mode="lstm", batch_first=True):
+        super(QAModel, self).__init__()
+        self.lstm_q = layer.CudnnRNN(hidden_size=hidden_size,
+                                   num_layers=num_layers,
+                                   bidirectional=True,
+                                   return_sequences=False,
+                                   rnn_mode=rnn_mode,
+                                   batch_first=batch_first)
+        self.lstm_a = layer.CudnnRNN(hidden_size=hidden_size,
+                                   num_layers=num_layers,
+                                   bidirectional=True,
+                                   return_sequences=False,
+                                   rnn_mode=rnn_mode,
+                                   batch_first=batch_first)
+
+    def forward(self, q, a_batch):
+        q = self.lstm_q(q) # BS, Hidden*2
+        a_batch = self.lstm_a(a_batch) # {2, hidden*2}
+        bs_a = int(a_batch.shape[0]/2) # cut concated a-a+ to half and half
+        a_pos, a_neg = autograd.split(a_batch, 0, [bs_a,bs_a])
+        sim_pos = autograd.cossim(q, a_pos)
+        sim_neg = autograd.cossim(q, a_neg)
+        return sim_pos, sim_neg
+
+    def train_one_batch(self, q, a):
+        out = self.forward(q, a)
+        loss = autograd.qa_lstm_loss(out[0], out[1])
+        self.optimizer.backward_and_update(loss)
+        return out, loss
+
+
+
+class MLP(model.Model):
+    def __init__(self):
+        super(MLP, self).__init__()
+        self.linear1 = layer.Linear(500)
+        self.relu = layer.ReLU()
+        self.linear2 = layer.Linear(2)
+
+    def forward(self, q, a):
+        q=autograd.reshape(q, (q.shape[0], -1))

Review comment:
       ok




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org