You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/08/22 14:54:57 UTC

[1/5] incubator-singa git commit: SINGA-388 Develop some RNN layers by calling tiny operations like matmul, addbias

Repository: incubator-singa
Updated Branches:
  refs/heads/master 65756e6f6 -> 2224d5f9a


SINGA-388 Develop some RNN layers by calling tiny operations like matmul, addbias

- Add unit test case for both vallina rnn and lstm.
  The unit test cases include gradients shape check as well as value check compared with numerical
  calculation results.

- Add device_check() to valina_rnn and lstm, this function can check the device of inputs and
  paramerters. If the devices of them are not the same, the funciton can transfer them on a same
  device.

- fix some bugs in test cases and source codes.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/a44a01c0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/a44a01c0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/a44a01c0

Branch: refs/heads/master
Commit: a44a01c0f291cfca8a688570e3c752b1ef6ec829
Parents: 5dc17b9
Author: xuewanqi <xu...@outlook.com>
Authored: Tue Aug 14 13:37:07 2018 +0000
Committer: xuewanqi <xu...@outlook.com>
Committed: Thu Aug 16 11:40:25 2018 +0000

----------------------------------------------------------------------
 python/singa/autograd.py      |  86 ++++++++++++-------
 python/singa/net.py           |   0
 python/singa/tensor.py        |   6 +-
 test/python/test_layer.py     |   2 +-
 test/python/test_operation.py | 172 ++++++++++++++++++++++++++++++++++++-
 5 files changed, 227 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a44a01c0/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
old mode 100644
new mode 100755
index b18e08e..7032135
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -55,7 +55,8 @@ def infer_dependency(op):
             if src_op not in dependency_count:
                 # dependency[src_op] = [Counter() for _ in src_op.y_id2idx]
                 if isinstance(src_op, Dummy):
-                    # only when a Dummy operator needs store grads, its dependency needs to be counted.
+                    # only when a Dummy operator needs store grads, its
+                    # dependency needs to be counted.
                     if src_op.stores_grad:
                         dependency_count[src_op] = 0
                         queue.append(src_op)
@@ -107,9 +108,9 @@ def backward(y, dy=None):
     if y.stores_grad:
         #gradients[y] = dy
         if isinstance(dy, float):
-            g=np.array(dy)
+            g = np.array(dy)
         else:
-            g=dy
+            g = dy
         tg = Tensor(device=g.device(), data=g)
         yield (y, tg)
 
@@ -139,7 +140,7 @@ def backward(y, dy=None):
             if isinstance(src_op, Dummy):
                 if not src_op.stores_grad:
                     continue
-                    
+
             y_idx = src_op.y_id2idx[x_id]
             if src_op not in not_ready:
                 # src_op may have mulitple outputs
@@ -153,13 +154,15 @@ def backward(y, dy=None):
                     # add the gradient from another children operation that
                     # uses y_idx'th output of src_op as input arg
                     dxs[y_idx] += dx
-            
+
             dependency[src_op] -= 1
 
             if y_stores_grad:
                 if dependency[src_op] == 0:
                     # store the gradient for final return, e.g. if x is parameter
-                    # may cause a delay output, as only after src_op is ready then output, not the current outlet of src_op is ready then output.
+                    # may cause a delay output, as only after src_op is ready
+                    # then output, not the current outlet of src_op is ready
+                    # then output.
                     g = not_ready[src_op][y_idx]
                     tg = Tensor(device=g.device(), data=g)
                     yield (y, tg)
@@ -167,13 +170,13 @@ def backward(y, dy=None):
             if src_op.requires_grad is True:
                 if dependency[src_op] == 0:
                     if not isinstance(src_op, Dummy):
-                        #Dummy can be in not_ready list but cannot be in ready list.
+                        # Dummy can be in not_ready list but cannot be in ready
+                        # list.
                         ready.append((src_op, not_ready[src_op]))
                     del not_ready[src_op]
         del op  # delete the operation to free all tensors from this op
 
 
-
 class Operation(object):
     '''
     An operation includes the forward and backward function of
@@ -800,7 +803,7 @@ class BatchNorm2d(Layer):
         self.handle.device_id = x.device.id()
 
         y = batchnorm_2d(self.handle, x, self.scale, self.bias,
-                      self.running_mean, self.running_var)
+                         self.running_mean, self.running_var)
         return y
 
 
@@ -985,6 +988,7 @@ class Tanh(Operation):
 def tanh(x):
     return Tanh()(x)[0]
 
+
 class Sigmoid(Operation):
 
     def forward(self, x):
@@ -1021,31 +1025,28 @@ class ElemMatmul(Operation):
 def elemmatmul(x, y):
     return ElemMatmul()(x, y)[0]
 
+
 def add_all(*xs):
     assert len(xs) > 2
-    y=add(xs[0],xs[1])
+    y = add(xs[0], xs[1])
     for x in xs[2:]:
-        y=add(y, x)
+        y = add(y, x)
     return
 
 class RNN(Layer):
     def __init__(self):
         raise NotImplementedError
 
-    def __call__(self, h0, *xs):
-        batchsize=xs[0].shape[0]
-        out=[]
-        h = self.step_forward(xs[0], h0, self.Wx, self.Wh, self.b)
-        out.append(h)
-        for x in xs[1:]:
-            assert x.shape[0] == batchsize
-            h = self.step_forward(x, h, self.Wx, self.Wh, self.b)
-            out.append(h)
-        return out, h
+    def __call__(self):
+        raise NotImplementedError
+
+    def step_forward(self):
+        raise NotImplementedError
 
 class Vanilla_RNN(RNN):
+
     def __init__(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True, batch_first=False, dropout=0, bidirectional=False):
-        self.nonlinearity=nonlinearity
+        self.nonlinearity = nonlinearity
 
         Wx_shape = (input_size, hidden_size)
         self.Wx = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True)
@@ -1055,27 +1056,45 @@ class Vanilla_RNN(RNN):
         self.Wh = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True)
         self.Wh.gaussian(0.0, 1.0)
 
-        B_shape=(hidden_size,)
+        B_shape = (hidden_size,)
         self.b = Tensor(shape=B_shape, requires_grad=True, stores_grad=True)
         self.b.set_value(0.0)
 
+        #self.params= (self.Wx, self.Wh, self.b)
+
+    def __call__(self, h0, *xs):
+        inputs=xs+(h0,)
+        self.device_check(*inputs)
+        #self.device_check(inputs[0], *self.params)
+        self.device_check(inputs[0], self.Wx, self.Wh, self.b)
+        batchsize = xs[0].shape[0]
+        out = []
+        h = self.step_forward(xs[0], h0, self.Wx, self.Wh, self.b)
+        out.append(h)
+        for x in xs[1:]:
+            assert x.shape[0] == batchsize
+            h = self.step_forward(x, h, self.Wx, self.Wh, self.b)
+            out.append(h)
+        return out, h
+
     def step_forward(self, x, h, Wx, Wh, b):
-        y1=matmul(x, Wx)
-        y2=matmul(h, Wh)
-        y=add(y1,y2)
-        y=add_bias(y,b,axis=0)
+        y1 = matmul(x, Wx)
+        y2 = matmul(h, Wh)
+        y = add(y1, y2)
+        y = add_bias(y, b, axis=0)
         if self.nonlinearity == 'tanh':
-            y=tanh(y)
+            y = tanh(y)
         elif self.nonlinearity == 'relu':
-            y=relu(y)
+            y = relu(y)
         else:
             raise ValueError
         return y
 
+
 class LSTM(RNN):
 
     def __init__(self, input_size, hidden_size, nonlinearity='tanh', num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False):
-        self.nonlinearity=nonlinearity
+        self.nonlinearity = nonlinearity
 
         Wx_shape = (input_size, hidden_size)
         self.Wx = []
@@ -1105,7 +1124,13 @@ class LSTM(RNN):
             b.set_value(0.0)
             self.Bh.append(b)
 
+        #self.params=self.Wx + self.Wh + self.Bx + self.Bh
+
     def __call__(self, h0, c0, *xs):
+        inputs=xs+(h0,c0)
+        self.device_check(*inputs)
+        #self.device_check(inputs[0], *self.params)
+        self.device_check(inputs[0], *(self.Wx + self.Wh + self.Bx + self.Bh))
         batchsize = xs[0].shape[0]
         out = []
         h, c = self.step_forward(
@@ -1154,4 +1179,3 @@ class LSTM(RNN):
         hout = tanh(cout)
         hout = elemmatmul(o, hout)
         return hout, cout
-

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a44a01c0/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
old mode 100644
new mode 100755

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a44a01c0/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
old mode 100644
new mode 100755
index 441431f..80c9a2e
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -638,7 +638,7 @@ def reshape(tensor, shape):
     Returns:
         the new Tensor
     '''
-    return _call_singa_func(singa.Reshape, t.data, s)
+    return _call_singa_func(singa.Reshape, tensor.data, shape)
 
 
 def transpose(t, axes=None):
@@ -1333,8 +1333,8 @@ def tensordot(A, B, axes=2):
 
     A = transpose(A, newaxes_a)
     B = transpose(B, newaxes_b)
-    at = Reshape(A, newshape_a)
-    bt = Reshape(B, newshape_b)
+    at = reshape(A, newshape_a)
+    bt = reshape(B, newshape_b)
 
     res = mult(at, bt)
     if len(olda + oldb) == 0:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a44a01c0/test/python/test_layer.py
----------------------------------------------------------------------
diff --git a/test/python/test_layer.py b/test/python/test_layer.py
old mode 100644
new mode 100755
index 2c49961..4c859f4
--- a/test/python/test_layer.py
+++ b/test/python/test_layer.py
@@ -62,7 +62,7 @@ class TestPythonLayer(unittest.TestCase):
 
         raw_x = np.arange(9, dtype=np.float32) + 1
         x = tensor.from_numpy(raw_x)
-        x.reshape((1, 1, 3, 3))
+        x = x.reshape((1, 1, 3, 3))
         w = np.array([1, 1, 0, 0, 0, -1, 0, 1, 0], dtype=np.float32)
         params[0].copy_from_numpy(w)
         params[1].set_value(1.0)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a44a01c0/test/python/test_operation.py
----------------------------------------------------------------------
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index 67018c1..4975d99 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -6,6 +6,8 @@ from singa import singa_wrap as singa
 from singa import device
 from singa import autograd
 
+import numpy as np
+
 autograd.training = True
 
 CTensor = singa.Tensor
@@ -21,6 +23,31 @@ def _tuple_to_string(t):
     lt = [str(x) for x in t]
     return '(' + ', '.join(lt) + ')'
 
+def prepare_inputs_targets_for_rnn_test():
+        x_0 = np.random.random((2, 3)).astype(np.float32)
+        x_1 = np.random.random((2, 3)).astype(np.float32)
+        x_2 = np.random.random((2, 3)).astype(np.float32)
+
+        h_0 = np.random.random((2, 1)).astype(
+            np.float32)  # (2,1) rather than (2,)
+
+        t_0 = np.random.random((2, 2)).astype(np.float32)
+        t_1 = np.random.random((2, 2)).astype(np.float32)
+        t_2 = np.random.random((2, 2)).astype(np.float32)
+
+        x0 = tensor.Tensor(device=gpu_dev, data=x_0)
+        x1 = tensor.Tensor(device=gpu_dev, data=x_1)
+        x2 = tensor.Tensor(device=gpu_dev, data=x_2)
+
+        h0 = tensor.Tensor(device=gpu_dev, data=h_0)
+
+        t0 = tensor.Tensor(device=gpu_dev, data=t_0)
+        t1 = tensor.Tensor(device=gpu_dev, data=t_1)
+        t2 = tensor.Tensor(device=gpu_dev, data=t_2)
+
+        inputs = [x0, x1, x2]
+        targets = [t0, t1, t2]
+        return inputs, targets, h0
 
 class TestPythonOperation(unittest.TestCase):
 
@@ -32,8 +59,8 @@ class TestPythonOperation(unittest.TestCase):
 
     def test_conv2d_gpu(self):
         # (in_channels, out_channels, kernel_size)
-        conv_0 = autograd.Conv2D(3, 1, 2)
-        conv_without_bias_0 = autograd.Conv2D(3, 1, 2, bias=False)
+        conv_0 = autograd.Conv2d(3, 1, 2)
+        conv_without_bias_0 = autograd.Conv2d(3, 1, 2, bias=False)
 
         gpu_input_tensor = tensor.Tensor(shape=(2, 3, 3, 3), device=gpu_dev)
         gpu_input_tensor.gaussian(0.0, 1.0)
@@ -52,8 +79,8 @@ class TestPythonOperation(unittest.TestCase):
 
     def test_conv2d_cpu(self):
         # (in_channels, out_channels, kernel_size)
-        conv_1 = autograd.Conv2D(3, 1, 2)
-        conv_without_bias_1 = autograd.Conv2D(3, 1, 2, bias=False)
+        conv_1 = autograd.Conv2d(3, 1, 2)
+        conv_without_bias_1 = autograd.Conv2d(3, 1, 2, bias=False)
 
         cpu_input_tensor = tensor.Tensor(shape=(2, 3, 3, 3), device=cpu_dev)
         cpu_input_tensor.gaussian(0.0, 1.0)
@@ -87,5 +114,142 @@ class TestPythonOperation(unittest.TestCase):
         self.check_shape(ds.shape(), (3,))
         self.check_shape(db.shape(), (3,))
 
+    def test_vanillaRNN_gpu_tiny_ops(self):
+        # gradients shape check.
+        inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
+        rnn = autograd.Vanilla_RNN(3, 2)
+
+        hs, _ = rnn(h0, *inputs)
+
+        loss = autograd.softmax_cross_entropy(hs[0], target[0])
+        for i in range(1, len(hs)):
+            l = autograd.softmax_cross_entropy(hs[i], target[i])
+            loss = autograd.add(loss, l)
+        # d=autograd.infer_dependency(loss.creator)
+        # print(d)
+        for t, dt in autograd.backward(loss):
+            self.check_shape(t.shape, dt.shape)
+
+    def test_LSTM_gpu_tiny_ops(self):
+        # gradients shape check.
+        inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
+        c_0 = np.random.random((2, 1)).astype(np.float32)
+        c0 = tensor.Tensor(device=gpu_dev, data=c_0)
+
+        rnn = autograd.LSTM(3, 2)
+
+        hs, _, _ = rnn(h0, c0, *inputs)
+        loss = autograd.softmax_cross_entropy(hs[0], target[0])
+
+        for i in range(1, len(hs)):
+            l = autograd.softmax_cross_entropy(hs[i], target[i])
+            loss = autograd.add(loss, l)
+        # d=autograd.infer_dependency(loss.creator)
+        # print(d)
+        for t, dt in autograd.backward(loss):
+            self.check_shape(t.shape, dt.shape)
+
+    def test_numerical_gradients_check_for_vallina_rnn(self):
+        inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
+
+        rnn = autograd.Vanilla_RNN(3, 2)
+
+        hs, _ = rnn(h0, *inputs)
+
+        loss1 = autograd.softmax_cross_entropy(hs[0], target[0])
+        for i in range(1, len(hs)):
+            l = autograd.softmax_cross_entropy(hs[i], target[i])
+            loss1 = autograd.add(loss1, l)
+        grads = autograd.gradients(loss1)
+
+        # autograd gradients for dL/dWx[0][0]
+        d1 = tensor.to_numpy(grads[rnn.Wx])[0][0]
+        #print('autograd result of dL/dWx[0][0] is ', d1)
+
+
+        length = 0.01
+        diff = np.array([1, 0, 0, 0, 0, 0]) * length
+        diff = np.reshape(diff, (3, 2))
+        diff = tensor.from_numpy(diff)
+        diff.to_device(gpu_dev)
+
+        rnn.Wx += diff
+        hs, _ = rnn(h0, *inputs)
+        #hs=rnn(h0, x0,x1)
+        loss2_p = autograd.softmax_cross_entropy(hs[0], target[0])
+        for i in range(1, len(hs)):
+            l = autograd.softmax_cross_entropy(hs[i], target[i])
+            loss2_p = autograd.add(loss2_p, l)
+
+        rnn.Wx -= diff
+        rnn.Wx -= diff
+        hs, _ = rnn(h0, *inputs)
+        #hs=rnn(h0, x0,x1)
+        loss2_n = autograd.softmax_cross_entropy(hs[0], target[0])
+        for i in range(1, len(hs)):
+            l = autograd.softmax_cross_entropy(hs[i], target[i])
+            loss2_n = autograd.add(loss2_n, l)
+
+        loss2_p_np = tensor.to_numpy(loss2_p)
+        loss2_n_np = tensor.to_numpy(loss2_n)
+        # Numerical gradients for dL/dWx[0][0]
+        d2 = (loss2_p_np - loss2_n_np) / 2 / length
+        #print('numerical calculation dL/dWx[0][0] is ', (loss2_p_np-loss2_n_np)/2/length)
+
+        self.assertAlmostEqual(np.sum(d1 - d2), 0., places=3)
+
+    def test_numerical_gradients_check_for_lstm(self):
+        inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
+        c_0 = np.random.random((2, 1)).astype(np.float32)
+        c0 = tensor.Tensor(device=gpu_dev, data=c_0)
+
+        rnn = autograd.LSTM(3, 2)
+
+        hs, _, _ = rnn(h0, c0, *inputs)
+
+        loss1 = autograd.softmax_cross_entropy(hs[0], target[0])
+        for i in range(1, len(hs)):
+            l = autograd.softmax_cross_entropy(hs[i], target[i])
+            loss1 = autograd.add(loss1, l)
+        grads = autograd.gradients(loss1)
+
+        # autograd gradients for dL/dWx[0][0]
+        d1 = tensor.to_numpy(grads[rnn.Wx[0]])[0][0]
+        #print('autograd result of dL/dWx[0][0] is ', d1)
+
+
+        length = 0.01
+        diff = np.array([1, 0, 0, 0, 0, 0]) * length
+        diff = np.reshape(diff, (3, 2))
+        diff = tensor.from_numpy(diff)
+        diff.to_device(gpu_dev)
+
+        rnn.Wx[0] += diff
+        hs, _, _ = rnn(h0, c0, *inputs)
+        #hs=rnn(h0, x0,x1)
+        loss2_p = autograd.softmax_cross_entropy(hs[0], target[0])
+        for i in range(1, len(hs)):
+            l = autograd.softmax_cross_entropy(hs[i], target[i])
+            loss2_p = autograd.add(loss2_p, l)
+
+        rnn.Wx[0] -= diff
+        rnn.Wx[0] -= diff
+        hs, _, _ = rnn(h0, c0, *inputs)
+        #hs=rnn(h0, x0,x1)
+        loss2_n = autograd.softmax_cross_entropy(hs[0], target[0])
+        for i in range(1, len(hs)):
+            l = autograd.softmax_cross_entropy(hs[i], target[i])
+            loss2_n = autograd.add(loss2_n, l)
+
+        loss2_p_np = tensor.to_numpy(loss2_p)
+        loss2_n_np = tensor.to_numpy(loss2_n)
+        # Numerical gradients for dL/dWx[0][0]
+        d2 = (loss2_p_np - loss2_n_np) / 2 / length
+        #print('numerical calculation dL/dWx[0][0] is ', (loss2_p_np-loss2_n_np)/2/length)
+
+        self.assertAlmostEqual(np.sum(d1 - d2), 0., places=3)
+
+
+
 if __name__ == '__main__':
     unittest.main()


[3/5] incubator-singa git commit: SINGA-388 Develop some RNN layers by calling tiny operations like matmul, addbias

Posted by wa...@apache.org.
SINGA-388 Develop some RNN layers by calling tiny operations like matmul, addbias

- develop Vanilla RNN by calling some smaller operations.
- add some necessary operations which will be used in Vanilla RNN layer.
- the developed RNN layer has passed test.(return correct number of gradients)


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/7df6a5db
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/7df6a5db
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/7df6a5db

Branch: refs/heads/master
Commit: 7df6a5db4b7ef30aca561889e46e2457b35b15c3
Parents: 770d6cd
Author: xuewanqi <xu...@outlook.com>
Authored: Mon Aug 13 13:44:29 2018 +0000
Committer: xuewanqi <xu...@outlook.com>
Committed: Thu Aug 16 11:40:25 2018 +0000

----------------------------------------------------------------------
 python/singa/autograd.py | 74 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 74 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7df6a5db/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
old mode 100755
new mode 100644
index 56b5498..4c7959c
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -87,6 +87,7 @@ def backward(y, dy=None):
         a dictionary storing the gradient tensors of all tensors
         whose stores_grad is true (e.g. parameter tensors)
     '''
+    assert isinstance(y, Tensor), 'wrong input type.'
     dependency = infer_dependency(y.creator)
     assert y.size() == 1, 'y must be a Tensor with a single value;'\
         'size of y is % d' % y.size()
@@ -172,6 +173,7 @@ def backward(y, dy=None):
         del op  # delete the operation to free all tensors from this op
 
 
+
 class Operation(object):
     '''
     An operation includes the forward and backward function of
@@ -962,3 +964,75 @@ class AvgPool1d(Pooling2d):
             stride = kernel_size
         super(MaxPool2d, self).__init__(
             (1, kernel_size), (0, stride), (0, padding), False)
+
+
+class Tanh(Operation):
+
+    def forward(self, x):
+        out = singa.Tanh(x)
+        if training:
+            self.cache = (out,)
+        return out
+
+    def backward(self, dy):
+        dx = singa.__mul__(self.cache[0], self.cache[0])
+        dx = singa.MultFloat(dx, -1.0)
+        dx = singa.AddFloat(dx, 1.0)
+        dx = singa.__mul__(dy, dx)
+        return dx
+
+
+def tanh(x):
+    return Tanh()(x)[0]
+
+
+def add_all(*xs):
+    assert len(xs) > 2
+    y=add(xs[0],xs[1])
+    for x in xs[2:]:
+        y=add(y, x)
+    return
+
+
+class Vanilla_RNN(Layer):
+
+    def __init__(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True, batch_first=False, dropout=0, bidirectional=False):
+        self.nonlinearity=nonlinearity
+
+        Wx_shape = (input_size, hidden_size)
+        self.Wx = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True)
+        self.Wx.gaussian(0.0, 1.0)
+
+        Wh_shape = (hidden_size, hidden_size)
+        self.Wh = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True)
+        self.Wh.gaussian(0.0, 1.0)
+
+        B_shape=(hidden_size,)
+        self.b = Tensor(shape=B_shape, requires_grad=True, stores_grad=True)
+        self.b.set_value(0.0)
+
+    def __call__(self, h0, *xs):
+        batchsize=xs[0].shape[0]
+        self.out=[]
+        h = self.step_forward(xs[0], h0, self.Wx, self.Wh, self.b)
+        self.out.append(h)
+        for x in xs[1:]:
+            assert x.shape[0] == batchsize
+            h = self.step_forward(x, h, self.Wx, self.Wh, self.b)
+            self.out.append(h)
+        return self.out
+
+    def step_forward(self, x, h, Wx, Wh, b):
+        y1=matmul(x, Wx)
+        y2=matmul(h, Wh)
+        y=add(y1,y2)
+        y=add_bias(y,b,axis=0)
+        if self.nonlinearity == 'tanh':
+            y=tanh(y)
+        elif self.nonlinearity == 'relu':
+            y=relu(y)
+        else:
+            raise ValueError
+        return y
+
+


[2/5] incubator-singa git commit: SINGA-388 Develop some RNN layers by calling tiny operations like matmul, addbias

Posted by wa...@apache.org.
SINGA-388 Develop some RNN layers by calling tiny operations like matmul, addbias

- Develop LSTM by calling tiny operations.
- Add some operations which are necessary for LSTM layer into autograd.py
- redesign the structure of RNN Layer.
- LSTM layer has passed test.(return correct number of gradients)


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/5dc17b91
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/5dc17b91
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/5dc17b91

Branch: refs/heads/master
Commit: 5dc17b91e045e6318d224bebee8912da6c646596
Parents: 7df6a5d
Author: xuewanqi <xu...@outlook.com>
Authored: Mon Aug 13 14:11:37 2018 +0000
Committer: xuewanqi <xu...@outlook.com>
Committed: Thu Aug 16 11:40:25 2018 +0000

----------------------------------------------------------------------
 python/singa/autograd.py | 143 ++++++++++++++++++++++++++++++++++++++----
 1 file changed, 131 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5dc17b91/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 4c7959c..b18e08e 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -985,6 +985,41 @@ class Tanh(Operation):
 def tanh(x):
     return Tanh()(x)[0]
 
+class Sigmoid(Operation):
+
+    def forward(self, x):
+        out = singa.Sigmoid(x)
+        if training:
+            self.cache = (out,)
+        return out
+
+    def backward(self, dy):
+        dx = singa.MultFloat(self.cache[0], -1.0)
+        dx = singa.AddFloat(dx, 1.0)
+        dx = singa.__mul__(self.cache[0], dx)
+        dx = singa.__mul__(dy, dx)
+        return dx
+
+
+def sigmoid(x):
+    return Sigmoid()(x)[0]
+
+
+class ElemMatmul(Operation):
+
+    def forward(self, x1, x2):
+        if training:
+            self.cache = (x1, x2)
+        return singa.__mul__(x1, x2)
+
+    def backward(self, dy):
+        dx1 = singa.__mul__(dy, self.cache[1])
+        dx2 = singa.__mul__(dy, self.cache[0])
+        return dx1, dx2
+
+
+def elemmatmul(x, y):
+    return ElemMatmul()(x, y)[0]
 
 def add_all(*xs):
     assert len(xs) > 2
@@ -993,9 +1028,22 @@ def add_all(*xs):
         y=add(y, x)
     return
 
+class RNN(Layer):
+    def __init__(self):
+        raise NotImplementedError
 
-class Vanilla_RNN(Layer):
+    def __call__(self, h0, *xs):
+        batchsize=xs[0].shape[0]
+        out=[]
+        h = self.step_forward(xs[0], h0, self.Wx, self.Wh, self.b)
+        out.append(h)
+        for x in xs[1:]:
+            assert x.shape[0] == batchsize
+            h = self.step_forward(x, h, self.Wx, self.Wh, self.b)
+            out.append(h)
+        return out, h
 
+class Vanilla_RNN(RNN):
     def __init__(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True, batch_first=False, dropout=0, bidirectional=False):
         self.nonlinearity=nonlinearity
 
@@ -1011,17 +1059,6 @@ class Vanilla_RNN(Layer):
         self.b = Tensor(shape=B_shape, requires_grad=True, stores_grad=True)
         self.b.set_value(0.0)
 
-    def __call__(self, h0, *xs):
-        batchsize=xs[0].shape[0]
-        self.out=[]
-        h = self.step_forward(xs[0], h0, self.Wx, self.Wh, self.b)
-        self.out.append(h)
-        for x in xs[1:]:
-            assert x.shape[0] == batchsize
-            h = self.step_forward(x, h, self.Wx, self.Wh, self.b)
-            self.out.append(h)
-        return self.out
-
     def step_forward(self, x, h, Wx, Wh, b):
         y1=matmul(x, Wx)
         y2=matmul(h, Wh)
@@ -1035,4 +1072,86 @@ class Vanilla_RNN(Layer):
             raise ValueError
         return y
 
+class LSTM(RNN):
+
+    def __init__(self, input_size, hidden_size, nonlinearity='tanh', num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False):
+        self.nonlinearity=nonlinearity
+
+        Wx_shape = (input_size, hidden_size)
+        self.Wx = []
+        for i in range(4):
+            w = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True)
+            w.gaussian(0.0, 1.0)
+            self.Wx.append(w)
+
+        Wh_shape = (hidden_size, hidden_size)
+        self.Wh = []
+        for i in range(4):
+            w = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True)
+            w.gaussian(0.0, 1.0)
+            self.Wh.append(w)
+
+        Bx_shape = (hidden_size,)
+        self.Bx = []
+        for i in range(4):
+            b = Tensor(shape=Bx_shape, requires_grad=True, stores_grad=True)
+            b.set_value(0.0)
+            self.Bx.append(b)
+
+        Bh_shape = (hidden_size,)
+        self.Bh = []
+        for i in range(4):
+            b = Tensor(shape=Bx_shape, requires_grad=True, stores_grad=True)
+            b.set_value(0.0)
+            self.Bh.append(b)
+
+    def __call__(self, h0, c0, *xs):
+        batchsize = xs[0].shape[0]
+        out = []
+        h, c = self.step_forward(
+            xs[0], h0, c0, self.Wx, self.Wh, self.Bx, self.Bh)
+        out.append(h)
+        for x in xs[1:]:
+            assert x.shape[0] == batchsize
+            h, c = self.step_forward(
+                x, h, c, self.Wx, self.Wh, self.Bx, self.Bh)
+            out.append(h)
+        return out, h, c
+
+    def step_forward(self, x, h, c, Wx, Wh, Bx, Bh):
+        y1 = matmul(x, Wx[0])
+        y1 = add_bias(y1, Bx[0], axis=0)
+        y2 = matmul(h, Wh[0])
+        y2 = add_bias(y2, Bh[0], axis=0)
+        i = add(y1, y2)
+        i = sigmoid(i)
+
+        y1 = matmul(x, Wx[1])
+        y1 = add_bias(y1, Bx[1], axis=0)
+        y2 = matmul(h, Wh[1])
+        y2 = add_bias(y2, Bh[1], axis=0)
+        f = add(y1, y2)
+        f = sigmoid(f)
+
+        y1 = matmul(x, Wx[2])
+        y1 = add_bias(y1, Bx[2], axis=0)
+        y2 = matmul(h, Wh[2])
+        y2 = add_bias(y2, Bh[2], axis=0)
+        o = add(y1, y2)
+        o = sigmoid(o)
+
+        y1 = matmul(x, Wx[3])
+        y1 = add_bias(y1, Bx[3], axis=0)
+        y2 = matmul(h, Wh[3])
+        y2 = add_bias(y2, Bh[3], axis=0)
+        g = add(y1, y2)
+        g = tanh(g)
+
+        cout1 = elemmatmul(f, c)
+        cout2 = elemmatmul(i, g)
+        cout = add(cout1, cout2)
+
+        hout = tanh(cout)
+        hout = elemmatmul(o, hout)
+        return hout, cout
 


[4/5] incubator-singa git commit: SINGA-388 Develop some RNN layers by calling tiny operations like matmul, addbias

Posted by wa...@apache.org.
SINGA-388 Develop some RNN layers by calling tiny operations like matmul, addbias

- Improve test cases for vallina_rnn and lstm. check all elementes of parameter
  matrix.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/0cd4e308
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/0cd4e308
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/0cd4e308

Branch: refs/heads/master
Commit: 0cd4e3084e84a2b1562877dd47ff5ed46fb6aadf
Parents: a44a01c
Author: xuewanqi <xu...@outlook.com>
Authored: Wed Aug 15 12:42:25 2018 +0000
Committer: xuewanqi <xu...@outlook.com>
Committed: Tue Aug 21 15:10:06 2018 +0000

----------------------------------------------------------------------
 python/singa/autograd.py      |   8 +-
 test/python/test_operation.py | 179 ++++++++++++++++---------------------
 2 files changed, 82 insertions(+), 105 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0cd4e308/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 7032135..c0f6a7a 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -1060,7 +1060,7 @@ class Vanilla_RNN(RNN):
         self.b = Tensor(shape=B_shape, requires_grad=True, stores_grad=True)
         self.b.set_value(0.0)
 
-        #self.params= (self.Wx, self.Wh, self.b)
+        self.params= (self.Wx, self.Wh, self.b)
 
     def __call__(self, h0, *xs):
         inputs=xs+(h0,)
@@ -1078,9 +1078,9 @@ class Vanilla_RNN(RNN):
         return out, h
 
     def step_forward(self, x, h, Wx, Wh, b):
-        y1 = matmul(x, Wx)
         y2 = matmul(h, Wh)
-        y = add(y1, y2)
+        y1 = matmul(x, Wx)
+        y = add(y2, y1)
         y = add_bias(y, b, axis=0)
         if self.nonlinearity == 'tanh':
             y = tanh(y)
@@ -1124,7 +1124,7 @@ class LSTM(RNN):
             b.set_value(0.0)
             self.Bh.append(b)
 
-        #self.params=self.Wx + self.Wh + self.Bx + self.Bh
+        self.params=self.Wx + self.Wh + self.Bx + self.Bh
 
     def __call__(self, h0, c0, *xs):
         inputs=xs+(h0,c0)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0cd4e308/test/python/test_operation.py
----------------------------------------------------------------------
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index 4975d99..64562a5 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -23,31 +23,33 @@ def _tuple_to_string(t):
     lt = [str(x) for x in t]
     return '(' + ', '.join(lt) + ')'
 
+
 def prepare_inputs_targets_for_rnn_test():
-        x_0 = np.random.random((2, 3)).astype(np.float32)
-        x_1 = np.random.random((2, 3)).astype(np.float32)
-        x_2 = np.random.random((2, 3)).astype(np.float32)
+    x_0 = np.random.random((2, 3)).astype(np.float32)
+    x_1 = np.random.random((2, 3)).astype(np.float32)
+    x_2 = np.random.random((2, 3)).astype(np.float32)
+
+    h_0 = np.zeros((2, 2)).astype(
+        np.float32)  
 
-        h_0 = np.random.random((2, 1)).astype(
-            np.float32)  # (2,1) rather than (2,)
+    t_0 = np.random.random((2, 2)).astype(np.float32)
+    t_1 = np.random.random((2, 2)).astype(np.float32)
+    t_2 = np.random.random((2, 2)).astype(np.float32)
 
-        t_0 = np.random.random((2, 2)).astype(np.float32)
-        t_1 = np.random.random((2, 2)).astype(np.float32)
-        t_2 = np.random.random((2, 2)).astype(np.float32)
+    x0 = tensor.Tensor(device=gpu_dev, data=x_0)
+    x1 = tensor.Tensor(device=gpu_dev, data=x_1)
+    x2 = tensor.Tensor(device=gpu_dev, data=x_2)
 
-        x0 = tensor.Tensor(device=gpu_dev, data=x_0)
-        x1 = tensor.Tensor(device=gpu_dev, data=x_1)
-        x2 = tensor.Tensor(device=gpu_dev, data=x_2)
+    h0 = tensor.Tensor(device=gpu_dev, data=h_0)
 
-        h0 = tensor.Tensor(device=gpu_dev, data=h_0)
+    t0 = tensor.Tensor(device=gpu_dev, data=t_0)
+    t1 = tensor.Tensor(device=gpu_dev, data=t_1)
+    t2 = tensor.Tensor(device=gpu_dev, data=t_2)
 
-        t0 = tensor.Tensor(device=gpu_dev, data=t_0)
-        t1 = tensor.Tensor(device=gpu_dev, data=t_1)
-        t2 = tensor.Tensor(device=gpu_dev, data=t_2)
+    inputs = [x0, x1, x2]
+    targets = [t0, t1, t2]
+    return inputs, targets, h0
 
-        inputs = [x0, x1, x2]
-        targets = [t0, t1, t2]
-        return inputs, targets, h0
 
 class TestPythonOperation(unittest.TestCase):
 
@@ -114,7 +116,7 @@ class TestPythonOperation(unittest.TestCase):
         self.check_shape(ds.shape(), (3,))
         self.check_shape(db.shape(), (3,))
 
-    def test_vanillaRNN_gpu_tiny_ops(self):
+    def test_vanillaRNN_gpu_tiny_ops_shape_check(self):
         # gradients shape check.
         inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
         rnn = autograd.Vanilla_RNN(3, 2)
@@ -130,7 +132,7 @@ class TestPythonOperation(unittest.TestCase):
         for t, dt in autograd.backward(loss):
             self.check_shape(t.shape, dt.shape)
 
-    def test_LSTM_gpu_tiny_ops(self):
+    def test_LSTM_gpu_tiny_ops_shape_check(self):
         # gradients shape check.
         inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
         c_0 = np.random.random((2, 1)).astype(np.float32)
@@ -149,107 +151,82 @@ class TestPythonOperation(unittest.TestCase):
         for t, dt in autograd.backward(loss):
             self.check_shape(t.shape, dt.shape)
 
+    def gradients_check(self, func, param, autograds, h=0.0005, df=1):
+        # param: PyTensor
+        # autograds: numpy_tensor
+        p = tensor.to_numpy(param)
+        it = np.nditer(p, flags=['multi_index'], op_flags=['readwrite'])
+        while not it.finished:
+            idx = it.multi_index
+            diff = np.zeros_like(p)
+            diff[idx] += h
+            diff = tensor.from_numpy(diff)
+            diff.to_device(gpu_dev)
+
+            param += diff
+            pos = func()
+            pos = tensor.to_numpy(pos)
+
+            param -= diff
+            param -= diff
+            neg = func()
+            neg = tensor.to_numpy(neg)
+
+            numerical_grad = np.sum((pos - neg) * df) / (2 * h)
+            #print((autograds[idx] - numerical_grad)/numerical_grad)
+            # threshold set as -5% to +5%
+            #self.assertAlmostEqual((autograds[idx] - numerical_grad)/(numerical_grad+0.0000001), 0., places=1)
+            self.assertAlmostEqual(
+                autograds[idx] - numerical_grad, 0., places=2)
+
+            it.iternext()
+
     def test_numerical_gradients_check_for_vallina_rnn(self):
         inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
 
         rnn = autograd.Vanilla_RNN(3, 2)
 
-        hs, _ = rnn(h0, *inputs)
-
-        loss1 = autograd.softmax_cross_entropy(hs[0], target[0])
-        for i in range(1, len(hs)):
-            l = autograd.softmax_cross_entropy(hs[i], target[i])
-            loss1 = autograd.add(loss1, l)
-        grads = autograd.gradients(loss1)
-
-        # autograd gradients for dL/dWx[0][0]
-        d1 = tensor.to_numpy(grads[rnn.Wx])[0][0]
-        #print('autograd result of dL/dWx[0][0] is ', d1)
+        def valinna_rnn_forward():
+            hs, _ = rnn(h0, *inputs)
 
+            loss = autograd.softmax_cross_entropy(hs[0], target[0])
+            for i in range(1, len(hs)):
+                l = autograd.softmax_cross_entropy(hs[i], target[i])
+                loss = autograd.add(loss, l)
+            #grads = autograd.gradients(loss)
+            return loss
 
-        length = 0.01
-        diff = np.array([1, 0, 0, 0, 0, 0]) * length
-        diff = np.reshape(diff, (3, 2))
-        diff = tensor.from_numpy(diff)
-        diff.to_device(gpu_dev)
-
-        rnn.Wx += diff
-        hs, _ = rnn(h0, *inputs)
-        #hs=rnn(h0, x0,x1)
-        loss2_p = autograd.softmax_cross_entropy(hs[0], target[0])
-        for i in range(1, len(hs)):
-            l = autograd.softmax_cross_entropy(hs[i], target[i])
-            loss2_p = autograd.add(loss2_p, l)
-
-        rnn.Wx -= diff
-        rnn.Wx -= diff
-        hs, _ = rnn(h0, *inputs)
-        #hs=rnn(h0, x0,x1)
-        loss2_n = autograd.softmax_cross_entropy(hs[0], target[0])
-        for i in range(1, len(hs)):
-            l = autograd.softmax_cross_entropy(hs[i], target[i])
-            loss2_n = autograd.add(loss2_n, l)
+        loss1 = valinna_rnn_forward()
+        auto_grads = autograd.gradients(loss1)
 
-        loss2_p_np = tensor.to_numpy(loss2_p)
-        loss2_n_np = tensor.to_numpy(loss2_n)
-        # Numerical gradients for dL/dWx[0][0]
-        d2 = (loss2_p_np - loss2_n_np) / 2 / length
-        #print('numerical calculation dL/dWx[0][0] is ', (loss2_p_np-loss2_n_np)/2/length)
+        for param in rnn.params:
+            auto_grad = tensor.to_numpy(auto_grads[param])
 
-        self.assertAlmostEqual(np.sum(d1 - d2), 0., places=3)
+            self.gradients_check(valinna_rnn_forward, param, auto_grad)
 
     def test_numerical_gradients_check_for_lstm(self):
         inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
-        c_0 = np.random.random((2, 1)).astype(np.float32)
+        c_0 = np.zeros((2, 2)).astype(np.float32)
         c0 = tensor.Tensor(device=gpu_dev, data=c_0)
 
         rnn = autograd.LSTM(3, 2)
 
-        hs, _, _ = rnn(h0, c0, *inputs)
-
-        loss1 = autograd.softmax_cross_entropy(hs[0], target[0])
-        for i in range(1, len(hs)):
-            l = autograd.softmax_cross_entropy(hs[i], target[i])
-            loss1 = autograd.add(loss1, l)
-        grads = autograd.gradients(loss1)
-
-        # autograd gradients for dL/dWx[0][0]
-        d1 = tensor.to_numpy(grads[rnn.Wx[0]])[0][0]
-        #print('autograd result of dL/dWx[0][0] is ', d1)
-
-
-        length = 0.01
-        diff = np.array([1, 0, 0, 0, 0, 0]) * length
-        diff = np.reshape(diff, (3, 2))
-        diff = tensor.from_numpy(diff)
-        diff.to_device(gpu_dev)
-
-        rnn.Wx[0] += diff
-        hs, _, _ = rnn(h0, c0, *inputs)
-        #hs=rnn(h0, x0,x1)
-        loss2_p = autograd.softmax_cross_entropy(hs[0], target[0])
-        for i in range(1, len(hs)):
-            l = autograd.softmax_cross_entropy(hs[i], target[i])
-            loss2_p = autograd.add(loss2_p, l)
-
-        rnn.Wx[0] -= diff
-        rnn.Wx[0] -= diff
-        hs, _, _ = rnn(h0, c0, *inputs)
-        #hs=rnn(h0, x0,x1)
-        loss2_n = autograd.softmax_cross_entropy(hs[0], target[0])
-        for i in range(1, len(hs)):
-            l = autograd.softmax_cross_entropy(hs[i], target[i])
-            loss2_n = autograd.add(loss2_n, l)
+        def lstm_forward():
+            hs, _, _ = rnn(h0, c0, *inputs)
 
-        loss2_p_np = tensor.to_numpy(loss2_p)
-        loss2_n_np = tensor.to_numpy(loss2_n)
-        # Numerical gradients for dL/dWx[0][0]
-        d2 = (loss2_p_np - loss2_n_np) / 2 / length
-        #print('numerical calculation dL/dWx[0][0] is ', (loss2_p_np-loss2_n_np)/2/length)
+            loss = autograd.softmax_cross_entropy(hs[0], target[0])
+            for i in range(1, len(hs)):
+                l = autograd.softmax_cross_entropy(hs[i], target[i])
+                loss = autograd.add(loss, l)
+            return loss
 
-        self.assertAlmostEqual(np.sum(d1 - d2), 0., places=3)
+        loss1 = lstm_forward()
+        auto_grads = autograd.gradients(loss1)
 
+        for param in rnn.params:
+            auto_grad = tensor.to_numpy(auto_grads[param])
 
+            self.gradients_check(lstm_forward, param, auto_grad)
 
 if __name__ == '__main__':
     unittest.main()


[5/5] incubator-singa git commit: Merge branch 'pr407'

Posted by wa...@apache.org.
Merge branch 'pr407'


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/2224d5f9
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/2224d5f9
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/2224d5f9

Branch: refs/heads/master
Commit: 2224d5f9ae333659a9c6bfcb018031e16520d49f
Parents: 65756e6 0cd4e30
Author: Wang Wei <wa...@gmail.com>
Authored: Wed Aug 22 22:54:26 2018 +0800
Committer: Wang Wei <wa...@gmail.com>
Committed: Wed Aug 22 22:54:26 2018 +0800

----------------------------------------------------------------------
 python/singa/autograd.py      | 233 +++++++++++++++++++++++++++++++++++--
 python/singa/net.py           |   0
 python/singa/tensor.py        |   6 +-
 test/python/test_layer.py     |   2 +-
 test/python/test_operation.py | 149 +++++++++++++++++++++++-
 5 files changed, 374 insertions(+), 16 deletions(-)
----------------------------------------------------------------------