You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by mo...@apache.org on 2018/05/18 04:52:17 UTC

[11/14] incubator-singa git commit: SINGA-349 Create layer operations for autograd

SINGA-349 Create layer operations for autograd

1. Change the API of Conv2d operations into Pytorch style. next step is to confirm the new design workable.

2. Add flags in Conv2d forward function

3. Delete extra file


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/6402a53d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/6402a53d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/6402a53d

Branch: refs/heads/master
Commit: 6402a53d31185bb455c9d796b03d01f3dc476de3
Parents: 5abcc6e
Author: xuewanqi <36...@users.noreply.github.com>
Authored: Sat May 5 17:00:22 2018 +0800
Committer: Wang Wei <dc...@nus.edu.sg>
Committed: Thu May 17 21:19:07 2018 +0800

----------------------------------------------------------------------
 python/singa/convolution_operation.py | 158 -----------------------------
 python/singa/layer_ops.py             |  59 ++++++-----
 2 files changed, 36 insertions(+), 181 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6402a53d/python/singa/convolution_operation.py
----------------------------------------------------------------------
diff --git a/python/singa/convolution_operation.py b/python/singa/convolution_operation.py
deleted file mode 100644
index 8475c21..0000000
--- a/python/singa/convolution_operation.py
+++ /dev/null
@@ -1,158 +0,0 @@
-from singa import tensor
-from singa import layer
-from singa.proto import model_pb2
-from singa import autograd
-
-
-
-def ctensor2numpy(x):
-    '''
-    // For test use.
-
-
-    To be used in SoftMax Operation.
-    Convert a singa_tensor to numpy_tensor.
-    '''
-    np_array = x.GetFloatValue(int(x.Size()))
-    return np_array.reshape(x.shape())
-
-class Convolution2D(tensor.Operation):
-    def __init__(self, name, nb_kernels, kernel=3, stride=1, border_mode='same',
-                 cudnn_prefer='fastest', workspace_byte_limit=1024,
-                 data_format='NCHW', use_bias=True, W_specs=None, b_specs=None,
-                 pad=None,input_sample_shape=None):
-        '''
-        How to match Keras:
-
-        in Keras conv2d, self.kernel record how to generate kernel (shape,initializer,name,regularizer,constraint),
-        it can be interpret to
-        shape -> kernel+input_sample_shape[0](nb_channels)+nb_kernels,
-        initializer, name, regularizer, constraint -> W_specs.
-        '''
-        self.PyLayer = layer.Conv2D(name, nb_kernels, kernel=kernel, stride=stride, border_mode=border_mode,
-                 cudnn_prefer=cudnn_prefer, workspace_byte_limit=workspace_byte_limit,
-                 data_format=data_format, use_bias=use_bias, W_specs=W_specs, b_specs=b_specs,
-                 pad=pad, input_sample_shape=input_sample_shape)
-
-
-    def __call__(self, x):
-        if not self.PyLayer.has_setup:
-            self.PyLayer.setup(x.shape[1:])
-        param_data = self.PyLayer.layer.param_values()
-
-        if not hasattr(self, 'w'):
-            self.w = tensor.Tensor(data=param_data[0], requires_grad=True, stores_grad=True)
-            self.w.gaussian(0.0, 0.1)  # TODO realize other initialization method according to W_specs
-
-        xs = [x, self.w]
-
-        if len(param_data) == 2:
-            self.b = tensor.Tensor(data=param_data[1], requires_grad=True, stores_grad=True)
-            self.b.set_value(0.0)  # TODO realize other initialization method according to b_specs
-            xs.append(self.b)
-
-        xs = tuple(xs)
-        return self._do_forward(*xs)
-
-    def forward(self, *xs):
-        return self.PyLayer.layer.Forward(4, xs[0])  #how ktrain works?  flag & ktrain.
-
-    def backward(self, dy):
-        ret = self.PyLayer.layer.Backward(True, dy)
-        return (ret[0],)+ret[1]
-
-
-class MaxPooling2D(tensor.Operation):
-    def __init__(self, name, kernel=3, stride=2, border_mode='same', pad=None,
-                 data_format='NCHW', input_sample_shape=None):
-
-        self.PyLayer = layer.Pooling2D(name, model_pb2.PoolingConf.MAX,
-                                           kernel, stride, border_mode,
-                                           pad, data_format, input_sample_shape)
-
-    def __call__(self, x):
-        if not self.PyLayer.has_setup:
-            self.PyLayer.setup(x.shape[1:])
-        return self._do_forward(x)
-
-    def forward(self, x):
-        return self.PyLayer.layer.Forward(4, x)
-
-    def backward(self, dy):
-        return self.PyLayer.layer.Backward(True, dy)[0]   # how backward() return?
-
-
-class Activation(tensor.Operation):
-    def __init__(self,name, mode='relu',input_sample_shape=None):
-        self.PyLayer = layer.Activation(name, mode, input_sample_shape)
-
-    def __call__(self, x):
-        if not self.PyLayer.has_setup:
-            self.PyLayer.setup(x.shape[1:])
-        return self._do_forward(x)
-
-    def forward(self, x):
-        return self.PyLayer.layer.Forward(4, x)
-
-    def backward(self, dy):
-        return self.PyLayer.layer.Backward(True, dy)[0]
-
-
-class Flatten(tensor.Operation):
-    def __init__(self, name, axis=1, input_sample_shape=None):
-        self.PyLayer = layer.Flatten(name, axis, input_sample_shape)
-
-    def __call__(self, x):
-        if not self.PyLayer.has_setup:
-            self.PyLayer.setup(x.shape[1:])
-        return self._do_forward(x)
-
-    def forward(self, x):
-        return self.PyLayer.layer.Forward(4, x)
-
-    def backward(self, dy):
-        return self.PyLayer.layer.Backward(True, dy)[0]
-
-
-class Dense(tensor.Operation):
-    '''
-    Need to implemented?
-    '''
-    pass
-
-
-inputs=tensor.Tensor(shape=(10, 2, 3, 3), requires_grad=False, stores_grad=False)
-inputs.gaussian(1,0)
-
-x = Convolution2D('conv',4)(inputs)[0]
-print(x.shape)
-
-x = MaxPooling2D('pooling')(x)[0]
-print(x.shape)
-
-x = Activation('relu')(x)[0]
-print(x.shape)
-
-x = Flatten('flatten')(x)[0]
-print(x.shape)
-
-w0 = tensor.Tensor(shape=(4, 10), requires_grad=True, stores_grad=True)
-w0.gaussian(0.0, 0.1)
-x = tensor.matmul(x, w0)
-print(x.shape)
-
-x = tensor.softmax(x)
-
-target=tensor.Tensor(shape=(10, 10), requires_grad=False, stores_grad=False)
-target.gaussian(0.0 ,0.1)
-loss = tensor.cross_entropy(x, target)
-
-grad=autograd.backward(loss)
-print(grad)
-
-
-
-
-
-
-

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6402a53d/python/singa/layer_ops.py
----------------------------------------------------------------------
diff --git a/python/singa/layer_ops.py b/python/singa/layer_ops.py
index 1ca888f..e5ef45f 100644
--- a/python/singa/layer_ops.py
+++ b/python/singa/layer_ops.py
@@ -1,26 +1,36 @@
 from singa import tensor
 from singa import layer
 from singa.proto import model_pb2
-from singa import autograd
 
 
+class Conv2D(tensor.Operation):
+    def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,**kwargs):
+
+        name='Conv2d'
+        border_mode = 'same'
+        cudnn_prefer = 'fastest'
+        workspace_byte_limit = 1024
+        data_format = 'NCHW'
+        W_specs = None
+        b_specs = None
+        input_sample_shape=None
+
+        allowed_kwargs = {'name':name,
+                          'border_mode':border_mode,
+                          'cudnn_prefer':cudnn_prefer,
+                          'workspace_byte_limit':workspace_byte_limit,
+                          'data_format':data_format,
+                          'W_specs':W_specs,
+                          'b_specs':b_specs,
+                          'input_sample_shape':input_sample_shape
+                          }
+
+        for kwarg in kwargs:
+            if kwarg not in allowed_kwargs:
+                raise TypeError('Keyword argument not understood:', kwarg)
+            else:
+                allowed_kwargs[kwarg] = kwargs[kwarg]
 
-def ctensor2numpy(x):
-    '''
-    // For test use.
-
-
-    To be used in SoftMax Operation.
-    Convert a singa_tensor to numpy_tensor.
-    '''
-    np_array = x.GetFloatValue(int(x.Size()))
-    return np_array.reshape(x.shape())
-
-class Convolution2D(tensor.Operation):
-    def __init__(self, name, nb_kernels, kernel=3, stride=1, border_mode='same',
-                 cudnn_prefer='fastest', workspace_byte_limit=1024,
-                 data_format='NCHW', use_bias=True, W_specs=None, b_specs=None,
-                 pad=None,input_sample_shape=None):
         '''
         How to match Keras:
 
@@ -29,10 +39,10 @@ class Convolution2D(tensor.Operation):
         shape -> kernel+input_sample_shape[0](nb_channels)+nb_kernels,
         initializer, name, regularizer, constraint -> W_specs.
         '''
-        self.PyLayer = layer.Conv2D(name, nb_kernels, kernel=kernel, stride=stride, border_mode=border_mode,
+        self.PyLayer = layer.Conv2D(name, nb_kernels=out_channels, kernel=kernel_size, stride=stride, border_mode=border_mode,
                  cudnn_prefer=cudnn_prefer, workspace_byte_limit=workspace_byte_limit,
-                 data_format=data_format, use_bias=use_bias, W_specs=W_specs, b_specs=b_specs,
-                 pad=pad, input_sample_shape=input_sample_shape)
+                 data_format=data_format, use_bias=bias, W_specs=W_specs, b_specs=b_specs,
+                 pad=padding, input_sample_shape=input_sample_shape)
 
 
     def __call__(self, x):
@@ -53,8 +63,11 @@ class Convolution2D(tensor.Operation):
         xs = tuple(xs)
         return self._do_forward(*xs)
 
-    def forward(self, *xs):
-        return self.PyLayer.layer.Forward(4, xs[0])  #how ktrain works?  flag & ktrain.
+    def forward(self, flag=True,*xs):
+        if flag is True:
+            return self.PyLayer.layer.Forward(4, xs[0])
+        else:
+            return self.PyLayer.layer.Forward(8, xs[0])
 
     def backward(self, dy):
         ret = self.PyLayer.layer.Backward(0, dy)
@@ -78,7 +91,7 @@ class MaxPooling2D(tensor.Operation):
         return self.PyLayer.layer.Forward(4, x)
 
     def backward(self, dy):
-        return self.PyLayer.layer.Backward(0, dy)[0]   # how backward() return?
+        return self.PyLayer.layer.Backward(0, dy)[0]
 
 
 class Activation(tensor.Operation):