You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2018/08/24 02:58:36 UTC

[1/3] incubator-singa git commit: SINGA-383 Add Separable Convolution for autograd

Repository: incubator-singa
Updated Branches:
  refs/heads/master 2224d5f9a -> 8aac80e42


SINGA-383 Add Separable Convolution for autograd

- let Conv2d layer support 'groups' paramters, for grouped convolution.

- implement Separable Convolution layer.

- add unit test case for new developed SeparableConv2d layer.

- the implemented SeparableConv2d layer has passed both unit test and network test.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/ca70bdf3
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/ca70bdf3
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/ca70bdf3

Branch: refs/heads/master
Commit: ca70bdf3f02412f216d10e8d4ba6c265bdd139ee
Parents: 2224d5f
Author: xuewanqi <xu...@outlook.com>
Authored: Mon Aug 20 08:16:02 2018 +0000
Committer: xuewanqi <xu...@outlook.com>
Committed: Thu Aug 23 06:30:06 2018 +0000

----------------------------------------------------------------------
 python/singa/autograd.py           | 52 ++++++++++++++++++++++++++-------
 src/api/model_operation.i          |  2 +-
 src/model/operation/convolution.cc | 10 +++++--
 src/model/operation/convolution.h  | 10 +++----
 test/python/test_operation.py      | 22 ++++++++++++++
 5 files changed, 77 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ca70bdf3/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index c0f6a7a..938b813 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -684,6 +684,14 @@ class Conv2d(Layer):
         self.in_channels = in_channels
         self.out_channels = out_channels
 
+        self.groups = groups
+
+        assert self.groups >= 1 and self.in_channels % self.groups == 0, 'please set reasonable groups.'
+
+        # each group should contribute equally to the output feature maps. shown as the later part of
+        # the following judgement.
+        assert self.out_channels >= self.groups and self.out_channels % self.groups == 0, 'out_channels and groups dismatched.'
+
         if isinstance(kernel_size, int):
             self.kernel_size = (kernel_size, kernel_size)
         elif isinstance(kernel_size, tuple):
@@ -705,7 +713,7 @@ class Conv2d(Layer):
         else:
             raise TypeError('Wrong padding type.')
 
-        if dilation != 1 or groups != 1:
+        if dilation != 1:
             raise ValueError('Not implemented yet')
 
         self.bias = bias
@@ -720,11 +728,15 @@ class Conv2d(Layer):
             else:
                 self.inner_params[kwarg] = kwargs[kwarg]
 
-        w_shape = (self.out_channels, self.in_channels,
+        w_shape = (self.out_channels, int(self.in_channels / self.groups),
                    self.kernel_size[0], self.kernel_size[1])
+
         self.W = Tensor(shape=w_shape, requires_grad=True, stores_grad=True)
+        # std = math.sqrt(
+        # 2.0 / (self.in_channels * self.kernel_size[0] * self.kernel_size[1] +
+        # self.out_channels))
         std = math.sqrt(
-            2.0 / (self.in_channels * self.kernel_size[0] * self.kernel_size[1] + self.out_channels))
+            2.0 / (w_shape[1] * self.kernel_size[0] * self.kernel_size[1] + self.out_channels))
         self.W.gaussian(0.0, std)
 
         if self.bias:
@@ -743,25 +755,43 @@ class Conv2d(Layer):
         self.device_check(x, self.W, self.b)
 
         if x.device.id() == -1:
-            if not hasattr(self, 'handle'):
-                self.handle = singa.ConvHandle(x.data, self.kernel_size, self.stride,
-                                               self.padding, self.in_channels, self.out_channels, self.bias)
-            elif x.shape[0] != self.handle.batchsize:
-                self.handle = singa.ConvHandle(x.data, self.kernel_size, self.stride,
-                                               self.padding, self.in_channels, self.out_channels, self.bias)
+            if self.groups != 1:
+                raise ValueError('Not implemented yet')
+            else:
+                if not hasattr(self, 'handle'):
+                    self.handle = singa.ConvHandle(x.data, self.kernel_size, self.stride,
+                                                   self.padding, self.in_channels, self.out_channels, self.bias)
+                elif x.shape[0] != self.handle.batchsize:
+                    self.handle = singa.ConvHandle(x.data, self.kernel_size, self.stride,
+                                                   self.padding, self.in_channels, self.out_channels, self.bias)
         else:
             if not hasattr(self, 'handle'):
                 self.handle = singa.CudnnConvHandle(x.data, self.kernel_size, self.stride,
-                                                    self.padding, self.in_channels, self.out_channels, self.bias)
+                                                    self.padding, self.in_channels, self.out_channels, self.bias, self.groups)
             elif x.shape[0] != self.handle.batchsize:
                 self.handle = singa.CudnnConvHandle(x.data, self.kernel_size, self.stride,
-                                                    self.padding, self.in_channels, self.out_channels, self.bias)
+                                                    self.padding, self.in_channels, self.out_channels, self.bias, self.groups)
         self.handle.device_id = x.device.id()
 
         y = conv2d(self.handle, x, self.W, self.b)
         return y
 
 
+class SeparableConv2d(Layer):
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
+
+        self.mapping_spacial_conv = Conv2d(
+            in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=False)
+
+        self.mapping_depth_conv = Conv2d(in_channels, out_channels, 1, bias=False)
+
+    def __call__(self, x):
+        y = self.mapping_spacial_conv(x)
+        y = self.mapping_depth_conv(y)
+        return y
+
+
 class BatchNorm2d(Layer):
 
     def __init__(self, num_features, momentum=0.9):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ca70bdf3/src/api/model_operation.i
----------------------------------------------------------------------
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index 435ff1c..56141d8 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -58,7 +58,7 @@ class CudnnConvHandle: public ConvHandle {
   CudnnConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
                   const std::vector<size_t>& stride, const std::vector<size_t>& padding,
                   const size_t in_channels, const size_t out_channels,
-                  const bool bias, const size_t workspace_byte_limit = 1024 * 1024 * 1024,
+                  const bool bias, const size_t groups = 1, const size_t workspace_byte_limit = 1024 * 1024 * 1024,
                   const std::string& prefer = "fastest");
   bool bias_term;
   size_t batchsize;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ca70bdf3/src/model/operation/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution.cc b/src/model/operation/convolution.cc
index 7c71d7c..beb824d 100755
--- a/src/model/operation/convolution.cc
+++ b/src/model/operation/convolution.cc
@@ -184,6 +184,7 @@ CudnnConvHandle::CudnnConvHandle(const Tensor &input,
                                  const std::vector<size_t>& kernel_size,
                                  const std::vector<size_t>& stride, const std::vector<size_t>& padding,
                                  const size_t in_channels, const size_t out_channels, const bool bias,
+                                 const size_t groups,
                                  const size_t workspace_byte_limit, const std::string& prefer)
   : ConvHandle(input, kernel_size, stride, padding, in_channels, out_channels,
                bias) {
@@ -199,7 +200,6 @@ CudnnConvHandle::CudnnConvHandle(const Tensor &input,
   CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc));
   CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
 
-
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW,
                                          GetCudnnDataType(dtype), batchsize,
                                          channels, height, width));
@@ -217,9 +217,14 @@ CudnnConvHandle::CudnnConvHandle(const Tensor &input,
               , GetCudnnDataType(dtype)
 #endif
                                              ));
+  if (CUDNN_MAJOR >= 7 && groups > 1) {
+    CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, groups));
+  }
+  else if (groups > 1) {LOG(FATAL) << "The current version of cuDNN not support grouped convolution.";};
+
   CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc, GetCudnnDataType(dtype),
                                          CUDNN_TENSOR_NCHW, num_filters,
-                                         channels, kernel_h, kernel_w));
+                                         channels / groups, kernel_h, kernel_w));
   if (prefer == "fastest" || prefer == "limited_workspace" ||
       prefer == "no_workspace") {
     cudnnConvolutionFwdPreference_t fwd_pref;
@@ -289,6 +294,7 @@ CudnnConvHandle::CudnnConvHandle(const Tensor &input,
                  << ") is larger than the expected Bytes ("
                  << workspace_byte_limit << ")";
   workspace = Tensor(Shape{workspace_count}, dev, dtype);
+
 }
 
 CudnnConvHandle::~CudnnConvHandle() {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ca70bdf3/src/model/operation/convolution.h
----------------------------------------------------------------------
diff --git a/src/model/operation/convolution.h b/src/model/operation/convolution.h
index 9da881f..7fd1ce7 100755
--- a/src/model/operation/convolution.h
+++ b/src/model/operation/convolution.h
@@ -17,12 +17,12 @@ namespace singa {
 
 class ConvHandle {
 
- public:
+public:
   ConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
              const std::vector<size_t>& stride, const std::vector<size_t>& padding,
              const size_t in_channels, const size_t out_channels,
              const bool bias);
- 
+
   size_t kernel_w;
   size_t pad_w;
   size_t stride_w;
@@ -59,15 +59,15 @@ Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch)
 
 #ifdef USE_CUDNN
 class CudnnConvHandle: public ConvHandle {
- public:
+public:
   CudnnConvHandle(const Tensor &input, const std::vector<size_t>& kernel_size,
                   const std::vector<size_t>& stride, const std::vector<size_t>& padding,
                   const size_t in_channels, const size_t out_channels,
-                  const bool bias, const size_t workspace_byte_limit = 1024 * 1024 * 1024,
+                  const bool bias, const size_t groups = 1, const size_t workspace_byte_limit = 1024 * 1024 * 1024,
                   const std::string& prefer = "fastest");
   ~CudnnConvHandle();
   // TODO(wangwei) add the destructor
- 
+
   cudnnTensorDescriptor_t x_desc = nullptr;
   cudnnTensorDescriptor_t y_desc = nullptr;
   cudnnTensorDescriptor_t bias_desc = nullptr;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ca70bdf3/test/python/test_operation.py
----------------------------------------------------------------------
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index 64562a5..d20e764 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -99,6 +99,28 @@ class TestPythonOperation(unittest.TestCase):
         y_without_bias = conv_without_bias_1(cpu_input_tensor)
         self.check_shape(y_without_bias.shape, (2, 1, 2, 2))
 
+    def test_SeparableConv2d_gpu(self):
+        separ_conv=autograd.SeparableConv2d(8, 16, 3, padding=1)
+
+        x=np.random.random((10,8,28,28)).astype(np.float32)
+        x=tensor.Tensor(device=gpu_dev, data=x)
+
+        #y = separ_conv(x)
+        y1 = separ_conv.mapping_spacial_conv(x)
+        y2 = separ_conv.mapping_depth_conv(y1)
+        
+        dy1, dW_depth, _ = y2.creator.backward(y2.data)
+        dx, dW_spacial, _ = y1.creator.backward(dy1)
+
+        self.check_shape(y2.shape, (10, 16, 28, 28))
+
+        self.check_shape(dy1.shape(), (10, 8, 28, 28))
+        self.check_shape(dW_depth.shape(), (16, 8, 1, 1)) 
+
+        self.check_shape(dx.shape(), (10, 8, 28, 28))
+        self.check_shape(dW_spacial.shape(), (8, 1, 3, 3))
+
+
     def test_batchnorm2d_gpu(self):
         batchnorm_0 = autograd.BatchNorm2d(3)
 


[2/3] incubator-singa git commit: SINGA-383 Add Separable Convolution for autograd

Posted by zh...@apache.org.
SINGA-383 Add Separable Convolution for autograd

- Implement Xception net by calling SeparableConv2d layer. The file is added to   /example/autograd folder.

- Modified SeparableConv2d layer API.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/d5422a43
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/d5422a43
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/d5422a43

Branch: refs/heads/master
Commit: d5422a432d1ad4261f3b44b1f61af5f4c2a651ec
Parents: ca70bdf
Author: xuewanqi <xu...@outlook.com>
Authored: Tue Aug 21 14:17:00 2018 +0000
Committer: xuewanqi <xu...@outlook.com>
Committed: Thu Aug 23 06:49:53 2018 +0000

----------------------------------------------------------------------
 examples/autograd/xceptionnet.py | 202 ++++++++++++++++++++++++++++++++++
 python/singa/autograd.py         |   6 +-
 2 files changed, 205 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d5422a43/examples/autograd/xceptionnet.py
----------------------------------------------------------------------
diff --git a/examples/autograd/xceptionnet.py b/examples/autograd/xceptionnet.py
new file mode 100755
index 0000000..f52a8ac
--- /dev/null
+++ b/examples/autograd/xceptionnet.py
@@ -0,0 +1,202 @@
+from singa import autograd
+from singa import tensor
+from singa import device
+from singa import opt
+
+import numpy as np
+from tqdm import trange
+
+
+# the code is modified from
+# https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py
+
+__all__ = ['xception']
+
+
+class Block(autograd.Layer):
+
+    def __init__(self, in_filters, out_filters, reps, strides=1, padding=0, start_with_relu=True, grow_first=True):
+        super(Block, self).__init__()
+
+        if out_filters != in_filters or strides != 1:
+            self.skip = autograd.Conv2d(in_filters, out_filters,
+                                        1, stride=strides, padding=padding, bias=False)
+            self.skipbn = autograd.BatchNorm2d(out_filters)
+        else:
+            self.skip = None
+
+        self.layers = []
+
+        filters = in_filters
+        if grow_first:
+            self.layers.append(autograd.ReLU())
+            self.layers.append(autograd.SeparableConv2d(in_filters, out_filters,
+                                                        3, stride=1, padding=1, bias=False))
+            self.layers.append(autograd.BatchNorm2d(out_filters))
+            filters = out_filters
+
+        for i in range(reps - 1):
+            self.layers.append(autograd.ReLU())
+            self.layers.append(autograd.SeparableConv2d(filters, filters,
+                                                        3, stride=1, padding=1, bias=False))
+            self.layers.append(autograd.BatchNorm2d(filters))
+
+        if not grow_first:
+            self.layers.append(autograd.ReLU())
+            self.layers.append(autograd.SeparableConv2d(in_filters, out_filters,
+                                                        3, stride=1, padding=1, bias=False))
+            self.layers.append(autograd.BatchNorm2d(out_filters))
+
+        if not start_with_relu:
+            self.layers = self.layers[1:]
+        else:
+            self.layers[0] = autograd.ReLU()
+
+        if strides != 1:
+            self.layers.append(autograd.MaxPool2d(3, strides, padding + 1))
+
+    def __call__(self, x):
+        y = self.layers[0](x)
+        for layer in self.layers[1:]:
+            if isinstance(y, tuple):
+                y = y[0]
+            y = layer(y)
+
+        if self.skip is not None:
+            skip = self.skip(x)
+            skip = self.skipbn(skip)
+        else:
+            skip = x
+        y = autograd.add(y, skip)
+        return y
+
+
+class Xception(autograd.Layer):
+    """
+    Xception optimized for the ImageNet dataset, as specified in
+    https://arxiv.org/pdf/1610.02357.pdf
+    """
+
+    def __init__(self, num_classes=1000):
+        """ Constructor
+        Args:
+            num_classes: number of classes
+        """
+        super(Xception, self).__init__()
+        self.num_classes = num_classes
+
+        self.conv1 = autograd.Conv2d(3, 32, 3, 2, 0, bias=False)
+        self.bn1 = autograd.BatchNorm2d(32)
+
+        self.conv2 = autograd.Conv2d(32, 64, 3, 1, 1, bias=False)
+        self.bn2 = autograd.BatchNorm2d(64)
+        # do relu here
+
+        self.block1 = Block(
+            64, 128, 2, 2, padding=0, start_with_relu=False, grow_first=True)
+        self.block2 = Block(
+            128, 256, 2, 2, padding=0, start_with_relu=True, grow_first=True)
+        self.block3 = Block(
+            256, 728, 2, 2, padding=0, start_with_relu=True, grow_first=True)
+
+        self.block4 = Block(
+            728, 728, 3, 1, start_with_relu=True, grow_first=True)
+        self.block5 = Block(
+            728, 728, 3, 1, start_with_relu=True, grow_first=True)
+        self.block6 = Block(
+            728, 728, 3, 1, start_with_relu=True, grow_first=True)
+        self.block7 = Block(
+            728, 728, 3, 1, start_with_relu=True, grow_first=True)
+
+        self.block8 = Block(
+            728, 728, 3, 1, start_with_relu=True, grow_first=True)
+        self.block9 = Block(
+            728, 728, 3, 1, start_with_relu=True, grow_first=True)
+        self.block10 = Block(
+            728, 728, 3, 1, start_with_relu=True, grow_first=True)
+        self.block11 = Block(
+            728, 728, 3, 1, start_with_relu=True, grow_first=True)
+
+        self.block12 = Block(
+            728, 1024, 2, 2, start_with_relu=True, grow_first=False)
+
+        self.conv3 = autograd.SeparableConv2d(1024, 1536, 3, 1, 1)
+        self.bn3 = autograd.BatchNorm2d(1536)
+
+        # do relu here
+        self.conv4 = autograd.SeparableConv2d(1536, 2048, 3, 1, 1)
+        self.bn4 = autograd.BatchNorm2d(2048)
+
+        self.globalpooling = autograd.MaxPool2d(10, 1)
+        self.fc = autograd.Linear(2048, num_classes)
+
+    def features(self, input):
+        x = self.conv1(input)
+        x = self.bn1(x)
+        x = autograd.relu(x)
+
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = autograd.relu(x)
+
+        x = self.block1(x)
+        x = self.block2(x)
+        x = self.block3(x)
+        x = self.block4(x)
+        x = self.block5(x)
+        x = self.block6(x)
+        x = self.block7(x)
+        x = self.block8(x)
+        x = self.block9(x)
+        x = self.block10(x)
+        x = self.block11(x)
+        x = self.block12(x)
+
+        x = self.conv3(x)
+        x = self.bn3(x)
+        x = autograd.relu(x)
+
+        x = self.conv4(x)
+        x = self.bn4(x)
+        return x
+
+    def logits(self, features):
+        x = autograd.relu(features)
+        x = self.globalpooling(x)
+        x = autograd.flatten(x)
+        x = self.fc(x)
+        return x
+
+    def __call__(self, input):
+        x = self.features(input)
+        x = self.logits(x)
+        return x
+
+
+if __name__ == '__main__':
+    model = Xception(num_classes=1000)
+    print('Start intialization............')
+    dev = device.create_cuda_gpu_on(0)
+    #dev = device.create_cuda_gpu()
+
+    niters = 20
+    batch_size = 16
+    IMG_SIZE = 299
+    sgd = opt.SGD(lr=0.1, momentum=0.9, weight_decay=1e-5)
+
+    tx = tensor.Tensor((batch_size, 3, IMG_SIZE, IMG_SIZE), dev)
+    ty = tensor.Tensor((batch_size,), dev, tensor.int32)
+    autograd.training = True
+    x = np.random.randn(batch_size, 3, IMG_SIZE, IMG_SIZE).astype(np.float32)
+    y = np.random.randint(0, 1000, batch_size, dtype=np.int32)
+    tx.copy_from_numpy(x)
+    ty.copy_from_numpy(y)
+
+    with trange(niters) as t:
+        for b in t:
+            x = model(tx)
+            loss = autograd.softmax_cross_entropy(x, ty)
+            for p, g in autograd.backward(loss):
+                # print(p.shape, g.shape)
+                sgd.update(p, g)
+                # pass

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d5422a43/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 938b813..84afcd1 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -779,12 +779,12 @@ class Conv2d(Layer):
 
 class SeparableConv2d(Layer):
 
-    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
 
         self.mapping_spacial_conv = Conv2d(
-            in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=False)
+            in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=bias)
 
-        self.mapping_depth_conv = Conv2d(in_channels, out_channels, 1, bias=False)
+        self.mapping_depth_conv = Conv2d(in_channels, out_channels, 1, bias=bias)
 
     def __call__(self, x):
         y = self.mapping_spacial_conv(x)


[3/3] incubator-singa git commit: SINGA-383 Add Separable Convolution for autograd

Posted by zh...@apache.org.
SINGA-383 Add Separable Convolution for autograd

- rename two functions, spacial_conv and depth_conv, in SeparableConv2d Layer.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/8aac80e4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/8aac80e4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/8aac80e4

Branch: refs/heads/master
Commit: 8aac80e425d9a146f8a86b6e38a023edca542099
Parents: d5422a4
Author: xuewanqi <xu...@outlook.com>
Authored: Fri Aug 24 02:19:18 2018 +0000
Committer: xuewanqi <xu...@outlook.com>
Committed: Fri Aug 24 02:19:18 2018 +0000

----------------------------------------------------------------------
 python/singa/autograd.py      | 8 ++++----
 test/python/test_operation.py | 4 ++--
 2 files changed, 6 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8aac80e4/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 84afcd1..b521126 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -781,14 +781,14 @@ class SeparableConv2d(Layer):
 
     def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
 
-        self.mapping_spacial_conv = Conv2d(
+        self.spacial_conv = Conv2d(
             in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=bias)
 
-        self.mapping_depth_conv = Conv2d(in_channels, out_channels, 1, bias=bias)
+        self.depth_conv = Conv2d(in_channels, out_channels, 1, bias=bias)
 
     def __call__(self, x):
-        y = self.mapping_spacial_conv(x)
-        y = self.mapping_depth_conv(y)
+        y = self.spacial_conv(x)
+        y = self.depth_conv(y)
         return y
 
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8aac80e4/test/python/test_operation.py
----------------------------------------------------------------------
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index d20e764..2fdd9fb 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -106,8 +106,8 @@ class TestPythonOperation(unittest.TestCase):
         x=tensor.Tensor(device=gpu_dev, data=x)
 
         #y = separ_conv(x)
-        y1 = separ_conv.mapping_spacial_conv(x)
-        y2 = separ_conv.mapping_depth_conv(y1)
+        y1 = separ_conv.spacial_conv(x)
+        y2 = separ_conv.depth_conv(y1)
         
         dy1, dW_depth, _ = y2.creator.backward(y2.data)
         dx, dW_spacial, _ = y1.creator.backward(dy1)