You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by mo...@apache.org on 2018/05/18 04:52:14 UTC

[08/14] incubator-singa git commit: SINGA-349 Create layer operations for autograd

SINGA-349 Create layer operations for autograd

1. rewrite Linear opertion

2. avoid absolute path

3. modified mnist_cnn example

4. delete unnecessary codes


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/40e609a4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/40e609a4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/40e609a4

Branch: refs/heads/master
Commit: 40e609a4e807595d335adaae17966daa8adac04c
Parents: ed464ef
Author: xuewanqi <xu...@u.nus.edu>
Authored: Tue May 15 14:58:41 2018 +0800
Committer: Wang Wei <dc...@nus.edu.sg>
Committed: Thu May 17 21:19:07 2018 +0800

----------------------------------------------------------------------
 examples/autograd/mnist_cnn.py |  60 +++++----
 python/singa/autograd.py       | 260 +++++++++++-------------------------
 2 files changed, 109 insertions(+), 211 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/40e609a4/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index bc717c7..7afbb9e 100644
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -1,11 +1,12 @@
 import numpy as np
+import argparse
+import os
 
 from singa import tensor
 from singa import autograd
 from singa import optimizer
 
 
-
 def load_data(path):
     f = np.load(path)
     x_train, y_train = f['x_train'], f['y_train']
@@ -13,6 +14,7 @@ def load_data(path):
     f.close()
     return (x_train, y_train), (x_test, y_test)
 
+
 def to_categorical(y, num_classes):
     '''
     Converts a class vector (integers) to binary class matrix.
@@ -32,12 +34,14 @@ def to_categorical(y, num_classes):
     categorical=categorical.astype(np.float32)
     return categorical
 
+
 def preprocess(data):
     data=data.astype(np.float32)
     data /= 255
     data=np.expand_dims(data, axis=1)
     return data
 
+
 def accuracy(pred,target):
     y = np.argmax(pred, axis=1)
     t = np.argmax(target, axis=1)
@@ -47,47 +51,51 @@ def accuracy(pred,target):
 
 if __name__ == '__main__':
 
-    batch_number=600
+    parser = argparse.ArgumentParser(description='Train CNN over MNIST')
+    parser.add_argument('file_path', type=str, help='the dataset path')
+    args = parser.parse_args()
+
+    assert os.path.exists(args.file_path), 'Pls download the MNIST dataset from' \
+     'https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz'
+
+    train, test = load_data(args.file_path)
+
+    batch_number = 600
     num_classes = 10
     epochs = 1
 
     sgd = optimizer.SGD(0.05)
 
-    train,test=load_data('/Users/wanqixue/Downloads/mnist.npz')
-    x_train=preprocess(train[0])
+    x_train = preprocess(train[0])
     y_train = to_categorical(train[1], num_classes)
 
     x_test=preprocess(test[0])
     y_test=to_categorical(test[1],num_classes)
-    print ('the shape of training data is',x_train.shape)
-    print ('the shape of training label is',y_train.shape)
+    print ('the shape of training data is', x_train.shape)
+    print ('the shape of training label is', y_train.shape)
     print ('the shape of testing data is', x_test.shape)
     print ('the shape of testing label is', y_test.shape)
 
     # operations initialization
-    conv1=autograd.Conv2d(3, 32)
-    conv2=autograd.Conv2d(32, 32)
-
-    w0 = tensor.Tensor(shape=(25088, 10), requires_grad=True, stores_grad=True)
-    w0.gaussian(0.0, 0.1)
-    b0 = tensor.Tensor(shape=(1, 10), requires_grad=True, stores_grad=True)
-    b0.set_value(0.0)
-
-
-    def forward(x,t):
-        y=conv1(x)
-        y=autograd.relu(y)
-        y=conv2(y)
-        y=autograd.relu(y)
-        y=autograd.max_pool_2d(y)
-        y=autograd.flatten(y)
-        y=autograd.dense(y, w0, b0)
-        y=autograd.soft_max(y)
-        loss=autograd.cross_entropy(y, t)
+    conv1 = autograd.Conv2d(3, 32)
+    conv2 = autograd.Conv2d(32, 32)
+    linear = autograd.Linear(32*28*28, 10)
+
+
+    def forward(x, t):
+        y = conv1(x)
+        y = autograd.relu(y)
+        y = conv2(y)
+        y = autograd.relu(y)
+        y = autograd.max_pool_2d(y)
+        y = autograd.flatten(y)
+        y = linear(y)
+        y = autograd.soft_max(y)
+        loss = autograd.cross_entropy(y, t)
         return loss, y
 
     for epoch in range(epochs):
-        for i in range(16):
+        for i in range(batch_number):
             inputs = tensor.Tensor(data=x_train[i * 100:(1 + i) * 100, :], requires_grad=False, stores_grad=False)
             targets = tensor.Tensor(data=y_train[i * 100:(1 + i) * 100, :], requires_grad=False, stores_grad=False)
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/40e609a4/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 35211de..daae43c 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -84,8 +84,8 @@ class Matmul(Operation):
             singa.Mult(self.input[0].T(), dy)
 
 
-def matmul(x, w):
-    return Matmul()(x, w)[0]
+def matmul(x, w, flag=True):
+    return Matmul()(x, w, flag)[0]
 
 
 class AddBias(Operation):
@@ -246,27 +246,19 @@ def ctensor2numpy(x):
     np_array = x.GetFloatValue(int(x.Size()))
     return np_array.reshape(x.shape())
 
+
 class Conv2d(Operation):
-    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True, **kwargs):
-
-        name = 'Conv2d'
-        border_mode = 'same'
-        cudnn_prefer = 'fastest'
-        workspace_byte_limit = 1024
-        data_format = 'NCHW'
-        W_specs ={'init': 'xavier'}
-        b_specs = {'init': 'constant'}
-        input_sample_shape = None
-
-        inner_params = {'name':name,
-                          'border_mode':border_mode,
-                          'cudnn_prefer':cudnn_prefer,
-                          'workspace_byte_limit':workspace_byte_limit,
-                          'data_format':data_format,
-                          'W_specs':W_specs,
-                          'b_specs':b_specs,
-                          'input_sample_shape':input_sample_shape
-                          }
+    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True,
+                 **kwargs):
+
+        inner_params = {'name': 'Conv2d',
+                          'border_mode': 'same',
+                          'cudnn_prefer': 'fastest',
+                          'workspace_byte_limit': 1024,
+                          'data_format': 'NCHW',
+                          'W_specs': {'init': 'xavier'},
+                          'b_specs': {'init': 'constant'},
+                          'input_sample_shape': None}
         # TODO valid value of inner_params check
 
         for kwarg in kwargs:
@@ -277,8 +269,13 @@ class Conv2d(Operation):
                 
         self.in_channels = in_channels
         self.out_channels = out_channels
-        self.W_specs=inner_params['W_specs']
-        self.b_specs=inner_params['b_specs']
+        self.W_specs = inner_params['W_specs']
+        self.b_specs = inner_params['b_specs']
+
+        if isinstance(kernel_size, int):
+            self.kernel_size = (kernel_size, kernel_size)
+        else:
+            self.kernel_size = kernel_size
 
         if padding == 0:
             pad = None
@@ -294,6 +291,12 @@ class Conv2d(Operation):
                  data_format=inner_params['data_format'], use_bias=bias, W_specs=self.W_specs, b_specs=self.b_specs,
                  pad=pad, input_sample_shape=inner_params['input_sample_shape'])
 
+    def get_params(self):
+        assert self.has_setup, \
+            'Must call setup() before get_params()'
+        params = self.PyLayer.layer.param_values()
+        return params
+
     def __call__(self, x, flag=True):
         assert type(flag) is bool, 'flag can only be bool.'
         if flag:
@@ -305,68 +308,74 @@ class Conv2d(Operation):
             self.PyLayer.setup(x.shape[1:])
 
         param_data = self.PyLayer.layer.param_values()
+
         if not hasattr(self, 'w'):
             self.w = Tensor(device=param_data[0].device, data=param_data[0], requires_grad=True, stores_grad=True)
-            if self.W_specs['init'] == 'xavier':
-                std = math.sqrt(2.0/(self.in_channels+self.out_channels))
-                self.w.gaussian(0.0, std)
-            elif self.W_specs['init'] == 'gaussian':
-                if 'std' not in self.W_specs or 'mean' not in self.W_specs:
-                    self.w.gaussian(0.0, 0.1)
-                else:
-                    self.w.gaussian(self.W_specs['mean'],self.W_specs['std'])
-            elif self.W_specs['init'] == 'uniform':
-                if 'low' not in self.W_specs or 'high' not in self.W_specs:
-                    self.w.uniform(0.0, 0.1)
-                else:
-                    self.w.uniform(self.W_specs['low'],self.W_specs['high'])
+            std = math.sqrt(2.0/(self.in_channels*self.kernel_size[0]*self.kernel_size[1]+self.out_channels))
+            self.w.gaussian(0.0, std)
 
         xs = [x, self.w]
 
         if len(param_data) == 2:
             if not hasattr(self, 'b'):
                 self.b = Tensor(device=param_data[1].device, data=param_data[1], requires_grad=True, stores_grad=True)
-                if self.b_specs['init'] == 'gaussian':
-                    if 'std' not in self.b_specs or 'mean' not in self.b_specs:
-                        self.b.gaussian(0.0, 0.1)
-                    else:
-                        self.b.gaussian(self.b_specs['mean'], self.b_specs['std'])
-                elif self.b_specs['init'] == 'uniform':
-                    if 'low' not in self.b_specs or 'high' not in self.b_specs:
-                        self.b.uniform(0.0, 0.1)
-                    else:
-                        self.b.uniform(self.b_specs['low'], self.b_specs['high'])
-                elif self.b_specs['init'] == 'constant':
-                    self.b.set_value(0.0)
+                self.b.set_value(0.0)
 
             xs.append(self.b)
 
         xs = tuple(xs)
-        return self._do_forward_0(*xs)
-
-    def _do_forward_0(self, *xs):
         return self._do_forward(*xs)[0]
 
     def forward(self, *xs):
         return self.PyLayer.layer.Forward(self.flag, xs[0])
 
     def backward(self, dy):
-        ret = self.PyLayer.layer.Backward(0, dy)
+        ret = self.PyLayer.layer.Backward(self.flag, dy)
         return (ret[0],)+ret[1]
 
 
+class Linear(Operation):
+    def __init__(self, in_features, out_features, bias=True):
+        self.in_features = in_features
+        self.out_features = out_features
+        w_shape = (in_features, out_features)
+        self.w = Tensor(shape=w_shape, requires_grad=True, stores_grad=True)
+        if bias:
+            b_shape = (1, out_features)
+            self.b = Tensor(shape=b_shape, requires_grad=True, stores_grad=True)
+        self.init_value = False
+
+    def get_params(self):
+        if hasattr(self, 'b'):
+            return (self.w, self.b)
+        else:
+            return self.w
+
+    def __call__(self, x, flag=True):
+        assert type(flag) is bool, 'flag can only be bool.'
+        self.flag = flag
+        if self.init_value is False:
+            std = math.sqrt(2.0 / (self.in_features + self.out_features))
+            self.w.gaussian(0.0, std)
+            if hasattr(self, 'b'):
+                self.b.set_value(0.0)
+            self.init_value = True
+        return self._do_forward(x)
+
+    def _do_forward(self, x):
+        y = matmul(x, self.w, self.flag)
+        if hasattr(self, 'b'):
+            y = add_bias(y, self.b, axis=0)
+        return y
+
+
 class MaxPool2d(Operation):
     def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs):
 
-        name = 'MaxPool2d'
-        border_mode = 'same'
-        data_format = 'NCHW'
-        input_sample_shape = None
-
-        inner_params = {'name': name,
-                          'border_mode': border_mode,
-                          'data_format': data_format,
-                          'input_sample_shape': input_sample_shape
+        inner_params = {'name': 'MaxPool2d',
+                          'border_mode': 'same',
+                          'data_format': 'NCHW',
+                          'input_sample_shape': None
                           }
 
         for kwarg in kwargs:
@@ -405,29 +414,9 @@ class MaxPool2d(Operation):
     def backward(self, dy):
         return self.PyLayer.layer.Backward(0, dy)[0]
 
-def max_pool_2d(x,kernel_size=3, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs):
-    return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode, **kwargs)(x)[0]
-
 
-'''class ReLU_Layer(Operation):
-    def __init__(self, name='ReLU', mode='relu',input_sample_shape=None):
-        self.PyLayer = layer.Activation(name, mode, input_sample_shape)
-
-    def __call__(self, x, flag=True):
-        assert type(flag) is bool, 'flag can only be bool.'
-        if flag:
-            self.flag = model_pb2.kTrain
-        else:
-            self.flag = model_pb2.kEval
-        if not self.PyLayer.has_setup:
-            self.PyLayer.setup(x.shape[1:])
-        return self._do_forward(x)
-
-    def forward(self, *xs):
-        return self.PyLayer.layer.Forward(self.flag, xs[0])
-
-    def backward(self, dy):
-        return self.PyLayer.layer.Backward(0, dy)[0]'''
+def max_pool_2d(x, kernel_size=3, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs):
+    return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode, **kwargs)(x)[0]
 
 
 class Flatten(Operation):
@@ -450,109 +439,10 @@ class Flatten(Operation):
     def backward(self, dy):
         return self.PyLayer.layer.Backward(0, dy)[0]
 
-def flatten(x, name='Flatten', axis=1, input_sample_shape=None):
-    return Flatten(name,axis,input_sample_shape)(x)[0]
-
-def dense(x, w, b=None, bias=True, axis=0):
-    if bias:
-        if b is None:
-            raise ValueError('must input bias value.')
-        else:
-            y= matmul(x, w)
-            y= add_bias(y, b, axis)
-            return y
-    else:
-        return matmul(x, w)
-
-'''class Linear(Operation):
-    def __init__(self, in_features, out_features, bias=True, **kwargs):
-
-        name = 'Linear'
-        W_transpose=False
-        W_specs = {'init': 'xavier'}
-        b_specs = {'init': 'constant'}
-        input_sample_shape = in_features
-
-        inner_params = {'name': name,
-                          'W_transpose': W_transpose,
-                          'W_specs': W_specs,
-                          'b_specs': b_specs,
-                          'input_sample_shape': input_sample_shape
-                          }
 
-        # TODO valid value of inner_params check
-
-        for kwarg in kwargs:
-            if kwarg not in allowed_kwargs:
-                raise TypeError('Keyword argument not understood:', kwarg)
-            else:
-                inner_params[kwarg] = kwargs[kwarg]
-
-        self.in_features = in_features
-        self.out_features = out_features
-        self.W_specs = W_specs
-        self.b_specs = b_specs
-
-        self.PyLayer = layer.Dense(inner_params['name'], num_output=out_features, use_bias=bias,
-                     W_specs=self.W_specs, b_specs=self.b_specs,
-                     W_transpose=inner_params['W_transpose'], input_sample_shape=inner_params['input_sample_shape'])
-
-    def __call__(self, x, flag=True):
-        assert type(flag) is bool, 'flag can only be bool.'
-        if flag:
-            self.flag = model_pb2.kTrain
-        else:
-            self.flag = model_pb2.kEval
-
-        if not self.PyLayer.has_setup:
-            self.PyLayer.setup(x.shape[1:])
-
-        param_data = self.PyLayer.layer.param_values()
-        if not hasattr(self, 'w'):
-            self.w = Tensor(device=param_data[0].device, data=param_data[0], requires_grad=True, stores_grad=True)
-            if self.W_specs['init'] == 'xavier':
-                std = math.sqrt(2.0/(self.in_channels+self.out_channels))
-                self.w.gaussian(0.0, std)
-            elif self.W_specs['init'] == 'gaussian':
-                if 'std' not in self.W_specs or 'mean' not in self.W_specs:
-                    self.w.gaussian(0.0, 0.1)
-                else:
-                    self.w.gaussian(self.W_specs['mean'],self.W_specs['std'])
-            elif self.W_specs['init'] == 'uniform':
-                if 'low' not in self.W_specs or 'high' not in self.W_specs:
-                    self.w.uniform(0.0, 0.1)
-                else:
-                    self.w.uniform(self.W_specs['low'],self.W_specs['high'])
-
-        xs = [x, self.w]
-
-        if len(param_data) == 2:
-            if not hasattr(self, 'b'):
-                self.b = Tensor(device=param_data[1].device, data=param_data[1], requires_grad=True, stores_grad=True)
-                if self.b_specs['init'] == 'gaussian':
-                    if 'std' not in self.b_specs or 'mean' not in self.b_specs:
-                        self.b.gaussian(0.0, 0.1)
-                    else:
-                        self.b.gaussian(self.b_specs['mean'], self.b_specs['std'])
-                elif self.b_specs['init'] == 'uniform':
-                    if 'low' not in self.b_specs or 'high' not in self.b_specs:
-                        self.b.uniform(0.0, 0.1)
-                    else:
-                        self.b.uniform(self.b_specs['low'], self.b_specs['high'])
-                elif self.b_specs['init'] == 'constant':
-                    self.b.set_value(0.0)
-
-            xs.append(self.b)
-
-        xs = tuple(xs)
-        return self._do_forward(*xs)
-
-    def forward(self, *xs):
-        return self.PyLayer.layer.Forward(self.flag, xs[0])
+def flatten(x, name='Flatten', axis=1, input_sample_shape=None):
+    return Flatten(name, axis, input_sample_shape)(x)[0]
 
-    def backward(self, dy):
-        ret = self.PyLayer.layer.Backward(0, dy)
-        return (ret[0],)+ret[1]'''
 
 def infer_dependency(op):
     '''