You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2020/01/29 02:29:29 UTC

[singa] branch master updated: hotfix: bugs in autograd.py and also update test case

This is an automated email from the ASF dual-hosted git repository.

wangwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/master by this push:
     new eee9e44  hotfix: bugs in autograd.py and also update test case
     new 5057712  Merge pull request #579 from chrishkchris/hotfix_autograd
eee9e44 is described below

commit eee9e4414fedeea93a7b1f5cd9c11d45952685e3
Author: chrishkchris <ch...@yahoo.com.hk>
AuthorDate: Wed Jan 22 06:54:34 2020 +0000

    hotfix: bugs in autograd.py and also update test case
---
 python/singa/autograd.py      | 16 ++++++++--------
 test/python/test_operation.py | 28 ++++++++++++++--------------
 2 files changed, 22 insertions(+), 22 deletions(-)

diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 8e245c6..e2c3e1c 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -632,18 +632,18 @@ class Reshape(Operation):
             self.shape = list(shape)
 
     def forward(self, x):
-        _shape = x.shape()
+        self._shape = x.shape()
         shape = self.shape
         # handle the shape with 0
-        shape = [_shape[i] if i < len(_shape) and shape[i] == 0 else shape[i] for i in range(len(shape))]
+        shape = [self._shape[i] if i < len(self._shape) and shape[i] == 0 else shape[i] for i in range(len(shape))]
         # handle the shape with -1
-        hidden_shape = int(np.prod(_shape) // np.abs(np.prod(shape)))
+        hidden_shape = int(np.prod(self._shape) // np.abs(np.prod(shape)))
         self.cache=[s if s != -1 else hidden_shape for s in shape]
 
         return singa.Reshape(x, self.cache)
 
     def backward(self, dy):
-        return singa.Reshape(dy, self.cache)
+        return singa.Reshape(dy, self._shape)
 
 
 def reshape(a,shape):
@@ -1199,7 +1199,7 @@ class _Conv2d(Operation):
             b = CTensor((self.handle.num_filters,), x.device())
             b.SetFloatValue(0.0)
 
-        if singa.USE_CUDA:
+        if (type(self.handle) != singa.ConvHandle):
             return singa.GpuConvForward(x, W, b, self.handle)
         else:
             return singa.CpuConvForward(x, W, b, self.handle)
@@ -1209,7 +1209,7 @@ class _Conv2d(Operation):
             self, "inputs"
         ), "Please set training as True before do BP. "
         
-        if singa.USE_CUDA:
+        if (type(self.handle) != singa.ConvHandle):
             dx = singa.GpuConvBackwardx(
                 dy, self.inputs[1], self.inputs[0], self.handle
             )
@@ -1572,7 +1572,7 @@ class _Pooling2d(Operation):
         self.handle = handle
 
     def forward(self, x):
-        if singa.USE_CUDA:
+        if (type(self.handle) != singa.PoolingHandle):
             y = singa.GpuPoolingForward(self.handle, x)
         else:
             y = singa.CpuPoolingForward(self.handle, x)
@@ -1583,7 +1583,7 @@ class _Pooling2d(Operation):
         return y
 
     def backward(self, dy):
-        if singa.USE_CUDA:
+        if (type(self.handle) != singa.PoolingHandle):
             dx = singa.GpuPoolingBackward(
                 self.handle, dy, self.cache[0], self.cache[1]
             )
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index 6c40f1b..31c6007 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -1438,7 +1438,7 @@ class TestPythonOperation(unittest.TestCase):
     def test_reshape_cpu(self):
         x = np.array([0.1,-1.0,0.4,4.0,-0.9,9.0]).reshape(3,2).astype(np.float32)
         y = x.reshape(2,3)
-        dy = np.ones((3, 2), dtype = np.float32)
+        dy = np.array([1,2,3,4,5,6]).reshape(2,3).astype(np.float32)
         grad = dy.reshape(3,2)
 
 
@@ -1458,7 +1458,7 @@ class TestPythonOperation(unittest.TestCase):
     def test_reshape_gpu(self):
         x = np.array([0.1,-1.0,0.4,4.0,-0.9,9.0]).reshape(3,2).astype(np.float32)
         y = x.reshape(2,3)
-        dy = np.ones((3, 2), dtype = np.float32)
+        dy = np.array([1,2,3,4,5,6]).reshape(2,3).astype(np.float32)
         grad = dy.reshape(3,2)
 
 
@@ -2579,9 +2579,9 @@ class TestPythonOperation(unittest.TestCase):
 
             result = autograd.div(x,x1)
             dx0,dx1 = result.creator.backward(dy.data)
-            np.testing.assert_array_almost_equal(tensor.to_numpy(result), y, decimal=5)
-            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx0)), grad0, decimal=5)
-            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx1)), grad1, decimal=5)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(result), y, decimal=2)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx0)), grad0, decimal=2)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx1)), grad1, decimal=2)
             break
 
     def test_div_broadcast_cpu(self):
@@ -2611,9 +2611,9 @@ class TestPythonOperation(unittest.TestCase):
 
             result = autograd.div(x,x1)
             dx0,dx1 = result.creator.backward(dy.data)
-            np.testing.assert_array_almost_equal(tensor.to_numpy(result), y, decimal=5)
-            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx0)), grad0, decimal=5)
-            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx1)), grad1, decimal=5)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(result), y, decimal=2)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx0)), grad0, decimal=2)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx1)), grad1, decimal=2)
 
     def test_pow_broadcast_gpu(self):
         dev = gpu_dev
@@ -2642,9 +2642,9 @@ class TestPythonOperation(unittest.TestCase):
 
             result = autograd.pow(x,x1)
             dx0,dx1 = result.creator.backward(dy.data)
-            np.testing.assert_array_almost_equal(tensor.to_numpy(result), y, decimal=5)
-            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx0)), grad0, decimal=5)
-            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx1)), grad1, decimal=5)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(result), y, decimal=2)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx0)), grad0, decimal=2)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx1)), grad1, decimal=2)
 
     def test_pow_broadcast_cpu(self):
         dev = cpu_dev
@@ -2673,9 +2673,9 @@ class TestPythonOperation(unittest.TestCase):
 
             result = autograd.pow(x,x1)
             dx0,dx1 = result.creator.backward(dy.data)
-            np.testing.assert_array_almost_equal(tensor.to_numpy(result), y, decimal=5)
-            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx0)), grad0, decimal=5)
-            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx1)), grad1, decimal=5)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(result), y, decimal=2)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx0)), grad0, decimal=2)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx1)), grad1, decimal=2)
 
     def test_prelu_broadcast_gpu(self):
         dev = gpu_dev