You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/05/14 05:53:40 UTC
[1/2] incubator-singa git commit: SINGA-337 Add test cases for code
Repository: incubator-singa
Updated Branches:
refs/heads/master 1bee4d2a0 -> 92b892a67
SINGA-337 Add test cases for code
I have implemented test cases as many as I can, and all test cases have been tested in the local environment.
The coding style and format follows the previous test cases.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/c5ca8e8b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/c5ca8e8b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/c5ca8e8b
Branch: refs/heads/master
Commit: c5ca8e8baf8a78991ce3e6d3c1fd89791e3af485
Parents: c61a0d8
Author: Wentong <as...@163.com>
Authored: Sat May 12 21:06:11 2018 +0800
Committer: GitHub <no...@github.com>
Committed: Sat May 12 21:06:11 2018 +0800
----------------------------------------------------------------------
test/python/test_layer.py | 66 +++++++----
test/python/test_loss.py | 20 +++-
test/python/test_optimizer.py | 230 +++++++++++++++++++++++++++++++++++++
3 files changed, 291 insertions(+), 25 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c5ca8e8b/test/python/test_layer.py
----------------------------------------------------------------------
diff --git a/test/python/test_layer.py b/test/python/test_layer.py
index cef29e4..2c49961 100644
--- a/test/python/test_layer.py
+++ b/test/python/test_layer.py
@@ -86,11 +86,11 @@ class TestPythonLayer(unittest.TestCase):
dx = tensor.to_numpy(dx).flatten()
dw = tensor.to_numpy(dw).flatten()
dy = dy.flatten()
- self.assertAlmostEquals(dy[0] * w[4], dx[0])
- self.assertAlmostEquals(dy[0] * w[5] + dy[1] * w[3], dx[1])
- self.assertAlmostEquals(dy[1] * w[4], dx[2])
- self.assertAlmostEquals(dy[0] * w[7] + dy[2] * w[1], dx[3])
- self.assertAlmostEquals(
+ self.assertAlmostEqual(dy[0] * w[4], dx[0])
+ self.assertAlmostEqual(dy[0] * w[5] + dy[1] * w[3], dx[1])
+ self.assertAlmostEqual(dy[1] * w[4], dx[2])
+ self.assertAlmostEqual(dy[0] * w[7] + dy[2] * w[1], dx[3])
+ self.assertAlmostEqual(
dy[0] *
w[8] +
dy[1] *
@@ -100,16 +100,16 @@ class TestPythonLayer(unittest.TestCase):
dy[3] *
w[0],
dx[4])
- self.assertAlmostEquals(dy[1] * w[7] + dy[3] * w[1], dx[5])
- self.assertAlmostEquals(dy[2] * w[4], dx[6])
- self.assertAlmostEquals(dy[2] * w[5] + dy[3] * w[3], dx[7])
- self.assertAlmostEquals(dy[3] * w[4], dx[8])
-
- self.assertAlmostEquals(dy[3] * raw_x[4], dw[0])
- self.assertAlmostEquals(dy[3] * raw_x[5] + dy[2] * raw_x[3], dw[1])
- self.assertAlmostEquals(dy[2] * raw_x[4], dw[2])
- self.assertAlmostEquals(dy[1] * raw_x[1] + dy[3] * raw_x[7], dw[3])
- self.assertAlmostEquals(
+ self.assertAlmostEqual(dy[1] * w[7] + dy[3] * w[1], dx[5])
+ self.assertAlmostEqual(dy[2] * w[4], dx[6])
+ self.assertAlmostEqual(dy[2] * w[5] + dy[3] * w[3], dx[7])
+ self.assertAlmostEqual(dy[3] * w[4], dx[8])
+
+ self.assertAlmostEqual(dy[3] * raw_x[4], dw[0])
+ self.assertAlmostEqual(dy[3] * raw_x[5] + dy[2] * raw_x[3], dw[1])
+ self.assertAlmostEqual(dy[2] * raw_x[4], dw[2])
+ self.assertAlmostEqual(dy[1] * raw_x[1] + dy[3] * raw_x[7], dw[3])
+ self.assertAlmostEqual(
dy[0] *
raw_x[0] +
dy[1] *
@@ -119,10 +119,10 @@ class TestPythonLayer(unittest.TestCase):
dy[3] *
raw_x[8],
dw[4], 5)
- self.assertAlmostEquals(dy[0] * raw_x[1] + dy[2] * raw_x[7], dw[5])
- self.assertAlmostEquals(dy[1] * raw_x[4], dw[6])
- self.assertAlmostEquals(dy[0] * raw_x[3] + dy[1] * raw_x[5], dw[7])
- self.assertAlmostEquals(dy[0] * raw_x[4], dw[8])
+ self.assertAlmostEqual(dy[0] * raw_x[1] + dy[2] * raw_x[7], dw[5])
+ self.assertAlmostEqual(dy[1] * raw_x[4], dw[6])
+ self.assertAlmostEqual(dy[0] * raw_x[3] + dy[1] * raw_x[5], dw[7])
+ self.assertAlmostEqual(dy[0] * raw_x[4], dw[8])
def test_conv1D(self):
in_sample_shape = (224,)
@@ -213,12 +213,12 @@ class TestPythonLayer(unittest.TestCase):
lyr = layer.Concat('concat', 0, [(3,), (3,)])
t = lyr.forward(model_pb2.kTrain, [t1, t2])
tnp = tensor.to_numpy(t)
- self.assertEquals(np.sum(tnp), 12)
+ self.assertEqual(np.sum(tnp), 12)
t3 = tensor.Tensor((3, 3))
t3.set_value(1.5)
grads, _ = lyr.backward(model_pb2.kTrain, [t3])
gnp = tensor.to_numpy(grads[0])
- self.assertEquals(np.sum(gnp), 6 * 1.5)
+ self.assertEqual(np.sum(gnp), 6 * 1.5)
def test_slice(self):
t = np.zeros((3, 3))
@@ -228,16 +228,34 @@ class TestPythonLayer(unittest.TestCase):
out = lyr.forward(model_pb2.kTrain, [tensor.from_numpy(t)])
t1 = tensor.to_numpy(out[0])
t2 = tensor.to_numpy(out[1])
- self.assertEquals(np.average(t1), 2)
- self.assertEquals(np.average(t2), 1)
+ self.assertEqual(np.average(t1), 2)
+ self.assertEqual(np.average(t2), 1)
t1 = tensor.Tensor((3, 2))
t2 = tensor.Tensor((3, 1))
t1.set_value(1)
t2.set_value(2)
grad, _ = lyr.backward(model_pb2.kTrain, [t1, t2])
gnp = tensor.to_numpy(grad)
- self.assertEquals(np.sum(gnp), 12)
+ self.assertEqual(np.sum(gnp), 12)
+ def test_l2norm(self):
+ in_sample_shape = (3, 224, 224)
+ l2norm = layer.L2Norm('l2norm', input_sample_shape=in_sample_shape)
+ out_sample_shape = l2norm.get_output_sample_shape()
+ self.check_shape(out_sample_shape, in_sample_shape)
+
+ def test_merge(self):
+ in_sample_shape = (3, 224, 224)
+ merge = layer.Merge('merge', input_sample_shape=in_sample_shape)
+ out_sample_shape = merge.get_output_sample_shape()
+ self.check_shape(out_sample_shape, in_sample_shape)
+
+ def test_split(self):
+ in_sample_shape = (3, 224, 224)
+ split = layer.Split('split', num_output=3,
+ input_sample_shape=in_sample_shape)
+ out_sample_shape = split.get_output_sample_shape()
+ self.check_shape(out_sample_shape, [in_sample_shape] * 3)
if __name__ == '__main__':
unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c5ca8e8b/test/python/test_loss.py
----------------------------------------------------------------------
diff --git a/test/python/test_loss.py b/test/python/test_loss.py
index 31784ce..f30859b 100644
--- a/test/python/test_loss.py
+++ b/test/python/test_loss.py
@@ -46,10 +46,28 @@ class TestLoss(unittest.TestCase):
l2 = sig.evaluate(True, self.x, self.y)
p = 1.0 / (1 + np.exp(-self.x_np))
- l = - (self.y_np * np.log(p) + (1-self.y_np) * np.log(1-p))
+ l = - (self.y_np * np.log(p) + (1 - self.y_np) * np.log(1 - p))
self.assertAlmostEqual(l1.l1(), l2)
self.assertAlmostEqual(l1.l1(), np.average(l))
+ def test_squared_error(self):
+ sqe = loss.SquaredError()
+ l1 = sqe.forward(True, self.x, self.y)
+ sqe.backward()
+ l2 = sqe.evaluate(True, self.x, self.y)
+
+ l = 0.5 * (self.y_np - self.x_np) ** 2
+ self.assertAlmostEqual(l1.l1(), l2)
+ self.assertAlmostEqual(l1.l1(), np.average(l))
+
+ def test_softmax_cross_entropy(self):
+ sce = loss.SoftmaxCrossEntropy()
+ l1 = sce.forward(True, self.x, self.y)
+ sce.backward()
+ l2 = sce.evaluate(True, self.x, self.y)
+
+ self.assertAlmostEqual(l1.l1(), l2)
+
if __name__ == '__main__':
unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c5ca8e8b/test/python/test_optimizer.py
----------------------------------------------------------------------
diff --git a/test/python/test_optimizer.py b/test/python/test_optimizer.py
index f5c5471..84da11c 100644
--- a/test/python/test_optimizer.py
+++ b/test/python/test_optimizer.py
@@ -41,6 +41,23 @@ def np_adam(plist, glist, mlist, vlist, lr, t, b1=0.9, b2=0.999):
v += (1-b2) * g * g
alpha = lr * math.sqrt(1. - math.pow(b2, t)) / (1. - math.pow(b1, t))
p -= alpha * m / (np.sqrt(v) + 1e-8)
+
+def np_rmsprop(plist, glist, vlist, lr, t, rho=0.9):
+ for p, g, v in zip(plist, glist, vlist):
+ v *= rho
+ v += (1-rho) * g * g
+ p -= lr * g / (np.sqrt(v + 1e-8))
+
+def np_momentum(plist, glist, vlist, lr, t, momentum=0.9):
+ for p, g, v in zip(plist, glist, vlist):
+ v *= momentum
+ v += lr * g
+ p -= v
+
+def np_adagrad(plist, glist, vlist, lr, t):
+ for p, g, v in zip(plist, glist, vlist):
+ v += g * g
+ p -= lr * g / (np.sqrt(v + 1e-8))
class TestOptimizer(unittest.TestCase):
@@ -147,6 +164,219 @@ class TestOptimizer(unittest.TestCase):
self.assertAlmostEqual(g[i],
self.np_g[i] + coefficient * self.np_W[i])
+ @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+ def test_adam_cuda(self):
+ lr = 0.1
+ n, m = 4, 6
+ p1 = np.random.rand(n, m)
+ p2 = np.random.rand(n, m)
+ g1 = np.random.rand(n, m) * 0.01
+ g2 = np.random.rand(n, m) * 0.01
+ m1 = np.zeros((n, m))
+ m2 = np.zeros((n, m))
+ v1 = np.zeros((n, m))
+ v2 = np.zeros((n, m))
+ t1 = tensor.from_numpy(p1)
+ t2 = tensor.from_numpy(p2)
+ tg1 = tensor.from_numpy(g1)
+ tg2 = tensor.from_numpy(g2)
+
+ for t in range(1, 10):
+ np_adam([p1, p2], [g1, g2], [m1, m2], [v1, v2], lr, t)
+
+ adam = opt.Adam(lr=lr)
+ self.to_cuda()
+ for t in range(1, 10):
+ adam.apply(0, tg1, t1, 'p1', t)
+ adam.apply(0, tg2, t2, 'p2', t)
+
+ t1 = tensor.to_numpy(t1)
+ t2 = tensor.to_numpy(t2)
+ for t, p in zip([t1, t2], [p1, p2]):
+ for i in range(n):
+ for j in range(m):
+ self.assertAlmostEqual(t[i, j], p[i, j], 6)
+
+ def test_rmsprop(self):
+ lr = 0.1
+ n, m = 2, 2
+ p1 = np.random.rand(n, m)
+ p2 = np.random.rand(n, m)
+ g1 = np.random.rand(n, m) * 0.01
+ g2 = np.random.rand(n, m) * 0.01
+ v1 = np.zeros((n, m))
+ v2 = np.zeros((n, m))
+ t1 = tensor.from_numpy(p1)
+ t2 = tensor.from_numpy(p2)
+ tg1 = tensor.from_numpy(g1)
+ tg2 = tensor.from_numpy(g2)
+
+ for t in range(1, 4):
+ np_rmsprop([p1, p2], [g1, g2], [v1, v2], lr, t)
+
+ rsmprop = opt.RMSProp(lr=lr)
+ for t in range(1, 4):
+ rsmprop.apply(0, tg1, t1, 'p1', t)
+ rsmprop.apply(0, tg2, t2, 'p2', t)
+
+ t1 = tensor.to_numpy(t1)
+ t2 = tensor.to_numpy(t2)
+ for t, p in zip([t1, t2], [p1, p2]):
+ for i in range(n):
+ for j in range(m):
+ self.assertAlmostEqual(t[i, j], p[i, j], 2)
+
+ @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+ def test_rmsprop_cuda(self):
+ lr = 0.1
+ n, m = 2, 2
+ p1 = np.random.rand(n, m)
+ p2 = np.random.rand(n, m)
+ g1 = np.random.rand(n, m) * 0.01
+ g2 = np.random.rand(n, m) * 0.01
+ v1 = np.zeros((n, m))
+ v2 = np.zeros((n, m))
+ t1 = tensor.from_numpy(p1)
+ t2 = tensor.from_numpy(p2)
+ tg1 = tensor.from_numpy(g1)
+ tg2 = tensor.from_numpy(g2)
+
+ for t in range(1, 4):
+ np_rmsprop([p1, p2], [g1, g2], [v1, v2], lr, t)
+
+ rsmprop = opt.RMSProp(lr=lr)
+ self.to_cuda()
+ for t in range(1, 4):
+ rsmprop.apply(0, tg1, t1, 'p1', t)
+ rsmprop.apply(0, tg2, t2, 'p2', t)
+
+ t1 = tensor.to_numpy(t1)
+ t2 = tensor.to_numpy(t2)
+ for t, p in zip([t1, t2], [p1, p2]):
+ for i in range(n):
+ for j in range(m):
+ self.assertAlmostEqual(t[i, j], p[i, j], 2)
+
+ def test_momentum(self):
+ lr = 0.1
+ n, m = 2, 2
+ p1 = np.random.rand(n, m)
+ p2 = np.random.rand(n, m)
+ g1 = np.random.rand(n, m) * 0.01
+ g2 = np.random.rand(n, m) * 0.01
+ v1 = np.zeros((n, m))
+ v2 = np.zeros((n, m))
+ t1 = tensor.from_numpy(p1)
+ t2 = tensor.from_numpy(p2)
+ tg1 = tensor.from_numpy(g1)
+ tg2 = tensor.from_numpy(g2)
+
+ for t in range(1, 4):
+ np_momentum([p1, p2], [g1, g2], [v1, v2], lr, t)
+
+ momentum = opt.SGD(lr, momentum=0.9)
+ for t in range(1, 4):
+ momentum.apply(0, tg1, t1, 'p1', t)
+ momentum.apply(0, tg2, t2, 'p2', t)
+
+ t1 = tensor.to_numpy(t1)
+ t2 = tensor.to_numpy(t2)
+ for t, p in zip([t1, t2], [p1, p2]):
+ for i in range(n):
+ for j in range(m):
+ self.assertAlmostEqual(t[i, j], p[i, j], 2)
+
+ @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+ def test_momentum_cuda(self):
+ lr = 0.1
+ n, m = 2, 2
+ p1 = np.random.rand(n, m)
+ p2 = np.random.rand(n, m)
+ g1 = np.random.rand(n, m) * 0.01
+ g2 = np.random.rand(n, m) * 0.01
+ v1 = np.zeros((n, m))
+ v2 = np.zeros((n, m))
+ t1 = tensor.from_numpy(p1)
+ t2 = tensor.from_numpy(p2)
+ tg1 = tensor.from_numpy(g1)
+ tg2 = tensor.from_numpy(g2)
+
+ for t in range(1, 4):
+ np_momentum([p1, p2], [g1, g2], [v1, v2], lr, t)
+
+ momentum = opt.SGD(lr, momentum=0.9)
+ self.to_cuda()
+ for t in range(1, 4):
+ momentum.apply(0, tg1, t1, 'p1', t)
+ momentum.apply(0, tg2, t2, 'p2', t)
+
+ t1 = tensor.to_numpy(t1)
+ t2 = tensor.to_numpy(t2)
+ for t, p in zip([t1, t2], [p1, p2]):
+ for i in range(n):
+ for j in range(m):
+ self.assertAlmostEqual(t[i, j], p[i, j], 2)
+
+ def test_adagrad(self):
+ lr = 0.1
+ n, m = 2, 2
+ p1 = np.random.rand(n, m)
+ p2 = np.random.rand(n, m)
+ g1 = np.random.rand(n, m) * 0.01
+ g2 = np.random.rand(n, m) * 0.01
+ v1 = np.zeros((n, m))
+ v2 = np.zeros((n, m))
+ t1 = tensor.from_numpy(p1)
+ t2 = tensor.from_numpy(p2)
+ tg1 = tensor.from_numpy(g1)
+ tg2 = tensor.from_numpy(g2)
+
+ for t in range(1, 4):
+ np_adagrad([p1, p2], [g1, g2], [v1, v2], lr, t)
+
+ adagrad = opt.AdaGrad(lr=lr)
+ for t in range(1, 4):
+ adagrad.apply(0, tg1, t1, 'p1', t)
+ adagrad.apply(0, tg2, t2, 'p2', t)
+
+ t1 = tensor.to_numpy(t1)
+ t2 = tensor.to_numpy(t2)
+ for t, p in zip([t1, t2], [p1, p2]):
+ for i in range(n):
+ for j in range(m):
+ self.assertAlmostEqual(t[i, j], p[i, j], 2)
+
+ @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+ def test_adagrad_cuda(self):
+ lr = 0.1
+ n, m = 2, 2
+ p1 = np.random.rand(n, m)
+ p2 = np.random.rand(n, m)
+ g1 = np.random.rand(n, m) * 0.01
+ g2 = np.random.rand(n, m) * 0.01
+ v1 = np.zeros((n, m))
+ v2 = np.zeros((n, m))
+ t1 = tensor.from_numpy(p1)
+ t2 = tensor.from_numpy(p2)
+ tg1 = tensor.from_numpy(g1)
+ tg2 = tensor.from_numpy(g2)
+
+ for t in range(1, 4):
+ np_adagrad([p1, p2], [g1, g2], [v1, v2], lr, t)
+
+ adagrad = opt.AdaGrad(lr=lr)
+ self.to_cuda()
+ for t in range(1, 4):
+ adagrad.apply(0, tg1, t1, 'p1', t)
+ adagrad.apply(0, tg2, t2, 'p2', t)
+
+ t1 = tensor.to_numpy(t1)
+ t2 = tensor.to_numpy(t2)
+ for t, p in zip([t1, t2], [p1, p2]):
+ for i in range(n):
+ for j in range(m):
+ self.assertAlmostEqual(t[i, j], p[i, j], 2)
+
if __name__ == '__main__':
unittest.main()
[2/2] incubator-singa git commit: Merge branch 'pr370'
Posted by wa...@apache.org.
Merge branch 'pr370'
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/92b892a6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/92b892a6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/92b892a6
Branch: refs/heads/master
Commit: 92b892a6706d88bed2bc4fe78e25e1bcc3704d10
Parents: 1bee4d2 c5ca8e8
Author: Wang Wei <dc...@nus.edu.sg>
Authored: Mon May 14 13:53:28 2018 +0800
Committer: Wang Wei <dc...@nus.edu.sg>
Committed: Mon May 14 13:53:28 2018 +0800
----------------------------------------------------------------------
test/python/test_layer.py | 66 +++++++----
test/python/test_loss.py | 20 +++-
test/python/test_optimizer.py | 230 +++++++++++++++++++++++++++++++++++++
3 files changed, 291 insertions(+), 25 deletions(-)
----------------------------------------------------------------------