You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/08/04 17:22:58 UTC
[incubator-mxnet] branch master updated: [MXNET-23] Adding support
to profile kvstore server during distributed training (#11215)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 4649bfa [MXNET-23] Adding support to profile kvstore server during distributed training (#11215)
4649bfa is described below
commit 4649bfa641ad4129b3e83ea0af14b489e512f8f4
Author: Rahul Huilgol <ra...@gmail.com>
AuthorDate: Sat Aug 4 10:22:50 2018 -0700
[MXNET-23] Adding support to profile kvstore server during distributed training (#11215)
* server profiling
merge with master
cleanup old code
added a check and better info message
add functions for C compatibility
fix doc
lint fixes
fix compile issues
lint fix
build error
update function signatures to preserve compatibility
fix comments
lint
* add part1 of test
* add integration test
---
ci/docker/runtime_functions.sh | 1 +
example/image-classification/common/fit.py | 23 +++++-
include/mxnet/c_api.h | 59 ++++++++++++--
include/mxnet/kvstore.h | 26 +++++++
python/mxnet/kvstore.py | 8 +-
python/mxnet/profiler.py | 79 +++++++++++++++----
src/c_api/c_api_profile.cc | 87 ++++++++++++++++++---
src/kvstore/gradient_compression.cc | 21 +----
src/kvstore/kvstore_dist.h | 9 +++
src/kvstore/kvstore_dist_server.h | 121 ++++++++++++++++++++++-------
src/kvstore/kvstore_local.h | 16 ++++
tests/nightly/test_server_profiling.py | 69 ++++++++++++++++
12 files changed, 434 insertions(+), 85 deletions(-)
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 2147190..1c861be 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -731,6 +731,7 @@ integrationtest_ubuntu_cpu_dist_kvstore() {
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=compressed_cpu
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=compressed_cpu --no-multiprecision
+ ../../tools/launch.py -n 3 --launcher local python test_server_profiling.py
}
integrationtest_ubuntu_gpu_scala() {
diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py
index 67cda78..b3b1305 100755
--- a/example/image-classification/common/fit.py
+++ b/example/image-classification/common/fit.py
@@ -135,6 +135,12 @@ def add_fit_args(parser):
help='the epochs to ramp-up lr to scaled large-batch value')
train.add_argument('--warmup-strategy', type=str, default='linear',
help='the ramping-up strategy for large batch sgd')
+ train.add_argument('--profile-worker-suffix', type=str, default='',
+ help='profile workers actions into this file. During distributed training\
+ filename saved will be rank1_ followed by this suffix')
+ train.add_argument('--profile-server-suffix', type=str, default='',
+ help='profile server actions into a file with name like rank1_ followed by this suffix \
+ during distributed training')
return train
@@ -150,6 +156,17 @@ def fit(args, network, data_loader, **kwargs):
if args.gc_type != 'none':
kv.set_gradient_compression({'type': args.gc_type,
'threshold': args.gc_threshold})
+ if args.profile_server_suffix:
+ mx.profiler.set_config(filename=args.profile_server_suffix, profile_all=True, profile_process='server')
+ mx.profiler.set_state(state='run', profile_process='server')
+
+ if args.profile_worker_suffix:
+ if kv.num_workers > 1:
+ filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix
+ else:
+ filename = args.profile_worker_suffix
+ mx.profiler.set_config(filename=filename, profile_all=True, profile_process='worker')
+ mx.profiler.set_state(state='run', profile_process='worker')
# logging
head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
@@ -180,7 +197,6 @@ def fit(args, network, data_loader, **kwargs):
logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
args.disp_batches * args.batch_size / (time.time() - tic))
tic = time.time()
-
return
# load model
@@ -314,3 +330,8 @@ def fit(args, network, data_loader, **kwargs):
epoch_end_callback=checkpoint,
allow_missing=True,
monitor=monitor)
+
+ if args.profile_server_suffix:
+ mx.profiler.set_state(state='run', profile_process='server')
+ if args.profile_worker_suffix:
+ mx.profiler.set_state(state='run', profile_process='worker')
\ No newline at end of file
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 75147cf..6bbe9df 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -230,7 +230,19 @@ MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id);
MXNET_DLL int MXNotifyShutdown();
/*!
- * \brief Set up configuration of profiler
+ * \brief Set up configuration of profiler for the process passed as profile_process in keys
+ * \param num_params Number of parameters
+ * \param keys array of parameter keys
+ * \param vals array of parameter values
+ * \param kvstoreHandle handle to kvstore
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXSetProcessProfilerConfig(int num_params, const char* const* keys,
+ const char* const* vals,
+ KVStoreHandle kvstoreHandle);
+
+/*!
+ * \brief Set up configuration of profiler for worker/current process
* \param num_params Number of parameters
* \param keys array of parameter keys
* \param vals array of parameter values
@@ -239,7 +251,21 @@ MXNET_DLL int MXNotifyShutdown();
MXNET_DLL int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals);
/*!
- * \brief Set up state of profiler
+ * \brief Set up state of profiler for either worker or server process
+ * \param state indicate the working state of profiler,
+ * profiler not running when state == 0,
+ * profiler running when state == 1
+ * \param profile_process an int,
+ * when 0 command is for worker/current process,
+ * when 1 command is for server process
+ * \param kvstoreHandle handle to kvstore, needed for server process profiling
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXSetProcessProfilerState(int state, int profile_process,
+ KVStoreHandle kvStoreHandle);
+
+/*!
+ * \brief Set up state of profiler for current process
* \param state indicate the working state of profiler,
* profiler not running when state == 0,
* profiler running when state == 1
@@ -250,12 +276,23 @@ MXNET_DLL int MXSetProfilerState(int state);
/*!
* \brief Save profile and stop profiler
* \param finished true if stat output should stop after this point
+ * \param profile_process an int,
+ * when 0 command is for worker/current process,
+ * when 1 command is for server process
+ * \param kvstoreHandle handle to kvstore
* \return 0 when success, -1 when failure happens.
*/
-MXNET_DLL int MXDumpProfile(int finished);
+MXNET_DLL int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle);
/*!
+ * \brief Save profile and stop profiler for worker/current process
+ * \param finished true if stat output should stop after this point
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXDumpProfile(int finished);
+
+/*!
* \brief Print aggregate stats to the a string
* \param out_str Will receive a pointer to the output string
* \param reset Clear the aggregate stats after printing
@@ -267,6 +304,16 @@ MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset);
/*!
* \brief Pause profiler tuning collection
* \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
+ * \param profile_process integer which denotes whether to process worker or server process
+ * \param kvstoreHandle handle to kvstore
+ * \return 0 when success, -1 when failure happens.
+ * \note pausing and resuming is global and not recursive
+ */
+MXNET_DLL int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle);
+
+/*!
+ * \brief Pause profiler tuning collection for worker/current process
+ * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
* \return 0 when success, -1 when failure happens.
* \note pausing and resuming is global and not recursive
*/
@@ -2145,8 +2192,7 @@ typedef void (MXKVStoreServerController)(int head,
void *controller_handle);
/**
- * \return Run as server (or scheduler)
- *
+ * \brief Run as server (or scheduler)
* \param handle handle to the KVStore
* \param controller the user-defined server controller
* \param controller_handle helper handle for implementing controller
@@ -2157,8 +2203,7 @@ MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle,
void *controller_handle);
/**
- * \return Send a command to all server nodes
- *
+ * \brief Send a command to all server nodes
* \param handle handle to the KVStore
* \param cmd_id the head of the command
* \param cmd_body the body of the command
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index e10bd21..a73d963 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -38,6 +38,18 @@
#endif // MXNET_USE_DIST_KVSTORE
namespace mxnet {
+
+/*!
+ * \brief enum to denote types of commands kvstore sends to server regarding profiler
+ * kSetConfig sets profiler configs. Similar to mx.profiler.set_config()
+ * kState allows changing state of profiler to stop or run
+ * kPause allows pausing and resuming of profiler
+ * kDump asks profiler to dump output
+ */
+enum class KVStoreServerProfilerCommand {
+ kSetConfig, kState, kPause, kDump
+};
+
/*!
* \brief distributed key-value store
*
@@ -365,6 +377,20 @@ class KVStore {
virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { }
/**
+ * \brief Sends server profiler commands to all server nodes
+ * Only the worker with rank=0 sends the command which will be received by all servers
+ * \param type ProfilerCommand type
+ * \param params parameters for that command in the form of a string
+ */
+ virtual void SetServerProfilerCommand(const KVStoreServerProfilerCommand type,
+ const std::string& params) {
+ LOG(INFO) << "Unable to pass server the profiler command. If you are using "
+ << "distributed kvstore, you need to compile with USE_DIST_KVSTORE=1."
+ << "If you are training on single machine, then there is no server process"
+ << "to profile. Please profile the worker process instead.";
+ }
+
+ /**
* \brief the prototype of a server controller
*/
typedef std::function<void(int, const std::string&)> Controller;
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index 6097336..a548175 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -28,6 +28,7 @@ from .base import _LIB, c_str_array, c_handle_array, c_array, c_array_buf, c_str
from .base import check_call, string_types, mx_uint, py_str
from .base import NDArrayHandle, KVStoreHandle
from . import optimizer as opt
+from .profiler import set_kvstore_handle
def _ctype_key_value(keys, vals):
"""
@@ -88,7 +89,8 @@ def _get_kvstore_server_command_type(command):
'kSetMultiPrecision': 1,
'kStopServer': 2,
'kSyncMode': 3,
- 'kSetGradientCompression': 4}
+ 'kSetGradientCompression': 4,
+ 'kSetProfilerParams': 5}
assert (command in command_types), "Unknown command type to send to server"
return command_types[command]
@@ -670,4 +672,6 @@ def create(name='local'):
handle = KVStoreHandle()
check_call(_LIB.MXKVStoreCreate(c_str(name),
ctypes.byref(handle)))
- return KVStore(handle)
+ kv = KVStore(handle)
+ set_kvstore_handle(kv.handle)
+ return kv
diff --git a/python/mxnet/profiler.py b/python/mxnet/profiler.py
index 0e7a31c..0b5e85b 100644
--- a/python/mxnet/profiler.py
+++ b/python/mxnet/profiler.py
@@ -22,8 +22,13 @@
from __future__ import absolute_import
import ctypes
import warnings
-from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str
+from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str, KVStoreHandle
+profiler_kvstore_handle = KVStoreHandle()
+
+def set_kvstore_handle(handle):
+ global profiler_kvstore_handle
+ profiler_kvstore_handle = handle
def set_config(**kwargs):
"""Set up the configure of profiler (only accepts keyword arguments).
@@ -49,12 +54,17 @@ def set_config(**kwargs):
aggregate_stats : boolean,
whether to maintain aggregate stats in memory for console
dump. Has some negative performance impact.
+ profile_process : string
+ whether to profile kvstore `server` or `worker`.
+ server can only be profiled when kvstore is of type dist.
+ if this is not passed, defaults to `worker`
"""
kk = kwargs.keys()
vv = kwargs.values()
- check_call(_LIB.MXSetProfilerConfig(len(kwargs),
- c_str_array([key for key in kk]),
- c_str_array([str(val) for val in vv])))
+ check_call(_LIB.MXSetProcessProfilerConfig(len(kwargs),
+ c_str_array([key for key in kk]),
+ c_str_array([str(val) for val in vv]),
+ profiler_kvstore_handle))
def profiler_set_config(mode='symbolic', filename='profile.json'):
@@ -73,10 +83,10 @@ def profiler_set_config(mode='symbolic', filename='profile.json'):
keys = c_str_array([key for key in ["profile_" + mode, "filename"]])
values = c_str_array([str(val) for val in [True, filename]])
assert len(keys) == len(values)
- check_call(_LIB.MXSetProfilerConfig(len(keys), keys, values))
+ check_call(_LIB.MXSetProcessProfilerConfig(len(keys), keys, values, profiler_kvstore_handle))
-def set_state(state='stop'):
+def set_state(state='stop', profile_process='worker'):
"""Set up the profiler state to 'run' or 'stop'.
Parameters
@@ -84,9 +94,16 @@ def set_state(state='stop'):
state : string, optional
Indicates whether to run the profiler, can
be 'stop' or 'run'. Default is `stop`.
+ profile_process : string
+ whether to profile kvstore `server` or `worker`.
+ server can only be profiled when kvstore is of type dist.
+ if this is not passed, defaults to `worker`
"""
state2int = {'stop': 0, 'run': 1}
- check_call(_LIB.MXSetProfilerState(ctypes.c_int(state2int[state])))
+ profile_process2int = {'worker': 0, 'server': 1}
+ check_call(_LIB.MXSetProcessProfilerState(ctypes.c_int(state2int[state]),
+ profile_process2int[profile_process],
+ profiler_kvstore_handle))
def profiler_set_state(state='stop'):
@@ -102,7 +119,7 @@ def profiler_set_state(state='stop'):
'Please use profiler.set_state() instead')
set_state(state)
-def dump(finished=True):
+def dump(finished=True, profile_process='worker'):
"""Dump profile and stop profiler. Use this to save profile
in advance in case your program cannot exit normally.
@@ -111,9 +128,16 @@ def dump(finished=True):
finished : boolean
Indicates whether to stop statistic output (dumping) after this dump.
Default is True
+ profile_process : string
+ whether to profile kvstore `server` or `worker`.
+ server can only be profiled when kvstore is of type dist.
+ if this is not passed, defaults to `worker`
"""
- fin = 1 if finished is True else False
- check_call(_LIB.MXDumpProfile(fin))
+ fin = 1 if finished is True else 0
+ profile_process2int = {'worker': 0, 'server': 1}
+ check_call(_LIB.MXDumpProcessProfile(fin,
+ profile_process2int[profile_process],
+ profiler_kvstore_handle))
def dump_profile():
@@ -138,14 +162,37 @@ def dumps(reset=False):
return py_str(debug_str.value)
-def pause():
- """Pause profiling."""
- check_call(_LIB.MXProfilePause(int(1)))
+def pause(profile_process='worker'):
+ """Pause profiling.
+
+ Parameters
+ ----------
+ profile_process : string
+ whether to profile kvstore `server` or `worker`.
+ server can only be profiled when kvstore is of type dist.
+ if this is not passed, defaults to `worker`
+ """
+ profile_process2int = {'worker': 0, 'server': 1}
+ check_call(_LIB.MXProcessProfilePause(int(1),
+ profile_process2int[profile_process],
+ profiler_kvstore_handle))
+
+def resume(profile_process='worker'):
+ """
+ Resume paused profiling.
-def resume():
- """Resume paused profiling."""
- check_call(_LIB.MXProfilePause(int(0)))
+ Parameters
+ ----------
+ profile_process : string
+ whether to profile kvstore `server` or `worker`.
+ server can only be profiled when kvstore is of type dist.
+ if this is not passed, defaults to `worker`
+ """
+ profile_process2int = {'worker': 0, 'server': 1}
+ check_call(_LIB.MXProcessProfilePause(int(0),
+ profile_process2int[profile_process],
+ profiler_kvstore_handle))
class Domain(object):
diff --git a/src/c_api/c_api_profile.cc b/src/c_api/c_api_profile.cc
index c584177..9c03b33 100644
--- a/src/c_api/c_api_profile.cc
+++ b/src/c_api/c_api_profile.cc
@@ -29,6 +29,7 @@
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_group.h>
+#include <mxnet/kvstore.h>
#include <stack>
#include "./c_api_common.h"
#include "../profiler/profiler.h"
@@ -197,6 +198,10 @@ struct PythonProfileObjects {
};
static PythonProfileObjects python_profile_objects;
+enum class ProfileProcess {
+ kWorker, kServer
+};
+
struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
bool profile_all;
bool profile_symbolic;
@@ -207,6 +212,7 @@ struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
bool continuous_dump;
float dump_period;
bool aggregate_stats;
+ int profile_process;
DMLC_DECLARE_PARAMETER(ProfileConfigParam) {
DMLC_DECLARE_FIELD(profile_all).set_default(false)
.describe("Profile all.");
@@ -228,6 +234,13 @@ struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
DMLC_DECLARE_FIELD(aggregate_stats).set_default(false)
.describe("Maintain aggregate stats, required for MXDumpAggregateStats. Note that "
"this can have anegative performance impact.");
+ DMLC_DECLARE_FIELD(profile_process)
+ .add_enum("worker", static_cast<int>(ProfileProcess::kWorker))
+ .add_enum("server", static_cast<int>(ProfileProcess::kServer))
+ .set_default(static_cast<int>(ProfileProcess::kWorker))
+ .describe("Specifies which process to profile: "
+ "worker: this is default. for single node training it should always be worker."
+ "server: for distributed training, this profiles server process");
}
};
@@ -248,7 +261,8 @@ struct ProfileMarkerScopeParam : public dmlc::Parameter<ProfileMarkerScopeParam>
DMLC_REGISTER_PARAMETER(ProfileMarkerScopeParam);
-int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals) {
+int MXSetProcessProfilerConfig(int num_params, const char* const* keys, const char* const* vals,
+ KVStoreHandle kvstoreHandle) {
mxnet::IgnoreProfileCallScope ignore;
API_BEGIN();
std::vector<std::pair<std::string, std::string>> kwargs;
@@ -260,19 +274,37 @@ int MXSetProfilerConfig(int num_params, const char* const* keys, const char* con
}
ProfileConfigParam param;
param.Init(kwargs);
- int mode = 0;
- if (param.profile_api || param.profile_all) { mode |= profiler::Profiler::kAPI; }
- if (param.profile_symbolic || param.profile_all) { mode |= profiler::Profiler::kSymbolic; }
- if (param.profile_imperative || param.profile_all) { mode |= profiler::Profiler::kImperative; }
- if (param.profile_memory || param.profile_all) { mode |= profiler::Profiler::kMemory; }
- profiler::Profiler::Get()->SetConfig(profiler::Profiler::ProfilerMode(mode),
- std::string(param.filename),
- param.continuous_dump,
- param.dump_period,
- param.aggregate_stats);
+ if (static_cast<ProfileProcess>(param.profile_process) == ProfileProcess::kServer) {
+ std::ostringstream os;
+ for (int i = 0; i < num_params; ++i) {
+ // this will be sent to the server now, those configs shouldn't have profile server again
+ if (strcmp(keys[i], "profile_process") == 0) continue;
+ os << keys[i] << ":" << vals[i];
+ if (i != num_params - 1) os << ",";
+ }
+ CHECK(kvstoreHandle) << "KVStoreHandle passed to profiler is null";
+ static_cast<KVStore*>(kvstoreHandle)->SetServerProfilerCommand(
+ mxnet::KVStoreServerProfilerCommand::kSetConfig, os.str());
+ } else {
+ int mode = 0;
+ if (param.profile_api || param.profile_all) { mode |= profiler::Profiler::kAPI; }
+ if (param.profile_symbolic || param.profile_all) { mode |= profiler::Profiler::kSymbolic; }
+ if (param.profile_imperative ||
+ param.profile_all) { mode |= profiler::Profiler::kImperative; }
+ if (param.profile_memory || param.profile_all) { mode |= profiler::Profiler::kMemory; }
+ profiler::Profiler::Get()->SetConfig(profiler::Profiler::ProfilerMode(mode),
+ std::string(param.filename),
+ param.continuous_dump,
+ param.dump_period,
+ param.aggregate_stats);
+ }
API_END();
}
+int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals) {
+ return MXSetProcessProfilerConfig(num_params, keys, vals, nullptr);
+}
+
int MXAggregateProfileStatsPrint(const char **out_str, int reset) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
@@ -293,19 +325,40 @@ int MXAggregateProfileStatsPrint(const char **out_str, int reset) {
}
int MXDumpProfile(int finished) {
+ return MXDumpProcessProfile(finished, static_cast<int>(ProfileProcess::kWorker), nullptr);
+}
+
+int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle) {
mxnet::IgnoreProfileCallScope ignore;
API_BEGIN();
+ if (static_cast<ProfileProcess>(profile_process) == ProfileProcess::kServer) {
+ CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null";
+ static_cast<KVStore*>(kvStoreHandle)->SetServerProfilerCommand(
+ mxnet::KVStoreServerProfilerCommand::kDump,
+ std::to_string(finished));
+ } else {
profiler::Profiler *profiler = profiler::Profiler::Get();
CHECK(profiler->IsEnableOutput())
<< "Profiler hasn't been run. Config and start profiler first";
profiler->DumpProfile(finished != 0);
+ }
API_END()
}
int MXSetProfilerState(int state) {
+ return MXSetProcessProfilerState(state, static_cast<int>(ProfileProcess::kWorker), nullptr);
+}
+
+int MXSetProcessProfilerState(int state, int profile_process, KVStoreHandle kvStoreHandle) {
mxnet::IgnoreProfileCallScope ignore;
// state, kNotRunning: 0, kRunning: 1
API_BEGIN();
+ if (static_cast<ProfileProcess>(profile_process) == ProfileProcess::kServer) {
+ CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null";
+ static_cast<KVStore*>(kvStoreHandle)->SetServerProfilerCommand(
+ mxnet::KVStoreServerProfilerCommand::kState,
+ std::to_string(state));
+ } else {
switch (state) {
case profiler::Profiler::kNotRunning:
profiler::vtune::vtune_pause();
@@ -315,6 +368,7 @@ int MXSetProfilerState(int state) {
break;
}
profiler::Profiler::Get()->SetState(profiler::Profiler::ProfilerState(state));
+ }
API_END();
}
@@ -450,8 +504,18 @@ int MXProfileDurationStop(ProfileHandle duration_handle) {
}
int MXProfilePause(int paused) {
+ return MXProcessProfilePause(paused, static_cast<int>(ProfileProcess::kWorker), nullptr);
+}
+
+int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle) {
mxnet::IgnoreProfileCallScope ignore;
API_BEGIN();
+ if (static_cast<ProfileProcess>(profile_process) == ProfileProcess::kServer) {
+ CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null";
+ static_cast<KVStore*>(kvStoreHandle)->SetServerProfilerCommand(
+ mxnet::KVStoreServerProfilerCommand::kPause,
+ std::to_string(paused));
+ } else {
if (paused) {
profiler::vtune::vtune_pause();
profiler::Profiler::Get()->set_paused(true);
@@ -459,6 +523,7 @@ int MXProfilePause(int paused) {
profiler::Profiler::Get()->set_paused(false);
profiler::vtune::vtune_resume();
}
+ }
API_END();
}
diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc
index e94a057..e4a06fa 100644
--- a/src/kvstore/gradient_compression.cc
+++ b/src/kvstore/gradient_compression.cc
@@ -23,31 +23,14 @@
* \author Rahul Huilgol
*/
-#include <sstream>
#include <vector>
+#include "kvstore_local.h"
#include "gradient_compression.h"
#include "gradient_compression-inl.h"
namespace mxnet {
namespace kvstore {
-/*!
- * \brief Splits a string into smaller strings using char as delimiter
- * Example: "a,b,c,,d" is split into ["a","b","c","","d"]
- * \param s string to split
- * \param delim char to split string around
- * \param result container for tokens extracted after splitting
- */
-template<typename Out>
-void split(const std::string &s, const char delim, Out result) {
- std::stringstream ss;
- ss.str(s);
- std::string item;
- while (std::getline(ss, item, delim)) {
- *(result++) = item;
- }
-}
-
DMLC_REGISTER_PARAMETER(GradientCompressionParam);
GradientCompression::GradientCompression() {
@@ -90,7 +73,7 @@ std::string GradientCompression::EncodeParams() {
void GradientCompression::DecodeParams(const std::string &s) {
std::vector<std::string> elems;
- split(s, ',', std::back_inserter(elems));
+ mxnet::kvstore::split(s, ',', std::back_inserter(elems));
type_ = static_cast<CompressionType>(stoi(elems[0]));
if (elems.size() > 1) {
if (!elems[1].empty()) {
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 7e2f5cb..23fbf67 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -93,6 +93,15 @@ class KVStoreDist : public KVStoreLocal {
}
}
+ void SetServerProfilerCommand(const KVStoreServerProfilerCommand type,
+ const std::string& params) override {
+ if (get_rank() == 0) {
+ SendCommandToServers(static_cast<int>(CommandType::kSetProfilerParams),
+ params + std::to_string(static_cast<int>(type)));
+ }
+ }
+
+
void Barrier() override {
ps::Postoffice::Get()->Barrier(ps_worker_->get_customer()->customer_id(), ps::kWorkerGroup);
}
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index 451fb78..372b58d 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -24,6 +24,9 @@
*/
#ifndef MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_
#define MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_
+#include <mxnet/c_api.h>
+#include <mxnet/kvstore.h>
+#include <ps/ps.h>
#include <queue>
#include <string>
#include <mutex>
@@ -32,8 +35,7 @@
#include <functional>
#include <future>
#include <vector>
-#include "ps/ps.h"
-#include "mxnet/kvstore.h"
+#include "../profiler/profiler.h"
#include "../operator/tensor/elemwise_binary_op-inl.h"
#include "../operator/tensor/init_op.h"
@@ -42,7 +44,8 @@ namespace kvstore {
// maintain same order in frontend.
enum class CommandType {
- kController, kSetMultiPrecision, kStopServer, kSyncMode, kSetGradientCompression,
+ kController, kSetMultiPrecision, kStopServer, kSyncMode,
+ kSetGradientCompression, kSetProfilerParams
};
enum class RequestType {
@@ -164,6 +167,7 @@ class KVStoreDistServer {
}
~KVStoreDistServer() {
+ profiler::Profiler::Get()->SetState(profiler::Profiler::ProfilerState(0));
delete ps_server_;
}
@@ -194,27 +198,37 @@ class KVStoreDistServer {
void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
CommandType recved_type = static_cast<CommandType>(recved.head);
- if (recved_type == CommandType::kStopServer) {
- exec_.Stop();
- } else if (recved_type == CommandType::kSyncMode) {
- sync_mode_ = true;
- } else if (recved_type == CommandType::kSetGradientCompression) {
- gradient_compression_->DecodeParams(recved.body);
- } else if (recved_type == CommandType::kSetMultiPrecision) {
- // uses value 1 for message id from frontend
- if (!multi_precision_) {
- multi_precision_ = true;
- CreateMultiPrecisionCopies();
- }
- } else if (recved_type == CommandType::kController) {
- // value of 0
- // let the main thread to execute ctrl, which is necessary for python
- exec_.Exec([this, recved]() {
- CHECK(controller_);
- controller_(recved.head, recved.body);
- });
- } else {
- LOG(FATAL) << "Unknown command type received " << recved.head;
+ switch (recved_type) {
+ case CommandType::kStopServer:
+ exec_.Stop();
+ break;
+ case CommandType::kSyncMode:
+ sync_mode_ = true;
+ break;
+ case CommandType::kSetGradientCompression:
+ gradient_compression_->DecodeParams(recved.body);
+ break;
+ case CommandType::kSetProfilerParams:
+ // last char is the type of profiler command
+ ProcessServerProfilerCommands(static_cast<KVStoreServerProfilerCommand>
+ (recved.body.back() - '0'),
+ recved.body);
+ break;
+ case CommandType::kSetMultiPrecision:
+ // uses value 1 for message id from frontend
+ if (!multi_precision_) {
+ multi_precision_ = true;
+ CreateMultiPrecisionCopies();
+ }
+ break;
+ case CommandType::kController:
+ // this uses value 0 for message id from frontend
+ // let the main thread to execute ctrl, which is necessary for python
+ exec_.Exec([this, recved]() {
+ CHECK(controller_);
+ controller_(recved.head, recved.body);
+ });
+ break;
}
app->Response(recved);
}
@@ -225,11 +239,11 @@ class KVStoreDistServer {
* some keys are initialized before optimizer is set.
*/
void CreateMultiPrecisionCopies() {
- for (auto const& stored_entry : store_) {
+ for (auto const &stored_entry : store_) {
const int key = stored_entry.first;
- const NDArray& stored = stored_entry.second;
+ const NDArray &stored = stored_entry.second;
if (stored.dtype() != mshadow::kFloat32) {
- auto& stored_realt = store_realt_[key];
+ auto &stored_realt = store_realt_[key];
if (stored.storage_type() == kRowSparseStorage) {
stored_realt = NDArray(kRowSparseStorage, stored.shape(), stored.ctx(),
true, mshadow::kFloat32);
@@ -237,7 +251,7 @@ class KVStoreDistServer {
stored_realt = NDArray(stored.shape(), stored.ctx(), false, mshadow::kFloat32);
}
- auto& update = update_buf_[key];
+ auto &update = update_buf_[key];
if (!update.merged.is_none()) {
if (update.merged.storage_type() == kRowSparseStorage) {
update.merged = NDArray(kRowSparseStorage, update.merged.shape(), update.merged.ctx(),
@@ -254,11 +268,60 @@ class KVStoreDistServer {
CopyFromTo(stored, stored_realt);
}
}
- for (auto const& stored_realt_entry : store_realt_) {
+ for (auto const &stored_realt_entry : store_realt_) {
stored_realt_entry.second.WaitToRead();
}
}
+ void ProcessServerProfilerCommands(KVStoreServerProfilerCommand type, const std::string& body) {
+ switch (type) {
+ case KVStoreServerProfilerCommand::kSetConfig:
+ SetProfilerConfig(body.substr(0, body.size() - 1));
+ break;
+ case KVStoreServerProfilerCommand::kState:
+ MXSetProfilerState(static_cast<int>(body.front() - '0'));
+ break;
+ case KVStoreServerProfilerCommand::kPause:
+ MXProfilePause(static_cast<int>(body.front() - '0'));
+ break;
+ case KVStoreServerProfilerCommand::kDump:
+ MXDumpProfile(static_cast<int>(body.front() - '0'));
+ break;
+ }
+ }
+
+ void SetProfilerConfig(std::string params_str) {
+ std::vector<std::string> elems;
+ mxnet::kvstore::split(params_str, ',', std::back_inserter(elems));
+ std::vector<const char*> ckeys;
+ std::vector<const char*> cvals;
+ ckeys.reserve(elems.size());
+ cvals.reserve(elems.size());
+
+ for (size_t i=0; i < elems.size(); i++) {
+ std::vector<std::string> parts;
+ mxnet::kvstore::split(elems[i], ':', std::back_inserter(parts));
+ CHECK_EQ(parts.size(), 2) << "Improper profiler config passed from worker";
+ CHECK(!parts[0].empty()) << "ProfilerConfig parameter is empty";
+ CHECK(!parts[1].empty()) << "ProfilerConfig value is empty for parameter "<< parts[0];
+ if (parts[0] == "filename") {
+ parts[1] = "rank" + std::to_string(ps::MyRank()) + "_" + parts[1];
+ }
+ char* ckey = new char[parts[0].length() + 1];
+ std::snprintf(ckey, parts[0].length() + 1, "%s", parts[0].c_str());
+ ckeys.push_back(ckey);
+
+ char* cval = new char[parts[1].length() + 1];
+ std::snprintf(cval, parts[1].length() + 1, "%s", parts[1].c_str());
+ cvals.push_back(cval);
+ }
+ MXSetProfilerConfig(elems.size(), &ckeys[0], &cvals[0]);
+ for (size_t i=0; i < ckeys.size(); i++) {
+ delete[] ckeys[i];
+ delete[] cvals[i];
+ }
+ }
+
void DataHandleEx(const ps::KVMeta& req_meta,
const ps::KVPairs<char>& req_data,
ps::KVServer<char>* server) {
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 324bc2c..4e004a3 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -40,6 +40,22 @@
namespace mxnet {
namespace kvstore {
+/*!
+ * \brief Splits a string into smaller strings using char as delimiter
+ * Example: "a,b,c,,d" is split into ["a","b","c","","d"]
+ * \param s string to split
+ * \param delim char to split string around
+ * \param result container for tokens extracted after splitting
+ */
+template<typename Out>
+void split(const std::string &s, const char delim, Out result) {
+ std::stringstream ss;
+ ss.str(s);
+ std::string item;
+ while (std::getline(ss, item, delim)) {
+ *(result++) = item;
+ }
+}
enum KeyType {
kUndefinedKey = -1,
diff --git a/tests/nightly/test_server_profiling.py b/tests/nightly/test_server_profiling.py
new file mode 100644
index 0000000..7d157a3
--- /dev/null
+++ b/tests/nightly/test_server_profiling.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import mxnet as mx
+import json
+
+key = '99'
+shape = (1200, 1200) # bigger than MXNET_KVSTORE_BIGARRAY_BOUND
+kv = mx.kv.create('dist_sync')
+
+def init_kv():
+ # init kv dns keys
+ kv.init(key, mx.nd.ones(shape))
+ kv.set_optimizer(mx.optimizer.create('sgd'))
+ return kv, kv.rank, kv.num_workers
+
+def test_sync_push_pull():
+ kv, my_rank, nworker = init_kv()
+ def check_default_keys(kv, my_rank):
+ nrepeat = 10
+ # checks pull after push in loop, because behavior during
+ # consecutive pushes doesn't offer any guarantees
+ for i in range(nrepeat):
+ kv.push(key, mx.nd.ones(shape, dtype='float32') * (my_rank+1))
+ val = mx.nd.zeros(shape, dtype='float32')
+ kv.pull(key, out=val)
+ mx.nd.waitall()
+ check_default_keys(kv, my_rank)
+
+if __name__ == "__main__":
+ server_filename_suffix = 'test_profile_server.json'
+ worker_filename_suffix = 'test_profile_worker.json'
+ mx.profiler.set_config(filename=server_filename_suffix, profile_all=True, profile_process='server')
+ mx.profiler.set_config(filename='rank' + str(kv.rank) + '_' + worker_filename_suffix, profile_all=True, profile_process='worker')
+ mx.profiler.set_state(state='run', profile_process='server')
+ mx.profiler.set_state(state='run', profile_process='worker')
+ test_sync_push_pull()
+ mx.profiler.set_state(state='stop', profile_process='server')
+ mx.profiler.set_state(state='stop', profile_process='worker')
+
+ import glob, os
+
+ # will only work when launcher mode is local, as used for integration test
+ if kv.rank == 0:
+ for rank in range(kv.num_workers):
+ for suffix in [worker_filename_suffix, server_filename_suffix]:
+ # throws value error if file is not proper json
+ filename = 'rank' + str(rank) + '_' + suffix
+ print(glob.glob('*'), os.getcwd())
+ with open(filename, 'r') as f:
+ j = json.load(f)
+
+
+