You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/08/09 16:02:54 UTC

[1/2] incubator-singa git commit: SINGA-231 Batchnormlized VGG model for cifar-10

Repository: incubator-singa
Updated Branches:
  refs/heads/dev db5478efa -> 28678ae83


SINGA-231 Batchnormlized VGG model for cifar-10

In this ticket, we implemented a batch normalized VGG model for cifar10
dataset (refer to http://torch.ch/blog/2015/07/30/cifar.html).

*    +vgg-parallel.cc for parallel training
*    +vgg.py using python language
*    fix a bug in ResetLike() method in tensor.h, which before did not
     reset shape.
*    fix a bug in local_updater.cc, which may cause race condition when
     multi-threads try to initialize mutexes concurrently.
*    revise batch nomalization layer to support 2D tensor input


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/bc3b74b3
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/bc3b74b3
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/bc3b74b3

Branch: refs/heads/dev
Commit: bc3b74b3662230f867c42344f0600498368f4785
Parents: db5478e
Author: WANG Ji <ij...@gmail.com>
Authored: Sat Aug 6 17:36:28 2016 +0800
Committer: WANG Ji <ij...@gmail.com>
Committed: Mon Aug 8 11:44:01 2016 +0800

----------------------------------------------------------------------
 examples/cifar10/CMakeLists.txt       |   5 +
 examples/cifar10/train_vgg_cifar10.py | 162 ++++++++++++++
 examples/cifar10/vgg-parallel.cc      | 333 +++++++++++++++++++++++++++++
 examples/cifar10/vgg.py               |  52 +++++
 src/core/tensor/tensor.cc             |   2 +-
 src/model/layer/batchnorm.cc          |  25 ++-
 src/model/layer/batchnorm.h           |   3 +-
 src/model/layer/cudnn_batchnorm.cc    |  31 ++-
 src/model/updater/local_updater.cc    |   1 +
 src/python/singa/layer.py             |  10 +-
 src/python/singa/net.py               |   6 +-
 11 files changed, 613 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/examples/cifar10/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/examples/cifar10/CMakeLists.txt b/examples/cifar10/CMakeLists.txt
index 92f884c..76c0b73 100644
--- a/examples/cifar10/CMakeLists.txt
+++ b/examples/cifar10/CMakeLists.txt
@@ -10,4 +10,9 @@ ADD_EXECUTABLE(alexnet-parallel alexnet-parallel.cc)
 ADD_DEPENDENCIES(alexnet-parallel singa_core singa_model singa_utils)
 TARGET_LINK_LIBRARIES(alexnet-parallel singa_core singa_utils singa_model protobuf ${SINGA_LIBKER_LIBS})
 SET_TARGET_PROPERTIES(alexnet-parallel PROPERTIES LINK_FLAGS "${LINK_FLAGS} -pthread")
+
+ADD_EXECUTABLE(vgg-parallel vgg-parallel.cc)
+ADD_DEPENDENCIES(vgg-parallel singa_core singa_model singa_utils)
+TARGET_LINK_LIBRARIES(vgg-parallel singa_core singa_utils singa_model protobuf ${SINGA_LIBKER_LIBS})
+SET_TARGET_PROPERTIES(vgg-parallel PROPERTIES LINK_FLAGS "${LINK_FLAGS} -pthread")
 ENDIF(USE_CUDNN)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/examples/cifar10/train_vgg_cifar10.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/train_vgg_cifar10.py b/examples/cifar10/train_vgg_cifar10.py
new file mode 100644
index 0000000..e9df04e
--- /dev/null
+++ b/examples/cifar10/train_vgg_cifar10.py
@@ -0,0 +1,162 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+""" CIFAR10 dataset is at https://www.cs.toronto.edu/~kriz/cifar.html.
+It includes 5 binary dataset, each contains 10000 images. 1 row (1 image)
+includes 1 label & 3072 pixels.  3072 pixels are 3 channels of a 32x32 image
+"""
+
+import cPickle
+import numpy as np
+import os
+import sys
+import math
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
+from singa import initializer
+from singa import utils
+from singa import optimizer
+from singa import device
+from singa import tensor
+from singa.proto import core_pb2
+
+import vgg
+
+
+def load_dataset(filepath):
+    print 'Loading data file %s' % filepath
+    with open(filepath, 'rb') as fd:
+        cifar10 = cPickle.load(fd)
+    image = cifar10['data'].astype(dtype=np.uint8)
+    image = image.reshape((-1, 3, 32, 32))
+    label = np.asarray(cifar10['labels'], dtype=np.uint8)
+    label = label.reshape(label.size, 1)
+    return image, label
+
+
+def load_train_data(dir_path, num_batches=5):
+    labels = []
+    batchsize = 10000
+    images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8)
+    for did in range(1, num_batches + 1):
+        fname_train_data = dir_path + "/data_batch_{}".format(did)
+        image, label = load_dataset(fname_train_data)
+        images[(did - 1) * batchsize:did * batchsize] = image
+        labels.extend(label)
+    images = np.array(images, dtype=np.float32)
+    labels = np.array(labels, dtype=np.int32)
+    return images, labels
+
+
+def load_test_data(dir_path):
+    images, labels = load_dataset(dir_path + "/test_batch")
+    return np.array(images,  dtype=np.float32), np.array(labels, dtype=np.int32)
+
+
+def get_lr(epoch):
+    return 0.01 / float(1 << ((epoch / 30)))
+    #if epoch < 100:
+    #    return 0.01
+    #elif epoch < 150:
+    #    return 0.005
+    #elif epoch < 200:
+    #    return 0.001
+    #elif epoch < 250:
+    #    return 0.0001
+
+
+def train(data_dir, net, num_epoch=250, batch_size=128):
+    print 'Creating Device............'
+    cuda = device.create_cuda_gpus(2)[1]
+    net.to_device(cuda)
+    print 'Start intialization............'
+    opt = optimizer.SGD(momentum=0.9, weight_decay=0.0005)
+    for (p, name) in zip(net.param_values(), net.param_names()):
+        print name, p.shape
+        if len(p.shape) > 1:
+            if 'mean' in name  or 'beta' in name:
+                p.set_value(0.0)
+            elif 'var' in name:
+                p.set_value(1.0)
+            elif 'gamma' in name:
+                initializer.uniform(p, 0, 1)
+            elif 'conv' in name:
+                initializer.gaussian(p, 0, math.sqrt(2.0/(9.0 * p.shape[0])))
+            else:
+                initializer.gaussian(p, 0, 0.02)
+
+                #stdv = 1.0/math.sqrt(p.shape[1])
+                #initializer.uniform(p, -stdv, stdv)
+        else:
+            p.set_value(0)
+        #print specs.name, filler.type, p.l1()
+        print name, p.l1()
+    print 'Loading data ..................'
+    train_x, train_y = load_train_data(data_dir)
+    test_x, test_y = load_test_data(data_dir)
+    mean = train_x.mean()
+    std = train_x.std()
+    train_x -= mean
+    test_x -= mean
+    train_x /= std
+    test_x /= std
+
+    tx = tensor.Tensor((batch_size, 3, 32, 32), cuda)
+    ty = tensor.Tensor((batch_size,), cuda, core_pb2.kInt)
+    num_train_batch = train_x.shape[0] / batch_size
+    num_test_batch = test_x.shape[0] / batch_size
+    idx = np.arange(train_x.shape[0], dtype=np.int32)
+    for epoch in range(num_epoch):
+        np.random.shuffle(idx)
+        loss, acc = 0.0, 0.0
+        print 'Epoch %d' % epoch
+        for b in range(num_train_batch):
+            x = train_x[idx[b * batch_size: (b + 1) * batch_size]]
+            y = train_y[idx[b * batch_size: (b + 1) * batch_size]]
+            tx.copy_from_numpy(x)
+            ty.copy_from_numpy(y)
+            grads, (l, a) = net.train(tx, ty)
+            loss += l
+            acc += a
+            for (s, p, g) in zip(net.param_specs(), net.param_values(), grads):
+                opt.apply_with_lr(epoch, get_lr(epoch), g, p, str(s.name))
+            # update progress bar
+            utils.update_progress(b * 1.0 / num_train_batch,
+                                  'training loss = %f, accuracy = %f' % (l, a))
+        info = '\ntraining loss = %f, training accuracy = %f' \
+            % (loss / num_train_batch, acc / num_train_batch)
+        print info
+
+        loss, acc = 0.0, 0.0
+        for b in range(num_test_batch):
+            x = test_x[b * batch_size: (b + 1) * batch_size]
+            y = test_y[b * batch_size: (b + 1) * batch_size]
+            tx.copy_from_numpy(x)
+            ty.copy_from_numpy(y)
+            l, a = net.evaluate(tx, ty)
+            loss += l
+            acc += a
+
+        print 'test loss = %f, test accuracy = %f' \
+            % (loss / num_test_batch, acc / num_test_batch)
+    net.save('model.bin')  # save model params into checkpoint file
+
+if __name__ == '__main__':
+    data_dir = 'cifar-10-batches-py'
+    assert os.path.exists(data_dir), \
+        'Pls download the cifar10 dataset via "download_data.py py"'
+    net = vgg.create_net()
+    train(data_dir, net)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/examples/cifar10/vgg-parallel.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/vgg-parallel.cc b/examples/cifar10/vgg-parallel.cc
new file mode 100644
index 0000000..ba308e9
--- /dev/null
+++ b/examples/cifar10/vgg-parallel.cc
@@ -0,0 +1,333 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "cifar10.h"
+#include "singa/model/feed_forward_net.h"
+#include "singa/model/optimizer.h"
+#include "singa/model/updater.h"
+#include "singa/model/initializer.h"
+#include "singa/model/metric.h"
+#include "singa/utils/channel.h"
+#include "singa/utils/string.h"
+#include "singa/core/memory.h"
+#include "../../src/model/layer/cudnn_convolution.h"
+#include "../../src/model/layer/cudnn_activation.h"
+#include "../../src/model/layer/cudnn_pooling.h"
+#include "../../src/model/layer/cudnn_lrn.h"
+#include "../../src/model/layer/cudnn_dropout.h"
+#include "../../src/model/layer/cudnn_batchnorm.h"
+#include "../../src/model/layer/dense.h"
+#include "../../src/model/layer/flatten.h"
+#include <thread>
+#include <memory>
+#include <cmath>
+
+namespace singa {
+
+const float default_wd  = 0.0005f;
+
+LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
+                      int pad, float std = .02f, float bias = .0f) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("CudnnConvolution");
+  ConvolutionConf *conv = conf.mutable_convolution_conf();
+  conv->set_num_output(nb_filter);
+  conv->add_kernel_size(kernel);
+  conv->add_stride(stride);
+  conv->add_pad(pad);
+  conv->set_bias_term(true);
+
+  ParamSpec *wspec = conf.add_param();
+  wspec->set_name(name + "_weight");
+  auto wfill = wspec->mutable_filler();
+  wfill->set_type("Gaussian");
+  wfill->set_std(sqrt(2.0f/(nb_filter*9.0f)));
+
+  ParamSpec *bspec = conf.add_param();
+  bspec->set_name(name + "_bias");
+  auto bfill = bspec->mutable_filler();
+  bfill->set_value(bias);
+  //  bspec->set_lr_mult(2);
+  //  bspec->set_decay_mult(0);
+  return conf;
+}
+
+LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride,
+                         int pad) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("CudnnPooling");
+  PoolingConf *pool = conf.mutable_pooling_conf();
+  pool->set_kernel_size(kernel);
+  pool->set_stride(stride);
+  pool->set_pad(pad);
+  if (!max_pool) pool->set_pool(PoolingConf_PoolMethod_AVE);
+  return conf;
+}
+
+LayerConf GenReLUConf(string name) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("RELU");
+  return conf;
+}
+
+LayerConf GenDenseConf(string name, int num_output, float std, float wd = default_wd) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("Dense");
+  DenseConf *dense = conf.mutable_dense_conf();
+  dense->set_num_output(num_output);
+
+  ParamSpec *wspec = conf.add_param();
+  wspec->set_name(name + "_weight");
+  wspec->set_decay_mult(wd);
+  auto wfill = wspec->mutable_filler();
+  wfill->set_type("Gaussian");
+  wfill->set_std(std);
+
+  ParamSpec *bspec = conf.add_param();
+  bspec->set_name(name + "_bias");
+  bspec->set_lr_mult(2);
+  bspec->set_decay_mult(0);
+
+  return conf;
+}
+
+LayerConf GenFlattenConf(string name) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("Flatten");
+  return conf;
+}
+
+LayerConf GenBatchNormConf(string name) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("CudnnBatchNorm");
+  ParamSpec *gammaspec = conf.add_param();
+  gammaspec->set_name(name + "_gamma");
+  auto gammafill = gammaspec->mutable_filler();
+  gammafill->set_type("uniform");
+  gammafill->set_min(0);
+  gammafill->set_max(1);
+
+  ParamSpec *betaspec = conf.add_param();
+  betaspec->set_name(name + "_beta");
+  auto betafill = betaspec->mutable_filler();
+  betafill->set_type("constant");
+  betafill->set_value(0);
+
+  ParamSpec *meanspec = conf.add_param();
+  meanspec->set_name(name + "_mean");
+  auto meanfill = meanspec->mutable_filler();
+  meanfill->set_type("constant");
+  meanfill->set_value(0);
+
+  ParamSpec *varspec = conf.add_param();
+  varspec->set_name(name + "_var");
+  auto varfill = varspec->mutable_filler();
+  varfill->set_type("constant");
+  varfill->set_value(1);
+
+  return conf;
+}
+
+LayerConf GenDropoutConf(string name, float dropout_ratio) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("CudnnDropout");
+  DropoutConf *dropout = conf.mutable_dropout_conf();
+  dropout->set_dropout_ratio(dropout_ratio);
+
+  return conf;
+}
+
+void ConvBNReLU(FeedForwardNet& net, string name, int nb_filter, Shape* shape = nullptr) {
+  net.Add(new CudnnConvolution(), GenConvConf(name+"_conv", nb_filter, 3, 1, 1), shape);
+  net.Add(new CudnnBatchNorm(), GenBatchNormConf(name+"_bn"));
+  net.Add(new CudnnActivation(), GenReLUConf(name+"_relu"));
+}
+
+FeedForwardNet CreateNet() {
+  FeedForwardNet net;
+  Shape s{3, 32, 32};
+  ConvBNReLU(net, "conv1_1", 64, &s);
+  net.Add(new CudnnDropout(), GenDropoutConf("drop1", 0.3));
+  ConvBNReLU(net, "conv1_2", 64);
+  net.Add(new CudnnPooling(), GenPoolingConf("pool1", true, 2, 2, 0));
+  ConvBNReLU(net, "conv2_1", 128);
+  net.Add(new CudnnDropout(), GenDropoutConf("drop2", 0.4));
+  ConvBNReLU(net, "conv2_2", 128);
+  net.Add(new CudnnPooling(), GenPoolingConf("pool2", true, 2, 2, 0));
+  ConvBNReLU(net, "conv3_1", 256);
+  net.Add(new CudnnDropout(), GenDropoutConf("drop3_1", 0.4));
+  ConvBNReLU(net, "conv3_2", 256);
+  net.Add(new CudnnDropout(), GenDropoutConf("drop3_2", 0.4));
+  ConvBNReLU(net, "conv3_3", 256);
+  net.Add(new CudnnPooling(), GenPoolingConf("pool3", true, 2, 2, 0));
+  ConvBNReLU(net, "conv4_1", 512);
+  net.Add(new CudnnDropout(), GenDropoutConf("drop4_1", 0.4));
+  ConvBNReLU(net, "conv4_2", 512);
+  net.Add(new CudnnDropout(), GenDropoutConf("drop4_2", 0.4));
+  ConvBNReLU(net, "conv4_3", 512);
+  net.Add(new CudnnPooling(), GenPoolingConf("pool4", true, 2, 2, 0));
+  ConvBNReLU(net, "conv5_1", 512);
+  net.Add(new CudnnDropout(), GenDropoutConf("drop5_1", 0.4));
+  ConvBNReLU(net, "conv5_2", 512);
+  net.Add(new CudnnDropout(), GenDropoutConf("drop5_2", 0.4));
+  ConvBNReLU(net, "conv5_3", 512);
+  net.Add(new CudnnPooling(), GenPoolingConf("pool5", true, 2, 2, 0));
+  net.Add(new Flatten(), GenFlattenConf("flat"));
+  net.Add(new CudnnDropout(), GenDropoutConf("flat_drop", 0.5));
+  net.Add(new Dense(), GenDenseConf("ip1", 512, 0.02));
+  net.Add(new CudnnBatchNorm(), GenBatchNormConf("ip1_bn"));
+  net.Add(new CudnnActivation(), GenReLUConf("ip1_relu"));
+  net.Add(new CudnnDropout(), GenDropoutConf("ip1_drop", 0.5));
+  net.Add(new Dense(), GenDenseConf("ip2", 10, 0.02));
+
+  return net;
+}
+
+void Train(float lr, int num_epoch, string data_dir) {
+  Cifar10 data(data_dir);
+  Tensor train_x, train_y, test_x, test_y;
+  Tensor train_x_1, train_x_2, train_y_1, train_y_2;
+  {
+    auto train = data.ReadTrainData();
+    size_t nsamples = train.first.shape(0);
+    auto mtrain =
+        Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples});
+    const Tensor &mean = Average(mtrain, 0);
+    SubRow(mean, &mtrain);
+    Tensor std = Square(mtrain);
+    std = Average(std, 0);
+    std = Sqrt(std);;
+    std += 1e-6f;
+    DivRow(std, &mtrain);
+
+    train_x = Reshape(mtrain, train.first.shape());
+    train_y = train.second;
+
+    LOG(INFO) << "Slicing training data...";
+    train_x_1.Reshape(Shape{nsamples / 2, train.first.shape(1),
+        train.first.shape(2), train.first.shape(3)});
+    LOG(INFO) << "Copying first data slice...";
+    CopyDataToFrom(&train_x_1, train_x, train_x.Size() / 2);
+    train_x_2.Reshape(Shape{nsamples / 2, train.first.shape(1),
+        train.first.shape(2), train.first.shape(3)});
+    LOG(INFO) << "Copying second data slice...";
+    CopyDataToFrom(&train_x_2, train_x, train_x.Size() / 2, 0,
+                   train_x.Size() / 2);
+    train_y_1.Reshape(Shape{nsamples / 2});
+    train_y_1.AsType(kInt);
+    LOG(INFO) << "Copying first label slice...";
+    CopyDataToFrom(&train_y_1, train_y, train_y.Size() / 2);
+    train_y_2.Reshape(Shape{nsamples / 2});
+    train_y_2.AsType(kInt);
+    LOG(INFO) << "Copying second label slice...";
+    CopyDataToFrom(&train_y_2, train_y, train_y.Size() / 2, 0,
+                   train_y.Size() / 2);
+
+    auto test = data.ReadTestData();
+    nsamples = test.first.shape(0);
+    auto mtest =
+        Reshape(test.first, Shape{nsamples, test.first.Size() / nsamples});
+    SubRow(mean, &mtest);
+    DivRow(std, &mtest);
+    test_x = Reshape(mtest, test.first.shape());
+    test_y = test.second;
+  }
+
+  CHECK_EQ(train_x.shape(0), train_y.shape(0));
+  CHECK_EQ(test_x.shape(0), test_y.shape(0));
+  LOG(INFO) << "Total Training samples = " << train_y.shape(0)
+            << ", Total Test samples = " << test_y.shape(0);
+  CHECK_EQ(train_x_1.shape(0), train_y_1.shape(0));
+  LOG(INFO) << "On net 1, Training samples = " << train_y_1.shape(0)
+            << ", Test samples = " << test_y.shape(0);
+  CHECK_EQ(train_x_2.shape(0), train_y_2.shape(0));
+  LOG(INFO) << "On net 2, Training samples = " << train_y_2.shape(0);
+
+  auto net_1 = CreateNet();
+  auto net_2 = CreateNet();
+
+  SGD sgd;
+  OptimizerConf opt_conf;
+  opt_conf.set_momentum(0.9);
+  auto reg = opt_conf.mutable_regularizer();
+  reg->set_coefficient(0.0005);
+  sgd.Setup(opt_conf);
+  sgd.SetLearningRateGenerator([lr](int epoch) {
+    return 0.01f / static_cast<float>(1u << (epoch/30));
+  });
+
+  SoftmaxCrossEntropy loss_1, loss_2;
+  Accuracy acc_1, acc_2;
+  /// Create updater aggregating gradient on CPU
+  std::shared_ptr<Updater> updater = std::make_shared<LocalUpdater>(2, &sgd);
+
+  /// Only need to register parameter once.
+  net_1.Compile(true, true, updater, &loss_1, &acc_1);
+  net_2.Compile(true, false, updater, &loss_2, &acc_2);
+
+  MemPoolConf mem_conf;
+  mem_conf.add_device(0);
+  mem_conf.add_device(1);
+  std::shared_ptr<DeviceMemPool> mem_pool(new CnMemPool(mem_conf));
+  std::shared_ptr<CudaGPU> cuda_1(new CudaGPU(0, mem_pool));
+  std::shared_ptr<CudaGPU> cuda_2(new CudaGPU(1, mem_pool));
+  net_1.ToDevice(cuda_1);
+  net_2.ToDevice(cuda_2);
+
+  train_x_1.ToDevice(cuda_1);
+  train_y_1.ToDevice(cuda_1);
+  test_x.ToDevice(cuda_1);
+  test_y.ToDevice(cuda_1);
+  train_x_2.ToDevice(cuda_2);
+  train_y_2.ToDevice(cuda_2);
+
+  LOG(INFO) << "Launching thread...";
+  std::thread t1 =
+      net_1.TrainThread(50, num_epoch, train_x_1, train_y_1, test_x, test_y);
+  std::thread t2 = net_2.TrainThread(50, num_epoch, train_x_2, train_y_2);
+  t1.join();
+  t2.join();
+}
+}
+
+int main(int argc, char **argv) {
+  singa::InitChannel(nullptr);
+  int pos = singa::ArgPos(argc, argv, "-epoch");
+  int nEpoch = 1;
+  if (pos != -1) nEpoch = atoi(argv[pos + 1]);
+  pos = singa::ArgPos(argc, argv, "-lr");
+  float lr = 0.001;
+  if (pos != -1) lr = atof(argv[pos + 1]);
+  pos = singa::ArgPos(argc, argv, "-data");
+  string data = "cifar-10-batches-bin";
+  if (pos != -1) data = argv[pos + 1];
+
+  LOG(INFO) << "Start training";
+  singa::Train(lr, nEpoch, data);
+  LOG(INFO) << "End training";
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/examples/cifar10/vgg.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/vgg.py b/examples/cifar10/vgg.py
new file mode 100644
index 0000000..8063307
--- /dev/null
+++ b/examples/cifar10/vgg.py
@@ -0,0 +1,52 @@
+import sys
+import os
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
+from singa import layer
+from singa import metric
+from singa import loss
+from singa import net as ffnet
+from singa.proto import core_pb2
+
+def ConvBnReLU(net, name, nb_filers, sample_shape=None):
+    net.add(layer.Conv2D(name + '_1', nb_filers, 3, 1, pad=1,
+                         input_sample_shape=sample_shape))
+    net.add(layer.BatchNormalization(name + '_2'))
+    net.add(layer.Activation(name + '_3'))
+
+def create_net():
+    net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy())
+    ConvBnReLU(net, 'conv1_1', 64, (3, 32, 32))
+    net.add(layer.Dropout('drop1', 0.3, engine='cudnn'))
+    ConvBnReLU(net, 'conv1_2', 64)
+    net.add(layer.MaxPooling2D('pool1', 2, 2, border_mode='valid'))
+    ConvBnReLU(net, 'conv2_1', 128)
+    net.add(layer.Dropout('drop2_1', 0.4, engine='cudnn'))
+    ConvBnReLU(net, 'conv2_2', 128)
+    net.add(layer.MaxPooling2D('pool2', 2, 2, border_mode='valid'))
+    ConvBnReLU(net, 'conv3_1', 256)
+    net.add(layer.Dropout('drop3_1', 0.4, engine='cudnn'))
+    ConvBnReLU(net, 'conv3_2', 256)
+    net.add(layer.Dropout('drop3_2', 0.4, engine='cudnn'))
+    ConvBnReLU(net, 'conv3_3', 256)
+    net.add(layer.MaxPooling2D('pool3', 2, 2, border_mode='valid'))
+    ConvBnReLU(net, 'conv4_1', 512)
+    net.add(layer.Dropout('drop4_1', 0.4, engine='cudnn'))
+    ConvBnReLU(net, 'conv4_2', 512)
+    net.add(layer.Dropout('drop4_2', 0.4, engine='cudnn'))
+    ConvBnReLU(net, 'conv4_3', 512)
+    net.add(layer.MaxPooling2D('pool4', 2, 2, border_mode='valid'))
+    ConvBnReLU(net, 'conv5_1', 512)
+    net.add(layer.Dropout('drop5_1', 0.4, engine='cudnn'))
+    ConvBnReLU(net, 'conv5_2', 512)
+    net.add(layer.Dropout('drop5_2', 0.4, engine='cudnn'))
+    ConvBnReLU(net, 'conv5_3', 512)
+    net.add(layer.MaxPooling2D('pool5', 2, 2, border_mode='valid'))
+    net.add(layer.Flatten('flat'))
+    net.add(layer.Dropout('drop_flat', 0.5, engine='cudnn'))
+    net.add(layer.Dense('ip1', 512))
+    net.add(layer.BatchNormalization('batchnorm_ip1'))
+    net.add(layer.Activation('relu_ip1'))
+    net.add(layer.Dropout('drop_ip2', 0.5, engine='cudnn'))
+    net.add(layer.Dense('ip2', 10))
+    return net

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 4972a86..c16bd29 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -80,11 +80,11 @@ void Tensor::ResetLike(const Tensor &in) {
   if (block_ == nullptr || device_ != in.device_ || MemSize() != in.MemSize()) {
     if (block_ != nullptr && block_->DecRefCount() == 0)
       device_->FreeBlock(block_);
-    shape_ = in.shape_;
     device_ = in.device_;
     data_type_ = in.data_type_;
     block_ = device_->NewBlock(in.MemSize());
   }
+  shape_ = in.shape_;
 }
 
 void Tensor::Reshape(const Shape &shape) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/model/layer/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.cc b/src/model/layer/batchnorm.cc
index b6edc9e..6ea9f2a 100644
--- a/src/model/layer/batchnorm.cc
+++ b/src/model/layer/batchnorm.cc
@@ -27,8 +27,18 @@ void BatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
   out_sample_shape_ = in_sample;
   factor_ = conf.batchnorm_conf().factor();
   channels_ = in_sample.at(0);
-  height_ = in_sample.at(1);
-  width_ = in_sample.at(2);
+  if (in_sample.size() == 3u)
+    height_ = in_sample.at(1);
+  else
+    height_ = 1;
+  if (in_sample.size() == 3u)
+    width_ = in_sample.at(2);
+  else
+    width_ = 1;
+  if (in_sample.size() == 1u)
+    is_2d_ = true;
+  else
+    is_2d_ = false;
 
   bnScale_.Reshape(Shape{channels_ * height_ * width_});
   bnBias_.ResetLike(bnScale_);
@@ -92,7 +102,8 @@ const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
     AddRow(bnBias_, &output);
   }
 
-  output.Reshape(Shape{output.shape(0), channels_, height_, width_});
+  if (!is_2d_)
+    output.Reshape(Shape{output.shape(0), channels_, height_, width_});
   return output;
 }
 
@@ -170,10 +181,16 @@ const std::pair<Tensor, vector<Tensor>> BatchNorm::Backward(
     SumRows(dy, &dbnBias_);
     param_grad.push_back(dbnScale_);
     param_grad.push_back(dbnBias_);
+    Tensor dummy;
+    dummy.ResetLike(runningMean_);
+    dummy.SetValue(.0f);
+    param_grad.push_back(dummy);
+    param_grad.push_back(dummy);
   } else {
     LOG(ERROR) << "Do not call backward for evaluation phase";
   }
-  dx.Reshape(Shape{dx.shape(0), channels_, height_, width_});
+  if (!is_2d_)
+    dx.Reshape(Shape{dx.shape(0), channels_, height_, width_});
   return std::make_pair(dx, param_grad);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/model/layer/batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.h b/src/model/layer/batchnorm.h
index 6ff818b..f3d83ab 100644
--- a/src/model/layer/batchnorm.h
+++ b/src/model/layer/batchnorm.h
@@ -44,7 +44,7 @@ class BatchNorm : public Layer {
   /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&);
   const std::pair<Tensor, vector<Tensor>> Backward(
       int flag, const Tensor& grad) override;
-  const std::vector<Tensor> param_values() override {
+  virtual const std::vector<Tensor> param_values() override {
     return std::vector<Tensor> { bnScale_, bnBias_, runningMean_,
                                  runningVariance_ };
   }
@@ -77,6 +77,7 @@ class BatchNorm : public Layer {
  protected:
   float factor_;
   size_t channels_, height_, width_;
+  bool is_2d_ = false;
   Tensor bnScale_, bnBias_;
   Tensor dbnScale_, dbnBias_;
   Tensor runningMean_, runningVariance_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/model/layer/cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.cc b/src/model/layer/cudnn_batchnorm.cc
index 9e1e892..461f1b6 100644
--- a/src/model/layer/cudnn_batchnorm.cc
+++ b/src/model/layer/cudnn_batchnorm.cc
@@ -75,14 +75,20 @@ const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) {
   auto shape = input.shape();
   auto dtype = input.data_type();
   Tensor output;
+  Tensor x;
+  if(is_2d_)
+    x = Reshape(input, Shape{shape.at(0), shape.at(1), 1, 1});
+  else
+    x = input;
+  shape = x.shape();
   if (!has_init_cudnn_)
     InitCudnn(shape, dtype);
   // TODO(wangji): check device id of input and params
-  output.ResetLike(input);
+  output.ResetLike(x);
   if ((flag & kTrain) == kTrain) {
     output.device()->Exec(
         [=](Context* ctx) {
-          Block *inBlock = input.block(), *outBlock = output.block(),
+          Block *inBlock = x.block(), *outBlock = output.block(),
             *saveMeanBlock = resultSaveMean_.block(),
             *saveVarBlock = resultSaveVariance_.block(),
             *runningMeanBlock = runningMean_.block(),
@@ -110,7 +116,7 @@ const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) {
               saveMeanBlock->mutable_data(),
               saveVarBlock->mutable_data()));
         },
-        {input.block(),
+        {x.block(),
          bnScale_.block(),
          bnBias_.block()},
         {output.block(),
@@ -118,11 +124,11 @@ const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) {
          runningVariance_.block(),
          resultSaveMean_.block(),
          resultSaveVariance_.block()});
-    buf_.push(input);
+    buf_.push(x);
   } else {
     output.device()->Exec(
         [=](Context* ctx) {
-          Block *inBlock = input.block(), *outBlock = output.block(),
+          Block *inBlock = x.block(), *outBlock = output.block(),
             *runningMeanBlock = runningMean_.block(),
             *runningVarBlock = runningVariance_.block(),
             *bnScaleBlock = bnScale_.block(),
@@ -145,13 +151,15 @@ const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) {
               runningVarBlock->data(),
               epsilon));
         },
-        {input.block(),
+        {x.block(),
          bnScale_.block(),
          bnBias_.block(),
          runningMean_.block(),
          runningVariance_.block()},
         {output.block()});
   }
+  if (is_2d_)
+    output.Reshape(Shape{shape.at(0), shape.at(1)});
   return output;
 }
 
@@ -160,13 +168,13 @@ const std::pair<Tensor, vector<Tensor>> CudnnBatchNorm::Backward(
   vector <Tensor> param_grad;
   Tensor dx;
   if ((flag & kTrain) == kTrain) {
-    Tensor input = buf_.top();
+    Tensor x = buf_.top();
     buf_.pop();
     dx.ResetLike(grad);
     dx.device()->Exec(
         [=](Context* ctx) {
           Block *dyblock = grad.block(), *dxblock = dx.block(),
-            *xblock = input.block(),
+            *xblock = x.block(),
             *bnScaleBlock = bnScale_.block(),
             *dbnScaleBlock = dbnScale_.block(),
             *dbnBiasBlock = dbnBias_.block(),
@@ -208,6 +216,13 @@ const std::pair<Tensor, vector<Tensor>> CudnnBatchNorm::Backward(
   }
   param_grad.push_back(dbnScale_);
   param_grad.push_back(dbnBias_);
+  Tensor dummy;
+  dummy.ResetLike(dbnScale_);
+  dummy.SetValue(.0f);
+  param_grad.push_back(dummy);
+  param_grad.push_back(dummy);
+  if (is_2d_)
+    dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
   return std::make_pair(dx, param_grad);
 }
 }  // namespace

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/model/updater/local_updater.cc
----------------------------------------------------------------------
diff --git a/src/model/updater/local_updater.cc b/src/model/updater/local_updater.cc
index eab4a7c..c3c6793 100644
--- a/src/model/updater/local_updater.cc
+++ b/src/model/updater/local_updater.cc
@@ -33,6 +33,7 @@ void LocalUpdater::Register(const string& name, const ParamSpec& specs) {
   }
   dev_index_[name] = 0;
   to_updater_finished_[name] = 0;
+  mtx_[name];
 }
 
 void LocalUpdater::Apply(int step, const string& name, Tensor& grad,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/src/python/singa/layer.py b/src/python/singa/layer.py
index 937a7e1..a443e1a 100644
--- a/src/python/singa/layer.py
+++ b/src/python/singa/layer.py
@@ -327,10 +327,16 @@ class BatchNormalization(Layer):
             beta_specs['name'] = name + '_beta'
         if 'name' not in gamma_specs:
             gamma_specs['name'] = name + '_gamma'
-        self.conf.param.extend([_construct_param_specs_from_dict(beta_specs)])
+        mean_specs = {'init': 'constant', 'value': 0, 'name': name+'_mean'}
+        var_specs = {'init': 'constant', 'value': 1, 'name': name+'_var'}
         self.conf.param.extend([_construct_param_specs_from_dict(gamma_specs)])
-        self.param_specs.append(_construct_param_specs_from_dict(beta_specs))
+        self.conf.param.extend([_construct_param_specs_from_dict(beta_specs)])
+        self.conf.param.extend([_construct_param_specs_from_dict(mean_specs)])
+        self.conf.param.extend([_construct_param_specs_from_dict(var_specs)])
         self.param_specs.append(_construct_param_specs_from_dict(gamma_specs))
+        self.param_specs.append(_construct_param_specs_from_dict(beta_specs))
+        self.param_specs.append(_construct_param_specs_from_dict(mean_specs))
+        self.param_specs.append(_construct_param_specs_from_dict(var_specs))
         _check_engine(engine, ['cudnn'])
         self.layer = _create_layer(engine, 'BatchNorm')
         if input_sample_shape is not None:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/python/singa/net.py
----------------------------------------------------------------------
diff --git a/src/python/singa/net.py b/src/python/singa/net.py
index 084db4b..c0ba61d 100644
--- a/src/python/singa/net.py
+++ b/src/python/singa/net.py
@@ -64,6 +64,9 @@ class FeedForwardNet(object):
             specs.extend(lyr.param_specs)
         return specs
 
+    def param_names(self):
+        return [spec.name for spec in self.param_specs()]
+
     def train(self, x, y):
         out = self.forward(kTrain, x)
         l = self.loss.forward(kTrain, out, y)
@@ -89,9 +92,10 @@ class FeedForwardNet(object):
         return tensor.softmax(xx)
 
     def forward(self, flag, x):
+        #print x.l1()
         for lyr in self.layers:
             x = lyr.forward(flag, x)
-            # print lyr.name, x.l1()
+        #    print lyr.name, x.l1()
         return x
 
     def backward(self, flag=kTrain):


[2/2] incubator-singa git commit: SINGA-231 Batchnormlized VGG model for cifar-10

Posted by wa...@apache.org.
SINGA-231 Batchnormlized VGG model for cifar-10

Merge the training of vgg and alexnet into train.py
The validation accuracy of vgg could reach 0.89


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/28678ae8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/28678ae8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/28678ae8

Branch: refs/heads/dev
Commit: 28678ae8329112ca1f11086b52ded7149ec9ab2c
Parents: bc3b74b
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Tue Aug 9 20:06:29 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Wed Aug 10 00:01:03 2016 +0800

----------------------------------------------------------------------
 examples/cifar10/alexnet.py           |  16 ++-
 examples/cifar10/predict.py           |  14 ++-
 examples/cifar10/run-parallel.sh      |   1 +
 examples/cifar10/train.py             |  63 +++++++----
 examples/cifar10/train_vgg_cifar10.py | 162 -----------------------------
 examples/cifar10/vgg-parallel.cc      |  24 ++---
 examples/cifar10/vgg.py               |  66 ++++++++++--
 7 files changed, 138 insertions(+), 208 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/28678ae8/examples/cifar10/alexnet.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/alexnet.py b/examples/cifar10/alexnet.py
index 4b3daec..96c339a 100644
--- a/examples/cifar10/alexnet.py
+++ b/examples/cifar10/alexnet.py
@@ -14,15 +14,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # =============================================================================
+''' This model is created following the structure from
+https://code.google.com/p/cuda-convnet/source/browse/trunk/example-layers/layers-18pct.cfg
+Following the same setting for hyper-parameters and data pre-processing, the final
+validation accuracy would be about 82%.
+'''
+
 import sys
 import os
 
 sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
 from singa import layer
+from singa import initializer
 from singa import metric
 from singa import loss
 from singa import net as ffnet
-from singa.proto import core_pb2
 
 
 def create_net():
@@ -44,4 +50,12 @@ def create_net():
     net.add(layer.MaxPooling2D('pool3', 3, 2, pad=1))
     net.add(layer.Flatten('flat'))
     net.add(layer.Dense('dense', 10, W_specs=W2_specs.copy(), b_specs=b_specs.copy()))
+    for (p, specs) in zip(net.param_values(), net.param_specs()):
+        filler = specs.filler
+        if filler.type == 'gaussian':
+            initializer.gaussian(p, filler.mean, filler.std)
+        else:
+            p.set_value(0)
+        print specs.name, filler.type, p.l1()
+
     return net

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/28678ae8/examples/cifar10/predict.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/predict.py b/examples/cifar10/predict.py
index d083d0b..07b1145 100644
--- a/examples/cifar10/predict.py
+++ b/examples/cifar10/predict.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # =============================================================================
-
+import cPickle as pickle
 import numpy as np
 import sys
 import os
@@ -27,6 +27,15 @@ import net as ffnet
 
 
 def predict(net, images, cuda, topk=5):
+    '''Predict the label of each image.
+
+    Args:
+        net, a pretrained neural net
+        images, a batch of images [batch_size, 3, 32, 32], which have been
+            pre-processed
+        cuda, the cuda device
+        topk, return the topk labels for each image.
+    '''
     x = tensor.from_numpy(images.astype(np.float32))
     x.to_device(cuda)
     y = net.predict(x)
@@ -40,7 +49,7 @@ def predict(net, images, cuda, topk=5):
 def load_dataset(filepath):
     print 'Loading data file %s' % filepath
     with open(filepath, 'rb') as fd:
-        cifar10 = cPickle.load(fd)
+        cifar10 = pickle.load(fd)
     image = cifar10['data'].astype(dtype=np.uint8)
     image = image.reshape((-1, 3, 32, 32))
     label = np.asarray(cifar10['labels'], dtype=np.uint8)
@@ -79,4 +88,5 @@ if __name__ == '__main__':
 
     mean = compute_image_mean('cifar-10-batches-py')
     test_images, _ = load_test_data('cifar-10-batches-py')
+    # minus mean is for alexnet; vgg uses a different pre-processing strategy
     print predict(model, test_images - mean, cuda)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/28678ae8/examples/cifar10/run-parallel.sh
----------------------------------------------------------------------
diff --git a/examples/cifar10/run-parallel.sh b/examples/cifar10/run-parallel.sh
index 6a9109a..18193db 100755
--- a/examples/cifar10/run-parallel.sh
+++ b/examples/cifar10/run-parallel.sh
@@ -1,2 +1,3 @@
 #!/usr/bin/env sh
 ../../build/bin/alexnet-parallel -epoch 4
+#../../build/bin/vgg-parallel -epoch 4

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/28678ae8/examples/cifar10/train.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py
index f4caca4..cb4110d 100644
--- a/examples/cifar10/train.py
+++ b/examples/cifar10/train.py
@@ -23,9 +23,9 @@ import cPickle
 import numpy as np
 import os
 import sys
+import argparse
 
 sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
-from singa import initializer
 from singa import utils
 from singa import optimizer
 from singa import device
@@ -33,6 +33,7 @@ from singa import tensor
 from singa.proto import core_pb2
 
 import alexnet
+import vgg
 
 
 def load_dataset(filepath):
@@ -65,7 +66,28 @@ def load_test_data(dir_path):
     return np.array(images,  dtype=np.float32), np.array(labels, dtype=np.int32)
 
 
-def get_lr(epoch):
+def normalize_for_vgg(train_x, test_x):
+    mean = train_x.mean()
+    std = train_x.std()
+    train_x -= mean
+    test_x -= mean
+    train_x /= std
+    test_x /= std
+    return train_x, test_x
+
+
+def normalize_for_alexnet(train_x, test_x):
+    mean = np.average(train_x, axis=0)
+    train_x -= mean
+    test_x -= mean
+    return train_x, test_x
+
+
+def vgg_lr(epoch):
+    return 0.01 / float(1 << ((epoch / 30)))
+
+
+def alexnet_lr(epoch):
     if epoch < 120:
         return 0.001
     elif epoch < 130:
@@ -74,32 +96,21 @@ def get_lr(epoch):
         return 0.00001
 
 
-def train(data_dir, net, num_epoch=140, batch_size=100):
+def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100):
     print 'Start intialization............'
     cuda = device.create_cuda_gpu()
     net.to_device(cuda)
     opt = optimizer.SGD(momentum=0.9, weight_decay=0.004)
     for (p, specs) in zip(net.param_values(), net.param_specs()):
-        filler = specs.filler
-        if filler.type == 'gaussian':
-            initializer.gaussian(p, filler.mean, filler.std)
-        else:
-            p.set_value(0)
         opt.register(p, specs)
-        print specs.name, filler.type, p.l1()
-    print 'Loading data ..................'
-    train_x, train_y = load_train_data(data_dir)
-    test_x, test_y = load_test_data(data_dir)
-    mean = np.average(train_x, axis=0)
-    train_x -= mean
-    test_x -= mean
 
     tx = tensor.Tensor((batch_size, 3, 32, 32), cuda)
     ty = tensor.Tensor((batch_size,), cuda, core_pb2.kInt)
+    train_x, train_y, test_x, test_y = data
     num_train_batch = train_x.shape[0] / batch_size
     num_test_batch = test_x.shape[0] / batch_size
     idx = np.arange(train_x.shape[0], dtype=np.int32)
-    for epoch in range(num_epoch):
+    for epoch in range(max_epoch):
         np.random.shuffle(idx)
         loss, acc = 0.0, 0.0
         print 'Epoch %d' % epoch
@@ -135,8 +146,20 @@ def train(data_dir, net, num_epoch=140, batch_size=100):
     net.save('model.bin')  # save model params into checkpoint file
 
 if __name__ == '__main__':
-    data_dir = 'cifar-10-batches-py'
-    assert os.path.exists(data_dir), \
+    parser = argparse.ArgumentParser(description='Train vgg/alexnet for cifar10')
+    parser.add_argument('model', choices=['vgg', 'alexnet'], default='alexnet')
+    parser.add_argument('data', default='cifar-10-batches-py')
+    args = parser.parse_args()
+    assert os.path.exists(args.data), \
         'Pls download the cifar10 dataset via "download_data.py py"'
-    net = alexnet.create_net()
-    train(data_dir, net)
+    print 'Loading data ..................'
+    train_x, train_y = load_train_data(args.data)
+    test_x, test_y = load_test_data(args.data)
+    if args.model == 'alexnet':
+        train_x, test_x = normalize_for_alexnet(train_x, test_x)
+        net = alexnet.create_net()
+        train((train_x, train_y, test_x, test_y), net, 140, alexnet_lr, 0.004)
+    else:
+        train_x, test_x = normalize_for_vgg(train_x, test_x)
+        net = vgg.create_net()
+        train((train_x, train_y, test_x, test_y), net, 250, vgg_lr, 0.0005)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/28678ae8/examples/cifar10/train_vgg_cifar10.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/train_vgg_cifar10.py b/examples/cifar10/train_vgg_cifar10.py
deleted file mode 100644
index e9df04e..0000000
--- a/examples/cifar10/train_vgg_cifar10.py
+++ /dev/null
@@ -1,162 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# =============================================================================
-""" CIFAR10 dataset is at https://www.cs.toronto.edu/~kriz/cifar.html.
-It includes 5 binary dataset, each contains 10000 images. 1 row (1 image)
-includes 1 label & 3072 pixels.  3072 pixels are 3 channels of a 32x32 image
-"""
-
-import cPickle
-import numpy as np
-import os
-import sys
-import math
-
-sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
-from singa import initializer
-from singa import utils
-from singa import optimizer
-from singa import device
-from singa import tensor
-from singa.proto import core_pb2
-
-import vgg
-
-
-def load_dataset(filepath):
-    print 'Loading data file %s' % filepath
-    with open(filepath, 'rb') as fd:
-        cifar10 = cPickle.load(fd)
-    image = cifar10['data'].astype(dtype=np.uint8)
-    image = image.reshape((-1, 3, 32, 32))
-    label = np.asarray(cifar10['labels'], dtype=np.uint8)
-    label = label.reshape(label.size, 1)
-    return image, label
-
-
-def load_train_data(dir_path, num_batches=5):
-    labels = []
-    batchsize = 10000
-    images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8)
-    for did in range(1, num_batches + 1):
-        fname_train_data = dir_path + "/data_batch_{}".format(did)
-        image, label = load_dataset(fname_train_data)
-        images[(did - 1) * batchsize:did * batchsize] = image
-        labels.extend(label)
-    images = np.array(images, dtype=np.float32)
-    labels = np.array(labels, dtype=np.int32)
-    return images, labels
-
-
-def load_test_data(dir_path):
-    images, labels = load_dataset(dir_path + "/test_batch")
-    return np.array(images,  dtype=np.float32), np.array(labels, dtype=np.int32)
-
-
-def get_lr(epoch):
-    return 0.01 / float(1 << ((epoch / 30)))
-    #if epoch < 100:
-    #    return 0.01
-    #elif epoch < 150:
-    #    return 0.005
-    #elif epoch < 200:
-    #    return 0.001
-    #elif epoch < 250:
-    #    return 0.0001
-
-
-def train(data_dir, net, num_epoch=250, batch_size=128):
-    print 'Creating Device............'
-    cuda = device.create_cuda_gpus(2)[1]
-    net.to_device(cuda)
-    print 'Start intialization............'
-    opt = optimizer.SGD(momentum=0.9, weight_decay=0.0005)
-    for (p, name) in zip(net.param_values(), net.param_names()):
-        print name, p.shape
-        if len(p.shape) > 1:
-            if 'mean' in name  or 'beta' in name:
-                p.set_value(0.0)
-            elif 'var' in name:
-                p.set_value(1.0)
-            elif 'gamma' in name:
-                initializer.uniform(p, 0, 1)
-            elif 'conv' in name:
-                initializer.gaussian(p, 0, math.sqrt(2.0/(9.0 * p.shape[0])))
-            else:
-                initializer.gaussian(p, 0, 0.02)
-
-                #stdv = 1.0/math.sqrt(p.shape[1])
-                #initializer.uniform(p, -stdv, stdv)
-        else:
-            p.set_value(0)
-        #print specs.name, filler.type, p.l1()
-        print name, p.l1()
-    print 'Loading data ..................'
-    train_x, train_y = load_train_data(data_dir)
-    test_x, test_y = load_test_data(data_dir)
-    mean = train_x.mean()
-    std = train_x.std()
-    train_x -= mean
-    test_x -= mean
-    train_x /= std
-    test_x /= std
-
-    tx = tensor.Tensor((batch_size, 3, 32, 32), cuda)
-    ty = tensor.Tensor((batch_size,), cuda, core_pb2.kInt)
-    num_train_batch = train_x.shape[0] / batch_size
-    num_test_batch = test_x.shape[0] / batch_size
-    idx = np.arange(train_x.shape[0], dtype=np.int32)
-    for epoch in range(num_epoch):
-        np.random.shuffle(idx)
-        loss, acc = 0.0, 0.0
-        print 'Epoch %d' % epoch
-        for b in range(num_train_batch):
-            x = train_x[idx[b * batch_size: (b + 1) * batch_size]]
-            y = train_y[idx[b * batch_size: (b + 1) * batch_size]]
-            tx.copy_from_numpy(x)
-            ty.copy_from_numpy(y)
-            grads, (l, a) = net.train(tx, ty)
-            loss += l
-            acc += a
-            for (s, p, g) in zip(net.param_specs(), net.param_values(), grads):
-                opt.apply_with_lr(epoch, get_lr(epoch), g, p, str(s.name))
-            # update progress bar
-            utils.update_progress(b * 1.0 / num_train_batch,
-                                  'training loss = %f, accuracy = %f' % (l, a))
-        info = '\ntraining loss = %f, training accuracy = %f' \
-            % (loss / num_train_batch, acc / num_train_batch)
-        print info
-
-        loss, acc = 0.0, 0.0
-        for b in range(num_test_batch):
-            x = test_x[b * batch_size: (b + 1) * batch_size]
-            y = test_y[b * batch_size: (b + 1) * batch_size]
-            tx.copy_from_numpy(x)
-            ty.copy_from_numpy(y)
-            l, a = net.evaluate(tx, ty)
-            loss += l
-            acc += a
-
-        print 'test loss = %f, test accuracy = %f' \
-            % (loss / num_test_batch, acc / num_test_batch)
-    net.save('model.bin')  # save model params into checkpoint file
-
-if __name__ == '__main__':
-    data_dir = 'cifar-10-batches-py'
-    assert os.path.exists(data_dir), \
-        'Pls download the cifar10 dataset via "download_data.py py"'
-    net = vgg.create_net()
-    train(data_dir, net)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/28678ae8/examples/cifar10/vgg-parallel.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/vgg-parallel.cc b/examples/cifar10/vgg-parallel.cc
index ba308e9..c6b7fa1 100644
--- a/examples/cifar10/vgg-parallel.cc
+++ b/examples/cifar10/vgg-parallel.cc
@@ -32,7 +32,7 @@
 #include "../../src/model/layer/cudnn_activation.h"
 #include "../../src/model/layer/cudnn_pooling.h"
 #include "../../src/model/layer/cudnn_lrn.h"
-#include "../../src/model/layer/cudnn_dropout.h"
+#include "../../src/model/layer/dropout.h"
 #include "../../src/model/layer/cudnn_batchnorm.h"
 #include "../../src/model/layer/dense.h"
 #include "../../src/model/layer/flatten.h"
@@ -155,7 +155,7 @@ LayerConf GenBatchNormConf(string name) {
 LayerConf GenDropoutConf(string name, float dropout_ratio) {
   LayerConf conf;
   conf.set_name(name);
-  conf.set_type("CudnnDropout");
+  conf.set_type("Dropout");
   DropoutConf *dropout = conf.mutable_dropout_conf();
   dropout->set_dropout_ratio(dropout_ratio);
 
@@ -172,37 +172,37 @@ FeedForwardNet CreateNet() {
   FeedForwardNet net;
   Shape s{3, 32, 32};
   ConvBNReLU(net, "conv1_1", 64, &s);
-  net.Add(new CudnnDropout(), GenDropoutConf("drop1", 0.3));
+  net.Add(new Dropout(), GenDropoutConf("drop1", 0.3));
   ConvBNReLU(net, "conv1_2", 64);
   net.Add(new CudnnPooling(), GenPoolingConf("pool1", true, 2, 2, 0));
   ConvBNReLU(net, "conv2_1", 128);
-  net.Add(new CudnnDropout(), GenDropoutConf("drop2", 0.4));
+  net.Add(new Dropout(), GenDropoutConf("drop2", 0.4));
   ConvBNReLU(net, "conv2_2", 128);
   net.Add(new CudnnPooling(), GenPoolingConf("pool2", true, 2, 2, 0));
   ConvBNReLU(net, "conv3_1", 256);
-  net.Add(new CudnnDropout(), GenDropoutConf("drop3_1", 0.4));
+  net.Add(new Dropout(), GenDropoutConf("drop3_1", 0.4));
   ConvBNReLU(net, "conv3_2", 256);
-  net.Add(new CudnnDropout(), GenDropoutConf("drop3_2", 0.4));
+  net.Add(new Dropout(), GenDropoutConf("drop3_2", 0.4));
   ConvBNReLU(net, "conv3_3", 256);
   net.Add(new CudnnPooling(), GenPoolingConf("pool3", true, 2, 2, 0));
   ConvBNReLU(net, "conv4_1", 512);
-  net.Add(new CudnnDropout(), GenDropoutConf("drop4_1", 0.4));
+  net.Add(new Dropout(), GenDropoutConf("drop4_1", 0.4));
   ConvBNReLU(net, "conv4_2", 512);
-  net.Add(new CudnnDropout(), GenDropoutConf("drop4_2", 0.4));
+  net.Add(new Dropout(), GenDropoutConf("drop4_2", 0.4));
   ConvBNReLU(net, "conv4_3", 512);
   net.Add(new CudnnPooling(), GenPoolingConf("pool4", true, 2, 2, 0));
   ConvBNReLU(net, "conv5_1", 512);
-  net.Add(new CudnnDropout(), GenDropoutConf("drop5_1", 0.4));
+  net.Add(new Dropout(), GenDropoutConf("drop5_1", 0.4));
   ConvBNReLU(net, "conv5_2", 512);
-  net.Add(new CudnnDropout(), GenDropoutConf("drop5_2", 0.4));
+  net.Add(new Dropout(), GenDropoutConf("drop5_2", 0.4));
   ConvBNReLU(net, "conv5_3", 512);
   net.Add(new CudnnPooling(), GenPoolingConf("pool5", true, 2, 2, 0));
   net.Add(new Flatten(), GenFlattenConf("flat"));
-  net.Add(new CudnnDropout(), GenDropoutConf("flat_drop", 0.5));
+  net.Add(new Dropout(), GenDropoutConf("flat_drop", 0.5));
   net.Add(new Dense(), GenDenseConf("ip1", 512, 0.02));
   net.Add(new CudnnBatchNorm(), GenBatchNormConf("ip1_bn"));
   net.Add(new CudnnActivation(), GenReLUConf("ip1_relu"));
-  net.Add(new CudnnDropout(), GenDropoutConf("ip1_drop", 0.5));
+  net.Add(new Dropout(), GenDropoutConf("ip1_drop", 0.5));
   net.Add(new Dense(), GenDenseConf("ip2", 10, 0.02));
 
   return net;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/28678ae8/examples/cifar10/vgg.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/vgg.py b/examples/cifar10/vgg.py
index 8063307..0b9bb56 100644
--- a/examples/cifar10/vgg.py
+++ b/examples/cifar10/vgg.py
@@ -1,12 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+""" The VGG model is adapted from http://torch.ch/blog/2015/07/30/cifar.html.
+The best validation accuracy we achieved is about 89% without data augmentation.
+The performance could be improved by tuning some hyper-parameters, including
+learning rate, weight decay, max_epoch, parameter initialization, etc.
+"""
+
 import sys
 import os
+import math
 
 sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
+
 from singa import layer
+from singa import initializer
 from singa import metric
 from singa import loss
 from singa import net as ffnet
-from singa.proto import core_pb2
+
 
 def ConvBnReLU(net, name, nb_filers, sample_shape=None):
     net.add(layer.Conv2D(name + '_1', nb_filers, 3, 1, pad=1,
@@ -14,39 +39,58 @@ def ConvBnReLU(net, name, nb_filers, sample_shape=None):
     net.add(layer.BatchNormalization(name + '_2'))
     net.add(layer.Activation(name + '_3'))
 
+
 def create_net():
     net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy())
     ConvBnReLU(net, 'conv1_1', 64, (3, 32, 32))
-    net.add(layer.Dropout('drop1', 0.3, engine='cudnn'))
+    net.add(layer.Dropout('drop1', 0.3, engine='cuda'))
     ConvBnReLU(net, 'conv1_2', 64)
     net.add(layer.MaxPooling2D('pool1', 2, 2, border_mode='valid'))
     ConvBnReLU(net, 'conv2_1', 128)
-    net.add(layer.Dropout('drop2_1', 0.4, engine='cudnn'))
+    net.add(layer.Dropout('drop2_1', 0.4, engine='cuda'))
     ConvBnReLU(net, 'conv2_2', 128)
     net.add(layer.MaxPooling2D('pool2', 2, 2, border_mode='valid'))
     ConvBnReLU(net, 'conv3_1', 256)
-    net.add(layer.Dropout('drop3_1', 0.4, engine='cudnn'))
+    net.add(layer.Dropout('drop3_1', 0.4, engine='cuda'))
     ConvBnReLU(net, 'conv3_2', 256)
-    net.add(layer.Dropout('drop3_2', 0.4, engine='cudnn'))
+    net.add(layer.Dropout('drop3_2', 0.4, engine='cuda'))
     ConvBnReLU(net, 'conv3_3', 256)
     net.add(layer.MaxPooling2D('pool3', 2, 2, border_mode='valid'))
     ConvBnReLU(net, 'conv4_1', 512)
-    net.add(layer.Dropout('drop4_1', 0.4, engine='cudnn'))
+    net.add(layer.Dropout('drop4_1', 0.4, engine='cuda'))
     ConvBnReLU(net, 'conv4_2', 512)
-    net.add(layer.Dropout('drop4_2', 0.4, engine='cudnn'))
+    net.add(layer.Dropout('drop4_2', 0.4, engine='cuda'))
     ConvBnReLU(net, 'conv4_3', 512)
     net.add(layer.MaxPooling2D('pool4', 2, 2, border_mode='valid'))
     ConvBnReLU(net, 'conv5_1', 512)
-    net.add(layer.Dropout('drop5_1', 0.4, engine='cudnn'))
+    net.add(layer.Dropout('drop5_1', 0.4, engine='cuda'))
     ConvBnReLU(net, 'conv5_2', 512)
-    net.add(layer.Dropout('drop5_2', 0.4, engine='cudnn'))
+    net.add(layer.Dropout('drop5_2', 0.4, engine='cuda'))
     ConvBnReLU(net, 'conv5_3', 512)
     net.add(layer.MaxPooling2D('pool5', 2, 2, border_mode='valid'))
     net.add(layer.Flatten('flat'))
-    net.add(layer.Dropout('drop_flat', 0.5, engine='cudnn'))
+    net.add(layer.Dropout('drop_flat', 0.5, engine='cuda'))
     net.add(layer.Dense('ip1', 512))
     net.add(layer.BatchNormalization('batchnorm_ip1'))
     net.add(layer.Activation('relu_ip1'))
-    net.add(layer.Dropout('drop_ip2', 0.5, engine='cudnn'))
+    net.add(layer.Dropout('drop_ip2', 0.5, engine='cuda'))
     net.add(layer.Dense('ip2', 10))
+    print 'Start intialization............'
+    for (p, name) in zip(net.param_values(), net.param_names()):
+        print name, p.shape
+        if len(p.shape) > 1:
+            if 'mean' in name or 'beta' in name:
+                p.set_value(0.0)
+            elif 'var' in name:
+                p.set_value(1.0)
+            elif 'gamma' in name:
+                initializer.uniform(p, 0, 1)
+            elif 'conv' in name:
+                initializer.gaussian(p, 0, math.sqrt(2.0/(9.0 * p.shape[0])))
+            else:
+                initializer.gaussian(p, 0, 0.02)
+        else:
+            p.set_value(0)
+        print name, p.l1()
+
     return net