You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/04 17:22:52 UTC

[GitHub] eric-haibin-lin closed pull request #11215: [MXNET-23] Adding support to profile kvstore server during distributed training

eric-haibin-lin closed pull request #11215: [MXNET-23] Adding support to profile kvstore server during distributed training 
URL: https://github.com/apache/incubator-mxnet/pull/11215
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 8805850e314..428d71960bb 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -759,6 +759,7 @@ integrationtest_ubuntu_cpu_dist_kvstore() {
     ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision
     ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=compressed_cpu
     ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=compressed_cpu --no-multiprecision
+    ../../tools/launch.py -n 3 --launcher local python test_server_profiling.py
 }
 
 integrationtest_ubuntu_gpu_scala() {
diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py
index 67cda78172b..b3b13053add 100755
--- a/example/image-classification/common/fit.py
+++ b/example/image-classification/common/fit.py
@@ -135,6 +135,12 @@ def add_fit_args(parser):
                        help='the epochs to ramp-up lr to scaled large-batch value')
     train.add_argument('--warmup-strategy', type=str, default='linear',
                        help='the ramping-up strategy for large batch sgd')
+    train.add_argument('--profile-worker-suffix', type=str, default='',
+                       help='profile workers actions into this file. During distributed training\
+                             filename saved will be rank1_ followed by this suffix')
+    train.add_argument('--profile-server-suffix', type=str, default='',
+                       help='profile server actions into a file with name like rank1_ followed by this suffix \
+                             during distributed training')
     return train
 
 
@@ -150,6 +156,17 @@ def fit(args, network, data_loader, **kwargs):
     if args.gc_type != 'none':
         kv.set_gradient_compression({'type': args.gc_type,
                                      'threshold': args.gc_threshold})
+    if args.profile_server_suffix:
+        mx.profiler.set_config(filename=args.profile_server_suffix, profile_all=True, profile_process='server')
+        mx.profiler.set_state(state='run', profile_process='server')
+
+    if args.profile_worker_suffix:
+        if kv.num_workers > 1:
+            filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix
+        else:
+            filename = args.profile_worker_suffix
+        mx.profiler.set_config(filename=filename, profile_all=True, profile_process='worker')
+        mx.profiler.set_state(state='run', profile_process='worker')
 
     # logging
     head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
@@ -180,7 +197,6 @@ def fit(args, network, data_loader, **kwargs):
                 logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
                              args.disp_batches * args.batch_size / (time.time() - tic))
                 tic = time.time()
-
         return
 
     # load model
@@ -314,3 +330,8 @@ def fit(args, network, data_loader, **kwargs):
               epoch_end_callback=checkpoint,
               allow_missing=True,
               monitor=monitor)
+
+    if args.profile_server_suffix:
+        mx.profiler.set_state(state='run', profile_process='server')
+    if args.profile_worker_suffix:
+        mx.profiler.set_state(state='run', profile_process='worker')
\ No newline at end of file
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 75147cfd706..6bbe9dfe8f0 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -230,7 +230,19 @@ MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id);
 MXNET_DLL int MXNotifyShutdown();
 
 /*!
- * \brief Set up configuration of profiler
+ * \brief Set up configuration of profiler for the process passed as profile_process in keys
+ * \param num_params Number of parameters
+ * \param keys array of parameter keys
+ * \param vals array of parameter values
+ * \param kvstoreHandle handle to kvstore
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXSetProcessProfilerConfig(int num_params, const char* const* keys,
+                                         const char* const* vals,
+                                         KVStoreHandle kvstoreHandle);
+
+/*!
+ * \brief Set up configuration of profiler for worker/current process
  * \param num_params Number of parameters
  * \param keys array of parameter keys
  * \param vals array of parameter values
@@ -239,7 +251,21 @@ MXNET_DLL int MXNotifyShutdown();
 MXNET_DLL int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals);
 
 /*!
- * \brief Set up state of profiler
+ * \brief Set up state of profiler for either worker or server process
+ * \param state indicate the working state of profiler,
+ *  profiler not running when state == 0,
+ *  profiler running when state == 1
+ * \param profile_process an int,
+ * when 0 command is for worker/current process,
+ * when 1 command is for server process
+ * \param kvstoreHandle handle to kvstore, needed for server process profiling
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXSetProcessProfilerState(int state, int profile_process,
+                                        KVStoreHandle kvStoreHandle);
+
+/*!
+ * \brief Set up state of profiler for current process
  * \param state indicate the working state of profiler,
  *  profiler not running when state == 0,
  *  profiler running when state == 1
@@ -250,11 +276,22 @@ MXNET_DLL int MXSetProfilerState(int state);
 /*!
  * \brief Save profile and stop profiler
  * \param finished true if stat output should stop after this point
+ * \param profile_process an int,
+ * when 0 command is for worker/current process,
+ * when 1 command is for server process
+ * \param kvstoreHandle handle to kvstore
  * \return 0 when success, -1 when failure happens.
  */
-MXNET_DLL int MXDumpProfile(int finished);
+MXNET_DLL int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle);
 
 
+/*!
+ * \brief Save profile and stop profiler for worker/current process
+ * \param finished true if stat output should stop after this point
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXDumpProfile(int finished);
+
 /*!
  * \brief Print aggregate stats to the a string
  * \param out_str Will receive a pointer to the output string
@@ -267,6 +304,16 @@ MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset);
 /*!
  * \brief Pause profiler tuning collection
  * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
+ * \param profile_process integer which denotes whether to process worker or server process
+ * \param kvstoreHandle handle to kvstore
+ * \return 0 when success, -1 when failure happens.
+ * \note pausing and resuming is global and not recursive
+ */
+MXNET_DLL int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle);
+
+/*!
+ * \brief Pause profiler tuning collection for worker/current process
+ * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
  * \return 0 when success, -1 when failure happens.
  * \note pausing and resuming is global and not recursive
  */
@@ -2145,8 +2192,7 @@ typedef void (MXKVStoreServerController)(int head,
                                          void *controller_handle);
 
 /**
- * \return Run as server (or scheduler)
- *
+ * \brief Run as server (or scheduler)
  * \param handle handle to the KVStore
  * \param controller the user-defined server controller
  * \param controller_handle helper handle for implementing controller
@@ -2157,8 +2203,7 @@ MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle,
                                  void *controller_handle);
 
 /**
- * \return Send a command to all server nodes
- *
+ * \brief Send a command to all server nodes
  * \param handle handle to the KVStore
  * \param cmd_id the head of the command
  * \param cmd_body the body of the command
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index e10bd213aa2..a73d9635613 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -38,6 +38,18 @@
 #endif  // MXNET_USE_DIST_KVSTORE
 
 namespace mxnet {
+
+/*!
+ * \brief enum to denote types of commands kvstore sends to server regarding profiler
+ * kSetConfig sets profiler configs. Similar to mx.profiler.set_config()
+ * kState allows changing state of profiler to stop or run
+ * kPause allows pausing and resuming of profiler
+ * kDump asks profiler to dump output
+ */
+enum class KVStoreServerProfilerCommand {
+  kSetConfig, kState, kPause, kDump
+};
+
 /*!
  * \brief distributed key-value store
  *
@@ -364,6 +376,20 @@ class KVStore {
    */
   virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { }
 
+  /**
+   * \brief Sends server profiler commands to all server nodes
+   * Only the worker with rank=0 sends the command which will be received by all servers
+   * \param type ProfilerCommand type
+   * \param params parameters for that command in the form of a string
+   */
+  virtual void SetServerProfilerCommand(const KVStoreServerProfilerCommand type,
+                                        const std::string& params) {
+    LOG(INFO) << "Unable to pass server the profiler command. If you are using "
+              << "distributed kvstore, you need to compile with USE_DIST_KVSTORE=1."
+              << "If you are training on single machine, then there is no server process"
+              << "to profile. Please profile the worker process instead.";
+  }
+
   /**
    * \brief the prototype of a server controller
    */
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index 60973365975..a5481750139 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -28,6 +28,7 @@
 from .base import check_call, string_types, mx_uint, py_str
 from .base import NDArrayHandle, KVStoreHandle
 from . import optimizer as opt
+from .profiler import set_kvstore_handle
 
 def _ctype_key_value(keys, vals):
     """
@@ -88,7 +89,8 @@ def _get_kvstore_server_command_type(command):
                      'kSetMultiPrecision': 1,
                      'kStopServer': 2,
                      'kSyncMode': 3,
-                     'kSetGradientCompression': 4}
+                     'kSetGradientCompression': 4,
+                     'kSetProfilerParams': 5}
     assert (command in command_types), "Unknown command type to send to server"
     return command_types[command]
 
@@ -670,4 +672,6 @@ def create(name='local'):
     handle = KVStoreHandle()
     check_call(_LIB.MXKVStoreCreate(c_str(name),
                                     ctypes.byref(handle)))
-    return KVStore(handle)
+    kv = KVStore(handle)
+    set_kvstore_handle(kv.handle)
+    return kv
diff --git a/python/mxnet/profiler.py b/python/mxnet/profiler.py
index 0e7a31c687e..0b5e85b1eb5 100644
--- a/python/mxnet/profiler.py
+++ b/python/mxnet/profiler.py
@@ -22,8 +22,13 @@
 from __future__ import absolute_import
 import ctypes
 import warnings
-from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str
+from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str, KVStoreHandle
 
+profiler_kvstore_handle = KVStoreHandle()
+
+def set_kvstore_handle(handle):
+    global profiler_kvstore_handle
+    profiler_kvstore_handle = handle
 
 def set_config(**kwargs):
     """Set up the configure of profiler (only accepts keyword arguments).
@@ -49,12 +54,17 @@ def set_config(**kwargs):
     aggregate_stats : boolean,
         whether to maintain aggregate stats in memory for console
         dump.  Has some negative performance impact.
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
     """
     kk = kwargs.keys()
     vv = kwargs.values()
-    check_call(_LIB.MXSetProfilerConfig(len(kwargs),
-                                        c_str_array([key for key in kk]),
-                                        c_str_array([str(val) for val in vv])))
+    check_call(_LIB.MXSetProcessProfilerConfig(len(kwargs),
+                                               c_str_array([key for key in kk]),
+                                               c_str_array([str(val) for val in vv]),
+                                               profiler_kvstore_handle))
 
 
 def profiler_set_config(mode='symbolic', filename='profile.json'):
@@ -73,10 +83,10 @@ def profiler_set_config(mode='symbolic', filename='profile.json'):
     keys = c_str_array([key for key in ["profile_" + mode, "filename"]])
     values = c_str_array([str(val) for val in [True, filename]])
     assert len(keys) == len(values)
-    check_call(_LIB.MXSetProfilerConfig(len(keys), keys, values))
+    check_call(_LIB.MXSetProcessProfilerConfig(len(keys), keys, values, profiler_kvstore_handle))
 
 
-def set_state(state='stop'):
+def set_state(state='stop', profile_process='worker'):
     """Set up the profiler state to 'run' or 'stop'.
 
     Parameters
@@ -84,9 +94,16 @@ def set_state(state='stop'):
     state : string, optional
         Indicates whether to run the profiler, can
         be 'stop' or 'run'. Default is `stop`.
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
     """
     state2int = {'stop': 0, 'run': 1}
-    check_call(_LIB.MXSetProfilerState(ctypes.c_int(state2int[state])))
+    profile_process2int = {'worker': 0, 'server': 1}
+    check_call(_LIB.MXSetProcessProfilerState(ctypes.c_int(state2int[state]),
+                                              profile_process2int[profile_process],
+                                              profiler_kvstore_handle))
 
 
 def profiler_set_state(state='stop'):
@@ -102,7 +119,7 @@ def profiler_set_state(state='stop'):
                   'Please use profiler.set_state() instead')
     set_state(state)
 
-def dump(finished=True):
+def dump(finished=True, profile_process='worker'):
     """Dump profile and stop profiler. Use this to save profile
     in advance in case your program cannot exit normally.
 
@@ -111,9 +128,16 @@ def dump(finished=True):
     finished : boolean
         Indicates whether to stop statistic output (dumping) after this dump.
         Default is True
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
     """
-    fin = 1 if finished is True else False
-    check_call(_LIB.MXDumpProfile(fin))
+    fin = 1 if finished is True else 0
+    profile_process2int = {'worker': 0, 'server': 1}
+    check_call(_LIB.MXDumpProcessProfile(fin,
+                                         profile_process2int[profile_process],
+                                         profiler_kvstore_handle))
 
 
 def dump_profile():
@@ -138,14 +162,37 @@ def dumps(reset=False):
     return py_str(debug_str.value)
 
 
-def pause():
-    """Pause profiling."""
-    check_call(_LIB.MXProfilePause(int(1)))
+def pause(profile_process='worker'):
+    """Pause profiling.
+
+    Parameters
+    ----------
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
+    """
+    profile_process2int = {'worker': 0, 'server': 1}
+    check_call(_LIB.MXProcessProfilePause(int(1),
+                                          profile_process2int[profile_process],
+                                          profiler_kvstore_handle))
+
 
+def resume(profile_process='worker'):
+    """
+    Resume paused profiling.
 
-def resume():
-    """Resume paused profiling."""
-    check_call(_LIB.MXProfilePause(int(0)))
+    Parameters
+    ----------
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
+    """
+    profile_process2int = {'worker': 0, 'server': 1}
+    check_call(_LIB.MXProcessProfilePause(int(0),
+                                          profile_process2int[profile_process],
+                                          profiler_kvstore_handle))
 
 
 class Domain(object):
diff --git a/src/c_api/c_api_profile.cc b/src/c_api/c_api_profile.cc
index c5841775794..9c03b339e3c 100644
--- a/src/c_api/c_api_profile.cc
+++ b/src/c_api/c_api_profile.cc
@@ -29,6 +29,7 @@
 #include <dmlc/base.h>
 #include <dmlc/logging.h>
 #include <dmlc/thread_group.h>
+#include <mxnet/kvstore.h>
 #include <stack>
 #include "./c_api_common.h"
 #include "../profiler/profiler.h"
@@ -197,6 +198,10 @@ struct PythonProfileObjects {
 };
 static PythonProfileObjects python_profile_objects;
 
+enum class ProfileProcess {
+  kWorker, kServer
+};
+
 struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
   bool profile_all;
   bool profile_symbolic;
@@ -207,6 +212,7 @@ struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
   bool continuous_dump;
   float dump_period;
   bool aggregate_stats;
+  int profile_process;
   DMLC_DECLARE_PARAMETER(ProfileConfigParam) {
     DMLC_DECLARE_FIELD(profile_all).set_default(false)
       .describe("Profile all.");
@@ -228,6 +234,13 @@ struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
     DMLC_DECLARE_FIELD(aggregate_stats).set_default(false)
       .describe("Maintain aggregate stats, required for MXDumpAggregateStats.  Note that "
       "this can have anegative performance impact.");
+    DMLC_DECLARE_FIELD(profile_process)
+      .add_enum("worker", static_cast<int>(ProfileProcess::kWorker))
+      .add_enum("server", static_cast<int>(ProfileProcess::kServer))
+      .set_default(static_cast<int>(ProfileProcess::kWorker))
+      .describe("Specifies which process to profile: "
+                "worker: this is default. for single node training it should always be worker."
+                "server: for distributed training, this profiles server process");
   }
 };
 
@@ -248,7 +261,8 @@ struct ProfileMarkerScopeParam : public dmlc::Parameter<ProfileMarkerScopeParam>
 
 DMLC_REGISTER_PARAMETER(ProfileMarkerScopeParam);
 
-int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals) {
+int MXSetProcessProfilerConfig(int num_params, const char* const* keys, const char* const* vals,
+                               KVStoreHandle kvstoreHandle) {
     mxnet::IgnoreProfileCallScope ignore;
   API_BEGIN();
     std::vector<std::pair<std::string, std::string>> kwargs;
@@ -260,19 +274,37 @@ int MXSetProfilerConfig(int num_params, const char* const* keys, const char* con
     }
     ProfileConfigParam param;
     param.Init(kwargs);
-    int mode = 0;
-    if (param.profile_api || param.profile_all)        { mode |= profiler::Profiler::kAPI; }
-    if (param.profile_symbolic || param.profile_all)   { mode |= profiler::Profiler::kSymbolic; }
-    if (param.profile_imperative || param.profile_all) { mode |= profiler::Profiler::kImperative; }
-    if (param.profile_memory || param.profile_all)     { mode |= profiler::Profiler::kMemory; }
-    profiler::Profiler::Get()->SetConfig(profiler::Profiler::ProfilerMode(mode),
-                                         std::string(param.filename),
-                                         param.continuous_dump,
-                                         param.dump_period,
-                                         param.aggregate_stats);
+    if (static_cast<ProfileProcess>(param.profile_process) == ProfileProcess::kServer) {
+      std::ostringstream os;
+      for (int i = 0; i < num_params; ++i) {
+        // this will be sent to the server now, those configs shouldn't have profile server again
+        if (strcmp(keys[i], "profile_process") == 0) continue;
+        os << keys[i] << ":" << vals[i];
+        if (i != num_params - 1) os << ",";
+      }
+      CHECK(kvstoreHandle) << "KVStoreHandle passed to profiler is null";
+      static_cast<KVStore*>(kvstoreHandle)->SetServerProfilerCommand(
+      mxnet::KVStoreServerProfilerCommand::kSetConfig, os.str());
+    } else {
+      int mode = 0;
+      if (param.profile_api || param.profile_all)        { mode |= profiler::Profiler::kAPI; }
+      if (param.profile_symbolic || param.profile_all)   { mode |= profiler::Profiler::kSymbolic; }
+      if (param.profile_imperative ||
+          param.profile_all) { mode |= profiler::Profiler::kImperative; }
+      if (param.profile_memory || param.profile_all)     { mode |= profiler::Profiler::kMemory; }
+      profiler::Profiler::Get()->SetConfig(profiler::Profiler::ProfilerMode(mode),
+                                           std::string(param.filename),
+                                           param.continuous_dump,
+                                           param.dump_period,
+                                           param.aggregate_stats);
+    }
   API_END();
 }
 
+int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals) {
+  return MXSetProcessProfilerConfig(num_params, keys, vals, nullptr);
+}
+
 int MXAggregateProfileStatsPrint(const char **out_str, int reset) {
   MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
   API_BEGIN();
@@ -293,19 +325,40 @@ int MXAggregateProfileStatsPrint(const char **out_str, int reset) {
 }
 
 int MXDumpProfile(int finished) {
+  return MXDumpProcessProfile(finished, static_cast<int>(ProfileProcess::kWorker), nullptr);
+}
+
+int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle) {
   mxnet::IgnoreProfileCallScope ignore;
   API_BEGIN();
+  if (static_cast<ProfileProcess>(profile_process) == ProfileProcess::kServer) {
+    CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null";
+    static_cast<KVStore*>(kvStoreHandle)->SetServerProfilerCommand(
+      mxnet::KVStoreServerProfilerCommand::kDump,
+      std::to_string(finished));
+  } else {
     profiler::Profiler *profiler = profiler::Profiler::Get();
     CHECK(profiler->IsEnableOutput())
       << "Profiler hasn't been run. Config and start profiler first";
     profiler->DumpProfile(finished != 0);
+  }
   API_END()
 }
 
 int MXSetProfilerState(int state) {
+  return MXSetProcessProfilerState(state, static_cast<int>(ProfileProcess::kWorker), nullptr);
+}
+
+int MXSetProcessProfilerState(int state, int profile_process, KVStoreHandle kvStoreHandle) {
   mxnet::IgnoreProfileCallScope ignore;
   // state, kNotRunning: 0, kRunning: 1
   API_BEGIN();
+  if (static_cast<ProfileProcess>(profile_process) == ProfileProcess::kServer) {
+    CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null";
+    static_cast<KVStore*>(kvStoreHandle)->SetServerProfilerCommand(
+      mxnet::KVStoreServerProfilerCommand::kState,
+      std::to_string(state));
+  } else {
     switch (state) {
       case profiler::Profiler::kNotRunning:
         profiler::vtune::vtune_pause();
@@ -315,6 +368,7 @@ int MXSetProfilerState(int state) {
         break;
     }
     profiler::Profiler::Get()->SetState(profiler::Profiler::ProfilerState(state));
+  }
   API_END();
 }
 
@@ -450,8 +504,18 @@ int MXProfileDurationStop(ProfileHandle duration_handle) {
 }
 
 int MXProfilePause(int paused) {
+  return MXProcessProfilePause(paused, static_cast<int>(ProfileProcess::kWorker), nullptr);
+}
+
+int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle) {
   mxnet::IgnoreProfileCallScope ignore;
   API_BEGIN();
+  if (static_cast<ProfileProcess>(profile_process) == ProfileProcess::kServer) {
+    CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null";
+    static_cast<KVStore*>(kvStoreHandle)->SetServerProfilerCommand(
+      mxnet::KVStoreServerProfilerCommand::kPause,
+      std::to_string(paused));
+  } else {
     if (paused) {
       profiler::vtune::vtune_pause();
       profiler::Profiler::Get()->set_paused(true);
@@ -459,6 +523,7 @@ int MXProfilePause(int paused) {
       profiler::Profiler::Get()->set_paused(false);
       profiler::vtune::vtune_resume();
     }
+  }
   API_END();
 }
 
diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc
index e94a0570d1f..e4a06fa9a1f 100644
--- a/src/kvstore/gradient_compression.cc
+++ b/src/kvstore/gradient_compression.cc
@@ -23,31 +23,14 @@
  * \author Rahul Huilgol
  */
 
-#include <sstream>
 #include <vector>
+#include "kvstore_local.h"
 #include "gradient_compression.h"
 #include "gradient_compression-inl.h"
 
 namespace mxnet {
 namespace kvstore {
 
-/*!
- * \brief Splits a string into smaller strings using char as delimiter
- * Example: "a,b,c,,d" is split into ["a","b","c","","d"]
- * \param s string to split
- * \param delim char to split string around
- * \param result container for tokens extracted after splitting
- */
-template<typename Out>
-void split(const std::string &s, const char delim, Out result) {
-  std::stringstream ss;
-  ss.str(s);
-  std::string item;
-  while (std::getline(ss, item, delim)) {
-    *(result++) = item;
-  }
-}
-
 DMLC_REGISTER_PARAMETER(GradientCompressionParam);
 
 GradientCompression::GradientCompression() {
@@ -90,7 +73,7 @@ std::string GradientCompression::EncodeParams() {
 
 void GradientCompression::DecodeParams(const std::string &s) {
   std::vector<std::string> elems;
-  split(s, ',', std::back_inserter(elems));
+  mxnet::kvstore::split(s, ',', std::back_inserter(elems));
   type_ = static_cast<CompressionType>(stoi(elems[0]));
   if (elems.size() > 1) {
     if (!elems[1].empty()) {
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 7e2f5cb5faa..23fbf67474e 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -93,6 +93,15 @@ class KVStoreDist : public KVStoreLocal {
     }
   }
 
+  void SetServerProfilerCommand(const KVStoreServerProfilerCommand type,
+                                const std::string& params) override {
+    if (get_rank() == 0) {
+      SendCommandToServers(static_cast<int>(CommandType::kSetProfilerParams),
+                           params + std::to_string(static_cast<int>(type)));
+    }
+  }
+
+
   void Barrier() override {
     ps::Postoffice::Get()->Barrier(ps_worker_->get_customer()->customer_id(), ps::kWorkerGroup);
   }
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index 451fb78a622..372b58dbbf3 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -24,6 +24,9 @@
  */
 #ifndef MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_
 #define MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_
+#include <mxnet/c_api.h>
+#include <mxnet/kvstore.h>
+#include <ps/ps.h>
 #include <queue>
 #include <string>
 #include <mutex>
@@ -32,8 +35,7 @@
 #include <functional>
 #include <future>
 #include <vector>
-#include "ps/ps.h"
-#include "mxnet/kvstore.h"
+#include "../profiler/profiler.h"
 #include "../operator/tensor/elemwise_binary_op-inl.h"
 #include "../operator/tensor/init_op.h"
 
@@ -42,7 +44,8 @@ namespace kvstore {
 
 // maintain same order in frontend.
 enum class CommandType {
-  kController, kSetMultiPrecision, kStopServer, kSyncMode, kSetGradientCompression,
+  kController, kSetMultiPrecision, kStopServer, kSyncMode,
+  kSetGradientCompression, kSetProfilerParams
 };
 
 enum class RequestType {
@@ -164,6 +167,7 @@ class KVStoreDistServer {
   }
 
   ~KVStoreDistServer() {
+    profiler::Profiler::Get()->SetState(profiler::Profiler::ProfilerState(0));
     delete ps_server_;
   }
 
@@ -194,27 +198,37 @@ class KVStoreDistServer {
 
   void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
     CommandType recved_type = static_cast<CommandType>(recved.head);
-    if (recved_type == CommandType::kStopServer) {
-      exec_.Stop();
-    } else if (recved_type == CommandType::kSyncMode) {
-      sync_mode_ = true;
-    } else if (recved_type == CommandType::kSetGradientCompression) {
-      gradient_compression_->DecodeParams(recved.body);
-    } else if (recved_type == CommandType::kSetMultiPrecision) {
-      // uses value 1 for message id from frontend
-      if (!multi_precision_) {
-        multi_precision_ = true;
-        CreateMultiPrecisionCopies();
-      }
-    } else if (recved_type == CommandType::kController) {
-      // value of 0
-      // let the main thread to execute ctrl, which is necessary for python
-      exec_.Exec([this, recved]() {
-          CHECK(controller_);
-          controller_(recved.head, recved.body);
-        });
-    } else {
-      LOG(FATAL) << "Unknown command type received " << recved.head;
+    switch (recved_type) {
+      case CommandType::kStopServer:
+        exec_.Stop();
+        break;
+      case CommandType::kSyncMode:
+        sync_mode_ = true;
+        break;
+      case CommandType::kSetGradientCompression:
+        gradient_compression_->DecodeParams(recved.body);
+        break;
+      case CommandType::kSetProfilerParams:
+        // last char is the type of profiler command
+        ProcessServerProfilerCommands(static_cast<KVStoreServerProfilerCommand>
+                                                  (recved.body.back() - '0'),
+                                      recved.body);
+        break;
+      case CommandType::kSetMultiPrecision:
+        // uses value 1 for message id from frontend
+        if (!multi_precision_) {
+          multi_precision_ = true;
+          CreateMultiPrecisionCopies();
+        }
+        break;
+      case CommandType::kController:
+        // this uses value 0 for message id from frontend
+        // let the main thread to execute ctrl, which is necessary for python
+        exec_.Exec([this, recved]() {
+            CHECK(controller_);
+            controller_(recved.head, recved.body);
+          });
+        break;
     }
     app->Response(recved);
   }
@@ -225,11 +239,11 @@ class KVStoreDistServer {
    * some keys are initialized before optimizer is set.
    */
   void CreateMultiPrecisionCopies() {
-    for (auto const& stored_entry : store_) {
+    for (auto const &stored_entry : store_) {
       const int key = stored_entry.first;
-      const NDArray& stored = stored_entry.second;
+      const NDArray &stored = stored_entry.second;
       if (stored.dtype() != mshadow::kFloat32) {
-        auto& stored_realt = store_realt_[key];
+        auto &stored_realt = store_realt_[key];
         if (stored.storage_type() == kRowSparseStorage) {
           stored_realt = NDArray(kRowSparseStorage, stored.shape(), stored.ctx(),
                                  true, mshadow::kFloat32);
@@ -237,7 +251,7 @@ class KVStoreDistServer {
           stored_realt = NDArray(stored.shape(), stored.ctx(), false, mshadow::kFloat32);
         }
 
-        auto& update = update_buf_[key];
+        auto &update = update_buf_[key];
         if (!update.merged.is_none()) {
           if (update.merged.storage_type() == kRowSparseStorage) {
             update.merged = NDArray(kRowSparseStorage, update.merged.shape(), update.merged.ctx(),
@@ -254,11 +268,60 @@ class KVStoreDistServer {
         CopyFromTo(stored, stored_realt);
       }
     }
-    for (auto const& stored_realt_entry : store_realt_) {
+    for (auto const &stored_realt_entry : store_realt_) {
       stored_realt_entry.second.WaitToRead();
     }
   }
 
+  void ProcessServerProfilerCommands(KVStoreServerProfilerCommand type, const std::string& body) {
+    switch (type) {
+      case KVStoreServerProfilerCommand::kSetConfig:
+        SetProfilerConfig(body.substr(0, body.size() - 1));
+        break;
+      case KVStoreServerProfilerCommand::kState:
+        MXSetProfilerState(static_cast<int>(body.front() - '0'));
+        break;
+      case KVStoreServerProfilerCommand::kPause:
+        MXProfilePause(static_cast<int>(body.front() - '0'));
+        break;
+      case KVStoreServerProfilerCommand::kDump:
+        MXDumpProfile(static_cast<int>(body.front() - '0'));
+        break;
+    }
+  }
+
+  void SetProfilerConfig(std::string params_str) {
+    std::vector<std::string> elems;
+    mxnet::kvstore::split(params_str, ',', std::back_inserter(elems));
+    std::vector<const char*> ckeys;
+    std::vector<const char*> cvals;
+    ckeys.reserve(elems.size());
+    cvals.reserve(elems.size());
+
+    for (size_t i=0; i < elems.size(); i++) {
+      std::vector<std::string> parts;
+      mxnet::kvstore::split(elems[i], ':', std::back_inserter(parts));
+      CHECK_EQ(parts.size(), 2) << "Improper profiler config passed from worker";
+      CHECK(!parts[0].empty()) << "ProfilerConfig parameter is empty";
+      CHECK(!parts[1].empty()) << "ProfilerConfig value is empty for parameter "<< parts[0];
+      if (parts[0] == "filename") {
+        parts[1] = "rank" + std::to_string(ps::MyRank()) + "_" + parts[1];
+      }
+      char* ckey = new char[parts[0].length() + 1];
+      std::snprintf(ckey, parts[0].length() + 1, "%s", parts[0].c_str());
+      ckeys.push_back(ckey);
+
+      char* cval = new char[parts[1].length() + 1];
+      std::snprintf(cval, parts[1].length() + 1, "%s", parts[1].c_str());
+      cvals.push_back(cval);
+    }
+    MXSetProfilerConfig(elems.size(), &ckeys[0], &cvals[0]);
+    for (size_t i=0; i < ckeys.size(); i++) {
+      delete[] ckeys[i];
+      delete[] cvals[i];
+    }
+  }
+
   void DataHandleEx(const ps::KVMeta& req_meta,
                     const ps::KVPairs<char>& req_data,
                     ps::KVServer<char>* server) {
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 324bc2c9558..4e004a3a300 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -40,6 +40,22 @@
 
 namespace mxnet {
 namespace kvstore {
+/*!
+ * \brief Splits a string into smaller strings using char as delimiter
+ * Example: "a,b,c,,d" is split into ["a","b","c","","d"]
+ * \param s string to split
+ * \param delim char to split string around
+ * \param result container for tokens extracted after splitting
+ */
+template<typename Out>
+void split(const std::string &s, const char delim, Out result) {
+  std::stringstream ss;
+  ss.str(s);
+  std::string item;
+  while (std::getline(ss, item, delim)) {
+    *(result++) = item;
+  }
+}
 
 enum KeyType {
   kUndefinedKey = -1,
diff --git a/tests/nightly/test_server_profiling.py b/tests/nightly/test_server_profiling.py
new file mode 100644
index 00000000000..7d157a3e418
--- /dev/null
+++ b/tests/nightly/test_server_profiling.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import mxnet as mx
+import json
+
+key = '99'
+shape = (1200, 1200)        # bigger than MXNET_KVSTORE_BIGARRAY_BOUND
+kv = mx.kv.create('dist_sync')
+
+def init_kv():
+    # init kv dns keys
+    kv.init(key, mx.nd.ones(shape))
+    kv.set_optimizer(mx.optimizer.create('sgd'))
+    return kv, kv.rank, kv.num_workers
+
+def test_sync_push_pull():
+    kv, my_rank, nworker = init_kv()
+    def check_default_keys(kv, my_rank):
+        nrepeat = 10
+        # checks pull after push in loop, because behavior during
+        # consecutive pushes doesn't offer any guarantees
+        for i in range(nrepeat):
+            kv.push(key, mx.nd.ones(shape, dtype='float32') * (my_rank+1))
+            val = mx.nd.zeros(shape, dtype='float32')
+            kv.pull(key, out=val)
+            mx.nd.waitall()
+    check_default_keys(kv, my_rank)
+
+if __name__ == "__main__":
+    server_filename_suffix = 'test_profile_server.json'
+    worker_filename_suffix = 'test_profile_worker.json'
+    mx.profiler.set_config(filename=server_filename_suffix, profile_all=True, profile_process='server')
+    mx.profiler.set_config(filename='rank' + str(kv.rank) + '_' + worker_filename_suffix, profile_all=True, profile_process='worker')
+    mx.profiler.set_state(state='run', profile_process='server')
+    mx.profiler.set_state(state='run', profile_process='worker')
+    test_sync_push_pull()
+    mx.profiler.set_state(state='stop', profile_process='server')
+    mx.profiler.set_state(state='stop', profile_process='worker')
+
+    import glob, os
+
+    # will only work when launcher mode is local, as used for integration test
+    if kv.rank == 0:
+        for rank in range(kv.num_workers):
+            for suffix in [worker_filename_suffix, server_filename_suffix]:
+                # throws value error if file is not proper json
+                filename = 'rank' + str(rank) + '_' + suffix
+                print(glob.glob('*'), os.getcwd())
+                with open(filename, 'r') as f:
+                    j = json.load(f)
+
+
+


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services