You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2015/05/17 08:20:00 UTC
[2/5] incubator-singa git commit: 1. move functions in pm_server
(pm_worker) into server (trainer) to simplify the logics. now workers send
simple messages to the stub thread which construct the real update/get/put
requests. the stub thread also handles
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 3d46ee6..ac5566c 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -11,36 +11,37 @@ using std::string;
namespace singa {
Msg* Param::GenPutMsg(void* arg){
- char buf[256];
- int v=*(int*)arg;
- sprintf(buf, "%d %d %f %f", v, size(),
+ char buf[128];
+ sprintf(buf, "%d %f %f", size(),
learning_rate_multiplier(), weight_decay_multiplier());
Msg* msg=new Msg();
msg->set_type(kPut);
+ int v=version();
+ if(arg!=nullptr)
+ v=*(int*)arg;
+ msg->set_target(owner(), v);
msg->add_frame(buf, strlen(buf));
msg->add_frame(mutable_cpu_data(), size()*sizeof(float));
return msg;
}
Msg* Param::GenGetMsg(void* arg){
- char buf[12];
- int v=*(int*)arg;
- sprintf(buf, "%d", v);
- LOG(ERROR)<<"gen get version "<<v;
Msg* msg=new Msg();
msg->set_type(kGet);
- msg->add_frame(buf, strlen(buf));
+ int v=version();
+ if(arg!=nullptr)
+ v=*(int*)arg;
+ msg->set_target(owner(), v);
return msg;
}
Msg* Param::GenUpdateMsg(void* arg){
- char buf[10];
- int v=*(int*)arg;
- sprintf(buf, "%d", v);
Msg* msg=new Msg();
msg->set_type(kUpdate);
- msg->add_frame(buf, strlen(buf));
-
+ int v=version();
+ if(arg!=nullptr)
+ v=*(int*)arg;
+ msg->set_target(owner(), v);
msg->add_frame(mutable_cpu_grad(), size()*sizeof(float));
return msg;
}
@@ -50,16 +51,16 @@ Msg* Param::GenSyncMsg(void* arg){
}
Msg* Param::HandlePutMsg(Msg** msg){
- int v, size;
+ int size;
float lr, wc;
- sscanf(static_cast<char*>((*msg)->frame_data()), "%d %d %f %f",
- &v, &size, &lr, &wc);
+ sscanf(static_cast<char*>((*msg)->frame_data()), "%d %f %f",
+ &size, &lr, &wc);
proto_.set_learning_rate_multiplier(lr);
proto_.set_weight_decay_multiplier(wc);
CHECK((*msg)->next_frame());
vector<int> shape{size};
data_=std::make_shared<Blob<float>>(shape);
- data_->set_version(v);
+ data_->set_version((*msg)->target_second());
grad_.Reshape(shape);
history_.Reshape(shape);
CHECK_EQ(size* sizeof(float), (*msg)->frame_size());
@@ -70,21 +71,16 @@ Msg* Param::HandlePutMsg(Msg** msg){
}
Msg* Param::HandleGetMsg(Msg** msg){
- int v;
- sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v);
- CHECK_LE(v, version());
- CHECK(!(*msg)->next_frame());
- (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size());
- (*msg)->SwapAddr();
- (*msg)->set_type(kRGet);
+ if((*msg)->target_second()<=version()){
+ (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size());
+ (*msg)->SwapAddr();
+ (*msg)->set_type(kRGet);
+ }
return *msg;
}
int Param::ParseUpdateMsg(Msg** msg){
- int v;
- sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v);
- CHECK_LE(v, version());
- CHECK((*msg)->next_frame());
+ CHECK((*msg)->frame_size());
memcpy(mutable_cpu_grad(), (*msg)->frame_data(),(*msg)->frame_size());
delete (*msg);
*msg=nullptr;
@@ -93,16 +89,15 @@ int Param::ParseUpdateMsg(Msg** msg){
Msg* Param::GenUpdateResponseMsg(void* arg){
Msg* msg=new Msg();
- char buf[10];
- sprintf(buf, "%d", version());
msg->set_type(kRUpdate);
- msg->set_target(id());
- msg->add_frame(buf, strlen(buf));
+ int v=version();
+ if(arg!=nullptr)
+ v=*(int*)arg;
+ msg->set_target(owner(), v);
msg->add_frame(mutable_cpu_data(), size()*sizeof(float));
return msg;
}
-
Msg* Param::HandleSyncMsg(Msg** msg){
delete *msg;
*msg=nullptr;
@@ -118,12 +113,12 @@ int Param::ParsePutResponseMsg(Msg **msg){
return ParseSyncResponseMsg(msg);
}
int Param::ParseGetResponseMsg(Msg **msg){
- int v;
- sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v);
- CHECK((*msg)->next_frame());
+ CHECK((*msg)->frame_size());
memcpy(mutable_cpu_data(), (*msg)->frame_data(), (*msg)->frame_size());
// must be set after all other settings are done!
- set_version(v);
+ set_version((*msg)->target_second());
+ delete *msg;
+ *msg=nullptr;
return 1;
}
int Param::ParseUpdateResponseMsg(Msg **msg){
@@ -140,7 +135,6 @@ void Param::Setup(const ParamProto& proto, const vector<int>& shape,
}
void Param::Init(int v){
- set_version(v);
Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size()));
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
auto random=ASingleton<Random<cpu>>::Instance(seed);
@@ -178,6 +172,7 @@ void Param::Init(int v){
LOG(ERROR) << "Illegal parameter init method ";
break;
}
+ set_version(v);
}
/**************************RandomSyncParam********************************