You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/09 02:00:54 UTC

[GitHub] rahul003 closed pull request #9933: [MXNET-23] Adding support to profile kvstore server during distributed training

rahul003 closed pull request #9933: [MXNET-23] Adding support to profile kvstore server during distributed training
URL: https://github.com/apache/incubator-mxnet/pull/9933
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py
index 3f37ad3ac59..1baab606db4 100755
--- a/example/image-classification/common/fit.py
+++ b/example/image-classification/common/fit.py
@@ -135,6 +135,12 @@ def add_fit_args(parser):
                        help='the epochs to ramp-up lr to scaled large-batch value')
     train.add_argument('--warmup-strategy', type=str, default='linear',
                        help='the ramping-up strategy for large batch sgd')
+    train.add_argument('--profile-worker-suffix', type=str, default='',
+                       help='profile workers actions into this file. During distributed training\
+                             filename saved will be rank1_ followed by this suffix')
+    train.add_argument('--profile-server-suffix', type=str, default='',
+                       help='profile server actions into a file with name like rank1_ followed by this suffix \
+                             during distributed training')
     return train
 
 
@@ -150,6 +156,17 @@ def fit(args, network, data_loader, **kwargs):
     if args.gc_type != 'none':
         kv.set_gradient_compression({'type': args.gc_type,
                                      'threshold': args.gc_threshold})
+    if args.profile_server_suffix:
+        mx.profiler.set_config(filename=args.profile_server_suffix, profile_all=True, profile_process='server')
+        mx.profiler.set_state(state='run', profile_process='server')
+
+    if args.profile_worker_suffix:
+        if kv.num_workers > 1:
+            filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix
+        else:
+            filename = args.profile_worker_suffix
+        mx.profiler.set_config(filename=filename, profile_all=True, profile_process='worker')
+        mx.profiler.set_state(state='run', profile_process='worker')
 
     # logging
     head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
@@ -175,7 +192,6 @@ def fit(args, network, data_loader, **kwargs):
                 logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
                              args.disp_batches * args.batch_size / (time.time() - tic))
                 tic = time.time()
-
         return
 
     # load model
@@ -309,3 +325,8 @@ def fit(args, network, data_loader, **kwargs):
               epoch_end_callback=checkpoint,
               allow_missing=True,
               monitor=monitor)
+
+    if args.profile_server_suffix:
+        mx.profiler.set_state(state='run', profile_process='server')
+    if args.profile_worker_suffix:
+        mx.profiler.set_state(state='run', profile_process='worker')
\ No newline at end of file
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 6b7cf4407ed..d95f7dd8e8c 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -229,6 +229,17 @@ MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id);
  */
 MXNET_DLL int MXNotifyShutdown();
 
+/*!
+ * \brief Set up configuration of profiler
+ * \param num_params Number of parameters
+ * \param keys array of parameter keys
+ * \param vals array of parameter values
+ * \param kvstoreHandle handle to kvstore
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals,
+                                  KVStoreHandle kvstoreHandle);
+
 /*!
  * \brief Set up configuration of profiler
  * \param num_params Number of parameters
@@ -243,18 +254,48 @@ MXNET_DLL int MXSetProfilerConfig(int num_params, const char* const* keys, const
  * \param state indicate the working state of profiler,
  *  profiler not running when state == 0,
  *  profiler running when state == 1
+ * \param profile_process an int,
+ * when 0 command is for worker process,
+ * when 1 command is for server process
+ * \param kvstoreHandle handle to kvstore
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXSetProfilerState(int state, int profile_process, KVStoreHandle kvStoreHandle);
+
+/*!
+ * \brief Set up state of profiler
+ * \param state indicate the working state of profiler,
+ *  profiler not running when state == 0,
+ *  profiler running when state == 1
+ * \param profile_process an int,
+ * when 0 command is for worker process,
+ * when 1 command is for server process
  * \return 0 when success, -1 when failure happens.
  */
-MXNET_DLL int MXSetProfilerState(int state);
+MXNET_DLL int MXSetProfilerState(int state, int profile_process);
 
 /*!
  * \brief Save profile and stop profiler
  * \param finished true if stat output should stop after this point
+ * \param profile_process an int,
+ * when 0 command is for worker process,
+ * when 1 command is for server process
+ * \param kvstoreHandle handle to kvstore
  * \return 0 when success, -1 when failure happens.
  */
-MXNET_DLL int MXDumpProfile(int finished);
+MXNET_DLL int MXDumpProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle);
 
 
+/*!
+ * \brief Save profile and stop profiler
+ * \param finished true if stat output should stop after this point
+ * \param profile_process an int,
+ * when 0 command is for worker process,
+ * when 1 command is for server process
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXDumpProfile(int finished, int profile_process);
+
 /*!
  * \brief Print aggregate stats to the a string
  * \param out_str Will receive a pointer to the output string
@@ -267,10 +308,21 @@ MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset);
 /*!
  * \brief Pause profiler tuning collection
  * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
+ * \param profile_process integer which denotes whether to process worker or server process
+ * \param kvstoreHandle handle to kvstore
  * \return 0 when success, -1 when failure happens.
  * \note pausing and resuming is global and not recursive
  */
-MXNET_DLL int MXProfilePause(int paused);
+MXNET_DLL int MXProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle);
+
+/*!
+ * \brief Pause profiler tuning collection
+ * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
+ * \param profile_process integer which denotes whether to process worker or server process
+ * \return 0 when success, -1 when failure happens.
+ * \note pausing and resuming is global and not recursive
+ */
+MXNET_DLL int MXProfilePause(int paused, int profile_process);
 
 /*!
  * \brief Create profiling domain
@@ -2089,8 +2141,7 @@ typedef void (MXKVStoreServerController)(int head,
                                          void *controller_handle);
 
 /**
- * \return Run as server (or scheduler)
- *
+ * \brief Run as server (or scheduler)
  * \param handle handle to the KVStore
  * \param controller the user-defined server controller
  * \param controller_handle helper handle for implementing controller
@@ -2101,8 +2152,7 @@ MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle,
                                  void *controller_handle);
 
 /**
- * \return Send a command to all server nodes
- *
+ * \brief Send a command to all server nodes
  * \param handle handle to the KVStore
  * \param cmd_id the head of the command
  * \param cmd_body the body of the command
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index 4e99a9c861f..7be8ef995ea 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -38,6 +38,18 @@
 #endif  // MXNET_USE_DIST_KVSTORE
 
 namespace mxnet {
+
+/*!
+ * \brief enum to denote types of commands kvstore sends to server regarding profiler
+ * kSetConfig sets profiler configs. Similar to mx.profiler.set_config()
+ * kState allows changing state of profiler to stop or run
+ * kPause allows pausing and resuming of profiler
+ * kDump asks profiler to dump output
+ */
+enum class KVStoreServerProfilerCommand {
+  kSetConfig, kState, kPause, kDump
+};
+
 /*!
  * \brief distributed key-value store
  *
@@ -361,6 +373,20 @@ class KVStore {
    */
   virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { }
 
+  /**
+   * \brief Sends server profiler commands to all server nodes
+   * Only the worker with rank=0 sends the command which will be received by all servers
+   * \param type ProfilerCommand type
+   * \param params parameters for that command in the form of a string
+   */
+  virtual void SetServerProfilerCommand(const KVStoreServerProfilerCommand type,
+                                        const std::string& params) {
+    LOG(INFO) << "Unable to pass server the profiler command. If you are using "
+              << "distributed kvstore, you need to compile with USE_DIST_KVSTORE=1."
+              << "If you are training on single machine, then there is no server process"
+              << "to profile. Please profile the worker process instead.";
+  }
+
   /**
    * \brief the prototype of a server controller
    */
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index f31dac01cd1..850314c8612 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -28,6 +28,7 @@
 from .base import check_call, string_types, mx_uint, py_str
 from .base import NDArrayHandle, KVStoreHandle
 from . import optimizer as opt
+from .profiler import set_kvstore_handle
 
 def _ctype_key_value(keys, vals):
     """
@@ -88,7 +89,8 @@ def _get_kvstore_server_command_type(command):
                      'kSetMultiPrecision': 1,
                      'kStopServer': 2,
                      'kSyncMode': 3,
-                     'kSetGradientCompression': 4}
+                     'kSetGradientCompression': 4,
+                     'kSetProfilerParams': 5}
     assert (command in command_types), "Unknown command type to send to server"
     return command_types[command]
 
@@ -665,4 +667,6 @@ def create(name='local'):
     handle = KVStoreHandle()
     check_call(_LIB.MXKVStoreCreate(c_str(name),
                                     ctypes.byref(handle)))
-    return KVStore(handle)
+    kv = KVStore(handle)
+    set_kvstore_handle(kv.handle)
+    return kv
diff --git a/python/mxnet/profiler.py b/python/mxnet/profiler.py
index 0e7a31c687e..aa1ea38a340 100644
--- a/python/mxnet/profiler.py
+++ b/python/mxnet/profiler.py
@@ -22,8 +22,13 @@
 from __future__ import absolute_import
 import ctypes
 import warnings
-from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str
+from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str, KVStoreHandle
 
+profiler_kvstore_handle = KVStoreHandle()
+
+def set_kvstore_handle(handle):
+    global profiler_kvstore_handle
+    profiler_kvstore_handle = handle
 
 def set_config(**kwargs):
     """Set up the configure of profiler (only accepts keyword arguments).
@@ -49,12 +54,17 @@ def set_config(**kwargs):
     aggregate_stats : boolean,
         whether to maintain aggregate stats in memory for console
         dump.  Has some negative performance impact.
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
     """
     kk = kwargs.keys()
     vv = kwargs.values()
     check_call(_LIB.MXSetProfilerConfig(len(kwargs),
                                         c_str_array([key for key in kk]),
-                                        c_str_array([str(val) for val in vv])))
+                                        c_str_array([str(val) for val in vv]),
+                                        profiler_kvstore_handle))
 
 
 def profiler_set_config(mode='symbolic', filename='profile.json'):
@@ -73,10 +83,10 @@ def profiler_set_config(mode='symbolic', filename='profile.json'):
     keys = c_str_array([key for key in ["profile_" + mode, "filename"]])
     values = c_str_array([str(val) for val in [True, filename]])
     assert len(keys) == len(values)
-    check_call(_LIB.MXSetProfilerConfig(len(keys), keys, values))
+    check_call(_LIB.MXSetProfilerConfig(len(keys), keys, values, profiler_kvstore_handle))
 
 
-def set_state(state='stop'):
+def set_state(state='stop', profile_process='worker'):
     """Set up the profiler state to 'run' or 'stop'.
 
     Parameters
@@ -84,9 +94,16 @@ def set_state(state='stop'):
     state : string, optional
         Indicates whether to run the profiler, can
         be 'stop' or 'run'. Default is `stop`.
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
     """
     state2int = {'stop': 0, 'run': 1}
-    check_call(_LIB.MXSetProfilerState(ctypes.c_int(state2int[state])))
+    profile_process2int = {'worker': 0, 'server': 1}
+    check_call(_LIB.MXSetProfilerState(ctypes.c_int(state2int[state]),
+                                       profile_process2int[profile_process],
+                                       profiler_kvstore_handle))
 
 
 def profiler_set_state(state='stop'):
@@ -102,7 +119,7 @@ def profiler_set_state(state='stop'):
                   'Please use profiler.set_state() instead')
     set_state(state)
 
-def dump(finished=True):
+def dump(finished=True, profile_process='worker'):
     """Dump profile and stop profiler. Use this to save profile
     in advance in case your program cannot exit normally.
 
@@ -111,9 +128,16 @@ def dump(finished=True):
     finished : boolean
         Indicates whether to stop statistic output (dumping) after this dump.
         Default is True
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
     """
-    fin = 1 if finished is True else False
-    check_call(_LIB.MXDumpProfile(fin))
+    fin = 1 if finished is True else 0
+    profile_process2int = {'worker': 0, 'server': 1}
+    check_call(_LIB.MXDumpProfile(fin,
+                                  profile_process2int[profile_process],
+                                  profiler_kvstore_handle))
 
 
 def dump_profile():
@@ -138,14 +162,37 @@ def dumps(reset=False):
     return py_str(debug_str.value)
 
 
-def pause():
-    """Pause profiling."""
-    check_call(_LIB.MXProfilePause(int(1)))
+def pause(profile_process='worker'):
+    """Pause profiling.
+
+    Parameters
+    ----------
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
+    """
+    profile_process2int = {'worker': 0, 'server': 1}
+    check_call(_LIB.MXProfilePause(int(1),
+                                   profile_process2int[profile_process],
+                                   profiler_kvstore_handle))
+
 
+def resume(profile_process='worker'):
+    """
+    Resume paused profiling.
 
-def resume():
-    """Resume paused profiling."""
-    check_call(_LIB.MXProfilePause(int(0)))
+    Parameters
+    ----------
+    profile_process : string
+        whether to profile kvstore `server` or `worker`.
+        server can only be profiled when kvstore is of type dist.
+        if this is not passed, defaults to `worker`
+    """
+    profile_process2int = {'worker': 0, 'server': 1}
+    check_call(_LIB.MXProfilePause(int(0),
+                                   profile_process2int[profile_process],
+                                   profiler_kvstore_handle))
 
 
 class Domain(object):
diff --git a/src/c_api/c_api_profile.cc b/src/c_api/c_api_profile.cc
index c5841775794..d023a50542b 100644
--- a/src/c_api/c_api_profile.cc
+++ b/src/c_api/c_api_profile.cc
@@ -29,6 +29,7 @@
 #include <dmlc/base.h>
 #include <dmlc/logging.h>
 #include <dmlc/thread_group.h>
+#include <mxnet/kvstore.h>
 #include <stack>
 #include "./c_api_common.h"
 #include "../profiler/profiler.h"
@@ -197,6 +198,10 @@ struct PythonProfileObjects {
 };
 static PythonProfileObjects python_profile_objects;
 
+enum class ProfileProcess {
+  kWorker, kServer
+};
+
 struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
   bool profile_all;
   bool profile_symbolic;
@@ -207,6 +212,7 @@ struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
   bool continuous_dump;
   float dump_period;
   bool aggregate_stats;
+  int profile_process;
   DMLC_DECLARE_PARAMETER(ProfileConfigParam) {
     DMLC_DECLARE_FIELD(profile_all).set_default(false)
       .describe("Profile all.");
@@ -228,6 +234,13 @@ struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
     DMLC_DECLARE_FIELD(aggregate_stats).set_default(false)
       .describe("Maintain aggregate stats, required for MXDumpAggregateStats.  Note that "
       "this can have anegative performance impact.");
+    DMLC_DECLARE_FIELD(profile_process)
+      .add_enum("worker", static_cast<int>(ProfileProcess::kWorker))
+      .add_enum("server", static_cast<int>(ProfileProcess::kServer))
+      .set_default(static_cast<int>(ProfileProcess::kWorker))
+      .describe("Specifies which process to profile: "
+                "worker: this is default. for single node training it should always be worker."
+                "server: for distributed training, this profiles server process");
   }
 };
 
@@ -248,7 +261,8 @@ struct ProfileMarkerScopeParam : public dmlc::Parameter<ProfileMarkerScopeParam>
 
 DMLC_REGISTER_PARAMETER(ProfileMarkerScopeParam);
 
-int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals) {
+int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals,
+                        KVStoreHandle kvstoreHandle) {
     mxnet::IgnoreProfileCallScope ignore;
   API_BEGIN();
     std::vector<std::pair<std::string, std::string>> kwargs;
@@ -260,19 +274,36 @@ int MXSetProfilerConfig(int num_params, const char* const* keys, const char* con
     }
     ProfileConfigParam param;
     param.Init(kwargs);
-    int mode = 0;
-    if (param.profile_api || param.profile_all)        { mode |= profiler::Profiler::kAPI; }
-    if (param.profile_symbolic || param.profile_all)   { mode |= profiler::Profiler::kSymbolic; }
-    if (param.profile_imperative || param.profile_all) { mode |= profiler::Profiler::kImperative; }
-    if (param.profile_memory || param.profile_all)     { mode |= profiler::Profiler::kMemory; }
-    profiler::Profiler::Get()->SetConfig(profiler::Profiler::ProfilerMode(mode),
-                                         std::string(param.filename),
-                                         param.continuous_dump,
-                                         param.dump_period,
-                                         param.aggregate_stats);
+    if (static_cast<ProfileProcess>(param.profile_process) == ProfileProcess::kServer) {
+      std::ostringstream os;
+      for (int i = 0; i < num_params; ++i) {
+        // this will be sent to the server now, those configs shouldn't have profile server again
+        if (strcmp(keys[i], "profile_process") == 0) continue;
+        os << keys[i] << ":" << vals[i];
+        if (i != num_params - 1) os << ",";
+      }
+      CHECK(kvstoreHandle) << "KVStoreHandle passed to profiler is null";
+      static_cast<KVStore*>(kvstoreHandle)->SetServerProfilerCommand(
+      mxnet::KVStoreServerProfilerCommand::kSetConfig, os.str());
+    } else {
+      int mode = 0;
+      if (param.profile_api || param.profile_all)        { mode |= profiler::Profiler::kAPI; }
+      if (param.profile_symbolic || param.profile_all)   { mode |= profiler::Profiler::kSymbolic; }
+      if (param.profile_imperative || param.profile_all) { mode |= profiler::Profiler::kImperative; }
+      if (param.profile_memory || param.profile_all)     { mode |= profiler::Profiler::kMemory; }
+      profiler::Profiler::Get()->SetConfig(profiler::Profiler::ProfilerMode(mode),
+                                           std::string(param.filename),
+                                           param.continuous_dump,
+                                           param.dump_period,
+                                           param.aggregate_stats);
+    }
   API_END();
 }
 
+int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals) {
+  MXSetProfilerConfig(num_params, keys, vals, nullptr);
+}
+
 int MXAggregateProfileStatsPrint(const char **out_str, int reset) {
   MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
   API_BEGIN();
@@ -292,29 +323,51 @@ int MXAggregateProfileStatsPrint(const char **out_str, int reset) {
   API_END();
 }
 
-int MXDumpProfile(int finished) {
+int MXDumpProfile(int finished, int profile_process) {
+  MXDumpProfile(finished, profile_process, nullptr);
+}
+
+int MXDumpProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle) {
   mxnet::IgnoreProfileCallScope ignore;
   API_BEGIN();
-    profiler::Profiler *profiler = profiler::Profiler::Get();
-    CHECK(profiler->IsEnableOutput())
-      << "Profiler hasn't been run. Config and start profiler first";
-    profiler->DumpProfile(finished != 0);
+    if (static_cast<ProfileProcess>(profile_process) == ProfileProcess::kServer) {
+      CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null";
+      static_cast<KVStore*>(kvStoreHandle)->SetServerProfilerCommand(
+        mxnet::KVStoreServerProfilerCommand::kDump,
+        std::to_string(finished));
+    } else {
+      profiler::Profiler *profiler = profiler::Profiler::Get();
+      CHECK(profiler->IsEnableOutput())
+        << "Profiler hasn't been run. Config and start profiler first";
+      profiler->DumpProfile(finished != 0);
+    }
   API_END()
 }
 
-int MXSetProfilerState(int state) {
+int MXSetProfilerState(int state, int profile_process) {
+  MXSetProfilerState(state, profile_process, nullptr);
+}
+
+int MXSetProfilerState(int state, int profile_process, KVStoreHandle kvStoreHandle) {
   mxnet::IgnoreProfileCallScope ignore;
   // state, kNotRunning: 0, kRunning: 1
   API_BEGIN();
-    switch (state) {
-      case profiler::Profiler::kNotRunning:
-        profiler::vtune::vtune_pause();
-        break;
-      case profiler::Profiler::kRunning:
-        profiler::vtune::vtune_resume();
-        break;
+    if (static_cast<ProfileProcess>(profile_process) == ProfileProcess::kServer) {
+      CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null";
+      static_cast<KVStore*>(kvStoreHandle)->SetServerProfilerCommand(
+        mxnet::KVStoreServerProfilerCommand::kState,
+        std::to_string(state));
+    } else {
+      switch (state) {
+        case profiler::Profiler::kNotRunning:
+          profiler::vtune::vtune_pause();
+          break;
+        case profiler::Profiler::kRunning:
+          profiler::vtune::vtune_resume();
+          break;
+      }
+      profiler::Profiler::Get()->SetState(profiler::Profiler::ProfilerState(state));
     }
-    profiler::Profiler::Get()->SetState(profiler::Profiler::ProfilerState(state));
   API_END();
 }
 
@@ -449,15 +502,26 @@ int MXProfileDurationStop(ProfileHandle duration_handle) {
   API_END();
 }
 
-int MXProfilePause(int paused) {
+int MXProfilePause(int paused, int profile_process) {
+  MXProfilePause(paused, profile_process, nullptr);
+}
+
+int MXProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle) {
   mxnet::IgnoreProfileCallScope ignore;
   API_BEGIN();
-    if (paused) {
-      profiler::vtune::vtune_pause();
-      profiler::Profiler::Get()->set_paused(true);
+    if (static_cast<ProfileProcess>(profile_process) == ProfileProcess::kServer) {
+      CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null";
+      static_cast<KVStore*>(kvStoreHandle)->SetServerProfilerCommand(
+        mxnet::KVStoreServerProfilerCommand::kPause,
+        std::to_string(paused));
     } else {
-      profiler::Profiler::Get()->set_paused(false);
-      profiler::vtune::vtune_resume();
+      if (paused) {
+        profiler::vtune::vtune_pause();
+        profiler::Profiler::Get()->set_paused(true);
+      } else {
+        profiler::Profiler::Get()->set_paused(false);
+        profiler::vtune::vtune_resume();
+      }
     }
   API_END();
 }
diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc
index e94a0570d1f..e4a06fa9a1f 100644
--- a/src/kvstore/gradient_compression.cc
+++ b/src/kvstore/gradient_compression.cc
@@ -23,31 +23,14 @@
  * \author Rahul Huilgol
  */
 
-#include <sstream>
 #include <vector>
+#include "kvstore_local.h"
 #include "gradient_compression.h"
 #include "gradient_compression-inl.h"
 
 namespace mxnet {
 namespace kvstore {
 
-/*!
- * \brief Splits a string into smaller strings using char as delimiter
- * Example: "a,b,c,,d" is split into ["a","b","c","","d"]
- * \param s string to split
- * \param delim char to split string around
- * \param result container for tokens extracted after splitting
- */
-template<typename Out>
-void split(const std::string &s, const char delim, Out result) {
-  std::stringstream ss;
-  ss.str(s);
-  std::string item;
-  while (std::getline(ss, item, delim)) {
-    *(result++) = item;
-  }
-}
-
 DMLC_REGISTER_PARAMETER(GradientCompressionParam);
 
 GradientCompression::GradientCompression() {
@@ -90,7 +73,7 @@ std::string GradientCompression::EncodeParams() {
 
 void GradientCompression::DecodeParams(const std::string &s) {
   std::vector<std::string> elems;
-  split(s, ',', std::back_inserter(elems));
+  mxnet::kvstore::split(s, ',', std::back_inserter(elems));
   type_ = static_cast<CompressionType>(stoi(elems[0]));
   if (elems.size() > 1) {
     if (!elems[1].empty()) {
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 373081bc7b1..71b1a5a8ad9 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -92,6 +92,15 @@ class KVStoreDist : public KVStoreLocal {
     }
   }
 
+  void SetServerProfilerCommand(const KVStoreServerProfilerCommand type,
+                                const std::string& params) override {
+    if (get_rank() == 0) {
+      SendCommandToServers(static_cast<int>(CommandType::kSetProfilerParams),
+                           params + std::to_string(static_cast<int>(type)));
+    }
+  }
+
+
   void Barrier() override {
     ps::Postoffice::Get()->Barrier(ps_worker_->get_customer()->customer_id(), ps::kWorkerGroup);
   }
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index 421de27b39d..92924fc94d4 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -24,6 +24,9 @@
  */
 #ifndef MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_
 #define MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_
+#include <mxnet/c_api.h>
+#include <mxnet/kvstore.h>
+#include <ps/ps.h>
 #include <queue>
 #include <string>
 #include <mutex>
@@ -32,8 +35,7 @@
 #include <functional>
 #include <future>
 #include <vector>
-#include "ps/ps.h"
-#include "mxnet/kvstore.h"
+#include "../profiler/profiler.h"
 #include "../operator/tensor/elemwise_binary_op-inl.h"
 #include "../operator/tensor/init_op.h"
 
@@ -42,7 +44,7 @@ namespace kvstore {
 
 // maintain same order in frontend.
 enum class CommandType {
-  kController, kSetMultiPrecision, kStopServer, kSyncMode, kSetGradientCompression,
+  kController, kSetMultiPrecision, kStopServer, kSyncMode, kSetGradientCompression, kSetProfilerParams
 };
 
 enum class RequestType {
@@ -163,6 +165,7 @@ class KVStoreDistServer {
   }
 
   ~KVStoreDistServer() {
+    profiler::Profiler::Get()->SetState(profiler::Profiler::ProfilerState(0));
     delete ps_server_;
   }
 
@@ -193,31 +196,93 @@ class KVStoreDistServer {
 
   void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
     CommandType recved_type = static_cast<CommandType>(recved.head);
-    if (recved_type == CommandType::kStopServer) {
-      exec_.Stop();
-    } else if (recved_type == CommandType::kSyncMode) {
-      sync_mode_ = true;
-    } else if (recved_type == CommandType::kSetGradientCompression) {
-      gradient_compression_->DecodeParams(recved.body);
-    } else if (recved_type == CommandType::kSetMultiPrecision) {
-      // uses value 1 for message id from frontend
-      if (!multi_precision_) {
-        multi_precision_ = true;
-        CreateMultiPrecisionCopies();
-      }
-    } else if (recved_type == CommandType::kController) {
-      // value of 0
-      // let the main thread to execute ctrl, which is necessary for python
-      exec_.Exec([this, recved]() {
-          CHECK(controller_);
-          controller_(recved.head, recved.body);
-        });
-    } else {
-      LOG(FATAL) << "Unknown command type received " << recved.head;
+    switch (recved_type) {
+      case CommandType::kStopServer:
+        exec_.Stop();
+        break;
+      case CommandType::kSyncMode:
+        sync_mode_ = true;
+        break;
+      case CommandType::kSetGradientCompression:
+        gradient_compression_->DecodeParams(recved.body);
+        break;
+      case CommandType::kSetMultiPrecision:
+        if (!multi_precision_) {
+          multi_precision_ = true;
+          CreateMultiPrecisionCopies();
+        }
+      case CommandType::kSetProfilerParams:
+        // last char is the type of profiler command
+        ProcessServerProfilerCommands(static_cast<KVStoreServerProfilerCommand>
+                                                  (recved.body.back() - '0'),
+                                      recved.body);
+        break;
+      case CommandType::kController:
+        // this uses value 0 for message id from frontend
+        // let the main thread to execute ctrl, which is necessary for python
+        exec_.Exec([this, recved]() {
+            CHECK(controller_);
+            controller_(recved.head, recved.body);
+          });
+        break;
+      default:
+        LOG(FATAL) << "Unknown command type received " << recved.head;
+        break;
     }
     app->Response(recved);
   }
 
+  void ProcessServerProfilerCommands(KVStoreServerProfilerCommand type, const std::string& body) {
+    switch (type) {
+      case KVStoreServerProfilerCommand::kSetConfig:
+        SetProfilerConfig(body.substr(0, body.size() - 1));
+        break;
+      case KVStoreServerProfilerCommand::kState:
+        MXSetProfilerState(static_cast<int>(body.front() - '0'), 0, nullptr);
+        break;
+      case KVStoreServerProfilerCommand::kPause:
+        MXProfilePause(static_cast<int>(body.front() - '0'), 0, nullptr);
+        break;
+      case KVStoreServerProfilerCommand::kDump:
+        MXDumpProfile(static_cast<int>(body.front() - '0'), 0, nullptr);
+        break;
+    }
+  }
+
+  void SetProfilerConfig(std::string params_str) {
+    std::vector<std::string> elems;
+    mxnet::kvstore::split(params_str, ',', std::back_inserter(elems));
+    std::vector<const char*> ckeys;
+    std::vector<const char*> cvals;
+    ckeys.reserve(elems.size());
+    cvals.reserve(elems.size());
+
+    for (size_t i=0; i < elems.size(); i++) {
+      std::vector<std::string> parts;
+      mxnet::kvstore::split(elems[i], ':', std::back_inserter(parts));
+      CHECK_EQ(parts.size(), 2) << "Improper profiler config passed from worker";
+      CHECK(!parts[0].empty()) << "ProfilerConfig parameter is empty";
+      CHECK(!parts[1].empty()) << "ProfilerConfig value is empty for parameter "<< parts[0];
+      if (parts[0] == "filename") {
+        parts[1] = "rank" + std::to_string(ps::MyRank()) + "_" + parts[1];
+      }
+
+      char* ckey = new char[parts[0].length() + 1];
+      std::snprintf(ckey, parts[0].length() + 1, "%s", parts[0].c_str());
+      ckeys.push_back(ckey);
+
+      char* cval = new char[parts[1].length() + 1];
+      std::snprintf(cval, parts[1].length() + 1, "%s", parts[1].c_str());
+      cvals.push_back(cval);
+    }
+
+    MXSetProfilerConfig(elems.size(), &ckeys[0], &cvals[0], nullptr);
+    for (size_t i=0; i < ckeys.size(); i++) {
+      delete[] ckeys[i];
+      delete[] cvals[i];
+    }
+  }
+
   /*
    * For keys already initialized, if necessary create stored_realt.
    * This will only be used if by some wrong usage of kvstore,
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 38ecf121dfe..ce3276e24eb 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -39,6 +39,22 @@
 
 namespace mxnet {
 namespace kvstore {
+/*!
+ * \brief Splits a string into smaller strings using char as delimiter
+ * Example: "a,b,c,,d" is split into ["a","b","c","","d"]
+ * \param s string to split
+ * \param delim char to split string around
+ * \param result container for tokens extracted after splitting
+ */
+template<typename Out>
+void split(const std::string &s, const char delim, Out result) {
+  std::stringstream ss;
+  ss.str(s);
+  std::string item;
+  while (std::getline(ss, item, delim)) {
+    *(result++) = item;
+  }
+}
 
 enum KeyType {
   kUndefinedKey = -1,


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services