You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2015/05/17 08:20:00 UTC

[2/5] incubator-singa git commit: 1. move functions in pm_server (pm_worker) into server (trainer) to simplify the logics. now workers send simple messages to the stub thread which construct the real update/get/put requests. the stub thread also handles

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 3d46ee6..ac5566c 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -11,36 +11,37 @@ using std::string;
 namespace singa {
 
 Msg* Param::GenPutMsg(void* arg){
-  char buf[256];
-  int v=*(int*)arg;
-  sprintf(buf, "%d %d %f %f", v, size(),
+  char buf[128];
+  sprintf(buf, "%d %f %f", size(),
       learning_rate_multiplier(), weight_decay_multiplier());
   Msg* msg=new Msg();
   msg->set_type(kPut);
+  int v=version();
+  if(arg!=nullptr)
+    v=*(int*)arg;
+  msg->set_target(owner(), v);
   msg->add_frame(buf, strlen(buf));
   msg->add_frame(mutable_cpu_data(), size()*sizeof(float));
 	return msg;
 }
 
 Msg* Param::GenGetMsg(void* arg){
-  char buf[12];
-  int v=*(int*)arg;
-  sprintf(buf, "%d", v);
-  LOG(ERROR)<<"gen get version "<<v;
   Msg* msg=new Msg();
   msg->set_type(kGet);
-  msg->add_frame(buf, strlen(buf));
+  int v=version();
+  if(arg!=nullptr)
+    v=*(int*)arg;
+  msg->set_target(owner(), v);
   return msg;
 }
 
 Msg* Param::GenUpdateMsg(void* arg){
-  char buf[10];
-  int v=*(int*)arg;
-  sprintf(buf, "%d", v);
   Msg* msg=new Msg();
   msg->set_type(kUpdate);
-  msg->add_frame(buf, strlen(buf));
-
+  int v=version();
+  if(arg!=nullptr)
+    v=*(int*)arg;
+  msg->set_target(owner(), v);
   msg->add_frame(mutable_cpu_grad(), size()*sizeof(float));
   return msg;
 }
@@ -50,16 +51,16 @@ Msg* Param::GenSyncMsg(void* arg){
 }
 
 Msg* Param::HandlePutMsg(Msg** msg){
-  int v, size;
+  int size;
   float lr, wc;
-  sscanf(static_cast<char*>((*msg)->frame_data()), "%d %d %f %f",
-      &v, &size, &lr, &wc);
+  sscanf(static_cast<char*>((*msg)->frame_data()), "%d %f %f",
+      &size, &lr, &wc);
   proto_.set_learning_rate_multiplier(lr);
   proto_.set_weight_decay_multiplier(wc);
   CHECK((*msg)->next_frame());
   vector<int> shape{size};
   data_=std::make_shared<Blob<float>>(shape);
-  data_->set_version(v);
+  data_->set_version((*msg)->target_second());
   grad_.Reshape(shape);
   history_.Reshape(shape);
   CHECK_EQ(size* sizeof(float), (*msg)->frame_size());
@@ -70,21 +71,16 @@ Msg* Param::HandlePutMsg(Msg** msg){
 }
 
 Msg* Param::HandleGetMsg(Msg** msg){
-  int v;
-  sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v);
-  CHECK_LE(v, version());
-  CHECK(!(*msg)->next_frame());
-  (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size());
-  (*msg)->SwapAddr();
-  (*msg)->set_type(kRGet);
+  if((*msg)->target_second()<=version()){
+    (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size());
+    (*msg)->SwapAddr();
+    (*msg)->set_type(kRGet);
+  }
   return *msg;
 }
 
 int Param::ParseUpdateMsg(Msg** msg){
-  int v;
-  sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v);
-  CHECK_LE(v, version());
-  CHECK((*msg)->next_frame());
+  CHECK((*msg)->frame_size());
   memcpy(mutable_cpu_grad(), (*msg)->frame_data(),(*msg)->frame_size());
   delete (*msg);
   *msg=nullptr;
@@ -93,16 +89,15 @@ int Param::ParseUpdateMsg(Msg** msg){
 
 Msg* Param::GenUpdateResponseMsg(void* arg){
   Msg* msg=new Msg();
-  char buf[10];
-  sprintf(buf, "%d", version());
   msg->set_type(kRUpdate);
-  msg->set_target(id());
-  msg->add_frame(buf, strlen(buf));
+  int v=version();
+  if(arg!=nullptr)
+    v=*(int*)arg;
+  msg->set_target(owner(), v);
   msg->add_frame(mutable_cpu_data(), size()*sizeof(float));
   return msg;
 }
 
-
 Msg* Param::HandleSyncMsg(Msg** msg){
   delete *msg;
   *msg=nullptr;
@@ -118,12 +113,12 @@ int Param::ParsePutResponseMsg(Msg **msg){
   return ParseSyncResponseMsg(msg);
 }
 int Param::ParseGetResponseMsg(Msg **msg){
-  int v;
-  sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v);
-  CHECK((*msg)->next_frame());
+  CHECK((*msg)->frame_size());
   memcpy(mutable_cpu_data(), (*msg)->frame_data(), (*msg)->frame_size());
   // must be set after all other settings are done!
-  set_version(v);
+  set_version((*msg)->target_second());
+  delete *msg;
+  *msg=nullptr;
   return 1;
 }
 int Param::ParseUpdateResponseMsg(Msg **msg){
@@ -140,7 +135,6 @@ void Param::Setup(const ParamProto& proto, const vector<int>& shape,
 }
 
 void Param::Init(int v){
-  set_version(v);
   Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size()));
   unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
   auto random=ASingleton<Random<cpu>>::Instance(seed);
@@ -178,6 +172,7 @@ void Param::Init(int v){
     LOG(ERROR) << "Illegal parameter init method ";
     break;
   }
+  set_version(v);
 }
 
 /**************************RandomSyncParam********************************