You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2015/06/15 06:22:14 UTC

[2/3] incubator-singa git commit: SINGA-6 Update the implementation of Singleton to make it thread-safe (using C++11's static construction). The ASingleton (constructor with Argument) was used for sharing the Mshadow::Random among Layers and Params. To m

SINGA-6
Update the implementation of Singleton to make it thread-safe (using C++11's static construction).
The ASingleton (constructor with Argument) was used for sharing the Mshadow::Random among Layers and Params.
To make it thread-safe
1. we change it to TSingleton (Thread specfic singlton).
2. we add a construtor without arguments for Mshadow::Random which uses a seed generated based on the system clock.
3. we replace the rand/srand in Mshaodw::Random<cpu> with C++11's random functions, i.e., from random.h
Now each thread will have a separate Mshadow::Random object if it calls TSingleton<Random<cpu>>::Instance() to generate random numbers.
Hence it is thread-safe.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/3ba71553
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/3ba71553
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/3ba71553

Branch: refs/heads/master
Commit: 3ba71553685b2dd1aeebda948a36dd087a096e97
Parents: f13e3a7
Author: wang wei <wa...@comp.nus.edu.sg>
Authored: Sun Jun 14 15:58:39 2015 +0800
Committer: wang sheng <wa...@gmail.com>
Committed: Mon Jun 15 12:20:49 2015 +0800

----------------------------------------------------------------------
 examples/cifar10/cluster.conf         |   6 -
 examples/cifar10/cluster.conf.example |   6 +
 examples/cifar10/model.conf           | 218 -----------------------------
 examples/cifar10/model.conf.example   | 218 +++++++++++++++++++++++++++++
 include/mshadow/tensor_random.h       |  37 ++++-
 include/utils/singleton.h             |  34 ++---
 src/neuralnet/layer.cc                |   4 +-
 src/utils/param.cc                    |   3 +-
 8 files changed, 271 insertions(+), 255 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ba71553/examples/cifar10/cluster.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/cluster.conf b/examples/cifar10/cluster.conf
deleted file mode 100644
index 97c64fd..0000000
--- a/examples/cifar10/cluster.conf
+++ /dev/null
@@ -1,6 +0,0 @@
-nworker_groups: 1
-nserver_groups: 1
-nservers_per_group: 1
-nworkers_per_group: 1
-nworkers_per_procs: 1
-workspace: "examples/cifar10/"

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ba71553/examples/cifar10/cluster.conf.example
----------------------------------------------------------------------
diff --git a/examples/cifar10/cluster.conf.example b/examples/cifar10/cluster.conf.example
new file mode 100644
index 0000000..97c64fd
--- /dev/null
+++ b/examples/cifar10/cluster.conf.example
@@ -0,0 +1,6 @@
+nworker_groups: 1
+nserver_groups: 1
+nservers_per_group: 1
+nworkers_per_group: 1
+nworkers_per_procs: 1
+workspace: "examples/cifar10/"

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ba71553/examples/cifar10/model.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/model.conf b/examples/cifar10/model.conf
deleted file mode 100644
index 72ebf8e..0000000
--- a/examples/cifar10/model.conf
+++ /dev/null
@@ -1,218 +0,0 @@
-name: "cifar10-convnet"
-train_steps: 70000
-test_steps:100
-test_frequency:1000
-display_frequency:30
-updater{
-  momentum:0.9
-  weight_decay:0.004
-  learning_rate_change_method:kFixedStep
-  step:0
-  step:60000
-  step:65000
-  step_lr:0.001
-  step_lr:0.0001
-  step_lr:0.00001
-}
-neuralnet {
-partition_type: kDataPartition
-layer{
-  name: "data"
-  type: "kShardData"
-  data_param {
-    path: "examples/cifar10/cifar10_train_shard"
-    batchsize: 128
-  }
-  exclude: kTest
-}
-layer{
-  name: "data"
-  type: "kShardData"
-  data_param {
-    path: "examples/cifar10/cifar10_test_shard"
-    batchsize: 128
-  }
-  exclude: kTrain
-}
-layer{
-  name:"rgb"
-  type: "kRGBImage"
-  srclayers: "data"
-  rgbimage_param {
-    meanfile: "examples/cifar10/image_mean.bin"
-  }
-}
-layer{
-  name: "label"
-  type: "kLabel"
-  srclayers: "data"
-}
-
-layer {
-  name: "conv1"
-  type: "kConvolution"
-  srclayers: "rgb"
-  convolution_param {
-    num_filters: 32
-    kernel: 5
-    stride: 1
-    pad:2
-  }
-  param{
-      name: "weight"
-      init_method:kGaussian
-      std:0.0001
-      learning_rate_multiplier:1.0
-    }
-  param{
-      name: "bias"
-      init_method: kConstant
-      learning_rate_multiplier:2.0
-      value:0
-    }
-}
-
-layer {
-  name: "pool1"
-  type: "kPooling"
-  srclayers: "conv1"
-  pooling_param {
-    pool: MAX
-    kernel: 3
-    stride: 2
-  }
-}
-layer {
-  name: "relu1"
-  type: "kReLU"
-  srclayers:"pool1"
-}
-layer {
-  name: "norm1"
-  type: "kLRN"
-  lrn_param {
-    norm_region: WITHIN_CHANNEL
-    local_size: 3
-    alpha: 5e-05
-    beta: 0.75
-  }
-  srclayers:"relu1"
-}
-layer {
-  name: "conv2"
-  type: "kConvolution"
-  srclayers: "norm1"
-  convolution_param {
-    num_filters: 32
-    kernel: 5
-    stride: 1
-    pad:2
-  }
-  param{
-      name: "weight"
-      init_method:kGaussian
-      std:0.01
-      learning_rate_multiplier:1.0
-    }
-  param{
-      name: "bias"
-      init_method: kConstant
-      learning_rate_multiplier:2.0
-      value:0
-    }
-}
-layer {
-  name: "relu2"
-  type: "kReLU"
-  srclayers:"conv2"
-}
-layer {
-  name: "pool2"
-  type: "kPooling"
-  srclayers: "relu2"
-  pooling_param {
-    pool: MAX
-    kernel: 3
-    stride: 2
-  }
-}
-layer {
-  name: "norm2"
-  type: "kLRN"
-  lrn_param {
-    norm_region: WITHIN_CHANNEL
-    local_size: 3
-    alpha: 5e-05
-    beta: 0.75
-  }
-  srclayers:"pool2"
-}
-layer {
-  name: "conv3"
-  type: "kConvolution"
-  srclayers: "norm2"
-  convolution_param {
-    num_filters: 64
-    kernel: 5
-    stride: 1
-    pad:2
-  }
-  param{
-      name: "weight"
-      init_method:kGaussian
-      std:0.01
-    }
-  param{
-      name: "bias"
-      init_method: kConstant
-      value:0
-    }
-}
-layer {
-  name: "relu3"
-  type: "kReLU"
-  srclayers:"conv3"
-}
-layer {
-  name: "pool3"
-  type: "kPooling"
-  srclayers: "relu3"
-  pooling_param {
-    pool: AVE
-    kernel: 3
-    stride: 2
-  }
-}
-layer {
-  name: "ip1"
-  type: "kInnerProduct"
-  srclayers:"pool3"
-  inner_product_param {
-    num_output: 10
-  }
-  param{
-      name: "weight"
-      init_method:kGaussian
-      std:0.01
-      learning_rate_multiplier:1.0
-      weight_decay_multiplier:250
-    }
-  param{
-      name: "bias"
-      init_method: kConstant
-      learning_rate_multiplier:2.0
-      weight_decay_multiplier:0
-      value:0
-  }
-}
-
-layer{
-  name: "loss"
-  type:"kSoftmaxLoss"
-  softmaxloss_param{
-    topk:1
-  }
-  srclayers:"ip1"
-  srclayers: "label"
-}
-}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ba71553/examples/cifar10/model.conf.example
----------------------------------------------------------------------
diff --git a/examples/cifar10/model.conf.example b/examples/cifar10/model.conf.example
new file mode 100644
index 0000000..3f1e2a2
--- /dev/null
+++ b/examples/cifar10/model.conf.example
@@ -0,0 +1,218 @@
+name: "cifar10-convnet"
+train_steps: 1000
+test_steps:100
+test_frequency:300
+display_frequency:30
+updater{
+  momentum:0.9
+  weight_decay:0.004
+  learning_rate_change_method:kFixedStep
+  step:0
+  step:60000
+  step:65000
+  step_lr:0.001
+  step_lr:0.0001
+  step_lr:0.00001
+}
+neuralnet {
+partition_type: kDataPartition
+layer{
+  name: "data"
+  type: "kShardData"
+  data_param {
+    path: "examples/cifar10/cifar10_train_shard"
+    batchsize: 16
+  }
+  exclude: kTest
+}
+layer{
+  name: "data"
+  type: "kShardData"
+  data_param {
+    path: "examples/cifar10/cifar10_test_shard"
+    batchsize: 128
+  }
+  exclude: kTrain
+}
+layer{
+  name:"rgb"
+  type: "kRGBImage"
+  srclayers: "data"
+  rgbimage_param {
+    meanfile: "examples/cifar10/image_mean.bin"
+  }
+}
+layer{
+  name: "label"
+  type: "kLabel"
+  srclayers: "data"
+}
+
+layer {
+  name: "conv1"
+  type: "kConvolution"
+  srclayers: "rgb"
+  convolution_param {
+    num_filters: 32
+    kernel: 5
+    stride: 1
+    pad:2
+  }
+  param{
+      name: "weight"
+      init_method:kGaussian
+      std:0.0001
+      learning_rate_multiplier:1.0
+    }
+  param{
+      name: "bias"
+      init_method: kConstant
+      learning_rate_multiplier:2.0
+      value:0
+    }
+}
+
+layer {
+  name: "pool1"
+  type: "kPooling"
+  srclayers: "conv1"
+  pooling_param {
+    pool: MAX
+    kernel: 3
+    stride: 2
+  }
+}
+layer {
+  name: "relu1"
+  type: "kReLU"
+  srclayers:"pool1"
+}
+layer {
+  name: "norm1"
+  type: "kLRN"
+  lrn_param {
+    norm_region: WITHIN_CHANNEL
+    local_size: 3
+    alpha: 5e-05
+    beta: 0.75
+  }
+  srclayers:"relu1"
+}
+layer {
+  name: "conv2"
+  type: "kConvolution"
+  srclayers: "norm1"
+  convolution_param {
+    num_filters: 32
+    kernel: 5
+    stride: 1
+    pad:2
+  }
+  param{
+      name: "weight"
+      init_method:kGaussian
+      std:0.01
+      learning_rate_multiplier:1.0
+    }
+  param{
+      name: "bias"
+      init_method: kConstant
+      learning_rate_multiplier:2.0
+      value:0
+    }
+}
+layer {
+  name: "relu2"
+  type: "kReLU"
+  srclayers:"conv2"
+}
+layer {
+  name: "pool2"
+  type: "kPooling"
+  srclayers: "relu2"
+  pooling_param {
+    pool: MAX
+    kernel: 3
+    stride: 2
+  }
+}
+layer {
+  name: "norm2"
+  type: "kLRN"
+  lrn_param {
+    norm_region: WITHIN_CHANNEL
+    local_size: 3
+    alpha: 5e-05
+    beta: 0.75
+  }
+  srclayers:"pool2"
+}
+layer {
+  name: "conv3"
+  type: "kConvolution"
+  srclayers: "norm2"
+  convolution_param {
+    num_filters: 64
+    kernel: 5
+    stride: 1
+    pad:2
+  }
+  param{
+      name: "weight"
+      init_method:kGaussian
+      std:0.01
+    }
+  param{
+      name: "bias"
+      init_method: kConstant
+      value:0
+    }
+}
+layer {
+  name: "relu3"
+  type: "kReLU"
+  srclayers:"conv3"
+}
+layer {
+  name: "pool3"
+  type: "kPooling"
+  srclayers: "relu3"
+  pooling_param {
+    pool: AVE
+    kernel: 3
+    stride: 2
+  }
+}
+layer {
+  name: "ip1"
+  type: "kInnerProduct"
+  srclayers:"pool3"
+  inner_product_param {
+    num_output: 10
+  }
+  param{
+      name: "weight"
+      init_method:kGaussian
+      std:0.01
+      learning_rate_multiplier:1.0
+      weight_decay_multiplier:250
+    }
+  param{
+      name: "bias"
+      init_method: kConstant
+      learning_rate_multiplier:2.0
+      weight_decay_multiplier:0
+      value:0
+  }
+}
+
+layer{
+  name: "loss"
+  type:"kSoftmaxLoss"
+  softmaxloss_param{
+    topk:1
+  }
+  srclayers:"ip1"
+  srclayers: "label"
+}
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ba71553/include/mshadow/tensor_random.h
----------------------------------------------------------------------
diff --git a/include/mshadow/tensor_random.h b/include/mshadow/tensor_random.h
index b3f0b84..717d32c 100644
--- a/include/mshadow/tensor_random.h
+++ b/include/mshadow/tensor_random.h
@@ -7,13 +7,17 @@
  *   Based on curand|MKL|stdlib
  */
 #include <cstdlib>
+#include <random>
+#include <chrono>
 #include "tensor.h"
 #include "tensor_container.h"
 
 namespace mshadow {
-    /*! 
-     * \brief random number generator 
+    /*!
+     * \brief random number generator
      * \tparam Device the device of random number generator
+     *
+     * Note: replaced rand (srand) with c++11's random functions.
      */
     template<typename Device>
     class Random {};
@@ -23,6 +27,14 @@ namespace mshadow {
     class Random<cpu> {
     public:
         /*!
+         * \brief constructor of random engine using default seed
+         */
+        Random<cpu> (){
+          // obtain a seed from the system clock:
+          unsigned s= std::chrono::system_clock::now().time_since_epoch().count();
+          Seed(s);
+        }
+        /*!
          * \brief constructor of random engine
          * \param seed random number seed
          */
@@ -31,7 +43,8 @@ namespace mshadow {
             int status = vslNewStream(&vStream_, VSL_BRNG_MT19937, seed);
             utils::Assert( status == VSL_STATUS_OK, "MKL VSL Random engine failed to be initialized.\n" );
             #else
-            srand(seed);
+            //srand(seed);
+            gen_.seed(seed);
             #endif
             buffer_.Resize( Shape1( kRandBufferSize ) );
         }
@@ -51,7 +64,8 @@ namespace mshadow {
             status = vslNewStream(&vStream_, VSL_BRNG_MT19937, seed);
             utils::Assert(status == VSL_STATUS_OK);
             #else
-            srand( seed );
+            // srand( seed );
+            gen_.seed(seed);
             #endif
         }
         /*!
@@ -64,6 +78,7 @@ namespace mshadow {
         template<int dim>
         inline void SampleUniform( Tensor<cpu, dim> &dst, real_t a=0.0f, real_t b=1.0f ) {
             Tensor<cpu, 2> mat = dst.FlatTo2D();
+            std::uniform_real_distribution<real_t> distribution (a,b);
             for ( index_t i = 0; i < mat.shape[1]; ++i ) {
                 #if MSHADOW_USE_MKL
                 #if MSHADOW_SINGLE_PRECISION
@@ -74,9 +89,14 @@ namespace mshadow {
                 utils::Assert(status == VSL_STATUS_OK, "Failed to generate random number by MKL.\n" );
                 #else
                 // use stdlib
+                /*
                 for ( index_t j = 0; j < mat.shape[0]; ++j ) {
                     mat[i][j] = this->RandNext()*(b-a) + a;
                 }
+                */
+                for ( index_t j = 0; j < mat.shape[0]; ++j ) {
+                    mat[i][j] = distribution(gen_);
+                }
                 #endif
             }
         }
@@ -93,6 +113,7 @@ namespace mshadow {
                 dst = mu; return;
             }
             Tensor<cpu, 2> mat = dst.FlatTo2D();
+            std::normal_distribution<real_t> distribution (mu, sigma);
             for (index_t i = 0; i < mat.shape[1]; ++i) {
                 #if MSHADOW_USE_MKL
                 #if MSHADOW_SINGLE_PRECISION
@@ -102,6 +123,7 @@ namespace mshadow {
                 #endif
                 utils::Assert(status == VSL_STATUS_OK, "Failed to generate random number by MKL.\n" );
                 #else
+                /*
                 real_t g1 = 0.0f, g2 = 0.0f;
                 for (index_t j = 0; j < mat.shape[0]; ++j) {
                     if( (j & 1) == 0 ){
@@ -111,6 +133,10 @@ namespace mshadow {
                         mat[i][j] = mu + g2 * sigma;
                     }
                 }
+                */
+                for (index_t j = 0; j < mat.shape[0]; ++j) {
+                  mat[i][j] = distribution(gen_);
+                }
                 #endif
             }
         }
@@ -177,6 +203,9 @@ namespace mshadow {
         #endif
         /*! \brief temporal space used to store random numbers */
         TensorContainer<cpu,1> buffer_;
+
+        /*! \brief c++11 random generator, added for SINGA use */
+        std::mt19937 gen_;
     }; // class Random<cpu>
 
 #ifdef __CUDACC__

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ba71553/include/utils/singleton.h
----------------------------------------------------------------------
diff --git a/include/utils/singleton.h b/include/utils/singleton.h
index 2e2bdfb..3c2022b 100644
--- a/include/utils/singleton.h
+++ b/include/utils/singleton.h
@@ -1,41 +1,31 @@
 #ifndef INCLUDE_UTILS_SINGLETON_H_
 #define INCLUDE_UTILS_SINGLETON_H_
-
+/**
+  * thread-safe implementation for C++11 according to
+  * http://stackoverflow.com/questions/2576022/efficient-thread-safe-singleton-in-c
+  */
 template<typename T>
 class Singleton {
  public:
+
   static T* Instance() {
-    if (data_==nullptr) {
-      data_ = new T();
-    }
+    static T* data_=new T();
     return data_;
   }
- private:
-  static T* data_;
 };
 
-template<typename T> T* Singleton<T>::data_ = nullptr;
-
-
 /**
- * Singleton initiated with argument
+ * Thread Specific Singleton
+ *
+ * Each thread will have its own data_ storage.
  */
-template<typename T, typename X=int>
-class ASingleton {
+template<typename T>
+class TSingleton {
  public:
   static T* Instance(){
+    static thread_local T* data_=new T();
     return data_;
   }
-  static T* Instance(X x) {
-    if (data_==nullptr) {
-      data_ = new T(x);
-    }
-    return data_;
-  }
- private:
-  static T* data_;
 };
 
-template<typename T, typename X> T* ASingleton<T,X>::data_ = nullptr;
-
 #endif // INCLUDE_UTILS_SINGLETON_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ba71553/src/neuralnet/layer.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/layer.cc b/src/neuralnet/layer.cc
index 25cae42..a374511 100644
--- a/src/neuralnet/layer.cc
+++ b/src/neuralnet/layer.cc
@@ -129,8 +129,6 @@ void DropoutLayer::Setup(const LayerProto& proto,
   grad_.ReshapeLike(*srclayers[0]->mutable_grad(this));
   mask_.Reshape(srclayers[0]->data(this).shape());
   pdrop_=proto.dropout_param().dropout_ratio();
-  unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
-  ASingleton<Random<cpu>>::Instance(seed);
 }
 
 void DropoutLayer::SetupAfterPartition(const LayerProto& proto,
@@ -147,7 +145,7 @@ void DropoutLayer::ComputeFeature(bool training, const vector<SLayer>& srclayers
   }
   float pkeep=1-pdrop_;
   Tensor<cpu, 1> mask(mask_.mutable_cpu_data(), Shape1(mask_.count()));
-  mask = F<op::threshold>(ASingleton<Random<cpu>>::Instance()\
+  mask = F<op::threshold>(TSingleton<Random<cpu>>::Instance()\
       ->uniform(mask.shape), pkeep ) * (1.0f/pkeep);
   Tensor<cpu, 1> data(data_.mutable_cpu_data(), Shape1(data_.count()));
   Blob<float>* srcblob=srclayers[0]->mutable_data(this);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ba71553/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index ac5566c..89743e5 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -136,8 +136,7 @@ void Param::Setup(const ParamProto& proto, const vector<int>& shape,
 
 void Param::Init(int v){
   Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size()));
-  unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
-  auto random=ASingleton<Random<cpu>>::Instance(seed);
+  auto random=TSingleton<Random<cpu>>::Instance();
   switch (proto_.init_method()) {
   case ParamProto::kConstant:
     data=proto_.value();