You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/10/18 13:07:31 UTC

[2/3] incubator-singa git commit: SINGA-253 Net converter for caffe model

SINGA-253 Net converter for caffe model

Convert caffe model into singa model.
It is a very basic implementation. Now it can convert feed forward net and supports pysinga.
Implementation method:
1. read proto file of caffe model using caffe.proto, and serialize it to string.
2. parse the string use singa model.proto and get the layer config.
3. setup each layer and add it to a feed forward net.

Update and example:
1. put conversion funcions into layer.py and converter.py;
2. add 'caffe' option in cifar example which converts alexnet from
caffe to singa.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/77405ec1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/77405ec1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/77405ec1

Branch: refs/heads/master
Commit: 77405ec19565ed82feedb933cf562f61c44522c4
Parents: c967169
Author: XiangruiCAI <ca...@gmail.com>
Authored: Thu Sep 29 22:22:00 2016 +0800
Committer: Xiangrui <ca...@gmail.com>
Committed: Tue Oct 18 11:38:26 2016 +0800

----------------------------------------------------------------------
 examples/cifar10/README.md                      |    4 +-
 examples/cifar10/caffe/__init__.py              |    0
 examples/cifar10/caffe/caffe_net.py             |   42 +
 .../cifar10/caffe/cifar10_full_solver.prototxt  |   30 +
 .../caffe/cifar10_full_train_test.prototxt      |  223 +++
 .../cifar10/caffe/cifar10_quick_solver.prototxt |   29 +
 .../caffe/cifar10_quick_train_test.prototxt     |  225 +++
 examples/cifar10/train.py                       |   25 +-
 python/singa/converter.py                       |  155 ++
 python/singa/layer.py                           |  127 +-
 src/proto/caffe.proto                           | 1402 ++++++++++++++++++
 src/proto/model.proto                           |    5 +-
 12 files changed, 2250 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/77405ec1/examples/cifar10/README.md
----------------------------------------------------------------------
diff --git a/examples/cifar10/README.md b/examples/cifar10/README.md
index 0cf069f..65df5e6 100644
--- a/examples/cifar10/README.md
+++ b/examples/cifar10/README.md
@@ -10,6 +10,7 @@ the best validation accuracy (without data augmentation) we achieved was about 8
 
 2. [VGGNet](http://torch.ch/blog/2015/07/30/cifar.html), the best validation accuracy (without data augmentation) we achieved was about 89%.
 3. [ResNet](https://github.com/facebook/fb.resnet.torch), the best validation accuracy (without data augmentation) we achieved was about 83%.
+4. [Alexnet from Caffe](https://github.com/BVLC/caffe/tree/master/examples/cifar10), SINGA is able to convert model from Caffe seamlessly.
 
 
 ## Instructions
@@ -40,7 +41,8 @@ version of the Cifar-10 dataset in 'cifar-10-batches-py' folder.
 
         python train.py vgg cifar-10-batches-py
 
-    To train other models, please replace 'vgg' to 'alexnet' or 'resnet'. By default
+    To train other models, please replace 'vgg' to 'alexnet', 'resnet' or 'caffe', 
+    where 'caffe' refers to the alexnet model converted from Caffe. By default
     the training would run on a CudaGPU device, to run it on CppCPU, add an additional
     argument
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/77405ec1/examples/cifar10/caffe/__init__.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/caffe/__init__.py b/examples/cifar10/caffe/__init__.py
new file mode 100644
index 0000000..e69de29

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/77405ec1/examples/cifar10/caffe/caffe_net.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/caffe/caffe_net.py b/examples/cifar10/caffe/caffe_net.py
new file mode 100644
index 0000000..543e3a5
--- /dev/null
+++ b/examples/cifar10/caffe/caffe_net.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+import os
+from singa import converter
+
+
+def create_net(use_cpu):
+    if use_cpu:
+        layer.engine = 'singacpp'
+
+    #net_proto = os.path.abspath('./caffe/cifar10_full_train_test.prototxt')
+    #solver_proto = os.path.abspath('./caffe/cifar10_full_solver.prototxt')
+    net_proto = os.path.abspath('./caffe/cifar10_quick_train_test.prototxt')
+    solver_proto = os.path.abspath('./caffe/cifar10_quick_solver.prototxt')
+    input_sample_shape = [3, 32, 32, ]
+
+    cvt = converter.CaffeConverter(net_proto, solver_proto, input_sample_shape)
+    net = cvt.create_net()
+    for (p, specs) in zip(net.param_values(), net.param_specs()):
+        filler = specs.filler
+        if filler.type == 'gaussian':
+            p.gaussian(filler.mean, filler.std)
+        else:
+            p.set_value(0)
+        print specs.name, filler.type, p.l1()
+
+    return net

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/77405ec1/examples/cifar10/caffe/cifar10_full_solver.prototxt
----------------------------------------------------------------------
diff --git a/examples/cifar10/caffe/cifar10_full_solver.prototxt b/examples/cifar10/caffe/cifar10_full_solver.prototxt
new file mode 100644
index 0000000..1e708d8
--- /dev/null
+++ b/examples/cifar10/caffe/cifar10_full_solver.prototxt
@@ -0,0 +1,30 @@
+# From caffe repository
+# Commits on Sep 2, 2015
+
+# reduce learning rate after 120 epochs (60000 iters) by factor 0f 10
+# then another factor of 10 after 10 more epochs (5000 iters)
+
+# The train/test net protocol buffer definition
+net: "examples/cifar10/cifar10_full_train_test.prototxt"
+# test_iter specifies how many forward passes the test should carry out.
+# In the case of CIFAR10, we have test batch size 100 and 100 test iterations,
+# covering the full 10,000 testing images.
+test_iter: 100
+# Carry out testing every 1000 training iterations.
+test_interval: 1000
+# The base learning rate, momentum and the weight decay of the network.
+base_lr: 0.001
+momentum: 0.9
+weight_decay: 0.004
+# The learning rate policy
+lr_policy: "fixed"
+# Display every 200 iterations
+display: 200
+# The maximum number of iterations
+max_iter: 60000
+# snapshot intermediate results
+snapshot: 10000
+snapshot_format: HDF5
+snapshot_prefix: "examples/cifar10/cifar10_full"
+# solver mode: CPU or GPU
+solver_mode: GPU

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/77405ec1/examples/cifar10/caffe/cifar10_full_train_test.prototxt
----------------------------------------------------------------------
diff --git a/examples/cifar10/caffe/cifar10_full_train_test.prototxt b/examples/cifar10/caffe/cifar10_full_train_test.prototxt
new file mode 100644
index 0000000..40b39ec
--- /dev/null
+++ b/examples/cifar10/caffe/cifar10_full_train_test.prototxt
@@ -0,0 +1,223 @@
+# From caffe repository
+# Commits on Feb 6, 2015
+
+name: "CIFAR10_full"
+layer {
+  name: "cifar"
+  type: "Data"
+  top: "data"
+  top: "label"
+  include {
+    phase: TRAIN
+  }
+  transform_param {
+    mean_file: "examples/cifar10/mean.binaryproto"
+  }
+  data_param {
+    source: "examples/cifar10/cifar10_train_lmdb"
+    batch_size: 100
+    backend: LMDB
+  }
+}
+layer {
+  name: "cifar"
+  type: "Data"
+  top: "data"
+  top: "label"
+  include {
+    phase: TEST
+  }
+  transform_param {
+    mean_file: "examples/cifar10/mean.binaryproto"
+  }
+  data_param {
+    source: "examples/cifar10/cifar10_test_lmdb"
+    batch_size: 100
+    backend: LMDB
+  }
+}
+layer {
+  name: "conv1"
+  type: "Convolution"
+  bottom: "data"
+  top: "conv1"
+  param {
+    lr_mult: 1
+  }
+  param {
+    lr_mult: 2
+  }
+  convolution_param {
+    num_output: 32
+    pad: 2
+    kernel_size: 5
+    stride: 1
+    weight_filler {
+      type: "gaussian"
+      std: 0.0001
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+}
+layer {
+  name: "pool1"
+  type: "Pooling"
+  bottom: "conv1"
+  top: "pool1"
+  pooling_param {
+    pool: MAX
+    kernel_size: 3
+    stride: 2
+  }
+}
+layer {
+  name: "relu1"
+  type: "ReLU"
+  bottom: "pool1"
+  top: "pool1"
+}
+layer {
+  name: "norm1"
+  type: "LRN"
+  bottom: "pool1"
+  top: "norm1"
+  lrn_param {
+    local_size: 3
+    alpha: 5e-05
+    beta: 0.75
+    norm_region: WITHIN_CHANNEL
+  }
+}
+layer {
+  name: "conv2"
+  type: "Convolution"
+  bottom: "norm1"
+  top: "conv2"
+  param {
+    lr_mult: 1
+  }
+  param {
+    lr_mult: 2
+  }
+  convolution_param {
+    num_output: 32
+    pad: 2
+    kernel_size: 5
+    stride: 1
+    weight_filler {
+      type: "gaussian"
+      std: 0.01
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+}
+layer {
+  name: "relu2"
+  type: "ReLU"
+  bottom: "conv2"
+  top: "conv2"
+}
+layer {
+  name: "pool2"
+  type: "Pooling"
+  bottom: "conv2"
+  top: "pool2"
+  pooling_param {
+    pool: AVE
+    kernel_size: 3
+    stride: 2
+  }
+}
+layer {
+  name: "norm2"
+  type: "LRN"
+  bottom: "pool2"
+  top: "norm2"
+  lrn_param {
+    local_size: 3
+    alpha: 5e-05
+    beta: 0.75
+    norm_region: WITHIN_CHANNEL
+  }
+}
+layer {
+  name: "conv3"
+  type: "Convolution"
+  bottom: "norm2"
+  top: "conv3"
+  convolution_param {
+    num_output: 64
+    pad: 2
+    kernel_size: 5
+    stride: 1
+    weight_filler {
+      type: "gaussian"
+      std: 0.01
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+}
+layer {
+  name: "relu3"
+  type: "ReLU"
+  bottom: "conv3"
+  top: "conv3"
+}
+layer {
+  name: "pool3"
+  type: "Pooling"
+  bottom: "conv3"
+  top: "pool3"
+  pooling_param {
+    pool: AVE
+    kernel_size: 3
+    stride: 2
+  }
+}
+layer {
+  name: "ip1"
+  type: "InnerProduct"
+  bottom: "pool3"
+  top: "ip1"
+  param {
+    lr_mult: 1
+    decay_mult: 250
+  }
+  param {
+    lr_mult: 2
+    decay_mult: 0
+  }
+  inner_product_param {
+    num_output: 10
+    weight_filler {
+      type: "gaussian"
+      std: 0.01
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+}
+layer {
+  name: "accuracy"
+  type: "Accuracy"
+  bottom: "ip1"
+  bottom: "label"
+  top: "accuracy"
+  include {
+    phase: TEST
+  }
+}
+layer {
+  name: "loss"
+  type: "SoftmaxWithLoss"
+  bottom: "ip1"
+  bottom: "label"
+  top: "loss"
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/77405ec1/examples/cifar10/caffe/cifar10_quick_solver.prototxt
----------------------------------------------------------------------
diff --git a/examples/cifar10/caffe/cifar10_quick_solver.prototxt b/examples/cifar10/caffe/cifar10_quick_solver.prototxt
new file mode 100644
index 0000000..3c5ce96
--- /dev/null
+++ b/examples/cifar10/caffe/cifar10_quick_solver.prototxt
@@ -0,0 +1,29 @@
+# From caffe repository
+# Commits on Sep 2, 2015
+
+# reduce the learning rate after 8 epochs (4000 iters) by a factor of 10
+
+# The train/test net protocol buffer definition
+net: "examples/cifar10/cifar10_quick_train_test.prototxt"
+# test_iter specifies how many forward passes the test should carry out.
+# In the case of MNIST, we have test batch size 100 and 100 test iterations,
+# covering the full 10,000 testing images.
+test_iter: 100
+# Carry out testing every 500 training iterations.
+test_interval: 500
+# The base learning rate, momentum and the weight decay of the network.
+base_lr: 0.001
+momentum: 0.9
+weight_decay: 0.004
+# The learning rate policy
+lr_policy: "fixed"
+# Display every 100 iterations
+display: 100
+# The maximum number of iterations
+max_iter: 4000
+# snapshot intermediate results
+snapshot: 4000
+snapshot_format: HDF5
+snapshot_prefix: "examples/cifar10/cifar10_quick"
+# solver mode: CPU or GPU
+solver_mode: GPU

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/77405ec1/examples/cifar10/caffe/cifar10_quick_train_test.prototxt
----------------------------------------------------------------------
diff --git a/examples/cifar10/caffe/cifar10_quick_train_test.prototxt b/examples/cifar10/caffe/cifar10_quick_train_test.prototxt
new file mode 100644
index 0000000..f83ca4b
--- /dev/null
+++ b/examples/cifar10/caffe/cifar10_quick_train_test.prototxt
@@ -0,0 +1,225 @@
+# From caffe repository
+# Commits on Feb 6, 2015
+
+name: "CIFAR10_quick"
+layer {
+  name: "cifar"
+  type: "Data"
+  top: "data"
+  top: "label"
+  include {
+    phase: TRAIN
+  }
+  transform_param {
+    mean_file: "examples/cifar10/mean.binaryproto"
+  }
+  data_param {
+    source: "examples/cifar10/cifar10_train_lmdb"
+    batch_size: 100
+    backend: LMDB
+  }
+}
+layer {
+  name: "cifar"
+  type: "Data"
+  top: "data"
+  top: "label"
+  include {
+    phase: TEST
+  }
+  transform_param {
+    mean_file: "examples/cifar10/mean.binaryproto"
+  }
+  data_param {
+    source: "examples/cifar10/cifar10_test_lmdb"
+    batch_size: 100
+    backend: LMDB
+  }
+}
+layer {
+  name: "conv1"
+  type: "Convolution"
+  bottom: "data"
+  top: "conv1"
+  param {
+    lr_mult: 1
+  }
+  param {
+    lr_mult: 2
+  }
+  convolution_param {
+    num_output: 32
+    pad: 2
+    kernel_size: 5
+    stride: 1
+    weight_filler {
+      type: "gaussian"
+      std: 0.0001
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+}
+layer {
+  name: "pool1"
+  type: "Pooling"
+  bottom: "conv1"
+  top: "pool1"
+  pooling_param {
+    pool: MAX
+    kernel_size: 3
+    stride: 2
+  }
+}
+layer {
+  name: "relu1"
+  type: "ReLU"
+  bottom: "pool1"
+  top: "pool1"
+}
+layer {
+  name: "conv2"
+  type: "Convolution"
+  bottom: "pool1"
+  top: "conv2"
+  param {
+    lr_mult: 1
+  }
+  param {
+    lr_mult: 2
+  }
+  convolution_param {
+    num_output: 32
+    pad: 2
+    kernel_size: 5
+    stride: 1
+    weight_filler {
+      type: "gaussian"
+      std: 0.01
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+}
+layer {
+  name: "relu2"
+  type: "ReLU"
+  bottom: "conv2"
+  top: "conv2"
+}
+layer {
+  name: "pool2"
+  type: "Pooling"
+  bottom: "conv2"
+  top: "pool2"
+  pooling_param {
+    pool: AVE
+    kernel_size: 3
+    stride: 2
+  }
+}
+layer {
+  name: "conv3"
+  type: "Convolution"
+  bottom: "pool2"
+  top: "conv3"
+  param {
+    lr_mult: 1
+  }
+  param {
+    lr_mult: 2
+  }
+  convolution_param {
+    num_output: 64
+    pad: 2
+    kernel_size: 5
+    stride: 1
+    weight_filler {
+      type: "gaussian"
+      std: 0.01
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+}
+layer {
+  name: "relu3"
+  type: "ReLU"
+  bottom: "conv3"
+  top: "conv3"
+}
+layer {
+  name: "pool3"
+  type: "Pooling"
+  bottom: "conv3"
+  top: "pool3"
+  pooling_param {
+    pool: AVE
+    kernel_size: 3
+    stride: 2
+  }
+}
+layer {
+  name: "ip1"
+  type: "InnerProduct"
+  bottom: "pool3"
+  top: "ip1"
+  param {
+    lr_mult: 1
+  }
+  param {
+    lr_mult: 2
+  }
+  inner_product_param {
+    num_output: 64
+    weight_filler {
+      type: "gaussian"
+      std: 0.1
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+}
+layer {
+  name: "ip2"
+  type: "InnerProduct"
+  bottom: "ip1"
+  top: "ip2"
+  param {
+    lr_mult: 1
+  }
+  param {
+    lr_mult: 2
+  }
+  inner_product_param {
+    num_output: 10
+    weight_filler {
+      type: "gaussian"
+      std: 0.1
+    }
+    bias_filler {
+      type: "constant"
+    }
+  }
+}
+layer {
+  name: "accuracy"
+  type: "Accuracy"
+  bottom: "ip2"
+  bottom: "label"
+  top: "accuracy"
+  include {
+    phase: TEST
+  }
+}
+layer {
+  name: "loss"
+  type: "SoftmaxWithLoss"
+  bottom: "ip2"
+  bottom: "label"
+  top: "loss"
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/77405ec1/examples/cifar10/train.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py
index 7494b8b..6443cf4 100644
--- a/examples/cifar10/train.py
+++ b/examples/cifar10/train.py
@@ -30,6 +30,7 @@ from singa import optimizer
 from singa import device
 from singa import tensor
 from singa.proto import core_pb2
+from caffe import caffe_net
 
 import alexnet
 import vgg
@@ -105,6 +106,13 @@ def resnet_lr(epoch):
         return 0.001
 
 
+def caffe_lr(epoch):
+    if epoch < 8:
+        return 0.001
+    else:
+        return 0.0001
+
+
 def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
           use_cpu=False):
     print 'Start intialization............'
@@ -163,8 +171,8 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Train dcnn for cifar10')
-    parser.add_argument('model', choices=['vgg', 'alexnet', 'resnet'],
-                        default='alexnet')
+    parser.add_argument('model', choices=['vgg', 'alexnet', 'resnet', 'caffe'],
+            default='alexnet')
     parser.add_argument('data', default='cifar-10-batches-py')
     parser.add_argument('--use_cpu', action='store_true')
     args = parser.parse_args()
@@ -173,10 +181,19 @@ if __name__ == '__main__':
     print 'Loading data ..................'
     train_x, train_y = load_train_data(args.data)
     test_x, test_y = load_test_data(args.data)
-    if args.model == 'alexnet':
+    if args.model == 'caffe':
+        train_x, test_x = normalize_for_alexnet(train_x, test_x)
+        net = caffe_net.create_net(args.use_cpu)
+        # for cifar10_full_train_test.prototxt
+        #train((train_x, train_y, test_x, test_y), net, 160, alexnet_lr, 0.004,
+        #      use_cpu=args.use_cpu)
+        # for cifar10_quick_train_test.prototxt
+        train((train_x, train_y, test_x, test_y), net, 18, caffe_lr, 0.004,
+              use_cpu=args.use_cpu)
+    elif args.model == 'alexnet':
         train_x, test_x = normalize_for_alexnet(train_x, test_x)
         net = alexnet.create_net(args.use_cpu)
-        train((train_x, train_y, test_x, test_y), net, 160, alexnet_lr, 0.004,
+        train((train_x, train_y, test_x, test_y), net, 2, alexnet_lr, 0.004,
               use_cpu=args.use_cpu)
     elif args.model == 'vgg':
         train_x, test_x = normalize_for_vgg(train_x, test_x)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/77405ec1/python/singa/converter.py
----------------------------------------------------------------------
diff --git a/python/singa/converter.py b/python/singa/converter.py
new file mode 100644
index 0000000..1378af0
--- /dev/null
+++ b/python/singa/converter.py
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+from google.protobuf import text_format
+from singa import layer
+from singa import metric
+from singa import loss
+from singa import net as ffnet
+from .proto import model_pb2
+from .proto import caffe_pb2
+
+
+class CaffeConverter:
+
+    def __init__(self, net_proto, solver_proto = None, input_sample_shape = None):
+        self.caffe_net_path = net_proto
+        self.caffe_solver_path = solver_proto
+        self.input_sample_shape = input_sample_shape
+
+    def read_net_proto(self):
+        net_config = caffe_pb2.NetParameter()
+        return self.read_proto(self.caffe_net_path, net_config)
+
+    def read_solver_proto(self):
+        solver_config = caffe_pb2.SolverParameter()
+        return self.read_proto(self.caffe_solver_path, solver_config)
+
+    def read_proto(self, filepath, parser_object):
+        file = open(filepath, "r")
+        if not file:
+            raise self.ProcessException("ERROR (" + filepath + ")!")
+        # Merges an ASCII representation of a protocol message into a message.
+        text_format.Merge(str(file.read()), parser_object)
+        file.close()
+        return parser_object
+
+    def convert_engine(self, layer_conf, solver_mode):
+        '''
+        Convert caffe engine into singa engine
+        return:
+            a singa engine string
+        '''
+        caffe_engine = ''
+        singa_engine = ''
+
+        # if no 'engine' field in caffe proto, set engine to -1
+        if layer_conf.type == 'Convolution' or layer_conf.type == 4:
+            caffe_engine = layer_conf.convolution_param.engine
+        elif layer_conf.type == 'Pooling' or layer_conf.type == 17:
+            caffe_engine = layer_conf.pooling_param.engine
+        elif layer_conf.type == 'ReLU' or layer_conf.type == 18:
+            caffe_engine = layer_conf.relu_param.engine
+        elif layer_conf.type == 'Sigmoid' or layer_conf.type == 19:
+            caffe_engine = layer_conf.sigmoid_param.engine
+        elif layer_conf.type == 'TanH' or layer_conf.type == 23:
+            caffe_engine = layer_conf.tanh_param.engine
+        elif layer_conf.type == 'LRN' or layer_conf.type == 15:
+            caffe_engine = layer_conf.lrn_param.engine
+        elif layer_conf.type == 'Softmax' or layer_conf.type == 20:
+            caffe_engine = layer_conf.softmax_param.engine
+        elif layer_conf.type == 'InnerProduct' or layer_conf.type == 14:
+            caffe_engine = -1
+        elif layer_conf.type == 'Dropout' or layer_conf.type == 6:
+            caffe_engine = -1
+        elif layer_conf.type == 'Flatten' or layer_conf.type == 8:
+            caffe_engine = -1
+        else:
+            raise Exception('Unknown layer type: ' + layer_conf.type)
+
+        # caffe_engine: -1-no field;  0-DEFAULT; 1-CAFFE; 2-CUDNN
+        # solver_mode: 0-CPU; 1-GPU
+        if solver_mode == 1:
+            singa_engine = 'cudnn'
+        else:
+            if caffe_engine == 2:
+                raise Exception('engine and solver mode mismatch!')
+            else:
+                singa_engine = 'singacpp'
+
+        if ((layer_conf.type == 'InnerProduct' or layer_conf.type == 14) or \
+            (layer_conf.type == 'Flatten' or layer_conf.type == 8)) and \
+            singa_engine == 'cudnn':
+            singa_engine = 'singacuda'
+
+        return singa_engine
+
+
+    def create_net(self):
+        '''
+        Create singa net based on caffe proto files.
+            net_proto: caffe prototxt that describes net
+            solver_proto: caffe prototxt that describe solver
+            input_sample_shape: shape of input data tensor
+        return:
+            a FeedForwardNet object
+        '''
+        caffe_net = self.read_net_proto()
+        if self.caffe_solver_path is not None:
+            caffe_solver = self.read_solver_proto()
+        layer_confs = ''
+        flatten_id = 0
+
+        # If the net proto has the input shape
+        if len(caffe_net.input_dim) > 0:
+            self.input_sample_shape = caffe_net.input_dim
+        if len(caffe_net.layer):
+            layer_confs = caffe_net.layer
+        elif len(caffe_net.layers):
+            layer_confs = caffe_net.layers
+        else:
+            raise Exception('Invalid proto file!')
+
+        net = ffnet.FeedForwardNet()
+        for i in range(len(layer_confs)):
+            if layer_confs[i].type == 'Data' or layer_confs[i].type == 5:
+                continue
+            elif layer_confs[i].type == 'SoftmaxWithLoss' or layer_confs[i].type == 21:
+                net.loss = loss.SoftmaxCrossEntropy()
+            elif layer_confs[i].type == 'EuclideanLoss' or layer_confs[i].type == 7:
+                net.loss = loss.SquareError()
+            elif layer_confs[i].type == 'Accuracy' or layer_confs[i].type == 1:
+                net.metric = metric.Accuracy()
+            else:
+                strConf = layer_confs[i].SerializeToString()
+                conf = model_pb2.LayerConf()
+                conf.ParseFromString(strConf)
+                if caffe_solver:
+                    layer.engine = self.convert_engine(
+                        layer_confs[i], caffe_solver.solver_mode)
+                else:
+                    layer.engine = self.convert_engine(layer_confs[i], 0)
+                lyr = layer.Layer(conf.name, conf)
+                if len(net.layers) == 0:
+                    lyr.setup(self.input_sample_shape)
+                    print lyr.name, lyr.get_output_sample_shape()
+                if layer_confs[i].type == 'InnerProduct' or layer_confs[i].type == 14:
+                    net.add(layer.Flatten('flat' + str(flatten_id)))
+                    flatten_id += 1
+                net.add(lyr)
+
+        return net

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/77405ec1/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/python/singa/layer.py b/python/singa/layer.py
index a22af55..e376bbf 100644
--- a/python/singa/layer.py
+++ b/python/singa/layer.py
@@ -78,14 +78,52 @@ class Layer(object):
         name (str): layer name
     '''
 
-    def __init__(self, name, **kwargs):
-        self.layer = None  # layer converted by swig
-        self.name = name  # TODO(wangwei) duplicate with self.conf.name
-        self.conf = model_pb2.LayerConf()
-        self.conf.name = name
-        self.param_specs = []
+    def __init__(self, name, conf=None, **kwargs):
+        if conf == None:
+            self.layer = None  # layer converted by swig
+            self.name = name  # TODO(wangwei) duplicate with self.conf.name
+            self.conf = model_pb2.LayerConf()
+            self.conf.name = name
+            self.param_specs = []
+        else:
+            self.conf = conf
+            self.name = str(conf.name)
+            self.caffe_layer()
+            self.param_specs = []
+
+            # convert caffe proto into singa proto format
+            #   case1: parameters of conv and dense layers
+            #   case2: type of activation layers
+            if (conf.type == 'Convolution' or conf.type == 4) or \
+                (conf.type == 'InnerProduct' or conf.type == 14):
+                w, b = _construct_param_specs_from_caffe_proto(conf)
+                del conf.param[:]
+                conf.param.extend([w, b])
+                self.param_specs.append(w)
+                self.param_specs.append(b)
+                #print 'conf:\n', conf
+            #if conf.type == 'Pooling':
+            #    print 'conf:\n', conf
+
+            elif (conf.type == 'ReLU' or conf.type == 18) or \
+                (conf.type == 'Sigmoid' or conf.type == 19) or \
+                (conf.type == 'TanH' or conf.type == 23):
+                conf.type = (engine + '_' + conf.type).lower()
+            self.conf = conf
+
         self.has_setup = False
 
+    def caffe_layer(self):
+        '''
+        Create a singa layer based on caffe layer configuration.
+        '''
+        _check_engine(engine, ['cudnn', 'singacpp', 'singacuda'])
+        if self.conf.type == 'InnerProduct' or self.conf.type == 14:
+            self.layer = _create_layer(engine, 'Dense')
+        else:
+            self.layer = _create_layer(engine, str(self.conf.type))
+
+
     def param_names(self):
         '''
         Returns:
@@ -243,6 +281,7 @@ class Conv2D(Layer):
             without the batchsize, e.g., (channel, height, width) or
             (height, width, channel)
     """
+
     def __init__(self, name, nb_kernels, kernel=3, stride=1, border_mode='same',
                  cudnn_prefer='fatest', data_format='NCHW',
                  use_bias=True, W_specs=None, b_specs=None,
@@ -319,6 +358,7 @@ class Pooling2D(Layer):
             model_pb2.PoolingConf.AVE
 
     '''
+
     def __init__(self, name, mode, kernel=3, stride=2, border_mode='same',
                  pad=None, data_format='NCHW', input_sample_shape=None):
         super(Pooling2D, self).__init__(name)
@@ -421,6 +461,7 @@ class BatchNormalization(Layer):
         name (string): layer name
         input_sample_shape (tuple): with at least one integer
     """
+
     def __init__(self, name, momentum=0.9,
                  beta_specs=None, gamma_specs=None, input_sample_shape=None):
         super(BatchNormalization, self).__init__(name)
@@ -434,8 +475,8 @@ class BatchNormalization(Layer):
             beta_specs['name'] = name + '_beta'
         if 'name' not in gamma_specs:
             gamma_specs['name'] = name + '_gamma'
-        mean_specs = {'init': 'constant', 'value': 0, 'name': name+'_mean'}
-        var_specs = {'init': 'constant', 'value': 1, 'name': name+'_var'}
+        mean_specs = {'init': 'constant', 'value': 0, 'name': name + '_mean'}
+        var_specs = {'init': 'constant', 'value': 1, 'name': name + '_var'}
         self.conf.param.extend([_construct_param_specs_from_dict(gamma_specs)])
         self.conf.param.extend([_construct_param_specs_from_dict(beta_specs)])
         self.conf.param.extend([_construct_param_specs_from_dict(mean_specs)])
@@ -499,6 +540,7 @@ class Dense(Layer):
         W_transpose (bool): if true, output=x*W.T+b;
         input_sample_shape (tuple): input feature length
     """
+
     def __init__(self, name, num_output, use_bias=True,
                  W_specs=None, b_specs=None,
                  W_transpose=False, input_sample_shape=None):
@@ -579,6 +621,7 @@ class Activation(Layer):
         mode (string): 'relu', 'sigmoid', or 'tanh'
         input_sample_shape (tuple): shape of a single sample
     """
+
     def __init__(self, name, mode='relu', input_sample_shape=None):
         super(Activation, self).__init__(name)
         _check_engine(engine, ['cudnn', 'singacpp', 'singacuda', 'singacl'])
@@ -596,6 +639,7 @@ class Softmax(Layer):
             [0,axis) as the row, the [axis, -1) as the column.
         input_sample_shape (tuple): shape of a single sample
     """
+
     def __init__(self, name, axis=1, input_sample_shape=None):
         super(Softmax, self).__init__(name)
         # conf = self.conf.softmax_conf
@@ -615,6 +659,7 @@ class Flatten(Layer):
             [0,axis) as the row, the [axis, -1) as the column.
         input_sample_shape (tuple): shape for a single sample
     """
+
     def __init__(self, name, axis=1, input_sample_shape=None):
         super(Flatten, self).__init__(name)
         conf = self.conf.flatten_conf
@@ -635,6 +680,7 @@ class Merge(Layer):
         input_sample_shape: sample shape of the input. The sample shape of all
             inputs should be the same.
     '''
+
     def __init__(self, name, input_sample_shape=None):
         self.in_shape = input_sample_shape
         self.num_input = 1
@@ -669,6 +715,7 @@ class Split(Layer):
         input_sample_shape: includes a single integer for the input sample
             feature size.
     '''
+
     def __init__(self, name, num_output, input_sample_shape=None):
         self.num_output = num_output
         self.in_shape = input_sample_shape
@@ -806,6 +853,7 @@ class RNN(Layer):
 
 
 class LSTM(RNN):
+
     def __init__(self, name, hidden_size, dropout=0.0, num_stacks=1,
                  input_mode='linear', bidirectional=False,
                  param_specs=None, input_sample_shape=None):
@@ -815,6 +863,7 @@ class LSTM(RNN):
 
 
 class GRU(RNN):
+
     def __init__(self, name, hidden_size, dropout=0.0, num_stacks=1,
                  input_mode='linear', bidirectional=False, param_specs=None,
                  input_sample_shape=None):
@@ -825,8 +874,8 @@ class GRU(RNN):
 
 def _check_engine(engine, allowed_engines):
     assert engine.lower() in Set(allowed_engines), \
-           '%s is not a supported engine. Pls use one of %s' % \
-           (engine, ', '.join(allowed_engines))
+        '%s is not a supported engine. Pls use one of %s' % \
+        (engine, ', '.join(allowed_engines))
 
 
 def _create_layer(eng, layer):
@@ -839,6 +888,7 @@ def _create_layer(eng, layer):
         layers, use the specific activation mode, e.g. 'relu', 'tanh'.
     '''
     layer_type = eng + '_' + layer
+    print layer_type.lower()
     return singa_wrap.CreateLayer(layer_type.lower())
 
 
@@ -927,6 +977,63 @@ def _construct_param_specs_from_dict(specs):
     return conf
 
 
+def _construct_param_specs_from_caffe_proto(lyr_conf):
+    """convert the param specs from a caffe layer proto into a singa paramspec
+    protobuf object.
+
+    args:
+        specs (dict): the fields inlcude
+            'name' for parameter name
+            'lr_mult' for learning rate multiplier;
+            'decay_mult' for weight decay multiplier;
+            'init' for init method, which could be 'gaussian', 'uniform',
+            'xavier' and 'msra';
+            'std', 'mean', 'high', 'low' are used by corresponding init methods;
+            caffe model has no 'constraint' and 'regularizer'
+
+    returns:
+        a pair of paramspec objects(weight and bias)
+    """
+    wparam = model_pb2.ParamSpec()
+    bparam = model_pb2.ParamSpec()
+    if len(lyr_conf.param) > 0:
+        wparam.name = lyr_conf.param[0].name
+        wparam.lr_mult = lyr_conf.param[0].lr_mult
+        wparam.decay_mult = lyr_conf.param[0].decay_mult
+        if len(lyr_conf.param) > 1:
+            bparam.name = lyr_conf.param[1].name
+            bparam.lr_mult = lyr_conf.param[1].lr_mult
+            bparam.decay_mult = lyr_conf.param[1].decay_mult
+    if wparam.name == '' or wparam.name is None:
+        wparam.name = lyr_conf.name + '_weight'
+    if bparam.name == '' or bparam.name is None:
+        bparam.name = lyr_conf.name + '_bias'
+    wfiller = wparam.filler
+    bfiller = bparam.filler
+    param = ''
+    if lyr_conf.type == 'Convolution' or lyr_conf.type == 4:
+        param = lyr_conf.convolution_conf
+    elif lyr_conf.type == 'InnerProduct' or lyr_conf.type == 14:
+        param = lyr_conf.dense_conf
+
+    if param != '':
+        wfiller.type = param.weight_filler.type.lower()
+        wfiller.min = param.weight_filler.min
+        wfiller.max = param.weight_filler.max
+        wfiller.mean = param.weight_filler.mean
+        wfiller.std = param.weight_filler.std
+        wfiller.value = param.weight_filler.value
+
+        bfiller.type = param.bias_filler.type.lower()
+        bfiller.min = param.bias_filler.min
+        bfiller.max = param.bias_filler.max
+        bfiller.mean = param.bias_filler.mean
+        bfiller.std = param.bias_filler.std
+        bfiller.value = param.bias_filler.value
+
+    return (wparam, bparam)
+
+
 def get_layer_list():
     """ Return a list of strings which include the identifiers (tags) of all
     supported layers