You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/08/11 03:18:35 UTC

[1/2] incubator-singa git commit: SINGA-235 - Unify the engines for cudnn and singa layers

Repository: incubator-singa
Updated Branches:
  refs/heads/dev 53639b7ce -> 05720c216


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/src/python/singa/layer.py b/src/python/singa/layer.py
index a87eb10..c8c8c05 100644
--- a/src/python/singa/layer.py
+++ b/src/python/singa/layer.py
@@ -22,6 +22,12 @@ from . import singa_wrap
 from .proto import model_pb2
 import tensor
 
+# engine could be 'cudnn', 'singa', which is used to create layers.
+# e.g., CudnnConvolution layer is identified by 'cudnn_convolution'
+# Convolution layer is identified by 'singa_convolution'
+# engine is case insensitive
+engine = 'cudnn'
+
 
 class Layer(object):
     """Base Python layer class.
@@ -78,12 +84,31 @@ class Layer(object):
         return tensor.from_raw_tensors(self.layer.param_values())
 
     def forward(self, flag, input):
+        '''Forward propagate through this layer.
+
+        Args:
+            flag, kTrain or kEval
+            input, an input tensor
+
+        Return:
+            a tensor for the transformed feature
+        '''
         assert self.has_setup, 'Must call setup() before forward()'
         assert isinstance(input, tensor.Tensor), 'input must be py Tensor'
         y = self.layer.Forward(flag, input.singa_tensor)
         return tensor.from_raw_tensor(y)
 
     def backward(self, flag, grad):
+        '''Backward propagate through this layer.
+
+        Args:
+            flag, for future use.
+            grad, gradient of the returned values of the forward function.
+
+        Return:
+            <dx, <dp1, dp2..>>, dx is the gradient of the input of the
+            forward function, dpi is the gradient of the i-th parameter
+        '''
         assert isinstance(grad, tensor.Tensor), 'grad must be py Tensor'
         ret = self.layer.Backward(flag, grad.singa_tensor)
         return tensor.from_raw_tensor(ret[0]), tensor.from_raw_tensors(ret[1])
@@ -104,7 +129,7 @@ class Layer(object):
 class Conv2D(Layer):
 
     def __init__(self, name, nb_kernels, kernel=3, stride=1, border_mode='same',
-                 engine='cudnn', cudnn_prefer='fatest', data_format='NCHW',
+                 cudnn_prefer='fatest', data_format='NCHW',
                  use_bias=True, W_specs=None, b_specs=None,
                  pad=None, input_sample_shape=None):
         """Construct a layer for 2D convolution.
@@ -117,8 +142,6 @@ class Conv2D(Layer):
                 'valid' -> padding is 0 for height and width
                 'same' -> padding is half of the kernel (floor),
                     the kernel must be odd number.
-            engine (string): implementation engin, could be 'cudnn'
-                (case insensitive)
             cudnn_prefer (string): the preferred algorithm for cudnn convolution
                 which could be 'fatest', 'autotune', 'limited_workspace' and
                 'no_workspace'
@@ -165,7 +188,7 @@ class Conv2D(Layer):
         self.conf.param.extend([bspecs])
         self.param_specs.append(bspecs)
 
-        _check_engine(engine, ['cudnn'])
+        _check_engine(engine, ['cudnn', 'singa'])
         self.layer = _create_layer(engine, 'Convolution')
         if input_sample_shape is not None:
             self.setup(input_sample_shape)
@@ -174,7 +197,7 @@ class Conv2D(Layer):
 class Conv1D(Conv2D):
 
     def __init__(self, name, nb_kernels, kernel=3, stride=1,
-                 border_mode='same', engine='cudnn', cudnn_prefer='fatest',
+                 border_mode='same', cudnn_prefer='fatest',
                  use_bias=True, W_specs={'init': 'Xavier'},
                  b_specs={'init': 'Constant', 'value': 0}, pad=None,
                  input_sample_shape=None):
@@ -191,7 +214,7 @@ class Conv1D(Conv2D):
         if input_sample_shape is not None:
             input_sample_shape = (1, 1, input_sample_shape[0])
         super(Conv1D, self).__init__(name, nb_kernels, (1, kernel), (0, stride),
-                                     border_mode, engine, cudnn_prefer,
+                                     border_mode, cudnn_prefer,
                                      use_bias=use_bias, pad=pad,
                                      W_specs=W_specs, b_specs=b_specs,
                                      input_sample_shape=input_sample_shape)
@@ -206,15 +229,14 @@ class Conv1D(Conv2D):
 class Pooling2D(Layer):
 
     def __init__(self, name, mode, kernel=3, stride=2, border_mode='same',
-                 pad=None, data_format='NCHW', engine='cudnn',
-                 input_sample_shape=None):
+                 pad=None, data_format='NCHW', input_sample_shape=None):
         super(Pooling2D, self).__init__(name)
         assert data_format == 'NCHW', 'Not supported data format: %s ' \
             'only "NCHW" is enabled currently' % (data_format)
         conf = self.conf.pooling_conf
         conf = _set_kernel_stride_pad(conf, kernel, stride, border_mode, pad)
         conf.pool = mode
-        _check_engine(engine, ['cudnn'])
+        _check_engine(engine, ['cudnn', 'singa'])
         self.layer = _create_layer(engine, 'Pooling')
         if input_sample_shape is not None:
             self.setup(input_sample_shape)
@@ -223,27 +245,25 @@ class Pooling2D(Layer):
 class MaxPooling2D(Pooling2D):
 
     def __init__(self, name, kernel=3, stride=2, border_mode='same', pad=None,
-                 data_format='NCHW', engine='cudnn', input_sample_shape=None):
+                 data_format='NCHW', input_sample_shape=None):
         super(MaxPooling2D, self).__init__(name, model_pb2.PoolingConf.MAX,
                                            kernel, stride, border_mode,
-                                           pad, data_format, engine,
-                                           input_sample_shape)
+                                           pad, data_format, input_sample_shape)
 
 
 class AvgPooling2D(Pooling2D):
 
     def __init__(self, name, kernel=3, stride=2, border_mode='same', pad=None,
-                 data_format='NCHW', engine='cudnn', input_sample_shape=None):
+                 data_format='NCHW', input_sample_shape=None):
         super(AvgPooling2D, self).__init__(name, model_pb2.PoolingConf.AVE,
                                            kernel, stride, border_mode,
-                                           pad, data_format, engine,
-                                           input_sample_shape)
+                                           pad, data_format, input_sample_shape)
 
 
 class MaxPooling1D(MaxPooling2D):
 
     def __init__(self, name, kernel=3, stride=2, border_mode='same', pad=None,
-                 data_format='NCHW', engine='cudnn', input_sample_shape=None):
+                 data_format='NCHW', input_sample_shape=None):
         """Max pooling for 1D feature.
 
         Args:
@@ -260,8 +280,7 @@ class MaxPooling1D(MaxPooling2D):
             input_sample_shape = None
         super(MaxPooling1D, self).__init__(name, (1, kernel), (0, stride),
                                            border_mode, pad,
-                                           data_format, engine,
-                                           input_sample_shape)
+                                           data_format, input_sample_shape)
 
     def get_output_sample_shape(self):
         shape = self.layer.GetOutputSampleShape()
@@ -271,7 +290,7 @@ class MaxPooling1D(MaxPooling2D):
 class AvgPooling1D(AvgPooling2D):
 
     def __init__(self, name, kernel=3, stride=2, border_mode='same', pad=None,
-                 data_format='NCHW', engine='cudnn', input_sample_shape=None):
+                 data_format='NCHW', input_sample_shape=None):
         """input_feature_length is a scalar value"""
         pad2 = None
         if pad is not None:
@@ -285,8 +304,7 @@ class AvgPooling1D(AvgPooling2D):
 
         super(AvgPooling1D, self).__init__(name, (kernel, 1), (0, stride),
                                            border_mode, pad2,
-                                           data_format, engine,
-                                           input_sample_shape)
+                                           data_format, input_sample_shape)
 
     def get_output_sample_shape(self):
         shape = self.layer.GetOutputSampleShape()
@@ -296,7 +314,7 @@ class AvgPooling1D(AvgPooling2D):
 class BatchNormalization(Layer):
     # TODO(wangwei) add mode and epsilon arguments
 
-    def __init__(self, name, momentum=0.9, engine='cudnn',
+    def __init__(self, name, momentum=0.9,
                  beta_specs=None, gamma_specs=None, input_sample_shape=None):
         """Batch-normalization.
 
@@ -337,16 +355,15 @@ class BatchNormalization(Layer):
         self.param_specs.append(_construct_param_specs_from_dict(beta_specs))
         self.param_specs.append(_construct_param_specs_from_dict(mean_specs))
         self.param_specs.append(_construct_param_specs_from_dict(var_specs))
-        _check_engine(engine, ['cudnn'])
+        _check_engine(engine, ['cudnn', 'singa'])
         self.layer = _create_layer(engine, 'BatchNorm')
         if input_sample_shape is not None:
             self.setup(input_sample_shape)
 
 
 class LRN(Layer):
-
     def __init__(self, name, size=5, alpha=1, beta=0.75, mode='cross_channel',
-                 k=1, engine='cudnn', input_sample_shape=None):
+                 k=1, input_sample_shape=None):
         """Local response normalization.
 
         Args:
@@ -364,7 +381,7 @@ class LRN(Layer):
         # TODO(wangwei) enable mode = 'within_channel'
         assert mode == 'cross_channel', 'only support mode="across_channel"'
         conf.norm_region = model_pb2.LRNConf.ACROSS_CHANNELS
-        _check_engine(engine, ['cudnn'])
+        _check_engine(engine, ['cudnn', 'singa'])
         self.layer = _create_layer(engine, 'LRN')
         if input_sample_shape is not None:
             self.setup(input_sample_shape)
@@ -374,7 +391,7 @@ class Dense(Layer):
 
     def __init__(self, name, num_output, use_bias=True,
                  W_specs=None, b_specs=None,
-                 W_transpose=True, engine='cuda', input_sample_shape=None):
+                 W_transpose=True, input_sample_shape=None):
         """Apply linear/affine transformation, also called inner-product or
         fully connected layer.
 
@@ -392,7 +409,6 @@ class Dense(Layer):
                 'regularizer' for regularization, currently support 'l2'
             b_specs (dict): specs for the bias vector, same fields as W_specs.
             W_transpose (bool): if true, output=x*W.T+b;
-            engine (string): could be 'cudnn', 'cuda'
             input_sample_shape (tuple): input feature length
         """
         super(Dense, self).__init__(name)
@@ -412,22 +428,19 @@ class Dense(Layer):
         self.param_specs.append(_construct_param_specs_from_dict(W_specs))
         self.conf.param.extend([_construct_param_specs_from_dict(b_specs)])
         self.param_specs.append(_construct_param_specs_from_dict(b_specs))
-        if engine == 'cudnn':
-            engine = 'cuda'
-        _check_engine(engine, ['cuda', 'cpp'])
-        self.layer = _create_layer(engine, 'Dense')
+        # dense layer is transparent to engine.
+        self.layer = _create_layer('singa', 'Dense')
         if input_sample_shape is not None:
             self.setup(input_sample_shape)
 
 
 class Dropout(Layer):
 
-    def __init__(self, name, p=0.5, engine='cuda', input_sample_shape=None):
+    def __init__(self, name, p=0.5, input_sample_shape=None):
         """Droput layer.
 
         Args:
             p (float): probability for dropping out the element, i.e., set to 0
-            engine (string): 'cudnn' for cudnn version>=5; or 'cuda'
             name (string): layer name
         """
         super(Dropout, self).__init__(name)
@@ -436,7 +449,7 @@ class Dropout(Layer):
         # 'cudnn' works for v>=5.0
         #  if engine.lower() == 'cudnn':
         #      engine = 'cuda'
-        _check_engine(engine, ['cudnn', 'cuda', 'cpp'])
+        _check_engine(engine, ['cudnn', 'singa'])
         self.layer = _create_layer(engine, 'Dropout')
         if input_sample_shape is not None:
             self.setup(input_sample_shape)
@@ -444,28 +457,25 @@ class Dropout(Layer):
 
 class Activation(Layer):
 
-    def __init__(self, name, mode='relu', engine='cudnn',
-                 input_sample_shape=None):
+    def __init__(self, name, mode='relu', input_sample_shape=None):
         """Activation layers.
 
         Args:
-            engine (string): 'cudnn'
             name (string): layer name
             mode (string): 'relu', 'sigmoid', or 'tanh'
             input_sample_shape (tuple): shape of a single sample
         """
         super(Activation, self).__init__(name)
-        _check_engine(engine, ['cudnn', 'cuda', 'cpp'])
-        mode_dict = {'relu': 'RELU', 'sigmoid': 'SIGMOID', 'tanh': 'TANH'}
-        self.conf.type = mode_dict[mode.lower()]
-        self.layer = _create_layer(engine, 'Activation')
+        self.conf.type = (engine + '_' + mode).lower()
+        _check_engine(engine, ['cudnn', 'singa'])
+        self.layer = _create_layer(engine, mode)
         if input_sample_shape is not None:
             self.setup(input_sample_shape)
 
 
 class Softmax(Layer):
 
-    def __init__(self, name, axis=1, engine='cudnn', input_sample_shape=None):
+    def __init__(self, name, axis=1, input_sample_shape=None):
         """Apply softmax.
 
         Args:
@@ -476,7 +486,7 @@ class Softmax(Layer):
         super(Softmax, self).__init__(name)
         # conf = self.conf.softmax_conf
         # conf.axis = axis
-        _check_engine(engine, ['cudnn', 'cuda', 'cpp'])
+        _check_engine(engine, ['cudnn', 'singa'])
         self.layer = _create_layer(engine, 'Softmax')
         if input_sample_shape is not None:
             self.setup(input_sample_shape)
@@ -484,7 +494,7 @@ class Softmax(Layer):
 
 class Flatten(Layer):
 
-    def __init__(self, name, axis=1, engine='cudnn', input_sample_shape=None):
+    def __init__(self, name, axis=1, input_sample_shape=None):
         """Reshape the input tensor into a matrix.
         Args:
             axis (int): reshape the input as a matrix with the dimension
@@ -494,24 +504,39 @@ class Flatten(Layer):
         super(Flatten, self).__init__(name)
         conf = self.conf.flatten_conf
         conf.axis = axis
-        _check_engine(engine, ['cudnn', 'cuda', 'cpp'])
-        if engine == 'cudnn':
-            engine = 'cuda'
-        self.layer = _create_layer(engine, 'Flatten')
+        # fltten layer is transparent to engine
+        self.layer = _create_layer('singa', 'Flatten')
         if input_sample_shape is not None:
             self.setup(input_sample_shape)
 
 
 class RNN(Layer):
-    def __init__(self, name, hidden_size, rnn_mode='lstm', engine='cudnn',
-            dropout=0.0, num_stacks=1, input_mode='linear', bidirectional=False,
-            param_specs=None, input_sample_shape=None):
+    def __init__(self, name, hidden_size, rnn_mode='lstm', dropout=0.0,
+                 num_stacks=1, input_mode='linear', bidirectional=False,
+                 param_specs=None, input_sample_shape=None):
+        '''Wrapper for singa::RNN class.
+
+        Args:
+            hidden_size, hidden feature size, the same for all stacks of layers.
+            rnn_mode, decides the rnn unit, which could be one of 'lstm', 'gru',
+                'tanh' and 'relu', refer to cudnn manual for each mode.
+            num_stacks, num of stacks of rnn layers. It is different to the
+                unrolling seqence length.
+            input_mode, 'linear' convert the input feature x by by a linear
+                transformation to get a feature vector of size hidden_size;
+                'skip' does nothing but requires the input feature size equals
+                hidden_size
+            bidirection, True for bidirectional RNN
+            param_specs, config for initializing the RNN parameters.
+            input_sample_shape, includes a single integer for the input sample
+                feature size.
+        '''
         super(RNN, self).__init__(name)
         conf = self.conf.rnn_conf
         assert hidden_size > 0, 'Hidden feature size must > 0'
         conf.hidden_size = hidden_size
-        assert rnn_mode in Set(['lstm', 'gru', 'tanh', 'relu']), \
-                'rnn mode %s is not available' %s (rnn_mode)
+        assert rnn_mode in Set(['lstm', 'gru', 'tanh', 'relu']),  \
+            'rnn mode %s is not available' % (rnn_mode)
         conf.rnn_mode = rnn_mode
         conf.num_stacks = num_stacks
         conf.dropout = dropout
@@ -519,10 +544,11 @@ class RNN(Layer):
         conf.direction = 'unidirectional'
         if bidirectional:
             conf.direction = 'bidirectional'
+        # currently only has rnn layer implemented using cudnn
         _check_engine(engine, ['cudnn'])
         if param_specs is None:
             param_specs = {'name': name + '-weight',
-                    'init': 'uniform', 'low':0, 'high':1};
+                           'init': 'uniform', 'low': 0, 'high': 1}
         self.conf.param.extend([_construct_param_specs_from_dict(param_specs)])
         self.param_specs.append(_construct_param_specs_from_dict(param_specs))
 
@@ -531,18 +557,59 @@ class RNN(Layer):
             self.setup(input_sample_shape)
 
     def forward(self, flag, inputs):
+        '''Forward inputs through the RNN.
+
+        Args:
+            flag, kTrain or kEval.
+            inputs, <x1, x2,...xn, hx, cx>, where xi is the input tensor for the
+                i-th position, its shape is (batch_size, input_feature_length);
+                the batch_size of xi must >= that of xi+1; hx is the initial
+                hidden state of shape (num_stacks * bidirection?2:1, batch_size,
+                hidden_size). cx is the initial cell state tensor of the same
+                shape as hy. cx is valid for only lstm. For other RNNs there is
+                no cx. Both hx and cx could be dummy tensors without shape and
+                data.
+
+        Returns:
+            <y1, y2, ... yn, hy, cy>, where yi is the output tensor for the i-th
+                position, its shape is (batch_size,
+                hidden_size * bidirection?2:1). hy is the final hidden state
+                tensor. cx is the final cell state tensor. cx is only used for
+                lstm.
+        '''
         assert self.has_setup, 'Must call setup() before forward()'
         assert len(inputs) > 1, 'The input to RNN must include at '\
-                'least one input tensor '\
-                'and one hidden state tensor (could be a dummy tensor)'
+            'least one input tensor '\
+            'and one hidden state tensor (could be a dummy tensor)'
         tensors = []
         for t in inputs:
-            assert isinstance(t, tensor.Tensor), 'input must be py Tensor %s' % (type(t))
+            assert isinstance(t, tensor.Tensor), \
+                'input must be py Tensor %s' % (type(t))
             tensors.append(t.singa_tensor)
         y = self.layer.Forward(flag, tensors)
         return tensor.from_raw_tensors(y)
 
     def backward(self, flag, grad):
+        '''Backward gradients through the RNN.
+
+        Args:
+            flag, for future use.
+            grad, <dy1, dy2,...dyn, dhy, dcy>, where dyi is the gradient for the
+            i-th output, its shape is (batch_size, hidden_size*bidirection?2:1);
+                dhy is the gradient for the final hidden state, its shape is
+                (num_stacks * bidirection?2:1, batch_size,
+                hidden_size). dcy is the gradient for the final cell state.
+                cx is valid only for lstm. For other RNNs there is
+                no cx. Both dhy and dcy could be dummy tensors without shape and
+                data.
+
+        Returns:
+            <dx1, dx2, ... dxn, dhx, dcx>, where dxi is the gradient tensor for
+            the i-th input, its shape is (batch_size,
+                input_feature_length). dhx is the gradient for the initial
+                hidden state. dcx is the gradient for the initial cell state,
+                which is valid only for lstm.
+        '''
         tensors = []
         for t in grad:
             assert isinstance(t, tensor.Tensor), 'grad must be py Tensor'
@@ -550,21 +617,23 @@ class RNN(Layer):
         ret = self.layer.Backward(flag, tensors)
         return tensor.from_raw_tensors(ret[0]), tensor.from_raw_tensors(ret[1])
 
+
 class LSTM(RNN):
-    def __init__(self, name, hidden_size, engine='cudnn',
-            dropout=0.0, num_stacks=1, input_mode='linear', bidirectional=False,
-            param_specs=None, input_sample_shape=None):
-        super(LSTM, self).__init__(name, hidden_size,  'lstm', engine, dropout,
-                num_stacks, input_mode, bidirectional, param_specs,
-                input_sample_shape)
+    def __init__(self, name, hidden_size, dropout=0.0, num_stacks=1,
+                 input_mode='linear', bidirectional=False,
+                 param_specs=None, input_sample_shape=None):
+        super(LSTM, self).__init__(name, hidden_size,  'lstm',  dropout,
+                                   num_stacks, input_mode, bidirectional,
+                                   param_specs, input_sample_shape)
+
 
 class GRU(RNN):
-    def __init__(self, name, hidden_size, engine='cudnn',
-            dropout=0.0, num_stacks=1, input_mode='linear', bidirectional=False,
-            param_specs=None, input_sample_shape=None):
-        super(GRU, self).__init__(name,  hidden_size, 'gru', engine, dropout,
-                num_stacks, input_mode, bidirectional, param_specs,
-                input_sample_shape)
+    def __init__(self, name, hidden_size, dropout=0.0, num_stacks=1,
+                 input_mode='linear', bidirectional=False, param_specs=None,
+                 input_sample_shape=None):
+        super(GRU, self).__init__(name,  hidden_size, 'gru',  dropout,
+                                  num_stacks, input_mode, bidirectional,
+                                  param_specs, input_sample_shape)
 
 
 def _check_engine(engine, allowed_engines):
@@ -573,12 +642,17 @@ def _check_engine(engine, allowed_engines):
            (engine, ', '.join(allowed_engines))
 
 
-def _create_layer(engine, layer):
-    if engine == 'cuda' or engine == 'cpp':
-        layer_type = layer
-    else:
-        layer_type = engine.title() + layer
-    return singa_wrap.CreateLayer(layer_type)
+def _create_layer(eng, layer):
+    ''' create singa wrap layer.
+
+    Both arguments are case insensitive.
+    Args:
+        engine, implementation engine, either 'singa' or 'cudnn'
+        layer, layer type, e.g., 'convolution', 'pooling'; for activation
+        layers, use the specific activation mode, e.g. 'relu', 'tanh'.
+    '''
+    layer_type = eng + '_' + layer
+    return singa_wrap.CreateLayer(layer_type.lower())
 
 
 def _set_kernel_stride_pad(conf, kernel, stride, border_mode, pad):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/python/singa/net.py
----------------------------------------------------------------------
diff --git a/src/python/singa/net.py b/src/python/singa/net.py
index c0ba61d..1617717 100644
--- a/src/python/singa/net.py
+++ b/src/python/singa/net.py
@@ -92,17 +92,17 @@ class FeedForwardNet(object):
         return tensor.softmax(xx)
 
     def forward(self, flag, x):
-        #print x.l1()
+        # print x.l1()
         for lyr in self.layers:
             x = lyr.forward(flag, x)
         #    print lyr.name, x.l1()
         return x
 
-    def backward(self, flag=kTrain):
+    def backward(self):
         grad = self.loss.backward()
         pgrads = []
         for lyr in reversed(self.layers):
-            grad, _pgrads = lyr.backward(flag, grad)
+            grad, _pgrads = lyr.backward(kTrain, grad)
             for g in reversed(_pgrads):
                 pgrads.append(g)
         return reversed(pgrads)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/src/python/singa/tensor.py b/src/python/singa/tensor.py
index 2d6fa5a..6e84a4f 100644
--- a/src/python/singa/tensor.py
+++ b/src/python/singa/tensor.py
@@ -39,7 +39,7 @@ class Tensor(object):
             return
         else:
             assert isinstance(shape, tuple), 'shape should be tuple'
-            vs = _tuple_to_vector(shape)
+            vs = list(shape)
             if device is None:
                 self.singa_tensor = singa.Tensor(vs, dtype)
             else:
@@ -111,8 +111,9 @@ class Tensor(object):
         return self.singa_tensor.L1()
 
     def set_value(self, x):
-        if isinstance(x, float):
-            self.singa_tensor.floatSetValue(x)
+        # assert type(x) == float, 'set value only accepts float input'
+        # if isinstance(x, float):
+        self.singa_tensor.floatSetValue(x)
 
     def copy_data(self, t):
         self.singa_tensor.CopyData(t.singa_tensor)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/python/swig/core_device.i
----------------------------------------------------------------------
diff --git a/src/python/swig/core_device.i b/src/python/swig/core_device.i
index b79d37e..a5d0731 100644
--- a/src/python/swig/core_device.i
+++ b/src/python/swig/core_device.i
@@ -58,6 +58,10 @@ class Platform {
   static const std::string DeviceQuery(int id, bool verbose = false);
   static const std::vector<std::shared_ptr<Device> >
   CreateCudaGPUs(const size_t num_devices, size_t init_size = 0);
+  static const std::vector<std::shared_ptr<Device>>
+  CreateCudaGPUsOn(const std::vector<int> &devices, size_t init_size = 0);
+  static std::shared_ptr<Device> GetDefaultDevice();
 };
+
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/python/swig/model_layer.i
----------------------------------------------------------------------
diff --git a/src/python/swig/model_layer.i b/src/python/swig/model_layer.i
index 6cbfe8f..a6cdad1 100644
--- a/src/python/swig/model_layer.i
+++ b/src/python/swig/model_layer.i
@@ -81,7 +81,6 @@ const std::vector<std::string> GetRegisteredLayers();
 class RNN : public Layer {
 };
 
-#if CUDNN_VERSION_MINOR >= 5 && CUDNN_VERSION_PATCH >= 5
 class CudnnRNN : public RNN {
  public:
  // note: Must use std::vector instead of vector.
@@ -93,7 +92,5 @@ class CudnnRNN : public RNN {
     const std::vector<size_t> GetOutputSampleShape() const override;
 };
 
-#endif  // CUDNN_VERSION_MINOR >= 5 && CUDNN_VERSION_PATCH >= 5
-
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_activation.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_activation.cc b/test/singa/test_activation.cc
index 001c49c..bb8ad84 100644
--- a/test/singa/test_activation.cc
+++ b/test/singa/test_activation.cc
@@ -27,15 +27,15 @@ using singa::Activation;
 using singa::Shape;
 TEST(Activation, Setup) {
   Activation acti;
-  EXPECT_EQ("Activation", acti.layer_type());
+  // EXPECT_EQ("Activation", acti.layer_type());
 
   singa::LayerConf conf;
-  conf.set_type("RELU");
+  conf.set_type("singa_relu");
   singa::ReLUConf* reluconf = conf.mutable_relu_conf();
   reluconf->set_negative_slope(0.5);
 
   acti.Setup(Shape{3}, conf);
-  EXPECT_EQ("RELU", acti.Mode());
+  EXPECT_EQ("relu", acti.Mode());
   EXPECT_EQ(0.5f, acti.Negative_slope());
 }
 
@@ -46,13 +46,13 @@ TEST(Activation, Forward) {
   in.CopyDataFromHostPtr<float>(x, n);
 
   float neg_slope = 0.5f;
-  std::string types[] = {"SIGMOID","TANH","RELU"};
+  std::string types[] = {"singa_sigmoid", "singa_tanh", "singa_relu"};
   for (int j = 0; j < 3; j++) {
     Activation acti;
     singa::LayerConf conf;
     std::string layertype = types[j];
     conf.set_type(layertype);
-    if (layertype == "RELU") {
+    if (layertype == "relu") {
       singa::ReLUConf* reluconf = conf.mutable_relu_conf();
       reluconf->set_negative_slope(neg_slope);
     }
@@ -64,15 +64,15 @@ TEST(Activation, Forward) {
     EXPECT_EQ(n, out.Size());
 
     float* y = new float[n];
-    if (acti.Mode() == "SIGMOID") {
+    if (acti.Mode() == "sigmoid") {
       for (size_t i = 0; i < n; i++)
         y[i] = 1.f / (1.f + exp(-x[i]));
     }
-    else if (acti.Mode() == "TANH") {
+    else if (acti.Mode() == "tanh") {
       for (size_t i = 0; i < n; i++)
         y[i] = tanh(x[i]);
     }
-    else if (acti.Mode() == "RELU") {
+    else if (acti.Mode() == "relu") {
       for (size_t i = 0; i < n; i++)
         y[i] = (x[i] >= 0.f) ? x[i] : 0.f;
     }
@@ -92,13 +92,13 @@ TEST(Activation, Backward) {
   in.CopyDataFromHostPtr<float>(x, n);
 
   float neg_slope = 0.5f;
-  std::string types[] = {"SIGMOID","TANH","RELU"};
+  std::string types[] = {"singa_sigmoid", "singa_tanh", "singa_relu"};
   for (int j = 0; j < 3; j++) {
     Activation acti;
     singa::LayerConf conf;
     std::string layertype = types[j];
     conf.set_type(layertype);
-    if (layertype == "RELU") {
+    if (layertype == "relu") {
       singa::ReLUConf* reluconf = conf.mutable_relu_conf();
       reluconf->set_negative_slope(neg_slope);
     }
@@ -114,15 +114,15 @@ TEST(Activation, Backward) {
     const float* xptr = in_diff.first.data<float>();
 
     float* dx = new float[n];
-    if (acti.Mode() == "SIGMOID") {
+    if (acti.Mode() == "sigmoid") {
       for (size_t i = 0; i < n; i++)
         dx[i] = grad[i] * yptr[i] * (1. - yptr[i]);
     }
-    else if (acti.Mode() == "TANH") {
+    else if (acti.Mode() == "tanh") {
       for (size_t i = 0; i < n; i++)
         dx[i] = grad[i] * (1 - yptr[i] * yptr[i]);
     }
-    else if (acti.Mode() == "RELU") {
+    else if (acti.Mode() == "relu") {
       for (size_t i = 0; i < n; i++)
         dx[i] = grad[i] * (x[i] > 0.f) + acti.Negative_slope() * (x[i] <= 0.f);
     }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_batchnorm.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_batchnorm.cc b/test/singa/test_batchnorm.cc
index c72dc0f..a61f6f3 100644
--- a/test/singa/test_batchnorm.cc
+++ b/test/singa/test_batchnorm.cc
@@ -27,7 +27,7 @@ using namespace singa;
 
 TEST(BatchNorm, Setup) {
   BatchNorm batchnorm;
-  EXPECT_EQ("BatchNorm", batchnorm.layer_type());
+  // EXPECT_EQ("BatchNorm", batchnorm.layer_type());
 
   singa::LayerConf conf;
   singa::BatchNormConf *batchnorm_conf = conf.mutable_batchnorm_conf();
@@ -68,10 +68,10 @@ TEST(BatchNorm, Forward) {
   EXPECT_EQ(1u, shape[1]);
   EXPECT_EQ(2u, shape[2]);
   EXPECT_EQ(1u, shape[3]);
-  EXPECT_NEAR(1.0f, outptr[0], 1e-6f);
-  EXPECT_NEAR(1.0f, outptr[1], 1e-6f);
-  EXPECT_NEAR(3.0f, outptr[2], 1e-6f);
-  EXPECT_NEAR(3.0f, outptr[3], 1e-6f);
+  EXPECT_NEAR(1.0f, outptr[0], 1e-4f);
+  EXPECT_NEAR(1.0f, outptr[1], 1e-4f);
+  EXPECT_NEAR(3.0f, outptr[2], 1e-4f);
+  EXPECT_NEAR(3.0f, outptr[3], 1e-4f);
 }
 
 TEST(BatchNorm, Backward) {
@@ -107,10 +107,10 @@ TEST(BatchNorm, Backward) {
   EXPECT_EQ(2u, shape[2]);
   EXPECT_EQ(1u, shape[3]);
   const float *dxptr = ret.first.data<float>();
-  EXPECT_NEAR(.0f, dxptr[0], 1e-6f);
-  EXPECT_NEAR(.0f, dxptr[1], 1e-6f);
-  EXPECT_NEAR(.0f, dxptr[2], 1e-6f);
-  EXPECT_NEAR(.0f, dxptr[3], 1e-6f);
+  EXPECT_NEAR(.0f, dxptr[0], 1e-4f);
+  EXPECT_NEAR(.0f, dxptr[1], 1e-4f);
+  EXPECT_NEAR(.0f, dxptr[2], 1e-4f);
+  EXPECT_NEAR(.0f, dxptr[3], 1e-4f);
 
   Tensor dbnScale = ret.second.at(0);
   const float *dbnScaleptr = dbnScale.data<float>();
@@ -118,8 +118,8 @@ TEST(BatchNorm, Backward) {
   EXPECT_EQ(1u, dbnScaleShape.size());
   EXPECT_EQ(2u, dbnScaleShape[0]);
 
-  EXPECT_NEAR(-2.0f, dbnScaleptr[0], 1e-6f);
-  EXPECT_NEAR(-2.0f, dbnScaleptr[1], 1e-6f);
+  EXPECT_NEAR(-2.0f, dbnScaleptr[0], 1e-4f);
+  EXPECT_NEAR(-2.0f, dbnScaleptr[1], 1e-4f);
 
   Tensor dbnBias = ret.second.at(1);
   const float *dbnBiasptr = dbnBias.data<float>();
@@ -127,6 +127,6 @@ TEST(BatchNorm, Backward) {
   EXPECT_EQ(1u, dbnBiasShape.size());
   EXPECT_EQ(2u, dbnBiasShape[0]);
 
-  EXPECT_NEAR(6.0f, dbnBiasptr[0], 1e-6f);
-  EXPECT_NEAR(4.0f, dbnBiasptr[1], 1e-6f);
+  EXPECT_NEAR(6.0f, dbnBiasptr[0], 1e-4f);
+  EXPECT_NEAR(4.0f, dbnBiasptr[1], 1e-4f);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_convolution.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_convolution.cc b/test/singa/test_convolution.cc
index c3ddcee..4cfb38d 100644
--- a/test/singa/test_convolution.cc
+++ b/test/singa/test_convolution.cc
@@ -29,7 +29,7 @@ using singa::Convolution;
 using singa::Shape;
 TEST(Convolution, Setup) {
   Convolution conv;
-  EXPECT_EQ("Convolution", conv.layer_type());
+  // EXPECT_EQ("Convolution", conv.layer_type());
 
   singa::LayerConf conf;
   singa::ConvolutionConf *convconf = conf.mutable_convolution_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_cudnn_activation.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_activation.cc b/test/singa/test_cudnn_activation.cc
index 9279d6c..6a989d1 100644
--- a/test/singa/test_cudnn_activation.cc
+++ b/test/singa/test_cudnn_activation.cc
@@ -29,12 +29,12 @@
 
 using singa::CudnnActivation;
 using singa::Shape;
-TEST(TCudnnActivation, Setup) {
+TEST(CudnnActivation, Setup) {
   CudnnActivation acti;
-  EXPECT_EQ("CudnnActivation", acti.layer_type());
+  // EXPECT_EQ("CudnnActivation", acti.layer_type());
 
   singa::LayerConf conf;
-  conf.set_type("RELU");
+  conf.set_type("cudnn_relu");
   singa::ReLUConf* reluconf = conf.mutable_relu_conf();
   reluconf->set_negative_slope(0.5f);
 
@@ -43,7 +43,7 @@ TEST(TCudnnActivation, Setup) {
   EXPECT_EQ(0.5f, acti.Negative_slope());
 }
 
-TEST(TCudnnActivation, Forward) {
+TEST(CudnnActivation, Forward) {
   const float x[] = {1.0f, 2.0f, 3.0f, -2.0f, -3.0f, -4.0};
   size_t n = sizeof(x) / sizeof(float);
   auto cuda = std::make_shared<singa::CudaGPU>();
@@ -51,13 +51,13 @@ TEST(TCudnnActivation, Forward) {
   in.CopyDataFromHostPtr<float>(x, n);
 
   float neg_slope = 0.5f;
-  std::string types[] = {"SIGMOID", "TANH", "RELU"};
+  std::string types[] = {"cudnn_sigmoid", "cudnn_tanh", "cudnn_relu"};
   for (int j = 0; j < 3; j++) {
     CudnnActivation acti;
     singa::LayerConf conf;
     std::string layertype = types[j];
     conf.set_type(layertype);
-    if (layertype == "RELU") {
+    if (layertype == "relu") {
       singa::ReLUConf* reluconf = conf.mutable_relu_conf();
       reluconf->set_negative_slope(neg_slope);
     }
@@ -68,11 +68,11 @@ TEST(TCudnnActivation, Forward) {
     out.ToHost();
     const float* yptr = out.data<float>();
     float* y = new float[n];
-    if (acti.Mode() == "SIGMOID") {
+    if (acti.Mode() == "sigmoid") {
       for (size_t i = 0; i < n; i++) y[i] = 1.f / (1.f + exp(-x[i]));
-    } else if (acti.Mode() == "TANH") {
+    } else if (acti.Mode() == "tanh") {
       for (size_t i = 0; i < n; i++) y[i] = tanh(x[i]);
-    } else if (acti.Mode() == "RELU") {
+    } else if (acti.Mode() == "relu") {
       for (size_t i = 0; i < n; i++) y[i] = (x[i] >= 0.f) ? x[i] : 0.f;
     } else
       LOG(FATAL) << "Unkown activation: " << acti.Mode();
@@ -83,14 +83,14 @@ TEST(TCudnnActivation, Forward) {
   }
 }
 
-TEST(TCudnnActivation, Backward) {
+TEST(CudnnActivation, Backward) {
   const float x[] = {2.0f, 3.0f, 3.0f, 7.f, 0.0f, 5.0, 1.5, 2.5, -2.5, 1.5};
   size_t n = sizeof(x) / sizeof(float);
   auto cuda = std::make_shared<singa::CudaGPU>();
   singa::Tensor in(singa::Shape{n}, cuda);
   in.CopyDataFromHostPtr<float>(x, n);
   float neg_slope = 0.5f;
-  std::string types[] = {"SIGMOID", "TANH", "RELU"};
+  std::string types[] = {"cudnn_sigmoid", "cudnn_tanh", "cudnn_relu"};
   for (int j = 0; j < 3; j++) {
     CudnnActivation acti;
     singa::LayerConf conf;
@@ -115,11 +115,11 @@ TEST(TCudnnActivation, Backward) {
     in_diff.ToHost();
     const float* xptr = in_diff.data<float>();
     float* dx = new float[n];
-    if (acti.Mode() == "SIGMOID") {
+    if (acti.Mode() == "sigmoid") {
       for (size_t i = 0; i < n; i++) dx[i] = grad[i] * yptr[i] * (1. - yptr[i]);
-    } else if (acti.Mode() == "TANH") {
+    } else if (acti.Mode() == "tanh") {
       for (size_t i = 0; i < n; i++) dx[i] = grad[i] * (1. - yptr[i] * yptr[i]);
-    } else if (acti.Mode() == "RELU") {
+    } else if (acti.Mode() == "relu") {
       for (size_t i = 0; i < n; i++)
         dx[i] =
             grad[i] * (x[i] > 0.f);  //+ acti.Negative_slope() * (x[i] <= 0.f);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_batchnorm.cc b/test/singa/test_cudnn_batchnorm.cc
index 4f6a38b..b2746dc 100644
--- a/test/singa/test_cudnn_batchnorm.cc
+++ b/test/singa/test_cudnn_batchnorm.cc
@@ -28,7 +28,7 @@ using singa::CudnnBatchNorm;
 using singa::Shape;
 TEST(CudnnBatchNorm, Setup) {
   CudnnBatchNorm batchnorm;
-  EXPECT_EQ("CudnnBatchNorm", batchnorm.layer_type());
+  // EXPECT_EQ("CudnnBatchNorm", batchnorm.layer_type());
 
   singa::LayerConf conf;
   singa::BatchNormConf *batchnorm_conf = conf.mutable_batchnorm_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_convolution.cc b/test/singa/test_cudnn_convolution.cc
index 66c62f6..8dbee63 100644
--- a/test/singa/test_cudnn_convolution.cc
+++ b/test/singa/test_cudnn_convolution.cc
@@ -27,7 +27,7 @@ using singa::CudnnConvolution;
 using singa::Shape;
 TEST(CudnnConvolution, Setup) {
   CudnnConvolution conv;
-  EXPECT_EQ("CudnnConvolution", conv.layer_type());
+  // EXPECT_EQ("CudnnConvolution", conv.layer_type());
 
   singa::LayerConf conf;
   singa::ConvolutionConf *convconf = conf.mutable_convolution_conf();
@@ -199,7 +199,7 @@ TEST(CudnnConvolution, Backward) {
 // Tests for prefer=autotune
 TEST(CudnnConvolution_AT, Setup) {
   CudnnConvolution conv;
-  EXPECT_EQ("CudnnConvolution", conv.layer_type());
+  // EXPECT_EQ("CudnnConvolution", conv.layer_type());
 
   singa::LayerConf conf;
   singa::ConvolutionConf *convconf = conf.mutable_convolution_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_cudnn_dropout.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_dropout.cc b/test/singa/test_cudnn_dropout.cc
index 7f28aca..4a89235 100644
--- a/test/singa/test_cudnn_dropout.cc
+++ b/test/singa/test_cudnn_dropout.cc
@@ -36,7 +36,7 @@ using singa::CudnnDropout;
 using singa::Shape;
 TEST(CudnnDropout, Setup) {
   CudnnDropout drop;
-  EXPECT_EQ("CudnnDropout", drop.layer_type());
+  // EXPECT_EQ("CudnnDropout", drop.layer_type());
 
   singa::LayerConf conf;
   singa::DropoutConf* dropconf = conf.mutable_dropout_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_cudnn_lrn.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_lrn.cc b/test/singa/test_cudnn_lrn.cc
index 23fbe2e..04ca5f2 100644
--- a/test/singa/test_cudnn_lrn.cc
+++ b/test/singa/test_cudnn_lrn.cc
@@ -30,7 +30,7 @@ using singa::CudnnLRN;
 using singa::Shape;
 TEST(CudnnLRN, Setup) {
   CudnnLRN lrn;
-  EXPECT_EQ("CudnnLRN", lrn.layer_type());
+  // EXPECT_EQ("CudnnLRN", lrn.layer_type());
 
   singa::LayerConf conf;
   singa::LRNConf *lrn_conf = conf.mutable_lrn_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_cudnn_pooling.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_pooling.cc b/test/singa/test_cudnn_pooling.cc
index 5c01889..0e3314e 100644
--- a/test/singa/test_cudnn_pooling.cc
+++ b/test/singa/test_cudnn_pooling.cc
@@ -27,7 +27,7 @@ using singa::CudnnPooling;
 using singa::Shape;
 TEST(CudnnPooling, Setup) {
   CudnnPooling pool;
-  EXPECT_EQ("CudnnPooling", pool.layer_type());
+  //  EXPECT_EQ("CudnnPooling", pool.layer_type());
 
   singa::LayerConf conf;
   singa::PoolingConf *poolconf = conf.mutable_pooling_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_rnn.cc b/test/singa/test_cudnn_rnn.cc
index effb3b1..e293cf7 100644
--- a/test/singa/test_cudnn_rnn.cc
+++ b/test/singa/test_cudnn_rnn.cc
@@ -45,7 +45,7 @@ class TestCudnnRNN : public ::testing::Test {
 
 TEST_F(TestCudnnRNN, Setup) {
   CudnnRNN rnn;
-  EXPECT_EQ("CudnnRNN", rnn.layer_type());
+  // EXPECT_EQ("CudnnRNN", rnn.layer_type());
   rnn.Setup(Shape{2}, conf);
   auto weight = rnn.param_values().at(0);
   EXPECT_EQ(weight.Size(), hidden_size * (2 + hidden_size + 2));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_cudnn_softmax.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_softmax.cc b/test/singa/test_cudnn_softmax.cc
index 2b88843..6e0d5ab 100644
--- a/test/singa/test_cudnn_softmax.cc
+++ b/test/singa/test_cudnn_softmax.cc
@@ -31,7 +31,7 @@ using singa::CudnnSoftmax;
 using singa::Shape;
 TEST(CudnnSoftmax, Setup) {
   CudnnSoftmax sft;
-  EXPECT_EQ("CudnnSoftmax", sft.layer_type());
+  // EXPECT_EQ("CudnnSoftmax", sft.layer_type());
 
   singa::LayerConf conf;
   singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_dense.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_dense.cc b/test/singa/test_dense.cc
index f4ecdfc..17e161a 100644
--- a/test/singa/test_dense.cc
+++ b/test/singa/test_dense.cc
@@ -26,7 +26,7 @@ using singa::Dense;
 using singa::Shape;
 TEST(Dense, Setup) {
   Dense dense;
-  EXPECT_EQ("Dense", dense.layer_type());
+  // EXPECT_EQ("Dense", dense.layer_type());
 
   singa::LayerConf conf;
   singa::DenseConf *denseconf = conf.mutable_dense_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_dropout.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_dropout.cc b/test/singa/test_dropout.cc
index 3dd988a..b0c34a3 100644
--- a/test/singa/test_dropout.cc
+++ b/test/singa/test_dropout.cc
@@ -26,7 +26,7 @@ using singa::Dropout;
 using singa::Shape;
 TEST(Dropout, Setup) {
   Dropout drop;
-  EXPECT_EQ("Dropout", drop.layer_type());
+  // EXPECT_EQ("Dropout", drop.layer_type());
 
   singa::LayerConf conf;
   singa::DropoutConf* dropconf = conf.mutable_dropout_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_flatten.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_flatten.cc b/test/singa/test_flatten.cc
index 25e00c4..65748f7 100644
--- a/test/singa/test_flatten.cc
+++ b/test/singa/test_flatten.cc
@@ -26,7 +26,7 @@ using singa::Flatten;
 using singa::Shape;
 TEST(Flatten, Setup) {
   Flatten flt;
-  EXPECT_EQ("Flatten", flt.layer_type());
+  // EXPECT_EQ("Flatten", flt.layer_type());
 
   singa::LayerConf conf;
   singa::FlattenConf *flattenconf = conf.mutable_flatten_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_layer.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_layer.cc b/test/singa/test_layer.cc
index 4071762..aa01746 100644
--- a/test/singa/test_layer.cc
+++ b/test/singa/test_layer.cc
@@ -4,26 +4,25 @@
 
 TEST(Layer, CreateLayer) {
   std::vector<std::string> types{
-      "Convolution", "Dense", "Dropout", "Activation", "BatchNorm",
-      "Flatten",     "LRN",   "Pooling", "PReLU",      "Softmax"};
+      "convolution", "dense", "dropout", "relu", "batchnorm",
+      "flatten",     "lrn",   "pooling", "prelu",      "softmax"};
   for (auto type : types) {
-    auto layer = singa::CreateLayer(type);
-    EXPECT_EQ(layer->layer_type(), type);
+    auto layer = singa::CreateLayer("singa_" + type);
+    // EXPECT_EQ(layer->layer_type(), type);
   }
 }
 
 #ifdef USE_CUDNN
 TEST(Layer, CreateCudnnLayer) {
   std::vector<std::string> types{
-      "CudnnConvolution", "CudnnActivation",
-      "CudnnBatchNorm",   "Flatten",      "CudnnLRN",
-      "CudnnPooling",     "PReLU",        "CudnnSoftmax"};
+      "convolution", "dropout", "relu", "batchnorm",
+      "lrn",   "pooling", "softmax"};
 #if CUDNN_VERSION_MAJOR >= 5
-  types.push_back("CudnnDropout");
+  types.push_back("dropout");
 #endif
   for (auto type : types) {
-    auto layer = singa::CreateLayer(type);
-    EXPECT_EQ(layer->layer_type(), type);
+    auto layer = singa::CreateLayer("cudnn_" + type);
+    // EXPECT_EQ(layer->layer_type(), type);
   }
 }
 #endif

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_lrn.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_lrn.cc b/test/singa/test_lrn.cc
index 5de4535..454e1a9 100644
--- a/test/singa/test_lrn.cc
+++ b/test/singa/test_lrn.cc
@@ -26,7 +26,7 @@ using namespace singa;
 
 TEST(LRN, Setup) {
   LRN lrn;
-  EXPECT_EQ("LRN", lrn.layer_type());
+  // EXPECT_EQ("LRN", lrn.layer_type());
 
   LayerConf conf;
   LRNConf *lrn_conf = conf.mutable_lrn_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_pooling.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_pooling.cc b/test/singa/test_pooling.cc
index 3089a90..7ba56d1 100644
--- a/test/singa/test_pooling.cc
+++ b/test/singa/test_pooling.cc
@@ -26,7 +26,7 @@ using singa::Pooling;
 using singa::Shape;
 TEST(Pooling, Setup) {
   Pooling pool;
-  EXPECT_EQ("Pooling", pool.layer_type());
+  //  EXPECT_EQ("Pooling", pool.layer_type());
 
   singa::LayerConf conf;
   singa::PoolingConf *poolconf = conf.mutable_pooling_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_prelu.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_prelu.cc b/test/singa/test_prelu.cc
index dbb7cde..77b4b74 100644
--- a/test/singa/test_prelu.cc
+++ b/test/singa/test_prelu.cc
@@ -27,7 +27,7 @@ using singa::PReLU;
 using singa::Shape;
 TEST(PReLU, Setup) {
   PReLU prelu;
-  EXPECT_EQ("PReLU", prelu.layer_type());
+  // EXPECT_EQ("PReLU", prelu.layer_type());
 
   singa::LayerConf conf;
   singa::PReLUConf *preluconf = conf.mutable_prelu_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/test/singa/test_softmax.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_softmax.cc b/test/singa/test_softmax.cc
index 00b8378..8064b80 100644
--- a/test/singa/test_softmax.cc
+++ b/test/singa/test_softmax.cc
@@ -27,7 +27,7 @@ using singa::Softmax;
 using singa::Shape;
 TEST(Softmax, Setup) {
   Softmax sft;
-  EXPECT_EQ("Softmax", sft.layer_type());
+  // EXPECT_EQ("Softmax", sft.layer_type());
 
   singa::LayerConf conf;
   sft.Setup(Shape{3}, conf);



[2/2] incubator-singa git commit: SINGA-235 - Unify the engines for cudnn and singa layers

Posted by wa...@apache.org.
SINGA-235 - Unify the engines for cudnn and singa layers

For most layers, we would have multiple implementations, e.g., using
cudnn for nvidia gpu, using cpp for cpu and using opencl for other gpus.

These layers have different classes. They are registered with different
identifiers. This ticket would unify the layer identifiers for each
engine:
1. cudnn layers are registered with identifier = cudnn_xxx, e.g.,
cudnn_convolution for the CudnnConvolution layer.
2. singa layers are registered with identifier = singa_xxx, e.g.,
singa_convolution for the Convolution layer.

cudnn engine must run on cuda devices. and singa engine could run on
cuda-gpu device or cpp-cpu device depending on the layer type. For
instance, the Convolution layer must run on cpp-cpu device, and Dense
layer can run on both devices and would select the correct device
automatically.
Users need to make sure the engine and the device of the tensors.

Both CPP and Python code is updated. Users have to compose the layer
identifier manually for CPP version. For Python version, users can set
layer.engine='cudnn' or 'singa'.

All identifiers are case insensitive.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/05720c21
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/05720c21
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/05720c21

Branch: refs/heads/dev
Commit: 05720c21636c0fd55770206176a60a9ab20ae16c
Parents: 53639b7
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Wed Aug 10 21:05:22 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Thu Aug 11 10:52:01 2016 +0800

----------------------------------------------------------------------
 examples/char-rnn/README.md            |  17 +--
 examples/char-rnn/train.py             |   2 +-
 examples/cifar10/alexnet-parallel.cc   |  84 ++++-------
 examples/cifar10/alexnet.cc            |  58 ++++---
 examples/cifar10/alexnet.py            |   5 +-
 examples/cifar10/train.py              |  33 ++--
 examples/cifar10/vgg-parallel.cc       |  91 ++++++-----
 examples/cifar10/vgg.py                |  24 +--
 examples/imagenet/alexnet.cc           |  22 +--
 include/singa/core/device.h            |  30 ++--
 include/singa/model/feed_forward_net.h |  12 +-
 include/singa/model/layer.h            |   4 +-
 include/singa/utils/integer.h          |  73 +++++++++
 src/core/device/platform.cc            |   4 +-
 src/core/tensor/tensor.cc              |   4 +-
 src/core/tensor/tensor_math_cpp.h      |   2 +-
 src/model/feed_forward_net.cc          |  17 +--
 src/model/layer/activation.cc          |  27 ++--
 src/model/layer/activation.h           |   2 +-
 src/model/layer/batchnorm.cc           |   6 +-
 src/model/layer/batchnorm.h            |   2 +-
 src/model/layer/convolution.cc         |   2 +-
 src/model/layer/convolution.h          |   2 +-
 src/model/layer/cudnn_activation.cc    |  10 +-
 src/model/layer/cudnn_activation.h     |   2 +-
 src/model/layer/cudnn_batchnorm.cc     |   2 +-
 src/model/layer/cudnn_batchnorm.h      |   2 +-
 src/model/layer/cudnn_convolution.cc   |   2 +-
 src/model/layer/cudnn_convolution.h    |   2 +-
 src/model/layer/cudnn_dropout.cc       |   2 +-
 src/model/layer/cudnn_dropout.h        |   2 +-
 src/model/layer/cudnn_lrn.cc           |   2 +-
 src/model/layer/cudnn_lrn.h            |   2 +-
 src/model/layer/cudnn_pooling.cc       |   2 +-
 src/model/layer/cudnn_pooling.h        |   2 +-
 src/model/layer/cudnn_rnn.cc           |  17 +--
 src/model/layer/cudnn_rnn.h            |   2 +-
 src/model/layer/cudnn_softmax.cc       |   2 +-
 src/model/layer/cudnn_softmax.h        |   2 +-
 src/model/layer/cudnn_utils.h          |   2 +-
 src/model/layer/dense.cc               |   2 +-
 src/model/layer/dense.h                |   2 +-
 src/model/layer/dropout.cc             |   2 +-
 src/model/layer/dropout.h              |   2 +-
 src/model/layer/flatten.cc             |   2 +-
 src/model/layer/flatten.h              |   2 +-
 src/model/layer/lrn.cc                 |   2 +-
 src/model/layer/lrn.h                  |   4 +-
 src/model/layer/pooling.cc             |   2 +-
 src/model/layer/pooling.h              |   2 +-
 src/model/layer/prelu.cc               |   4 +-
 src/model/layer/prelu.h                |   2 +-
 src/model/layer/rnn.cc                 |   2 +-
 src/model/layer/rnn.h                  |   2 +-
 src/model/layer/softmax.cc             |   2 +-
 src/model/layer/softmax.h              |   2 +-
 src/python/singa/device.py             |  13 ++
 src/python/singa/layer.py              | 226 ++++++++++++++++++----------
 src/python/singa/net.py                |   6 +-
 src/python/singa/tensor.py             |   7 +-
 src/python/swig/core_device.i          |   4 +
 src/python/swig/model_layer.i          |   3 -
 test/singa/test_activation.cc          |  26 ++--
 test/singa/test_batchnorm.cc           |  26 ++--
 test/singa/test_convolution.cc         |   2 +-
 test/singa/test_cudnn_activation.cc    |  28 ++--
 test/singa/test_cudnn_batchnorm.cc     |   2 +-
 test/singa/test_cudnn_convolution.cc   |   4 +-
 test/singa/test_cudnn_dropout.cc       |   2 +-
 test/singa/test_cudnn_lrn.cc           |   2 +-
 test/singa/test_cudnn_pooling.cc       |   2 +-
 test/singa/test_cudnn_rnn.cc           |   2 +-
 test/singa/test_cudnn_softmax.cc       |   2 +-
 test/singa/test_dense.cc               |   2 +-
 test/singa/test_dropout.cc             |   2 +-
 test/singa/test_flatten.cc             |   2 +-
 test/singa/test_layer.cc               |  19 ++-
 test/singa/test_lrn.cc                 |   2 +-
 test/singa/test_pooling.cc             |   2 +-
 test/singa/test_prelu.cc               |   2 +-
 test/singa/test_softmax.cc             |   2 +-
 81 files changed, 573 insertions(+), 433 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/examples/char-rnn/README.md
----------------------------------------------------------------------
diff --git a/examples/char-rnn/README.md b/examples/char-rnn/README.md
index c5cdbb8..d4cfa30 100644
--- a/examples/char-rnn/README.md
+++ b/examples/char-rnn/README.md
@@ -1,10 +1,10 @@
 # Train Char-RNN using SINGA
 
 Recurrent neural networks (RNN) are widely used for modelling sequential data,
-e.g., natural language sentences. This example describe how to implement a RNN
+e.g., natural language sentences. This example describes how to implement a RNN
 application (or model) using SINGA's RNN layers.
-We will use the [char-rnn](https://github.com/karpathy/char-rnn) modle as an
-example, which trains over setences or
+We will use the [char-rnn](https://github.com/karpathy/char-rnn) model as an
+example, which trains over sentences or
 source code, with each character as an input unit. Particularly, we will train
 a RNN using GRU over Linux kernel source code. After training, we expect to
 generate meaningful code from the model.
@@ -12,20 +12,19 @@ generate meaningful code from the model.
 
 ## Instructions
 
-* Compile and install SINGA. Currently the RNN implmentation depends on Cudnn V5.
+* Compile and install SINGA. Currently the RNN implementation depends on Cudnn with version >= 5.05.
 
 * Prepare the dataset. Download the [kernel source code](http://cs.stanford.edu/people/karpathy/char-rnn/).
 Other plain text files can also be used.
 
 * Start the training,
 
-    python train.py input_linux.txt
+        python train.py input_linux.txt
 
   Some hyper-parameters could be set through command line,
 
-    python train.py -h
+        python train.py -h
 
+* Sample characters from the model by providing the number of characters to sample and the seed string.
 
-* Sample characters from the model by providing num of characters and the seed string.
-
-    python sample.py 100 --seed '#include <std'
+        python sample.py 100 --seed '#include <std'

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/examples/char-rnn/train.py
----------------------------------------------------------------------
diff --git a/examples/char-rnn/train.py b/examples/char-rnn/train.py
index 22fdc82..3dfa0d9 100644
--- a/examples/char-rnn/train.py
+++ b/examples/char-rnn/train.py
@@ -195,7 +195,7 @@ def train(data, max_epoch, hidden_size =100, seq_length=100, batch_size=16,
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Train multi-stack LSTM for '\
             'modeling  character sequence from plain text files')
-    parser.add_argument('data', type=string, help='training file')
+    parser.add_argument('data', type=str, help='training file')
     parser.add_argument('-b', type=int, default=32, help='batch_size')
     parser.add_argument('-l', type=int, default=64, help='sequence length')
     parser.add_argument('-d', type=int, default=128, help='hidden size')

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/examples/cifar10/alexnet-parallel.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/alexnet-parallel.cc b/examples/cifar10/alexnet-parallel.cc
index 15ef58e..8cc3352 100644
--- a/examples/cifar10/alexnet-parallel.cc
+++ b/examples/cifar10/alexnet-parallel.cc
@@ -28,21 +28,17 @@
 #include "singa/utils/channel.h"
 #include "singa/utils/string.h"
 #include "singa/core/memory.h"
-#include "../../src/model/layer/cudnn_convolution.h"
-#include "../../src/model/layer/cudnn_activation.h"
-#include "../../src/model/layer/cudnn_pooling.h"
-#include "../../src/model/layer/cudnn_lrn.h"
-#include "../../src/model/layer/dense.h"
-#include "../../src/model/layer/flatten.h"
 #include <thread>
 #include <memory>
+
 namespace singa {
+const std::string engine = "cudnn";
 
 LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
                       int pad, float std) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnConvolution");
+  conf.set_type(engine + "_convolution");
   ConvolutionConf *conv = conf.mutable_convolution_conf();
   conv->set_num_output(nb_filter);
   conv->add_kernel_size(kernel);
@@ -67,7 +63,7 @@ LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
                          int pad) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnPooling");
+  conf.set_type(engine + "_pooling");
   PoolingConf *pool = conf.mutable_pooling_conf();
   pool->set_kernel_size(kernel);
   pool->set_stride(stride);
@@ -79,14 +75,14 @@ LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
 LayerConf GenReLUConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("RELU");
+  conf.set_type(engine + "_relu");
   return conf;
 }
 
 LayerConf GenDenseConf(string name, int num_output, float std, float wd) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("Dense");
+  conf.set_type("singa_dense");
   DenseConf *dense = conf.mutable_dense_conf();
   dense->set_num_output(num_output);
 
@@ -108,7 +104,7 @@ LayerConf GenDenseConf(string name, int num_output, float std, float wd) {
 LayerConf GenLRNConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnLRN");
+  conf.set_type(engine + "_lrn");
   LRNConf *lrn = conf.mutable_lrn_conf();
   lrn->set_local_size(3);
   lrn->set_alpha(5e-05);
@@ -119,7 +115,7 @@ LayerConf GenLRNConf(string name) {
 LayerConf GenFlattenConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("Flatten");
+  conf.set_type("singa_flatten");
   return conf;
 }
 
@@ -127,20 +123,19 @@ FeedForwardNet CreateNet() {
   FeedForwardNet net;
   Shape s{3, 32, 32};
 
-  net.Add(new CudnnConvolution(), GenConvConf("conv1", 32, 5, 1, 2, 0.0001),
-          &s);
-  net.Add(new CudnnPooling(), GenPoolingConf("pool1", true, 3, 2, 1));
-  net.Add(new CudnnActivation(), GenReLUConf("relu1"));
-  net.Add(new CudnnLRN(), GenLRNConf("lrn1"));
-  net.Add(new CudnnConvolution(), GenConvConf("conv2", 32, 5, 1, 2, 0.01));
-  net.Add(new CudnnActivation(), GenReLUConf("relu2"));
-  net.Add(new CudnnPooling(), GenPoolingConf("pool2", false, 3, 2, 1));
-  net.Add(new CudnnLRN(), GenLRNConf("lrn2"));
-  net.Add(new CudnnConvolution, GenConvConf("conv3", 64, 5, 1, 2, 0.01));
-  net.Add(new CudnnActivation(), GenReLUConf("relu3"));
-  net.Add(new CudnnPooling(), GenPoolingConf("pool3", false, 3, 2, 1));
-  net.Add(new Flatten(), GenFlattenConf("flat"));
-  net.Add(new Dense(), GenDenseConf("ip", 10, 0.01, 250));
+  net.Add(GenConvConf("conv1", 32, 5, 1, 2, 0.0001), &s);
+  net.Add(GenPoolingConf("pool1", true, 3, 2, 1));
+  net.Add(GenReLUConf("relu1"));
+  net.Add(GenLRNConf("lrn1"));
+  net.Add(GenConvConf("conv2", 32, 5, 1, 2, 0.01));
+  net.Add(GenReLUConf("relu2"));
+  net.Add(GenPoolingConf("pool2", false, 3, 2, 1));
+  net.Add(GenLRNConf("lrn2"));
+  net.Add(GenConvConf("conv3", 64, 5, 1, 2, 0.01));
+  net.Add(GenReLUConf("relu3"));
+  net.Add(GenPoolingConf("pool3", false, 3, 2, 1));
+  net.Add(GenFlattenConf("flat"));
+  net.Add(GenDenseConf("ip", 10, 0.01, 250));
   return net;
 }
 
@@ -228,35 +223,18 @@ void Train(float lr, int num_epoch, string data_dir) {
   mem_conf.add_device(0);
   mem_conf.add_device(1);
   std::shared_ptr<DeviceMemPool> mem_pool(new CnMemPool(mem_conf));
-  std::shared_ptr<CudaGPU> cuda_1(new CudaGPU(0, mem_pool));
-  std::shared_ptr<CudaGPU> cuda_2(new CudaGPU(1, mem_pool));
-  net_1.ToDevice(cuda_1);
-  net_2.ToDevice(cuda_2);
-
-  /*
-  // this does not work for net_2
-  train_x_2.ResetLike(train_x);
-  train_y_2.ResetLike(train_y);
-  test_x_2.ResetLike(test_x);
-  test_y_2.ResetLike(test_y);
-
-  train_x.ToDevice(cuda_1);
-  train_y.ToDevice(cuda_1);
-  test_x.ToDevice(cuda_1);
-  test_y.ToDevice(cuda_1);
+  std::shared_ptr<CudaGPU> dev_1(new CudaGPU(0, mem_pool));
+  std::shared_ptr<CudaGPU> dev_2(new CudaGPU(1, mem_pool));
 
-  train_x_2.ToDevice(cuda_2);
-  train_y_2.ToDevice(cuda_2);
-  test_x_2.ToDevice(cuda_2);
-  test_y_2.ToDevice(cuda_2);
-  */
+  net_1.ToDevice(dev_1);
+  net_2.ToDevice(dev_2);
 
-  train_x_1.ToDevice(cuda_1);
-  train_y_1.ToDevice(cuda_1);
-  test_x.ToDevice(cuda_1);
-  test_y.ToDevice(cuda_1);
-  train_x_2.ToDevice(cuda_2);
-  train_y_2.ToDevice(cuda_2);
+  train_x_1.ToDevice(dev_1);
+  train_y_1.ToDevice(dev_1);
+  test_x.ToDevice(dev_1);
+  test_y.ToDevice(dev_1);
+  train_x_2.ToDevice(dev_2);
+  train_y_2.ToDevice(dev_2);
 
   // net.Train(100, num_epoch, train_x, train_y, test_x, test_y);
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/examples/cifar10/alexnet.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/alexnet.cc b/examples/cifar10/alexnet.cc
index 6480557..e1363e4 100644
--- a/examples/cifar10/alexnet.cc
+++ b/examples/cifar10/alexnet.cc
@@ -26,19 +26,14 @@
 #include "singa/model/metric.h"
 #include "singa/utils/channel.h"
 #include "singa/utils/string.h"
-#include "../../src/model/layer/cudnn_convolution.h"
-#include "../../src/model/layer/cudnn_activation.h"
-#include "../../src/model/layer/cudnn_pooling.h"
-#include "../../src/model/layer/cudnn_lrn.h"
-#include "../../src/model/layer/dense.h"
-#include "../../src/model/layer/flatten.h"
 namespace singa {
 
+const std::string engine = "cudnn";
 LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
                       int pad, float std) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnConvolution");
+  conf.set_type(engine + "_convolution");
   ConvolutionConf *conv = conf.mutable_convolution_conf();
   conv->set_num_output(nb_filter);
   conv->add_kernel_size(kernel);
@@ -63,7 +58,7 @@ LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
                          int pad) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnPooling");
+  conf.set_type(engine + "_pooling");
   PoolingConf *pool = conf.mutable_pooling_conf();
   pool->set_kernel_size(kernel);
   pool->set_stride(stride);
@@ -75,14 +70,14 @@ LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
 LayerConf GenReLUConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("RELU");
+  conf.set_type(engine + "_relu");
   return conf;
 }
 
 LayerConf GenDenseConf(string name, int num_output, float std, float wd) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("Dense");
+  conf.set_type("singa_dense");
   DenseConf *dense = conf.mutable_dense_conf();
   dense->set_num_output(num_output);
 
@@ -104,7 +99,7 @@ LayerConf GenDenseConf(string name, int num_output, float std, float wd) {
 LayerConf GenLRNConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnLRN");
+  conf.set_type(engine + "_lrn");
   LRNConf *lrn = conf.mutable_lrn_conf();
   lrn->set_local_size(3);
   lrn->set_alpha(5e-05);
@@ -115,7 +110,7 @@ LayerConf GenLRNConf(string name) {
 LayerConf GenFlattenConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("Flatten");
+  conf.set_type("singa_flatten");
   return conf;
 }
 
@@ -123,20 +118,19 @@ FeedForwardNet CreateNet() {
   FeedForwardNet net;
   Shape s{3, 32, 32};
 
-  net.Add(new CudnnConvolution(), GenConvConf("conv1", 32, 5, 1, 2, 0.0001),
-          &s);
-  net.Add(new CudnnPooling(), GenPoolingConf("pool1", true, 3, 2, 1));
-  net.Add(new CudnnActivation(), GenReLUConf("relu1"));
-  net.Add(new CudnnLRN(), GenLRNConf("lrn1"));
-  net.Add(new CudnnConvolution(), GenConvConf("conv2", 32, 5, 1, 2, 0.01));
-  net.Add(new CudnnActivation(), GenReLUConf("relu2"));
-  net.Add(new CudnnPooling(), GenPoolingConf("pool2", false, 3, 2, 1));
-  net.Add(new CudnnLRN(), GenLRNConf("lrn2"));
-  net.Add(new CudnnConvolution, GenConvConf("conv3", 64, 5, 1, 2, 0.01));
-  net.Add(new CudnnActivation(), GenReLUConf("relu3"));
-  net.Add(new CudnnPooling(), GenPoolingConf("pool3", false, 3, 2, 1));
-  net.Add(new Flatten(), GenFlattenConf("flat"));
-  net.Add(new Dense(), GenDenseConf("ip", 10, 0.01, 250));
+  net.Add(GenConvConf("conv1", 32, 5, 1, 2, 0.0001), &s);
+  net.Add(GenPoolingConf("pool1", true, 3, 2, 1));
+  net.Add(GenReLUConf("relu1"));
+  net.Add(GenLRNConf("lrn1"));
+  net.Add(GenConvConf("conv2", 32, 5, 1, 2, 0.01));
+  net.Add(GenReLUConf("relu2"));
+  net.Add(GenPoolingConf("pool2", false, 3, 2, 1));
+  net.Add(GenLRNConf("lrn2"));
+  net.Add(GenConvConf("conv3", 64, 5, 1, 2, 0.01));
+  net.Add(GenReLUConf("relu3"));
+  net.Add(GenPoolingConf("pool3", false, 3, 2, 1));
+  net.Add(GenFlattenConf("flat"));
+  net.Add(GenDenseConf("ip", 10, 0.01, 250));
   return net;
 }
 
@@ -184,12 +178,12 @@ void Train(float lr, int num_epoch, string data_dir) {
   Accuracy acc;
   net.Compile(true, &sgd, &loss, &acc);
 
-  auto cuda = std::make_shared<CudaGPU>();
-  net.ToDevice(cuda);
-  train_x.ToDevice(cuda);
-  train_y.ToDevice(cuda);
-  test_x.ToDevice(cuda);
-  test_y.ToDevice(cuda);
+  auto dev = std::make_shared<CudaGPU>();
+  net.ToDevice(dev);
+  train_x.ToDevice(dev);
+  train_y.ToDevice(dev);
+  test_x.ToDevice(dev);
+  test_y.ToDevice(dev);
   net.Train(100, num_epoch, train_x, train_y, test_x, test_y);
 }
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/examples/cifar10/alexnet.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/alexnet.py b/examples/cifar10/alexnet.py
index 96c339a..9ed5599 100644
--- a/examples/cifar10/alexnet.py
+++ b/examples/cifar10/alexnet.py
@@ -31,7 +31,10 @@ from singa import loss
 from singa import net as ffnet
 
 
-def create_net():
+def create_net(use_cpu=False):
+    if use_cpu:
+        layer.engine = 'singa'
+
     net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy())
     W0_specs = {'init': 'gaussian', 'mean': 0, 'std': 0.0001}
     W1_specs = {'init': 'gaussian', 'mean': 0, 'std': 0.01}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/examples/cifar10/train.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py
index cb4110d..3285651 100644
--- a/examples/cifar10/train.py
+++ b/examples/cifar10/train.py
@@ -96,16 +96,23 @@ def alexnet_lr(epoch):
         return 0.00001
 
 
-def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100):
+def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
+          use_cpu=False):
     print 'Start intialization............'
-    cuda = device.create_cuda_gpu()
-    net.to_device(cuda)
+    if use_cpu:
+        print 'Using CPU'
+        dev = device.get_default_device()
+    else:
+        print 'Using GPU'
+        dev = device.create_cuda_gpu()
+
+    net.to_device(dev)
     opt = optimizer.SGD(momentum=0.9, weight_decay=0.004)
     for (p, specs) in zip(net.param_values(), net.param_specs()):
         opt.register(p, specs)
 
-    tx = tensor.Tensor((batch_size, 3, 32, 32), cuda)
-    ty = tensor.Tensor((batch_size,), cuda, core_pb2.kInt)
+    tx = tensor.Tensor((batch_size, 3, 32, 32), dev)
+    ty = tensor.Tensor((batch_size,), dev, core_pb2.kInt)
     train_x, train_y, test_x, test_y = data
     num_train_batch = train_x.shape[0] / batch_size
     num_test_batch = test_x.shape[0] / batch_size
@@ -127,7 +134,7 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100):
             # update progress bar
             utils.update_progress(b * 1.0 / num_train_batch,
                                   'training loss = %f, accuracy = %f' % (l, a))
-        info = 'training loss = %f, training accuracy = %f' \
+        info = '\ntraining loss = %f, training accuracy = %f' \
             % (loss / num_train_batch, acc / num_train_batch)
         print info
 
@@ -146,9 +153,11 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100):
     net.save('model.bin')  # save model params into checkpoint file
 
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='Train vgg/alexnet for cifar10')
+    parser = argparse.ArgumentParser(description='Train vgg/alexnet for '
+                                     'cifar10 dataset')
     parser.add_argument('model', choices=['vgg', 'alexnet'], default='alexnet')
     parser.add_argument('data', default='cifar-10-batches-py')
+    parser.add_argument('--use_cpu', action='store_true')
     args = parser.parse_args()
     assert os.path.exists(args.data), \
         'Pls download the cifar10 dataset via "download_data.py py"'
@@ -157,9 +166,11 @@ if __name__ == '__main__':
     test_x, test_y = load_test_data(args.data)
     if args.model == 'alexnet':
         train_x, test_x = normalize_for_alexnet(train_x, test_x)
-        net = alexnet.create_net()
-        train((train_x, train_y, test_x, test_y), net, 140, alexnet_lr, 0.004)
+        net = alexnet.create_net(args.use_cpu)
+        train((train_x, train_y, test_x, test_y), net, 140, alexnet_lr, 0.004,
+              use_cpu=args.use_cpu)
     else:
         train_x, test_x = normalize_for_vgg(train_x, test_x)
-        net = vgg.create_net()
-        train((train_x, train_y, test_x, test_y), net, 250, vgg_lr, 0.0005)
+        net = vgg.create_net(args.use_cpu)
+        train((train_x, train_y, test_x, test_y), net, 250, vgg_lr, 0.0005,
+              use_cpu=args.use_cpu)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/examples/cifar10/vgg-parallel.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/vgg-parallel.cc b/examples/cifar10/vgg-parallel.cc
index c6b7fa1..149cb21 100644
--- a/examples/cifar10/vgg-parallel.cc
+++ b/examples/cifar10/vgg-parallel.cc
@@ -28,27 +28,20 @@
 #include "singa/utils/channel.h"
 #include "singa/utils/string.h"
 #include "singa/core/memory.h"
-#include "../../src/model/layer/cudnn_convolution.h"
-#include "../../src/model/layer/cudnn_activation.h"
-#include "../../src/model/layer/cudnn_pooling.h"
-#include "../../src/model/layer/cudnn_lrn.h"
-#include "../../src/model/layer/dropout.h"
-#include "../../src/model/layer/cudnn_batchnorm.h"
-#include "../../src/model/layer/dense.h"
-#include "../../src/model/layer/flatten.h"
 #include <thread>
 #include <memory>
 #include <cmath>
 
 namespace singa {
 
+const std::string engine = "cudnn";
 const float default_wd  = 0.0005f;
 
 LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
                       int pad, float std = .02f, float bias = .0f) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnConvolution");
+  conf.set_type(engine + "_convolution");
   ConvolutionConf *conv = conf.mutable_convolution_conf();
   conv->set_num_output(nb_filter);
   conv->add_kernel_size(kernel);
@@ -75,7 +68,7 @@ LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
                          int pad) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnPooling");
+  conf.set_type(engine + "_pooling");
   PoolingConf *pool = conf.mutable_pooling_conf();
   pool->set_kernel_size(kernel);
   pool->set_stride(stride);
@@ -87,14 +80,14 @@ LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
 LayerConf GenReLUConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("RELU");
+  conf.set_type(engine + "_relu");
   return conf;
 }
 
 LayerConf GenDenseConf(string name, int num_output, float std, float wd = default_wd) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("Dense");
+  conf.set_type("singa_dense");
   DenseConf *dense = conf.mutable_dense_conf();
   dense->set_num_output(num_output);
 
@@ -116,14 +109,14 @@ LayerConf GenDenseConf(string name, int num_output, float std, float wd = defaul
 LayerConf GenFlattenConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("Flatten");
+  conf.set_type("singa_flatten");
   return conf;
 }
 
 LayerConf GenBatchNormConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnBatchNorm");
+  conf.set_type(engine + "_batchnorm");
   ParamSpec *gammaspec = conf.add_param();
   gammaspec->set_name(name + "_gamma");
   auto gammafill = gammaspec->mutable_filler();
@@ -155,7 +148,7 @@ LayerConf GenBatchNormConf(string name) {
 LayerConf GenDropoutConf(string name, float dropout_ratio) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("Dropout");
+  conf.set_type(engine + "_dropout");
   DropoutConf *dropout = conf.mutable_dropout_conf();
   dropout->set_dropout_ratio(dropout_ratio);
 
@@ -163,47 +156,47 @@ LayerConf GenDropoutConf(string name, float dropout_ratio) {
 }
 
 void ConvBNReLU(FeedForwardNet& net, string name, int nb_filter, Shape* shape = nullptr) {
-  net.Add(new CudnnConvolution(), GenConvConf(name+"_conv", nb_filter, 3, 1, 1), shape);
-  net.Add(new CudnnBatchNorm(), GenBatchNormConf(name+"_bn"));
-  net.Add(new CudnnActivation(), GenReLUConf(name+"_relu"));
+  net.Add(GenConvConf(name+"_conv", nb_filter, 3, 1, 1), shape);
+  net.Add(GenBatchNormConf(name+"_bn"));
+  net.Add(GenReLUConf(name+"_relu"));
 }
 
 FeedForwardNet CreateNet() {
   FeedForwardNet net;
   Shape s{3, 32, 32};
   ConvBNReLU(net, "conv1_1", 64, &s);
-  net.Add(new Dropout(), GenDropoutConf("drop1", 0.3));
+  net.Add(GenDropoutConf("drop1", 0.3));
   ConvBNReLU(net, "conv1_2", 64);
-  net.Add(new CudnnPooling(), GenPoolingConf("pool1", true, 2, 2, 0));
+  net.Add(GenPoolingConf("pool1", true, 2, 2, 0));
   ConvBNReLU(net, "conv2_1", 128);
-  net.Add(new Dropout(), GenDropoutConf("drop2", 0.4));
+  net.Add(GenDropoutConf("drop2", 0.4));
   ConvBNReLU(net, "conv2_2", 128);
-  net.Add(new CudnnPooling(), GenPoolingConf("pool2", true, 2, 2, 0));
+  net.Add(GenPoolingConf("pool2", true, 2, 2, 0));
   ConvBNReLU(net, "conv3_1", 256);
-  net.Add(new Dropout(), GenDropoutConf("drop3_1", 0.4));
+  net.Add(GenDropoutConf("drop3_1", 0.4));
   ConvBNReLU(net, "conv3_2", 256);
-  net.Add(new Dropout(), GenDropoutConf("drop3_2", 0.4));
+  net.Add(GenDropoutConf("drop3_2", 0.4));
   ConvBNReLU(net, "conv3_3", 256);
-  net.Add(new CudnnPooling(), GenPoolingConf("pool3", true, 2, 2, 0));
+  net.Add(GenPoolingConf("pool3", true, 2, 2, 0));
   ConvBNReLU(net, "conv4_1", 512);
-  net.Add(new Dropout(), GenDropoutConf("drop4_1", 0.4));
+  net.Add(GenDropoutConf("drop4_1", 0.4));
   ConvBNReLU(net, "conv4_2", 512);
-  net.Add(new Dropout(), GenDropoutConf("drop4_2", 0.4));
+  net.Add(GenDropoutConf("drop4_2", 0.4));
   ConvBNReLU(net, "conv4_3", 512);
-  net.Add(new CudnnPooling(), GenPoolingConf("pool4", true, 2, 2, 0));
+  net.Add(GenPoolingConf("pool4", true, 2, 2, 0));
   ConvBNReLU(net, "conv5_1", 512);
-  net.Add(new Dropout(), GenDropoutConf("drop5_1", 0.4));
+  net.Add(GenDropoutConf("drop5_1", 0.4));
   ConvBNReLU(net, "conv5_2", 512);
-  net.Add(new Dropout(), GenDropoutConf("drop5_2", 0.4));
+  net.Add(GenDropoutConf("drop5_2", 0.4));
   ConvBNReLU(net, "conv5_3", 512);
-  net.Add(new CudnnPooling(), GenPoolingConf("pool5", true, 2, 2, 0));
-  net.Add(new Flatten(), GenFlattenConf("flat"));
-  net.Add(new Dropout(), GenDropoutConf("flat_drop", 0.5));
-  net.Add(new Dense(), GenDenseConf("ip1", 512, 0.02));
-  net.Add(new CudnnBatchNorm(), GenBatchNormConf("ip1_bn"));
-  net.Add(new CudnnActivation(), GenReLUConf("ip1_relu"));
-  net.Add(new Dropout(), GenDropoutConf("ip1_drop", 0.5));
-  net.Add(new Dense(), GenDenseConf("ip2", 10, 0.02));
+  net.Add(GenPoolingConf("pool5", true, 2, 2, 0));
+  net.Add(GenFlattenConf("flat"));
+  net.Add(GenDropoutConf("flat_drop", 0.5));
+  net.Add(GenDenseConf("ip1", 512, 0.02));
+  net.Add(GenBatchNormConf("ip1_bn"));
+  net.Add(GenReLUConf("ip1_relu"));
+  net.Add(GenDropoutConf("ip1_drop", 0.5));
+  net.Add(GenDenseConf("ip2", 10, 0.02));
 
   return net;
 }
@@ -294,17 +287,17 @@ void Train(float lr, int num_epoch, string data_dir) {
   mem_conf.add_device(0);
   mem_conf.add_device(1);
   std::shared_ptr<DeviceMemPool> mem_pool(new CnMemPool(mem_conf));
-  std::shared_ptr<CudaGPU> cuda_1(new CudaGPU(0, mem_pool));
-  std::shared_ptr<CudaGPU> cuda_2(new CudaGPU(1, mem_pool));
-  net_1.ToDevice(cuda_1);
-  net_2.ToDevice(cuda_2);
-
-  train_x_1.ToDevice(cuda_1);
-  train_y_1.ToDevice(cuda_1);
-  test_x.ToDevice(cuda_1);
-  test_y.ToDevice(cuda_1);
-  train_x_2.ToDevice(cuda_2);
-  train_y_2.ToDevice(cuda_2);
+  std::shared_ptr<CudaGPU> dev_1(new CudaGPU(0, mem_pool));
+  std::shared_ptr<CudaGPU> dev_2(new CudaGPU(1, mem_pool));
+  net_1.ToDevice(dev_1);
+  net_2.ToDevice(dev_2);
+
+  train_x_1.ToDevice(dev_1);
+  train_y_1.ToDevice(dev_1);
+  test_x.ToDevice(dev_1);
+  test_y.ToDevice(dev_1);
+  train_x_2.ToDevice(dev_2);
+  train_y_2.ToDevice(dev_2);
 
   LOG(INFO) << "Launching thread...";
   std::thread t1 =

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/examples/cifar10/vgg.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/vgg.py b/examples/cifar10/vgg.py
index 0b9bb56..97e690c 100644
--- a/examples/cifar10/vgg.py
+++ b/examples/cifar10/vgg.py
@@ -40,40 +40,42 @@ def ConvBnReLU(net, name, nb_filers, sample_shape=None):
     net.add(layer.Activation(name + '_3'))
 
 
-def create_net():
+def create_net(use_cpu=False):
+    if use_cpu:
+        layer.engine = 'singa'
     net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy())
     ConvBnReLU(net, 'conv1_1', 64, (3, 32, 32))
-    net.add(layer.Dropout('drop1', 0.3, engine='cuda'))
+    net.add(layer.Dropout('drop1', 0.3))
     ConvBnReLU(net, 'conv1_2', 64)
     net.add(layer.MaxPooling2D('pool1', 2, 2, border_mode='valid'))
     ConvBnReLU(net, 'conv2_1', 128)
-    net.add(layer.Dropout('drop2_1', 0.4, engine='cuda'))
+    net.add(layer.Dropout('drop2_1', 0.4))
     ConvBnReLU(net, 'conv2_2', 128)
     net.add(layer.MaxPooling2D('pool2', 2, 2, border_mode='valid'))
     ConvBnReLU(net, 'conv3_1', 256)
-    net.add(layer.Dropout('drop3_1', 0.4, engine='cuda'))
+    net.add(layer.Dropout('drop3_1', 0.4))
     ConvBnReLU(net, 'conv3_2', 256)
-    net.add(layer.Dropout('drop3_2', 0.4, engine='cuda'))
+    net.add(layer.Dropout('drop3_2', 0.4))
     ConvBnReLU(net, 'conv3_3', 256)
     net.add(layer.MaxPooling2D('pool3', 2, 2, border_mode='valid'))
     ConvBnReLU(net, 'conv4_1', 512)
-    net.add(layer.Dropout('drop4_1', 0.4, engine='cuda'))
+    net.add(layer.Dropout('drop4_1', 0.4))
     ConvBnReLU(net, 'conv4_2', 512)
-    net.add(layer.Dropout('drop4_2', 0.4, engine='cuda'))
+    net.add(layer.Dropout('drop4_2', 0.4))
     ConvBnReLU(net, 'conv4_3', 512)
     net.add(layer.MaxPooling2D('pool4', 2, 2, border_mode='valid'))
     ConvBnReLU(net, 'conv5_1', 512)
-    net.add(layer.Dropout('drop5_1', 0.4, engine='cuda'))
+    net.add(layer.Dropout('drop5_1', 0.4))
     ConvBnReLU(net, 'conv5_2', 512)
-    net.add(layer.Dropout('drop5_2', 0.4, engine='cuda'))
+    net.add(layer.Dropout('drop5_2', 0.4))
     ConvBnReLU(net, 'conv5_3', 512)
     net.add(layer.MaxPooling2D('pool5', 2, 2, border_mode='valid'))
     net.add(layer.Flatten('flat'))
-    net.add(layer.Dropout('drop_flat', 0.5, engine='cuda'))
+    net.add(layer.Dropout('drop_flat', 0.5))
     net.add(layer.Dense('ip1', 512))
     net.add(layer.BatchNormalization('batchnorm_ip1'))
     net.add(layer.Activation('relu_ip1'))
-    net.add(layer.Dropout('drop_ip2', 0.5, engine='cuda'))
+    net.add(layer.Dropout('drop_ip2', 0.5))
     net.add(layer.Dense('ip2', 10))
     print 'Start intialization............'
     for (p, name) in zip(net.param_values(), net.param_names()):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/examples/imagenet/alexnet.cc
----------------------------------------------------------------------
diff --git a/examples/imagenet/alexnet.cc b/examples/imagenet/alexnet.cc
index 270312c..3fb5d04 100644
--- a/examples/imagenet/alexnet.cc
+++ b/examples/imagenet/alexnet.cc
@@ -22,13 +22,6 @@
 #include "singa/singa_config.h"
 #ifdef USE_OPENCV
 #include <cmath>
-#include "../../src/model/layer/cudnn_activation.h"
-#include "../../src/model/layer/cudnn_convolution.h"
-#include "../../src/model/layer/dropout.h"
-#include "../../src/model/layer/cudnn_lrn.h"
-#include "../../src/model/layer/cudnn_pooling.h"
-#include "../../src/model/layer/dense.h"
-#include "../../src/model/layer/flatten.h"
 #include "./ilsvrc12.h"
 #include "singa/io/snapshot.h"
 #include "singa/model/feed_forward_net.h"
@@ -40,11 +33,12 @@
 #include "singa/utils/timer.h"
 namespace singa {
 
+const std::string engine = "cudnn";
 LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
                       int pad, float std, float bias = .0f) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnConvolution");
+  conf.set_type(engine + "_convolution");
   ConvolutionConf *conv = conf.mutable_convolution_conf();
   conv->set_num_output(nb_filter);
   conv->add_kernel_size(kernel);
@@ -71,7 +65,7 @@ LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
                          int pad) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnPooling");
+  conf.set_type(engine + "_pooling");
   PoolingConf *pool = conf.mutable_pooling_conf();
   pool->set_kernel_size(kernel);
   pool->set_stride(stride);
@@ -83,7 +77,7 @@ LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
 LayerConf GenReLUConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("RELU");
+  conf.set_type(engine + "_relu");
   return conf;
 }
 
@@ -91,7 +85,7 @@ LayerConf GenDenseConf(string name, int num_output, float std, float wd,
                        float bias = .0f) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("Dense");
+  conf.set_type("singa_dense");
   DenseConf *dense = conf.mutable_dense_conf();
   dense->set_num_output(num_output);
 
@@ -115,7 +109,7 @@ LayerConf GenDenseConf(string name, int num_output, float std, float wd,
 LayerConf GenLRNConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnLRN");
+  conf.set_type(engine + "_lrn");
   LRNConf *lrn = conf.mutable_lrn_conf();
   lrn->set_local_size(5);
   lrn->set_alpha(1e-04);
@@ -126,14 +120,14 @@ LayerConf GenLRNConf(string name) {
 LayerConf GenFlattenConf(string name) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("Flatten");
+  conf.set_type("singa_flatten");
   return conf;
 }
 
 LayerConf GenDropoutConf(string name, float dropout_ratio) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("Dropout");
+  conf.set_type(engine + "_dropout");
   DropoutConf *dropout = conf.mutable_dropout_conf();
   dropout->set_dropout_ratio(dropout_ratio);
   return conf;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/include/singa/core/device.h
----------------------------------------------------------------------
diff --git a/include/singa/core/device.h b/include/singa/core/device.h
index 778a130..4c46114 100644
--- a/include/singa/core/device.h
+++ b/include/singa/core/device.h
@@ -321,23 +321,33 @@ public:
   /// Return a string containing all hardware info, e.g., version, memory size.
   static const std::string DeviceQuery(int id, bool verbose = false);
 
+  /// Return the defualt host device
+  static std::shared_ptr<Device> GetDefaultDevice() {
+    return defaultDevice;
+  }
+
   /// Create a set of CudaGPU Device using 'num_devices' free GPUs.
   static const std::vector<std::shared_ptr<Device>>
   CreateCudaGPUs(const size_t num_devices, size_t init_size = 0);
 
   /// Create a set of CudaGPU Device using given GPU IDs.
   static const std::vector<std::shared_ptr<Device>>
-  CreateCudaGPUs(const std::vector<int> &devices, size_t init_size = 0);
-
-  /// Create a \p num_devices set of valid OpenCL devices, regardless of platforms.
-  /// If there are fewer valid devices than requested, then this method will return as many as possible.
-  /// If OpenCL is not in use, this method will return an empty array.
-  const std::vector<std::shared_ptr<Device>> CreateOpenclDevices(const size_t num_devices);
-
-  /// Create a set of valid OpenCL devices, regardless of platforms, assigning \p id to each device in sequence.
-  /// If there are fewer valid devices than requested, then this method will return as many as possible.
+  CreateCudaGPUsOn(const std::vector<int> &devices, size_t init_size = 0);
+
+  /// Create a \p num_devices set of valid OpenCL devices, regardless of
+  /// platforms.  If there are fewer valid devices than requested, then this
+  /// method will return as many as possible.If OpenCL is not in use, this
+  /// method will return an empty array.
+  const std::vector<std::shared_ptr<Device> > CreateOpenclDevices(
+             const size_t num_devices);
+
+  /// Create a set of valid OpenCL devices, regardless of platforms, assigning
+  /// \p id to each device in sequence.
+  /// If there are fewer valid devices than requested, then this method will
+  /// return as many as possible.
   /// If OpenCL is not in use, this method will return an empty array.
-  const std::vector<std::shared_ptr<Device>> CreateOpenclDevices(const vector<int>& id);
+  const std::vector<std::shared_ptr<Device> >
+  CreateOpenclDevices(const vector<int> &id);
 
   /// This function is implementd by Caffe (http://caffe.berkeleyvision.org/).
   /// This function checks the availability of GPU #device_id.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/include/singa/model/feed_forward_net.h
----------------------------------------------------------------------
diff --git a/include/singa/model/feed_forward_net.h b/include/singa/model/feed_forward_net.h
index 8adc259..1bf112c 100644
--- a/include/singa/model/feed_forward_net.h
+++ b/include/singa/model/feed_forward_net.h
@@ -39,7 +39,7 @@ class FeedForwardNet {
   ///    following the topological order.
   /// 2. this layer has already been setup (Setup function is called outside).
   /// The layer will be freed in the destructor of FeedForwardNet.
-  Layer* Add(Layer* layer);
+  std::shared_ptr<Layer> Add(std::shared_ptr<Layer> layer);
 
   // TODO(wangwei) add ConcatenateLayer and SliceLayer
   // AddConcatenateLayer(vector<Layer*> src, Layer *dst);
@@ -49,11 +49,9 @@ class FeedForwardNet {
   /// Assume the layer is added in corret order.
   /// For the first layer, 'sample_shape' (the input sample shape) is necessary
   /// for calling Setup().
-  Layer* Add(const LayerConf& conf, const Shape* sample_shape = nullptr);
+  std::shared_ptr<Layer> Add(const LayerConf& conf,
+      const Shape* sample_shape = nullptr);
 
-  /// Add a layer, and call its Setup function.
-  Layer* Add(Layer* layer, const LayerConf& conf,
-             const Shape* sample_shape = nullptr);
   /// Set some fields used for training and evaluating the neural net.
   /// This method will instantiate an Updater ,then wrap the Optimier into
   /// Updater and always register the parameters of the net instance.
@@ -147,13 +145,13 @@ class FeedForwardNet {
     return std::thread([=]() { Train(batchsize, nb_epoch, x, y); });
   }
 
-  const vector<Layer*> layers() const { return layers_; }
+  const vector<std::shared_ptr<Layer>> layers() const { return layers_; }
   const vector<string> GetParamNames() const;
   const vector<ParamSpec> GetParamSpecs() const;
   const vector<Tensor> GetParamValues() const;
 
  protected:
-  vector<Layer*> layers_;
+  vector<std::shared_ptr<Layer>> layers_;
   std::shared_ptr<Updater> updater_;
   Loss* loss_;
   Metric* metric_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/include/singa/model/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h
index d31bd95..58f0f4b 100644
--- a/include/singa/model/layer.h
+++ b/include/singa/model/layer.h
@@ -222,8 +222,8 @@ class Layer {
   vector<ParamSpec> param_specs_;
 };
 
-#define RegisterLayerClass(SubLayer) \
-  static Registra<Layer, SubLayer> _##SubLayer##Layer(#SubLayer);
+#define RegisterLayerClass(Name, SubLayer) \
+  static Registra<Layer, SubLayer> Name##SubLayer(#Name);
 
 inline std::shared_ptr<Layer> CreateLayer(const std::string type) {
   std::shared_ptr<Layer> layer(Factory<Layer>::Create(type));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/include/singa/utils/integer.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/integer.h b/include/singa/utils/integer.h
new file mode 100644
index 0000000..9c2799d
--- /dev/null
+++ b/include/singa/utils/integer.h
@@ -0,0 +1,73 @@
+/************************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *************************************************************/
+
+#ifndef INTEGER_H_
+#define INTEGER_H_
+
+#include <cstdint>
+
+namespace singa{
+static bool isNetworkOrder() {
+    int test = 1;
+    return (1 != *(uint8_t*)&test);
+}
+
+template <typename T>
+static inline T byteSwap(const T& v) {
+    int size = sizeof(v);
+    T ret;
+    uint8_t *dest = reinterpret_cast<uint8_t *>(&ret);
+    uint8_t *src = const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(&v));
+    for (int i = 0; i < size; ++i) {
+        dest[i] = src[size - i - 1];
+    }
+    return ret;
+}
+
+template <typename T>
+static inline T hton(const T& v)
+{
+    return isNetworkOrder() ? v : byteSwap(v);
+}
+
+template <typename T>
+static inline T ntoh(const T& v) 
+{
+    return hton(v);
+}
+
+static inline int appendInteger(char* buf) {return 0;}
+static inline int readInteger(char* buf) {return 0;}
+
+template<typename Type, typename... Types>
+static int appendInteger(char* buf, Type value, Types... values) {
+    *(Type*)buf = hton(value);
+    return sizeof(Type) + appendInteger(buf + sizeof(Type), values...);
+}
+
+template<typename Type, typename... Types>
+static int readInteger(char* buf, Type& value, Types&... values) {
+    value = ntoh(*(Type*)buf);
+    return sizeof(Type) + readInteger(buf + sizeof(Type), values...);
+}
+
+}
+#endif

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/core/device/platform.cc
----------------------------------------------------------------------
diff --git a/src/core/device/platform.cc b/src/core/device/platform.cc
index a4561de..a3661f2 100644
--- a/src/core/device/platform.cc
+++ b/src/core/device/platform.cc
@@ -113,11 +113,11 @@ Platform::CreateCudaGPUs(const size_t num_devices, size_t init_size) {
   const vector<int> gpus = GetGPUIDs();
   CHECK_LE(num_devices, gpus.size());
   vector<int> use_gpus(gpus.begin(), gpus.begin() + num_devices);
-  return CreateCudaGPUs(use_gpus, init_size);
+  return CreateCudaGPUsOn(use_gpus, init_size);
 }
 
 const vector<shared_ptr<Device> >
-Platform::CreateCudaGPUs(const vector<int> &devices, size_t init_size) {
+Platform::CreateCudaGPUsOn(const vector<int> &devices, size_t init_size) {
   MemPoolConf conf;
   if (init_size > 0)
     conf.set_init_size(init_size);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index e260f9e..dfb1eb2 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -452,7 +452,7 @@ float Tensor::L1() const {
   float nrm = 0.0f;
   TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
     device_->Exec([&nrm, this](Context *ctx) {
-      DType ret;
+      DType ret = DType(0);
       Asum<DType, Lang>(this->Size(), this->block(), &ret, ctx);
       nrm = TypeCast<DType, float>(ret);
     }, {this->block()}, {});
@@ -465,7 +465,7 @@ float Tensor::L2() const {
   float nrm = 0.0f;
   TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
     device_->Exec([&nrm, this](Context *ctx) {
-      DType ret;
+      DType ret = DType(0);
       Nrm2<DType, Lang>(this->Size(), this->block(), &ret, ctx);
       nrm = TypeCast<DType, float>(ret);
     }, {this->block()}, {});

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/core/tensor/tensor_math_cpp.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index 941931d..a2802d5 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -239,7 +239,7 @@ void Sqrt<float, lang::Cpp>(const size_t num, const Block *in, Block *out,
   float *outPtr = static_cast<float *>(out->mutable_data());
   const float *inPtr = static_cast<const float *>(in->data());
   for (size_t i = 0; i < num; i++) {
-    CHECK_GT(inPtr[i], 0.f);
+    CHECK_GE(inPtr[i], 0.f);
     outPtr[i] = sqrt(inPtr[i]);
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/feed_forward_net.cc
----------------------------------------------------------------------
diff --git a/src/model/feed_forward_net.cc b/src/model/feed_forward_net.cc
index 9450c9e..514d6e2 100644
--- a/src/model/feed_forward_net.cc
+++ b/src/model/feed_forward_net.cc
@@ -26,23 +26,16 @@
 namespace singa {
 
 FeedForwardNet::~FeedForwardNet() {
-  for (auto layer : layers_) delete layer;
-}
-Layer* FeedForwardNet::Add(Layer* layer) {
-  layers_.push_back(layer);
-  return layer;
 }
 
-Layer* FeedForwardNet::Add(const LayerConf& conf, const Shape* sample_shape) {
-  CHECK(sample_shape != nullptr || layers_.size())
-      << "Must provide the input sample shape for the first layer";
-  Layer* layer = nullptr;  // TODO(wangwei) use CreateLayer(conf.type());
-  Add(layer, conf, sample_shape);
+std::shared_ptr<Layer> FeedForwardNet::Add(std::shared_ptr<Layer> layer) {
+  layers_.push_back(layer);
   return layer;
 }
 
-Layer* FeedForwardNet::Add(Layer* layer, const LayerConf& conf,
-                           const Shape* sample_shape) {
+std::shared_ptr<Layer> FeedForwardNet::Add(const LayerConf& conf,
+    const Shape* sample_shape) {
+  std::shared_ptr<Layer> layer(CreateLayer(conf.type()));
   CHECK(conf.has_name()) << "Must set layer name";
   if (sample_shape == nullptr)
     layer->Setup(layers_.back()->GetOutputSampleShape(), conf);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/activation.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/activation.cc b/src/model/layer/activation.cc
index 2497c31..aa40edb 100644
--- a/src/model/layer/activation.cc
+++ b/src/model/layer/activation.cc
@@ -18,14 +18,23 @@
 
 #include "singa/model/layer.h"
 #include "./activation.h"
+#include "singa/utils/string.h"
 namespace singa {
 
-RegisterLayerClass(Activation);
+RegisterLayerClass(singa_relu, Activation);
+RegisterLayerClass(singa_sigmoid, Activation);
+RegisterLayerClass(singa_tanh, Activation);
 
 void Activation::Setup(const Shape& in_sample, const LayerConf& conf) {
   Layer::Setup(in_sample, conf);
-  mode_ = conf.type();
-  if (mode_ == "RELU") {
+  auto pos = conf.type().find_first_of('_');
+  CHECK_NE(pos, string::npos) << "There should be a '_' in the laye type "
+    << conf.type();
+  mode_ = ToLowerCase(conf.type().substr(pos + 1));
+  if (mode_ != "relu" && mode_ != "sigmoid" && mode_ != "tanh")
+    LOG(FATAL) << "Unkown activation type: " << conf.type() << " " << mode_
+      << ". Please use singa_relu, singa_sigmoid, or singa_tanh";
+  if (mode_ == "relu") {
     neg_slope_ = conf.relu_conf().negative_slope();
   }
   out_sample_shape_ = in_sample;
@@ -33,13 +42,13 @@ void Activation::Setup(const Shape& in_sample, const LayerConf& conf) {
 
 const Tensor Activation::Forward(int flag, const Tensor& input) {
   Tensor output;
-  if (mode_ == "SIGMOID") {
+  if (mode_ == "sigmoid") {
     output = Sigmoid(input);
     if (flag & kTrain) buf_.push(output);
-  } else if (mode_ == "TANH") {
+  } else if (mode_ == "tanh") {
     output = Tanh(input);
     if (flag & kTrain) buf_.push(output);
-  } else if (mode_ == "RELU") {
+  } else if (mode_ == "relu") {
     output = ReLU(input);
     if (flag & kTrain) buf_.push(input);
   } else
@@ -55,11 +64,11 @@ const std::pair<Tensor, vector<Tensor>> Activation::Backward(
   // activation.
   Tensor input_grad, inout = buf_.top();
   buf_.pop();
-  if (mode_ == "SIGMOID")
+  if (mode_ == "sigmoid")
     input_grad = grad * inout * (inout * (-1.f) + 1.f);
-  else if (mode_ == "TANH")
+  else if (mode_ == "tanh")
     input_grad = grad * (inout * inout * (-1.f) + 1.f);
-  else if (mode_ == "RELU")
+  else if (mode_ == "relu")
     input_grad = grad * (inout > 0.f) + (inout <= 0.f) * neg_slope_;
   else LOG(FATAL) << "Unkown activation: " << mode_;
   return std::make_pair(input_grad, param_grad);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/activation.h
----------------------------------------------------------------------
diff --git a/src/model/layer/activation.h b/src/model/layer/activation.h
index e3fb657..7d15979 100644
--- a/src/model/layer/activation.h
+++ b/src/model/layer/activation.h
@@ -26,7 +26,7 @@ namespace singa {
 class Activation : public Layer {
  public:
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "Activation"; }
+  // const std::string layer_type() const override { return "Activation"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.cc b/src/model/layer/batchnorm.cc
index 6ea9f2a..f348661 100644
--- a/src/model/layer/batchnorm.cc
+++ b/src/model/layer/batchnorm.cc
@@ -21,7 +21,7 @@
 #include "batchnorm.h"
 
 namespace singa {
-RegisterLayerClass(BatchNorm);
+RegisterLayerClass(singa_batchnorm, BatchNorm);
 void BatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
   Layer::Setup(in_sample, conf);
   out_sample_shape_ = in_sample;
@@ -78,8 +78,8 @@ const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
     runningVariance_ *= 1.0f - factor_;
     Axpy(factor_, var, &runningVariance_);
     Tensor tmp = var.Clone();
-    tmp += 1e-6f;
     tmp = Sqrt(tmp);
+    tmp += 1e-6f;
     xnorm = x.Clone();
     SubRow(mean, &xnorm);
     DivRow(tmp, &xnorm);
@@ -94,8 +94,8 @@ const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
     xnorm = x.Clone();
     SubRow(runningMean_, &xnorm);
     Tensor tmp = runningVariance_.Clone();
-    tmp += 1e-6f;
     tmp = Sqrt(tmp);
+    tmp += 1e-6f;
     DivRow(tmp, &xnorm);
     output = xnorm.Clone();
     MultRow(bnScale_, &output);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.h b/src/model/layer/batchnorm.h
index f3d83ab..c2cfde9 100644
--- a/src/model/layer/batchnorm.h
+++ b/src/model/layer/batchnorm.h
@@ -29,7 +29,7 @@ namespace singa {
 class BatchNorm : public Layer {
  public:
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "BatchNorm"; }
+  // const std::string layer_type() const override { return "BatchNorm"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/convolution.cc b/src/model/layer/convolution.cc
index 1bf6b39..4fc209f 100644
--- a/src/model/layer/convolution.cc
+++ b/src/model/layer/convolution.cc
@@ -23,7 +23,7 @@
 namespace singa {
 using std::vector;
 
-RegisterLayerClass(Convolution);
+RegisterLayerClass(singa_convolution, Convolution);
 void Convolution::Setup(const Shape &in_sample, const LayerConf &conf) {
   Layer::Setup(in_sample, conf);
   ConvolutionConf conv_conf = conf.convolution_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/convolution.h
----------------------------------------------------------------------
diff --git a/src/model/layer/convolution.h b/src/model/layer/convolution.h
index 1383a66..d85a17b 100644
--- a/src/model/layer/convolution.h
+++ b/src/model/layer/convolution.h
@@ -27,7 +27,7 @@ namespace singa {
 class Convolution : public Layer {
  public:
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "Convolution"; }
+  // const std::string layer_type() const override { return "Convolution"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const vector<size_t>& in_shape, const LayerConf& conf) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_activation.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_activation.cc b/src/model/layer/cudnn_activation.cc
index c86539d..4ecb375 100644
--- a/src/model/layer/cudnn_activation.cc
+++ b/src/model/layer/cudnn_activation.cc
@@ -25,7 +25,9 @@
 #include "singa/utils/logging.h"
 
 namespace singa {
-RegisterLayerClass(CudnnActivation);
+RegisterLayerClass(cudnn_relu, CudnnActivation);
+RegisterLayerClass(cudnn_sigmoid, CudnnActivation);
+RegisterLayerClass(cudnn_tanh, CudnnActivation);
 CudnnActivation::~CudnnActivation() {
   if (acti_desc_ != nullptr)
     CUDNN_CHECK(cudnnDestroyActivationDescriptor(acti_desc_));
@@ -40,11 +42,11 @@ void CudnnActivation::InitCudnn(size_t size, DataType dtype) {
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(
       desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, 1, 1, size));
 
-  if (mode_ == "SIGMOID")
+  if (mode_ == "sigmoid")
     cudnn_mode_ = CUDNN_ACTIVATION_SIGMOID;
-  else if (mode_ == "TANH")
+  else if (mode_ == "tanh")
     cudnn_mode_ = CUDNN_ACTIVATION_TANH;
-  else if (mode_ == "RELU")
+  else if (mode_ == "relu")
     cudnn_mode_ = CUDNN_ACTIVATION_RELU;
   else
     LOG(FATAL) << "Unkown activation: " << mode_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_activation.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_activation.h b/src/model/layer/cudnn_activation.h
index 526e03f..c69d157 100644
--- a/src/model/layer/cudnn_activation.h
+++ b/src/model/layer/cudnn_activation.h
@@ -35,7 +35,7 @@ class CudnnActivation : public Activation {
  public:
   ~CudnnActivation();
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "CudnnActivation"; }
+  // const std::string layer_type() const override { return "CudnnActivation"; }
 
   const Tensor Forward(int flag, const Tensor& input) override;
   const std::pair<Tensor, vector<Tensor>> Backward(int flag,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.cc b/src/model/layer/cudnn_batchnorm.cc
index 461f1b6..01682b7 100644
--- a/src/model/layer/cudnn_batchnorm.cc
+++ b/src/model/layer/cudnn_batchnorm.cc
@@ -23,7 +23,7 @@
 
 namespace singa {
 
-RegisterLayerClass(CudnnBatchNorm);
+RegisterLayerClass(cudnn_batchnorm, CudnnBatchNorm);
 CudnnBatchNorm::~CudnnBatchNorm() {
   if (has_init_cudnn_) {
     CUDNN_CHECK(cudnnDestroyTensorDescriptor(shape_desc_));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.h b/src/model/layer/cudnn_batchnorm.h
index 4f46452..c4390a1 100644
--- a/src/model/layer/cudnn_batchnorm.h
+++ b/src/model/layer/cudnn_batchnorm.h
@@ -31,7 +31,7 @@ class CudnnBatchNorm : public BatchNorm {
  public:
   ~CudnnBatchNorm();
   /// \copy doc Layer::layer_type()
-  const std::string layer_type() const override { return "CudnnBatchNorm"; }
+  // const std::string layer_type() const override { return "CudnnBatchNorm"; }
 
   void Setup(const Shape& in_sample, const LayerConf& conf) override;
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index e5efec0..ffd2ab7 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -23,7 +23,7 @@
 #include "singa/utils/logging.h"
 
 namespace singa {
-RegisterLayerClass(CudnnConvolution);
+RegisterLayerClass(cudnn_convolution, CudnnConvolution);
 CudnnConvolution::~CudnnConvolution() {
   if (bias_desc_ != nullptr)
     CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc_));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_convolution.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.h b/src/model/layer/cudnn_convolution.h
index cd0471f..545fd5c 100644
--- a/src/model/layer/cudnn_convolution.h
+++ b/src/model/layer/cudnn_convolution.h
@@ -34,7 +34,7 @@ class CudnnConvolution : public Convolution {
  public:
   ~CudnnConvolution();
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "CudnnConvolution"; }
+  // const std::string layer_type() const override { return "CudnnConvolution";}
 
   const Tensor Forward(int flag, const Tensor &input) override;
   const std::pair<Tensor, vector<Tensor>> Backward(int flag,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_dropout.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_dropout.cc b/src/model/layer/cudnn_dropout.cc
index e6950ca..c5b62cf 100644
--- a/src/model/layer/cudnn_dropout.cc
+++ b/src/model/layer/cudnn_dropout.cc
@@ -27,7 +27,7 @@
 #include "singa/utils/logging.h"
 
 namespace singa {
-RegisterLayerClass(CudnnDropout);
+RegisterLayerClass(cudnn_dropout, CudnnDropout);
 CudnnDropout::~CudnnDropout() {
   if (drop_desc_ != nullptr)
     CUDNN_CHECK(cudnnDestroyDropoutDescriptor(drop_desc_));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_dropout.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_dropout.h b/src/model/layer/cudnn_dropout.h
index 9e0cb9e..1241911 100644
--- a/src/model/layer/cudnn_dropout.h
+++ b/src/model/layer/cudnn_dropout.h
@@ -36,7 +36,7 @@ class CudnnDropout : public Dropout {
  public:
   ~CudnnDropout();
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "CudnnDropout"; }
+  // const std::string layer_type() const override { return "CudnnDropout"; }
 
   const Tensor Forward(int flag, const Tensor& input) override;
   const std::pair<Tensor, vector<Tensor>> Backward(int flag,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_lrn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_lrn.cc b/src/model/layer/cudnn_lrn.cc
index 540beb1..ac7645e 100644
--- a/src/model/layer/cudnn_lrn.cc
+++ b/src/model/layer/cudnn_lrn.cc
@@ -23,7 +23,7 @@
 #include "cudnn_utils.h"
 
 namespace singa {
-RegisterLayerClass(CudnnLRN);
+RegisterLayerClass(cudnn_lrn, CudnnLRN);
 CudnnLRN::~CudnnLRN() {
   if (has_init_cudnn_) {
     CUDNN_CHECK(cudnnDestroyLRNDescriptor(lrn_desc_));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_lrn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_lrn.h b/src/model/layer/cudnn_lrn.h
index e2a5e54..c48571d 100644
--- a/src/model/layer/cudnn_lrn.h
+++ b/src/model/layer/cudnn_lrn.h
@@ -31,7 +31,7 @@ class CudnnLRN : public LRN {
  public:
   ~CudnnLRN();
   /// \copy doc Layer::layer_type()
-  const std::string layer_type() const override { return "CudnnLRN"; }
+  // const std::string layer_type() const override { return "CudnnLRN"; }
 
   const Tensor Forward(int flag, const Tensor& input) override;
   const std::pair<Tensor, vector<Tensor>> Backward(int flag,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_pooling.cc b/src/model/layer/cudnn_pooling.cc
index 984427c..895ce3c 100644
--- a/src/model/layer/cudnn_pooling.cc
+++ b/src/model/layer/cudnn_pooling.cc
@@ -25,7 +25,7 @@
 #include "singa/utils/logging.h"
 
 namespace singa {
-RegisterLayerClass(CudnnPooling);
+RegisterLayerClass(cudnn_pooling, CudnnPooling);
 CudnnPooling::~CudnnPooling() {
   if (pool_desc_ != nullptr)
     CUDNN_CHECK(cudnnDestroyPoolingDescriptor(pool_desc_));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_pooling.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_pooling.h b/src/model/layer/cudnn_pooling.h
index 90779f5..2080db3 100644
--- a/src/model/layer/cudnn_pooling.h
+++ b/src/model/layer/cudnn_pooling.h
@@ -35,7 +35,7 @@ class CudnnPooling : public Pooling {
  public:
   ~CudnnPooling();
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "CudnnPooling"; }
+  // const std::string layer_type() const override { return "CudnnPooling"; }
 
   void Setup(const Shape& in_sample, const LayerConf &conf) override;
   const Tensor Forward(int flag, const Tensor &input) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc
index bfbfa48..9961df2 100644
--- a/src/model/layer/cudnn_rnn.cc
+++ b/src/model/layer/cudnn_rnn.cc
@@ -24,6 +24,7 @@
 #include "singa/utils/logging.h"
 
 namespace singa {
+RegisterLayerClass(cudnn_rnn, CudnnRNN);
 CudnnRNN::~CudnnRNN() {
   if (weight_desc_ != nullptr)
     CUDNN_CHECK(cudnnDestroyFilterDescriptor(weight_desc_));
@@ -126,25 +127,19 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) {
       dropout_state_.block()->mutable_data(), state_size, seed_));
 
   CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_));
-  cudnnRNNInputMode_t input_mode;
-  if (input_mode_ == "linear")
-    input_mode = CUDNN_LINEAR_INPUT;
-  else if (input_mode_ == "skip")
+  cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
+  if (input_mode_ == "skip")
     input_mode = CUDNN_SKIP_INPUT;
 
-  cudnnDirectionMode_t direction;
-  if (direction_ == "unidirectional")
-    direction = CUDNN_UNIDIRECTIONAL;
-  else if (direction_ == "bidirectional")
+  cudnnDirectionMode_t direction = CUDNN_UNIDIRECTIONAL;
+  if (direction_ == "bidirectional")
     direction = CUDNN_BIDIRECTIONAL;
 
-  cudnnRNNMode_t rnn_mode;
+  cudnnRNNMode_t rnn_mode = CUDNN_LSTM;
   if (rnn_mode_ == "relu")
     rnn_mode = CUDNN_RNN_RELU;
   else if (rnn_mode_ == "tanh")
     rnn_mode = CUDNN_RNN_TANH;
-  else if (rnn_mode_ == "lstm")
-    rnn_mode = CUDNN_LSTM;
   else if (rnn_mode_ == "gru")
     rnn_mode = CUDNN_GRU;
   CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_size_, num_stacks_,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_rnn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.h b/src/model/layer/cudnn_rnn.h
index cfb8aac..82c68b0 100644
--- a/src/model/layer/cudnn_rnn.h
+++ b/src/model/layer/cudnn_rnn.h
@@ -39,7 +39,7 @@ class CudnnRNN : public RNN {
  public:
   ~CudnnRNN();
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "CudnnRNN"; }
+  // const std::string layer_type() const override { return "CudnnRNN"; }
 
   const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) override;
   const std::pair<vector<Tensor>, vector<Tensor>> Backward(

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_softmax.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_softmax.cc b/src/model/layer/cudnn_softmax.cc
index 6dce68f..f1a4a5b 100644
--- a/src/model/layer/cudnn_softmax.cc
+++ b/src/model/layer/cudnn_softmax.cc
@@ -23,7 +23,7 @@
 #include "singa/utils/logging.h"
 namespace singa {
 
-RegisterLayerClass(CudnnSoftmax);
+RegisterLayerClass(cudnn_softmax, CudnnSoftmax);
 CudnnSoftmax::~CudnnSoftmax() {
   if (desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc_));
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_softmax.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_softmax.h b/src/model/layer/cudnn_softmax.h
index aca3729..532a643 100644
--- a/src/model/layer/cudnn_softmax.h
+++ b/src/model/layer/cudnn_softmax.h
@@ -34,7 +34,7 @@ class CudnnSoftmax : public Softmax {
  public:
   ~CudnnSoftmax();
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "CudnnSoftmax"; }
+  // const std::string layer_type() const override { return "CudnnSoftmax"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample_shape, const LayerConf &conf) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/cudnn_utils.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_utils.h b/src/model/layer/cudnn_utils.h
index 19c72ec..64ee758 100644
--- a/src/model/layer/cudnn_utils.h
+++ b/src/model/layer/cudnn_utils.h
@@ -26,7 +26,7 @@
 #include "singa/utils/logging.h"
 namespace singa {
 inline cudnnDataType_t GetCudnnDataType(DataType dtype) {
-  cudnnDataType_t ret;
+  cudnnDataType_t ret = CUDNN_DATA_FLOAT;
   switch (dtype) {
     case kFloat32:
       ret = CUDNN_DATA_FLOAT;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/dense.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/dense.cc b/src/model/layer/dense.cc
index 557d8bd..1a2d16e 100644
--- a/src/model/layer/dense.cc
+++ b/src/model/layer/dense.cc
@@ -23,7 +23,7 @@
 namespace singa {
 using std::vector;
 
-RegisterLayerClass(Dense);
+RegisterLayerClass(singa_dense, Dense);
 Dense::~Dense() {
   // delete weight_;
   // delete bias_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/dense.h
----------------------------------------------------------------------
diff --git a/src/model/layer/dense.h b/src/model/layer/dense.h
index bb5db66..8a149a5 100644
--- a/src/model/layer/dense.h
+++ b/src/model/layer/dense.h
@@ -28,7 +28,7 @@ class Dense : public Layer {
  public:
   ~Dense();
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "Dense"; }
+  // const std::string layer_type() const override { return "Dense"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/dropout.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/dropout.cc b/src/model/layer/dropout.cc
index 0a4b1df..35801b4 100644
--- a/src/model/layer/dropout.cc
+++ b/src/model/layer/dropout.cc
@@ -20,7 +20,7 @@
 #include "./dropout.h"
 namespace singa {
 
-RegisterLayerClass(Dropout);
+RegisterLayerClass(singa_dropout, Dropout);
 void Dropout::Setup(const Shape& in_sample, const LayerConf& conf) {
   Layer::Setup(in_sample, conf);
   dropout_ratio_ = conf.dropout_conf().dropout_ratio();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/dropout.h
----------------------------------------------------------------------
diff --git a/src/model/layer/dropout.h b/src/model/layer/dropout.h
index 1a4bdbf..711c86b 100644
--- a/src/model/layer/dropout.h
+++ b/src/model/layer/dropout.h
@@ -26,7 +26,7 @@ namespace singa {
 class Dropout : public Layer {
  public:
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "Dropout"; }
+  // const std::string layer_type() const override { return "Dropout"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/flatten.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/flatten.cc b/src/model/layer/flatten.cc
index e7d8fa0..d89361e 100644
--- a/src/model/layer/flatten.cc
+++ b/src/model/layer/flatten.cc
@@ -20,7 +20,7 @@
 #include "./flatten.h"
 namespace singa {
 
-RegisterLayerClass(Flatten);
+RegisterLayerClass(singa_flatten, Flatten);
 void Flatten::Setup(const Shape& in_sample, const LayerConf &conf) {
   Layer::Setup(in_sample, conf);
   axis_ = conf.flatten_conf().axis();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/flatten.h
----------------------------------------------------------------------
diff --git a/src/model/layer/flatten.h b/src/model/layer/flatten.h
index 6ac90c2..8bbf481 100644
--- a/src/model/layer/flatten.h
+++ b/src/model/layer/flatten.h
@@ -26,7 +26,7 @@ namespace singa {
 class Flatten : public Layer {
  public:
   /// \copydoc Layer::layer_type();
-  const std::string layer_type() const override { return "Flatten"; }
+  // const std::string layer_type() const override { return "Flatten"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/lrn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/lrn.cc b/src/model/layer/lrn.cc
index a624147..6b5a618 100644
--- a/src/model/layer/lrn.cc
+++ b/src/model/layer/lrn.cc
@@ -22,7 +22,7 @@
 #include <vector>
 
 namespace singa {
-RegisterLayerClass(LRN);
+RegisterLayerClass(singa_lrn, LRN);
 void LRN::Setup(const Shape& in_sample, const LayerConf& conf) {
   Layer::Setup(in_sample, conf);
   out_sample_shape_ = in_sample;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/lrn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/lrn.h b/src/model/layer/lrn.h
index 0632f8c..57e26ba 100644
--- a/src/model/layer/lrn.h
+++ b/src/model/layer/lrn.h
@@ -27,9 +27,7 @@ namespace singa {
 class LRN : public Layer {
  public:
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override {
-    return "LRN";
-  }
+  // const std::string layer_type() const override { return "LRN"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/pooling.cc b/src/model/layer/pooling.cc
index 943f9b2..5e7ba1d 100644
--- a/src/model/layer/pooling.cc
+++ b/src/model/layer/pooling.cc
@@ -20,7 +20,7 @@
 #include "singa/model/layer.h"
 namespace singa {
 
-RegisterLayerClass(Pooling);
+RegisterLayerClass(singa_pooling, Pooling);
 void Pooling::Setup(const Shape& in_sample, const LayerConf& conf) {
   Layer::Setup(in_sample, conf);
   PoolingConf pool_conf = conf.pooling_conf();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/pooling.h
----------------------------------------------------------------------
diff --git a/src/model/layer/pooling.h b/src/model/layer/pooling.h
index 6df292a..f844799 100644
--- a/src/model/layer/pooling.h
+++ b/src/model/layer/pooling.h
@@ -28,7 +28,7 @@ namespace singa {
 class Pooling : public Layer {
  public:
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "Pooling"; }
+  // const std::string layer_type() const override { return "Pooling"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/prelu.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/prelu.cc b/src/model/layer/prelu.cc
index 421bcaa..a20972c 100644
--- a/src/model/layer/prelu.cc
+++ b/src/model/layer/prelu.cc
@@ -20,7 +20,7 @@
 #include "./prelu.h"
 namespace singa {
 
-RegisterLayerClass(PReLU);
+RegisterLayerClass(singa_prelu, PReLU);
 void PReLU::Setup(const Shape& in_sample, const LayerConf &conf) {
   Layer::Setup(in_sample, conf);
   out_sample_shape_ = in_sample;
@@ -82,7 +82,7 @@ const std::pair<Tensor, vector<Tensor> > PReLU::Backward(int flag,
   Tensor da;
   da.ResetLike(a_);
   if (!channel_shared_) {
-    size_t n, c, h, w;
+    size_t n = 0, c = 0, h = 0, w = 0;
     Tensor temp1 = (input <= 0.f);
     if (temp1.nDim() == 4) {
       if (format_ == "NCHW") {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/prelu.h
----------------------------------------------------------------------
diff --git a/src/model/layer/prelu.h b/src/model/layer/prelu.h
index 70a9dcf..3041d1e 100644
--- a/src/model/layer/prelu.h
+++ b/src/model/layer/prelu.h
@@ -27,7 +27,7 @@ namespace singa {
 class PReLU : public Layer {
  public:
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "PReLU"; }
+  //  const std::string layer_type() const override { return "PReLU"; }
 
 
   /// \copydoc Layer::Setup(const LayerConf&);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/rnn.cc b/src/model/layer/rnn.cc
index 424c20b..524b462 100644
--- a/src/model/layer/rnn.cc
+++ b/src/model/layer/rnn.cc
@@ -22,7 +22,7 @@
 #include "singa/utils/string.h"
 
 namespace singa {
-
+RegisterLayerClass(singa_rnn, RNN);
 void RNN::Setup(const Shape& in_sample, const LayerConf &conf) {
   Layer::Setup(in_sample, conf);
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/rnn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/rnn.h b/src/model/layer/rnn.h
index 1b5dad7..3369a00 100644
--- a/src/model/layer/rnn.h
+++ b/src/model/layer/rnn.h
@@ -35,7 +35,7 @@ namespace singa {
 class RNN : public Layer {
  public:
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "RNN"; }
+  // const std::string layer_type() const override { return "RNN"; }
 
   /// Setup the RNN layer.
   /// in_shape is the shape of a single training instance from one timestep,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/softmax.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/softmax.cc b/src/model/layer/softmax.cc
index 6b1785c..6a49131 100644
--- a/src/model/layer/softmax.cc
+++ b/src/model/layer/softmax.cc
@@ -19,7 +19,7 @@
 #include "./softmax.h"
 namespace singa {
 
-RegisterLayerClass(Softmax);
+RegisterLayerClass(singa_softmax, Softmax);
 void Softmax::Setup(const Shape& in_sample, const LayerConf& conf) {
   Layer::Setup(in_sample, conf);
   CHECK_EQ(in_sample.size(), 1u);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/model/layer/softmax.h
----------------------------------------------------------------------
diff --git a/src/model/layer/softmax.h b/src/model/layer/softmax.h
index 837b23a..cf71587 100644
--- a/src/model/layer/softmax.h
+++ b/src/model/layer/softmax.h
@@ -24,7 +24,7 @@ namespace singa {
 class Softmax : public Layer {
  public:
   /// \copydoc Layer::layer_type()
-  const std::string layer_type() const override { return "Softmax"; }
+  // const std::string layer_type() const override { return "Softmax"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
   void Setup(const Shape& in_sample, const LayerConf& conf) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/05720c21/src/python/singa/device.py
----------------------------------------------------------------------
diff --git a/src/python/singa/device.py b/src/python/singa/device.py
index 3db90bf..aff3587 100644
--- a/src/python/singa/device.py
+++ b/src/python/singa/device.py
@@ -73,3 +73,16 @@ def create_cuda_gpus(num):
 
 def create_cuda_gpu():
     return singa.Platform.CreateCudaGPUs(1)[0]
+
+
+def create_cuda_gpus_on(device_ids):
+    return singa.Platform.CreateCudaGPUsOn(device_ids)
+
+
+def create_cuda_gpu_on(device_id):
+    devices = create_cuda_gpus_on([device_id])
+    return devices[0]
+
+
+def get_default_device():
+    return singa.Platform.GetDefaultDevice()