You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/08/04 17:22:58 UTC

[incubator-mxnet] branch master updated: [MXNET-23] Adding support to profile kvstore server during distributed training (#11215)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4649bfa  [MXNET-23] Adding support to profile kvstore server during distributed training  (#11215)
4649bfa is described below

commit 4649bfa641ad4129b3e83ea0af14b489e512f8f4
Author: Rahul Huilgol <ra...@gmail.com>
AuthorDate: Sat Aug 4 10:22:50 2018 -0700

    [MXNET-23] Adding support to profile kvstore server during distributed training  (#11215)
    
    * server profiling
    
    merge with master
    
    cleanup old code
    
    added a check and better info message
    
    add functions for C compatibility
    
    fix doc
    
    lint fixes
    
    fix compile issues
    
    lint fix
    
    build error
    
    update function signatures to preserve compatibility
    
    fix comments
    
    lint
    
    * add part1 of test
    
    * add integration test
---
 ci/docker/runtime_functions.sh             |   1 +
 example/image-classification/common/fit.py |  23 +++++-
 include/mxnet/c_api.h                      |  59 ++++++++++++--
 include/mxnet/kvstore.h                    |  26 +++++++
 python/mxnet/kvstore.py                    |   8 +-
 python/mxnet/profiler.py                   |  79 +++++++++++++++----
 src/c_api/c_api_profile.cc                 |  87 ++++++++++++++++++---
 src/kvstore/gradient_compression.cc        |  21 +----
 src/kvstore/kvstore_dist.h                 |   9 +++
 src/kvstore/kvstore_dist_server.h          | 121 ++++++++++++++++++++++-------
 src/kvstore/kvstore_local.h                |  16 ++++
 tests/nightly/test_server_profiling.py     |  69 ++++++++++++++++
 12 files changed, 434 insertions(+), 85 deletions(-)

diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 2147190..1c861be 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -731,6 +731,7 @@ integrationtest_ubuntu_cpu_dist_kvstore() {
     ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision
     ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=compressed_cpu
     ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=compressed_cpu --no-multiprecision
+    ../../tools/launch.py -n 3 --launcher local python test_server_profiling.py
 }
 
 integrationtest_ubuntu_gpu_scala() {
diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py
index 67cda78..b3b1305 100755
--- a/example/image-classification/common/fit.py
+++ b/example/image-classification/common/fit.py
@@ -135,6 +135,12 @@ def add_fit_args(parser):
                        help='the epochs to ramp-up lr to scaled large-batch value')
     train.add_argument('--warmup-strategy', type=str, default='linear',
                        help='the ramping-up strategy for large batch sgd')
+    train.add_argument('--profile-worker-suffix', type=str, default='',
+                       help='profile workers actions into this file. During distributed training\
+                             filename saved will be rank1_ followed by this suffix')
+    train.add_argument('--profile-server-suffix', type=str, default='',
+                       help='profile server actions into a file with name like rank1_ followed by this suffix \
+                             during distributed training')
     return train
 
 
@@ -150,6 +156,17 @@ def fit(args, network, data_loader, **kwargs):
     if args.gc_type != 'none':
         kv.set_gradient_compression({'type': args.gc_type,
                                      'threshold': args.gc_threshold})
+    if args.profile_server_suffix:
+        mx.profiler.set_config(filename=args.profile_server_suffix, profile_all=True, profile_process='server')
+        mx.profiler.set_state(state='run', profile_process='server')
+
+    if args.profile_worker_suffix:
+        if kv.num_workers > 1:
+            filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix
+        else:
+            filename = args.profile_worker_suffix
+        mx.profiler.set_config(filename=filename, profile_all=True, profile_process='worker')
+        mx.profiler.set_state(state='run', profile_process='worker')
 
     # logging
     head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
@@ -180,7 +197,6 @@ def fit(args, network, data_loader, **kwargs):
                 logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
                              args.disp_batches * args.batch_size / (time.time() - tic))
                 tic = time.time()
-
         return
 
     # load model
@@ -314,3 +330,8 @@ def fit(args, network, data_loader, **kwargs):
               epoch_end_callback=checkpoint,
               allow_missing=True,
               monitor=monitor)
+
+    if args.profile_server_suffix:
+        mx.profiler.set_state(state='run', profile_process='server')
+    if args.profile_worker_suffix:
+        mx.profiler.set_state(state='run', profile_process='worker')
\ No newline at end of file
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 75147cf..6bbe9df 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -230,7 +230,19 @@ MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id);
 MXNET_DLL int MXNotifyShutdown();
 
 /*!
- * \brief Set up configuration of profiler
+ * \brief Set up configuration of profiler for the process passed as profile_process in keys
+ * \param num_params Number of parameters
+ * \param keys array of parameter keys
+ * \param vals array of parameter values
+ * \param kvstoreHandle handle to kvstore
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXSetProcessProfilerConfig(int num_params, const char* const* keys,
+                                         const char* const* vals,
+                                         KVStoreHandle kvstoreHandle);
+
+/*!
+ * \brief Set up configuration of profiler for worker/current process
  * \param num_params Number of parameters
  * \param keys array of parameter keys
  * \param vals array of parameter values
@@ -239,7 +251,21 @@ MXNET_DLL int MXNotifyShutdown();
 MXNET_DLL int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals);
 
 /*!
- * \brief Set up state of profiler
+ * \brief Set up state of profiler for either worker or server process
+ * \param state indicate the working state of profiler,
+ *  profiler not running when state == 0,
+ *  profiler running when state == 1
+ * \param profile_process an int,
+ * when 0 command is for worker/current process,
+ * when 1 command is for server process
+ * \param kvstoreHandle handle to kvstore, needed for server process profiling
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXSetProcessProfilerState(int state, int profile_process,
+                                        KVStoreHandle kvStoreHandle);
+
+/*!
+ * \brief Set up state of profiler for current process
  * \param state indicate the working state of profiler,
  *  profiler not running when state == 0,
  *  profiler running when state == 1
@@ -250,12 +276,23 @@ MXNET_DLL int MXSetProfilerState(int state);
 /*!
  * \brief Save profile and stop profiler
  * \param finished true if stat output should stop after this point
+ * \param profile_process an int,
+ * when 0 command is for worker/current process,
+ * when 1 command is for server process
+ * \param kvstoreHandle handle to kvstore
  * \return 0 when success, -1 when failure happens.
  */
-MXNET_DLL int MXDumpProfile(int finished);
+MXNET_DLL int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle);
 
 
 /*!
+ * \brief Save profile and stop profiler for worker/current process
+ * \param finished true if stat output should stop after this point
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXDumpProfile(int finished);
+
+/*!
  * \brief Print aggregate stats to the a string
  * \param out_str Will receive a pointer to the output string
  * \param reset Clear the aggregate stats after printing
@@ -267,6 +304,16 @@ MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset);
 /*!
  * \brief Pause profiler tuning collection
  * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
+ * \param profile_process integer which denotes whether to process worker or server process
+ * \param kvstoreHandle handle to kvstore
+ * \return 0 when success, -1 when failure happens.
+ * \note pausing and resuming is global and not recursive
+ */
+MXNET_DLL int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle);
+
+/*!
+ * \brief Pause profiler tuning collection for worker/current process
+ * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
  * \return 0 when success, -1 when failure happens.
  * \note pausing and resuming is global and not recursive
  */
@@ -2145,8 +2192,7 @@ typedef void (MXKVStoreServerController)(int head,
                                          void *controller_handle);
 
 /**
- * \return Run as server (or scheduler)
- *
+ * \brief Run as server (or scheduler)
  * \param handle handle to the KVStore
  * \param controller the user-defined server controller
  * \param controller_handle helper handle for implementing controller
@@ -2157,8 +2203,7 @@ MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle,
                                  void *controller_handle);
 
 /**
- * \return Send a command to all server nodes
- *
+ * \brief Send a command to all server nodes
  * \param handle handle to the KVStore
  * \param cmd_id the head of the command
  * \param cmd_body the body of the command
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index e10bd21..a73d963 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -38,6 +38,18 @@
 #endif  // MXNET_USE_DIST_KVSTORE
 
 namespace mxnet {
+
+/*!
+ * \brief enum to denote types of commands kvstore sends to server regarding profiler
+ * kSetConfig sets profiler configs. Similar to mx.profiler.set_config()
+ * kState allows changing state of profiler to stop or run
+ * kPause allows pausing and resuming of profiler
+ * kDump asks profiler to dump output
+ */
+enum class KVStoreServerProfilerCommand {
+  kSetConfig, kState, kPause, kDump
+};
+
 /*!
  * \brief distributed key-value store
  *
@@ -365,6 +377,20 @@ class KVStore {
   virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { }
 
   /**
+   * \brief Sends server profiler commands to all server nodes
+   * Only the worker with rank=0 sends the command which will be received by all servers
+   * \param type ProfilerCommand type
+   * \param params parameters for that command in the form of a string
+   */
+  virtual void SetServerProfilerCommand(const KVStoreServerProfilerCommand type,
+                                        const std::string& params) {
+    LOG(INFO) << "Unable to pass server the profiler command. If you are using "
+              << "distributed kvstore, you need to compile with USE_DIST_KVSTORE=1."
+              << "If you are training on single machine, then there is no server process"
+              << "to profile. Please profile the worker process instead.";
+  }
+
+  /**
    * \brief the prototype of a server controller
    */
   typedef std::function<void(int, const std::string&)> Controller;
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index 6097336..a548175 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -28,6 +28,7 @@ from .base import _LIB, c_str_array, c_handle_array, c_array, c_array_buf, c_str
 from .base import check_call, string_types, mx_uint, py_str
 from .base import NDArrayHandle, KVStoreHandle
 from . import optimizer as opt
+from .profiler import set_kvstore_handle
 
 def _ctype_key_value(keys, vals):
     """
@@ -88,7 +89,8 @@ def _get_kvstore_server_command_type(command):
                      'kSetMultiPrecision': 1,
                      'kStopServer': 2,
                      'kSyncMode': 3,
-                     'kSetGradientCompression': 4}
+                     'kSetGradientCompression': 4,
+                     'kSetProfilerParams': 5}
     assert (command in command_types), "Unknown command type to send to server"
     return command_types[command]
 
@@ -670,4 +672,6 @@ def create(name='local'):
     handle = KVStoreHandle()
     check_call(_LIB.MXKVStoreCreate(c_str(name),
                                     ctypes.byref(handle)))
-    return KVStore(handle)
+    kv = KVStore(handle)
+    set_kvstore_handle(kv.handle)
+    return kv
diff --git a/python/mxnet/profiler.py b/python/mxnet/profiler.py
index 0e7a31c..0b5e85b 100644
--- a/python/mxnet/profiler.py
+++ b/python/mxnet/profiler.py
@@ -22,8 +22,13 @@
 from __future__ import absolute_import
 import ctypes
 import warnings
-from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str
+from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str, KVStoreHandle
 
+profiler_kvstore_handle = KVStoreHandle()
+
+def set_kvstore_handle(handle):
+    global profiler_kvstore_handle
+    profiler_kvstore_handle = handle
 
 def set_config(**kwargs):
     """Set up the configure of profiler (only accepts keyword arguments).
@@ -49,12 +54,17 @@ def set_config(**kwargs):
     aggregate_stats : boolean,
         whether to maintain aggregate stats in memory for console
         dump.  Has some negative performance impact.
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
     """
     kk = kwargs.keys()
     vv = kwargs.values()
-    check_call(_LIB.MXSetProfilerConfig(len(kwargs),
-                                        c_str_array([key for key in kk]),
-                                        c_str_array([str(val) for val in vv])))
+    check_call(_LIB.MXSetProcessProfilerConfig(len(kwargs),
+                                               c_str_array([key for key in kk]),
+                                               c_str_array([str(val) for val in vv]),
+                                               profiler_kvstore_handle))
 
 
 def profiler_set_config(mode='symbolic', filename='profile.json'):
@@ -73,10 +83,10 @@ def profiler_set_config(mode='symbolic', filename='profile.json'):
     keys = c_str_array([key for key in ["profile_" + mode, "filename"]])
     values = c_str_array([str(val) for val in [True, filename]])
     assert len(keys) == len(values)
-    check_call(_LIB.MXSetProfilerConfig(len(keys), keys, values))
+    check_call(_LIB.MXSetProcessProfilerConfig(len(keys), keys, values, profiler_kvstore_handle))
 
 
-def set_state(state='stop'):
+def set_state(state='stop', profile_process='worker'):
     """Set up the profiler state to 'run' or 'stop'.
 
     Parameters
@@ -84,9 +94,16 @@ def set_state(state='stop'):
     state : string, optional
         Indicates whether to run the profiler, can
         be 'stop' or 'run'. Default is `stop`.
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
     """
     state2int = {'stop': 0, 'run': 1}
-    check_call(_LIB.MXSetProfilerState(ctypes.c_int(state2int[state])))
+    profile_process2int = {'worker': 0, 'server': 1}
+    check_call(_LIB.MXSetProcessProfilerState(ctypes.c_int(state2int[state]),
+                                              profile_process2int[profile_process],
+                                              profiler_kvstore_handle))
 
 
 def profiler_set_state(state='stop'):
@@ -102,7 +119,7 @@ def profiler_set_state(state='stop'):
                   'Please use profiler.set_state() instead')
     set_state(state)
 
-def dump(finished=True):
+def dump(finished=True, profile_process='worker'):
     """Dump profile and stop profiler. Use this to save profile
     in advance in case your program cannot exit normally.
 
@@ -111,9 +128,16 @@ def dump(finished=True):
     finished : boolean
         Indicates whether to stop statistic output (dumping) after this dump.
         Default is True
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
     """
-    fin = 1 if finished is True else False
-    check_call(_LIB.MXDumpProfile(fin))
+    fin = 1 if finished is True else 0
+    profile_process2int = {'worker': 0, 'server': 1}
+    check_call(_LIB.MXDumpProcessProfile(fin,
+                                         profile_process2int[profile_process],
+                                         profiler_kvstore_handle))
 
 
 def dump_profile():
@@ -138,14 +162,37 @@ def dumps(reset=False):
     return py_str(debug_str.value)
 
 
-def pause():
-    """Pause profiling."""
-    check_call(_LIB.MXProfilePause(int(1)))
+def pause(profile_process='worker'):
+    """Pause profiling.
+
+    Parameters
+    ----------
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
+    """
+    profile_process2int = {'worker': 0, 'server': 1}
+    check_call(_LIB.MXProcessProfilePause(int(1),
+                                          profile_process2int[profile_process],
+                                          profiler_kvstore_handle))
+
 
+def resume(profile_process='worker'):
+    """
+    Resume paused profiling.
 
-def resume():
-    """Resume paused profiling."""
-    check_call(_LIB.MXProfilePause(int(0)))
+    Parameters
+    ----------
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
+    """
+    profile_process2int = {'worker': 0, 'server': 1}
+    check_call(_LIB.MXProcessProfilePause(int(0),
+                                          profile_process2int[profile_process],
+                                          profiler_kvstore_handle))
 
 
 class Domain(object):
diff --git a/src/c_api/c_api_profile.cc b/src/c_api/c_api_profile.cc
index c584177..9c03b33 100644
--- a/src/c_api/c_api_profile.cc
+++ b/src/c_api/c_api_profile.cc
@@ -29,6 +29,7 @@
 #include <dmlc/base.h>
 #include <dmlc/logging.h>
 #include <dmlc/thread_group.h>
+#include <mxnet/kvstore.h>
 #include <stack>
 #include "./c_api_common.h"
 #include "../profiler/profiler.h"
@@ -197,6 +198,10 @@ struct PythonProfileObjects {
 };
 static PythonProfileObjects python_profile_objects;
 
+enum class ProfileProcess {
+  kWorker, kServer
+};
+
 struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
   bool profile_all;
   bool profile_symbolic;
@@ -207,6 +212,7 @@ struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
   bool continuous_dump;
   float dump_period;
   bool aggregate_stats;
+  int profile_process;
   DMLC_DECLARE_PARAMETER(ProfileConfigParam) {
     DMLC_DECLARE_FIELD(profile_all).set_default(false)
       .describe("Profile all.");
@@ -228,6 +234,13 @@ struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
     DMLC_DECLARE_FIELD(aggregate_stats).set_default(false)
       .describe("Maintain aggregate stats, required for MXDumpAggregateStats.  Note that "
       "this can have anegative performance impact.");
+    DMLC_DECLARE_FIELD(profile_process)
+      .add_enum("worker", static_cast<int>(ProfileProcess::kWorker))
+      .add_enum("server", static_cast<int>(ProfileProcess::kServer))
+      .set_default(static_cast<int>(ProfileProcess::kWorker))
+      .describe("Specifies which process to profile: "
+                "worker: this is default. for single node training it should always be worker."
+                "server: for distributed training, this profiles server process");
   }
 };
 
@@ -248,7 +261,8 @@ struct ProfileMarkerScopeParam : public dmlc::Parameter<ProfileMarkerScopeParam>
 
 DMLC_REGISTER_PARAMETER(ProfileMarkerScopeParam);
 
-int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals) {
+int MXSetProcessProfilerConfig(int num_params, const char* const* keys, const char* const* vals,
+                               KVStoreHandle kvstoreHandle) {
     mxnet::IgnoreProfileCallScope ignore;
   API_BEGIN();
     std::vector<std::pair<std::string, std::string>> kwargs;
@@ -260,19 +274,37 @@ int MXSetProfilerConfig(int num_params, const char* const* keys, const char* con
     }
     ProfileConfigParam param;
     param.Init(kwargs);
-    int mode = 0;
-    if (param.profile_api || param.profile_all)        { mode |= profiler::Profiler::kAPI; }
-    if (param.profile_symbolic || param.profile_all)   { mode |= profiler::Profiler::kSymbolic; }
-    if (param.profile_imperative || param.profile_all) { mode |= profiler::Profiler::kImperative; }
-    if (param.profile_memory || param.profile_all)     { mode |= profiler::Profiler::kMemory; }
-    profiler::Profiler::Get()->SetConfig(profiler::Profiler::ProfilerMode(mode),
-                                         std::string(param.filename),
-                                         param.continuous_dump,
-                                         param.dump_period,
-                                         param.aggregate_stats);
+    if (static_cast<ProfileProcess>(param.profile_process) == ProfileProcess::kServer) {
+      std::ostringstream os;
+      for (int i = 0; i < num_params; ++i) {
+        // this will be sent to the server now, those configs shouldn't have profile server again
+        if (strcmp(keys[i], "profile_process") == 0) continue;
+        os << keys[i] << ":" << vals[i];
+        if (i != num_params - 1) os << ",";
+      }
+      CHECK(kvstoreHandle) << "KVStoreHandle passed to profiler is null";
+      static_cast<KVStore*>(kvstoreHandle)->SetServerProfilerCommand(
+      mxnet::KVStoreServerProfilerCommand::kSetConfig, os.str());
+    } else {
+      int mode = 0;
+      if (param.profile_api || param.profile_all)        { mode |= profiler::Profiler::kAPI; }
+      if (param.profile_symbolic || param.profile_all)   { mode |= profiler::Profiler::kSymbolic; }
+      if (param.profile_imperative ||
+          param.profile_all) { mode |= profiler::Profiler::kImperative; }
+      if (param.profile_memory || param.profile_all)     { mode |= profiler::Profiler::kMemory; }
+      profiler::Profiler::Get()->SetConfig(profiler::Profiler::ProfilerMode(mode),
+                                           std::string(param.filename),
+                                           param.continuous_dump,
+                                           param.dump_period,
+                                           param.aggregate_stats);
+    }
   API_END();
 }
 
+int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals) {
+  return MXSetProcessProfilerConfig(num_params, keys, vals, nullptr);
+}
+
 int MXAggregateProfileStatsPrint(const char **out_str, int reset) {
   MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
   API_BEGIN();
@@ -293,19 +325,40 @@ int MXAggregateProfileStatsPrint(const char **out_str, int reset) {
 }
 
 int MXDumpProfile(int finished) {
+  return MXDumpProcessProfile(finished, static_cast<int>(ProfileProcess::kWorker), nullptr);
+}
+
+int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle) {
   mxnet::IgnoreProfileCallScope ignore;
   API_BEGIN();
+  if (static_cast<ProfileProcess>(profile_process) == ProfileProcess::kServer) {
+    CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null";
+    static_cast<KVStore*>(kvStoreHandle)->SetServerProfilerCommand(
+      mxnet::KVStoreServerProfilerCommand::kDump,
+      std::to_string(finished));
+  } else {
     profiler::Profiler *profiler = profiler::Profiler::Get();
     CHECK(profiler->IsEnableOutput())
       << "Profiler hasn't been run. Config and start profiler first";
     profiler->DumpProfile(finished != 0);
+  }
   API_END()
 }
 
 int MXSetProfilerState(int state) {
+  return MXSetProcessProfilerState(state, static_cast<int>(ProfileProcess::kWorker), nullptr);
+}
+
+int MXSetProcessProfilerState(int state, int profile_process, KVStoreHandle kvStoreHandle) {
   mxnet::IgnoreProfileCallScope ignore;
   // state, kNotRunning: 0, kRunning: 1
   API_BEGIN();
+  if (static_cast<ProfileProcess>(profile_process) == ProfileProcess::kServer) {
+    CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null";
+    static_cast<KVStore*>(kvStoreHandle)->SetServerProfilerCommand(
+      mxnet::KVStoreServerProfilerCommand::kState,
+      std::to_string(state));
+  } else {
     switch (state) {
       case profiler::Profiler::kNotRunning:
         profiler::vtune::vtune_pause();
@@ -315,6 +368,7 @@ int MXSetProfilerState(int state) {
         break;
     }
     profiler::Profiler::Get()->SetState(profiler::Profiler::ProfilerState(state));
+  }
   API_END();
 }
 
@@ -450,8 +504,18 @@ int MXProfileDurationStop(ProfileHandle duration_handle) {
 }
 
 int MXProfilePause(int paused) {
+  return MXProcessProfilePause(paused, static_cast<int>(ProfileProcess::kWorker), nullptr);
+}
+
+int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle) {
   mxnet::IgnoreProfileCallScope ignore;
   API_BEGIN();
+  if (static_cast<ProfileProcess>(profile_process) == ProfileProcess::kServer) {
+    CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null";
+    static_cast<KVStore*>(kvStoreHandle)->SetServerProfilerCommand(
+      mxnet::KVStoreServerProfilerCommand::kPause,
+      std::to_string(paused));
+  } else {
     if (paused) {
       profiler::vtune::vtune_pause();
       profiler::Profiler::Get()->set_paused(true);
@@ -459,6 +523,7 @@ int MXProfilePause(int paused) {
       profiler::Profiler::Get()->set_paused(false);
       profiler::vtune::vtune_resume();
     }
+  }
   API_END();
 }
 
diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc
index e94a057..e4a06fa 100644
--- a/src/kvstore/gradient_compression.cc
+++ b/src/kvstore/gradient_compression.cc
@@ -23,31 +23,14 @@
  * \author Rahul Huilgol
  */
 
-#include <sstream>
 #include <vector>
+#include "kvstore_local.h"
 #include "gradient_compression.h"
 #include "gradient_compression-inl.h"
 
 namespace mxnet {
 namespace kvstore {
 
-/*!
- * \brief Splits a string into smaller strings using char as delimiter
- * Example: "a,b,c,,d" is split into ["a","b","c","","d"]
- * \param s string to split
- * \param delim char to split string around
- * \param result container for tokens extracted after splitting
- */
-template<typename Out>
-void split(const std::string &s, const char delim, Out result) {
-  std::stringstream ss;
-  ss.str(s);
-  std::string item;
-  while (std::getline(ss, item, delim)) {
-    *(result++) = item;
-  }
-}
-
 DMLC_REGISTER_PARAMETER(GradientCompressionParam);
 
 GradientCompression::GradientCompression() {
@@ -90,7 +73,7 @@ std::string GradientCompression::EncodeParams() {
 
 void GradientCompression::DecodeParams(const std::string &s) {
   std::vector<std::string> elems;
-  split(s, ',', std::back_inserter(elems));
+  mxnet::kvstore::split(s, ',', std::back_inserter(elems));
   type_ = static_cast<CompressionType>(stoi(elems[0]));
   if (elems.size() > 1) {
     if (!elems[1].empty()) {
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 7e2f5cb..23fbf67 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -93,6 +93,15 @@ class KVStoreDist : public KVStoreLocal {
     }
   }
 
+  void SetServerProfilerCommand(const KVStoreServerProfilerCommand type,
+                                const std::string& params) override {
+    if (get_rank() == 0) {
+      SendCommandToServers(static_cast<int>(CommandType::kSetProfilerParams),
+                           params + std::to_string(static_cast<int>(type)));
+    }
+  }
+
+
   void Barrier() override {
     ps::Postoffice::Get()->Barrier(ps_worker_->get_customer()->customer_id(), ps::kWorkerGroup);
   }
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index 451fb78..372b58d 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -24,6 +24,9 @@
  */
 #ifndef MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_
 #define MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_
+#include <mxnet/c_api.h>
+#include <mxnet/kvstore.h>
+#include <ps/ps.h>
 #include <queue>
 #include <string>
 #include <mutex>
@@ -32,8 +35,7 @@
 #include <functional>
 #include <future>
 #include <vector>
-#include "ps/ps.h"
-#include "mxnet/kvstore.h"
+#include "../profiler/profiler.h"
 #include "../operator/tensor/elemwise_binary_op-inl.h"
 #include "../operator/tensor/init_op.h"
 
@@ -42,7 +44,8 @@ namespace kvstore {
 
 // maintain same order in frontend.
 enum class CommandType {
-  kController, kSetMultiPrecision, kStopServer, kSyncMode, kSetGradientCompression,
+  kController, kSetMultiPrecision, kStopServer, kSyncMode,
+  kSetGradientCompression, kSetProfilerParams
 };
 
 enum class RequestType {
@@ -164,6 +167,7 @@ class KVStoreDistServer {
   }
 
   ~KVStoreDistServer() {
+    profiler::Profiler::Get()->SetState(profiler::Profiler::ProfilerState(0));
     delete ps_server_;
   }
 
@@ -194,27 +198,37 @@ class KVStoreDistServer {
 
   void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
     CommandType recved_type = static_cast<CommandType>(recved.head);
-    if (recved_type == CommandType::kStopServer) {
-      exec_.Stop();
-    } else if (recved_type == CommandType::kSyncMode) {
-      sync_mode_ = true;
-    } else if (recved_type == CommandType::kSetGradientCompression) {
-      gradient_compression_->DecodeParams(recved.body);
-    } else if (recved_type == CommandType::kSetMultiPrecision) {
-      // uses value 1 for message id from frontend
-      if (!multi_precision_) {
-        multi_precision_ = true;
-        CreateMultiPrecisionCopies();
-      }
-    } else if (recved_type == CommandType::kController) {
-      // value of 0
-      // let the main thread to execute ctrl, which is necessary for python
-      exec_.Exec([this, recved]() {
-          CHECK(controller_);
-          controller_(recved.head, recved.body);
-        });
-    } else {
-      LOG(FATAL) << "Unknown command type received " << recved.head;
+    switch (recved_type) {
+      case CommandType::kStopServer:
+        exec_.Stop();
+        break;
+      case CommandType::kSyncMode:
+        sync_mode_ = true;
+        break;
+      case CommandType::kSetGradientCompression:
+        gradient_compression_->DecodeParams(recved.body);
+        break;
+      case CommandType::kSetProfilerParams:
+        // last char is the type of profiler command
+        ProcessServerProfilerCommands(static_cast<KVStoreServerProfilerCommand>
+                                                  (recved.body.back() - '0'),
+                                      recved.body);
+        break;
+      case CommandType::kSetMultiPrecision:
+        // uses value 1 for message id from frontend
+        if (!multi_precision_) {
+          multi_precision_ = true;
+          CreateMultiPrecisionCopies();
+        }
+        break;
+      case CommandType::kController:
+        // this uses value 0 for message id from frontend
+        // let the main thread to execute ctrl, which is necessary for python
+        exec_.Exec([this, recved]() {
+            CHECK(controller_);
+            controller_(recved.head, recved.body);
+          });
+        break;
     }
     app->Response(recved);
   }
@@ -225,11 +239,11 @@ class KVStoreDistServer {
    * some keys are initialized before optimizer is set.
    */
   void CreateMultiPrecisionCopies() {
-    for (auto const& stored_entry : store_) {
+    for (auto const &stored_entry : store_) {
       const int key = stored_entry.first;
-      const NDArray& stored = stored_entry.second;
+      const NDArray &stored = stored_entry.second;
       if (stored.dtype() != mshadow::kFloat32) {
-        auto& stored_realt = store_realt_[key];
+        auto &stored_realt = store_realt_[key];
         if (stored.storage_type() == kRowSparseStorage) {
           stored_realt = NDArray(kRowSparseStorage, stored.shape(), stored.ctx(),
                                  true, mshadow::kFloat32);
@@ -237,7 +251,7 @@ class KVStoreDistServer {
           stored_realt = NDArray(stored.shape(), stored.ctx(), false, mshadow::kFloat32);
         }
 
-        auto& update = update_buf_[key];
+        auto &update = update_buf_[key];
         if (!update.merged.is_none()) {
           if (update.merged.storage_type() == kRowSparseStorage) {
             update.merged = NDArray(kRowSparseStorage, update.merged.shape(), update.merged.ctx(),
@@ -254,11 +268,60 @@ class KVStoreDistServer {
         CopyFromTo(stored, stored_realt);
       }
     }
-    for (auto const& stored_realt_entry : store_realt_) {
+    for (auto const &stored_realt_entry : store_realt_) {
       stored_realt_entry.second.WaitToRead();
     }
   }
 
+  void ProcessServerProfilerCommands(KVStoreServerProfilerCommand type, const std::string& body) {
+    switch (type) {
+      case KVStoreServerProfilerCommand::kSetConfig:
+        SetProfilerConfig(body.substr(0, body.size() - 1));
+        break;
+      case KVStoreServerProfilerCommand::kState:
+        MXSetProfilerState(static_cast<int>(body.front() - '0'));
+        break;
+      case KVStoreServerProfilerCommand::kPause:
+        MXProfilePause(static_cast<int>(body.front() - '0'));
+        break;
+      case KVStoreServerProfilerCommand::kDump:
+        MXDumpProfile(static_cast<int>(body.front() - '0'));
+        break;
+    }
+  }
+
+  void SetProfilerConfig(std::string params_str) {
+    std::vector<std::string> elems;
+    mxnet::kvstore::split(params_str, ',', std::back_inserter(elems));
+    std::vector<const char*> ckeys;
+    std::vector<const char*> cvals;
+    ckeys.reserve(elems.size());
+    cvals.reserve(elems.size());
+
+    for (size_t i=0; i < elems.size(); i++) {
+      std::vector<std::string> parts;
+      mxnet::kvstore::split(elems[i], ':', std::back_inserter(parts));
+      CHECK_EQ(parts.size(), 2) << "Improper profiler config passed from worker";
+      CHECK(!parts[0].empty()) << "ProfilerConfig parameter is empty";
+      CHECK(!parts[1].empty()) << "ProfilerConfig value is empty for parameter "<< parts[0];
+      if (parts[0] == "filename") {
+        parts[1] = "rank" + std::to_string(ps::MyRank()) + "_" + parts[1];
+      }
+      char* ckey = new char[parts[0].length() + 1];
+      std::snprintf(ckey, parts[0].length() + 1, "%s", parts[0].c_str());
+      ckeys.push_back(ckey);
+
+      char* cval = new char[parts[1].length() + 1];
+      std::snprintf(cval, parts[1].length() + 1, "%s", parts[1].c_str());
+      cvals.push_back(cval);
+    }
+    MXSetProfilerConfig(elems.size(), &ckeys[0], &cvals[0]);
+    for (size_t i=0; i < ckeys.size(); i++) {
+      delete[] ckeys[i];
+      delete[] cvals[i];
+    }
+  }
+
   void DataHandleEx(const ps::KVMeta& req_meta,
                     const ps::KVPairs<char>& req_data,
                     ps::KVServer<char>* server) {
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 324bc2c..4e004a3 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -40,6 +40,22 @@
 
 namespace mxnet {
 namespace kvstore {
+/*!
+ * \brief Splits a string into smaller strings using char as delimiter
+ * Example: "a,b,c,,d" is split into ["a","b","c","","d"]
+ * \param s string to split
+ * \param delim char to split string around
+ * \param result container for tokens extracted after splitting
+ */
+template<typename Out>
+void split(const std::string &s, const char delim, Out result) {
+  std::stringstream ss;
+  ss.str(s);
+  std::string item;
+  while (std::getline(ss, item, delim)) {
+    *(result++) = item;
+  }
+}
 
 enum KeyType {
   kUndefinedKey = -1,
diff --git a/tests/nightly/test_server_profiling.py b/tests/nightly/test_server_profiling.py
new file mode 100644
index 0000000..7d157a3
--- /dev/null
+++ b/tests/nightly/test_server_profiling.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import mxnet as mx
+import json
+
+key = '99'
+shape = (1200, 1200)        # bigger than MXNET_KVSTORE_BIGARRAY_BOUND
+kv = mx.kv.create('dist_sync')
+
+def init_kv():
+    # init kv dns keys
+    kv.init(key, mx.nd.ones(shape))
+    kv.set_optimizer(mx.optimizer.create('sgd'))
+    return kv, kv.rank, kv.num_workers
+
+def test_sync_push_pull():
+    kv, my_rank, nworker = init_kv()
+    def check_default_keys(kv, my_rank):
+        nrepeat = 10
+        # checks pull after push in loop, because behavior during
+        # consecutive pushes doesn't offer any guarantees
+        for i in range(nrepeat):
+            kv.push(key, mx.nd.ones(shape, dtype='float32') * (my_rank+1))
+            val = mx.nd.zeros(shape, dtype='float32')
+            kv.pull(key, out=val)
+            mx.nd.waitall()
+    check_default_keys(kv, my_rank)
+
+if __name__ == "__main__":
+    server_filename_suffix = 'test_profile_server.json'
+    worker_filename_suffix = 'test_profile_worker.json'
+    mx.profiler.set_config(filename=server_filename_suffix, profile_all=True, profile_process='server')
+    mx.profiler.set_config(filename='rank' + str(kv.rank) + '_' + worker_filename_suffix, profile_all=True, profile_process='worker')
+    mx.profiler.set_state(state='run', profile_process='server')
+    mx.profiler.set_state(state='run', profile_process='worker')
+    test_sync_push_pull()
+    mx.profiler.set_state(state='stop', profile_process='server')
+    mx.profiler.set_state(state='stop', profile_process='worker')
+
+    import glob, os
+
+    # will only work when launcher mode is local, as used for integration test
+    if kv.rank == 0:
+        for rank in range(kv.num_workers):
+            for suffix in [worker_filename_suffix, server_filename_suffix]:
+                # throws value error if file is not proper json
+                filename = 'rank' + str(rank) + '_' + suffix
+                print(glob.glob('*'), os.getcwd())
+                with open(filename, 'r') as f:
+                    j = json.load(f)
+
+
+