You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/08/22 14:55:00 UTC
[4/5] incubator-singa git commit: SINGA-388 Develop some RNN layers
by calling tiny operations like matmul, addbias
SINGA-388 Develop some RNN layers by calling tiny operations like matmul, addbias
- Improve test cases for vallina_rnn and lstm. check all elementes of parameter
matrix.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/0cd4e308
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/0cd4e308
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/0cd4e308
Branch: refs/heads/master
Commit: 0cd4e3084e84a2b1562877dd47ff5ed46fb6aadf
Parents: a44a01c
Author: xuewanqi <xu...@outlook.com>
Authored: Wed Aug 15 12:42:25 2018 +0000
Committer: xuewanqi <xu...@outlook.com>
Committed: Tue Aug 21 15:10:06 2018 +0000
----------------------------------------------------------------------
python/singa/autograd.py | 8 +-
test/python/test_operation.py | 179 ++++++++++++++++---------------------
2 files changed, 82 insertions(+), 105 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0cd4e308/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 7032135..c0f6a7a 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -1060,7 +1060,7 @@ class Vanilla_RNN(RNN):
self.b = Tensor(shape=B_shape, requires_grad=True, stores_grad=True)
self.b.set_value(0.0)
- #self.params= (self.Wx, self.Wh, self.b)
+ self.params= (self.Wx, self.Wh, self.b)
def __call__(self, h0, *xs):
inputs=xs+(h0,)
@@ -1078,9 +1078,9 @@ class Vanilla_RNN(RNN):
return out, h
def step_forward(self, x, h, Wx, Wh, b):
- y1 = matmul(x, Wx)
y2 = matmul(h, Wh)
- y = add(y1, y2)
+ y1 = matmul(x, Wx)
+ y = add(y2, y1)
y = add_bias(y, b, axis=0)
if self.nonlinearity == 'tanh':
y = tanh(y)
@@ -1124,7 +1124,7 @@ class LSTM(RNN):
b.set_value(0.0)
self.Bh.append(b)
- #self.params=self.Wx + self.Wh + self.Bx + self.Bh
+ self.params=self.Wx + self.Wh + self.Bx + self.Bh
def __call__(self, h0, c0, *xs):
inputs=xs+(h0,c0)
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0cd4e308/test/python/test_operation.py
----------------------------------------------------------------------
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index 4975d99..64562a5 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -23,31 +23,33 @@ def _tuple_to_string(t):
lt = [str(x) for x in t]
return '(' + ', '.join(lt) + ')'
+
def prepare_inputs_targets_for_rnn_test():
- x_0 = np.random.random((2, 3)).astype(np.float32)
- x_1 = np.random.random((2, 3)).astype(np.float32)
- x_2 = np.random.random((2, 3)).astype(np.float32)
+ x_0 = np.random.random((2, 3)).astype(np.float32)
+ x_1 = np.random.random((2, 3)).astype(np.float32)
+ x_2 = np.random.random((2, 3)).astype(np.float32)
+
+ h_0 = np.zeros((2, 2)).astype(
+ np.float32)
- h_0 = np.random.random((2, 1)).astype(
- np.float32) # (2,1) rather than (2,)
+ t_0 = np.random.random((2, 2)).astype(np.float32)
+ t_1 = np.random.random((2, 2)).astype(np.float32)
+ t_2 = np.random.random((2, 2)).astype(np.float32)
- t_0 = np.random.random((2, 2)).astype(np.float32)
- t_1 = np.random.random((2, 2)).astype(np.float32)
- t_2 = np.random.random((2, 2)).astype(np.float32)
+ x0 = tensor.Tensor(device=gpu_dev, data=x_0)
+ x1 = tensor.Tensor(device=gpu_dev, data=x_1)
+ x2 = tensor.Tensor(device=gpu_dev, data=x_2)
- x0 = tensor.Tensor(device=gpu_dev, data=x_0)
- x1 = tensor.Tensor(device=gpu_dev, data=x_1)
- x2 = tensor.Tensor(device=gpu_dev, data=x_2)
+ h0 = tensor.Tensor(device=gpu_dev, data=h_0)
- h0 = tensor.Tensor(device=gpu_dev, data=h_0)
+ t0 = tensor.Tensor(device=gpu_dev, data=t_0)
+ t1 = tensor.Tensor(device=gpu_dev, data=t_1)
+ t2 = tensor.Tensor(device=gpu_dev, data=t_2)
- t0 = tensor.Tensor(device=gpu_dev, data=t_0)
- t1 = tensor.Tensor(device=gpu_dev, data=t_1)
- t2 = tensor.Tensor(device=gpu_dev, data=t_2)
+ inputs = [x0, x1, x2]
+ targets = [t0, t1, t2]
+ return inputs, targets, h0
- inputs = [x0, x1, x2]
- targets = [t0, t1, t2]
- return inputs, targets, h0
class TestPythonOperation(unittest.TestCase):
@@ -114,7 +116,7 @@ class TestPythonOperation(unittest.TestCase):
self.check_shape(ds.shape(), (3,))
self.check_shape(db.shape(), (3,))
- def test_vanillaRNN_gpu_tiny_ops(self):
+ def test_vanillaRNN_gpu_tiny_ops_shape_check(self):
# gradients shape check.
inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
rnn = autograd.Vanilla_RNN(3, 2)
@@ -130,7 +132,7 @@ class TestPythonOperation(unittest.TestCase):
for t, dt in autograd.backward(loss):
self.check_shape(t.shape, dt.shape)
- def test_LSTM_gpu_tiny_ops(self):
+ def test_LSTM_gpu_tiny_ops_shape_check(self):
# gradients shape check.
inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
c_0 = np.random.random((2, 1)).astype(np.float32)
@@ -149,107 +151,82 @@ class TestPythonOperation(unittest.TestCase):
for t, dt in autograd.backward(loss):
self.check_shape(t.shape, dt.shape)
+ def gradients_check(self, func, param, autograds, h=0.0005, df=1):
+ # param: PyTensor
+ # autograds: numpy_tensor
+ p = tensor.to_numpy(param)
+ it = np.nditer(p, flags=['multi_index'], op_flags=['readwrite'])
+ while not it.finished:
+ idx = it.multi_index
+ diff = np.zeros_like(p)
+ diff[idx] += h
+ diff = tensor.from_numpy(diff)
+ diff.to_device(gpu_dev)
+
+ param += diff
+ pos = func()
+ pos = tensor.to_numpy(pos)
+
+ param -= diff
+ param -= diff
+ neg = func()
+ neg = tensor.to_numpy(neg)
+
+ numerical_grad = np.sum((pos - neg) * df) / (2 * h)
+ #print((autograds[idx] - numerical_grad)/numerical_grad)
+ # threshold set as -5% to +5%
+ #self.assertAlmostEqual((autograds[idx] - numerical_grad)/(numerical_grad+0.0000001), 0., places=1)
+ self.assertAlmostEqual(
+ autograds[idx] - numerical_grad, 0., places=2)
+
+ it.iternext()
+
def test_numerical_gradients_check_for_vallina_rnn(self):
inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
rnn = autograd.Vanilla_RNN(3, 2)
- hs, _ = rnn(h0, *inputs)
-
- loss1 = autograd.softmax_cross_entropy(hs[0], target[0])
- for i in range(1, len(hs)):
- l = autograd.softmax_cross_entropy(hs[i], target[i])
- loss1 = autograd.add(loss1, l)
- grads = autograd.gradients(loss1)
-
- # autograd gradients for dL/dWx[0][0]
- d1 = tensor.to_numpy(grads[rnn.Wx])[0][0]
- #print('autograd result of dL/dWx[0][0] is ', d1)
+ def valinna_rnn_forward():
+ hs, _ = rnn(h0, *inputs)
+ loss = autograd.softmax_cross_entropy(hs[0], target[0])
+ for i in range(1, len(hs)):
+ l = autograd.softmax_cross_entropy(hs[i], target[i])
+ loss = autograd.add(loss, l)
+ #grads = autograd.gradients(loss)
+ return loss
- length = 0.01
- diff = np.array([1, 0, 0, 0, 0, 0]) * length
- diff = np.reshape(diff, (3, 2))
- diff = tensor.from_numpy(diff)
- diff.to_device(gpu_dev)
-
- rnn.Wx += diff
- hs, _ = rnn(h0, *inputs)
- #hs=rnn(h0, x0,x1)
- loss2_p = autograd.softmax_cross_entropy(hs[0], target[0])
- for i in range(1, len(hs)):
- l = autograd.softmax_cross_entropy(hs[i], target[i])
- loss2_p = autograd.add(loss2_p, l)
-
- rnn.Wx -= diff
- rnn.Wx -= diff
- hs, _ = rnn(h0, *inputs)
- #hs=rnn(h0, x0,x1)
- loss2_n = autograd.softmax_cross_entropy(hs[0], target[0])
- for i in range(1, len(hs)):
- l = autograd.softmax_cross_entropy(hs[i], target[i])
- loss2_n = autograd.add(loss2_n, l)
+ loss1 = valinna_rnn_forward()
+ auto_grads = autograd.gradients(loss1)
- loss2_p_np = tensor.to_numpy(loss2_p)
- loss2_n_np = tensor.to_numpy(loss2_n)
- # Numerical gradients for dL/dWx[0][0]
- d2 = (loss2_p_np - loss2_n_np) / 2 / length
- #print('numerical calculation dL/dWx[0][0] is ', (loss2_p_np-loss2_n_np)/2/length)
+ for param in rnn.params:
+ auto_grad = tensor.to_numpy(auto_grads[param])
- self.assertAlmostEqual(np.sum(d1 - d2), 0., places=3)
+ self.gradients_check(valinna_rnn_forward, param, auto_grad)
def test_numerical_gradients_check_for_lstm(self):
inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
- c_0 = np.random.random((2, 1)).astype(np.float32)
+ c_0 = np.zeros((2, 2)).astype(np.float32)
c0 = tensor.Tensor(device=gpu_dev, data=c_0)
rnn = autograd.LSTM(3, 2)
- hs, _, _ = rnn(h0, c0, *inputs)
-
- loss1 = autograd.softmax_cross_entropy(hs[0], target[0])
- for i in range(1, len(hs)):
- l = autograd.softmax_cross_entropy(hs[i], target[i])
- loss1 = autograd.add(loss1, l)
- grads = autograd.gradients(loss1)
-
- # autograd gradients for dL/dWx[0][0]
- d1 = tensor.to_numpy(grads[rnn.Wx[0]])[0][0]
- #print('autograd result of dL/dWx[0][0] is ', d1)
-
-
- length = 0.01
- diff = np.array([1, 0, 0, 0, 0, 0]) * length
- diff = np.reshape(diff, (3, 2))
- diff = tensor.from_numpy(diff)
- diff.to_device(gpu_dev)
-
- rnn.Wx[0] += diff
- hs, _, _ = rnn(h0, c0, *inputs)
- #hs=rnn(h0, x0,x1)
- loss2_p = autograd.softmax_cross_entropy(hs[0], target[0])
- for i in range(1, len(hs)):
- l = autograd.softmax_cross_entropy(hs[i], target[i])
- loss2_p = autograd.add(loss2_p, l)
-
- rnn.Wx[0] -= diff
- rnn.Wx[0] -= diff
- hs, _, _ = rnn(h0, c0, *inputs)
- #hs=rnn(h0, x0,x1)
- loss2_n = autograd.softmax_cross_entropy(hs[0], target[0])
- for i in range(1, len(hs)):
- l = autograd.softmax_cross_entropy(hs[i], target[i])
- loss2_n = autograd.add(loss2_n, l)
+ def lstm_forward():
+ hs, _, _ = rnn(h0, c0, *inputs)
- loss2_p_np = tensor.to_numpy(loss2_p)
- loss2_n_np = tensor.to_numpy(loss2_n)
- # Numerical gradients for dL/dWx[0][0]
- d2 = (loss2_p_np - loss2_n_np) / 2 / length
- #print('numerical calculation dL/dWx[0][0] is ', (loss2_p_np-loss2_n_np)/2/length)
+ loss = autograd.softmax_cross_entropy(hs[0], target[0])
+ for i in range(1, len(hs)):
+ l = autograd.softmax_cross_entropy(hs[i], target[i])
+ loss = autograd.add(loss, l)
+ return loss
- self.assertAlmostEqual(np.sum(d1 - d2), 0., places=3)
+ loss1 = lstm_forward()
+ auto_grads = autograd.gradients(loss1)
+ for param in rnn.params:
+ auto_grad = tensor.to_numpy(auto_grads[param])
+ self.gradients_check(lstm_forward, param, auto_grad)
if __name__ == '__main__':
unittest.main()