You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2015/05/03 16:04:08 UTC

[03/12] incubator-singa git commit: Transfer code from nusinga repo to singa apache repo. New commuinication framework is implemented to unify the frameworks of existing distributed deep learning systems. Communication is now implmented using ZeroMQ. API

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
new file mode 100644
index 0000000..4ea621d
--- /dev/null
+++ b/src/proto/model.proto
@@ -0,0 +1,382 @@
+package singa;
+enum MsgType{
+  kGet=0;
+  kPut=1;
+  kSync=2;
+  kUpdate=3;
+  kSyncRequest=4;
+  kSyncResponse=5;
+  kStop=6;
+  kData=7;
+  kRGet=8;
+  kRUpdate=9;
+  kConnect=10;
+};
+
+enum EntityType{
+  kWorkerParam=0;
+  kWorkerLayer=1;
+  kServer=2;
+  kStub=3;
+};
+enum Phase {
+  kTrain = 0;
+  kValidation=1;
+  kTest= 2;
+}
+enum ShareOption{
+  kValueOnly=0;
+  kWhole=1;
+};
+message ModelProto{
+  optional string name = 1;
+  // relative path to system folder
+  optional string train_folder=2 [default="train"];
+  optional string test_folder=3 [default="test"];
+  optional string validation_folder=4 [default="validation"];
+  // start display after this num steps
+  optional int32 display_after_steps = 6 [default = 0];
+  // frequency of display
+  optional int32 display_frequency = 7 [default = 0];
+
+  // the time of validation
+  //optional int32 validation_step = 9 [default = 0];
+  // start validation after this num steps
+  optional int32 validation_after_steps = 10 [default = 0];
+  // frequency of validation
+  optional int32 validation_frequency = 11 [default = 0];
+
+  // the time of test
+  //optional int32 test_step = 12 [default = 0];
+  // start test after this num steps
+  optional int32 test_after_steps = 13 [default = 0];
+  // frequency of test
+  optional int32 test_frequency = 14 [default = 0];
+  optional int32 checkpoint_after_steps = 15 [default = 0];
+  // frequency of test
+  optional int32 checkpoint_frequency = 16 [default = 0];
+  optional bool prefetch=18[default=true];
+
+
+  // total num of steps for training
+  optional int32 train_steps = 20;
+  // total num of steps for validation
+  optional int32 validation_steps=21;
+  // total num of steps for test
+  optional int32 test_steps=22;
+  // last snapshot step
+  optional int32 step=29 [default=0];
+
+  optional UpdaterProto updater=31;
+  // There are two basic algorithms for calculating gradients.
+  // Different deep learning models use different algorithms.
+  enum GradCalcAlg{
+    kBackPropagation = 1;
+    kContrastiveDivergence = 2;
+  }
+  optional GradCalcAlg alg= 32 [default = kBackPropagation];
+  optional bool hogwild=33 [default=false];
+  optional NetProto neuralnet = 40;
+  optional bool debug=41 [default=false];
+}
+
+message NetProto{
+  repeated LayerProto layer=1;
+  optional PartitionType partition_type=3 [default=kNone];
+}
+
+message ParamProto {
+  // for the program to identify it and share among layers.
+  // e.g., "conv1_weight","fc_bias"
+  optional string name = 1;
+  optional int32 id=2;
+  // in most situations, user do not need to config this,
+  // the program will calculate it
+  repeated int32 shape = 3;
+
+  // split the parameter into multiple DAryProtos for serialzation and
+  // transferring (Google Protobuf has size limit)
+  optional int32 split_threshold=4 [default=5000000];
+  // partition dimension, -1 for no partition
+  optional int32 partition_dim=5 [default =-1];
+
+  optional int32 version=6;
+
+  // value of the parameter
+  //repeated DAryProto ary = 6;
+
+  enum InitMethod {
+    kConstant = 0;
+    // sample gaussian with std and mean
+    kGaussian = 1;
+    // uniform sampling between low and high
+    kUniform = 2;
+    // copy the content and history which are from previous training
+    kPretrained = 3;
+    // from Toronto Convnet, let a=1/sqrt(fan_in), w*=a after generating from
+    // Gaussian distribution
+    kGaussainSqrtFanIn = 4;
+    // from Toronto Convnet, rectified linear activation, let
+    // a=sqrt(3)/sqrt(fan_in), range is [-a, +a]; no need to set value=sqrt(3),
+    // the program will multiply it.
+    kUniformSqrtFanIn = 5;
+    // from Theano MLP tutorial, let a=1/sqrt(fan_in+fan_out). for tanh
+    // activation, range is [-6a, +6a], for sigmoid activation, range is
+    // [-24a, +24a], put the scale factor to value field.
+    // <a href="http://deeplearning.net/tutorial/mlp.html"> Theano MLP</a>
+    kUniformSqrtFanInOut = 6;
+  }
+  optional InitMethod init_method = 7 [default = kConstant];
+  // constant init
+  optional float value = 8 [default = 1];
+  // for uniform sampling
+  optional float low = 9 [default = -1];
+  optional float high = 10 [default = 1];
+  // for gaussian sampling
+  optional float mean = 11 [default = 0];
+  optional float std = 12 [default = 1];
+  // multiplied on the global learning rate.
+  optional float learning_rate_multiplier =13 [default=1];
+  // multiplied on the global weight decay.
+  optional float weight_decay_multiplier =14 [default=1];
+}
+
+message BlobProtos{
+  repeated BlobProto blobs=1;
+  repeated int32 ids=2;
+  repeated string names=3;
+}
+
+
+
+enum PartitionType{
+  kDataPartition=0;
+  kLayerPartition=1;
+  kNone=2;
+}
+enum ConnectionType{
+  kOneToOne=0;
+  kOneToAll=1;
+}
+
+message LayerProto {
+  optional string name = 1; // the layer name
+  optional string type = 2; // the layer type from the enum above
+  repeated string srclayers=3;
+  optional int32 locationid=4 [default=0]; // todo make locationID an array
+  optional int32 partitionid=5 [default=0];
+  optional PartitionType partition_type=6;
+  // can be pos/neg neuron value for CD, neuron value/grad for BP
+  //repeated DAryProto ary = 10;
+  repeated string share_ary =11;
+  // parameters, e.g., weight matrix or bias vector
+  repeated ParamProto param = 12;
+  // names of parameters shared from other layers
+  repeated string share_param=13;
+
+  // All layers are included in the net structure for training phase by default.
+  // Layers, e.g., computing performance metrics for test phase, can be excluded
+  // by this field which defines in which phase this layer should be excluded.
+  repeated Phase exclude = 20;
+
+  // hyper-parameters for layers
+  optional ConvolutionProto convolution_param = 21;
+  optional ConcateProto concate_param = 31;
+  optional DataProto data_param = 22;
+  optional DropoutProto dropout_param = 23;
+  optional InnerProductProto inner_product_param = 24;
+  optional LRNProto lrn_param = 25;
+  optional MnistProto mnist_param= 26;
+  optional PoolingProto pooling_param = 27;
+  optional SliceProto slice_param = 32;
+  optional SplitProto split_param = 33;
+  optional ReLUProto relu_param = 28;
+  optional RGBImage rgbimage_param=34;
+  optional SoftmaxLossProto softmaxloss_param = 29;
+  optional TanhProto tanh_param=30;
+}
+message RGBImage {
+  optional float scale=1 [default=1.0];
+  optional int32 cropsize=2 [default=0];
+  optional bool mirror=3 [default=false];
+  optional string meanfile=4;
+}
+message SplitProto{
+  optional int32 num_splits=1;
+}
+// scaled tan: A*tan(B*x)
+message TanhProto{
+  optional float outer_scale=1 [default=1.0];
+  optional float inner_scale=2 [default=1.0];
+}
+
+// Message that stores parameters used by SoftmaxLossProto
+message SoftmaxLossProto {
+  // accuracy is not comptued by default, unless topk>0;
+  // When computing accuracy, count as correct by comparing the true label to
+  // the top k scoring classes.
+  optional int32 topk = 1 [default=1] ;
+  optional float scale=2 [default=1];
+}
+// Message that stores parameters used by ConvolutionLayer
+message ConvolutionProto {
+  optional uint32 num_filters = 1; // The number of outputs for the layer
+  optional bool bias_term = 2 [default = true]; // whether to have bias terms
+  // Pad, kernel size, and stride are all given as a single value for equal
+  // dimensions in height and width or as Y, X pairs.
+  optional uint32 pad = 3 [default = 0]; // The padding size (equal in Y, X)
+  optional uint32 stride = 4 [default = 1]; // The stride (equal in Y, X)
+  required uint32 kernel= 5; // The kernel height/width
+}
+
+message ConcateProto{
+  optional int32 concate_dimension=1;
+  optional int32 concate_num=2;
+}
+
+// Message that stores parameters used by DataLayer
+message DataProto {
+  // Specify the data source.
+  optional string source = 1;
+  // path to the data file/folder, absolute or relative to the
+  // ClusterProto::workspace
+  optional string path=2;
+  // Specify the batch size.
+  optional uint32 batchsize = 4;
+  // skip [0,random_skip] records
+  optional uint32 random_skip=5 [default=0];
+}
+
+message MnistProto {
+  // elastic distortion
+  optional int32 kernel=1 [default=0];
+  optional float sigma=2 [default=0];
+  optional float alpha=3 [default=0];
+  // rotation or horizontal shearing
+  optional float beta=4 [default=0];
+  // scaling
+  optional float gamma=5 [default=0];
+  // scale to this size as input for deformation
+  optional int32 resize=6 [default=0] ;
+  optional int32 elastic_freq=7 [default=0];
+  optional float norm_a=8 [default=1];
+  optional float norm_b=9 [default=0];
+}
+// Message that stores parameters used by DropoutLayer
+message DropoutProto {
+  optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio
+}
+// Message that stores parameters used by InnerProductLayer
+message InnerProductProto {
+  optional uint32 num_output = 1; // The number of outputs for the layer
+  optional bool bias_term = 2 [default = true]; // whether to have bias terms
+}
+
+// Message that stores parameters used by LRNLayer
+message LRNProto {
+  optional uint32 local_size = 1 [default = 5];
+  optional float alpha = 2 [default = 1.];
+  optional float beta = 3 [default = 0.75];
+  enum NormRegion {
+    ACROSS_CHANNELS = 0;
+    WITHIN_CHANNEL = 1;
+  }
+  optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS];
+  optional float knorm =5 [default=1.0];
+}
+
+// Message that stores parameters used by PoolingLayer
+message PoolingProto {
+  enum PoolMethod {
+    MAX = 0;
+    AVE = 1;
+  }
+  optional PoolMethod pool = 1 [default = MAX]; // The pooling method
+  // Pad, kernel size, and stride are all given as a single value for equal
+  // dimensions in height and width or as Y, X pairs.
+  required uint32 kernel= 2; // The kernel size (square)
+  optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X)
+  optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X)
+}
+
+message SliceProto{
+  optional int32 slice_dimension=1;
+  optional int32 slice_num=2;
+}
+// Message that stores parameters used by ReLULayer
+message ReLUProto {
+  // Allow non-zero slope for negative inputs to speed up optimization
+  // Described in:
+  // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities
+  // improve neural network acoustic models. In ICML Workshop on Deep Learning
+  // for Audio, Speech, and Language Processing.
+  optional float negative_slope = 1 [default = 0];
+}
+
+
+
+message Record {
+  enum Type{
+    kSingleLabelImage=0;
+  }
+  optional Type type=1 [default=kSingleLabelImage];
+  optional SingleLabelImageRecord image=2;
+}
+
+// to import caffe's lmdb dataset
+message Datum {
+  optional int32 channels = 1;
+  optional int32 height = 2;
+  optional int32 width = 3;
+  // the actual image data, in bytes
+  optional bytes data = 4;
+  optional int32 label = 5;
+  // Optionally, the datum could also hold float data.
+  repeated float float_data = 6;
+  // If true data contains an encoded image that need to be decoded
+  optional bool encoded = 7 [default = false];
+}
+message SingleLabelImageRecord{
+  repeated int32 shape=1;
+  optional int32 label=2;
+  optional bytes pixel=3;
+  repeated float data=4;
+}
+
+message UpdaterProto {
+  optional float momentum=4 [default=0];
+  optional float weight_decay=5 [default=0];
+  // used in changing learning rate
+  optional float gamma = 6 [default=1];
+  optional float pow=7 [default=0];
+  optional float delta=8 [default=0.0000001];
+  optional float rho=9 [default=0.9];
+  optional float base_learning_rate=12;
+  optional float final_learning_rate=13;
+  optional int32 learning_rate_change_frequency = 14;
+  enum ChangeProto {
+    kFixed = 0;
+    kInverse_t= 1;
+    kInverse= 2;
+    kExponential = 3;
+    kLinear = 4;
+    kStep = 5;
+    kFixedStep=6;
+  }
+  optional ChangeProto learning_rate_change_method = 16 [default = kFixed];
+  optional int32 sync_frequency=17 [default=1];
+  // warmup the parameters and then send to parameter servers.
+  optional int32 warmup_steps=25 [default=10];
+  optional float moving_rate=26 [default=0];
+  optional string param_type=27[default="Param"];
+  repeated int32 step=28;
+  repeated float step_lr=29;
+}
+message BlobProto {
+  optional int32 num = 1 [default = 0];
+  optional int32 channels = 2 [default = 0];
+  optional int32 height = 3 [default = 0];
+  optional int32 width = 4 [default = 0];
+  repeated float data = 5 [packed = true];
+  repeated float diff = 6 [packed = true];
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_consistency.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_consistency.cc b/src/test/dist_test/test_consistency.cc
new file mode 100644
index 0000000..a4ed9b2
--- /dev/null
+++ b/src/test/dist_test/test_consistency.cc
@@ -0,0 +1,406 @@
+//  Copyright © 2014 Anh Dinh. All Rights Reserved.
+
+//  Testing the unbalance in spliting parameter vectors.
+
+#include "core/global-table.h"
+#include "core/common.h"
+#include "core/disk-table.h"
+#include "core/table.h"
+#include "core/table_server.h"
+#include "utils/global_context.h"
+#include <gflags/gflags.h>
+#include "proto/model.pb.h"
+#include "proto/common.pb.h"
+#include "worker.h"
+#include "coordinator.h"
+#include "utils/common.h"
+#include "utils/proto_helper.h"
+
+#include <cmath>
+#include <stdlib.h>
+#include <vector>
+#include <iostream>
+#include <fstream>
+
+
+DEFINE_bool(restore_mode, false, "restore from checkpoint file");
+using namespace lapis;
+using std::vector;
+
+//DEFINE_bool(sync_update, false, "Synchronous put/update queue");
+DEFINE_int32(checkpoint_frequency, 5000, "frequency for cp");
+DEFINE_int32(checkpoint_after, 1, "cp after this steps");
+DEFINE_string(par_mode, "hybrid",  "time training algorithm");
+DEFINE_bool(restore, false, "restore from checkpoint file");
+
+DEFINE_string(db_backend, "lmdb", "backend db");
+DEFINE_string(system_conf, "examples/imagenet12/system.conf", "configuration file for node roles");
+DEFINE_string(model_conf, "examples/imagenet12/model.conf", "DL model configuration file");
+DEFINE_string(checkpoint_dir,"/data1/wangwei/lapis/","check point dir");
+DEFINE_int32(threshold,1000000, "max # of parameters in a vector");
+DEFINE_int32(iterations,5,"numer of get/put iterations");
+DEFINE_int32(workers,2,"numer of workers doing get/put");
+DECLARE_bool(checkpoint_enabled);
+
+#ifndef FLAGS_v
+  DEFINE_int32(v, 3, "vlog controller");
+#endif
+
+
+struct AnhUpdateHandler: BaseUpdateHandler<VKey,SGDValue>{
+	bool Update(SGDValue *a, const SGDValue &b){
+    float * adptr=a->mutable_data()->mutable_value()->mutable_data();
+    const float*bdptr=b.grad(0).value().data();
+    for(int i=0;i<b.grad(0).value_size();i++)
+      adptr[i]+=bdptr[i];
+		return true;
+	}
+
+  bool Get(const VKey k, const SGDValue &val, SGDValue *ret){
+      *ret = val;
+      return true;
+  }
+
+  bool is_checkpointable(const VKey k, const SGDValue v){
+  	return true; //always checkpoint
+  }
+};
+
+typedef map<int, GlobalTable*> Map;
+Map tables;
+shared_ptr<NetworkThread> network;
+shared_ptr<GlobalContext> context;
+std::vector<ServerState*> server_states;
+TableServer *table_server;
+TableDelegate *delegate;
+void create_mem_table(int id, int num_shards){
+
+	TableDescriptor *info = new TableDescriptor(id, num_shards);
+	  info->key_marshal = new Marshal<VKey>();
+	  info->value_marshal = new Marshal<SGDValue>();
+	  info->sharder = new VKeySharder;
+	  info->accum = new AnhUpdateHandler;
+	  info->partition_factory = new typename SparseTable<VKey, SGDValue>::Factory;
+	  auto table=new TypedGlobalTable<VKey, SGDValue>();
+	  table->Init(info);
+	  tables[id] = table;
+}
+
+void coordinator_assign_tables(int id){
+	for (int i = 0; i < context->num_procs() 	; ++i) {
+	    RegisterWorkerRequest req;
+	    int src = 0;
+	    //  adding memory server.
+	    if (context->IsTableServer(i)) {
+	      network->Read(MPI::ANY_SOURCE, MTYPE_REGISTER_WORKER, &req, &src);
+	      server_states.push_back(new ServerState(i));
+	    }
+	  }
+	  LOG(INFO) << " All servers registered and started up. Ready to go";
+	  //  set itself as the current worker for the table
+	  tables[id]->worker_id_ = network->id();
+
+	  // memory servers are specified in global context. Round-robin assignment
+
+	    VLOG(3)<<"num of shards"<<tables[id]->num_shards()<<" for table"<< id;
+
+	    int server_idx = 0;
+	    for (int shard = 0; shard < tables[id]->num_shards(); ++shard) {
+	      ServerState &server = *server_states[server_idx];
+	      LOG(INFO) << "Assigning table ("<<id<<","<<shard<<") to server "
+	                <<server_states[server_idx]->server_id;
+
+	      // TODO(Anh) may overwrite this field if #shards>#table_servers
+	      server.shard_id = shard;
+	      server.local_shards.insert(new TaskId(id, shard));
+	      server_idx = (server_idx + 1) % server_states.size();
+	    }
+
+	  VLOG(3)<<"table assignment";
+	  //  then send table assignment
+	  ShardAssignmentRequest req;
+	  for (size_t i = 0; i < server_states.size(); ++i) {
+	    ServerState &server = *server_states[i];
+	    for (auto * task: server.local_shards) {
+	      ShardAssignment *s  = req.add_assign();
+	      s->set_new_worker(server.server_id);
+	      s->set_table(task->table);
+	      s->set_shard(task->shard);
+	      //  update local tables
+	      CHECK(tables.find(task->table)!=tables.end());
+	      GlobalTable *t = tables.at(task->table);
+	      t->get_partition_info(task->shard)->owner = server.server_id;
+	      delete task;
+	    }
+	  }
+	  VLOG(3)<<"finish table assignment, req size "<<req.assign_size();
+	  network->SyncBroadcast(MTYPE_SHARD_ASSIGNMENT, MTYPE_SHARD_ASSIGNMENT_DONE, req);
+	  VLOG(3)<<"finish table server init";
+}
+
+
+void worker_table_init(){
+	table_server = new TableServer();
+	table_server->StartTableServer(tables);
+	VLOG(3) << "done starting table server";
+}
+
+double random_double(){
+	return static_cast<double>(rand())/static_cast<double>(RAND_MAX);
+}
+
+// popular table with random large or small messages.
+// the message distribution specified in FLAGS_large_precentage
+void coordinator_load_data(const vector<int>& tuples){
+  auto table = static_cast<TypedGlobalTable<VKey,SGDValue>*>(tables[0]);
+
+  int nservers=context->num_table_servers();
+  int keyid=0;
+  if (!FLAGS_restore_mode){
+    for(auto tuple: tuples){
+      for(int offset=0;offset<tuple;){
+        SGDValue x;
+        DAryProto *data=x.mutable_data();
+        DAryProto *grad=x.add_grad();
+        for(int i=0;i <std::min(FLAGS_threshold, tuple-offset);i++){
+          data->add_value(i*1.0f);
+          grad->add_value(i*1.0f);
+        }
+        offset+=data->value_size();
+        VKey key;
+        key.set_key(keyid++);
+        table->put(key,x);
+      }
+    }
+    LOG(ERROR)<<"put "<<keyid<<" tuples";
+  }
+
+  /*
+	LogFile *file = new LogFile("/data1/wangwei/lapis/checkpoint_0","rw",0);
+	VLOG(3) << "Loaded table " << file->file_name();
+	string k,v;
+	int table_size = file->read_latest_table_size();
+	VLOG(3) << "table size = " << table_size;
+	for (int i=0; i<table_size; i++){
+		int tmp;
+		file->previous_entry(&k, &v, &tmp);
+		int *key = reinterpret_cast<int *>((char*)&k[0]);
+		int *val = reinterpret_cast<int *>((char*)&v[0]);
+		VLOG(3) << "k = " << *key << " val = " << *val;
+	}
+	delete file;
+  */
+
+	/*
+	for (int i=0; i<num_keys; i++){
+		table->put(i,0); //loaded again
+	}*/
+	VLOG(3) << "Coordinator done loading ..., from process "<<NetworkThread::Get()->id();
+}
+
+void get(TypedGlobalTable<VKey,SGDValue>* table, const vector<int>& tuples){
+  SGDValue v;
+  int num_keys=0;
+  for(auto tuple: tuples){
+    num_keys+=tuple/FLAGS_threshold+(tuple%FLAGS_threshold!=0);
+  }
+  LOG(ERROR)<<"getting "<<num_keys<<" tuples";
+
+  for (int i=0; i<num_keys; i++){
+    VKey key;
+    key.set_key(i);
+    table->async_get(key, &v);
+  }
+
+
+  int key=0;
+  SGDValue val;
+
+  LOG(INFO)<<"start collect key";
+  for (int i=0; i<num_keys; i++){
+    VKey key;
+    while(!table->async_get_collect(&key, &val))
+      Sleep(0.001);
+    //LOG(INFO)<<"collect key "<<key<<" with val "<<val;
+  }
+}
+
+void update(TypedGlobalTable<VKey,SGDValue>* table, const vector<int>& tuples){
+  if(NetworkThread::Get()->id()==0)
+    sleep(2);
+  LOG(INFO)<<"start update";
+  int keyid=0;
+  for(auto tuple: tuples){
+    for(int offset=0;offset<tuple;){
+      SGDValue x;
+      DAryProto *grad=x.add_grad();
+      for(int i=0;i <std::min(FLAGS_threshold, tuple-offset);i++){
+        grad->add_value(i*1.0f);
+      }
+      offset+=grad->value_size();
+      VKey key;
+      key.set_key(keyid++);
+      table->update(key,x);
+    }
+  }
+  LOG(ERROR)<<"updated "<<keyid<<" tuples";
+}
+
+void worker_test_data(const vector<int>& tuples){
+  auto table = static_cast<TypedGlobalTable<VKey,SGDValue>*>(tables[0]);
+
+  get(table, tuples);
+  update(table, tuples);
+  update(table, tuples);
+  update(table, tuples);
+  get(table, tuples);
+}
+
+void shutdown(){
+	if (context->AmICoordinator()){
+		EmptyMessage msg;
+		for (int i=0; i<context->num_procs()-1; i++)
+			network->Read(MPI::ANY_SOURCE, MTYPE_WORKER_END, &msg);
+		 EmptyMessage shutdown_msg;
+		  for (int i = 0; i < network->size() - 1; i++) {
+		    network->Send(i, MTYPE_SHUTDOWN, shutdown_msg);
+		  }
+		  network->Flush();
+		  network->Shutdown();
+	}
+	else{
+	  network->Flush();
+
+	  network->Send(context->num_procs()-1, MTYPE_WORKER_END, EmptyMessage());
+
+	  EmptyMessage msg;
+
+	  network->Read(context->num_procs()-1, MTYPE_SHUTDOWN, &msg);
+
+	  if (context->AmITableServer())
+		  table_server->ShutdownTableServer();
+
+	  network->Shutdown();
+	}
+}
+
+void HandleShardAssignment() {
+
+  ShardAssignmentRequest shard_req;
+  auto mpi=NetworkThread::Get();
+  mpi->Read(GlobalContext::kCoordinator, MTYPE_SHARD_ASSIGNMENT, &shard_req);
+  //  request read from coordinator
+  for (int i = 0; i < shard_req.assign_size(); i++) {
+    const ShardAssignment &a = shard_req.assign(i);
+    GlobalTable *t = tables.at(a.table());
+    t->get_partition_info(a.shard())->owner = a.new_worker();
+
+
+    //if local shard, create check-point files
+    if (FLAGS_checkpoint_enabled && t->is_local_shard(a.shard())){
+      string checkpoint_file = StringPrintf("%s/checkpoint_%d",FLAGS_checkpoint_dir.c_str(), a.shard());
+        char hostname[256];
+        gethostname(hostname, sizeof(hostname));
+        VLOG(3) << "try to open for writing *****"<<checkpoint_file<<" "<<string(hostname);
+
+      FILE *tmp_file = fopen(checkpoint_file.c_str(), "r");
+      if (tmp_file){//exists -> open to reading and writing
+        fclose(tmp_file);
+        auto cp = t->checkpoint_files();
+
+        if (FLAGS_restore_mode){//open in read mode to restore, then close
+          LogFile *file = new LogFile(checkpoint_file,"rw",0);
+          VLOG(3) << "Loaded table " << file->file_name();
+          int table_size = file->read_latest_table_size();
+          delete file;
+
+          double start=Now();
+          VLOG(3) << "Open checkpoint file to restore";
+          (*cp)[a.shard()] = new LogFile(checkpoint_file,"r",a.shard());
+          t->Restore(a.shard());
+          delete (*cp)[a.shard()];
+          double end=Now();
+          LOG(ERROR)<<"restore time\t"<<end-start<< "\tfor\t"
+            <<table_size<<"\tthreshold\t"<<FLAGS_threshold;
+        }
+        char hostname[256];
+        gethostname(hostname, sizeof(hostname));
+        VLOG(3) << "open for writing *****"<<checkpoint_file<<" "<<string(hostname);
+
+
+
+        VLOG(3) << "Open checkpoint file for writing";
+        (*cp)[a.shard()] = new LogFile(checkpoint_file,"a",a.shard());
+      }
+      else{// not exist -> open to writing first time
+        auto cp = t->checkpoint_files();
+        (*cp)[a.shard()] = new LogFile(checkpoint_file,"w",a.shard());
+        VLOG(3) << "Added to new checkpoint files for shard "<< a.shard();
+      }
+
+    }
+
+
+  }
+  EmptyMessage empty;
+  mpi->Send(GlobalContext::kCoordinator, MTYPE_SHARD_ASSIGNMENT_DONE, empty);
+  VLOG(3)<<"finish handle shard assignment **";
+
+}
+
+
+int main(int argc, char **argv) {
+	FLAGS_logtostderr = 1;
+	int provided;
+	MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided);
+	google::InitGoogleLogging(argv[0]);
+	gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+	context = GlobalContext::Get(FLAGS_system_conf);
+	network = NetworkThread::Get();
+
+	ModelProto model;
+	ReadProtoFromTextFile(FLAGS_model_conf.c_str(), &model);
+
+	create_mem_table(0,context->num_table_servers());
+
+  vector<int> tuple_size{37448736, 16777216, 4096000, 1327104, 884736, 884736, 614400,14112,4096,4096,1000,384,384,256,256,96};
+  /*
+  vector<int> tuples;
+  for(int i=0;i<3;i++){
+    for(int j=0;j<FLAGS_workers;j++)
+      tuples.push_back(tuple_size[i]/FLAGS_workers);
+  }
+  for(int i=3;i<tuple_size.size();i++)
+    tuples.push_back(tuple_size[i]);
+    */
+
+	if (context->AmICoordinator()){
+		VLOG(3) << "Coordinator process rank = " << NetworkThread::Get()->id();
+		coordinator_assign_tables(0);
+		coordinator_load_data(tuple_size);
+
+		network->barrier();
+	}
+	else{
+		if (context->AmITableServer()){
+			worker_table_init();
+			HandleShardAssignment();
+			network->barrier();
+		}
+		else{
+			VLOG(3) << "Inside worker, waiting for assignemtn";
+			HandleShardAssignment();
+			network->barrier();
+      if(!FLAGS_restore_mode)
+        worker_test_data(tuple_size);
+		}
+	}
+	shutdown();
+
+
+	VLOG(3) << "Done ...";
+	return 0;
+}
+
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_core.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_core.cc b/src/test/dist_test/test_core.cc
new file mode 100644
index 0000000..35d589b
--- /dev/null
+++ b/src/test/dist_test/test_core.cc
@@ -0,0 +1,192 @@
+//  Copyright © 2014 Anh Dinh. All Rights Reserved.
+
+
+
+#include "core/global-table.h"
+#include "core/common.h"
+#include "core/disk-table.h"
+#include "core/table.h"
+#include "core/table_server.h"
+#include "utils/global_context.h"
+#include <gflags/gflags.h>
+#include "proto/model.pb.h"
+#include "worker.h"
+#include "coordinator.h"
+#include "model_controller/myacc.h"
+#include <cmath>
+
+using namespace lapis;
+
+DEFINE_bool(sync_update, false, "Synchronous put/update queue");
+DEFINE_string(system_conf, "examples/imagenet12/system.conf", "configuration file for node roles");
+DEFINE_string(model_conf, "examples/imagenet12/model.conf", "DL model configuration file");
+DEFINE_int32(num_keys,10,"");
+
+typedef map<int, GlobalTable*> Map;
+Map tables;
+shared_ptr<NetworkThread> network;
+shared_ptr<GlobalContext> context;
+std::vector<ServerState*> server_states;
+TableServer *table_server;
+
+void create_mem_table(int id, int num_shards){
+
+	TableDescriptor *info = new TableDescriptor(id, num_shards);
+	  info->key_marshal = new Marshal<int>();
+	  info->value_marshal = new Marshal<int>();
+	  info->sharder = new Sharding::Mod;
+	  info->accum = new TestUpdater();
+	  info->partition_factory = new typename SparseTable<int, int>::Factory;
+	  auto table=new TypedGlobalTable<int, int>();
+	  table->Init(info);
+	  tables[id] = table;
+}
+
+void coordinator_assign_tables(int id){
+	for (int i = 0; i < context->num_processes()-1; ++i) {
+	    RegisterWorkerRequest req;
+	    int src = 0;
+	    network->Read(MPI::ANY_SOURCE, MTYPE_REGISTER_WORKER, &req, &src);
+	    //  adding memory server.
+	    if (context->IsTableServer(i)) {
+	      server_states.push_back(new ServerState(i));
+	    }
+	  }
+	  LOG(INFO) << " All servers registered and started up. Ready to go";
+	  //  set itself as the current worker for the table
+	  tables[id]->worker_id_ = network->id();
+
+	  // memory servers are specified in global context. Round-robin assignment
+
+	    VLOG(3)<<"num of shards"<<tables[id]->num_shards()<<" for table"<< id;
+
+	    int server_idx = 0;
+	    for (int shard = 0; shard < tables[id]->num_shards(); ++shard) {
+	      ServerState &server = *server_states[server_idx];
+	      LOG(INFO) << "Assigning table ("<<id<<","<<shard<<") to server "
+	                <<server_states[server_idx]->server_id;
+
+	      // TODO(Anh) may overwrite this field if #shards>#table_servers
+	      server.shard_id = shard;
+	      server.local_shards.insert(new TaskId(id, shard));
+	      server_idx = (server_idx + 1) % server_states.size();
+	    }
+
+	  VLOG(3)<<"table assignment";
+	  //  then send table assignment
+	  ShardAssignmentRequest req;
+	  for (size_t i = 0; i < server_states.size(); ++i) {
+	    ServerState &server = *server_states[i];
+	    for (auto * task: server.local_shards) {
+	      ShardAssignment *s  = req.add_assign();
+	      s->set_new_worker(server.server_id);
+	      s->set_table(task->table);
+	      s->set_shard(task->shard);
+	      //  update local tables
+	      CHECK(tables.find(task->table)!=tables.end());
+	      GlobalTable *t = tables.at(task->table);
+	      t->get_partition_info(task->shard)->owner = server.server_id;
+	      delete task;
+	    }
+	  }
+	  VLOG(3)<<"finish table assignment, req size "<<req.assign_size();
+	  network->SyncBroadcast(MTYPE_SHARD_ASSIGNMENT, MTYPE_SHARD_ASSIGNMENT_DONE, req);
+	  VLOG(3)<<"finish table server init";
+}
+
+void worker_table_init(){
+	table_server = new TableServer();
+	table_server->StartTableServer(tables);
+	VLOG(3) << "done starting table server";
+}
+
+
+void coordinator_load_data(){
+	auto table = static_cast<TypedGlobalTable<int,int>*>(tables[0]);
+	for (int i = 1; i<=FLAGS_num_keys; i++){
+		table->put(i,i);
+	}
+	VLOG(3) << "Loaded data successfully ...";
+}
+
+void worker_test_data(){
+	auto table = static_cast<TypedGlobalTable<int,int>*>(tables[0]);
+	for (int i=1; i<=FLAGS_num_keys; i++)
+		VLOG(3) << StringPrintf("Worker %d got (%d,%d)", NetworkThread::Get()->id(), i, table->get(i));
+
+
+	for (int j = 0; j < 2; j++) {
+		for (int i = 1; i <= FLAGS_num_keys; i++)
+			table->update(i, i);
+
+		for (int i = 1; i <= FLAGS_num_keys; i++)
+			VLOG(3)
+					<< StringPrintf("Worker %d got (%d,%d)",
+							NetworkThread::Get()->id(), i, table->get(i));
+	}
+/*
+	for (int i = 1; i <= FLAGS_num_keys; i++)
+				VLOG(3)
+						<< StringPrintf("Worker %d got (%d,%d)",
+
+							NetworkThread::Get()->id(), i, table->get(i));
+*/
+}
+
+void shutdown(){
+	if (context->AmICoordinator()){
+		VLOG(3) << "Coordinator is shutting down ...";
+		EmptyMessage msg;
+		for (int i=0; i<context->num_processes()-1; i++)
+			network->Read(MPI::ANY_SOURCE, MTYPE_WORKER_END, &msg);
+		 EmptyMessage shutdown_msg;
+		  for (int i = 0; i < network->size() - 1; i++) {
+		    network->Send(i, MTYPE_WORKER_SHUTDOWN, shutdown_msg);
+		  }
+		  network->Flush();
+		  network->Shutdown();
+	}
+	else{
+		VLOG(3) << "Worker " << network->id() << " is shutting down ...";
+	  network->Flush();
+	  VLOG(3) << "Done flushing the network thread";
+	  network->Send(GlobalContext::kCoordinatorRank, MTYPE_WORKER_END, EmptyMessage());
+	  EmptyMessage msg;
+	  network->Read(GlobalContext::kCoordinatorRank, MTYPE_WORKER_SHUTDOWN, &msg);
+	  VLOG(3) << "Worker received MTYPE_WORKER_SHUTDOWN";
+	  table_server->ShutdownTableServer();
+	  VLOG(3) << "Flushing node " << network->id();
+	  network->Shutdown();
+	}
+}
+
+
+int main(int argc, char **argv) {
+	FLAGS_logtostderr = 1;
+	google::InitGoogleLogging(argv[0]);
+	gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+	context = GlobalContext::Get(FLAGS_system_conf, FLAGS_model_conf);
+	network = NetworkThread::Get();
+	VLOG(3) << "*** testing memory servers, with "
+			<< context->num_table_servers() << " servers";
+	create_mem_table(0,context->num_table_servers());
+
+	if (context->AmICoordinator()){
+		coordinator_assign_tables(0);
+		coordinator_load_data();
+		network->barrier();
+	}
+	else{
+		worker_table_init();
+		network->barrier();
+		VLOG(3) << "passed the barrier";
+		//Sleep(1);
+		worker_test_data();
+	}
+
+	shutdown();
+	return 0;
+}
+
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_da.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_da.cc b/src/test/dist_test/test_da.cc
new file mode 100644
index 0000000..51aa93e
--- /dev/null
+++ b/src/test/dist_test/test_da.cc
@@ -0,0 +1,700 @@
+#include <glog/logging.h>
+#include <mpi.h>
+#include <utility>
+#include <vector>
+
+#include "da/gary.h"
+#include "da/dary.h"
+#include "da/ary.h"
+
+
+using std::make_pair;
+using std::vector;
+void Debug() {
+  int i = 0;
+  char hostname[256];
+  gethostname(hostname, sizeof(hostname));
+  printf("PID %d on %s ready for attach\n", getpid(), hostname);
+  fflush(stdout);
+  while (0 == i)
+    sleep(5);
+}
+
+
+
+void TestPar(int pdim, int rank){
+  lapis::DAry a1, a2;
+  lapis::DAry a3, a4;
+  vector<lapis::Range> slice{make_pair(0,4), make_pair(0,8)};
+  a1.SetShape({4,8});
+  a2.SetShape({4,8});
+  a1.Setup(pdim);
+  a2.Setup(pdim);
+  a1.Random();
+  a2.Random();
+  ARMCI_Barrier();
+
+
+  if(rank==0){
+    //Debug();
+    LOG(ERROR)<<"test simple partition along "<< pdim<<" dim";
+    a3=a1.Fetch(slice);
+    a4=a2.Fetch(slice);
+    LOG(ERROR)<<"fetch a";
+    LOG(ERROR)<<a3.ToString();
+    LOG(ERROR)<<"fetch b";
+    LOG(ERROR)<<a4.ToString();
+    a3.Add(a4);
+    LOG(ERROR)<<"a<- a+b";
+    LOG(ERROR)<<a3.ToString();
+  }
+  ARMCI_Barrier();
+  a1.Add(a2);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a1.Fetch(slice);
+    LOG(ERROR)<<"add then fetch";
+    LOG(ERROR)<<a5.ToString();
+  }
+}
+
+
+
+void TestMixedParElt(int pa, int pb, int pc, int rank){
+  LOG(ERROR)<<" p dim for a,b,c is "<<pa<<" "<<pb<<" "<<pc;
+  vector<lapis::Range> slice{make_pair(0,3),make_pair(0,6), make_pair(0,2)};
+  lapis::DAry a1, a2, a3;
+  a1.SetShape({3,6,2});
+  a2.SetShape({3,6,2});
+  a3.SetShape({3,6,2});
+  a1.Setup(pa);
+  a2.Setup(pb);
+  a3.Setup(pc);
+  a1.Random();
+  a2.Random();
+  a3.Random();
+
+  ARMCI_Barrier();
+  if(rank==0){
+    LOG(ERROR)<<"test elementwise ops with mixed partition";
+    lapis::DAry a5, a4;
+//    Debug();
+    a5=a1.Fetch(slice);
+    a4=a2.Fetch(slice);
+    LOG(ERROR)<<"fetch a";
+    LOG(ERROR)<<a5.ToString();
+    LOG(ERROR)<<"fetch b";
+    LOG(ERROR)<<a4.ToString();
+    a5.Copy(a4);
+    LOG(ERROR)<<"fetch op a.Copy(b)";
+    LOG(ERROR)<<a5.ToString();
+  }
+  ARMCI_Barrier();
+  a1.Copy(a2);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a1.Fetch(slice);
+    LOG(ERROR)<<"op fetch a.Copy(b)";
+    LOG(ERROR)<<a5.ToString();
+  }
+
+//////////////////////////////////////////////////
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a8, a4, a5({3,6,2});
+    //Debug();
+    a8=a1.Fetch(slice);
+    a4=a2.Fetch(slice);
+    LOG(ERROR)<<"fetch a";
+    LOG(ERROR)<<a8.ToString();
+    LOG(ERROR)<<"fetch b";
+    LOG(ERROR)<<a4.ToString();
+    a5.Mult(a8,a4);
+    LOG(ERROR)<<"fetch op c.mult(a,b)";
+    LOG(ERROR)<<a5.ToString();
+  }
+  ARMCI_Barrier();
+  a3.Mult(a1,a2);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a3.Fetch(slice);
+    LOG(ERROR)<<"op fetch a.Mult(b,c)";
+    LOG(ERROR)<<a5.ToString();
+  }
+//////////////////////////////////////////////////
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a8, a4, a5({3,6,2});
+    //Debug();
+    a8=a1.Fetch(slice);
+    a4=a2.Fetch(slice);
+    LOG(ERROR)<<"fetch a";
+    LOG(ERROR)<<a8.ToString();
+    LOG(ERROR)<<"fetch b";
+    LOG(ERROR)<<a4.ToString();
+    a5.Div(a8,a4);
+    LOG(ERROR)<<"fetch op c.div(a,b)";
+    LOG(ERROR)<<a5.ToString();
+  }
+  ARMCI_Barrier();
+  a3.Div(a1,a2);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a3.Fetch(slice);
+    LOG(ERROR)<<"op fetch a.div(b,c)";
+    LOG(ERROR)<<a5.ToString();
+  }
+//////////////////////////////////////////////////
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a8, a4, a5({3,6,2});
+    //Debug();
+    a8=a1.Fetch(slice);
+    LOG(ERROR)<<"fetch a";
+    LOG(ERROR)<<a8.ToString();
+    a5.Mult(a8, 3.0);
+    LOG(ERROR)<<"fetch op c.mult(a,3)";
+    LOG(ERROR)<<a5.ToString();
+  }
+  ARMCI_Barrier();
+  a3.Mult(a1,3.0);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a3.Fetch(slice);
+    LOG(ERROR)<<"op fetch a.mult(b,3)";
+    LOG(ERROR)<<a5.ToString();
+  }
+
+//////////////////////////////////////////////////
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a8, a4, a5({3,6,2});
+    //Debug();
+    a8=a1.Fetch(slice);
+    LOG(ERROR)<<"fetch a";
+    LOG(ERROR)<<a8.ToString();
+    a5.Square(a8);
+    LOG(ERROR)<<"fetch op c.square(a)";
+    LOG(ERROR)<<a5.ToString();
+  }
+  ARMCI_Barrier();
+  a3.Square(a1);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a3.Fetch(slice);
+    LOG(ERROR)<<"op fetch a.sqaure(b)";
+    LOG(ERROR)<<a5.ToString();
+  }
+
+
+//////////////////////////////////////////////////
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a8, a4, a5({3,6,2});
+    //Debug();
+    a8=a1.Fetch(slice);
+    LOG(ERROR)<<"fetch a";
+    LOG(ERROR)<<a8.ToString();
+    a5.Pow(a8,3.0);
+    LOG(ERROR)<<"fetch op c.pow(a, 3)";
+    LOG(ERROR)<<a5.ToString();
+  }
+  ARMCI_Barrier();
+  a3.Pow(a1,3.0);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a3.Fetch(slice);
+    LOG(ERROR)<<"op fetch a.pow(b,3)";
+    LOG(ERROR)<<a5.ToString();
+  }
+
+
+//////////////////////////////////////////////////
+  ARMCI_Barrier();
+  a3.SampleUniform(0.0,3.0);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a3.Fetch(slice);
+    LOG(ERROR)<<"op fetch a.uniform(0,3)";
+    LOG(ERROR)<<a5.ToString();
+  }
+//////////////////////////////////////////////////
+  ARMCI_Barrier();
+  a3.SampleGaussian(0.0,1.0);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a3.Fetch(slice);
+    LOG(ERROR)<<"op fetch a.norm(0,1)";
+    LOG(ERROR)<<a5.ToString();
+  }
+
+//////////////////////////////////////////////////
+  ARMCI_Barrier();
+  a3.Fill(1.43);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a3.Fetch(slice);
+    LOG(ERROR)<<"op fetch a.fill(1.43)";
+    LOG(ERROR)<<a5.ToString();
+  }
+
+
+//////////////////////////////////////////////////
+  ARMCI_Barrier();
+  a1.Random();
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a8, a4, a5({3,6,2});
+    a4=a1.Fetch(slice);
+    a5.Threshold(a4,0.3);
+    LOG(ERROR)<<"fetch op b=threshold(a,0.3)";
+    LOG(ERROR)<<a4.ToString();
+    LOG(ERROR)<<a5.ToString();
+  }
+
+  ARMCI_Barrier();
+  a3.Threshold(a1, .30f);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a3.Fetch(slice);
+    LOG(ERROR)<<"op fetch b=threshold(a,0.3)";
+    LOG(ERROR)<<a5.ToString();
+  }
+
+//////////////////////////////////////////////////
+  ARMCI_Barrier();
+  a1.Random();
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a8, a4, a5({3,6,2});
+    a4=a1.Fetch(slice);
+    a5.Max(a4,0.3);
+    LOG(ERROR)<<"fetch op b=max(a,0.3)";
+    LOG(ERROR)<<a4.ToString();
+    LOG(ERROR)<<a5.ToString();
+  }
+
+  ARMCI_Barrier();
+  a3.Max(a1, .30f);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a3.Fetch(slice);
+    LOG(ERROR)<<"op fetch b=max(a,0.3)";
+    LOG(ERROR)<<a5.ToString();
+  }
+
+
+//////////////////////////////////////////////////
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry a6, a4, a5({3,6,2});
+    a6=a1.Fetch(slice);
+    a4=a2.Fetch(slice);
+    a5.Map([](float a, float b) {return a+2*b;}, a6,a4);
+    LOG(ERROR)<<"fetch op b=map(a+2b)";
+    LOG(ERROR)<<a6.ToString();
+    LOG(ERROR)<<a4.ToString();
+    LOG(ERROR)<<a5.ToString();
+  }
+  ARMCI_Barrier();
+  a3.Map([](float a, float b) {return a+2*b;}, a1,a2);
+  if(rank==0){
+    lapis::DAry a5;
+    a5=a3.Fetch(slice);
+    LOG(ERROR)<<"op fetch b=map(a+2b)";
+    LOG(ERROR)<<a5.ToString();
+  }
+  LOG(ERROR)<<"finish elementwise ops";
+}
+
+
+void TestLargeDot(int pa, int pb, int pc, int rank){
+  if(rank==0){
+    LOG(ERROR)<<"test Dot, partition for a, b, c : "
+      << pa<<" "<<pb<<" "<<pc<<" dim";
+  }
+
+  double t1, t2, t3;
+  t1=MPI_Wtime();
+  lapis::DAry a,b,c;
+  a.SetShape({256,9216});
+  b.SetShape({9216,4096});
+  c.SetShape({256,4096});
+  a.Setup(pa);
+  b.Setup(pb);
+  c.Setup(pc);
+  a.Random();
+  b.Random();
+  c.Random();
+  ARMCI_Barrier();
+  t2=MPI_Wtime();
+  c.Dot(a,b);
+  t3=MPI_Wtime();
+  ARMCI_Barrier();
+  LOG(ERROR)<<"setup time: "<<t2-t1<<" dot time: "
+    <<t3-t2<<" wait time:"<<MPI_Wtime()-t3;
+}
+
+void TestDot(int pa, int pb, int pc, int rank){
+  vector<lapis::Range> slicea{make_pair(0,4), make_pair(0,8)};
+  vector<lapis::Range> sliceb{make_pair(0,8), make_pair(0,4)};
+  vector<lapis::Range> slicec{make_pair(0,4), make_pair(0,4)};
+  lapis::DAry a,b,c;
+  a.SetShape({4,8});
+  b.SetShape({8,4});
+  c.SetShape({4,4});
+  a.Setup(pa);
+  b.Setup(pb);
+  c.Setup(pc);
+  a.Random();
+  b.Random();
+  c.Random();
+  //////////////////////
+  ARMCI_Barrier();
+  if(rank==0){
+    LOG(ERROR)<<"test Dot, partition for a, b, c : "
+      << pa<<" "<<pb<<" "<<pc<<" dim";
+    LOG(ERROR)<<"c=a*b";
+    lapis::DAry x,y,z;
+    x=a.Fetch(slicea);
+    y=b.Fetch(sliceb);
+    z=c.Fetch(slicec);
+    z.Dot(x,y);
+    LOG(ERROR)<<"fetch dot ";
+    LOG(ERROR)<<z.ToString();
+  }
+  ARMCI_Barrier();
+  //Debug();
+  c.Dot(a,b);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry z;
+    z=c.Fetch(slicec);
+    LOG(ERROR)<<"dot fetch";
+    LOG(ERROR)<<z.ToString();
+  }
+  /////////////////////////////
+  ARMCI_Barrier();
+
+  if(rank==0){
+    LOG(ERROR)<<"a=c*b^T";
+    lapis::DAry x,y,z;
+    x=a.Fetch(slicea);
+    y=b.Fetch(sliceb);
+    z=c.Fetch(slicec);
+    x.Dot(z,y, false, true);
+    LOG(ERROR)<<"fetch dot ";
+    LOG(ERROR)<<x.ToString();
+  }
+  ARMCI_Barrier();
+  //Debug();
+  a.Dot(c,b, false, true);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry z;
+    z=a.Fetch(slicea);
+    LOG(ERROR)<<"dot fetch";
+    LOG(ERROR)<<z.ToString();
+  }
+
+  /////////////////////////////
+  ARMCI_Barrier();
+  if(rank==0){
+    LOG(ERROR)<<"b=a^T*c";
+    lapis::DAry x,y,z;
+    x=a.Fetch(slicea);
+    y=b.Fetch(sliceb);
+    z=c.Fetch(slicec);
+    y.Dot(x,z, true, false);
+    LOG(ERROR)<<"fetch dot ";
+    LOG(ERROR)<<y.ToString();
+  }
+  ARMCI_Barrier();
+  //Debug();
+  b.Dot(a,c, true, false);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry z;
+    z=b.Fetch(sliceb);
+    LOG(ERROR)<<"dot fetch";
+    LOG(ERROR)<<z.ToString();
+  }
+  ARMCI_Barrier();
+  /////////////////////////////
+  ARMCI_Barrier();
+  if(rank==0){
+    LOG(ERROR)<<"b=a^T*c^T";
+    lapis::DAry x,y,z;
+    x=a.Fetch(slicea);
+    y=b.Fetch(sliceb);
+    z=c.Fetch(slicec);
+    y.Dot(x,z, true, true);
+    LOG(ERROR)<<"fetch dot ";
+    LOG(ERROR)<<y.ToString();
+  }
+  ARMCI_Barrier();
+  //Debug();
+  b.Dot(a,c, true, true);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry z;
+    z=b.Fetch(sliceb);
+    LOG(ERROR)<<"dot fetch";
+    LOG(ERROR)<<z.ToString();
+  }
+  ARMCI_Barrier();
+}
+
+
+void TestSubarray(int pa, int pb, int pc, int rank){
+  vector<lapis::Range> slicea{make_pair(0,4), make_pair(0,8)};
+  vector<lapis::Range> sliceb{make_pair(0,8), make_pair(0,4)};
+  vector<lapis::Range> slicec{make_pair(0,4), make_pair(0,4)};
+  vector<lapis::Range> slice{make_pair(0,4)};
+  lapis::DAry a,b,c;
+  a.SetShape({4});
+  b.SetShape({8,4});
+  c.SetShape({4,4});
+  a.Setup(pa);
+  b.Setup(pb);
+  c.Setup(pc);
+  b.Random();
+  c.Random();
+
+  //Debug();
+  lapis::DAry sb=b[2];
+  lapis::DAry sc=c[3];
+
+  ARMCI_Barrier();
+  if(rank==0){
+    LOG(ERROR)<<"test subary, partition for a, b, c : "
+      << pa<<" "<<pb<<" "<<pc<<" dim";
+    lapis::DAry y,z, x({4});
+    LOG(ERROR)<<"fetch full b, c";
+    y=b.Fetch(sliceb);
+    z=c.Fetch(slicec);
+    LOG(ERROR)<<y.ToString();
+    LOG(ERROR)<<z.ToString();
+    LOG(ERROR)<<"fetch sub, sb[2], sc[3]";
+    y=sb.Fetch(slice);
+    z=sc.Fetch(slice);
+    LOG(ERROR)<<y.ToString();
+    LOG(ERROR)<<z.ToString();
+  }
+  ARMCI_Barrier();
+  a.Add(sb,sc);
+  ARMCI_Barrier();
+  //Debug();
+  if(rank==0){
+    lapis::DAry z;
+    z=a.Fetch(slice);
+    LOG(ERROR)<<"sub add fetch, sb[2]+sc[3]";
+    LOG(ERROR)<<z.ToString();
+  }
+}
+
+void TestReshape(int pa, int pb, int pc, int rank){
+  vector<lapis::Range> sliceb3{make_pair(0,2),make_pair(0,4), make_pair(0,4)};
+  vector<lapis::Range> sliceb{make_pair(0,8), make_pair(0,4)};
+  vector<lapis::Range> slicec{make_pair(0,4), make_pair(0,4)};
+  vector<lapis::Range> slicea{make_pair(0,4)};
+  lapis::DAry a,b,c,b3,b2,b1;
+  a.SetShape({4});
+  b.SetShape({8,4});
+  c.SetShape({4,4});
+  a.Setup(pa);
+  b.Setup(pb);
+  c.Setup(pc);
+  b.Random();
+  c.Random();
+
+  b3=b.Reshape({2,4,4});
+  //Debug() ;
+  b2=b3[1];
+  if(rank==0){
+    LOG(ERROR)<<"test reshape+subary, partition for a, b, c : "
+      << pa<<" "<<pb<<" "<<pc<<" dim";
+    lapis::DAry y,z,x;
+    LOG(ERROR)<<"fetch b, b2, c";
+    y=b.Fetch(sliceb);
+    z=b2.Fetch(slicec);
+    x=c.Fetch(slicec);
+    LOG(ERROR)<<y.ToString();
+    LOG(ERROR)<<z.ToString();
+    LOG(ERROR)<<x.ToString();
+    LOG(ERROR)<<"fetch sub, b2+c";
+    z.Add(x);
+    LOG(ERROR)<<z.ToString();
+  }
+
+  ARMCI_Barrier();
+  c.Add(b2);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry y,z,x;
+    x=c.Fetch(slicec);
+    LOG(ERROR)<<"sub add,fetch c+b2";
+    LOG(ERROR)<<x.ToString();
+  }
+  ARMCI_Barrier();
+  b2.Add(c);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry y,z,x;
+    x=b2.Fetch(slicec);
+    LOG(ERROR)<<"sub add,fetch b2+c";
+    LOG(ERROR)<<x.ToString();
+  }
+  ARMCI_Barrier();
+  b1=b2[2];
+  if(rank==0){
+    lapis::DAry y,z,x;
+    x=b1.Fetch(slicea);
+    LOG(ERROR)<<"fetch b1";
+    LOG(ERROR)<<x.ToString();
+  }
+
+  a.Add(b1);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry y,z,x;
+    x=a.Fetch(slicea);
+    LOG(ERROR)<<"add fetch a+b1";
+    LOG(ERROR)<<x.ToString();
+  }
+  ARMCI_Barrier();
+  b1.Add(a);
+  ARMCI_Barrier();
+  if(rank==0){
+    lapis::DAry y,z,x;
+    x=b1.Fetch(slicea);
+    LOG(ERROR)<<"add fetch b1+a";
+    LOG(ERROR)<<x.ToString();
+  }
+
+  ARMCI_Barrier();
+  {
+    lapis::DAry b3=b.Reshape({4,2,4});
+    lapis::DAry a;
+    a.SetShape({2,4});
+    a.Setup(pa);
+    a.Random();
+    lapis::DAry b1=b3[1];
+    lapis::DAry b2=b3[3];
+    lapis::DAry c;
+    c.SetShape({2,2});
+    c.Setup(pc);
+    ARMCI_Barrier();
+    c.Dot(a,b2,false, true);
+    ARMCI_Barrier();
+    if(rank==0){
+      lapis::DAry x,y,z,zz({2,2});
+      y=b3.Fetch({make_pair(0,4), make_pair(0,2), make_pair(0,4)});
+      x=a.Fetch({make_pair(0,2), make_pair(0,4)});
+      LOG(ERROR)<<"fetch b,a";
+      LOG(ERROR)<<y.ToString();
+      LOG(ERROR)<<x.ToString();
+      z=y[3];
+      zz.Dot(x,z,false, true);
+      LOG(ERROR)<<"fetch dot c=a*b[3]^T";
+      LOG(ERROR)<<zz.ToString();
+
+      x=a.Fetch({make_pair(0,2), make_pair(0,4)});
+      y=b2.Fetch({make_pair(0,2), make_pair(0,4)});
+      z=c.Fetch({make_pair(0,2), make_pair(0,2)});
+      LOG(ERROR)<<"op fetch c=a*b[3]^T";
+      LOG(ERROR)<<x.ToString();
+      LOG(ERROR)<<y.ToString();
+      LOG(ERROR)<<z.ToString();
+
+    }
+    ARMCI_Barrier();
+  }
+}
+
+
+
+int main(int argc, char**argv){
+ // MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided);
+  MPI_Init(&argc, &argv);
+  int rank, nprocs;
+  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+  MPI_Comm_size(MPI_COMM_WORLD, &nprocs);
+  vector<int> procs;
+  for (int i = 0; i < nprocs; i++) {
+    procs.push_back(i);
+  }
+  //Debug();
+  lapis::GAry::Init(rank,procs);
+  google::InitGoogleLogging(argv[0]);
+  /*
+  if(nprocs%3==0){
+    TestMixedParElt(0,0,0,rank);
+    TestMixedParElt(0,0,1,rank);
+    TestMixedParElt(0,1,0,rank);
+    TestMixedParElt(1,0,0,rank);
+    TestMixedParElt(1,1,0,rank);
+    TestMixedParElt(1,1,1,rank);
+    TestMixedParElt(0,1,1,rank);
+  }
+  if(nprocs%2==0){
+    TestMixedParElt(1,1,1,rank);
+    TestMixedParElt(1,2,1,rank);
+    TestMixedParElt(2,1,1,rank);
+    TestMixedParElt(1,1,2,rank);
+    TestMixedParElt(2,2,2,rank);
+  }
+  TestDot(0,0,0,rank);
+  TestDot(0,0,1,rank);
+  TestDot(0,1,0,rank);
+  TestDot(0,1,1,rank);
+  TestDot(1,0,0,rank);
+  TestDot(1,0,1,rank);
+  TestDot(1,1,0,rank);
+  TestDot(1,1,1,rank);
+
+  TestPar(0, rank);
+  TestPar(1, rank);
+  */
+  double start, end;
+  start=MPI_Wtime();
+  TestLargeDot(0,0,0,rank);
+  TestLargeDot(0,0,1,rank);
+  TestLargeDot(0,1,0,rank);
+  TestLargeDot(0,1,1,rank);
+  TestLargeDot(1,0,0,rank);
+  TestLargeDot(1,0,1,rank);
+  TestLargeDot(1,1,0,rank);
+  TestLargeDot(1,1,1,rank);
+  end=MPI_Wtime();
+  if(rank==0)
+    LOG(ERROR)<<"dot time for 256*4k 4k*4k matrix, "<<end-start;
+  /*
+  TestSubarray(0,0,0,rank);
+  TestSubarray(0,0,1,rank);
+  TestSubarray(0,1,0,rank);
+  TestSubarray(0,1,1,rank);
+  TestReshape(0,0,0,rank);
+  TestReshape(0,0,1,rank);
+  TestReshape(0,1,0,rank);
+  TestReshape(0,1,1,rank);
+  */
+
+  LOG(ERROR)<<"finish";
+  lapis::GAry::Finalize();
+  MPI_Finalize();
+  return 0;
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_dary.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_dary.cc b/src/test/dist_test/test_dary.cc
new file mode 100644
index 0000000..ce605e6
--- /dev/null
+++ b/src/test/dist_test/test_dary.cc
@@ -0,0 +1,85 @@
+#include <iostream>
+#include "darray/dary.h"
+#include "utils/timer.h"
+
+
+int main() {
+  lapis::DAry x({1000000});
+  lapis::DAry y({1000000});
+  x.Random();
+  y.Random();
+  lapis::Timer t;
+  for(int i=0;i<100;i++){
+    float *dptrx=x.dptr();
+    float *dptry=y.dptr();
+    for(int k=0;k<10000;k++)
+      dptrx[k]*=dptry[k];
+  }
+  std::cout<<"arymath: "<<t.elapsed()/10<<std::endl;
+  lapis::DAry m({1000000});
+  lapis::DAry n({1000000});
+  m.Random();
+  n.Random();
+  t.Reset();
+  for(int i=0;i<100;i++)
+    m.Mult(m,n);
+  std::cout<<"arymath: "<<t.elapsed()/10<<std::endl;
+
+
+  lapis::DAry a({2,2});
+  lapis::DAry b,c;
+  b.InitLike(a);
+  c.InitLike(a);
+  a.Random();
+  b.Random();
+  std::cout<<a.ToString()<<std::endl;
+  std::cout<<b.ToString()<<std::endl;
+  c.Dot(a,b);
+  std::cout<<"c=a.b"<<c.ToString()<<std::endl;
+  a.Add(b);
+  std::cout<<"a=a+b"<<a.ToString()<<std::endl;
+  a.Mult(a,b);
+  std::cout<<"a=a*b"<<a.ToString()<<std::endl;
+  a.Minus(a,b);
+  std::cout<<"a=a-b"<<a.ToString()<<std::endl;
+
+  c.Random();
+  std::cout<<"random c "<<c.ToString()<<std::endl;
+  a.Threshold(c, 0.3);
+  std::cout<<"a=threshold(c,0.3) "<<a.ToString()<<std::endl;
+
+  a.Pow(c, 0.4);
+  std::cout<<"a=Pow(c,0.4) "<<a.ToString()<<std::endl;
+
+  c.Set(0.5);
+  std::cout<<"c=set(0.5) "<<c.ToString()<<std::endl;
+  a.Square(c);
+  std::cout<<"a=square(c) "<<a.ToString()<<std::endl;
+
+  c.Copy(a);
+  std::cout<<"c=Copy(a) "<<c.ToString()<<std::endl;
+
+  lapis::DAry d({2});
+  d.SumRow(b);
+  std::cout<<"d=SumRow(b) "<<d.ToString()<<std::endl;
+  d.SumCol(b);
+  std::cout<<"d=SumCol(b) "<<d.ToString()<<std::endl;
+  b.AddRow(d);
+  std::cout<<"b=AddRow(d) "<<b.ToString()<<std::endl;
+  b.AddCol(d);
+  std::cout<<"b=AddCol(d) "<<b.ToString()<<std::endl;
+
+  std::cout<<"max(b) "<<b.Max()<<std::endl;
+  std::cout<<"Sum(b) "<<b.Sum()<<std::endl;
+
+  lapis::DAry e({3,3,3});
+  e.SampleGaussian(0.0f,1.0f);
+  std::cout<<"Gaussain e "<<e.ToString()<<std::endl;
+
+  lapis::DAry f({9});
+  f.Sum(e, 0, {0,2});
+  std::cout<<"f.sum  "<<f.ToString()<<std::endl;
+
+  return 0;
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_disk_table.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_disk_table.cc b/src/test/dist_test/test_disk_table.cc
new file mode 100644
index 0000000..99987bb
--- /dev/null
+++ b/src/test/dist_test/test_disk_table.cc
@@ -0,0 +1,188 @@
+//  Copyright © 2014 Anh Dinh. All Rights Reserved.
+//  main class for testing distributed memory layer
+//
+//  the command to run this should be:
+//		mpirun -hostfile <host> -bycore -nooversubscribe
+//				-n <num_servers> test -sync_update
+
+
+#include "core/global-table.h"
+#include "core/common.h"
+#include "core/disk-table.h"
+#include "core/table.h"
+#include "core/table_server.h"
+#include "utils/global_context.h"
+#include <gflags/gflags.h>
+#include "proto/model.pb.h"
+#include "worker.h"
+#include <cmath>
+
+DEFINE_int32(record_size,100, "# elements per float vector");
+DECLARE_int32(block_size);
+DEFINE_int32(table_size, 1000, "# records per table");
+DEFINE_string(system_conf, "examples/imagenet12/system.conf", "configuration file for node roles");
+DEFINE_string(model_conf, "examples/imagenet12/model.conf", "DL model configuration file");
+DEFINE_bool(is_testing_put,true, "data put vs. data get");
+DECLARE_int32(debug_index);
+DECLARE_int32(table_buffer);
+using namespace lapis;
+
+typedef map<int, GlobalTable*> Map;
+Map tables;
+
+//  put random message to the pointers
+void create_random_message(FloatVector* message, const int count){
+	for (int i=0; i<FLAGS_record_size; i++){
+		message->add_data(count*FLAGS_record_size+i);
+	}
+}
+
+void create_disk_table(int id){
+	DiskTableDescriptor *info = new DiskTableDescriptor(id, "disk_test",
+			FLAGS_block_size);
+	info->key_marshal = new Marshal<int>();
+	info->value_marshal = new Marshal<FloatVector>();
+	tables[id] = new TypedDiskTable<int,FloatVector>(info);
+}
+
+
+//  if testing put, write and send data. Else do nothing
+void run_coordinator(shared_ptr<NetworkThread> network, int tid){
+	// wait for wokers to be up
+	RegisterWorkerRequest req;
+	for (int i=0; i<network->size()-1; i++)
+		network->Read(MPI::ANY_SOURCE, MTYPE_REGISTER_WORKER, &req);
+
+	// put data in
+	TypedDiskTable<int, FloatVector>* table = static_cast<TypedDiskTable<int,
+			FloatVector>*>(tables[tid]);
+
+	//  if testing put()
+	if (FLAGS_is_testing_put) {
+		int count = 0;
+		for (int i = 0; i < FLAGS_table_size; i++) {
+			FloatVector message;
+			create_random_message(&message, i);
+			table->put(i, message);
+			count += message.ByteSize();
+		}
+		table->finish_put();
+	}
+
+	VLOG(3) << "Coordinator about to shut down";
+	for (int i=0; i<network->size()-1; i++){
+		EmptyMessage end_msg;
+		network->Read(i,MTYPE_WORKER_END, &end_msg);
+	}
+
+	EmptyMessage shutdown_msg;
+	for (int i = 0; i < network->size() - 1; i++) {
+		network->Send(i, MTYPE_WORKER_SHUTDOWN, shutdown_msg);
+	}
+	network->Flush();
+	network->Shutdown();
+	table->PrintStats();
+
+	if (FLAGS_is_testing_put) {
+		int sub_blocks = ceil(((double) FLAGS_table_size / FLAGS_table_buffer));
+		CHECK_EQ(table->stats()["total sub block sent"], sub_blocks);
+		CHECK_EQ(table->stats()["total record sent"], FLAGS_table_size);
+		VLOG(3) << "test coordinator sending: successful";
+	}
+
+}
+
+//  if testing put(), do nothing. Else read() until done()
+void run_worker(shared_ptr<NetworkThread> network, int tid){
+	TableServer* ts = new TableServer();
+	ts->StartTableServer(tables);
+
+	// put data in
+	TypedDiskTable<int, FloatVector>* table = static_cast<TypedDiskTable<int,
+			FloatVector>*>(tables[tid]);
+	double total_read = 0;
+	if (!FLAGS_is_testing_put){
+		VLOG(3) << "testing read from table ...";
+		table->Load();
+		while (!table->done()){
+			int k;
+			FloatVector v;
+			table->get(&k,&v);
+			table->Next();
+			total_read++;
+		}
+
+		int k;
+		FloatVector v;
+		table->get(&k, &v);
+		total_read++;
+	}
+
+	int size = network->size();
+
+	network->Flush();
+	network->Send(GlobalContext::kCoordinatorRank, MTYPE_WORKER_END,
+			EmptyMessage());
+	EmptyMessage msg;
+
+	int src = 0;
+	network->Read(GlobalContext::kCoordinatorRank, MTYPE_WORKER_SHUTDOWN, &msg,
+			&src);
+	network->Flush();
+	network->Shutdown();
+
+	Stats stats =
+			(static_cast<TypedDiskTable<int, FloatVector>*>(tables[0]))->stats();
+
+	if (FLAGS_is_testing_put) {
+		int sub_blocks = ceil(((double) FLAGS_table_size / FLAGS_table_buffer));
+		if (size == 2) {
+			CHECK_EQ(stats["total sub block received"], sub_blocks);
+			CHECK_EQ(stats["total record stored"], FLAGS_table_size);
+		}
+		VLOG(3) << "test table-server writing: successful";
+		VLOG(3) << "number of sub blocks = " << sub_blocks;
+		VLOG(3) << "total data stored = " << stats["total byte stored"];
+	}
+	else{
+		if (size==2)
+			CHECK_EQ(stats["total record read"], FLAGS_table_size);
+		VLOG(3) << "test table-server reading: successful";
+		VLOG(3) << "read bandwidth = "
+				<< (stats["total byte read"]
+						/ (stats["last byte read"] - stats["first byte read"]));
+		//VLOG(3) << "total number of record read = " << stats["total record read"];
+	}
+
+	network->PrintStats();
+	static_cast<TypedDiskTable<int, FloatVector>*>(tables[0])->PrintStats();
+}
+
+//  check all the records have been stored to disk
+int test_disk(int tid) {
+	// Init GlobalContext
+	auto gc = lapis::GlobalContext::Get(FLAGS_system_conf, FLAGS_model_conf);
+	//start network thread
+	shared_ptr<NetworkThread> network = NetworkThread::Get();
+
+	if (network->id() == network->size() - 1)
+		run_coordinator(network, tid);
+	else
+		run_worker(network,tid);
+	return 0;
+}
+
+// for debugging use
+//#ifndef FLAGS_v
+//  DEFINE_int32(v, 3, "vlog controller");
+//#endif
+
+int main(int argc, char **argv) {
+	FLAGS_logtostderr = 1;
+	google::InitGoogleLogging(argv[0]);
+	gflags::ParseCommandLineFlags(&argc, &argv, true);
+	create_disk_table(0);
+	return test_disk(0);
+}
+
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_mnistlayer.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_mnistlayer.cc b/src/test/dist_test/test_mnistlayer.cc
new file mode 100644
index 0000000..882e121
--- /dev/null
+++ b/src/test/dist_test/test_mnistlayer.cc
@@ -0,0 +1,165 @@
+#include <gtest/gtest.h>
+#include <sys/stat.h>
+#include <cstdint>
+#include "opencv2/highgui/highgui.hpp"
+#include "opencv2/imgproc/imgproc.hpp"
+
+#include "model/layer.h"
+#include "proto/model.pb.h"
+#include "utils/shard.h"
+using namespace singa;
+TEST(MnistLayerTest, SingleScale){
+  LayerProto proto;
+  MnistProto *mnist=proto.mutable_mnist_param();
+  mnist->set_size(55);
+  MnistImageLayer layer;
+  layer.FromProto(proto);
+  cv::Mat image;
+  image=cv::imread("src/test/data/mnist.png", 0);
+  string pixel;
+  pixel.resize(image.rows*image.cols);
+  for(int i=0,k=0;i<image.rows;i++)
+    for(int j=0; j<image.cols;j++)
+      pixel[k++]=static_cast<char>(image.at<uint8_t>(i,j));
+  Record rec;
+  rec.set_type(Record_Type_kMnist);
+  MnistRecord *mrec=rec.mutable_mnist();
+  mrec->set_pixel(pixel);
+  layer.Setup(1, rec, kNone);
+  layer.AddInputRecord(rec);
+
+  const vector<uint8_t>& dat=layer.Convert2Image(0);
+  int s=static_cast<int>(sqrt(dat.size()));
+  cv::Mat newimg(s,s,CV_8UC1);
+  int count=0;
+  for(int i=0,k=0;i<newimg.rows;i++)
+    for(int j=0; j<newimg.cols;j++){
+      count+=dat[k]>0;
+      newimg.at<uint8_t>(i,j)=dat[k++];
+    }
+  //LOG(ERROR)<<"image positive "<<count<<" size "<<s;
+  cv::imwrite("src/test/data/mnist_scale.png", newimg);
+}
+
+TEST(MnistLayerTest, SingleAffineTransform){
+  LayerProto proto;
+  MnistProto *mnist=proto.mutable_mnist_param();
+  mnist->set_beta(15);
+  mnist->set_gamma(16);
+  mnist->set_size(55);
+  MnistImageLayer layer;
+  layer.FromProto(proto);
+  cv::Mat image;
+  image=cv::imread("src/test/data/mnist.png", 0);
+  string pixel;
+  pixel.resize(image.rows*image.cols);
+  for(int i=0,k=0;i<image.rows;i++)
+    for(int j=0; j<image.cols;j++)
+      pixel[k++]=static_cast<char>(image.at<uint8_t>(i,j));
+  Record rec;
+  rec.set_type(Record_Type_kMnist);
+  MnistRecord *mrec=rec.mutable_mnist();
+  mrec->set_pixel(pixel);
+  layer.Setup(1, rec, kNone);
+  layer.AddInputRecord(rec);
+
+  const vector<uint8_t>& dat=layer.Convert2Image(0);
+  int s=static_cast<int>(sqrt(dat.size()));
+  cv::Mat newimg(s,s,CV_8UC1);
+  int count=0;
+  for(int i=0,k=0;i<newimg.rows;i++)
+    for(int j=0; j<newimg.cols;j++){
+      count+=dat[k]>0;
+      newimg.at<uint8_t>(i,j)=dat[k++];
+    }
+  //LOG(ERROR)<<"image positive "<<count<<" size "<<s;
+
+  cv::imwrite("src/test/data/mnist_affine.png", newimg);
+}
+TEST(MnistLayerTest, SingleElasticDistortion){
+  LayerProto proto;
+  MnistProto *mnist=proto.mutable_mnist_param();
+  mnist->set_elastic_freq(1);
+  mnist->set_sigma(6);
+  mnist->set_alpha(36);
+  mnist->set_beta(15);
+  mnist->set_gamma(16);
+  mnist->set_size(55);
+  mnist->set_kernel(21);
+  MnistImageLayer layer;
+  layer.FromProto(proto);
+  cv::Mat image;
+  image=cv::imread("src/test/data/mnist.png", 0);
+  string pixel;
+  pixel.resize(image.rows*image.cols);
+  for(int i=0,k=0;i<image.rows;i++)
+    for(int j=0; j<image.cols;j++)
+      pixel[k++]=static_cast<char>(image.at<uint8_t>(i,j));
+  Record rec;
+  rec.set_type(Record_Type_kMnist);
+  MnistRecord *mrec=rec.mutable_mnist();
+  mrec->set_pixel(pixel);
+  layer.Setup(1, rec, kNone);
+  layer.AddInputRecord(rec);
+
+  const vector<uint8_t>& dat=layer.Convert2Image(0);
+  int s=static_cast<int>(sqrt(dat.size()));
+  cv::Mat newimg(s,s,CV_8UC1);
+  int count=0;
+  for(int i=0,k=0;i<newimg.rows;i++)
+    for(int j=0; j<newimg.cols;j++){
+      count+=dat[k]>0;
+      newimg.at<uint8_t>(i,j)=dat[k++];
+    }
+  cv::imwrite("src/test/data/mnist_elastic.png", newimg);
+}
+TEST(MnistLayerTest, MultElasticDistortion){
+  LayerProto proto;
+  MnistProto *mnist=proto.mutable_mnist_param();
+  int kTotal=100;
+  int kSize=29;
+  mnist->set_elastic_freq(kTotal);
+  mnist->set_sigma(6);
+  mnist->set_alpha(36);
+  mnist->set_beta(15);
+  mnist->set_gamma(16);
+  mnist->set_size(kSize);
+  mnist->set_kernel(21);
+  MnistImageLayer layer;
+  layer.FromProto(proto);
+  vector<vector<int>> shapes{{kTotal, kSize,kSize}};
+  layer.Setup(shapes, kNone);
+  shard::Shard source("/data1/wangwei/singa/data/mnist/test/",shard::Shard::kRead);
+  int n=static_cast<int>(sqrt(kTotal));
+  cv::Mat origin(n*28,n*28, CV_8UC1);
+  char disp[1024];
+  for(int x=0;x<n;x++){
+    sprintf(disp+strlen(disp), "\n");
+    for(int y=0;y<n;y++){
+      Record rec;
+      string key;
+      CHECK(source.Next(&key, &rec));
+      const string pixel=rec.mnist().pixel();
+      cv::Mat img=origin(cv::Rect(y*28, x*28, 28, 28));
+      for(int i=0,k=0;i<28;i++)
+        for(int j=0;j<28;j++)
+          img.at<uint8_t>(i,j)=static_cast<uint8_t>(pixel[k++]);
+      layer.AddInputRecord(rec);
+      sprintf(disp+strlen(disp), "%d ", rec.mnist().label());
+    }
+  }
+  LOG(ERROR)<<disp;
+  cv::imwrite("src/test/data/mnist_big.png", origin);
+
+  cv::Mat output(n*kSize,n*kSize, CV_8UC1);
+  for(int i=0;i<kTotal;i++){
+    const vector<uint8_t>& dat=layer.Convert2Image(i);
+    int x=(i/n);
+    int y=i%n;
+    cv::Mat img=output(cv::Rect(y*kSize, x*kSize, kSize, kSize));
+    for(int i=0,k=0;i<kSize;i++)
+      for(int j=0;j<kSize;j++)
+        img.at<uint8_t>(i,j)=dat[k++];
+  }
+  cv::imwrite("src/test/data/mnist_bigout.png", output);
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_model.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_model.cc b/src/test/dist_test/test_model.cc
new file mode 100644
index 0000000..c3f98b9
--- /dev/null
+++ b/src/test/dist_test/test_model.cc
@@ -0,0 +1,25 @@
+// Copyright © 2014 Wei Wang. All Rights Reserved.
+// 2014-08-02 14:13
+#include <glog/logging.h>
+#include <gflags/gflags.h>
+
+
+#include "model/sgd_trainer.h"
+#include "model/net.h"
+#include "proto/model.pb.h"
+#include "utils/proto_helper.h"
+
+DEFINE_int32(v, 1, "vlog");
+
+int main(int argc, char** argv) {
+  FLAGS_logtostderr=1;
+  google::InitGoogleLogging(argv[0]);
+  gflags::ParseCommandLineFlags(&argc, &argv, true);
+  lapis::ModelProto model_proto;
+  lapis::ReadProtoFromTextFile("examples/imagenet12/model.conf", &model_proto);
+  lapis::SGDTrainer trainer;
+  trainer.Init(model_proto.trainer());
+  lapis::Net net;
+  net.Init(model_proto.net());
+  trainer.Run(&net);
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_neuralnet.cc b/src/test/dist_test/test_neuralnet.cc
new file mode 100644
index 0000000..a857124
--- /dev/null
+++ b/src/test/dist_test/test_neuralnet.cc
@@ -0,0 +1,141 @@
+#include <gtest/gtest.h>
+#include <model/neuralnet.h>
+#include "proto/model.pb.h"
+#include "utils/common.h"
+#include "utils/param_updater.h"
+
+using namespace singa;
+NetProto CreateMLPProto(){
+  ModelProto model;
+  ReadProtoFromTextFile("examples/mnist/mlp.conf", &model);
+  return model.neuralnet();
+}
+TEST(NeuralnetTest, BP){
+  ModelProto model;
+  ReadProtoFromTextFile("examples/mnist/mlp.conf", &model);
+
+  AdaGradUpdater updater;
+  updater.Init(model.solver().updater());
+
+  NeuralNet net(model.neuralnet());
+  auto layers=net.layers();
+  for(int i=0;i<3;i++){
+    bool firstlayer=true;
+    for(auto& layer: layers){
+      layer->ComputeFeature();
+      if(firstlayer){
+        DataLayer* dl=static_cast<DataLayer*>(layer.get());
+        dl->CompletePrefetch();
+        firstlayer=false;
+      }
+    }
+
+    for(int k=layers.size()-1;k>=0;k--){
+      layers[k]->ComputeGradient();
+      for(Param* param: layers[k]->GetParams())
+        updater.Update(i, param);
+    }
+  }
+}
+NetProto CreateConvNetProto(){
+  NetProto proto;
+  LayerProto *layer;
+
+  layer=proto.add_layer();
+  layer->set_name("data");
+  layer->set_type("kShardData");
+  DataProto *data=layer->mutable_data_param();
+  data->set_batchsize(8);
+  data->set_path("/data1/wangwei/singa/data/mnist/train/");
+
+  // 4x3x10x10
+  layer=proto.add_layer();
+  layer->set_name("mnist");
+  layer->set_type("kMnistImage");
+  layer->add_srclayers("data");
+
+  // 4x1
+  layer=proto.add_layer();
+  layer->set_name("label");
+  layer->set_type("kLabel");
+  layer->add_srclayers("data");
+
+  // 4x8x9x9
+  layer=proto.add_layer();
+  layer->set_name("conv1");
+  layer->set_type("kConvolution");
+  layer->add_srclayers("mnist");
+  layer->add_param();
+  layer->add_param();
+  ConvolutionProto *conv=layer->mutable_convolution_param();
+  conv->set_num_filters(8);
+  conv->set_kernel(2);
+
+  // 4x8x9x9
+  layer=proto.add_layer();
+  layer->set_name("relu1");
+  layer->set_type("kReLU");
+  layer->add_srclayers("conv1");
+
+  // 4x8x4x4
+  layer=proto.add_layer();
+  layer->set_name("pool1");
+  layer->set_type("kPooling");
+  layer->add_srclayers("relu1");
+  PoolingProto *pool=layer->mutable_pooling_param();
+  pool->set_kernel(4);
+  pool->set_stride(2);
+
+  // 4x10
+  layer=proto.add_layer();
+  layer->set_name("fc1");
+  layer->set_type("kInnerProduct");
+  layer->add_srclayers("pool1");
+  layer->add_param();
+  layer->add_param();
+  InnerProductProto *inner=layer->mutable_inner_product_param();
+  inner->set_num_output(10);
+
+  // 4x10
+  layer=proto.add_layer();
+  layer->set_name("loss");
+  layer->set_type("kSoftmaxLoss");
+  layer->add_srclayers("fc1");
+  layer->add_srclayers("label");
+
+  return proto;
+}
+
+TEST(NeuralNetTest, NoPartition){
+  NetProto proto=CreateConvNetProto();
+  NeuralNet net(proto);
+  const auto& layers=net.layers();
+  ASSERT_EQ(8, layers.size());
+  ASSERT_EQ("data", layers.at(0)->name());
+  ASSERT_EQ("loss", layers.at(7)->name());
+}
+
+TEST(NeuralNetTest, DataPartition){
+  NetProto proto=CreateConvNetProto();
+  proto.set_partition_type(kDataPartition);
+  NeuralNet net(proto, 3);
+  const auto& layers=net.layers();
+  ASSERT_EQ(28, layers.size());
+  ASSERT_EQ("data", layers.at(0)->name());
+}
+TEST(NeuralNetTest, LayerPartition){
+  NetProto proto=CreateConvNetProto();
+  proto.set_partition_type(kLayerPartition);
+  NeuralNet net(proto, 2);
+ // const auto& layers=net.layers();
+}
+TEST(NeuralNetTest, HyridPartition){
+  NetProto proto=CreateConvNetProto();
+  int num_layers=proto.layer_size();
+  proto.mutable_layer(num_layers-2)->set_partition_type(kDataPartition);
+  proto.mutable_layer(num_layers-1)->set_partition_type(kDataPartition);
+  proto.set_partition_type(kLayerPartition);
+  NeuralNet net(proto, 2);
+}
+
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_pm.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_pm.cc b/src/test/dist_test/test_pm.cc
new file mode 100644
index 0000000..67c210a
--- /dev/null
+++ b/src/test/dist_test/test_pm.cc
@@ -0,0 +1,88 @@
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+
+#include <iostream>
+#include <fstream>
+
+#include <gflags/gflags.h>
+#include <glog/logging.h>
+#include "utils/cluster.h"
+#include "utils/common.h"
+#include "proto/model.pb.h"
+#include "proto/cluster.pb.h"
+#include "server/server.h"
+#include "server/pm_server.h"
+#include "worker/pm_client.h"
+#include "worker/worker.h"
+#include "proto/topology.pb.h"
+#include <string.h>
+#include <google/protobuf/text_format.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+
+using namespace google::protobuf::io;
+using google::protobuf::TextFormat;
+
+using std::ifstream;
+
+/**
+ * Testing put/get/update performance of the new zeromq-based parameter
+ * servers.
+ */
+DEFINE_int32(procsID, 0, "global process ID");
+DEFINE_string(hostfile, "examples/imagenet12/hostfile", "hostfile");
+DEFINE_string(cluster_conf, "examples/imagenet12/cluster.conf",
+    "configuration file for the cluster");
+DEFINE_string(model_conf, "examples/imagenet12/model.conf",
+    "Deep learning model configuration file");
+
+DEFINE_string(topology_config,"examples/imagenet12/topology.conf", "Network of servers");
+DEFINE_int32(server_threads,1,"Number of server's worker threads per process");
+DEFINE_int32(client_threads,1,"Number of client's worker threads per process");
+
+DEFINE_string(mode, "client", "client or server mode");
+DEFINE_int32(node_id, 0, "ID of the node, client or server");
+DEFINE_int32(primary_set, 0, "ID of the primary server set (for client mode only)");
+
+/**
+ *
+ * Read the topology file in, and start the Client or server respectively.
+ *
+ * test_pm --node_id <id>
+ */
+
+
+#ifndef FLAGS_v
+  DEFINE_int32(v, 3, "vlog controller");
+#endif
+
+int main(int argc, char **argv) {
+	google::InitGoogleLogging(argv[0]);
+	gflags::ParseCommandLineFlags(&argc, &argv, true);
+	FLAGS_logtostderr = 1;
+
+
+	//Read in the topology file
+	int fd = open(FLAGS_topology_config.c_str(), O_RDONLY);
+	assert(fd != -1);
+	singa::Topology topology;
+	TextFormat::Parse(new FileInputStream(fd), &topology);
+
+
+	//read host file
+	ifstream hostfile(FLAGS_hostfile.c_str());
+	string host;
+	vector<string> hosts;
+	while (getline(hostfile, host))
+		hosts.push_back(host);
+	
+	if (FLAGS_node_id < topology.nservers()) {
+		singa::SingaServer *server = new singa::SingaServer(FLAGS_node_id, topology, hosts);
+		server->StartServer();
+	} else {
+		singa::SingaClient *client = new singa::SingaClient(FLAGS_node_id, topology, hosts);
+		client->StartClient();
+	}
+	
+	return 0;
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_router.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_router.cc b/src/test/dist_test/test_router.cc
new file mode 100644
index 0000000..bed3d99
--- /dev/null
+++ b/src/test/dist_test/test_router.cc
@@ -0,0 +1,27 @@
+#include <gflags/gflags.h>
+#include <gtest/gtest.h>
+#include "utils/router.h"
+#include "utils/common.h"
+#include "utils/cluster.h"
+DEFINE_string(hostfile, "examples/imagenet12/hostfile", "hostfile");
+DEFINE_string(cluster_conf, "examples/imagenet12/cluster.conf",
+    "configuration file for the cluster");
+DEFINE_int32(procsID, 0, "global process ID");
+
+int main(int argc, char** argv){
+  google::InitGoogleLogging(argv[0]);
+  gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+  // Init Cluster
+  singa::ClusterProto pcluster;
+  singa::ReadProtoFromTextFile(FLAGS_cluster_conf.c_str(), &pcluster);
+  auto cluster=singa::Cluster::Get(pcluster, FLAGS_hostfile, FLAGS_procsID);
+  if(cluster->AmIServer()){
+    singa::Router server(5732);
+    CHECK(server.Bind(cluster->server_addr(0), cluster->nworkers()));
+  }else{
+    singa::Router worker(5732);
+    CHECK(worker.Connect(cluster->server_addr(0)));
+  }
+  return 0;
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_split.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_split.cc b/src/test/dist_test/test_split.cc
new file mode 100644
index 0000000..674d546
--- /dev/null
+++ b/src/test/dist_test/test_split.cc
@@ -0,0 +1,304 @@
+//  Copyright © 2014 Anh Dinh. All Rights Reserved.
+
+
+//  Testing the unbalance in spliting parameter vectors.
+
+#include "core/global-table.h"
+#include "core/common.h"
+#include "core/disk-table.h"
+#include "core/table.h"
+#include "core/table_server.h"
+#include "utils/global_context.h"
+#include <gflags/gflags.h>
+#include "proto/model.pb.h"
+#include "worker.h"
+#include "coordinator.h"
+//#include "model_controller/myacc.h"
+#include "utils/common.h"
+
+#include <cmath>
+#include <stdlib.h>
+#include <vector>
+#include <iostream>
+#include <fstream>
+
+using namespace lapis;
+using std::vector;
+
+//DEFINE_bool(sync_update, false, "Synchronous put/update queue");
+DEFINE_string(system_conf, "examples/imagenet12/system.conf", "configuration file for node roles");
+DEFINE_string(model_conf, "examples/imagenet12/model.conf", "DL model configuration file");
+DEFINE_int64(threshold,1000000, "max # of parameters in a vector");
+DEFINE_int32(iterations,5,"numer of get/put iterations");
+DEFINE_int32(workers,2,"numer of workers doing get/put");
+#ifndef FLAGS_v
+  DEFINE_int32(v, 3, "vlog controller");
+#endif
+
+typedef map<int, GlobalTable*> Map;
+Map tables;
+shared_ptr<NetworkThread> network;
+shared_ptr<GlobalContext> context;
+std::vector<ServerState*> server_states;
+TableServer *table_server;
+
+FloatVector large_msg, small_msg;
+const int SIZE=16;
+
+long sizes[] = { 37448736, 16777216, 4096000, 1327104, 884736, 884736, 614400,
+		14112, 4096, 4096, 1000, 384, 384, 256, 256, 96 };
+
+vector<FloatVector*> value_msg;
+
+int num_keys;
+
+// create large and small messages
+void init_messages(){
+	num_keys = 0;
+  long nservers=context->num_table_servers();
+	for (int i=0; i<SIZE; i++){
+		int total=0;
+    int threshold=std::max(FLAGS_threshold,0l);//, sizes[i]/nservers);
+    VLOG(3)<<"worker: "<<threshold;
+		while (total<sizes[i]){
+			FloatVector* fv = new FloatVector();
+			for (int j=0; j+total<sizes[i] && j<threshold; j++)
+				fv->add_data(static_cast<float>(rand())/static_cast<float>(RAND_MAX));
+			value_msg.push_back(fv);
+			total+=threshold;
+			num_keys++;
+		}
+	}
+}
+
+void create_mem_table(int id, int num_shards){
+
+	TableDescriptor *info = new TableDescriptor(id, num_shards);
+	  info->key_marshal = new Marshal<int>();
+	  info->value_marshal = new Marshal<FloatVector>();
+	  info->sharder = new Sharding::Mod;
+	  info->accum = new MyAcc();
+	  info->partition_factory = new typename SparseTable<int, FloatVector>::Factory;
+	  auto table=new TypedGlobalTable<int, FloatVector>();
+	  table->Init(info);
+	  tables[id] = table;
+}
+
+void coordinator_assign_tables(int id){
+	for (int i = 0; i < context->num_processes()-1; ++i) {
+	    RegisterWorkerRequest req;
+	    int src = 0;
+	    network->Read(MPI::ANY_SOURCE, MTYPE_REGISTER_WORKER, &req, &src);
+	    //  adding memory server.
+	    if (context->IsTableServer(i)) {
+	      server_states.push_back(new ServerState(i));
+	    }
+	  }
+	  LOG(INFO) << " All servers registered and started up. Ready to go";
+	  //  set itself as the current worker for the table
+	  tables[id]->worker_id_ = network->id();
+
+	  // memory servers are specified in global context. Round-robin assignment
+
+	    VLOG(3)<<"num of shards"<<tables[id]->num_shards()<<" for table"<< id;
+
+	    int server_idx = 0;
+	    for (int shard = 0; shard < tables[id]->num_shards(); ++shard) {
+	      ServerState &server = *server_states[server_idx];
+	      LOG(INFO) << "Assigning table ("<<id<<","<<shard<<") to server "
+	                <<server_states[server_idx]->server_id;
+
+	      // TODO(Anh) may overwrite this field if #shards>#table_servers
+	      server.shard_id = shard;
+	      server.local_shards.insert(new TaskId(id, shard));
+	      server_idx = (server_idx + 1) % server_states.size();
+	    }
+
+	  VLOG(3)<<"table assignment";
+	  //  then send table assignment
+	  ShardAssignmentRequest req;
+	  for (size_t i = 0; i < server_states.size(); ++i) {
+	    ServerState &server = *server_states[i];
+	    for (auto * task: server.local_shards) {
+	      ShardAssignment *s  = req.add_assign();
+	      s->set_new_worker(server.server_id);
+	      s->set_table(task->table);
+	      s->set_shard(task->shard);
+	      //  update local tables
+	      CHECK(tables.find(task->table)!=tables.end());
+	      GlobalTable *t = tables.at(task->table);
+	      t->get_partition_info(task->shard)->owner = server.server_id;
+	      delete task;
+	    }
+	  }
+	  VLOG(3)<<"finish table assignment, req size "<<req.assign_size();
+	  network->SyncBroadcast(MTYPE_SHARD_ASSIGNMENT, MTYPE_SHARD_ASSIGNMENT_DONE, req);
+	  VLOG(3)<<"finish table server init";
+}
+
+void worker_table_init(){
+	table_server = new TableServer();
+	table_server->StartTableServer(tables);
+	VLOG(3) << "done starting table server";
+}
+
+double random_double(){
+	return static_cast<double>(rand())/static_cast<double>(RAND_MAX);
+}
+
+// popular table with random large or small messages.
+// the message distribution specified in FLAGS_large_precentage
+void coordinator_load_data(){
+	auto table = static_cast<TypedGlobalTable<int,FloatVector>*>(tables[0]);
+
+	num_keys = 0;
+  int nservers=context->num_table_servers();
+	for (int i = 0; i < SIZE; i++) {
+		int total = 0;
+    int threshold=std::max(FLAGS_threshold,0l);//  sizes[i]/nservers);
+    while (total < sizes[i]) {
+      FloatVector* fv = new FloatVector();
+      for (int j = 0; j + total < sizes[i] && j < threshold; j++)
+        fv->add_data(
+            static_cast<float>(rand())
+            / static_cast<float>(RAND_MAX));
+      table->put(num_keys,*fv);
+      total += threshold;
+      num_keys++;
+    }
+	}
+	VLOG(3) << "Loaded data successfully ... " << num_keys << " messages";
+}
+
+void get(TypedGlobalTable<int,FloatVector>* table, ofstream &latency){
+	double start , end;
+  StateQueue<int> state(num_keys);
+  FloatVector v;
+  /*
+	for (int i=0; i<num_keys; i++){
+    start = Now();
+    table->get(i);
+    end=Now();
+    latency << "get: " << (end - start) << endl;
+  }
+  */
+  start=Now();
+	for (int i=0; i<num_keys; i++){
+    if(table->async_get(i, &v))
+      state.Invalid(i);
+	}
+  latency << "send get: " << (Now() - start) << endl;
+  start=Now();
+  while(state.HasValid()){
+    int key=state.Next();
+    if(table->async_get_collect(&key, &v))
+      state.Invalid(key);
+    sleep(0.001);
+  }
+  latency << "collect get: " << (Now() - start) << endl;
+}
+
+void update(TypedGlobalTable<int,FloatVector>* table, ofstream &latency){
+	double start, end;
+	for (int i=0; i<num_keys; i++){
+		start = Now();
+		table->update(i,*value_msg[i]);
+    end=Now();
+		latency << "update: " << (end - start) << endl;
+	}
+}
+
+void worker_test_data(){
+	init_messages();
+	auto table = static_cast<TypedGlobalTable<int,FloatVector>*>(tables[0]);
+
+	ofstream latency(StringPrintf("latency_%d",NetworkThread::Get()->id()));
+	ofstream throughput(StringPrintf("throughput_%d", NetworkThread::Get()->id()));
+	double start, end;
+	for (int i=0; i<FLAGS_iterations; i++){
+		start = Now();
+		get(table, latency);
+    end=Now();
+		throughput << "get: " << (end - start) << " over " << num_keys << " ops " << endl;
+		start = Now();
+		update(table, latency);
+    end=Now();
+		throughput << "update: " << (end - start) << " over " << num_keys << " ops " << endl;
+    sleep(10);
+	}
+	latency.close();
+	throughput.close();
+
+}
+
+void print_table_stats(){
+	auto table = static_cast<TypedGlobalTable<int,FloatVector>*>(tables[0]);
+	ofstream log_file(StringPrintf("log_variance_%d", NetworkThread::Get()->id()));
+	log_file << "table size at process "<< NetworkThread::Get()->id()<<" = " << table->stats()["TABLE_SIZE"] << endl;
+	log_file.close();
+}
+
+void shutdown(){
+	if (context->AmICoordinator()){
+		VLOG(3) << "Coordinator is shutting down ...";
+		EmptyMessage msg;
+		for (int i=0; i<context->num_processes()-1; i++)
+			network->Read(MPI::ANY_SOURCE, MTYPE_WORKER_END, &msg);
+		 EmptyMessage shutdown_msg;
+		  for (int i = 0; i < network->size() - 1; i++) {
+		    network->Send(i, MTYPE_WORKER_SHUTDOWN, shutdown_msg);
+		  }
+		  network->Flush();
+		  network->Shutdown();
+	}
+	else{
+		VLOG(3) << "Worker " << network->id() << " is shutting down ...";
+	  network->Flush();
+	  VLOG(3) << "Done flushing the network thread";
+	  network->Send(GlobalContext::kCoordinatorRank, MTYPE_WORKER_END, EmptyMessage());
+	  EmptyMessage msg;
+	  network->Read(GlobalContext::kCoordinatorRank, MTYPE_WORKER_SHUTDOWN, &msg);
+	  VLOG(3) << "Worker received MTYPE_WORKER_SHUTDOWN";
+
+	  table_server->ShutdownTableServer();
+	  VLOG(3) << "Flushing node " << network->id();
+	  network->Shutdown();
+	}
+}
+
+
+int main(int argc, char **argv) {
+	FLAGS_logtostderr = 1;
+	google::InitGoogleLogging(argv[0]);
+	gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+	context = GlobalContext::Get(FLAGS_system_conf, FLAGS_model_conf);
+	network = NetworkThread::Get();
+	VLOG(3) << "*** testing memory servers, with "
+			<< context->num_table_servers() << " servers";
+
+
+	create_mem_table(0,context->num_table_servers());
+
+  LOG(INFO)<<"threshold: "<<FLAGS_threshold<<" nworkers: "<<FLAGS_workers;
+	if (context->AmICoordinator()){
+		coordinator_assign_tables(0);
+		coordinator_load_data();
+		network->barrier();
+	}
+	else{
+		worker_table_init();
+		network->barrier();
+		VLOG(3) << "passed the barrier";
+		print_table_stats();
+
+		//Sleep(1);
+    if(network->id()<FLAGS_workers)
+      worker_test_data();
+	}
+
+	shutdown();
+	return 0;
+}
+
+