You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by ch...@apache.org on 2020/04/12 15:37:52 UTC

[singa] branch master updated: Update rnn example to use the Module API from v3.0

This is an automated email from the ASF dual-hosted git repository.

chrishkchris pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/master by this push:
     new e2fedf9  Update rnn example to use the Module API from v3.0
     new 536f7e4  Merge pull request #676 from nudles/master
e2fedf9 is described below

commit e2fedf9f98a6755d60f60b67555f5e09947ae800
Author: wang wei <wa...@gmail.com>
AuthorDate: Sat Apr 11 22:50:50 2020 +0800

    Update rnn example to use the Module API from v3.0
    
    Note the placeholder tensors for the input data and labels
    should be reused after the first iteration.
    TODO test model.graph(True, False).
---
 RELEASE_NOTES            |   2 +-
 examples/rnn/README.md   |  12 +-
 examples/rnn/train.py    | 285 ++++++++++++++++++++++++-----------------------
 python/singa/autograd.py |  54 ++++-----
 tool/release/release.py  |   2 +-
 5 files changed, 176 insertions(+), 179 deletions(-)

diff --git a/RELEASE_NOTES b/RELEASE_NOTES
index 5fc4df4..8491a6b 100644
--- a/RELEASE_NOTES
+++ b/RELEASE_NOTES
@@ -61,7 +61,7 @@ This release includes following changes:
     After analyzing the dependency, the computational graph is created, which is further analyzed for
     speed and memory optimization. To enable this feature, use the [Module API](./python/singa/module.py).
 
-  * New website based on Docusaurus. The documentation files are moved to a separate repo [singa-doc]](https://github.com/apache/singa-doc).
+  * New website based on Docusaurus. The documentation files are moved to a separate repo [singa-doc](https://github.com/apache/singa-doc).
     The static website files are stored at [singa-site](https://github.com/apache/singa-site).
 
   * DNNL([Deep Neural Network Library](https://github.com/intel/mkl-dnn)), powered by Intel, 
diff --git a/examples/rnn/README.md b/examples/rnn/README.md
index 6a3a9bd..7c1c697 100644
--- a/examples/rnn/README.md
+++ b/examples/rnn/README.md
@@ -24,14 +24,10 @@ application (or model) using SINGA's RNN layers.
 We will use the [char-rnn](https://github.com/karpathy/char-rnn) model as an
 example, which trains over sentences or
 source code, with each character as an input unit. Particularly, we will train
-a RNN using GRU over Linux kernel source code. After training, we expect to
-generate meaningful code from the model.
-
+a RNN over Linux kernel source code. 
 
 ## Instructions
 
-* Compile and install SINGA. Currently the RNN implementation depends on Cudnn with version >= 5.05.
-
 * Prepare the dataset. Download the [kernel source code](http://cs.stanford.edu/people/karpathy/char-rnn/).
 Other plain text files can also be used.
 
@@ -42,9 +38,3 @@ Other plain text files can also be used.
   Some hyper-parameters could be set through command line,
 
         python train.py -h
-
-* Sample characters from the model by providing the number of characters to sample and the seed string.
-
-        python sample.py 'model.bin' 100 --seed '#include <std'
-
-  Please replace 'model.bin' with the path to one of the checkpoint paths.
diff --git a/examples/rnn/train.py b/examples/rnn/train.py
index c2440d7..30ce680 100644
--- a/examples/rnn/train.py
+++ b/examples/rnn/train.py
@@ -20,26 +20,55 @@ The train file could be any text file,
 e.g., http://cs.stanford.edu/people/karpathy/char-rnn/
 '''
 
-
 from __future__ import division
 from __future__ import print_function
-from builtins import zip
 from builtins import range
-from builtins import object
-import pickle as pickle
 import numpy as np
+import sys
 import argparse
+from tqdm import tqdm
 
-from singa import layer
-from singa import loss
 from singa import device
 from singa import tensor
-from singa import optimizer
-from singa import initializer
-from singa import utils
+from singa import autograd
+from singa import module
+from singa import opt
+
+
+class CharRNN(module.Module):
+
+    def __init__(self, vocab_size, hidden_size=32):
+        super(CharRNN, self).__init__()
+        self.rnn = autograd.LSTM(vocab_size, hidden_size)
+        self.dense = autograd.Linear(hidden_size, vocab_size)
+        self.optimizer = opt.SGD(0.01)
+        self.hidden_size = hidden_size
+        self.vocab_size = vocab_size
+        self.hx = tensor.Tensor((1, self.hidden_size))
+        self.cx = tensor.Tensor((1, self.hidden_size))
+
+    def reset_states(self, dev):
+        self.hx.to_device(dev)
+        self.cx.to_device(dev)
+        self.hx.set_value(0.0)
+        self.cx.set_value(0.0)
+
+    def forward(self, inputs):
+        x, self.hx, self.cx = self.rnn(inputs, (self.hx, self.cx))
+        x = autograd.cat(x)
+        x = autograd.reshape(x, (-1, self.hidden_size))
+        return self.dense(x)
+
+    def loss(self, out, ty):
+        ty = autograd.reshape(ty, (-1, 1))
+        return autograd.softmax_cross_entropy(out, ty)
+
+    def optim(self, loss):
+        self.optimizer.backward_and_update(loss)
 
 
 class Data(object):
+
     def __init__(self, fpath, batch_size=32, seq_length=100, train_ratio=0.8):
         '''Data object for loading a plain text file.
 
@@ -48,7 +77,8 @@ class Data(object):
             train_ratio, split the text file into train and test sets, where
                 train_ratio of the characters are in the train set.
         '''
-        self.raw_data = open(fpath, 'r',encoding='iso-8859-1').read()  # read text file
+        self.raw_data = open(fpath, 'r',
+                             encoding='iso-8859-1').read()  # read text file
         chars = list(set(self.raw_data))
         self.vocab_size = len(chars)
         self.char_to_idx = {ch: i for i, ch in enumerate(chars)}
@@ -56,12 +86,12 @@ class Data(object):
         data = [self.char_to_idx[c] for c in self.raw_data]
         # seq_length + 1 for the data + label
         nsamples = len(data) // (1 + seq_length)
-        data = data[0:nsamples * (1 + seq_length)]
+        data = data[0:300 * (1 + seq_length)]
         data = np.asarray(data, dtype=np.int32)
         data = np.reshape(data, (-1, seq_length + 1))
         # shuffle all sequences
         np.random.shuffle(data)
-        self.train_dat = data[0:int(data.shape[0]*train_ratio)]
+        self.train_dat = data[0:int(data.shape[0] * train_ratio)]
         self.num_train_batch = self.train_dat.shape[0] // batch_size
         self.val_dat = data[self.train_dat.shape[0]:]
         self.num_test_batch = self.val_dat.shape[0] // batch_size
@@ -69,23 +99,35 @@ class Data(object):
         print('val dat', self.val_dat.shape)
 
 
-def numpy2tensors(npx, npy, dev):
+def numpy2tensors(npx, npy, dev, inputs=None, labels=None):
     '''batch, seq, dim -- > seq, batch, dim'''
+    tmpy = np.swapaxes(npy, 0, 1).reshape((-1, 1))
+    if labels:
+        labels.copy_from_numpy(tmpy)
+    else:
+        labels = tensor.from_numpy(tmpy)
+    labels.to_device(dev)
     tmpx = np.swapaxes(npx, 0, 1)
-    tmpy = np.swapaxes(npy, 0, 1)
-    inputs = []
-    labels = []
+    inputs_ = []
     for t in range(tmpx.shape[0]):
-        x = tensor.from_numpy(tmpx[t])
-        y = tensor.from_numpy(tmpy[t])
-        x.to_device(dev)
-        y.to_device(dev)
-        inputs.append(x)
-        labels.append(y)
+        if inputs:
+            inputs[t].copy_from_numpy(tmpx[t])
+        else:
+            x = tensor.from_numpy(tmpx[t])
+            x.to_device(dev)
+            inputs_.append(x)
+    if not inputs:
+        inputs = inputs_
     return inputs, labels
 
 
-def convert(batch, batch_size, seq_length, vocab_size, dev):
+def convert(batch,
+            batch_size,
+            seq_length,
+            vocab_size,
+            dev,
+            inputs=None,
+            labels=None):
     '''convert a batch of data into a sequence of input tensors'''
     y = batch[:, 1:]
     x1 = batch[:, :seq_length]
@@ -94,127 +136,90 @@ def convert(batch, batch_size, seq_length, vocab_size, dev):
         for t in range(seq_length):
             c = x1[b, t]
             x[b, t, c] = 1
-    return numpy2tensors(x, y, dev)
-
-
-def get_lr(epoch):
-    return 0.001 / float(1 << (epoch // 50))
-
-
-def train(data, max_epoch, hidden_size=100, seq_length=100, batch_size=16,
-          num_stacks=1, dropout=0.5, model_path='model'):
+    return numpy2tensors(x, y, dev, inputs, labels)
+
+
+def sample(model, data, dev, nsamples=100, use_max=False):
+    while True:
+        cmd = input('Do you want to sample text from the model [y/n]')
+        if cmd == 'n':
+            return
+        else:
+            seed = input('Please input some seeding text, e.g., #include <c: ')
+            inputs = []
+            for c in seed:
+                x = np.zeros((1, data.vocab_size), dtype=np.float32)
+                x[0, data.char_to_idx[c]] = 1
+                tx = tensor.from_numpy(x)
+                tx.to_device(dev)
+                inputs.append(tx)
+            model.reset_states(dev)
+            outputs = model(inputs)
+            y = tensor.softmax(outputs[-1])
+            sys.stdout.write(seed)
+            for i in range(nsamples):
+                prob = tensor.to_numpy(y)[0]
+                if use_max:
+                    cur = np.argmax(prob)
+                else:
+                    cur = np.random.choice(data.vocab_size, 1, p=prob)[0]
+                sys.stdout.write(data.idx_to_char[cur])
+                x = np.zeros((1, data.vocab_size), dtype=np.float32)
+                x[0, cur] = 1
+                tx = tensor.from_numpy(x)
+                tx.to_device(dev)
+                outputs = model([tx])
+                y = tensor.softmax(outputs[-1])
+
+
+def evaluate(model, data, batch_size, seq_length, dev):
+    model.eval()
+    val_loss = 0.0
+    for b in range(data.num_test_batch):
+        batch = data.val_dat[b * batch_size:(b + 1) * batch_size]
+        inputs, labels = convert(batch, batch_size, seq_length, data.vocab_size,
+                                 dev)
+        model.reset_states(dev)
+        y = model(inputs)
+        loss = model.loss(y, labels)[0]
+        val_loss += tensor.to_numpy(loss)[0]
+    print('            validation loss is %f' %
+          (val_loss / data.num_test_batch / seq_length))
+
+
+def train(data,
+          max_epoch,
+          hidden_size=100,
+          seq_length=100,
+          batch_size=16,
+          model_path='model'):
     # SGD with L2 gradient normalization
-    opt = optimizer.RMSProp(constraint=optimizer.L2Constraint(5))
     cuda = device.create_cuda_gpu()
-    rnn = layer.LSTM(
-        name='lstm',
-        hidden_size=hidden_size,
-        num_stacks=num_stacks,
-        dropout=dropout,
-        input_sample_shape=(
-            data.vocab_size,
-        ))
-    rnn.to_device(cuda)
-    print('created rnn')
-    rnn_w = rnn.param_values()[0]
-    rnn_w.uniform(-0.08, 0.08)  # init all rnn parameters
-    print('rnn weight l1 = %f' % (rnn_w.l1()))
-    dense = layer.Dense(
-        'dense',
-        data.vocab_size,
-        input_sample_shape=(
-            hidden_size,
-        ))
-    dense.to_device(cuda)
-    dense_w = dense.param_values()[0]
-    dense_b = dense.param_values()[1]
-    print('dense w ', dense_w.shape)
-    print('dense b ', dense_b.shape)
-    initializer.uniform(dense_w, dense_w.shape[0], 0)
-    print('dense weight l1 = %f' % (dense_w.l1()))
-    dense_b.set_value(0)
-    print('dense b l1 = %f' % (dense_b.l1()))
-
-    g_dense_w = tensor.Tensor(dense_w.shape, cuda)
-    g_dense_b = tensor.Tensor(dense_b.shape, cuda)
-
-    lossfun = loss.SoftmaxCrossEntropy()
+    model = CharRNN(data.vocab_size, hidden_size)
+    model.on_device(cuda)
+    model.graph(True, True)
+
+    inputs, labels = None, None
+
     for epoch in range(max_epoch):
+        model.train()
         train_loss = 0
-        for b in range(data.num_train_batch):
-            batch = data.train_dat[b * batch_size: (b + 1) * batch_size]
+        for b in tqdm(range(data.num_train_batch)):
+            batch = data.train_dat[b * batch_size:(b + 1) * batch_size]
             inputs, labels = convert(batch, batch_size, seq_length,
-                                     data.vocab_size, cuda)
-            inputs.append(tensor.Tensor())
-            inputs.append(tensor.Tensor())
-
-            outputs = rnn.forward(True, inputs)[0:-2]
-            grads = []
-            batch_loss = 0
-            g_dense_w.set_value(0.0)
-            g_dense_b.set_value(0.0)
-            for output, label in zip(outputs, labels):
-                act = dense.forward(True, output)
-                lvalue = lossfun.forward(True, act, label)
-                batch_loss += lvalue.l1()
-                grad = lossfun.backward()
-                grad /= batch_size
-                grad, gwb = dense.backward(True, grad)
-                grads.append(grad)
-                g_dense_w += gwb[0]
-                g_dense_b += gwb[1]
-                # print output.l1(), act.l1()
-            utils.update_progress(
-                b * 1.0 / data.num_train_batch, 'training loss = %f' %
-                (batch_loss / seq_length))
-            train_loss += batch_loss
-
-            grads.append(tensor.Tensor())
-            grads.append(tensor.Tensor())
-            g_rnn_w = rnn.backward(True, grads)[1][0]
-            dense_w, dense_b = dense.param_values()
-            opt.apply_with_lr(epoch, get_lr(epoch), g_rnn_w, rnn_w, 'rnnw')
-            opt.apply_with_lr(
-                epoch, get_lr(epoch),
-                g_dense_w, dense_w, 'dense_w')
-            opt.apply_with_lr(
-                epoch, get_lr(epoch),
-                g_dense_b, dense_b, 'dense_b')
+                                     data.vocab_size, cuda, inputs, labels)
+            model.reset_states(cuda)
+            y = model(inputs)
+            loss = model.loss(y, labels)
+            model.optim(loss)
+            train_loss += tensor.to_numpy(loss)[0]
+
         print('\nEpoch %d, train loss is %f' %
               (epoch, train_loss / data.num_train_batch / seq_length))
 
-        eval_loss = 0
-        for b in range(data.num_test_batch):
-            batch = data.val_dat[b * batch_size: (b + 1) * batch_size]
-            inputs, labels = convert(batch, batch_size, seq_length,
-                                     data.vocab_size, cuda)
-            inputs.append(tensor.Tensor())
-            inputs.append(tensor.Tensor())
-            outputs = rnn.forward(False, inputs)[0:-2]
-            for output, label in zip(outputs, labels):
-                output = dense.forward(True, output)
-                eval_loss += lossfun.forward(True, output, label).l1()
-        print('Epoch %d, evaluation loss is %f' %
-              (epoch, eval_loss / data.num_test_batch / seq_length))
-
-        if (epoch + 1) % 30 == 0:
-            # checkpoint the file model
-            with open('%s_%d.bin' % (model_path, epoch), 'wb') as fd:
-                print('saving model to %s' % model_path)
-                d = {}
-                for name, w in zip(
-                        ['rnn_w', 'dense_w', 'dense_b'],
-                        [rnn_w, dense_w, dense_b]):
-                    w.to_host()
-                    d[name] = tensor.to_numpy(w)
-                    w.to_device(cuda)
-                d['idx_to_char'] = data.idx_to_char
-                d['char_to_idx'] = data.char_to_idx
-                d['hidden_size'] = hidden_size
-                d['num_stacks'] = num_stacks
-                d['dropout'] = dropout
-
-                pickle.dump(d, fd)
+        # evaluate(model, data, batch_size, seq_length, cuda, inputs, labels)
+        # sample(model, data, cuda)
+
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(
@@ -224,9 +229,11 @@ if __name__ == '__main__':
     parser.add_argument('-b', type=int, default=32, help='batch_size')
     parser.add_argument('-l', type=int, default=64, help='sequence length')
     parser.add_argument('-d', type=int, default=128, help='hidden size')
-    parser.add_argument('-s', type=int, default=2, help='num of stacks')
     parser.add_argument('-m', type=int, default=50, help='max num of epoch')
     args = parser.parse_args()
     data = Data(args.data, batch_size=args.b, seq_length=args.l)
-    train(data, args.m,  hidden_size=args.d, num_stacks=args.s,
-          seq_length=args.l, batch_size=args.b)
+    train(data,
+          args.m,
+          hidden_size=args.d,
+          seq_length=args.l,
+          batch_size=args.b)
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index d26e794..e0813da 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -1806,13 +1806,13 @@ class SeparableConv2d(Layer):
     """
 
     def __init__(
-        self,
-        in_channels,
-        out_channels,
-        kernel_size,
-        stride=1,
-        padding=0,
-        bias=False,
+            self,
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=1,
+            padding=0,
+            bias=False,
     ):
         """
         Args:
@@ -3132,15 +3132,15 @@ class RNN(RNN_Base):
     """
 
     def __init__(
-        self,
-        input_size,
-        hidden_size,
-        num_layers=1,
-        nonlinearity="tanh",
-        bias=True,
-        batch_first=False,
-        dropout=0,
-        bidirectional=False,
+            self,
+            input_size,
+            hidden_size,
+            num_layers=1,
+            nonlinearity="tanh",
+            bias=True,
+            batch_first=False,
+            dropout=0,
+            bidirectional=False,
     ):
         """
         Args:
@@ -3212,15 +3212,15 @@ class LSTM(RNN_Base):
     """
 
     def __init__(
-        self,
-        input_size,
-        hidden_size,
-        nonlinearity="tanh",
-        num_layers=1,
-        bias=True,
-        batch_first=False,
-        dropout=0,
-        bidirectional=False,
+            self,
+            input_size,
+            hidden_size,
+            nonlinearity="tanh",
+            num_layers=1,
+            bias=True,
+            batch_first=False,
+            dropout=0,
+            bidirectional=False,
     ):
         """
         Args:
@@ -3244,14 +3244,14 @@ class LSTM(RNN_Base):
         self.Wx = []
         for i in range(4):
             w = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True)
-            w.gaussian(0.0, 1.0)
+            w.gaussian(0.0, 0.01)
             self.Wx.append(w)
 
         Wh_shape = (hidden_size, hidden_size)
         self.Wh = []
         for i in range(4):
             w = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True)
-            w.gaussian(0.0, 1.0)
+            w.gaussian(0.0, 0.01)
             self.Wh.append(w)
 
         Bx_shape = (hidden_size,)
diff --git a/tool/release/release.py b/tool/release/release.py
index 5de1655..28a53db 100755
--- a/tool/release/release.py
+++ b/tool/release/release.py
@@ -49,7 +49,7 @@ def main(args):
         default=False,
         dest='confirmed',
         action='store_true',
-        help="In interactive mode, for user to confirm. Could be used in sript")
+        help="In interactive mode, for user to confirm. Could be used in script")
     parser.add_argument('type',
                         choices=['major', 'minor', 'patch', 'rc', 'stable'],
                         help="Release types")