You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by mo...@apache.org on 2018/05/18 04:52:14 UTC
[08/14] incubator-singa git commit: SINGA-349 Create layer operations
for autograd
SINGA-349 Create layer operations for autograd
1. rewrite Linear opertion
2. avoid absolute path
3. modified mnist_cnn example
4. delete unnecessary codes
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/40e609a4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/40e609a4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/40e609a4
Branch: refs/heads/master
Commit: 40e609a4e807595d335adaae17966daa8adac04c
Parents: ed464ef
Author: xuewanqi <xu...@u.nus.edu>
Authored: Tue May 15 14:58:41 2018 +0800
Committer: Wang Wei <dc...@nus.edu.sg>
Committed: Thu May 17 21:19:07 2018 +0800
----------------------------------------------------------------------
examples/autograd/mnist_cnn.py | 60 +++++----
python/singa/autograd.py | 260 +++++++++++-------------------------
2 files changed, 109 insertions(+), 211 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/40e609a4/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index bc717c7..7afbb9e 100644
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -1,11 +1,12 @@
import numpy as np
+import argparse
+import os
from singa import tensor
from singa import autograd
from singa import optimizer
-
def load_data(path):
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
@@ -13,6 +14,7 @@ def load_data(path):
f.close()
return (x_train, y_train), (x_test, y_test)
+
def to_categorical(y, num_classes):
'''
Converts a class vector (integers) to binary class matrix.
@@ -32,12 +34,14 @@ def to_categorical(y, num_classes):
categorical=categorical.astype(np.float32)
return categorical
+
def preprocess(data):
data=data.astype(np.float32)
data /= 255
data=np.expand_dims(data, axis=1)
return data
+
def accuracy(pred,target):
y = np.argmax(pred, axis=1)
t = np.argmax(target, axis=1)
@@ -47,47 +51,51 @@ def accuracy(pred,target):
if __name__ == '__main__':
- batch_number=600
+ parser = argparse.ArgumentParser(description='Train CNN over MNIST')
+ parser.add_argument('file_path', type=str, help='the dataset path')
+ args = parser.parse_args()
+
+ assert os.path.exists(args.file_path), 'Pls download the MNIST dataset from' \
+ 'https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz'
+
+ train, test = load_data(args.file_path)
+
+ batch_number = 600
num_classes = 10
epochs = 1
sgd = optimizer.SGD(0.05)
- train,test=load_data('/Users/wanqixue/Downloads/mnist.npz')
- x_train=preprocess(train[0])
+ x_train = preprocess(train[0])
y_train = to_categorical(train[1], num_classes)
x_test=preprocess(test[0])
y_test=to_categorical(test[1],num_classes)
- print ('the shape of training data is',x_train.shape)
- print ('the shape of training label is',y_train.shape)
+ print ('the shape of training data is', x_train.shape)
+ print ('the shape of training label is', y_train.shape)
print ('the shape of testing data is', x_test.shape)
print ('the shape of testing label is', y_test.shape)
# operations initialization
- conv1=autograd.Conv2d(3, 32)
- conv2=autograd.Conv2d(32, 32)
-
- w0 = tensor.Tensor(shape=(25088, 10), requires_grad=True, stores_grad=True)
- w0.gaussian(0.0, 0.1)
- b0 = tensor.Tensor(shape=(1, 10), requires_grad=True, stores_grad=True)
- b0.set_value(0.0)
-
-
- def forward(x,t):
- y=conv1(x)
- y=autograd.relu(y)
- y=conv2(y)
- y=autograd.relu(y)
- y=autograd.max_pool_2d(y)
- y=autograd.flatten(y)
- y=autograd.dense(y, w0, b0)
- y=autograd.soft_max(y)
- loss=autograd.cross_entropy(y, t)
+ conv1 = autograd.Conv2d(3, 32)
+ conv2 = autograd.Conv2d(32, 32)
+ linear = autograd.Linear(32*28*28, 10)
+
+
+ def forward(x, t):
+ y = conv1(x)
+ y = autograd.relu(y)
+ y = conv2(y)
+ y = autograd.relu(y)
+ y = autograd.max_pool_2d(y)
+ y = autograd.flatten(y)
+ y = linear(y)
+ y = autograd.soft_max(y)
+ loss = autograd.cross_entropy(y, t)
return loss, y
for epoch in range(epochs):
- for i in range(16):
+ for i in range(batch_number):
inputs = tensor.Tensor(data=x_train[i * 100:(1 + i) * 100, :], requires_grad=False, stores_grad=False)
targets = tensor.Tensor(data=y_train[i * 100:(1 + i) * 100, :], requires_grad=False, stores_grad=False)
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/40e609a4/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 35211de..daae43c 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -84,8 +84,8 @@ class Matmul(Operation):
singa.Mult(self.input[0].T(), dy)
-def matmul(x, w):
- return Matmul()(x, w)[0]
+def matmul(x, w, flag=True):
+ return Matmul()(x, w, flag)[0]
class AddBias(Operation):
@@ -246,27 +246,19 @@ def ctensor2numpy(x):
np_array = x.GetFloatValue(int(x.Size()))
return np_array.reshape(x.shape())
+
class Conv2d(Operation):
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True, **kwargs):
-
- name = 'Conv2d'
- border_mode = 'same'
- cudnn_prefer = 'fastest'
- workspace_byte_limit = 1024
- data_format = 'NCHW'
- W_specs ={'init': 'xavier'}
- b_specs = {'init': 'constant'}
- input_sample_shape = None
-
- inner_params = {'name':name,
- 'border_mode':border_mode,
- 'cudnn_prefer':cudnn_prefer,
- 'workspace_byte_limit':workspace_byte_limit,
- 'data_format':data_format,
- 'W_specs':W_specs,
- 'b_specs':b_specs,
- 'input_sample_shape':input_sample_shape
- }
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True,
+ **kwargs):
+
+ inner_params = {'name': 'Conv2d',
+ 'border_mode': 'same',
+ 'cudnn_prefer': 'fastest',
+ 'workspace_byte_limit': 1024,
+ 'data_format': 'NCHW',
+ 'W_specs': {'init': 'xavier'},
+ 'b_specs': {'init': 'constant'},
+ 'input_sample_shape': None}
# TODO valid value of inner_params check
for kwarg in kwargs:
@@ -277,8 +269,13 @@ class Conv2d(Operation):
self.in_channels = in_channels
self.out_channels = out_channels
- self.W_specs=inner_params['W_specs']
- self.b_specs=inner_params['b_specs']
+ self.W_specs = inner_params['W_specs']
+ self.b_specs = inner_params['b_specs']
+
+ if isinstance(kernel_size, int):
+ self.kernel_size = (kernel_size, kernel_size)
+ else:
+ self.kernel_size = kernel_size
if padding == 0:
pad = None
@@ -294,6 +291,12 @@ class Conv2d(Operation):
data_format=inner_params['data_format'], use_bias=bias, W_specs=self.W_specs, b_specs=self.b_specs,
pad=pad, input_sample_shape=inner_params['input_sample_shape'])
+ def get_params(self):
+ assert self.has_setup, \
+ 'Must call setup() before get_params()'
+ params = self.PyLayer.layer.param_values()
+ return params
+
def __call__(self, x, flag=True):
assert type(flag) is bool, 'flag can only be bool.'
if flag:
@@ -305,68 +308,74 @@ class Conv2d(Operation):
self.PyLayer.setup(x.shape[1:])
param_data = self.PyLayer.layer.param_values()
+
if not hasattr(self, 'w'):
self.w = Tensor(device=param_data[0].device, data=param_data[0], requires_grad=True, stores_grad=True)
- if self.W_specs['init'] == 'xavier':
- std = math.sqrt(2.0/(self.in_channels+self.out_channels))
- self.w.gaussian(0.0, std)
- elif self.W_specs['init'] == 'gaussian':
- if 'std' not in self.W_specs or 'mean' not in self.W_specs:
- self.w.gaussian(0.0, 0.1)
- else:
- self.w.gaussian(self.W_specs['mean'],self.W_specs['std'])
- elif self.W_specs['init'] == 'uniform':
- if 'low' not in self.W_specs or 'high' not in self.W_specs:
- self.w.uniform(0.0, 0.1)
- else:
- self.w.uniform(self.W_specs['low'],self.W_specs['high'])
+ std = math.sqrt(2.0/(self.in_channels*self.kernel_size[0]*self.kernel_size[1]+self.out_channels))
+ self.w.gaussian(0.0, std)
xs = [x, self.w]
if len(param_data) == 2:
if not hasattr(self, 'b'):
self.b = Tensor(device=param_data[1].device, data=param_data[1], requires_grad=True, stores_grad=True)
- if self.b_specs['init'] == 'gaussian':
- if 'std' not in self.b_specs or 'mean' not in self.b_specs:
- self.b.gaussian(0.0, 0.1)
- else:
- self.b.gaussian(self.b_specs['mean'], self.b_specs['std'])
- elif self.b_specs['init'] == 'uniform':
- if 'low' not in self.b_specs or 'high' not in self.b_specs:
- self.b.uniform(0.0, 0.1)
- else:
- self.b.uniform(self.b_specs['low'], self.b_specs['high'])
- elif self.b_specs['init'] == 'constant':
- self.b.set_value(0.0)
+ self.b.set_value(0.0)
xs.append(self.b)
xs = tuple(xs)
- return self._do_forward_0(*xs)
-
- def _do_forward_0(self, *xs):
return self._do_forward(*xs)[0]
def forward(self, *xs):
return self.PyLayer.layer.Forward(self.flag, xs[0])
def backward(self, dy):
- ret = self.PyLayer.layer.Backward(0, dy)
+ ret = self.PyLayer.layer.Backward(self.flag, dy)
return (ret[0],)+ret[1]
+class Linear(Operation):
+ def __init__(self, in_features, out_features, bias=True):
+ self.in_features = in_features
+ self.out_features = out_features
+ w_shape = (in_features, out_features)
+ self.w = Tensor(shape=w_shape, requires_grad=True, stores_grad=True)
+ if bias:
+ b_shape = (1, out_features)
+ self.b = Tensor(shape=b_shape, requires_grad=True, stores_grad=True)
+ self.init_value = False
+
+ def get_params(self):
+ if hasattr(self, 'b'):
+ return (self.w, self.b)
+ else:
+ return self.w
+
+ def __call__(self, x, flag=True):
+ assert type(flag) is bool, 'flag can only be bool.'
+ self.flag = flag
+ if self.init_value is False:
+ std = math.sqrt(2.0 / (self.in_features + self.out_features))
+ self.w.gaussian(0.0, std)
+ if hasattr(self, 'b'):
+ self.b.set_value(0.0)
+ self.init_value = True
+ return self._do_forward(x)
+
+ def _do_forward(self, x):
+ y = matmul(x, self.w, self.flag)
+ if hasattr(self, 'b'):
+ y = add_bias(y, self.b, axis=0)
+ return y
+
+
class MaxPool2d(Operation):
def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs):
- name = 'MaxPool2d'
- border_mode = 'same'
- data_format = 'NCHW'
- input_sample_shape = None
-
- inner_params = {'name': name,
- 'border_mode': border_mode,
- 'data_format': data_format,
- 'input_sample_shape': input_sample_shape
+ inner_params = {'name': 'MaxPool2d',
+ 'border_mode': 'same',
+ 'data_format': 'NCHW',
+ 'input_sample_shape': None
}
for kwarg in kwargs:
@@ -405,29 +414,9 @@ class MaxPool2d(Operation):
def backward(self, dy):
return self.PyLayer.layer.Backward(0, dy)[0]
-def max_pool_2d(x,kernel_size=3, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs):
- return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode, **kwargs)(x)[0]
-
-'''class ReLU_Layer(Operation):
- def __init__(self, name='ReLU', mode='relu',input_sample_shape=None):
- self.PyLayer = layer.Activation(name, mode, input_sample_shape)
-
- def __call__(self, x, flag=True):
- assert type(flag) is bool, 'flag can only be bool.'
- if flag:
- self.flag = model_pb2.kTrain
- else:
- self.flag = model_pb2.kEval
- if not self.PyLayer.has_setup:
- self.PyLayer.setup(x.shape[1:])
- return self._do_forward(x)
-
- def forward(self, *xs):
- return self.PyLayer.layer.Forward(self.flag, xs[0])
-
- def backward(self, dy):
- return self.PyLayer.layer.Backward(0, dy)[0]'''
+def max_pool_2d(x, kernel_size=3, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs):
+ return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode, **kwargs)(x)[0]
class Flatten(Operation):
@@ -450,109 +439,10 @@ class Flatten(Operation):
def backward(self, dy):
return self.PyLayer.layer.Backward(0, dy)[0]
-def flatten(x, name='Flatten', axis=1, input_sample_shape=None):
- return Flatten(name,axis,input_sample_shape)(x)[0]
-
-def dense(x, w, b=None, bias=True, axis=0):
- if bias:
- if b is None:
- raise ValueError('must input bias value.')
- else:
- y= matmul(x, w)
- y= add_bias(y, b, axis)
- return y
- else:
- return matmul(x, w)
-
-'''class Linear(Operation):
- def __init__(self, in_features, out_features, bias=True, **kwargs):
-
- name = 'Linear'
- W_transpose=False
- W_specs = {'init': 'xavier'}
- b_specs = {'init': 'constant'}
- input_sample_shape = in_features
-
- inner_params = {'name': name,
- 'W_transpose': W_transpose,
- 'W_specs': W_specs,
- 'b_specs': b_specs,
- 'input_sample_shape': input_sample_shape
- }
- # TODO valid value of inner_params check
-
- for kwarg in kwargs:
- if kwarg not in allowed_kwargs:
- raise TypeError('Keyword argument not understood:', kwarg)
- else:
- inner_params[kwarg] = kwargs[kwarg]
-
- self.in_features = in_features
- self.out_features = out_features
- self.W_specs = W_specs
- self.b_specs = b_specs
-
- self.PyLayer = layer.Dense(inner_params['name'], num_output=out_features, use_bias=bias,
- W_specs=self.W_specs, b_specs=self.b_specs,
- W_transpose=inner_params['W_transpose'], input_sample_shape=inner_params['input_sample_shape'])
-
- def __call__(self, x, flag=True):
- assert type(flag) is bool, 'flag can only be bool.'
- if flag:
- self.flag = model_pb2.kTrain
- else:
- self.flag = model_pb2.kEval
-
- if not self.PyLayer.has_setup:
- self.PyLayer.setup(x.shape[1:])
-
- param_data = self.PyLayer.layer.param_values()
- if not hasattr(self, 'w'):
- self.w = Tensor(device=param_data[0].device, data=param_data[0], requires_grad=True, stores_grad=True)
- if self.W_specs['init'] == 'xavier':
- std = math.sqrt(2.0/(self.in_channels+self.out_channels))
- self.w.gaussian(0.0, std)
- elif self.W_specs['init'] == 'gaussian':
- if 'std' not in self.W_specs or 'mean' not in self.W_specs:
- self.w.gaussian(0.0, 0.1)
- else:
- self.w.gaussian(self.W_specs['mean'],self.W_specs['std'])
- elif self.W_specs['init'] == 'uniform':
- if 'low' not in self.W_specs or 'high' not in self.W_specs:
- self.w.uniform(0.0, 0.1)
- else:
- self.w.uniform(self.W_specs['low'],self.W_specs['high'])
-
- xs = [x, self.w]
-
- if len(param_data) == 2:
- if not hasattr(self, 'b'):
- self.b = Tensor(device=param_data[1].device, data=param_data[1], requires_grad=True, stores_grad=True)
- if self.b_specs['init'] == 'gaussian':
- if 'std' not in self.b_specs or 'mean' not in self.b_specs:
- self.b.gaussian(0.0, 0.1)
- else:
- self.b.gaussian(self.b_specs['mean'], self.b_specs['std'])
- elif self.b_specs['init'] == 'uniform':
- if 'low' not in self.b_specs or 'high' not in self.b_specs:
- self.b.uniform(0.0, 0.1)
- else:
- self.b.uniform(self.b_specs['low'], self.b_specs['high'])
- elif self.b_specs['init'] == 'constant':
- self.b.set_value(0.0)
-
- xs.append(self.b)
-
- xs = tuple(xs)
- return self._do_forward(*xs)
-
- def forward(self, *xs):
- return self.PyLayer.layer.Forward(self.flag, xs[0])
+def flatten(x, name='Flatten', axis=1, input_sample_shape=None):
+ return Flatten(name, axis, input_sample_shape)(x)[0]
- def backward(self, dy):
- ret = self.PyLayer.layer.Backward(0, dy)
- return (ret[0],)+ret[1]'''
def infer_dependency(op):
'''