You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/07/05 03:10:06 UTC

[11/18] incubator-singa git commit: SINGA-371 Implement functional operations in c++ for autograd

SINGA-371 Implement functional operations in c++ for autograd

- fixed some bugs in convolution_related.cc

- export device.lang() from C++ to python to judge device type(cpu or gpu)

- modified design of autograd.Conv2D

- modified the test file for Conv2D, this unit test has passed.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e68ea2ee
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e68ea2ee
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e68ea2ee

Branch: refs/heads/master
Commit: e68ea2ee6640d4124e2f5f32ac16726fa84d10ac
Parents: 189958a
Author: xuewanqi <xu...@outlook.com>
Authored: Thu Jun 28 05:22:30 2018 +0000
Committer: xuewanqi <xu...@outlook.com>
Committed: Thu Jun 28 05:22:30 2018 +0000

----------------------------------------------------------------------
 python/singa/autograd.py                   | 20 +++++++------
 src/api/core_device.i                      |  3 ++
 src/model/operation/convolution_related.cc | 14 ++++++++++
 test/python/test_operation.py              | 37 +++++++++++++++++++------
 4 files changed, 58 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e68ea2ee/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index e898312..e301e51 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -669,28 +669,30 @@ class Conv2D(Operation):
     	return self._do_forward(*xs)[0]
 
     def forward(self, *xs):
-        if gpu:
-            
+        if self.dev.lang()==1: #kCuda = 1           
             if not hasattr(self, 'cudnnconvhandles'):
-                self.cudnnconvhandles=InitCudnnConvHandles(xs[0], self.recorder, 
+                self.cudnnconvhandles=singa.InitCudnnConvHandles(xs[0], self.recorder, 
                     self.inner_params['workspace_MB_limit']*1024*1024, self.inner_params['cudnn_prefer'])
             elif self.reset:
-                self.cudnnconvhandles=InitCudnnConvHandles(xs[0], self.recorder, 
+                self.cudnnconvhandles=singa.InitCudnnConvHandles(xs[0], self.recorder, 
                     self.inner_params['workspace_MB_limit']*1024*1024, self.inner_params['cudnn_prefer'])
 
             return singa.GpuConvForward(xs[0], xs[1], xs[2], self.recorder, self.cudnnconvhandles)
 
-        if cpu:
-
+        elif self.dev.lang()==0: #kCpp = 0
             return singa.CpuConvForward(xs[0], xs[1], xs[2], self.recorder)
 
+        else:
+            TypeError('Not implemented yet')
+
+
     def backward(self, dy):
         assert training is True and hasattr(self, 'x'), 'Please set training as True before do BP. '
 
         # todo check device?
         dy.ToDevice(self.dev)
 
-        if gpu:
+        if self.dev.lang()==1: #kCuda = 1 
             dx = singa.GpuConvBackwardx(dy, self.W.data, self.x.data, self.cudnnconvhandles)
             dW = singa.GpuConvBackwardW(dy, self.x.data, self.W.data, self.cudnnconvhandles)
             if self.bias:
@@ -699,7 +701,7 @@ class Conv2D(Operation):
             else:
         	    return dx, dW
 
-        if cpu:
+        elif self.dev.lang()==0: #kCpp = 0
             dx = singa.CpuConvBackwardx(dy, self.W.data, self.x.data, self.recorder)
             dW = singa.CpuConvBackwardW(dy, self.x.data, self.W.data, self.recorder)
             if self.bias:
@@ -707,6 +709,8 @@ class Conv2D(Operation):
                 return dx, dW, db
             else:
                 return dx, dW
+        else:
+            TypeError('Not implemented yet')
 
 def infer_dependency(op):
     '''

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e68ea2ee/src/api/core_device.i
----------------------------------------------------------------------
diff --git a/src/api/core_device.i b/src/api/core_device.i
index a5b7de6..381f7c6 100644
--- a/src/api/core_device.i
+++ b/src/api/core_device.i
@@ -43,11 +43,14 @@ namespace std{
 
 namespace singa{
 
+enum LangType {kCpp, kCuda, kOpencl,kNumDeviceType};
+
 class Device {
  public:
   virtual void SetRandSeed(unsigned seed) = 0;
   std::shared_ptr<Device> host();
   int id() const;
+  LangType lang() const;
 };
 
 class Platform {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e68ea2ee/src/model/operation/convolution_related.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution_related.cc b/src/model/operation/convolution_related.cc
index 1004074..c828f90 100644
--- a/src/model/operation/convolution_related.cc
+++ b/src/model/operation/convolution_related.cc
@@ -318,6 +318,20 @@ CudnnConvHandles InitCudnnConvHandles(const Tensor &input, const Recorder r, con
                      << workspace_byte_limit_ << ")";
     workspace_ = Tensor(Shape{workspace_count_}, dev, dtype);
 
+    return CudnnConvHandles{
+    	x_desc_,
+        y_desc_,
+        bias_desc_,
+        filter_desc_,
+        conv_desc_,
+        fp_alg_,
+        bp_filter_alg_,
+        bp_data_alg_,
+
+        workspace_count_,
+        workspace_,
+    };
+
 };
 
 Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const Recorder r, const CudnnConvHandles cch){

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e68ea2ee/test/python/test_operation.py
----------------------------------------------------------------------
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index 295b2d2..ece537d 100644
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -10,16 +10,14 @@ autograd.training = True
 
 CTensor = singa.Tensor
 
-dev = device.create_cuda_gpu()
-
-gpu_input_tensor = tensor.Tensor(shape=(2, 3, 3, 3), device=dev)
-gpu_input_tensor.gaussian(0.0, 1.0)
+gpu_dev = device.create_cuda_gpu()
+cpu_dev = device.get_default_device()
 
 dy = CTensor([2, 1, 2, 2])
 singa.Gaussian(0.0, 1.0, dy)
-dy.ToDevice(dev)
 
-conv = autograd.Conv2d_GPU(3, 1, 2)  # (in_channels, out_channels, kernel_size)
+conv = autograd.Conv2D(3, 1, 2)  # (in_channels, out_channels, kernel_size)
+conv_without_bias = autograd.Conv2D(3,1,2,bias=False)
 
 
 def _tuple_to_string(t):
@@ -35,14 +33,37 @@ class TestPythonOperation(unittest.TestCase):
                                               _tuple_to_string(expect))
                          )
 
-    def test(self):
+    def test_conv2d_gpu(self):
+        gpu_input_tensor = tensor.Tensor(shape=(2, 3, 3, 3), device=gpu_dev)
+        gpu_input_tensor.gaussian(0.0, 1.0)
+
         y = conv(gpu_input_tensor)  # PyTensor
         dx, dW, db = conv.backward(dy)  # CTensor
-        
+
         self.check_shape(y.shape, (2, 1, 2, 2))
         self.check_shape(dx.shape(), (2, 3, 3, 3))
         self.check_shape(dW.shape(), (1, 3, 2, 2))
         self.check_shape(db.shape(), (1,))
 
+        #forward without bias
+        y_without_bias=conv_without_bias(gpu_input_tensor)
+        self.check_shape(y.shape, (2, 1, 2, 2))
+
+    def test_conv2d_cpu(self):
+        cpu_input_tensor = tensor.Tensor(shape=(2, 3, 3, 3), device=cpu_dev)
+        cpu_input_tensor.gaussian(0.0, 1.0)
+
+        y = conv(cpu_input_tensor)  # PyTensor
+        dx, dW, db = conv.backward(dy)  # CTensor
+
+        self.check_shape(y.shape, (2, 1, 2, 2))
+        self.check_shape(dx.shape(), (2, 3, 3, 3))
+        self.check_shape(dW.shape(), (1, 3, 2, 2))
+        self.check_shape(db.shape(), (1,))
+
+        #forward without bias
+        y_without_bias=conv_without_bias(cpu_input_tensor)
+        self.check_shape(y.shape, (2, 1, 2, 2))
+
 if __name__ == '__main__':
     unittest.main()