You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2017/11/19 20:11:05 UTC
[incubator-mxnet] 01/01: Revert "2bit gradient compression (#8662)"
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch revert-8662-gc-pr
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 2e58c0e162e081f7240db24f251d65f1d60b5f86
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Sat Nov 18 22:59:38 2017 -0800
Revert "2bit gradient compression (#8662)"
This reverts commit a499f892c9ee6f59ccfb57c9e431c91014078891.
---
example/image-classification/common/fit.py | 44 ++--
example/rnn/lstm_bucketing.py | 1 +
include/mxnet/c_api.h | 13 -
include/mxnet/kvstore.h | 15 --
python/mxnet/gluon/trainer.py | 12 +-
python/mxnet/kvstore.py | 62 -----
python/mxnet/module/bucketing_module.py | 17 +-
python/mxnet/module/module.py | 11 +-
src/c_api/c_api.cc | 14 --
src/kvstore/comm.h | 87 +------
src/kvstore/gradient_compression-inl.h | 155 ------------
src/kvstore/gradient_compression.cc | 193 --------------
src/kvstore/gradient_compression.cu | 40 ---
src/kvstore/gradient_compression.h | 138 ----------
src/kvstore/kvstore.cc | 2 +-
src/kvstore/kvstore_dist.h | 388 ++++++++---------------------
src/kvstore/kvstore_dist_server.h | 143 ++---------
src/kvstore/kvstore_local.h | 7 -
tests/nightly/dist_sync_kvstore.py | 120 +--------
tests/nightly/test_kvstore.py | 200 ++-------------
tools/bandwidth/measure.py | 6 +-
21 files changed, 167 insertions(+), 1501 deletions(-)
diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py
index 2b002c7..51a1abe 100755
--- a/example/image-classification/common/fit.py
+++ b/example/image-classification/common/fit.py
@@ -103,11 +103,6 @@ def add_fit_args(parser):
help='1 means test reading speed without training')
train.add_argument('--dtype', type=str, default='float32',
help='precision: float32 or float16')
- train.add_argument('--gc-type', type=str, default='none',
- help='type of gradient compression to use, \
- takes `2bit` or `none` for now')
- train.add_argument('--gc-threshold', type=float, default=0.5,
- help='threshold for 2bit gradient compression')
return train
def fit(args, network, data_loader, **kwargs):
@@ -119,9 +114,6 @@ def fit(args, network, data_loader, **kwargs):
"""
# kvstore
kv = mx.kvstore.create(args.kv_store)
- if args.gc_type != 'none':
- kv.set_gradient_compression({'type': args.gc_type,
- 'threshold': args.gc_threshold})
# logging
head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
@@ -170,10 +162,10 @@ def fit(args, network, data_loader, **kwargs):
lr_scheduler = lr_scheduler
optimizer_params = {
- 'learning_rate': lr,
- 'wd' : args.wd,
- 'lr_scheduler': lr_scheduler,
- 'multi_precision': True}
+ 'learning_rate': lr,
+ 'wd' : args.wd,
+ 'lr_scheduler': lr_scheduler,
+ 'multi_precision': True}
# Only a limited number of optimizers have 'momentum' property
has_momentum = {'sgd', 'dcasgd', 'nag'}
@@ -203,17 +195,17 @@ def fit(args, network, data_loader, **kwargs):
# run
model.fit(train,
- begin_epoch = args.load_epoch if args.load_epoch else 0,
- num_epoch = args.num_epochs,
- eval_data = val,
- eval_metric = eval_metrics,
- kvstore = kv,
- optimizer = args.optimizer,
- optimizer_params = optimizer_params,
- initializer = initializer,
- arg_params = arg_params,
- aux_params = aux_params,
- batch_end_callback = batch_end_callbacks,
- epoch_end_callback = checkpoint,
- allow_missing = True,
- monitor = monitor)
+ begin_epoch = args.load_epoch if args.load_epoch else 0,
+ num_epoch = args.num_epochs,
+ eval_data = val,
+ eval_metric = eval_metrics,
+ kvstore = kv,
+ optimizer = args.optimizer,
+ optimizer_params = optimizer_params,
+ initializer = initializer,
+ arg_params = arg_params,
+ aux_params = aux_params,
+ batch_end_callback = batch_end_callbacks,
+ epoch_end_callback = checkpoint,
+ allow_missing = True,
+ monitor = monitor)
diff --git a/example/rnn/lstm_bucketing.py b/example/rnn/lstm_bucketing.py
index 0e7f064..2e7bc65 100644
--- a/example/rnn/lstm_bucketing.py
+++ b/example/rnn/lstm_bucketing.py
@@ -48,6 +48,7 @@ parser.add_argument('--batch-size', type=int, default=32,
parser.add_argument('--disp-batches', type=int, default=50,
help='show progress for every n batches')
+
def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0):
if not os.path.isfile(fname):
raise IOError("Please use get_ptb_data.sh to download requied file (data/ptb.train.txt)")
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index a81193e..77fc6a5 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1550,19 +1550,6 @@ MXNET_DLL int MXInitPSEnv(mx_uint num_vars,
*/
MXNET_DLL int MXKVStoreCreate(const char *type,
KVStoreHandle *out);
-
-/*!
- * \brief Set parameters to use low-bit compressed gradients
- * \param handle handle to the kvstore
- * \param keys keys for compression parameters
- * \param vals values for compression parameters
- * \return 0 when success, -1 when failure happens
- */
-MXNET_DLL int MXKVStoreSetGradientCompression(KVStoreHandle handle,
- mx_uint num_params,
- const char** keys,
- const char** vals);
-
/*!
* \brief Delete a KVStore handle.
* \param handle handle to the kvstore
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index 4e99a9c..1649c43 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -31,7 +31,6 @@
#include <string>
#include <functional>
#include <atomic>
-#include "../../src/kvstore/gradient_compression.h"
#include "./ndarray.h"
#if MXNET_USE_DIST_KVSTORE
#include "ps/ps.h"
@@ -66,14 +65,6 @@ class KVStore {
*/
inline const std::string& type() { return type_; }
- /**
- * \brief Set parameters to use low-bit compressed gradients
- * \param compression_type type of compression
- * \param threshold threshold for 2bit compression
- */
- virtual void SetGradientCompression(const std::vector<std::pair<std::string, std::string> >
- & kwargs) = 0;
-
/*!
* \brief Initialize a list of key-value pair to the store.
*
@@ -397,12 +388,6 @@ class KVStore {
*/
std::string type_;
- /** \brief Gradient compression object starts with GC_NONE mode
- * Used if SetGradientCompression sets the type.
- * Currently there is no support for un-setting gradient compression
- */
- std::shared_ptr<kvstore::GradientCompression> gradient_compression_;
-
/**
* \brief whether to do barrier when finalize
*/
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index f3a1460..115d1ff 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -44,11 +44,6 @@ class Trainer(object):
kvstore : str or KVStore
kvstore type for multi-gpu and distributed training. See help on
:any:`mxnet.kvstore.create` for more information.
- compression_params : dict
- Specifies type of gradient compression and additional arguments depending
- on the type of compression being used. For example, 2bit compression requires a threshold.
- Arguments would then be {'type':'2bit', 'threshold':0.5}
- See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
Properties
----------
@@ -56,8 +51,7 @@ class Trainer(object):
The current learning rate of the optimizer. Given an Optimizer object
optimizer, its learning rate can be accessed as optimizer.learning_rate.
"""
- def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
- compression_params=None):
+ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device'):
if isinstance(params, (dict, ParameterDict)):
params = list(params.values())
if not isinstance(params, (list, tuple)):
@@ -71,7 +65,7 @@ class Trainer(object):
"First argument must be a list or dict of Parameters, " \
"got list of %s."%(type(param)))
self._params.append(param)
- self._compression_params = compression_params
+
optimizer_params = optimizer_params if optimizer_params else {}
self._scale = optimizer_params.get('rescale_grad', 1.0)
self._contexts = self._check_contexts()
@@ -110,8 +104,6 @@ class Trainer(object):
kvstore, update_on_kvstore = _create_kvstore(self._kvstore, len(self._contexts),
arg_arrays)
if kvstore:
- if self._compression_params:
- kvstore.set_gradient_compression(self._compression_params)
if 'dist' in kvstore.type:
update_on_kvstore = False
for i, param in enumerate(self._params):
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index bf42455..8625303 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -64,16 +64,6 @@ def _ctype_key_value(keys, vals):
else c_array_buf(ctypes.c_int, array('i', [keys] * len(vals)))
return (c_keys, c_handle_array(vals), use_str_keys)
-def _ctype_dict(param_dict):
- """
- Returns ctype arrays for keys and values(converted to strings) in a dictionary
- """
- assert(isinstance(param_dict, dict)), \
- "unexpected type for param_dict: " + str(type(param_dict))
- c_keys = c_array(ctypes.c_char_p, [c_str(k) for k in param_dict.keys()])
- c_vals = c_array(ctypes.c_char_p, [c_str(str(v)) for v in param_dict.values()])
- return (c_keys, c_vals)
-
def _updater_wrapper(updater):
"""A wrapper for the user-defined handle."""
def updater_handle(key, lhs_handle, rhs_handle, _):
@@ -360,58 +350,6 @@ class KVStore(object):
check_call(_LIB.MXKVStorePullRowSparse(
self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
- def set_gradient_compression(self, compression_params):
- """ Specifies type of low-bit quantization for gradient compression \
- and additional arguments depending on the type of compression being used.
-
- 2bit Gradient Compression takes a positive float `threshold`.
- The technique works by thresholding values such that positive values in the
- gradient above threshold will be set to threshold. Negative values whose absolute
- values are higher than threshold, will be set to the negative of threshold.
- Values whose absolute values are less than threshold will be set to 0.
- By doing so, each value in the gradient is in one of three states. 2bits are
- used to represent these states, and every 16 float values in the original
- gradient can be represented using one float. This compressed representation
- can reduce communication costs. The difference between these thresholded values and
- original values is stored at the sender's end as residual and added to the
- gradient in the next iteration.
-
- When kvstore is 'local', gradient compression is used to reduce communication
- between multiple devices (gpus). Gradient is quantized on each GPU which
- computed the gradients, then sent to the GPU which merges the gradients. This
- receiving GPU dequantizes the gradients and merges them. Note that this
- increases memory usage on each GPU because of the residual array stored.
-
- When kvstore is 'dist', gradient compression is used to reduce communication
- from worker to sender. Gradient is quantized on each worker which
- computed the gradients, then sent to the server which dequantizes
- this data and merges the gradients from each worker. Note that this
- increases CPU memory usage on each worker because of the residual array stored.
- Only worker to server communication is compressed in this setting.
- If each machine has multiple GPUs, currently this GPU to GPU or GPU to CPU communication
- is not compressed. Server to worker communication (in the case of pull)
- is also not compressed.
-
- To use 2bit compression, we need to specify `type` as `2bit`.
- Only specifying `type` would use default value for the threshold.
- To completely specify the arguments for 2bit compression, we would need to pass
- a dictionary which includes `threshold` like:
- {'type': '2bit', 'threshold': 0.5}
-
- Parameters
- ----------
- compression_params : dict
- A dictionary specifying the type and parameters for gradient compression.
- The key `type` in this dictionary is a
- required string argument and specifies the type of gradient compression.
- Currently `type` can be only `2bit`
- Other keys in this dictionary are optional and specific to the type
- of gradient compression.
- """
- ckeys, cvals = _ctype_dict(compression_params)
- check_call(_LIB.MXKVStoreSetGradientCompression(self.handle,
- mx_uint(len(compression_params)),
- ckeys, cvals))
def set_optimizer(self, optimizer):
""" Registers an optimizer with the kvstore.
diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index 4a5330e..dd6cafb 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -54,16 +54,10 @@ class BucketingModule(BaseModule):
Instead they are initialized to 0 and can be set by set_states()
group2ctxs : list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
- compression_params : dict
- Specifies type of gradient compression and additional arguments depending
- on the type of compression being used. For example, 2bit compression requires a threshold.
- Arguments would then be {'type':'2bit', 'threshold':0.5}
- See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
"""
def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
context=ctx.cpu(), work_load_list=None,
- fixed_param_names=None, state_names=None, group2ctxs=None,
- compression_params=None):
+ fixed_param_names=None, state_names=None, group2ctxs=None):
super(BucketingModule, self).__init__(logger=logger)
assert default_bucket_key is not None
@@ -81,7 +75,6 @@ class BucketingModule(BaseModule):
_check_input_names(symbol, state_names, "state", True)
_check_input_names(symbol, fixed_param_names, "fixed_param", True)
- self._compression_params = compression_params
self._fixed_param_names = fixed_param_names
self._state_names = state_names
self._context = context
@@ -329,9 +322,7 @@ class BucketingModule(BaseModule):
module = Module(symbol, data_names, label_names, logger=self.logger,
context=self._context, work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
- state_names=self._state_names,
- group2ctxs=self._group2ctxs,
- compression_params=self._compression_params)
+ state_names=self._state_names, group2ctxs=self._group2ctxs)
module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
force_rebind=False, shared_module=None, grad_req=grad_req)
self._curr_module = module
@@ -361,9 +352,7 @@ class BucketingModule(BaseModule):
logger=self.logger, context=self._context,
work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
- state_names=self._state_names,
- group2ctxs=self._group2ctxs,
- compression_params=self._compression_params)
+ state_names=self._state_names, group2ctxs=self._group2ctxs)
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index a9c6516..8301330 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -61,16 +61,10 @@ class Module(BaseModule):
Instead they are initialized to 0 and can be set by `set_states()`.
group2ctxs : list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
- compression_params : dict
- Specifies type of gradient compression and additional arguments depending
- on the type of compression being used. For example, 2bit compression requires a threshold.
- Arguments would then be {'type':'2bit', 'threshold':0.5}
- See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
"""
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
logger=logging, context=ctx.cpu(), work_load_list=None,
- fixed_param_names=None, state_names=None, group2ctxs=None,
- compression_params=None):
+ fixed_param_names=None, state_names=None, group2ctxs=None):
super(Module, self).__init__(logger=logger)
if isinstance(context, ctx.Context):
@@ -109,7 +103,6 @@ class Module(BaseModule):
self._aux_params = None
self._params_dirty = False
- self._compression_params = compression_params
self._optimizer = None
self._kvstore = None
self._update_on_kvstore = None
@@ -532,8 +525,6 @@ class Module(BaseModule):
self._updater = None
if kvstore:
- if self._compression_params:
- kvstore.set_gradient_compression(self._compression_params)
# copy initialized local parameters to kvstore
_initialize_kvstore(kvstore=kvstore,
param_arrays=self._exec_group.param_arrays,
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 027f00b..0dde004 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -748,20 +748,6 @@ int MXKVStoreCreate(const char *type,
API_END();
}
-int MXKVStoreSetGradientCompression(KVStoreHandle handle, mx_uint num_params,
- const char** keys, const char** vals) {
- API_BEGIN();
- std::vector<std::pair<std::string, std::string> > params;
- for (mx_uint i = 0; i < num_params; ++i) {
- std::pair<std::string, std::string> p;
- p.first = keys[i];
- p.second = vals[i];
- params.push_back(p);
- }
- static_cast<KVStore*>(handle)->SetGradientCompression(params);
- API_END();
-}
-
int MXKVStoreFree(KVStoreHandle handle) {
API_BEGIN();
delete static_cast<KVStore*>(handle);
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 5e15c2a..fcf1e6b 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -31,7 +31,6 @@
#include <tuple>
#include <thread>
#include "mxnet/ndarray.h"
-#include "gradient_compression.h"
#include "../ndarray/ndarray_function.h"
#include "../operator/tensor/sparse_retain-inl.h"
namespace mxnet {
@@ -81,18 +80,8 @@ class Comm {
return pinned_ctx_;
}
- /**
- * \brief Sets gradient compression parameters to be able to
- * perform reduce with compressed gradients
- */
- void SetGradientCompression(std::shared_ptr<GradientCompression> gc) {
- gc_ = gc;
- }
-
protected:
Context pinned_ctx_;
-
- std::shared_ptr<GradientCompression> gc_;
};
/**
@@ -496,7 +485,14 @@ class CommDevice : public Comm {
}
}
- void InitBuffersAndComm(const std::vector<NDArray>& src) {
+ const NDArray& Reduce(int key, const std::vector<NDArray>& src,
+ int priority) override {
+ // avoid extra copy for single device, but it may bring problems for
+ // abnormal usage of kvstore
+ if (src.size() == 1) {
+ return src[0];
+ }
+
if (!inited_) {
std::vector<Context> devs;
for (const auto& a : src) {
@@ -507,23 +503,7 @@ class CommDevice : public Comm {
EnableP2P(devs);
}
}
- }
-
- const NDArray& Reduce(int key, const std::vector<NDArray>& src,
- int priority) override {
- // when this reduce is called from kvstore_dist, gc is not set
- // we don't do compression twice in dist_sync_device
- if ((gc_ != nullptr) && (gc_->get_type() != CompressionType::kNone)) {
- return ReduceCompressed(key, src, priority);
- }
-
- // avoid extra copy for single device, but it may bring problems for
- // abnormal usage of kvstore
- if (src.size() == 1) {
- return src[0];
- }
- InitBuffersAndComm(src);
auto& buf = merge_buf_[key];
std::vector<NDArray> reduce(src.size());
CopyFromTo(src[0], &(buf.merged), priority);
@@ -546,52 +526,7 @@ class CommDevice : public Comm {
}
ElementwiseSum(reduce, &buf.merged);
- return buf.merged;
- }
-
- const NDArray& ReduceCompressed(int key, const std::vector<NDArray>& src,
- int priority) {
- InitBuffersAndComm(src);
- auto& buf = merge_buf_[key];
- std::vector<NDArray> reduce(src.size());
- if (buf.copy_buf.empty()) {
- // one buf for each context
- buf.copy_buf.resize(src.size());
- buf.compressed_recv_buf.resize(src.size());
- buf.compressed_send_buf.resize(src.size());
- buf.residual.resize(src.size());
- for (size_t i = 0; i < src.size(); ++i) {
- buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(),
- false, buf.merged.dtype());
- buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(),
- false, buf.merged.dtype());
- buf.residual[i] = 0;
- int64_t small_size = gc_->GetCompressedSize(buf.merged.shape().Size());
- buf.compressed_recv_buf[i] = NDArray(TShape{small_size}, buf.merged.ctx(),
- false, buf.merged.dtype());
- buf.compressed_send_buf[i] = NDArray(TShape{small_size}, src[i].ctx(),
- false, buf.merged.dtype());
- }
- }
-
- for (size_t i = 0; i < src.size(); ++i) {
- // compress before copy
- // this is done even if the data is on same context as copy_buf because
- // we don't want the training to be biased towards data on this GPU
- gc_->Quantize(src[i], &(buf.compressed_send_buf[i]), &(buf.residual[i]), priority);
-
- if (buf.compressed_send_buf[i].ctx() != buf.compressed_recv_buf[i].ctx()) {
- CopyFromTo(buf.compressed_send_buf[i], &(buf.compressed_recv_buf[i]), priority);
- } else {
- // avoid memory copy when they are on same context
- buf.compressed_recv_buf[i] = buf.compressed_send_buf[i];
- }
-
- gc_->Dequantize(buf.compressed_recv_buf[i], &(buf.copy_buf[i]), priority);
- reduce[i] = buf.copy_buf[i];
- }
- ElementwiseSum(reduce, &buf.merged);
return buf.merged;
}
@@ -704,12 +639,6 @@ class CommDevice : public Comm {
NDArray merged;
/// \brief the gpu buffer
std::vector<NDArray> copy_buf;
- /// \brief the residual buffer for gradient compression
- std::vector<NDArray> residual;
- /// \brief the small buffer for compressed data in sender
- std::vector<NDArray> compressed_send_buf;
- /// \brief the small buffer for compressed data in receiver
- std::vector<NDArray> compressed_recv_buf;
};
std::unordered_map<int, BufferEntry> merge_buf_;
bool inited_;
diff --git a/src/kvstore/gradient_compression-inl.h b/src/kvstore/gradient_compression-inl.h
deleted file mode 100644
index 9b69bd1..0000000
--- a/src/kvstore/gradient_compression-inl.h
+++ /dev/null
@@ -1,155 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file gradient_compression-inl.h
- * \author Rahul Huilgol
- * \brief Declares and defines functions used to quantize and dequantize data
- */
-#ifndef MXNET_KVSTORE_GRADIENT_COMPRESSION_INL_H_
-#define MXNET_KVSTORE_GRADIENT_COMPRESSION_INL_H_
-
-#include <vector>
-#include "../operator/mxnet_op.h"
-
-namespace mxnet {
-namespace kvstore {
-
-// these gpu functions are defined in gradient_compression.cu
-void Quantize2BitImpl(mshadow::Stream<mshadow::gpu> *s, const std::vector<mxnet::TBlob> &inputs,
- const float threshold);
-void Dequantize2BitImpl(mshadow::Stream<mshadow::gpu> *s, const std::vector<mxnet::TBlob> &inputs,
- const float threshold);
-
-struct quantize_2bit {
- MSHADOW_XINLINE static void Map(int out_block_id,
- int original_size,
- float *out,
- float *grad,
- float *residual,
- const float neg_threshold,
- const float pos_threshold) {
- // this block contains the compressed representation of
- // upto 16 values starting from out_block_id*16
- float *compr_block = out + out_block_id;
- // init to 0
- *compr_block = 0;
- // start and end are indices in original grad array
- const int start = out_block_id << 4;
- const int end = (start + 16 <= original_size) ? start + 16 : original_size;
- // cast as char* to manipulate bits of float addresses
- char *block_ptr = reinterpret_cast < char * > (compr_block);
- // masks to set bits when value meets pos_threshold
- // 0xc0 is mask when value is to be represented by the first two bits in a char*
- // 0xc0 means first two bits are set to 11
- const uint8_t posbits[] = {0xc0, 0x30, 0x0c, 0x03};
- // masks to set bits when value meets neg_threshold
- const uint8_t negbits[] = {0x80, 0x20, 0x08, 0x02};
- for (int i = start; i < end; i++) {
- // adds offset to reach appropriate byte
- char *curr_byte = block_ptr + ((i - start) >> 2);
- // adds gradient to existing residual to get updated grad
- residual[i] += grad[i];
- if (residual[i] >= pos_threshold) {
- // set data to 11
- *curr_byte |= posbits[(i & 3)];
- // reduce residual by pos_threshold
- residual[i] -= pos_threshold;
- } else if (residual[i] <= neg_threshold) {
- // set data to 10
- *curr_byte |= negbits[(i & 3)];
- residual[i] -= neg_threshold;
- }
- }
- }
-};
-
-template<typename xpu>
-void Quantize2BitKernelLaunch(mshadow::Stream<xpu> *s, const std::vector<mxnet::TBlob> &inputs,
- const float threshold) {
- mxnet::op::mxnet_op::Kernel<quantize_2bit, xpu>
- ::Launch(s,
- inputs[2].Size(), // compressed array size
- inputs[0].Size(), // original size
- inputs[2].dptr<float>(), // compressed array
- inputs[0].dptr<float>(), // original array
- inputs[1].dptr<float>(), // residual array
- -1 *threshold, // negative threshold
- threshold); // positive threshold
-}
-
-struct dequantize_2bit {
- MSHADOW_XINLINE static void Map(int i,
- float *out,
- float *in,
- const float neg_threshold,
- const float pos_threshold) {
- // get position of dequantized value to fill
- float *outval = out + i;
- // gets byte which holds quantized value for this position
- char *ch_ptr = reinterpret_cast<char *>(in + (i >> 4));
- ch_ptr += ((i & 15) >> 2);
- // masks used to quantize data
- const uint8_t posbits[] = {0xc0, 0x30, 0x0c, 0x03};
- const uint8_t negbits[] = {0x80, 0x20, 0x08, 0x02};
- // col denotes which two bits of a byte are set for this value
- // col=0 implies first two bits, col=3 implies last two bits,...
- const int col = i & 3;
- const uint8_t mask = posbits[col];
- const uint8_t negmask = negbits[col];
- const uint8_t masked = *ch_ptr & mask;
- if (masked == mask) {
- *outval = pos_threshold;
- } else if (masked == negmask) {
- // use posbits for mask as posbits are both 1s
- // then compare masked with negbits to see if only negbits were set
- *outval = neg_threshold;
- } else {
- *outval = 0;
- }
- }
-};
-
-template<typename xpu>
-void Dequantize2BitKernelLaunch(mshadow::Stream<xpu> *s, const std::vector<mxnet::TBlob> &inputs,
- const float threshold) {
- mxnet::op::mxnet_op::Kernel<dequantize_2bit, xpu>
- ::Launch(s,
- inputs[1].Size(), // original size
- inputs[1].dptr<float>(), // out array
- inputs[0].dptr<float>(), // compressed array
- -1 *threshold, // negative threshold
- threshold); // positive threshold
-}
-
-inline void Quantize2BitImpl(mshadow::Stream<mshadow::cpu> *s,
- const std::vector<mxnet::TBlob> &inputs,
- const float threshold) {
- Quantize2BitKernelLaunch(s, inputs, threshold);
-}
-
-inline void Dequantize2BitImpl(mshadow::Stream<mshadow::cpu> *s,
- const std::vector<mxnet::TBlob> &inputs,
- const float threshold) {
- Dequantize2BitKernelLaunch(s, inputs, threshold);
-}
-} // namespace kvstore
-} // namespace mxnet
-
-#endif // MXNET_KVSTORE_GRADIENT_COMPRESSION_INL_H_
diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc
deleted file mode 100644
index b8c626c..0000000
--- a/src/kvstore/gradient_compression.cc
+++ /dev/null
@@ -1,193 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file gradient_compression.cc
- * \brief Gradient compression for kvstore
- * \author Rahul Huilgol
- */
-
-#include <sstream>
-#include <vector>
-#include "gradient_compression.h"
-#include "gradient_compression-inl.h"
-
-namespace mxnet {
-namespace kvstore {
-
-/*!
- * \brief Splits a string into smaller strings using char as delimiter
- * Example: "a,b,c,,d" is split into ["a","b","c","","d"]
- * \param s string to split
- * \param delim char to split string around
- * \param result container for tokens extracted after splitting
- */
-template<typename Out>
-void split(const std::string &s, const char delim, Out result) {
- std::stringstream ss;
- ss.str(s);
- std::string item;
- while (std::getline(ss, item, delim)) {
- *(result++) = item;
- }
-}
-
-DMLC_REGISTER_PARAMETER(GradientCompressionParam);
-
-GradientCompression::GradientCompression() {
- type_ = CompressionType::kNone;
-}
-
-void GradientCompression::SetParams(const std::vector<std::pair<std::string, std::string> >
- & kwargs) {
- GradientCompressionParam params;
- params.InitAllowUnknown(kwargs);
- CHECK_GT(params.threshold, 0) << "threshold must be greater than 0";
- if (params.type == "2bit") {
- SetTwoBitCompression(params.threshold);
- } else {
- LOG(FATAL) << "Unknown type for gradient compression " << params.type;
- }
-}
-
-CompressionType GradientCompression::get_type() {
- return type_;
-}
-
-std::string GradientCompression::get_type_str() {
- return std::to_string(static_cast<int>(type_));
-}
-
-void GradientCompression::SetTwoBitCompression(const float threshold) {
- type_ = CompressionType::kTwoBit;
- threshold_ = threshold;
-}
-
-std::string GradientCompression::EncodeParams() {
- using namespace std; // to reduce length of next line
- string rval = get_type_str();
- if (type_ == CompressionType::kTwoBit) {
- rval += "," + to_string(threshold_);
- }
- return rval;
-}
-
-void GradientCompression::DecodeParams(const std::string &s) {
- std::vector<std::string> elems;
- split(s, ',', std::back_inserter(elems));
- type_ = static_cast<CompressionType>(stoi(elems[0]));
- if (elems.size() > 1) {
- if (!elems[1].empty()) {
- threshold_ = stof(elems[1]);
- }
- }
-}
-
-int GradientCompression::GetCompressionFactor() {
- if (type_ == CompressionType::kTwoBit) {
- return 16;
- } else {
- LOG(FATAL) << "Unsupported compression type: " << get_type_str();
- return 0;
- }
-}
-
-int64_t GradientCompression::GetCompressedSize(const int64_t original_size) {
- const int bits = GetCompressionFactor();
- return ((original_size % bits == 0) ?
- original_size / bits :
- original_size / bits + 1);
-}
-
-void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *to,
- mxnet::NDArray *residual, const int priority) {
- CHECK(from.shape().ndim() != 0) << "source operand has zero dimension shape";
- CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape";
- CHECK(residual->shape().ndim() != 0) << "residual operand has zero dimension shape";
- const int a = from.ctx().dev_mask();
- const int b = to->ctx().dev_mask();
- const float threshold = threshold_;
- if (type_ == CompressionType::kTwoBit) {
- if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) {
- mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) {
- std::vector<mxnet::TBlob> inputs = {from.data(), residual->data(), to->data()};
- Quantize2BitImpl(ctx.get_stream<mshadow::cpu>(), inputs, threshold);
- }, from.ctx(), {from.var()}, {to->var(), residual->var()},
- mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("QuantizeCPU"));
- } else {
-#if MXNET_USE_CUDA
- if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) {
- mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) {
- std::vector<mxnet::TBlob> inputs = {from.data(), residual->data(), to->data()};
- Quantize2BitImpl(ctx.get_stream<mshadow::gpu>(), inputs, threshold);
- // Wait GPU kernel to complete
- ctx.get_stream<mshadow::gpu>()->Wait();
- }, from.ctx(), {from.var()}, {to->var(), residual->var()},
- mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("QuantizeGPU"));
- } else {
- LOG(FATAL) << "unknown device mask";
- }
-#else
- LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
-#endif
- }
- } else {
- LOG(FATAL) << "Unsupported quantization of type " << get_type_str();
- }
-}
-
-void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to,
- const int priority) {
- CHECK(from.shape().ndim() != 0) << "source operands has zero dimension shape";
- CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape";
- const int a = from.ctx().dev_mask();
- const int b = to->ctx().dev_mask();
- const float threshold = threshold_;
- if (type_ == CompressionType::kTwoBit) {
- if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) {
- mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) {
- std::vector<mxnet::TBlob> inputs = {from.data(), to->data()};
- Dequantize2BitImpl(ctx.get_stream<mshadow::cpu>(), inputs, threshold);
- }, from.ctx(), {from.var()}, {to->var()},
- mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("DequantizeCPU"));
- } else {
-#if MXNET_USE_CUDA
- if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) {
- mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) {
- std::vector<mxnet::TBlob> inputs = {from.data(), to->data()};
- Dequantize2BitImpl(ctx.get_stream<mshadow::gpu>(), inputs, threshold);
- // Wait GPU kernel to complete
- ctx.get_stream<mshadow::gpu>()->Wait();
- }, from.ctx(), {from.var()}, {to->var()},
- mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("DequantizeGPU"));
- } else {
- LOG(FATAL) << "unknown device mask";
- }
-#else
- LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
-#endif
- }
- } else {
- LOG(FATAL) << "Unsupported dequantization of type " << get_type_str();
- }
-}
-
-} // namespace kvstore
-} // namespace mxnet
-
diff --git a/src/kvstore/gradient_compression.cu b/src/kvstore/gradient_compression.cu
deleted file mode 100644
index b0d9662..0000000
--- a/src/kvstore/gradient_compression.cu
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file gradient_compression.cu
- * \author Rahul Huilgol
- * \brief Implementation for gpu version of code
- */
-
-#include "gradient_compression-inl.h"
-
-namespace mxnet {
-namespace kvstore {
-void Quantize2BitImpl(mshadow::Stream<gpu>* s, const std::vector<TBlob>& inputs,
- const float threshold) {
- Quantize2BitKernelLaunch(s, inputs, threshold);
-}
-
-void Dequantize2BitImpl(mshadow::Stream<gpu>* s, const std::vector<TBlob>& inputs,
- const float threshold) {
- Dequantize2BitKernelLaunch(s, inputs, threshold);
-}
-} // namespace kvstore
-} // namespace mxnet
diff --git a/src/kvstore/gradient_compression.h b/src/kvstore/gradient_compression.h
deleted file mode 100644
index f40b45f..0000000
--- a/src/kvstore/gradient_compression.h
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file gradient_compression.h
- * \brief Gradient compression for kvstore
- * \author Rahul Huilgol
- */
-
-#ifndef MXNET_KVSTORE_GRADIENT_COMPRESSION_H_
-#define MXNET_KVSTORE_GRADIENT_COMPRESSION_H_
-#include <dmlc/parameter.h>
-#include <string>
-#include <utility>
-#include <vector>
-#include "mxnet/ndarray.h"
-
-namespace mxnet {
-namespace kvstore {
-
-enum class CompressionType {
- kNone, kTwoBit
-};
-
-struct GradientCompressionParam : public dmlc::Parameter<GradientCompressionParam> {
- std::string type;
- float threshold;
- DMLC_DECLARE_PARAMETER(GradientCompressionParam) {
- DMLC_DECLARE_FIELD(type)
- .describe("Type of gradient compression to use, like `2bit` for example");
- DMLC_DECLARE_FIELD(threshold).set_default(0.5)
- .describe("Threshold to use for 2bit gradient compression");
- }
-};
-
-class GradientCompression {
- public:
- GradientCompression();
-
- virtual ~GradientCompression() {}
-
- /*!
- * \brief sets parameters for gradient compression
- * \param kwargs a vector of pair of strings. A pair represents key and value
- * of the parameter. Will be parsed by GradientCompressionParam
- */
- void SetParams(const std::vector<std::pair<std::string, std::string> >& kwargs);
-
- /*!
- * \brief returns type of compression if any
- */
- CompressionType get_type();
-
- /*!
- * \brief returns as string the enum value of compression type
- */
- std::string get_type_str();
-
- /*!
- * \brief sets two bit gradient compression
- * \param threshold float value used for thresholding gradients
- */
- void SetTwoBitCompression(const float threshold);
-
- /*!
- * \brief encodes parameters of gc into a string
- */
- std::string EncodeParams();
-
- /*!
- * \brief decodes parameters of gc from a string and assigns them to member variables
- */
- void DecodeParams(const std::string &s);
-
- /*!
- * \brief returns compression factor, which is the factor by which size of gradient
- * reduces when using a particular type of compression
- */
- int GetCompressionFactor();
-
- /*!
- * \brief returns the size of compressed gradients given an original sized gradient array
- */
- int64_t GetCompressedSize(const int64_t original_size);
-
- /*!
- * \brief Issues quantize operation to be scheduled by the engine
- * Compresses `from` into `to` and accumulates the quantization error
- * into 'residual', using the quantization of type `type_`
- * \param from the ndarray containing original data to be quantized
- * \param to the target ndarray which contains quantized data
- * \param residual the ndarray which accumulates quantization error
- * \param priority Priority of the action.
- */
- void Quantize(const mxnet::NDArray &from, mxnet::NDArray *to,
- mxnet::NDArray *residual, const int priority);
-
- /*!
- * \brief Issues dequantize operation to be scheduled by the engine
- * Decompresses `from` into `to` using current parameters of `type` and `threshold`
- * \param from the ndarray containing quantized data
- * \param to the target ndarray which contains final dequantized data
- * \param priority Priority of the action.
- */
- void Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to, const int priority);
-
- private:
- /*!
- * \brief denotes the type of gradient compression which has been set
- */
- CompressionType type_;
-
- /*!
- * \brief denotes threshold used for quantization and dequantization
- * Must be a positive value. All positive gradients will be thresholded to `threshold_` and
- * all negative gradients will be thresholded to -1*`threshold_`
- */
- float threshold_ = 0;
-};
-} // namespace kvstore
-} // namespace mxnet
-#endif // MXNET_KVSTORE_GRADIENT_COMPRESSION_H_
diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc
index ac15873..ac37d5d 100644
--- a/src/kvstore/kvstore.cc
+++ b/src/kvstore/kvstore.cc
@@ -49,7 +49,7 @@ KVStore* KVStore::Create(const char *type_name) {
kv = new kvstore::KVStoreDist(use_device_comm);
if (!has("_async") && kv->IsWorkerNode() && kv->get_rank() == 0) {
// configure the server to be the sync mode
- kv->SendCommandToServers(static_cast<int>(kvstore::CommandType::kSyncMode), "");
+ kv->SendCommandToServers(kvstore::kSyncMode, "");
}
#else
LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to use " << tname;
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index b00d0de..571767d 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -69,7 +69,7 @@ class KVStoreDist : public KVStoreLocal {
Barrier();
if (get_rank() == 0) {
// stop the executor at servers
- SendCommandToServers(static_cast<int>(CommandType::kStopServer), "");
+ SendCommandToServers(kStopServer, "");
}
}
ps::Finalize(barrier_before_exit_);
@@ -86,15 +86,6 @@ class KVStoreDist : public KVStoreLocal {
}
}
- void SetGradientCompression(const std::vector<std::pair<std::string, std::string> >
- & kwargs) override {
- KVStoreLocal::SetGradientCompression(kwargs);
- if (get_rank() == 0) {
- SendCommandToServers(static_cast<int>(CommandType::kSetGradientCompression),
- gradient_compression_->EncodeParams());
- }
- }
-
void Barrier() override {
ps::Postoffice::Get()->Barrier(ps::kWorkerGroup);
}
@@ -141,38 +132,6 @@ class KVStoreDist : public KVStoreLocal {
}
private:
- /**
- * \brief struct for ps keys and lens
- */
- struct PSKV {
- ps::SArray<ps::Key> keys; // n keys
- ps::SArray<int> lens; // the length of the i-th value
- int size;
- };
-
- struct ComprPSKV {
- PSKV push;
- PSKV pull;
- };
-
- /**
- * \brief cache all key partitions
- *
- * `ps_kv_` is used for pushes and pulls without gradient compression
- * `compr_ps_kv_` is used for gradient compression. It contains different
- * pskv for push and pull because sizes would be different in both cases.
- * Note: `ps_kv_[k]` for some key k may not be the same as `compr_ps_kv_[k].pull`
- * This is because sharding may cause slightly different divisions when size is
- * not perfectly divisible.
- */
- std::unordered_map<int, PSKV> ps_kv_;
- std::unordered_map<int, ComprPSKV> compr_ps_kv_;
-
- /**
- * \brief serialize access to ps_kv_ or push_ps_kv_/pull_ps_kv_ while encoding keys
- */
- std::mutex mu_;
-
void InitImpl(const std::vector<int>& keys,
const std::vector<NDArray>& values) override {
CheckUnique(keys);
@@ -184,7 +143,6 @@ class KVStoreDist : public KVStoreLocal {
// wait until the push is finished
for (const int key : keys) {
comm_buf_[key].WaitToWrite();
- compr_buf_[key].WaitToWrite();
}
} else {
// do nothing
@@ -224,10 +182,7 @@ class KVStoreDist : public KVStoreLocal {
RunContext rctx, Engine::CallbackOnComplete cb) {
// convert to ps keys
size_t size = recv_buf.shape().Size();
-
- PSKV& pskv = (gradient_compression_->get_type() == CompressionType::kNone) ?
- EncodeDefaultKey(key, size, false) :
- EncodeCompressedKey(key, size, false);
+ PSKV& pskv = EncodeKey(key, size);
#if MKL_EXPERIMENTAL == 1
mkl_set_tblob_eager_mode(recv_buf.data());
#endif
@@ -235,11 +190,8 @@ class KVStoreDist : public KVStoreLocal {
// false means not to delete data when SArray is deleted
auto vals = new ps::SArray<real_t>(data, size, false);
// issue pull
- int cmd = (gradient_compression_->get_type() != CompressionType::kNone) ?
- static_cast<int>(DataHandleType::kCompressedPushPull) :
- static_cast<int>(DataHandleType::kDefaultPushPull);
CHECK_NOTNULL(ps_worker_)->ZPull(
- pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
+ pskv.keys, vals, &pskv.lens, kDefaultPushPull, [vals, cb](){ delete vals; cb(); });
};
CHECK_NOTNULL(Engine::Get())->PushAsync(
@@ -249,7 +201,7 @@ class KVStoreDist : public KVStoreLocal {
{recv_buf.var()},
FnProperty::kNormal,
priority,
- PROFILER_MESSAGE("KVStoreDistDefaultStoragePull"));
+ PROFILER_MESSAGE("KVStoreDistDefaultPull"));
comm_->Broadcast(key, recv_buf, grouped_vals[i], priority);
}
@@ -309,121 +261,103 @@ class KVStoreDist : public KVStoreLocal {
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals);
for (size_t i = 0; i < uniq_keys.size(); ++i) {
- // merge over devices
+ // merge over devcies
int key = uniq_keys[i];
const auto& vals = grouped_vals[i];
NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0];
+ auto& send_buf = comm_buf_[key];
const auto storage_type = merged.storage_type();
- auto &comm_buf = comm_buf_[key];
if (merged.ctx().dev_mask() == cpu::kDevMask) {
// Start of a push doesn't guarantee that the previous pushes are completed.
// This shouldn't affect training of networks though because training involves
// a sequence of push, pull, then push. This imposes ordering that the
// second push happens after the first pull, and the pull happens after first push.
- comm_buf = merged; // avoid memory copy
+ send_buf = merged; // avoid memory copy
} else {
- if (comm_buf.is_none()) {
+ if (send_buf.is_none()) {
if (storage_type == kDefaultStorage) {
- comm_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype());
+ send_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype());
} else {
- comm_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype());
+ send_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype());
}
}
- CopyFromTo(merged, &comm_buf);
+ CopyFromTo(merged, &send_buf);
}
// push to servers
if (storage_type == kDefaultStorage) {
- if (gradient_compression_->get_type() == CompressionType::kNone) {
- PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), true);
- PushDefault(key, comm_buf, pskv, priority);
- } else {
- // Note: gradient compression uses `do_merge` as proxy to
- // detect whether the push is initialization of a key or not.
- // is_active is false when push is initialization of key
- bool is_active = do_merge;
- PSKV &pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), is_active);
- // Returns push_pskv if active, else pull_pskv
- // we want inactive gc to send uncompressed gradients,
- // but sharded in the same way as later pushes would when gc becomes active
- if (is_active) {
- PushCompressed(key, comm_buf, pskv, priority);
- } else {
- PushDefault(key, comm_buf, pskv, priority);
- }
- }
- } else if (storage_type == kRowSparseStorage) {
- CHECK(gradient_compression_->get_type() == CompressionType::kNone)
- << "Gradient compression for row sparse storage type is not supported";
- PushRowSparse(key, comm_buf, priority);
- } else {
- LOG(FATAL) << "unknown storage type";
- }
- }
- }
-
- void PushCompressed(int key, const NDArray& comm_buf, const PSKV& pskv, int priority) {
- auto &small_buf = compr_buf_[key];
- auto &res_buf = residual_[key];
- size_t original_size = comm_buf.shape().Size();
-
- // Init the small buffer and residual_ buffer for quantize
- if (small_buf.is_none()) {
- small_buf = NDArray(TShape{pskv.size}, comm_buf.ctx(), false, comm_buf.dtype());
- res_buf = NDArray(TShape{(int64_t) original_size}, comm_buf.ctx(),
- false, comm_buf.dtype());
- res_buf = 0;
- }
- gradient_compression_->Quantize(comm_buf, &small_buf, &res_buf, priority);
- auto push_to_servers =
- [this, key, pskv, small_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
- size_t size = small_buf.shape().Size();
- real_t* data = small_buf.data().dptr<real_t>();
-#if MKL_EXPERIMENTAL == 1
- mkl_set_tblob_eager_mode(small_buf.data());
-#endif
- // do push. false means no delete
- ps::SArray<real_t> vals(data, size, false);
- CHECK_NOTNULL(ps_worker_)->ZPush(
- pskv.keys, vals, pskv.lens,
- static_cast<int>(DataHandleType::kCompressedPushPull), [cb]() { cb(); });
- };
- // acquire locks on both comm_buf and small_buf so that
- // pull (which uses comm_buf) for the same key waits till push finishes
- Engine::Get()->PushAsync(
- push_to_servers,
- pinned_ctx_,
- {small_buf.var(), comm_buf.var()},
- {},
- FnProperty::kNormal,
- priority,
- PROFILER_MESSAGE("KVStoreDistCompressedPush"));
- }
-
- void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) {
- auto push_to_servers =
- [this, key, pskv, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
+ auto push_to_servers =
+ [this, key, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
// convert to ps keys
size_t size = send_buf.shape().Size();
- real_t* data = send_buf.data().dptr<real_t>();
+ PSKV& pskv = EncodeKey(key, size);
+
#if MKL_EXPERIMENTAL == 1
mkl_set_tblob_eager_mode(send_buf.data());
#endif
+ real_t* data = send_buf.data().dptr<real_t>();
// do push. false means no delete
ps::SArray<real_t> vals(data, size, false);
CHECK_NOTNULL(ps_worker_)->ZPush(
- pskv.keys, vals, pskv.lens,
- static_cast<int>(DataHandleType::kDefaultPushPull), [cb]() { cb(); });
+ pskv.keys, vals, pskv.lens, 0, [cb]() { cb(); });
};
- Engine::Get()->PushAsync(
- push_to_servers,
+ Engine::Get()->PushAsync(
+ push_to_servers,
+ pinned_ctx_,
+ {send_buf.var()},
+ {},
+ FnProperty::kNormal,
+ priority,
+ PROFILER_MESSAGE("KVStoreDistDefaultPush"));
+ } else if (storage_type == kRowSparseStorage) {
+ PushRowSparse(key, send_buf, priority);
+ } else {
+ LOG(FATAL) << "unknown storage type";
+ }
+ }
+ }
+
+ // pull row sparse weight into `recv_buf` based on indices given by `indices`
+ void PullRowSparse_(const int key, const NDArray& recv_buf,
+ const NDArray& indices, int priority) {
+ using namespace rowsparse;
+ auto pull_from_servers = [this, key, recv_buf, indices]
+ (RunContext rctx, Engine::CallbackOnComplete cb) {
+ // allocate memory for the buffer
+ size_t num_rows = indices.shape().Size();
+ recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
+#if MKL_EXPERIMENTAL == 1
+ mkl_set_tblob_eager_mode(recv_buf.data());
+#endif
+ real_t* data = recv_buf.data().dptr<real_t>();
+ const auto offsets = indices.data().dptr<int64_t>();
+ const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
+ const int64_t size = num_rows * unit_len;
+ // convert to ps keys in row sparse format
+ PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
+ unit_len, recv_buf.shape()[0]);
+ if (this->log_verbose_) {
+ LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: "
+ << pskv.keys << " size: " << size;
+ }
+ auto vals = new ps::SArray<real_t>(data, size, false);
+ // copy indices to recv_buf. this needs to be done before ZPull
+ // because after pull is done, the callback function returns and locks are released.
+ // at this point, later functions may access the indices variable while copy happens
+ mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
+ indices.data().FlatTo1D<cpu, int64_t>());
+ CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, kRowSparsePushPull,
+ [vals, cb]() { delete vals; cb(); });
+ };
+ CHECK_NOTNULL(Engine::Get())->PushAsync(
+ pull_from_servers,
pinned_ctx_,
- {send_buf.var()},
- {},
+ {indices.var()},
+ {recv_buf.var()},
FnProperty::kNormal,
priority,
- PROFILER_MESSAGE("KVStoreDistDefaultPush"));
+ PROFILER_MESSAGE("KVStoreDistRowSparsePull"));
}
// push row sparse gradient
@@ -448,9 +382,9 @@ class KVStoreDist : public KVStoreLocal {
<< pskv.keys << " size: " << size;
}
ps::SArray<real_t> vals(data, size, false);
- CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens,
- static_cast<int>(DataHandleType::kRowSparsePushPull),
- [cb]() { cb(); });
+ CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, kRowSparsePushPull, [cb]() {
+ cb();
+ });
};
Engine::Get()->PushAsync(
push_to_servers,
@@ -462,50 +396,6 @@ class KVStoreDist : public KVStoreLocal {
PROFILER_MESSAGE("KVStoreDistRowSparsePush"));
}
-
- // pull row sparse weight into `recv_buf` based on indices given by `indices`
- void PullRowSparse_(const int key, const NDArray& recv_buf,
- const NDArray& indices, int priority) {
- using namespace rowsparse;
- auto pull_from_servers = [this, key, recv_buf, indices]
- (RunContext rctx, Engine::CallbackOnComplete cb) {
- // allocate memory for the buffer
- size_t num_rows = indices.shape().Size();
- recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
-#if MKL_EXPERIMENTAL == 1
- mkl_set_tblob_eager_mode(recv_buf.data());
-#endif
- real_t* data = recv_buf.data().dptr<real_t>();
- const auto offsets = indices.data().dptr<int64_t>();
- const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
- const int64_t size = num_rows * unit_len;
- // convert to ps keys in row sparse format
- PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
- unit_len, recv_buf.shape()[0]);
- if (this->log_verbose_) {
- LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: "
- << pskv.keys << " size: " << size;
- }
- auto vals = new ps::SArray<real_t>(data, size, false);
- // copy indices to recv_buf. this needs to be done before ZPull
- // because after pull is done, the callback function returns and locks are released.
- // at this point, later functions may access the indices variable while copy happens
- mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
- indices.data().FlatTo1D<cpu, int64_t>());
- CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens,
- static_cast<int>(DataHandleType::kRowSparsePushPull),
- [vals, cb]() { delete vals; cb(); });
- };
- CHECK_NOTNULL(Engine::Get())->PushAsync(
- pull_from_servers,
- pinned_ctx_,
- {indices.var()},
- {recv_buf.var()},
- FnProperty::kNormal,
- priority,
- PROFILER_MESSAGE("KVStoreDistRowSparsePull"));
- }
-
/**
* \brief check if the keys are all unique
*/
@@ -517,12 +407,32 @@ class KVStoreDist : public KVStoreLocal {
}
/**
+ * \brief struct for ps keys and lens
+ */
+ struct PSKV {
+ ps::SArray<ps::Key> keys; // n keys
+ ps::SArray<int> lens; // the length of the i-th value
+ int size;
+ };
+
+ /**
+ * \brief cache all key partitions
+ */
+ std::unordered_map<int, PSKV> ps_kv_;
+
+ /**
+ * \brief serizelize EncodeRowSparseKey and EncodeKey
+ */
+ std::mutex mu_;
+
+ /**
* \brief convert to keys in ps
*/
- inline PSKV& EncodeDefaultKey(int key, size_t size, bool is_push) {
+ inline PSKV& EncodeKey(int key, size_t size) {
mu_.lock();
PSKV& pskv = ps_kv_[key];
mu_.unlock();
+
if (!pskv.keys.empty()) {
CHECK_EQ(static_cast<size_t>(pskv.size), size) << "The value size cannot be changed";
} else {
@@ -544,8 +454,8 @@ class KVStoreDist : public KVStoreLocal {
pskv.size = 0;
for (int i = 0; i < num_servers; ++i) {
size_t part_size =
- static_cast<size_t>(round(static_cast<double>(size)/num_servers*(i+1))) -
- static_cast<size_t>(round(static_cast<double>(size)/num_servers*i));
+ static_cast<size_t>(round(static_cast<double>(size)/num_servers*(i+1))) -
+ static_cast<size_t>(round(static_cast<double>(size)/num_servers*i));
ps::Key ps_key = krs[i].begin() + key;
CHECK_LT(ps_key, krs[i].end());
pskv.keys.push_back(ps_key);
@@ -558,94 +468,6 @@ class KVStoreDist : public KVStoreLocal {
return pskv;
}
- /**
- * \brief Convert to keys in ps for compressed values
- * Divides original array into equal parts for each server
- * Populates both push and pull pskv on first call
- */
- inline PSKV& EncodeCompressedKey(int key, size_t original_size, bool is_push) {
- auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
- int num_servers = krs.size();
- CHECK_GT(num_servers, 0);
-
- // represents size of data to be sent
- size_t compr_size = gradient_compression_->GetCompressedSize(original_size);
-
- mu_.lock();
- PSKV& pskv = (is_push) ? compr_ps_kv_[key].push : compr_ps_kv_[key].pull;
- mu_.unlock();
-
- if (!pskv.keys.empty()) {
- size_t size = (is_push) ? compr_size : original_size;
- CHECK_EQ(static_cast<size_t >(pskv.size), size)<< "The value size can't be changed";
- } else {
- // populate both pull and push pskvs
- // push pskv has sizes corresponding to compressed data
- // pull pskv has decompressed sizes for parts in push_pskv
- mu_.lock();
- PSKV& pull_pskv = compr_ps_kv_[key].pull;
- PSKV& push_pskv = compr_ps_kv_[key].push;
- mu_.unlock();
-
- if (original_size < bigarray_bound_) {
- // a simple heuristic for load balancing
- // send it to a single random picked server
- int server = (key * 9973) % num_servers;
- ps::Key ps_key = krs[server].begin() + key;
- CHECK_LT(ps_key, krs[server].end());
- // meta info
- push_pskv.keys.push_back(krs[server].begin() + original_size);
- push_pskv.lens.push_back(0);
- // data
- push_pskv.keys.push_back(ps_key);
- pull_pskv.keys.push_back(ps_key);
- push_pskv.lens.push_back(compr_size);
- pull_pskv.lens.push_back(original_size);
- push_pskv.size = compr_size;
- pull_pskv.size = original_size;
- } else {
- // partition it to all servers
- push_pskv.size = 0;
- pull_pskv.size = 0;
-
- for (int i = 0; i < num_servers; ++i) {
- size_t part_compr, part_orig;
- if (i == num_servers-1) {
- part_compr = compr_size - push_pskv.size;
- part_orig = original_size - pull_pskv.size;
- } else {
- part_compr =
- static_cast<size_t> (round(static_cast<double>(compr_size)/num_servers*(i+1))) -
- static_cast<size_t> (round(static_cast<double>(compr_size)/num_servers*(i)));
- part_orig = part_compr * gradient_compression_->GetCompressionFactor();
- }
-
- // meta info
- ps::Key ps_key_dummy = krs[i].begin() + part_orig;
- CHECK_LT(ps_key_dummy, krs[i].end());
- push_pskv.keys.push_back(ps_key_dummy);
- push_pskv.lens.push_back(0);
-
- // data
- ps::Key ps_key = krs[i].begin() + key;
- CHECK_LT(ps_key, krs[i].end());
- push_pskv.keys.push_back(ps_key);
- pull_pskv.keys.push_back(ps_key);
- // push_pskv stores lengths of compressed blocks
- push_pskv.lens.push_back(part_compr);
- // pull_pskv stores lengths of original data
- pull_pskv.lens.push_back(part_orig);
- push_pskv.size += part_compr;
- pull_pskv.size += part_orig;
- }
- CHECK_EQ(static_cast<size_t>(push_pskv.size), compr_size);
- CHECK_EQ(static_cast<size_t>(pull_pskv.size), original_size);
- CHECK_EQ(push_pskv.lens.size(), num_servers*2);
- }
- }
- return pskv;
- }
-
// Note: this encoding method for row sparse keys doesn't allow cross-layer batching
inline PSKV& EncodeRowSparseKey(const int key, const int64_t size, const int64_t num_rows,
const int64_t *offsets, const size_t unit_len,
@@ -706,6 +528,7 @@ class KVStoreDist : public KVStoreLocal {
return pskv;
}
+
/**
* \brief for worker to push and pull data
*/
@@ -718,23 +541,8 @@ class KVStoreDist : public KVStoreLocal {
* \brief threshold for partition
*/
size_t bigarray_bound_;
- /**
- * \brief buffer for non-compressed data.
- * When gradient compression is active, this is used
- * for the data in pull and for original data in push
- */
+ /// \brief send & recver buffer
std::unordered_map<int, NDArray> comm_buf_;
- /**
- * \brief buffer for compressed data
- * Used when gradient compression is active and action
- * is push
- */
- std::unordered_map<int, NDArray> compr_buf_;
- /**
- * \brief residual buffer to accumulate quantization error
- * during gradient compression
- */
- std::unordered_map<int, NDArray> residual_;
bool log_verbose_;
};
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index de94c86..f2123e7 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -40,13 +40,10 @@
namespace mxnet {
namespace kvstore {
-enum class CommandType {
- kController, kStopServer, kSyncMode, kSetGradientCompression
-};
-
-enum class DataHandleType {
- kDefaultPushPull, kCompressedPushPull, kRowSparsePushPull
-};
+static const int kRowSparsePushPull = 1;
+static const int kDefaultPushPull = 0;
+static const int kStopServer = -1;
+static const int kSyncMode = -2;
/**
* \brief executor runs a function using the thread called \ref Start
@@ -120,7 +117,6 @@ class KVStoreDistServer {
ps_server_->set_request_handle(
std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3));
sync_mode_ = false;
- gradient_compression_ = std::make_shared<GradientCompression>();
log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
}
@@ -152,15 +148,11 @@ class KVStoreDistServer {
};
void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
- CommandType recved_type = static_cast<CommandType>(recved.head);
- if (recved_type == CommandType::kStopServer) {
+ if (recved.head == kStopServer) {
exec_.Stop();
- } else if (recved_type == CommandType::kSyncMode) {
+ } else if (recved.head == kSyncMode) {
sync_mode_ = true;
- } else if (recved_type == CommandType::kSetGradientCompression) {
- gradient_compression_->DecodeParams(recved.body);
} else {
- // this uses value 0 for message id from frontend
// let the main thread to execute ctrl, which is necessary for python
exec_.Exec([this, recved]() {
CHECK(controller_);
@@ -173,11 +165,8 @@ class KVStoreDistServer {
void DataHandleEx(const ps::KVMeta& req_meta,
const ps::KVPairs<real_t>& req_data,
ps::KVServer<real_t>* server) {
- DataHandleType recved_type = static_cast<DataHandleType>(req_meta.cmd);
- if (recved_type == DataHandleType::kRowSparsePushPull) {
+ if (req_meta.cmd == kRowSparsePushPull) {
DataHandleRowSparse(req_meta, req_data, server);
- } else if (recved_type == DataHandleType::kCompressedPushPull) {
- DataHandleCompressed(req_meta, req_data, server);
} else {
DataHandleDefault(req_meta, req_data, server);
}
@@ -370,91 +359,10 @@ class KVStoreDistServer {
}
}
- void DefaultStorageResponse(int key, const NDArray& stored,
- const ps::KVMeta& req_meta,
- const ps::KVPairs<real_t> &req_data,
- ps::KVServer<real_t>* server) {
- ps::KVPairs<real_t> response;
- CHECK(!stored.is_none()) << "init " << key << " first";
- auto len = stored.shape().Size();
- response.keys = req_data.keys;
- response.lens = {len};
- // TODO(mli) try to remove this CopyFrom
- response.vals.CopyFrom(static_cast<const float*>(stored.data().dptr_), len);
- server->Response(req_meta, response);
- }
-
- void DataHandleCompressed(const ps::KVMeta& req_meta,
- const ps::KVPairs<real_t> &req_data,
- ps::KVServer<real_t>* server) {
- if (req_meta.push) {
- // there used several WaitToRead, this is because \a recved's memory
- // could be deallocated when this function returns. so we need to make sure
- // the operators with \a NDArray are actually finished
-
- // first for dummy key which represents original size of array, whose len is 0
- CHECK_EQ(req_data.keys.size(), (size_t)2);
- CHECK_EQ(req_data.lens.size(), (size_t)2);
- CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[1]);
-
- int original_size = DecodeKey(req_data.keys[0]);
- int key = DecodeKey(req_data.keys[1]);
- auto& stored = store_[key];
-
- size_t ds[] = {(size_t)req_data.lens[1]};
- TShape dshape(ds, ds + 1);
- TBlob recv_blob((real_t*) req_data.vals.data(), // NOLINT(*)
- dshape, cpu::kDevMask);
- NDArray recved = NDArray(recv_blob, 0);
-
- NDArray decomp_buf = decomp_buf_[key];
- dshape = TShape{(int64_t) original_size};
-
- if (decomp_buf.is_none()) {
- decomp_buf = NDArray(dshape, Context());
- }
-
- if (stored.is_none()) {
- stored = NDArray(dshape, Context());
- gradient_compression_->Dequantize(recved, &stored, 0);
- server->Response(req_meta);
- stored.WaitToRead();
- } else if (sync_mode_) {
- // synced push
- auto& merged = merge_buf_[key];
- if (merged.array.is_none()) {
- merged.array = NDArray(dshape, Context());
- }
- if (merged.request.size() == 0) {
- gradient_compression_->Dequantize(recved, &merged.array, 0);
- } else {
- gradient_compression_->Dequantize(recved, &decomp_buf, 0);
- merged.array += decomp_buf;
- }
- merged.request.push_back(req_meta);
- ApplyUpdates(key, &merged, &stored, server);
- } else {
- // async push
- gradient_compression_->Dequantize(recved, &decomp_buf, 0);
- exec_.Exec([this, key, &decomp_buf, &stored]() {
- CHECK(updater_);
- updater_(key, decomp_buf, &stored);
- });
- server->Response(req_meta);
- stored.WaitToRead();
- }
- } else { // pull
- CHECK_EQ(req_data.keys.size(), (size_t)1);
- CHECK_EQ(req_data.lens.size(), (size_t)0);
- int key = DecodeKey(req_data.keys[0]);
- DefaultStorageResponse(key, store_[key], req_meta, req_data, server);
- }
- }
-
void DataHandleDefault(const ps::KVMeta& req_meta,
const ps::KVPairs<real_t> &req_data,
ps::KVServer<real_t>* server) {
- CHECK_EQ(req_meta.cmd, static_cast<int>(DataHandleType::kDefaultPushPull));
+ CHECK_EQ(req_meta.cmd, kDefaultPushPull);
// do some check
CHECK_EQ(req_data.keys.size(), (size_t)1);
if (req_meta.push) {
@@ -503,7 +411,15 @@ class KVStoreDistServer {
stored.WaitToRead();
}
} else {
- DefaultStorageResponse(key, stored, req_meta, req_data, server);
+ // pull
+ ps::KVPairs<real_t> response;
+ CHECK(!stored.is_none()) << "init " << key << " first";
+ auto len = stored.shape().Size();
+ response.keys = req_data.keys;
+ response.lens = {len};
+ // TODO(mli) try to remove this CopyFrom
+ response.vals.CopyFrom(static_cast<const float*>(stored.data().dptr_), len);
+ server->Response(req_meta, response);
}
}
@@ -512,44 +428,21 @@ class KVStoreDistServer {
return key - kr.begin();
}
-
/**
- * \brief user defined mode for push
+ * \brief user defined
*/
bool sync_mode_;
KVStore::Controller controller_;
KVStore::Updater updater_;
- /**
- * \brief store_ contains the value at kvstore for each key
- */
std::unordered_map<int, NDArray> store_;
-
- /**
- * \brief merge_buf_ is a buffer used if sync_mode is true. It represents
- * values from different workers being merged. The store will be updated
- * to this value when values from all workers are pushed into this buffer.
- */
std::unordered_map<int, MergeBuf> merge_buf_;
- /**
- * \brief decomp_buf_ is a buffer into which compressed values are
- * decompressed before merging to the store. used when compress_!='none'
- */
- std::unordered_map<int, NDArray> decomp_buf_;
-
Executor exec_;
ps::KVServer<float>* ps_server_;
// whether to LOG verbose information
bool log_verbose_;
-
- /**
- * \brief gradient compression object.
- * starts with none, used after SetGradientCompression sets the type
- * currently there is no support for unsetting gradient compression
- */
- std::shared_ptr<kvstore::GradientCompression> gradient_compression_;
};
} // namespace kvstore
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 9fe161c..1a4ced8 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -59,7 +59,6 @@ class KVStoreLocal : public KVStore {
comm_ = new CommCPU();
}
pinned_ctx_ = comm_->pinned_ctx();
- gradient_compression_ = std::make_shared<GradientCompression>();
}
virtual ~KVStoreLocal() {
@@ -137,11 +136,6 @@ class KVStoreLocal : public KVStore {
PullRowSparseImpl(keys, val_rowids, priority);
}
- void SetGradientCompression(const std::vector<std::pair<std::string, std::string> >
- & kwargs) override {
- gradient_compression_->SetParams(kwargs);
- }
-
private:
virtual void InitImpl(const std::vector<int>& keys,
const std::vector<NDArray>& values) {
@@ -151,7 +145,6 @@ class KVStoreLocal : public KVStore {
local_[keys[i]] = values[i].Copy(pinned_ctx_);
comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
}
- comm_->SetGradientCompression(gradient_compression_);
}
virtual void PushImpl(const std::vector<int>& keys,
diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py
index df85fe5..900d6bb 100644
--- a/tests/nightly/dist_sync_kvstore.py
+++ b/tests/nightly/dist_sync_kvstore.py
@@ -23,8 +23,7 @@ sys.path.insert(0, "../../python/")
import mxnet as mx
import numpy as np
import numpy.random as rnd
-from mxnet.test_utils import assert_almost_equal
-from test_kvstore import compute_expected_2bit_quantization
+import time
def check_diff_to_scalar(A, x, rank=None):
""" assert A == x"""
@@ -40,7 +39,6 @@ init_test_keys_device_big = [str(i) for i in range(500,600)]
rate = 2
shape = (2, 3)
-irregular_shape = (1211,1211)
big_shape = (1200, 1200) # bigger than MXNET_KVSTORE_BIGARRAY_BOUND
kv = mx.kv.create('dist_sync')
@@ -59,17 +57,6 @@ def init_kv():
kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate))
return kv, my_rank, nworker
-def init_kv_compressed(kv):
- threshold = 0.5
- kv.set_gradient_compression({'type': '2bit', 'threshold':threshold})
- # init kv compression keys
- kv.init('11221', mx.nd.zeros(big_shape))
- kv.init('112221', mx.nd.zeros(irregular_shape))
- kv.init('1121', mx.nd.zeros(shape))
- # to test inactive mode
- kv.init('1122', mx.nd.ones(shape))
- return kv, threshold
-
def test_sync_push_pull():
kv, my_rank, nworker = init_kv()
def check_default_keys(kv, my_rank, nworker):
@@ -186,114 +173,11 @@ def test_sync_push_pull():
expected[row] = updated_val[row]
check_diff_to_scalar(val, expected, rank=my_rank)
- def check_compr_residual(kv, threshold, nworker):
- for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]:
- # doesn't meet threshold
- kv.push(k, mx.nd.ones(s)*0.4)
- val=mx.nd.zeros(s)
- kv.pull(k,val)
- check_diff_to_scalar(val, 0)
-
- # just meets threshold with residual
- kv.push(k, mx.nd.ones(s)*(threshold - 0.4))
- val2 = mx.nd.zeros(s)
- kv.pull(k,val2)
- curval = threshold * rate * nworker
- check_diff_to_scalar(val2, curval)
-
- # doesn't meet threshold
- kv.push(k, mx.nd.ones(s)*0.2)
- val3= mx.nd.zeros(s)
- kv.pull(k, val3)
- check_diff_to_scalar(val3, curval)
-
- # exceeds again
- kv.push(k, mx.nd.ones(s)*(threshold-0.2))
- val4 = mx.nd.zeros(s)
- kv.pull(k,val4)
- curval += threshold*rate*nworker
- check_diff_to_scalar(val4, curval)
- # residual is 0 now
-
- def check_compr_ones(kv, threshold, nworker):
- for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]:
- val = mx.nd.zeros(s)
- kv.pull(k, val)
- curval = val[0][0].asnumpy()[0]
- kv.push(k,mx.nd.ones(s)*threshold)
- val2 = mx.nd.zeros(s)
- kv.pull(k, val2)
- newval = curval + rate*nworker*threshold
- check_diff_to_scalar(val2, newval)
- # residual = 0 again
-
- def check_compr_pull_before_push(kv):
- for k,s in [('1121', shape),('112221',irregular_shape),
- ('11221', big_shape), ('1122',shape)]:
- if k=='1122':
- # tests that GC is not used for init of a key
- val = mx.nd.zeros(s)
- kv.pull(k, val)
- check_diff_to_scalar(val, 1)
- else:
- val = mx.nd.ones(s)
- kv.pull(k, val)
- check_diff_to_scalar(val, 0)
-
- def check_compr_zero(kv):
- for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]:
- kv.push(k, mx.nd.zeros(s))
- # to check that all are set to 0s
- val = mx.nd.ones(s)
- kv.pull(k, val)
- check_diff_to_scalar(val, 0)
-
- def check_compr_random(kv, threshold, nworker):
- # set a seed so all workers generate same data. knowing this helps
- # calculate expected value after pull
- mx.random.seed(123)
- rnd.seed(123)
- nrepeat = 5
- compr_random_keys_shapes = [('2121', shape),('212221',irregular_shape),('21221', big_shape)]
- # use new keys so residual is 0 for calculation of expected
- for k,s in compr_random_keys_shapes:
- kv.init(k, mx.nd.zeros(s))
- for k,s in compr_random_keys_shapes:
- curr_residual = np.zeros(s)
- for l in range(nrepeat):
- orig_val = mx.nd.zeros(s)
- kv.pull(k, orig_val)
-
- grad = mx.nd.array(rnd.rand(s[0], s[1]))
- # creates a copy because push changes grad because of assignment
- grad_cpy = mx.nd.array(grad)
- kv.push(k, grad)
- val = mx.nd.zeros(s)
- kv.pull(k, val)
-
- diff = val - orig_val
-
- # compute expected by using simulation of operator
- compr, curr_residual, decompr = compute_expected_2bit_quantization(grad_cpy, curr_residual, threshold)
- decompr *= nworker * rate
- assert_almost_equal(diff.asnumpy(), decompr)
-
- print ('worker '+str(my_rank)+' started with non compression tests')
check_default_keys(kv, my_rank, nworker)
check_row_sparse_keys(kv, my_rank, nworker)
check_row_sparse_keys_with_zeros(kv, my_rank, nworker)
check_big_row_sparse_keys(kv, my_rank, nworker)
- print('worker ' + str(my_rank) + ' is done with non compression tests')
-
- # don't run non compressed keys after this as kvstore now is set to compressed
- print ('worker '+str(my_rank)+' started with compression tests')
- kv, threshold = init_kv_compressed(kv)
- check_compr_pull_before_push(kv)
- check_compr_zero(kv)
- check_compr_residual(kv, threshold, nworker)
- check_compr_ones(kv, threshold, nworker)
- check_compr_random(kv, threshold, nworker)
- print('worker ' + str(my_rank) + ' is done with compression tests')
+ print('worker ' + str(my_rank) + ' is done')
def test_sync_init():
def check_init(kv, cur_keys, cur_shape, device=False):
diff --git a/tests/nightly/test_kvstore.py b/tests/nightly/test_kvstore.py
index a14feac..081bc9c 100644
--- a/tests/nightly/test_kvstore.py
+++ b/tests/nightly/test_kvstore.py
@@ -21,59 +21,17 @@ import sys
sys.path.insert(0, "../../python/")
import mxnet as mx
import numpy as np
-import numpy.random as rnd
-import copy
-from mxnet.test_utils import assert_almost_equal
+keys = [3, 5, 7]
+# let the last shape exceed MXNET_KVSTORE_BIGARRAY_BOUND
+shapes = [(4, 4), (100, 100), (2000, 2000)];
-def check_diff_to_scalar(A, x, rank=None):
- """ assert A == x"""
- assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x)
+lr = .1
+nworker = 4
+nrepeat = 10
-def compute_expected_2bit_quantization(arr, curr_residual, threshold):
- from struct import pack,unpack
- def bits2int(bits):
- bits = [int(x) for x in bits[::-1]]
- x = 0
- for i in range(len(bits)):
- x += bits[i]*2**i
- return x
-
- def as_float32(s):
- return unpack("f",pack("I", bits2int(s)))[0]
-
- # str_quant stores the quantized representation as a sequence of bits
- str_quant = ''
- new_residual = []
- decompr = []
-
- arr_npy = arr.asnumpy()
- for i, a in np.ndenumerate(arr_npy):
- a += curr_residual[i]
- if a >= threshold:
- str_quant += '11'
- new_residual.append(a - threshold)
- decompr.append(threshold)
- elif a <= (-1*threshold):
- str_quant += '10'
- new_residual.append(a + threshold)
- decompr.append(-1*threshold)
- else:
- str_quant += '00'
- new_residual.append(a)
- decompr.append(0)
- # append extra bits when size of array not a factor of 16
- if len(str_quant)%16 != 0:
- str_quant += '0'*(16 - len(str_quant)%16)
-
- compr = []
- # converts the string generated into integers 32chars at a time
- i = 0
- while i<len(str_quant):
- cur_float = str_quant[i+24:i+32] + str_quant[i+16:i+24] + str_quant[i+8:i+16] + str_quant[i:i+8]
- compr.append(as_float32(cur_float))
- i+=32
- return np.array(compr), np.array(new_residual).reshape(arr.shape), np.array(decompr).reshape(arr.shape)
+## generate data
+data = [[[np.random.random(s)*2-1 for i in range(nworker)] for s in shapes] for j in range(nrepeat)]
## individual key interface
def test_kvstore(kv_type):
@@ -97,118 +55,9 @@ def test_kvstore(kv_type):
err = sum(err) / np.sum(np.abs(res[j]))
assert(err < 1e-6), (err, shapes[j])
-def test_compress_kvstore(kv_type, compression='2bit', threshold=0.5):
- print(kv_type + ' with ' + compression + ' compression')
- rate = 2
- kv = mx.kv.create(kv_type)
- kv.set_gradient_compression({'type':compression, 'threshold':threshold})
- kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate))
- for k, s in zip(keys, shapes):
- kv.init(k, mx.nd.zeros(s))
- # init one key with 1s so we can check if it was compressed during init
- kv.init(gc_init_test_key, mx.nd.ones(shapes[0]))
- # use different keys for random tests so that
- # we can track residual from start
- random_keys = [13, 15, 17]
- for k, s in zip(random_keys, shapes):
- kv.init(k, mx.nd.zeros(s))
-
- def pull_init_test(kv):
- # checks that compression is not applied to init of key
- out = [mx.nd.zeros(shapes[0], mx.gpu(g)) for g in range(nworker)]
- kv.pull(gc_init_test_key, out=out)
- exp = np.ones_like(out[0].asnumpy())
- for o in out:
- assert_almost_equal(o.asnumpy(), exp)
-
- def pull_before_push(kv):
- for i in range(nrepeat):
- for j in range(len(keys)):
- out = [mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)]
- kv.pull(keys[j], out=out)
- exp = np.zeros_like(out[0].asnumpy())
- for o in out:
- assert_almost_equal(o.asnumpy(), exp)
-
- def push_zeros(kv):
- for i in range(nrepeat):
- for j in range(len(keys)):
- kv.push(keys[j], [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)])
- out = [mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)]
- kv.pull(keys[j], out=out)
- exp = np.zeros_like(out[0].asnumpy())
- for o in out:
- assert_almost_equal(o.asnumpy(), exp)
-
- def verify_residual(kv, threshold, rate):
- for j in range(len(keys)):
- kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*0.4 for g in range(nworker)])
- out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
- kv.pull(keys[j],out=out)
- for o in out:
- check_diff_to_scalar(o, 0)
-
- kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*(threshold-0.3) for g in range(nworker)])
- out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
- kv.pull(keys[j],out=out)
- curval = threshold * rate * nworker
- for o in out:
- check_diff_to_scalar(o, curval)
-
- kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*(0.2) for g in range(nworker)])
- out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
- kv.pull(keys[j],out=out)
- for o in out:
- check_diff_to_scalar(o, curval)
-
- kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*(threshold-0.3) for g in range(nworker)])
- out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
- kv.pull(keys[j],out=out)
- curval += threshold*rate*nworker
- for o in out:
- check_diff_to_scalar(o, curval)
- # residual would be 0 now
- return curval
-
- def check_neg(kv, neg, rate, curval):
- for r in range(nrepeat):
- curval = curval + rate*nworker*neg
- for j in range(len(keys)):
- kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*neg for g in range(nworker)])
- out = [mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)]
- kv.pull(keys[j], out=out)
- for o in out:
- check_diff_to_scalar(o, curval)
- # residual would be 0 again
-
- def check_compr_random(kv, threshold):
- for k, s in zip(random_keys, shapes):
- curr_residual = [np.zeros(s) for g in range(nworker)]
- orig_val = [mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)]
- kv.pull(k, out=orig_val)
- grads = [mx.nd.random_uniform(-0.6, 0.6, shape=s, ctx=mx.gpu(g)) for g in range(nworker)]
- grads_cpy = copy.deepcopy(grads)
- kv.push(k, grads)
- val = [mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)]
- kv.pull(k, out=val)
- diffs = [val[g] - orig_val[g] for g in range(nworker)]
- # compute expected by using simulation of operator
- # on cpu
- sum_dequantized_vals = np.zeros(s)
- for g in range(nworker):
- compr, curr_residual[g], decompr = compute_expected_2bit_quantization(
- grads_cpy[g], curr_residual[g], threshold)
- sum_dequantized_vals += (decompr * rate)
-
- for g in range(nworker):
- assert_almost_equal(diffs[g].asnumpy(), sum_dequantized_vals)
-
- pull_init_test(kv)
- pull_before_push(kv)
- push_zeros(kv)
- curval = verify_residual(kv, threshold, rate)
- check_neg(kv, -1*threshold, rate, curval)
- check_compr_random(kv, threshold)
+test_kvstore('local_update_cpu')
+test_kvstore('local_allreduce_cpu')
+test_kvstore('local_allreduce_device')
## group keys interface
def test_group_kvstore(kv_type):
@@ -230,27 +79,6 @@ def test_group_kvstore(kv_type):
err = sum(err) / np.sum(np.abs(a))
assert(err < 1e-6), (err, a.shape)
-if __name__ == "__main__":
- keys = [3, 5, 7]
- # let the last shape exceed MXNET_KVSTORE_BIGARRAY_BOUND
- shapes = [(4, 4), (100, 100), (2000, 2000)]
-
- gc_init_test_key = 9
-
- lr = .1
- nworker = 4
- nrepeat = 10
-
- ## generate data
- data = [[[np.random.random(s)*2-1 for i in range(nworker)] for s in shapes] for j in range(nrepeat)]
-
- test_kvstore('local_update_cpu')
- test_kvstore('local_allreduce_cpu')
- test_kvstore('local_allreduce_device')
-
- # compression for local kvstore happens only when reduce is on device
- test_compress_kvstore('local_allreduce_device')
-
- test_group_kvstore('local_update_cpu')
- test_group_kvstore('local_allreduce_cpu')
- test_group_kvstore('local_allreduce_device')
+test_group_kvstore('local_update_cpu')
+test_group_kvstore('local_allreduce_cpu')
+test_group_kvstore('local_allreduce_device')
diff --git a/tools/bandwidth/measure.py b/tools/bandwidth/measure.py
index cd4f0fe..66ef737 100644
--- a/tools/bandwidth/measure.py
+++ b/tools/bandwidth/measure.py
@@ -53,8 +53,6 @@ def parse_args():
help='number of classes')
parser.add_argument('--optimizer', type=str, default='None',
help='the optimizer set to kvstore. None means no optimizer')
- parser.add_argument('--gc-type', type=str, default='none',
- help='type of gradient compression')
args = parser.parse_args()
logging.info(args)
return args
@@ -74,12 +72,10 @@ def error(gpu_res, cpu_res):
return res
def run(network, optimizer, gpus, kv_store, image_shape, disp_batches,
- num_batches, test_results, gc_type, **kwargs):
+ num_batches, test_results, **kwargs):
# create kvstore and optimizer
devs = [mx.gpu(int(i)) for i in gpus.split(',')]
kv = mx.kv.create(kv_store)
- if gc_type != 'none':
- kv.set_gradient_compression({'type': gc_type})
if optimizer is None or optimizer == 'None':
opt = None
else:
--
To stop receiving notification emails like this one, please contact
"commits@mxnet.apache.org" <co...@mxnet.apache.org>.