You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2020/07/29 02:04:08 UTC

[singa] branch dev updated: updated qabot training and data scripts, added max, mean, mlp qabot models, updated ranking loss fn in autograd, fix bug in cudnn rnn in autograd, added some utils tensor fn(random, zeros, ones), added cudnn rnn set param api, fixed and added test to autograd mse loss, cos sim, reduce mean

This is an automated email from the ASF dual-hosted git repository.

wangwei pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/dev by this push:
     new 12161b3  updated qabot training and data scripts, added max, mean, mlp qabot models, updated ranking loss fn in autograd, fix bug in cudnn rnn in autograd, added some utils tensor fn(random, zeros, ones), added cudnn rnn set param api, fixed and added test to autograd mse loss, cos sim, reduce mean
     new 493f185  Merge pull request #772 from dcslin/qabot4
12161b3 is described below

commit 12161b35ed653cb2e0244939344b05bb612a1a67
Author: dcslin <13...@users.noreply.github.com>
AuthorDate: Sun Jul 5 02:42:04 2020 +0000

    updated qabot training and data scripts, added max, mean, mlp qabot
    models, updated ranking loss fn in autograd, fix bug in cudnn rnn in
    autograd, added some utils tensor fn(random, zeros, ones), added cudnn
    rnn set param api, fixed and added test to autograd mse loss, cos sim,
    reduce mean
---
 examples/qabot/V2             |   1 -
 examples/qabot/data.py        | 142 ---------------------
 examples/qabot/model.py       |  92 --------------
 examples/qabot/qabot_data.py  | 282 ++++++++++++++++++++++++++++++++++++++++++
 examples/qabot/qabot_model.py | 152 +++++++++++++++++++++++
 examples/qabot/qabot_train.py | 159 ++++++++++++++++++++++++
 examples/qabot/train.py       | 185 ---------------------------
 python/singa/autograd.py      |  72 ++++++-----
 python/singa/layer.py         |  25 ++--
 python/singa/tensor.py        |  31 ++++-
 src/api/model_operation.i     |   3 +
 src/model/operation/rnn.cc    |  91 +++++++++++++-
 src/model/operation/rnn.h     |  12 +-
 test/python/test_operation.py | 151 ++++++++++++++++++++--
 14 files changed, 919 insertions(+), 479 deletions(-)

diff --git a/examples/qabot/V2 b/examples/qabot/V2
deleted file mode 120000
index f3a03bc..0000000
--- a/examples/qabot/V2
+++ /dev/null
@@ -1 +0,0 @@
-/root/QA-LSTM/insuranceQA/V2/
\ No newline at end of file
diff --git a/examples/qabot/data.py b/examples/qabot/data.py
deleted file mode 100644
index 40e998c..0000000
--- a/examples/qabot/data.py
+++ /dev/null
@@ -1,142 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-from gensim.models.keyedvectors import KeyedVectors
-import numpy as np
-import random
-
-
-def load_vocabulary(vocab_path, label_path):
-    id_to_word = {}
-    with open(vocab_path, 'rb') as f:
-        lines = f.readlines()
-        for l in lines:
-            d = l.rstrip().decode("utf-8").split("\t")
-            if d[0] not in id_to_word:
-                id_to_word[d[0]] = d[1]
-
-    label_to_ans = {}
-    label_to_ans_text = {}
-    with open(label_path) as f:
-        lines = f.readlines()
-        for l in lines:
-            label, answer = l.rstrip().split('\t')
-            if label not in label_to_ans:
-                label_to_ans[label] = answer
-                label_to_ans_text[label] = [
-                    id_to_word[t] for t in answer.split(' ')
-                ]
-    return id_to_word, label_to_ans, label_to_ans_text
-
-
-def parse_file(fpath, id_to_word, label_to_ans_text):
-    data = []
-    with open(fpath) as f:
-        lines = f.readlines()
-        for l in lines:
-            d = l.rstrip().split('\t')
-            q = [id_to_word[t] for t in d[1].split(' ')]  # question
-            poss = [label_to_ans_text[t] for t in d[2].split(' ')
-                   ]  # ground-truth
-            negs = [
-                label_to_ans_text[t] for t in d[3].split(' ') if t not in d[2]
-            ]  # candidate-pool without ground-truth
-            for pos in poss:
-                data.append((q, pos, negs))
-    return data
-
-
-def parse_test_file(fpath, id_to_word, label_to_ans_text):
-    data = []
-    with open(fpath) as f:
-        lines = f.readlines()
-        for l in lines[12:]:
-            d = l.rstrip().split('\t')
-            q = [id_to_word[t] for t in d[1].split(' ')]  # question
-            poss = [t for t in d[2].split(' ')]  # ground-truth
-            cands = [t for t in d[3].split(' ')]  # candidate-pool
-            data.append((q, poss, cands))
-    return data
-
-def words_text_to_fixed_seqlen_vec(word2vec, words, sentence_length=10):
-    sentence_vec = []
-    for word in words:
-        if len(sentence_vec) >= sentence_length:
-            break
-        if word in word2vec:
-            sentence_vec.append(word2vec[word])
-        else:
-            sentence_vec.append(np.zeros((300,)))
-    while len(sentence_vec) < sentence_length:
-        sentence_vec.append(np.zeros((300,)))
-    return np.array(sentence_vec, dtype=np.float32)
-
-
-def generate_qa_triplets(data, num_negs=10):
-    tuples = []
-    for (q, a_pos, a_negs) in data:
-        for i in range(num_negs):
-            tpl = (q, a_pos, random.choice(a_negs))
-            tuples.append(tpl)
-    return tuples
-
-def qa_tuples_to_naive_training_format(wv, tuples):
-    training = []
-    q_len_limit = 10
-    a_len_limit = 100
-    for tpl in tuples:
-        q, a_pos, a_neg = tpl
-        q_vec = words_text_to_fixed_seqlen_vec(wv, q, q_len_limit)
-        training.append(
-            (q_vec, words_text_to_fixed_seqlen_vec(wv, a_pos, a_len_limit), 1))
-        training.append(
-            (q_vec, words_text_to_fixed_seqlen_vec(wv, a_neg, a_len_limit), 0))
-    return training
-
-def triplet_text_to_vec(triplet, wv, q_max_len, a_max_len):
-    return [
-        words_text_to_fixed_seqlen_vec(wv, triplet[0], q_max_len),
-        words_text_to_fixed_seqlen_vec(wv, triplet[1], a_max_len),
-        words_text_to_fixed_seqlen_vec(wv, triplet[2], a_max_len)
-    ]
-
-def train_data_gen_fn(train_triplet_vecs, bs=32):
-    q = []
-    ap = []
-    an = []
-    for t in train_triplet_vecs:
-        q.append(t[0])
-        ap.append(t[1])
-        an.append(t[2])
-        if len(q) >= bs:
-            q = np.array(q)
-            ap = np.array(ap)
-            an = np.array(an)
-            a = np.concatenate([ap, an])
-            assert 2 * q.shape[0] == a.shape[0]
-            yield q, a
-            q = []
-            ap = []
-            an = []
-    # return the rest
-    # return np.array(q), np.concatenate([np.array(ap), np.array(an)])
-
-
-if __name__ == "__main__":
-    pass
diff --git a/examples/qabot/model.py b/examples/qabot/model.py
deleted file mode 100644
index 9252fe5..0000000
--- a/examples/qabot/model.py
+++ /dev/null
@@ -1,92 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-from singa import autograd
-from singa import layer
-from singa import model
-from singa import tensor
-from singa import device
-
-
-class QAModel(model.Model):
-
-    def __init__(self,
-                 hidden_size,
-                 num_layers=1,
-                 rnn_mode="lstm",
-                 batch_first=True):
-        super(QAModel, self).__init__()
-        return_sequences = False
-        self.lstm_q = layer.CudnnRNN(hidden_size=hidden_size,
-                                     num_layers=num_layers,
-                                     bidirectional=True,
-                                     return_sequences=return_sequences,
-                                     rnn_mode=rnn_mode,
-                                     batch_first=batch_first)
-        self.lstm_a = layer.CudnnRNN(hidden_size=hidden_size,
-                                     num_layers=num_layers,
-                                     bidirectional=True,
-                                     return_sequences=return_sequences,
-                                     rnn_mode=rnn_mode,
-                                     batch_first=batch_first)
-
-    def forward(self, q, a_batch):
-        q = self.lstm_q(q)  # BS, Hidden*2
-        a_batch = self.lstm_a(a_batch)  # {2, hidden*2}
-
-        # full sequences {2bs, seqlength, hidden*2}
-        # a_batch = autograd.reduce_mean(a_batch, [1]) # to {2bs, hidden*2}
-
-        bs_a = int(a_batch.shape[0] / 2)  # cut concated a-a+ to half and half
-        a_pos, a_neg = autograd.split(a_batch, 0, [bs_a, bs_a])
-
-        sim_pos = autograd.cossim(q, a_pos)
-        sim_neg = autograd.cossim(q, a_neg)
-        return sim_pos, sim_neg
-
-    def train_one_batch(self, q, a):
-        out = self.forward(q, a)
-        loss = autograd.qa_lstm_loss(out[0], out[1])
-        self.optimizer.backward_and_update(loss)
-
-        return out, loss
-
-
-class MLP(model.Model):
-
-    def __init__(self):
-        super(MLP, self).__init__()
-        self.linear1 = layer.Linear(500)
-        self.relu = layer.ReLU()
-        self.linear2 = layer.Linear(2)
-
-    def forward(self, q, a):
-        q = autograd.reshape(q, (q.shape[0], -1))
-        a = autograd.reshape(a, (q.shape[0], -1))
-        qa = autograd.cat([q, a], 1)
-        y = self.linear1(qa)
-        y = self.relu(y)
-        y = self.linear2(y)
-        return y
-
-    def train_one_batch(self, q, a, y):
-        out = self.forward(q, a)
-        loss = autograd.softmax_cross_entropy(out, y)
-        self.optimizer.backward_and_update(loss)
-        return out, loss
diff --git a/examples/qabot/qabot_data.py b/examples/qabot/qabot_data.py
new file mode 100644
index 0000000..a00d736
--- /dev/null
+++ b/examples/qabot/qabot_data.py
@@ -0,0 +1,282 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import numpy as np
+import random
+
+download_dir = "/tmp/"
+import os
+import urllib
+
+
+def check_exist_or_download(url):
+    ''' download data into tmp '''
+    name = url.rsplit('/', 1)[-1]
+    filename = os.path.join(download_dir, name)
+    if not os.path.isfile(filename):
+        print("Downloading %s" % url)
+        urllib.request.urlretrieve(url, filename)
+    return filename
+
+
+def unzip_data(download_dir, data_zip):
+    data_dir = download_dir + "insuranceQA-master/V2/"
+    if not os.path.exists(data_dir):
+        print("extracting %s to %s" % (download_dir, data_dir))
+        from zipfile import ZipFile
+        with ZipFile(data_zip, 'r') as zipObj:
+            zipObj.extractall(download_dir)
+    return data_dir
+
+
+def get_label2answer(data_dir):
+    import gzip
+    label2answer = dict()
+    with gzip.open(data_dir +
+                   "/InsuranceQA.label2answer.token.encoded.gz") as fin:
+        for line in fin:
+            pair = line.decode().strip().split("\t")
+            idxs = pair[1].split(" ")
+            idxs = [int(idx.replace("idx_", "")) for idx in idxs]
+            label2answer[int(pair[0])] = idxs
+    return label2answer
+
+
+pad_idx = 0
+pad_string = "<pad>"
+pad_embed = np.zeros((300,))
+
+insuranceqa_train_filename = "/InsuranceQA.question.anslabel.token.100.pool.solr.train.encoded.gz"
+insuranceqa_test_filename = "/InsuranceQA.question.anslabel.token.100.pool.solr.test.encoded.gz"
+insuranceQA_url = "https://github.com/shuzi/insuranceQA/archive/master.zip"
+insuranceQA_cache_fp = download_dir + "insuranceQA_cache.pickle"
+google_news_pretrain_embeddings_link = "https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz"
+
+
+def get_idx2word(data_dir):
+    idx2word = dict()
+    with open(data_dir + "vocabulary") as vc_f:
+        for line in vc_f:
+            pair = line.strip().split("\t")
+            idx = int(pair[0].replace("idx_", ""))
+            idx2word[idx] = pair[1]
+
+    # add padding string to idx2word lookup
+    idx2word[pad_idx] = pad_string
+
+    return idx2word
+
+
+def get_train_raw(data_dir, data_filename):
+    ''' deserialize training data file
+        args:
+            data_dir: dir of data file
+        return:
+            train_raw: list of QnA pair, length of list  == number of samples,
+                each pair has 3 fields:
+                    0 is question sentence idx encoded, use idx2word to decode,
+                        idx2vec to get embedding.
+                    1 is ans labels, each label corresponds to a ans sentence,
+                        use label2answer to decode.
+                    2 is top K candidate ans, these are negative ans for
+                        training.
+    '''
+    train_raw = []
+    import gzip
+    with gzip.open(data_dir + data_filename) as fin:
+        for line in fin:
+            tpl = line.decode().strip().split("\t")
+            question = [
+                int(idx.replace("idx_", "")) for idx in tpl[1].split(" ")
+            ]
+            ans = [int(label) for label in tpl[2].split(" ")]
+            candis = [int(label) for label in tpl[3].split(" ")]
+            train_raw.append((question, ans, candis))
+    return train_raw
+
+
+def limit_encode_train(train_raw, label2answer, idx2word, q_seq_limit,
+                       ans_seq_limit, idx2vec):
+    ''' prepare train data to embedded word vector sequence given sequence limit
+        return:
+            questions_encoded: np ndarray, shape
+                (number samples, seq length, vector size)
+            poss_encoded: same layout, sequence for positive answer
+            negs_encoded: same layout, sequence for negative answer
+    '''
+    questions = [question for question, answers, candis in train_raw]
+    # choose 1 answer from answer pool
+    poss = [
+        label2answer[random.choice(answers)]
+        for question, answers, candis in train_raw
+    ]
+    # choose 1 candidate from candidate pool
+    negs = [
+        label2answer[random.choice(candis)]
+        for question, answers, candis in train_raw
+    ]
+
+    # filtered word not in idx2vec
+    questions_filtered = [
+        [idx for idx in q if idx in idx2vec] for q in questions
+    ]
+    poss_filtered = [[idx for idx in ans if idx in idx2vec] for ans in poss]
+    negs_filtered = [[idx for idx in ans if idx in idx2vec] for ans in negs]
+
+    # crop to seq limit
+    questions_crop = [
+        q[:q_seq_limit] + [0] * max(0, q_seq_limit - len(q))
+        for q in questions_filtered
+    ]
+    poss_crop = [
+        ans[:ans_seq_limit] + [0] * max(0, ans_seq_limit - len(ans))
+        for ans in poss_filtered
+    ]
+    negs_crop = [
+        ans[:ans_seq_limit] + [0] * max(0, ans_seq_limit - len(ans))
+        for ans in negs_filtered
+    ]
+
+    # encoded, word idx to word vector
+    questions_encoded = [[idx2vec[idx] for idx in q] for q in questions_crop]
+    poss_encoded = [[idx2vec[idx] for idx in ans] for ans in poss_crop]
+    negs_encoded = [[idx2vec[idx] for idx in ans] for ans in negs_crop]
+
+    # make nd array
+    questions_encoded = np.array(questions_encoded).astype(np.float32)
+    poss_encoded = np.array(poss_encoded).astype(np.float32)
+    negs_encoded = np.array(negs_encoded).astype(np.float32)
+    return questions_encoded, poss_encoded, negs_encoded
+
+
+def get_idx2vec_weights(wv, idx2word):
+    idx2vec = {k: wv[v] for k, v in idx2word.items() if v in wv}
+
+    # add padding embedding (all zeros) to idx2vec lookup
+    idx2vec[pad_idx] = pad_embed
+    return idx2vec
+
+
+def prepare_data(use_cache=True):
+    import pickle
+    if not os.path.isfile(insuranceQA_cache_fp) or not use_cache:
+        # no cache is found, preprocess data from scratch
+        print("prepare data from scratch")
+
+        # get pretained word vector
+        from gensim.models.keyedvectors import KeyedVectors
+        google_news_pretrain_fp = check_exist_or_download(
+            google_news_pretrain_embeddings_link)
+        wv = KeyedVectors.load_word2vec_format(google_news_pretrain_fp,
+                                               binary=True)
+
+        # prepare insurance QA dataset
+        data_zip = check_exist_or_download(insuranceQA_url)
+        data_dir = unzip_data(download_dir, data_zip)
+
+        label2answer = get_label2answer(data_dir)
+        idx2word = get_idx2word(data_dir)
+        idx2vec = get_idx2vec_weights(wv, idx2word)
+
+        train_raw = get_train_raw(data_dir, insuranceqa_train_filename)
+        test_raw = get_train_raw(data_dir, insuranceqa_test_filename)
+        with open(insuranceQA_cache_fp, 'wb') as handle:
+            pickle.dump((train_raw, test_raw, label2answer, idx2word, idx2vec),
+                        handle,
+                        protocol=pickle.HIGHEST_PROTOCOL)
+    else:
+        # load from cached pickle
+        with open(insuranceQA_cache_fp, 'rb') as handle:
+            (train_raw, test_raw, label2answer, idx2word,
+             idx2vec) = pickle.load(handle)
+
+    return train_raw, test_raw, label2answer, idx2word, idx2vec
+
+
+def limit_encode_eval(train_raw,
+                      label2answer,
+                      idx2word,
+                      q_seq_limit,
+                      ans_seq_limit,
+                      idx2vec,
+                      top_k_candi_limit=6):
+    ''' prepare train data to embedded word vector sequence given sequence limit for testing
+        return:
+            questions_encoded: np ndarray, shape
+                (number samples, seq length, vector size)
+            poss_encoded: same layout, sequence for positive answer
+            negs_encoded: same layout, sequence for negative answer
+    '''
+    questions = [question for question, answers, candis in train_raw]
+
+    # combine truth and candidate answers label,
+    candi_pools = [
+        list(answers + candis)[:top_k_candi_limit]
+        for question, answers, candis in train_raw
+    ]
+    assert all([len(pool) == top_k_candi_limit for pool in candi_pools])
+
+    ans_count = [len(answers) for question, answers, candis in train_raw]
+    assert all([c > 0 for c in ans_count])
+
+    # encode ans
+    candi_pools_encoded = [[label2answer[candi_label]
+                            for candi_label in pool]
+                           for pool in candi_pools]
+
+    # filtered word not in idx2vec
+    questions_filtered = [
+        [idx for idx in q if idx in idx2vec] for q in questions
+    ]
+    candi_pools_filtered = [[[idx
+                              for idx in candi_encoded
+                              if idx in idx2vec]
+                             for candi_encoded in pool]
+                            for pool in candi_pools_encoded]
+
+    # crop to seq limit
+    questions_crop = [
+        q[:q_seq_limit] + [0] * max(0, q_seq_limit - len(q))
+        for q in questions_filtered
+    ]
+    candi_pools_crop = [[
+        candi[:ans_seq_limit] + [0] * max(0, ans_seq_limit - len(candi))
+        for candi in pool
+    ]
+                        for pool in candi_pools_filtered]
+
+    # encoded, word idx to word vector
+    questions_encoded = [[idx2vec[idx] for idx in q] for q in questions_crop]
+    candi_pools_encoded = [[[idx2vec[idx]
+                             for idx in candi]
+                            for candi in pool]
+                           for pool in candi_pools_crop]
+    questions_encoded = np.array(questions_encoded).astype(np.float32)
+    candi_pools_encoded = np.array(candi_pools_encoded).astype(np.float32)
+
+    # candi_pools_encoded shape
+    #    (number of sample QnA,
+    #     number of candi in pool,
+    #     number of sequence word idx per candi,
+    #     300 word embedding for 1 word idx)
+    #  e.g 10 QnA to test
+    #      5 each question has 5 possible ans
+    #      8 each ans has 8 words
+    #      300 each word has vector size 300
+    return questions_encoded, candi_pools_encoded, ans_count
diff --git a/examples/qabot/qabot_model.py b/examples/qabot/qabot_model.py
new file mode 100644
index 0000000..d5a9d88
--- /dev/null
+++ b/examples/qabot/qabot_model.py
@@ -0,0 +1,152 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import autograd, layer, model
+
+
+class QAModel_mlp(model.Model):
+
+    def __init__(self, hidden_size):
+        super().__init__()
+        self.linear_q = layer.Linear(hidden_size)
+        self.linear_a = layer.Linear(hidden_size)
+
+    def forward(self, q, a_batch):
+        q = autograd.reshape(q, (q.shape[0], -1))  # bs, seq_q*data_s
+        a_batch = autograd.reshape(a_batch,
+                                   (a_batch.shape[0], -1))  # 2bs, seq_a*data_s
+
+        q = self.linear_q(q)  # bs, hid_s
+        a_batch = self.linear_a(a_batch)  # 2bs, hid_s
+
+        a_pos, a_neg = autograd.split(a_batch, 0,
+                                      [q.shape[0], q.shape[0]])  # 2*(bs, hid)
+
+        sim_pos = autograd.cossim(q, a_pos)
+        sim_neg = autograd.cossim(q, a_neg)
+        return sim_pos, sim_neg
+
+
+class QAModel(model.Model):
+
+    def __init__(self,
+                 hidden_size,
+                 num_layers=1,
+                 bidirectional=True,
+                 return_sequences=False):
+        super(QAModel, self).__init__()
+        self.hidden_size = hidden_size
+        self.lstm_q = layer.CudnnRNN(hidden_size=hidden_size,
+                                     bidirectional=bidirectional,
+                                     return_sequences=return_sequences)
+        self.lstm_a = layer.CudnnRNN(hidden_size=hidden_size,
+                                     bidirectional=bidirectional,
+                                     return_sequences=return_sequences)
+
+    def forward(self, q, a_batch):
+        q = self.lstm_q(q)  # bs, Hidden*2
+        a_batch = self.lstm_a(a_batch)  # 2bs, Hidden*2
+
+        bs_a = q.shape[0]
+        # bs, hid*2
+        a_pos, a_neg = autograd.split(a_batch, 0, [bs_a, bs_a])
+
+        sim_pos = autograd.cossim(q, a_pos)
+        sim_neg = autograd.cossim(q, a_neg)
+        return sim_pos, sim_neg
+
+
+class QAModel_mean(model.Model):
+
+    def __init__(self, hidden_size, bidirectional=True, return_sequences=True):
+        super(QAModel_mean, self).__init__()
+        self.hidden_size = hidden_size
+        self.lstm_q = layer.CudnnRNN(hidden_size=hidden_size,
+                                     batch_first=True,
+                                     bidirectional=bidirectional,
+                                     return_sequences=return_sequences)
+        self.lstm_a = layer.CudnnRNN(hidden_size=hidden_size,
+                                     batch_first=True,
+                                     bidirectional=bidirectional,
+                                     return_sequences=return_sequences)
+
+    def forward(self, q, a_batch):
+        q = self.lstm_q(q)  # bs, seq, Hidden*2
+        a_batch = self.lstm_a(a_batch)  # 2bs, seq, Hidden*2
+
+        # bs, hid*2
+        q = autograd.reduce_mean(q, [1], keepdims=0)
+        # (2bs, hid*2)
+        a_batch = autograd.reduce_mean(a_batch, [1], keepdims=0)
+
+        # 2*(bs, seq, hid*2)
+        a_pos, a_neg = autograd.split(a_batch, 0, [q.shape[0], q.shape[0]])
+
+        sim_pos = autograd.cossim(q, a_pos)
+        sim_neg = autograd.cossim(q, a_neg)
+        return sim_pos, sim_neg
+
+
+class QAModel_maxpooling(model.Model):
+
+    def __init__(self,
+                 hidden_size,
+                 q_seq,
+                 a_seq,
+                 num_layers=1,
+                 bidirectional=True,
+                 return_sequences=True):
+        super(QAModel_maxpooling, self).__init__()
+        self.hidden_size = hidden_size
+        self.lstm_q = layer.CudnnRNN(hidden_size=hidden_size,
+                                     bidirectional=bidirectional,
+                                     return_sequences=return_sequences)
+        self.lstm_a = layer.CudnnRNN(hidden_size=hidden_size,
+                                     bidirectional=bidirectional,
+                                     return_sequences=return_sequences)
+        self.q_pool = layer.MaxPool2d((q_seq, 1))
+        self.a_pool = layer.MaxPool2d((a_seq, 1))
+
+    def forward(self, q, a_batch):
+        # bs, seq, Hidden*2
+        q = self.lstm_q(q)
+        # bs, 1, seq, hid*2
+        q = autograd.reshape(q, (q.shape[0], 1, q.shape[1], q.shape[2]))
+        # bs, 1, 1, hid*2
+        q = self.q_pool(q)
+        # bs, hid*2
+        q = autograd.reshape(q, (q.shape[0], q.shape[3]))
+
+        # 2bs, seq, Hidden*2
+        a_batch = self.lstm_a(a_batch)
+        # 2bs, 1, seq, hid*2
+        a_batch = autograd.reshape(
+            a_batch, (a_batch.shape[0], 1, a_batch.shape[1], a_batch.shape[2]))
+        # 2bs, 1, 1, hid*2
+        a_batch = self.a_pool(a_batch)
+        # 2bs, hid*2
+        a_batch = autograd.reshape(a_batch,
+                                   (a_batch.shape[0], a_batch.shape[3]))
+
+        # 2*(bs, hid*2)
+        a_pos, a_neg = autograd.split(a_batch, 0, [q.shape[0], q.shape[0]])
+
+        sim_pos = autograd.cossim(q, a_pos)
+        sim_neg = autograd.cossim(q, a_neg)
+        return sim_pos, sim_neg
\ No newline at end of file
diff --git a/examples/qabot/qabot_train.py b/examples/qabot/qabot_train.py
new file mode 100644
index 0000000..71e0c5c
--- /dev/null
+++ b/examples/qabot/qabot_train.py
@@ -0,0 +1,159 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import numpy as np
+import time
+import random
+from tqdm import tqdm
+import argparse
+
+from singa import autograd, tensor, device, opt
+from qabot_data import limit_encode_train, limit_encode_eval, prepare_data
+from qabot_model import QAModel_maxpooling
+
+
+def do_train(m, tq, ta, train, meta_data, args):
+    '''
+    batch size need to be large to see all negative ans
+    '''
+    m.train()
+    for epoch in range(args.epochs):
+        total_loss = 0
+        start = time.time()
+
+        q, ans_p, ans_n = limit_encode_train(train, meta_data['label2answer'],
+                                             meta_data['idx2word'],
+                                             args.q_seq_limit,
+                                             args.ans_seq_limit,
+                                             meta_data['idx2vec'])
+        bs = args.bs
+
+        for i in tqdm(range(len(q) // bs)):
+            tq.copy_from_numpy(q[i * bs:(i + 1) * bs])
+            a_batch = np.concatenate(
+                [ans_p[i * bs:(i + 1) * bs], ans_n[i * bs:(i + 1) * bs]])
+            ta.copy_from_numpy(a_batch)
+
+            p_sim, n_sim = m.forward(tq, ta)
+            l = autograd.ranking_loss(p_sim, n_sim)
+            m.optimizer(l)
+
+            total_loss += tensor.to_numpy(l)
+        print(
+            "epoch %d, time used %d sec, loss: " % (epoch, time.time() - start),
+            total_loss * bs / len(q))
+
+
+def do_eval(m, tq, ta, test, meta_data, args):
+    q, candis, ans_count = limit_encode_eval(test, meta_data['label2answer'],
+                                             meta_data['idx2word'],
+                                             args.q_seq_limit,
+                                             args.ans_seq_limit,
+                                             meta_data['idx2vec'],
+                                             args.number_of_candidates)
+    m.eval()
+    candi_pool_size = candis.shape[1]
+    correct = 0
+    start = time.time()
+    for i in tqdm(range(len(q))):
+        # batch size bs must satisfy: bs == repeated q, bs == number of answers//2
+        # 1 question repeat n times, n == number of answers//2
+        _q = np.repeat([q[i]], candi_pool_size // 2, axis=0)
+        tq.copy_from_numpy(_q)
+        ta.copy_from_numpy(candis[i])
+
+        (first_half_score, second_half_score) = m.forward(tq, ta)
+
+        first_half_score = tensor.to_numpy(first_half_score)
+        second_half_score = tensor.to_numpy(second_half_score)
+        scores = np.concatenate((first_half_score, second_half_score))
+        pred_max_idx = np.argmax(scores)
+
+        if pred_max_idx < ans_count[i]:
+            correct += 1
+
+    print("eval top %s " % (candi_pool_size), " accuracy", correct / len(q),
+          " time used %d sec" % (time.time() - start))
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-m',
+                        '--max-epoch',
+                        default=30,
+                        type=int,
+                        help='maximum epochs',
+                        dest='epochs')
+    parser.add_argument('-b',
+                        '--batch-size',
+                        default=50,
+                        type=int,
+                        help='batch size',
+                        dest='bs')
+    parser.add_argument('-l',
+                        '--learning-rate',
+                        default=0.01,
+                        type=float,
+                        help='initial learning rate',
+                        dest='lr')
+    parser.add_argument('-i',
+                        '--device-id',
+                        default=0,
+                        type=int,
+                        help='which GPU to use',
+                        dest='device_id')
+
+    args = parser.parse_args()
+
+    args.hid_s = 64
+    args.q_seq_limit = 10
+    args.ans_seq_limit = 50
+    args.embed_size = 300
+    args.number_of_candidates = args.bs * 2
+    assert args.number_of_candidates <= 100, "number_of_candidates should be <= 100"
+
+    dev = device.create_cuda_gpu_on(args.device_id)
+
+    # tensor container
+    tq = tensor.random((args.bs, args.q_seq_limit, args.embed_size), dev)
+    ta = tensor.random((args.bs * 2, args.ans_seq_limit, args.embed_size), dev)
+
+    # model
+    m = QAModel_maxpooling(args.hid_s,
+                           q_seq=args.q_seq_limit,
+                           a_seq=args.ans_seq_limit)
+    m.compile([tq, ta], is_train=True, use_graph=False, sequential=False)
+    m.optimizer = opt.SGD(args.lr, 0.9)
+
+    # get data
+    train_raw, test_raw, label2answer, idx2word, idx2vec = prepare_data()
+    meta_data = {
+        'label2answer': label2answer,
+        'idx2word': idx2word,
+        'idx2vec': idx2vec
+    }
+
+    print("training...")
+    do_train(m, tq, ta, train_raw, meta_data, args)
+
+    print("Eval with train data...")
+    do_eval(m, tq, ta, random.sample(train_raw, 2000), meta_data, args)
+
+    print("Eval with test data...")
+    do_eval(m, tq, ta, test_raw, meta_data, args)
diff --git a/examples/qabot/train.py b/examples/qabot/train.py
deleted file mode 100644
index f9007a1..0000000
--- a/examples/qabot/train.py
+++ /dev/null
@@ -1,185 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-import sys
-build_path = r'build/python'
-sys.path.append(build_path)
-model_path = r'examples/qabot'
-sys.path.append(model_path)
-
-import time
-from singa import model
-from singa import tensor
-from singa import device
-from singa import opt
-
-from gensim.models.keyedvectors import KeyedVectors
-import numpy as np
-
-from data import parse_file, parse_test_file, load_vocabulary, generate_qa_triplets, words_text_to_fixed_seqlen_vec, triplet_text_to_vec, train_data_gen_fn
-from model import QAModel
-
-# params
-q_max_len = 15
-a_max_len = 150
-bs = 50 # as tq, ta use fix bs, bs should be factor of test size - 100
-embed_size = 300
-hidden_size = 100
-max_epoch = 2
-
-dev = device.create_cuda_gpu()
-# dev = device.create_cuda_gpu_on(7)
-
-# embeding
-embed_path = 'GoogleNews-vectors-negative300.bin'
-wv = KeyedVectors.load_word2vec_format(embed_path, binary=True)
-print("successfully loaded word2vec model")
-
-# vocab
-id_to_word, label_to_ans, label_to_ans_text = load_vocabulary(
-    './V2/vocabulary', './V2/InsuranceQA.label2answer.token.encoded')
-print("loaded vocab")
-
-train_data = parse_file(
-    './V2/InsuranceQA.question.anslabel.token.100.pool.solr.train.encoded',
-    id_to_word, label_to_ans_text)
-test_data = parse_test_file(
-    './V2/InsuranceQA.question.anslabel.token.100.pool.solr.test.encoded',
-    id_to_word, label_to_ans_text)
-# train_data = train_data[:100]
-# test_data = test_data[:100]
-print("loaded train data")
-
-
-def load_model(max_bs, hidden_size):
-    m = QAModel(hidden_size)
-    m.optimizer = opt.SGD()
-    tq = tensor.Tensor((max_bs, q_max_len, embed_size), dev, tensor.float32)
-    ta = tensor.Tensor((max_bs * 2, a_max_len, embed_size), dev, tensor.float32)
-    tq.set_value(0.0)
-    ta.set_value(0.0)
-    m.compile([tq, ta], is_train=True, use_graph=False, sequential=False)
-    # m.compile([tq, ta], is_train=True, use_graph=True, sequential=True)
-    # m.compile([tq, ta], is_train=True, use_graph=True, sequential=False)
-    return m
-
-
-def training_top1_hits(m, wv, q_max_len, a_max_len, train_data):
-    m.eval()
-    hits = 0
-    train_eval_data = [
-        train_eval_format(r, wv, q_max_len, a_max_len) for r in train_data
-    ]
-    trials = len(train_eval_data)
-    for q, a in train_eval_data:
-        tq = tensor.from_numpy(q.astype(np.float32))
-        ta = tensor.from_numpy(a.astype(np.float32))
-        tq.to_device(dev)
-        ta.to_device(dev)
-        out = m.forward(tq, ta)
-        sim_first_half = tensor.to_numpy(out[0])
-        sim_second_half = tensor.to_numpy(out[1])
-        sim = np.concatenate([sim_first_half, sim_second_half]).flatten()
-        if np.argmax(sim) == 0:
-            hits += 1
-    # print("training top1 hits rate: ", hits/trials)
-    return hits / trials
-
-
-def training(m, all_train_data, max_epoch, eval_split_ratio=0.8):
-    split_num = int(eval_split_ratio * len(all_train_data))
-    train_data = all_train_data[:split_num]
-    eval_data = all_train_data[split_num:]
-
-    train_triplets = generate_qa_triplets(train_data) # triplet = <q, a+, a->
-    train_triplet_vecs = [
-        triplet_text_to_vec(t, wv, q_max_len, a_max_len) for t in train_triplets
-    ] # triplet vecs = <q_vec, a+_vec, a-_vec>
-    train_data_gen = train_data_gen_fn(train_triplet_vecs, bs)
-    m.train()
-
-    tq = tensor.Tensor((bs, q_max_len, embed_size), dev, tensor.float32)
-    ta = tensor.Tensor((bs * 2, a_max_len, embed_size), dev, tensor.float32)
-    for epoch in range(max_epoch):
-        start = time.time()
-        for q, a in train_data_gen: 
-            #     print(tq.shape) # (bs,seq,embed)
-            #     print(ta.shape) # (bs*2, seq, embed)
-            tq.copy_from_numpy(q)
-            ta.copy_from_numpy(a)
-            score, l = m(tq, ta)
-        top1hits = training_top1_hits(m, wv, q_max_len, a_max_len, train_data)
-        print(
-            "epoch %d, time used %d sec, top1 hits: %f, loss: " %
-            (epoch, time.time() - start, top1hits), l)
-
-
-def train_eval_format(row, wv, q_max_len, a_max_len):
-    q, apos, anegs = row
-    all_a = [apos] + anegs
-    a_vecs = [words_text_to_fixed_seqlen_vec(wv, a, a_max_len) for a in all_a]
-    if len(a_vecs) % 2 == 1:
-        a_vecs.pop(-1)
-    assert len(a_vecs) % 2 == 0
-    q_repeat = int(len(a_vecs) / 2)
-    q_vecs = [words_text_to_fixed_seqlen_vec(wv, q, q_max_len)] * q_repeat
-    return np.array(q_vecs), np.array(a_vecs)
-
-
-def test_format(r, wv, q_max_len, a_max_len):
-    q_text, labels, candis = r
-    candis_vecs = [
-        words_text_to_fixed_seqlen_vec(wv, label_to_ans_text[a_label],
-                                       a_max_len) for a_label in candis
-    ]
-    if len(candis_vecs) % 2 == 1:
-        candis_vecs.pop(-1)
-    assert len(candis_vecs) % 2 == 0
-    q_repeat = int(len(candis_vecs) / 2)
-    q_vecs = [words_text_to_fixed_seqlen_vec(wv, q_text, q_max_len)] * q_repeat
-    labels_idx = [candis.index(l) for l in labels if l in candis]
-    return np.array(q_vecs), np.array(candis_vecs), labels, labels_idx
-
-def testing(m, test_data):
-    test_tuple_vecs = [
-        test_format(r, wv, q_max_len, a_max_len) for r in test_data
-    ]
-    m.eval()
-    hits = 0
-    trials = len(test_tuple_vecs)
-
-    tq = tensor.Tensor((bs, q_max_len, embed_size), dev, tensor.float32)
-    ta = tensor.Tensor((bs * 2, a_max_len, embed_size), dev, tensor.float32)
-    for q, a, labels, labels_idx in test_tuple_vecs:
-        # print(q.shape) # (50, seq, embed)
-        # print(a.shape) # (100, seq, embed)
-        tq.copy_from_numpy(q)
-        ta.copy_from_numpy(a)
-        out = m.forward(tq, ta)
-        sim_first_half = tensor.to_numpy(out[0])
-        sim_second_half = tensor.to_numpy(out[1])
-        sim = np.concatenate([sim_first_half, sim_second_half]).flatten()
-        if np.argmax(sim) in labels_idx:
-            hits += 1
-    print("training top1 hits rate: ", hits / trials)
-
-
-m = load_model(bs, hidden_size)
-training(m, train_data, max_epoch)
-testing(m, test_data)
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index f29ae62..92039f2 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -1204,42 +1204,44 @@ def cross_entropy(y, t):
     return CrossEntropy()(y, t)[0]
 
 
-class QALSTMLoss(Operator):
+class RankingLoss(Operator):
 
     def __init__(self, M=0.2):
-        super(QALSTMLoss, self).__init__()
+        super().__init__()
+        # margin
         self.M = M
 
     def forward(self, pos, neg):
-        # L = max{0, M - cosine(q, a+) + cosine(q, a-)}
+        # L = max{0, M - fn(pos) + fn(neg)}
         zero = singa.Tensor(list(pos.shape()), pos.device())
         zero.SetFloatValue(0.0)
         val = singa.AddFloat(singa.__sub__(neg, pos), self.M)
         gt_zero = singa.__gt__(val, zero)
-        self.inputs = (gt_zero,)  # (BS,)
+        if training:
+            self.inputs = (gt_zero,)  # (BS,)
         all_loss = singa.__mul__(gt_zero, val)
         loss = singa.SumAll(all_loss)
         loss /= (pos.shape()[0])
-        # assert loss.shape(0) == 1
         return loss
 
     def backward(self, dy=1.0):
+        assert training, "enable training mode to do backward"
         # dpos = -1 if M-pos+neg > 0 else 0
         # dneg =  1 if M-pos+neg > 0 else 0
         gt_zero = self.inputs[0]
         dpos_factor = singa.Tensor(list(gt_zero.shape()), gt_zero.device())
-        dpos_factor.SetFloatValue(-1.0)
+        dpos_factor.SetFloatValue(-1.0 / gt_zero.Size())
         dneg_factor = singa.Tensor(list(gt_zero.shape()), gt_zero.device())
-        dneg_factor.SetFloatValue(1.0)
+        dneg_factor.SetFloatValue(1.0 / gt_zero.Size())
         dpos = singa.__mul__(gt_zero, dpos_factor)
         dneg = singa.__mul__(gt_zero, dneg_factor)
         return dpos, dneg
 
 
-def qa_lstm_loss(pos, neg, M=0.2):
+def ranking_loss(pos, neg, M=0.2):
     assert pos.shape == neg.shape, "input and target shape different: %s, %s" % (
         pos.shape, neg.shape)
-    return QALSTMLoss(M)(pos, neg)[0]
+    return RankingLoss(M)(pos, neg)[0]
 
 
 class SoftMaxCrossEntropy(Operator):
@@ -1279,12 +1281,15 @@ class MeanSquareError(Operator):
         self.err = singa.__sub__(x, t)
         sqr = singa.Square(self.err)
         loss = singa.SumAll(sqr)
-        loss /= (x.shape()[0] * 2)
+        self.n = 1
+        for s in x.shape():
+            self.n *= s
+        loss /= self.n
         return loss
 
     def backward(self, dy=1.0):
         dx = self.err
-        dx *= float(1 / self.err.shape()[0])
+        dx *= float(2 / self.n)
         dx *= dy
         return dx, None
 
@@ -1292,10 +1297,6 @@ class MeanSquareError(Operator):
 def mse_loss(x, t):
     assert x.shape == t.shape, "input and target shape different: %s, %s" % (
         x.shape, t.shape)
-    assert x.ndim() == 2, "2d input required, input shapes: %s, %s" % (x.shape,
-                                                                       t.shape)
-    assert t.ndim() == 2, "2d input required, input shapes: %s, %s" % (x.shape,
-                                                                       t.shape)
     return MeanSquareError()(x, t)[0]
 
 
@@ -3930,6 +3931,7 @@ class ReduceMean(Operator):
             _x = tensor.reshape(_x, x_shape)
         self.cache = (x_shape, x)
         scale = np.prod(x_shape) / np.prod(x.shape())
+        self.scale = scale
         _x = singa.MultFloat(_x.data, scale)
         return _x
 
@@ -3946,6 +3948,7 @@ class ReduceMean(Operator):
         mask = singa.Tensor(list(x.shape()), x.device())
         mask.SetFloatValue(1.0)
         dy = singa.__mul__(mask, dy)
+        dy = singa.MultFloat(dy, self.scale)
         return dy
 
 
@@ -4585,17 +4588,17 @@ class _RNN(Operator):
     """ RNN operation with c++ backend
     """
 
-    def __init__(self,
-                 handle,
-                 return_sequences=False,
-                 batch_first=True,
-                 use_mask=False,
-                 seq_lengths=None):
+    def __init__(
+            self,
+            handle,
+            return_sequences=False,
+            #  batch_first=True,
+            use_mask=False,
+            seq_lengths=None):
         assert singa.USE_CUDA, "Not able to run without CUDA"
         super(_RNN, self).__init__()
         self.handle = handle
         self.return_sequences = return_sequences
-        self.batch_first = batch_first
         self.use_mask = use_mask
         if use_mask:
             assert type(seq_lengths) == Tensor, "wrong type for seq_lengths"
@@ -4631,10 +4634,11 @@ class _RNN(Operator):
                  cy) = singa.GpuRNNForwardInference(x, hx, cx, w, self.handle)
 
         if self.return_sequences:
-            # return full y {seq, bs, data_size}
+            # (seq, bs, data)
             return y
         else:
             # return last time step of y
+            # (seq, bs, data)[-1] -> (bs, data)
             last_y_shape = (y.shape()[1], y.shape()[2])
             last_y = singa.Tensor(list(last_y_shape), x.device())
 
@@ -4647,19 +4651,16 @@ class _RNN(Operator):
         assert training is True and hasattr(
             self, "inputs"), "Please set training as True before do BP. "
 
+        # (seq, bs, hid)
         dy = None
         if self.return_sequences:
-            # from: dy shape {bs, seq, ..}
-            # to:   dy shape {seq, bs, ..}
             assert grad.shape() == self.inputs['y'].shape(), (
                 "grad shape %s != y shape %s" %
                 (grad.shape(), self.inputs['y'].shape()))
             dy = grad
-            dy = singa.Transpose(dy, (1, 0, 2))
         else:
-            # from: grad shape (bs, directions*hidden)
-            # to:     dy shape (seq, bs, directions*hidden)
-            #                  empty space filled by zeros
+            # grad (bs, directions*hidden) -> dy (seq, bs, directions*hidden)
+            #   empty space filled by zeros
             assert grad.shape() == (self.inputs['y'].shape()[1],
                                     self.inputs['y'].shape()[2]), (
                                         "grad y shape %s != last y shape %s" %
@@ -4694,9 +4695,6 @@ class _RNN(Operator):
             dW = singa.GpuRNNBackwardW(self.inputs['x'], self.inputs['hx'],
                                        self.inputs['y'], self.handle)
 
-        # dx {seq, bs, ..} => {bat, seq, ..}
-        if self.batch_first:
-            dx = singa.Transpose(dx, list((1, 0, 2)))
 
         return dx, dhx, dcx, dW
 
@@ -4764,10 +4762,13 @@ class CosSim(Operator):
         ad = singa.Reshape(ad, list(ad.shape()) + [1])  # b * 1
         bd = singa.Reshape(bd, list(bd.shape()) + [1])  # b * 1
         ret = singa.Reshape(ret, list(ret.shape()) + [1])  # b * 1
+        dy = singa.Reshape(dy, list(dy.shape()) + [1])  # boardcast
         da = singa.__sub__(singa.__div__(b, ab),
-                           singa.__mul__(ret, singa.__div__(a, ad)))
+                           singa.__div__(singa.__mul__(ret, a), ad))
         db = singa.__sub__(singa.__div__(a, ab),
-                           singa.__mul__(ret, singa.__div__(b, bd)))
+                           singa.__div__(singa.__mul__(ret, b), bd))
+        da = singa.__mul__(dy, da)
+        db = singa.__mul__(dy, db)
         return da, db
 
 
@@ -4780,6 +4781,9 @@ def cossim(a, b):
     Returns:
         the output Tensor.
     """
+    assert a.shape == b.shape, "shape not match for cossim"
+    assert a.ndim() == 2, "shape should be in 2d for cossim"
+    assert b.ndim() == 2, "shape should be in 2d for cossim"
     return CosSim()(a, b)[0]
 
 
diff --git a/python/singa/layer.py b/python/singa/layer.py
index 23f23f9..3002333 100644
--- a/python/singa/layer.py
+++ b/python/singa/layer.py
@@ -1414,12 +1414,12 @@ class CudnnRNN(Layer):
                  activation="tanh",
                  num_layers=1,
                  bias=True,
-                 batch_first=False,
+                 batch_first=True,
                  dropout=0,
                  bidirectional=False,
                  rnn_mode="lstm",
                  use_mask=False,
-                 return_sequences=False):
+                 return_sequences=True):
         """
             Args:
                 hidden_size: hidden feature dim
@@ -1467,13 +1467,15 @@ class CudnnRNN(Layer):
                         requires_grad=True,
                         stores_grad=True,
                         device=x.device)
-        self.W.gaussian(0, 1)
+
+        k = 1 / self.hidden_size
+        self.W.uniform(-math.sqrt(k), math.sqrt(k))
 
     def forward(self, x, hx=None, cx=None, seq_lengths=None):
 
         self.device_check(x, self.W)
-        if self.batch_first:
-            x = x.transpose((1, 0, 2))
+        if self.batch_first:  # (bs,seq,data) -> (seq,bs,data)
+            x = autograd.transpose(x, (1, 0, 2))
 
         batch_size = x.shape[1]
         directions = 2 if self.bidirectional else 1
@@ -1496,17 +1498,16 @@ class CudnnRNN(Layer):
             assert type(seq_lengths) == Tensor, "wrong type for seq_lengths"
             y = autograd._RNN(self.handle,
                               return_sequences=self.return_sequences,
-                              batch_first=self.batch_first,
                               use_mask=self.use_mask,
                               seq_lengths=seq_lengths)(x, hx, cx, self.W)[0]
         else:
-            y = autograd._RNN(self.handle,
-                              return_sequences=self.return_sequences,
-                              batch_first=self.batch_first)(x, hx, cx,
-                                                            self.W)[0]
+            y = autograd._RNN(
+                self.handle,
+                return_sequences=self.return_sequences,
+            )(x, hx, cx, self.W)[0]
         if self.return_sequences and self.batch_first:
-            #   outputs has shape of {sequence length, batch size, hidden size}
-            y = y.transpose((1, 0, 2))  # to {bs, seq, hid}
+            # (seq, bs, hid) -> (bs, seq, hid)
+            y = autograd.transpose(y, (1, 0, 2))
         return y
 
     def get_params(self):
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index b49771a..0057804 100755
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -844,7 +844,7 @@ def copy_data_to_from(dst, src, size, dst_offset=0, src_offset=0):
     singa.CopyDataToFrom(dst.data, src.data, size, dst_offset, src_offset)
 
 
-def from_numpy(np_array):
+def from_numpy(np_array, dev=None):
     '''Create a Tensor instance with the shape, dtype and values from the numpy
     array.
 
@@ -870,6 +870,8 @@ def from_numpy(np_array):
         dtype = core_pb2.kInt
     ret = Tensor(np_array.shape, dtype=dtype)
     ret.copy_from_numpy(np_array)
+    if dev:
+        ret.to_device(dev)
     return ret
 
 
@@ -1774,3 +1776,30 @@ def concatenate(tensors, axis):
     for t in tensors:
         ctensors.append(t.data)
     return _call_singa_func(singa.ConcatOn, ctensors, axis)
+
+
+def random(shape, device=get_default_device()):
+    ''' return a random tensor with given shape
+
+    Args:
+        shape: shape of generated tensor
+        device: device of generated tensor, default is cpu
+
+    Returns:
+        new tensor generated
+    '''
+    ret = Tensor(shape, device=device)
+    ret.uniform(0, 1)
+    return ret
+
+
+def zeros(shape, device=get_default_device()):
+    ret = Tensor(shape, device=device)
+    ret.set_value(0.0)
+    return ret
+
+
+def ones(shape, device=get_default_device()):
+    ret = Tensor(shape, device=device)
+    ret.set_value(1.0)
+    return ret
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index d7fd0dc..b0d95a0 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -218,6 +218,9 @@ std::vector<Tensor> GpuRNNForwardInference(const Tensor &x, const Tensor &hx, co
 std::vector<Tensor> GpuRNNBackwardx(const Tensor &y, const Tensor &dy, const Tensor &dhy, const Tensor &dcy, const Tensor &W, const Tensor &hx, const Tensor &cx, CudnnRNNHandle &h);
 Tensor GpuRNNBackwardW(const Tensor &x, const Tensor &hx, const Tensor &y, CudnnRNNHandle &h);
 
+void GpuRNNSetParam(int linLayerID, int pseudoLayer, Tensor &weights, Tensor &paramValues, bool is_bias, CudnnRNNHandle &h);
+Tensor GpuRNNGetParamCopy(int linLayerID, int pseudoLayer, Tensor &weights, bool is_bias, CudnnRNNHandle &h);
+
 std::vector<Tensor> GpuRNNForwardTrainingEx(const Tensor &x, const Tensor &hx, const Tensor &cx, const Tensor &W, const Tensor &seq_lengths, CudnnRNNHandle &h);
 std::vector<Tensor> GpuRNNForwardInferenceEx(const Tensor &x, const Tensor &hx, const Tensor &cx, const Tensor &W, const Tensor &seq_lengths, CudnnRNNHandle &h);
 std::vector<Tensor> GpuRNNBackwardxEx(const Tensor &y, const Tensor &dy, const Tensor &dhy, const Tensor &dcy, const Tensor &W, const Tensor &hx, const Tensor &cx, const Tensor &seq_lengths, CudnnRNNHandle &h);
diff --git a/src/model/operation/rnn.cc b/src/model/operation/rnn.cc
index 5448eba..f774bea 100644
--- a/src/model/operation/rnn.cc
+++ b/src/model/operation/rnn.cc
@@ -20,6 +20,8 @@
  ************************************************************/
 
 #include "rnn.h"
+
+#include <map>
 namespace singa {
 #ifdef USE_CUDNN
 CudnnRNNHandle::CudnnRNNHandle(const Tensor &x, const int hidden_size,
@@ -32,6 +34,7 @@ CudnnRNNHandle::CudnnRNNHandle(const Tensor &x, const int hidden_size,
       hidden_size(hidden_size),
       mode(mode),
       num_layers(num_layers) {
+  // cudnn rnn bias is not available in cudnn v7.4.5, not found in cudnn.h
   CHECK_EQ(bias, 1) << "Current implementation always include bias";
   CHECK(bidirectional == 0 || bidirectional == 1)
       << "bidirectional should be 0 or 1 not " << bidirectional;
@@ -57,7 +60,7 @@ CudnnRNNHandle::CudnnRNNHandle(const Tensor &x, const int hidden_size,
   init_rnn_desc();
   init_parameters_desc(xDesc);
   init_workspace(xDesc);
-
+  init_param_mapping(xDesc);
   delete[] xDesc;
 }
 
@@ -126,7 +129,6 @@ void CudnnRNNHandle::init_dropout_desc() {
                                         states, stateSize, seed));
 }
 
-
 void init_yDesc(cudnnTensorDescriptor_t *yDesc, CudnnRNNHandle &h) {
   int dimA[] = {h.batch_size,
                 h.bidirectional ? h.hidden_size * 2 : h.hidden_size, 1};
@@ -394,6 +396,7 @@ Tensor GpuRNNBackwardW(const Tensor &x, const Tensor &hx, const Tensor &y,
   // x shape {seq, bs}
   // y shape {seq, bs}
   dW.SetValue(0.0f);
+  h.workspace.SetValue(0.0f);
   dW.device()->Exec(
       [dW, x, hx, y, &h](Context *ctx) {
         cudnnTensorDescriptor_t *xDesc =
@@ -427,6 +430,90 @@ Tensor GpuRNNBackwardW(const Tensor &x, const Tensor &hx, const Tensor &y,
   return dW;
 }
 
+void CudnnRNNHandle::init_param_mapping(cudnnTensorDescriptor_t *xDesc) {
+  int linLayerIDRange = 2;
+  if (mode == 0 || mode == 1) {
+    // vanilla relu/tanh
+    linLayerIDRange = 2;
+  } else if (mode == 2) {
+    // lstm
+    linLayerIDRange = 8;
+  } else if (mode == 3) {
+    // gru
+    linLayerIDRange = 6;
+  }
+  int pseudoLayerRange = (bidirectional ? 2 : 1) * num_layers;
+
+  // dummy weights for getting the offset
+  Tensor weights(
+      Shape{
+          weights_size,
+      },
+      dev);
+  weights.SetValue(0.0f);
+  const void *W_ptr = weights.block()->data();
+
+  void *param_ptr = nullptr;
+  int dims[] = {1, 1, 1};
+  cudnnDataType_t data_type;
+  cudnnTensorFormat_t tensor_format;
+  int n_dims;
+  cudnnFilterDescriptor_t paramDesc;
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&paramDesc));
+
+  vector<bool> paramTypes{false, true};
+  for (int linLayerID = 0; linLayerID < linLayerIDRange; linLayerID++) {
+    for (int pseudoLayer = 0; pseudoLayer < pseudoLayerRange; pseudoLayer++) {
+      for (const bool &is_bias : paramTypes) {
+        // get param ptr
+        if (is_bias) {
+          CUDNN_CHECK(cudnnGetRNNLinLayerBiasParams(
+              ctx->cudnn_handle, rnnDesc, pseudoLayer, xDesc[0], wDesc, W_ptr,
+              linLayerID, paramDesc, &param_ptr));
+        } else {
+          CUDNN_CHECK(cudnnGetRNNLinLayerMatrixParams(
+              ctx->cudnn_handle, rnnDesc, pseudoLayer, xDesc[0], wDesc, W_ptr,
+              linLayerID, paramDesc, &param_ptr));
+        }
+
+        // get param dims
+        CUDNN_CHECK(cudnnGetFilterNdDescriptor(paramDesc, 3, &data_type,
+                                               &tensor_format, &n_dims, dims));
+
+        // get diff - offset
+        size_t offset = (float *)param_ptr - (float *)W_ptr;
+
+        // save in map
+        weights_mapping[std::make_tuple(linLayerID, pseudoLayer, is_bias)] =
+            std::make_tuple(offset, dims[0] * dims[1] * dims[2]);
+      }
+    }
+  }
+}
+
+void GpuRNNSetParam(int linLayerID, int pseudoLayer, Tensor &weights,
+                    Tensor &paramValues, bool is_bias, CudnnRNNHandle &h) {
+  size_t offset, size;
+  std::tie(offset, size) =
+      h.weights_mapping[std::make_tuple(linLayerID, pseudoLayer, is_bias)];
+  CHECK_EQ(size, paramValues.size()) << "param size is not expected";
+  CopyDataToFrom(&weights, paramValues, size, offset, 0);
+}
+
+Tensor GpuRNNGetParamCopy(int linLayerID, int pseudoLayer, Tensor &weights,
+                          bool is_bias, CudnnRNNHandle &h) {
+  size_t offset, size;
+  std::tie(offset, size) =
+      h.weights_mapping[std::make_tuple(linLayerID, pseudoLayer, is_bias)];
+  Tensor paramCopy(
+      Shape{
+          size,
+      },
+      weights.device());
+  CopyDataToFrom(&paramCopy, weights, size, 0, offset);
+  return paramCopy;
+}
+
 /*
 vector<Tensor> GpuRNNForwardTrainingEx();
 vector<Tensor> GpuRNNForwardInferenceEx();
diff --git a/src/model/operation/rnn.h b/src/model/operation/rnn.h
index 59f096e..bbc9266 100644
--- a/src/model/operation/rnn.h
+++ b/src/model/operation/rnn.h
@@ -84,7 +84,12 @@ class CudnnRNNHandle {
   void init_rnn_desc();
   void init_parameters_desc(cudnnTensorDescriptor_t *xDesc);
   void init_workspace(cudnnTensorDescriptor_t *xDesc);
-  Tensor get_weight(size_t pseudo_layer, const Tensor &w, size_t lin_layer_id);
+  void init_param_mapping(cudnnTensorDescriptor_t *xDesc);
+
+  // linLayerID, pseudoLayer, is_bias => offset, size
+  // e.g. Wx of 1st layer is at <0,0,false> => 0, data_s*hid_s
+  std::map<std::tuple<int, int, bool>, std::tuple<size_t, size_t>>
+      weights_mapping;
 };
 
 void init_xDesc(cudnnTensorDescriptor_t *xDesc, CudnnRNNHandle &h);
@@ -104,6 +109,11 @@ vector<Tensor> GpuRNNBackwardx(const Tensor &y, const Tensor &dy,
 Tensor GpuRNNBackwardW(const Tensor &x, const Tensor &hx, const Tensor &y,
                        CudnnRNNHandle &h);
 
+void GpuRNNSetParam(int linLayerID, int pseudoLayer, Tensor &weights,
+                    Tensor &paramValues, bool is_bias, CudnnRNNHandle &h);
+Tensor GpuRNNGetParamCopy(int linLayerID, int pseudoLayer, Tensor &weights,
+                          bool is_bias, CudnnRNNHandle &h);
+
 vector<Tensor> GpuRNNForwardTrainingEx(const Tensor &x, const Tensor &hx,
                                        const Tensor &cx, const Tensor &W,
                                        const Tensor &seq_lengths,
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index 6cdbace..9a4f2e5 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -455,6 +455,72 @@ class TestPythonOperation(unittest.TestCase):
 
             self.gradients_check(valinna_rnn_forward, param, auto_grad, dev=dev)
 
+    def _gradient_check_cudnn_rnn(self, mode="vanilla", dev=gpu_dev):
+        seq = 10
+        bs = 2
+        fea = 10
+        hid = 10
+        x = np.random.random((seq, bs, fea)).astype(np.float32)
+        tx = tensor.Tensor(device=dev, data=x)
+        y = np.random.random((seq, bs, hid)).astype(np.float32)
+        y = np.reshape(y, (-1, hid))
+        ty = tensor.Tensor(device=dev, data=y)
+        rnn = layer.CudnnRNN(hid, rnn_mode=mode, return_sequences=True)
+
+        def vanilla_rnn_forward():
+            out = rnn(tx)
+            out = autograd.reshape(out, (-1, hid))
+            loss = autograd.softmax_cross_entropy(out, ty)
+            return loss
+
+        loss = vanilla_rnn_forward()
+        auto_grads = autograd.gradients(loss)
+
+        params = rnn.get_params()
+        for key, param in params.items():
+            auto_grad = tensor.to_numpy(auto_grads[id(param)])
+            self.gradients_check(vanilla_rnn_forward, param, auto_grad, dev=dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_gradient_check_cudnn_rnn_vanilla(self):
+        self._gradient_check_cudnn_rnn(mode="vanilla", dev=gpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_gradient_check_cudnn_rnn_lstm(self):
+        self._gradient_check_cudnn_rnn(mode="lstm", dev=gpu_dev)
+
+    # Cos Sim Gradient Check
+    def _gradient_check_cossim(self, dev=gpu_dev):
+        bs = 2
+        vec = 3
+        ta = tensor.random((bs, vec), dev)
+        tb = tensor.random((bs, vec), dev)
+        # treat ta, tb as params
+        ta.stores_grad = True
+        tb.stores_grad = True
+        ty = tensor.random((bs,), dev)
+
+        def _forward():
+            out = autograd.cossim(ta, tb)
+            loss = autograd.mse_loss(out, ty)
+            return loss
+
+        loss = _forward()
+        auto_grads = autograd.gradients(loss)
+
+        params = {id(ta): ta, id(tb): tb}
+
+        for key, param in params.items():
+            auto_grad = tensor.to_numpy(auto_grads[id(param)])
+            self.gradients_check(_forward, param, auto_grad, dev=dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_gradient_check_cossim_gpu(self):
+        self._gradient_check_cossim(dev=gpu_dev)
+
+    def test_gradient_check_cossim_cpu(self):
+        self._gradient_check_cossim(dev=cpu_dev)
+
     def test_numerical_gradients_check_for_vallina_rnn_cpu(self):
         self._numerical_gradients_check_for_vallina_rnn_helper(cpu_dev)
 
@@ -3220,18 +3286,15 @@ class TestPythonOperation(unittest.TestCase):
     def test_upsample_gpu(self):
         self.upsample_helper(gpu_dev)
 
-    def test_invalid_inputs(self,dev=cpu_dev):
+    def test_invalid_inputs(self, dev=cpu_dev):
         _1d = tensor.Tensor((10,), dev)
-        _2d = tensor.Tensor((10,10), dev)
-        _3d = tensor.Tensor((10,10,10), dev)
+        _2d = tensor.Tensor((10, 10), dev)
+        _3d = tensor.Tensor((10, 10, 10), dev)
         self.assertRaises(AssertionError, autograd.softmax_cross_entropy, _2d,
                           _3d)
-        self.assertRaises(AssertionError, autograd.mse_loss, _2d,
-                          _3d)
-        self.assertRaises(AssertionError, autograd.add_bias, _2d,
-                          _1d, 3)
-        self.assertRaises(AssertionError, autograd.qa_lstm_loss, _2d,
-                          _1d)
+        self.assertRaises(AssertionError, autograd.mse_loss, _2d, _3d)
+        self.assertRaises(AssertionError, autograd.add_bias, _2d, _1d, 3)
+        self.assertRaises(AssertionError, autograd.ranking_loss, _2d, _1d)
 
     def where_helper(self, dev):
         X = np.array([[1, 2], [3, 4]], dtype=np.float32)
@@ -3308,6 +3371,76 @@ class TestPythonOperation(unittest.TestCase):
     def test_round_gpu(self):
         self.round_helper(gpu_dev)
 
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def _cossim_value(self, dev=gpu_dev):
+        # numpy val
+        np.random.seed(0)
+        bs = 1000
+        vec_s = 1200
+        a = np.random.random((bs, vec_s)).astype(np.float32)
+        b = np.random.random((bs, vec_s)).astype(np.float32)
+        dy = np.random.random((bs,)).astype(np.float32)
+
+        # singa tensor
+        ta = tensor.from_numpy(a)
+        tb = tensor.from_numpy(b)
+        tdy = tensor.from_numpy(dy)
+        ta.to_device(dev)
+        tb.to_device(dev)
+        tdy.to_device(dev)
+
+        # singa forward and backward
+        ty = autograd.cossim(ta, tb)
+        tda, tdb = ty.creator.backward(tdy.data)
+
+        np_forward = list()
+        for i in range(len(a)):
+            a_norm = np.linalg.norm(a[i])
+            b_norm = np.linalg.norm(b[i])
+            ab_dot = np.dot(a[i], b[i])
+            out = ab_dot / (a_norm * b_norm)
+            np_forward.append(out)
+
+        np_backward_a = list()
+        np_backward_b = list()
+        for i in range(len(a)):
+            a_norm = np.linalg.norm(a[i])
+            b_norm = np.linalg.norm(b[i])
+            da = dy[i] * (b[i] / (a_norm * b_norm) - (np_forward[i] * a[i]) /
+                          (a_norm * a_norm))
+            db = dy[i] * (a[i] / (a_norm * b_norm) - (np_forward[i] * b[i]) /
+                          (b_norm * b_norm))
+            np_backward_a.append(da)
+            np_backward_b.append(db)
+
+        np.testing.assert_array_almost_equal(tensor.to_numpy(ty),
+                                             np.array(np_forward))
+        np.testing.assert_array_almost_equal(
+            tensor.to_numpy(tensor.from_raw_tensor(tda)), np_backward_a)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_cossim_value_gpu(self):
+        self._cossim_value(gpu_dev)
+
+    def test_cossim_value_cpu(self):
+        self._cossim_value(cpu_dev)
+
+    def test_mse_loss_value(self, dev=cpu_dev):
+        y = np.random.random((1000, 1200)).astype(np.float32)
+        tar = np.random.random((1000, 1200)).astype(np.float32)
+        # get singa value
+        sy = tensor.from_numpy(y, dev)
+        starget = tensor.from_numpy(tar, dev)
+        sloss = autograd.mse_loss(sy, starget)
+        sgrad = sloss.creator.backward()[0]
+        # get np value result
+        np_loss = np.mean(np.square(tar - y))
+        np_grad = -2 * (tar - y) / np.prod(tar.shape)
+        # value check
+        np.testing.assert_array_almost_equal(
+            tensor.to_numpy(tensor.from_raw_tensor(sgrad)), np_grad)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(sloss), np_loss)
+
 
 if __name__ == '__main__':
     unittest.main()