You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/03/30 09:34:39 UTC

[1/3] incubator-singa git commit: SINGA-148 Race condition between Worker threads and Driver

Repository: incubator-singa
Updated Branches:
  refs/heads/master a6eea9c4a -> b00dc32fb


SINGA-148 Race condition between Worker threads and Driver

The worker may query the device id before the driver sets it up (via
Context).  It is fixed by sleeping the worker until the driver finishes the
setting.  Now all devices (GPU and CPU) must be setup via Context::SetupDevice,
otherwise the worker would sleep forever.
The Blob math functions now check device_id < 0 to call CPU/GPU
functions.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/96794175
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/96794175
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/96794175

Branch: refs/heads/master
Commit: 9679417509baf62e0c565e0cec140844b778d827
Parents: 259422c
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Tue Mar 29 11:49:01 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Tue Mar 29 11:49:01 2016 +0800

----------------------------------------------------------------------
 include/singa/utils/context.h   |  2 +-
 include/singa/utils/math_blob.h | 44 ++++++++++++++++++------------------
 src/worker.cc                   |  7 +++++-
 3 files changed, 29 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/96794175/include/singa/utils/context.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/context.h b/include/singa/utils/context.h
index b1128c1..3490d29 100644
--- a/include/singa/utils/context.h
+++ b/include/singa/utils/context.h
@@ -113,7 +113,7 @@ class Context {
     if (device_id_.find(tid) != device_id_.end())
       return device_id_[tid];
     else
-      return -1;
+      return -2;
   }
   /**
    * Setup the CPU thread, which may be assigned a GPU device.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/96794175/include/singa/utils/math_blob.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/math_blob.h b/include/singa/utils/math_blob.h
index 778824e..93967b4 100644
--- a/include/singa/utils/math_blob.h
+++ b/include/singa/utils/math_blob.h
@@ -46,7 +46,7 @@ template<typename Dtype>
 void Scale(Dtype alpha, Blob<Dtype> * B) {
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     cpu_scale(B->count(), alpha, B->mutable_cpu_data());
   } else {
 #ifdef USE_GPU
@@ -64,7 +64,7 @@ void AXPY(Dtype alpha, const Blob<Dtype> & A, Blob<Dtype> * B) {
   CHECK_EQ(A.count(), B->count());
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     cpu_axpy(A.count(), alpha, A.cpu_data(), B->mutable_cpu_data());
   } else {
 #ifdef USE_GPU
@@ -104,7 +104,7 @@ void GEMV(Dtype alpha, Dtype beta, const Blob<Dtype>& A,
   bool TranA = A.transpose();
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     cpu_gemv(A.cpu_data(), B.cpu_data(), m, n, alpha, beta, TranA,
         C->mutable_cpu_data());
   } else {
@@ -169,7 +169,7 @@ void GEMM(Dtype alpha, Dtype beta, const Blob<Dtype>& A, const Blob<Dtype>& B,
   bool TranB = B.transpose();
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     cpu_gemm(A.cpu_data(), B.cpu_data(), m, n, k, alpha, beta, TranA, TranB,
         C->mutable_cpu_data());
   } else {
@@ -212,7 +212,7 @@ Dtype VVDot(const Blob<Dtype> & A, const Blob<Dtype> & B) {
   int n = A.count();
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     res = cpu_dot(A.cpu_data(), B.cpu_data(), n);
   } else {
 #ifdef USE_GPU
@@ -241,7 +241,7 @@ void OuterProduct(const Blob<Dtype>& A, const Blob<Dtype>& B, Blob<Dtype> * C) {
   CHECK_EQ(C->count(), m * n);
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     cpu_gemm(A.cpu_data(), B.cpu_data(), m, n, 1, 1, 0, false, false,
         C->mutable_cpu_data());
   } else {
@@ -262,7 +262,7 @@ void Map(const Blob<Dtype> & A, Blob<Dtype> * B) {
   CHECK_EQ(A.count(), B->count()) << "Blobs must have the same size";
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     cpu_e_f<Op>(A.count(), A.cpu_data(), B->mutable_cpu_data());
   } else {
 #ifdef USE_GPU
@@ -285,7 +285,7 @@ void Map(const Blob<Dtype> & A, const Blob<Dtype> & B, Blob<Dtype> * C) {
   //cpu_e_f<Op>(A.count(), A.cpu_data(), B.cpu_data(), C->mutable_cpu_data());
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     cpu_e_f<Op>(A.count(), A.cpu_data(), B.cpu_data(), C->mutable_cpu_data());
   } else {
 #ifdef USE_GPU
@@ -304,7 +304,7 @@ void Map(Dtype alpha, const Blob<Dtype>& A, Blob<Dtype>* B) {
   CHECK_EQ(A.count(), B->count()) << "Blobs must have the same size";
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     cpu_e_f<Op>(A.count(), alpha, A.cpu_data(), B->mutable_cpu_data());
   } else {
 #ifdef USE_GPU
@@ -324,7 +324,7 @@ void Map(Dtype alpha, const Blob<Dtype>& A, const Blob<Dtype>& B,
   CHECK_EQ(A.count(), B->count()) << "Blobs must have the same size";
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     cpu_e_f<Op>(A.count(), alpha, A.cpu_data(), B->cpu_data(),
         C->mutable_cpu_data());
   } else {
@@ -346,7 +346,7 @@ void Copy(const Blob<Dtype>& A, Blob<Dtype>* B) {
   CHECK_EQ(A.count(), B->count()) << "Blobs must have the same size";
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     std::copy(A.cpu_data(), A.cpu_data() + A.count(), B->mutable_cpu_data());
   } else {
 #ifdef USE_GPU
@@ -474,7 +474,7 @@ void MVAddCol(Dtype alpha, Dtype beta, const Blob<Dtype> & A, Blob<Dtype> * B) {
     one.SetValue(1);
     auto context = Singleton<Context>::Instance();
     int device = context->device_id(std::this_thread::get_id());
-    if (device == -1) {
+    if (device < 0) {
       cpu_gemm(A.cpu_data(), one.cpu_data(), m, n, 1, alpha, beta, false, false,
           B->mutable_cpu_data());
     } else {
@@ -511,7 +511,7 @@ void MVAddRow(Dtype alpha, Dtype beta, const Blob<Dtype> & A, Blob<Dtype> * B) {
     int n = A.count(), m = B->count() / n;
     auto context = Singleton<Context>::Instance();
     int device = context->device_id(std::this_thread::get_id());
-    if (device == -1) {
+    if (device < 0) {
       Blob<Dtype> one(m);
       one.SetValue(1);
       cpu_gemm(one.cpu_data(), A.cpu_data(), m, n, 1, alpha, beta,
@@ -566,7 +566,7 @@ void MVSumCol(Dtype alpha, Dtype beta, const Blob<Dtype> & A, Blob<Dtype> * B) {
   int m = B->count(), n = A.count() / m;
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     Blob<Dtype> one(n);
     one.SetValue(1);
     cpu_gemm(A.cpu_data(), one.cpu_data(), m, 1, n, alpha, beta,
@@ -591,7 +591,7 @@ void MVSumRow(Dtype alpha, Dtype beta, const Blob<Dtype> & A, Blob<Dtype> * B) {
   int n = B->count(), m = A.count() / n;
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     Blob<Dtype> one(m);
     one.SetValue(1);
     cpu_gemm(one.cpu_data(), A.cpu_data(), 1, n, m, alpha, beta, false,
@@ -615,7 +615,7 @@ void Reduce2D(const Blob<Dtype> & A, Blob<Dtype> * B) {
   int m = B->count(), n = A.count() / m;
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     cpu_reduce_f<Op>(A.cpu_data(), m, n, B->mutable_cpu_data());
   } else {
 #ifdef USE_GPU
@@ -635,7 +635,7 @@ void Expand2D(const Blob<Dtype> & A, Blob<Dtype> * B) {
   int m = A.count(), n = B->count() / m;
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     cpu_expand_f<Op>(A.cpu_data(), m, n, B->mutable_cpu_data());
   } else {
 #ifdef USE_GPU
@@ -653,7 +653,7 @@ Dtype Asum(const Blob<Dtype>& A) {
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
   Dtype ret = Dtype(0);
-  if (device == -1) {
+  if (device < 0) {
     ret = cpu_asum(A.count(), A.cpu_data(), 1) / A.count();
   } else {
 #ifdef USE_GPU
@@ -671,7 +671,7 @@ void SampleUniform(Dtype low, Dtype high, Blob<Dtype>* A) {
   auto context = Singleton<Context>::Instance();
   const auto& thread = std::this_thread::get_id();
   int device = context->device_id(thread);
-  if (device == -1) {
+  if (device < 0) {
     cpu_sample_uniform(*context->rand_generator(thread), A->count(), low, high,
         A->mutable_cpu_data());
   } else {
@@ -689,7 +689,7 @@ void SampleGaussian(Dtype mean, Dtype std, Blob<Dtype>* A) {
   auto context = Singleton<Context>::Instance();
   const auto& thread = std::this_thread::get_id();
   int device = context->device_id(thread);
-  if (device == -1) {
+  if (device < 0) {
     cpu_sample_gaussian(*context->rand_generator(thread), A->count(), mean, std,
         A->mutable_cpu_data());
   } else {
@@ -708,7 +708,7 @@ void Softmax(int nb_rows, const Blob<Dtype>& A, Blob<Dtype>* B) {
   CHECK_EQ(A.count(), B->count());
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     cpu_softmax(nb_rows, A.count() / nb_rows, A.cpu_data(),
       B->mutable_cpu_data());
   } else {
@@ -721,7 +721,7 @@ template<typename Dtype>
 void Zero(Blob<Dtype>* B) {
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
-  if (device == -1) {
+  if (device < 0) {
     B->SetValue(0);
   } else {
 #ifdef USE_GPU

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/96794175/src/worker.cc
----------------------------------------------------------------------
diff --git a/src/worker.cc b/src/worker.cc
index 2afa8b0..1e35ff9 100644
--- a/src/worker.cc
+++ b/src/worker.cc
@@ -64,7 +64,12 @@ Worker::~Worker() {
 void Worker::Run() {
   // setup gpu device
   auto context = Singleton<Context>::Instance();
-  int device = context->device_id(std::this_thread::get_id());
+  // TODO(wangwei) -2 for uninitial device; -1 for CPU; >=0 for GPU now.
+  int device = -2;
+  while (device == -2) {
+    device = context->device_id(std::this_thread::get_id());
+    std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+  }
   LOG(ERROR) << "Worker (group = " << grp_id_ <<", id = " << id_ << ") "
     << " start on " << (device >= 0 ? "GPU " + std::to_string(device) : "CPU");
   if (device >= 0)


[3/3] incubator-singa git commit: SINGA-148 Race condition between Worker threads and Driver

Posted by wa...@apache.org.
SINGA-148 Race condition between Worker threads and Driver

minor bug fix in slice.cc, concate.cc
pass cpplint check


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b00dc32f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b00dc32f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b00dc32f

Branch: refs/heads/master
Commit: b00dc32fb1562d3b71bc346021d1918059d5e6cc
Parents: 77d5c5c
Author: WANG Sheng <wa...@gmail.com>
Authored: Wed Mar 30 15:22:41 2016 +0800
Committer: WANG Sheng <wa...@gmail.com>
Committed: Wed Mar 30 15:22:41 2016 +0800

----------------------------------------------------------------------
 include/singa/utils/math_blob.h           | 2 +-
 src/neuralnet/connection_layer/concate.cc | 8 ++++----
 src/neuralnet/connection_layer/slice.cc   | 8 ++++----
 3 files changed, 9 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b00dc32f/include/singa/utils/math_blob.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/math_blob.h b/include/singa/utils/math_blob.h
index 93967b4..55ba44b 100644
--- a/include/singa/utils/math_blob.h
+++ b/include/singa/utils/math_blob.h
@@ -282,7 +282,7 @@ template<typename Op, typename Dtype>
 void Map(const Blob<Dtype> & A, const Blob<Dtype> & B, Blob<Dtype> * C) {
   CHECK_EQ(A.count(), B.count()) << "Blobs must have the same size";
   CHECK_EQ(A.count(), C->count()) << "Blobs must have the same size";
-  //cpu_e_f<Op>(A.count(), A.cpu_data(), B.cpu_data(), C->mutable_cpu_data());
+  // cpu_e_f<Op>(A.count(), A.cpu_data(), B.cpu_data(), C->mutable_cpu_data());
   auto context = Singleton<Context>::Instance();
   int device = context->device_id(std::this_thread::get_id());
   if (device < 0) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b00dc32f/src/neuralnet/connection_layer/concate.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/connection_layer/concate.cc b/src/neuralnet/connection_layer/concate.cc
index 0cdd812..9d3fd0c 100644
--- a/src/neuralnet/connection_layer/concate.cc
+++ b/src/neuralnet/connection_layer/concate.cc
@@ -60,7 +60,7 @@ void ConcateLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) {
   int device = context->device_id(std::this_thread::get_id());
   while (concate_offset < data_.count()) {
     for (size_t i = 0; i < srclayers.size(); ++i) {
-      if (device == -1) {
+      if (device < 0) {
         const float* src = srclayers[i]->data(this).cpu_data()
           + srclayer_offset;
         float* dst = data_.mutable_cpu_data() + concate_offset;
@@ -72,7 +72,7 @@ void ConcateLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) {
         float* dst = data_.mutable_gpu_data() + concate_offset;
         cudaMemcpy(dst, src, step * sizeof(float), cudaMemcpyDefault);
 #else
-        LOG(FATAL) << "GPU is supported";
+        LOG(FATAL) << "GPU is not supported";
 #endif
       }
       concate_offset += step;
@@ -94,7 +94,7 @@ void ConcateLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
   int device = context->device_id(std::this_thread::get_id());
   while (concate_offset < grad_.count()) {
     for (size_t i = 0; i < srclayers.size(); ++i) {
-      if (device == -1) {
+      if (device < 0) {
         const float* src = grad_.cpu_data() + concate_offset;
         float* dst = srclayers[i]->mutable_grad(this)->mutable_cpu_data()
           + srclayer_offset;
@@ -106,7 +106,7 @@ void ConcateLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
           + srclayer_offset;
         cudaMemcpy(dst, src, step * sizeof(float), cudaMemcpyDefault);
 #else
-        LOG(FATAL) << "GPU is supported";
+        LOG(FATAL) << "GPU is not supported";
 #endif
       }
       concate_offset += step;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b00dc32f/src/neuralnet/connection_layer/slice.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/connection_layer/slice.cc b/src/neuralnet/connection_layer/slice.cc
index efa33a4..3cca3fd 100644
--- a/src/neuralnet/connection_layer/slice.cc
+++ b/src/neuralnet/connection_layer/slice.cc
@@ -73,7 +73,7 @@ void SliceLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) {
   int device = context->device_id(std::this_thread::get_id());
   while (srclayer_offset < blob.count()) {
     for (int i = 0; i < num_slices_; ++i) {
-      if (device == -1) {
+      if (device < 0) {
         const float* src = blob.cpu_data() + srclayer_offset;
         float* dst = datavec_[i]->mutable_cpu_data() + slice_offset;
         memcpy(dst, src, step * sizeof(float));
@@ -83,7 +83,7 @@ void SliceLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) {
         float* dst = datavec_[i]->mutable_gpu_data() + slice_offset;
         cudaMemcpy(dst, src, step * sizeof(float), cudaMemcpyDefault);
 #else
-        LOG(FATAL) << "GPU is supported";
+        LOG(FATAL) << "GPU is not supported";
 #endif
       }
       srclayer_offset += step;
@@ -105,7 +105,7 @@ void SliceLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
   int device = context->device_id(std::this_thread::get_id());
   while (srclayer_offset < blob->count()) {
     for (int i = 0; i < num_slices_; ++i) {
-      if (device == -1) {
+      if (device < 0) {
         const float* src = gradvec_[i]->cpu_data() + slice_offset;
         float* dst = blob->mutable_cpu_data() + srclayer_offset;
         memcpy(dst, src, step * sizeof(float));
@@ -115,7 +115,7 @@ void SliceLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
         float* dst = blob->mutable_gpu_data() + srclayer_offset;
         cudaMemcpy(dst, src, step * sizeof(float), cudaMemcpyDefault);
 #else
-        LOG(FATAL) << "GPU is supported";
+        LOG(FATAL) << "GPU is not supported";
 #endif
       }
       srclayer_offset += step;


[2/3] incubator-singa git commit: SINGA-148 Race condition between Worker threads and Driver

Posted by wa...@apache.org.
SINGA-148 Race condition between Worker threads and Driver

merge pr-140 to master


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/77d5c5ce
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/77d5c5ce
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/77d5c5ce

Branch: refs/heads/master
Commit: 77d5c5ce357b55144fbe25998fad333a0010d40b
Parents: 9679417 a6eea9c
Author: WANG Sheng <wa...@gmail.com>
Authored: Wed Mar 30 14:37:35 2016 +0800
Committer: WANG Sheng <wa...@gmail.com>
Committed: Wed Mar 30 14:37:35 2016 +0800

----------------------------------------------------------------------
 tool/docker/mesos/Dockerfile     |  9 ++---
 tool/docker/singa/Dockerfile     |  9 ++---
 tool/docker/singa/Dockerfile_gpu | 62 +++++++++++++++++++++++++++++++++++
 tool/mesos/singa_scheduler.cc    |  1 -
 4 files changed, 72 insertions(+), 9 deletions(-)
----------------------------------------------------------------------