You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/02/24 02:56:56 UTC
[incubator-mxnet] branch master updated: Dual stream cudnn
Convolution backward() with MXNET_GPU_WORKER_NSTREAMS=2. (#14006)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 5f32f32 Dual stream cudnn Convolution backward() with MXNET_GPU_WORKER_NSTREAMS=2. (#14006)
5f32f32 is described below
commit 5f32f32e8c7ce50f20eae1438bf970eba6830578
Author: Dick Carter <di...@comcast.net>
AuthorDate: Sat Feb 23 18:56:30 2019 -0800
Dual stream cudnn Convolution backward() with MXNET_GPU_WORKER_NSTREAMS=2. (#14006)
* Dual stream conv backward(). Enable with MXNET_GPU_WORKER_NSTREAMS=2.
* Fix for MSVC compiler.
* Fix cpplint.
* Add MXNET_GPU_WORKER_NSTREAMS env var documentation.
* Improve test function and commenting.
* Add description of proper aux stream use using events.
* RAII rework to simplify usage within operators.
* Fix cpplint.
* Expand testing to cover all engines.
* Fix NaiveEngine shutdown segfault on CentOS7.
---
docs/faq/env_var.md | 6 ++
include/mxnet/base.h | 133 ++++++++++++++++++++++++++
include/mxnet/op_attr_types.h | 9 ++
src/engine/naive_engine.cc | 25 ++++-
src/engine/stream_manager.h | 25 +++--
src/engine/threaded_engine_perdevice.cc | 12 ++-
src/imperative/imperative_utils.h | 6 +-
src/ndarray/ndarray.cc | 4 +-
src/operator/nn/cudnn/cudnn_convolution-inl.h | 72 ++++++++++----
tests/python/gpu/test_operator_gpu.py | 50 ++++++++++
10 files changed, 307 insertions(+), 35 deletions(-)
diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md
index c35d4e5..f49cb19 100644
--- a/docs/faq/env_var.md
+++ b/docs/faq/env_var.md
@@ -191,6 +191,12 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca
## Other Environment Variables
+* MXNET_GPU_WORKER_NSTREAMS
+ - Values: 1, or 2 ```(default=1)```
+ - Determines the number of GPU streams available to operators for their functions.
+ - Setting this to 2 may yield a modest performance increase, since ops like the cuDNN convolution op can then calculate their data- and weight-gradients in parallel.
+ - Setting this to 2 may also increase a model's demand for GPU global memory.
+
* MXNET_CUDNN_AUTOTUNE_DEFAULT
- Values: 0, 1, or 2 ```(default=1)```
- The default value of cudnn auto tuning for convolution layers.
diff --git a/include/mxnet/base.h b/include/mxnet/base.h
index 26c1a1b..2ea6ebb 100644
--- a/include/mxnet/base.h
+++ b/include/mxnet/base.h
@@ -194,6 +194,11 @@ struct Context {
*/
inline static int32_t GetGPUCount();
/*!
+ * Get the number of streams that a GPU Worker has available to operations.
+ * \return The number of streams that are available.
+ */
+ inline static int32_t GetGPUStreamsPerWorker();
+ /*!
* \brief get the free and total available memory on a GPU
* \param dev the GPU number to query
* \param free_mem pointer to the uint64_t holding free GPU memory
@@ -221,6 +226,112 @@ struct Context {
inline static Context FromString(const std::string& str);
};
+#if MXNET_USE_CUDA
+/*! \brief Holds an auxiliary mshadow gpu stream that can be synced with a primary stream. */
+class GPUAuxStream {
+ public:
+ /*!
+ * \brief constructor.
+ * \param primary_stream gpu stream that is synced with the created auxiliary stream.
+ */
+ explicit GPUAuxStream(mshadow::Stream<gpu> *primary_stream) :
+ primary_stream_(primary_stream),
+ aux_stream_(primary_stream),
+ gpu_stream_sync_event_(nullptr) {
+ if (Context::GetGPUStreamsPerWorker() >= 2) {
+ // Create auxiliary stream on the same device with the same properties as the primary stream
+ bool primary_has_blas_handle =
+ primary_stream->blas_handle_ownership_ == mshadow::Stream<gpu>::OwnHandle;
+ bool primary_has_dnn_handle =
+ primary_stream->dnn_handle_ownership_ == mshadow::Stream<gpu>::OwnHandle;
+ aux_stream_ = mshadow::NewStream<gpu>(primary_has_blas_handle,
+ primary_has_dnn_handle,
+ primary_stream->dev_id);
+ MSHADOW_CUDA_CALL(cudaEventCreateWithFlags(&gpu_stream_sync_event_, cudaEventDisableTiming));
+ }
+ }
+ /*! \brief destructor */
+ ~GPUAuxStream() {
+ // If the aux_stream_ == primary_stream_, then we created no new streams to destroy.
+ if (aux_stream_ != primary_stream_) {
+ MSHADOW_CATCH_ERROR(mshadow::DeleteStream<gpu>(aux_stream_));
+ MSHADOW_CATCH_ERROR(cudaEventDestroy(gpu_stream_sync_event_));
+ }
+ }
+ /*!
+ * \brief Makes future aux stream work wait on the completion of existing primary stream work.
+ */
+ void PreAuxStreamUseSync() {
+ // If the aux_stream_ == primary_stream_, then no synchronization is necessary.
+ if (aux_stream_ != primary_stream_)
+ StreamSync(primary_stream_, aux_stream_, gpu_stream_sync_event_);
+ }
+ /*!
+ * \brief Makes future primary stream work wait on the completion of existing aux stream work.
+ */
+ void PostAuxStreamUseSync() {
+ // If the aux_stream_ == primary_stream_, then no synchronization is necessary.
+ if (aux_stream_ != primary_stream_)
+ StreamSync(aux_stream_, primary_stream_, gpu_stream_sync_event_);
+ }
+ /*! \brief Getter for created auxiliary stream. */
+ mshadow::Stream<gpu> *GetStream() { return aux_stream_; }
+ /*!
+ * \brief Make future work enqueued to `s2` wait on completion of current work enqueued to `s1`.
+ * \param s1 stream with work that must be completed before future s2 work can begin.
+ * \param s2 stream whose future work is made to wait on the completion of existing s1 work.
+ * \param event used to pass s1 state to s2.
+ */
+ static void StreamSync(mshadow::Stream<gpu> *s1, mshadow::Stream<gpu> *s2, cudaEvent_t event) {
+ MSHADOW_CUDA_CALL(cudaEventRecord(event, s1->stream_));
+ MSHADOW_CUDA_CALL(cudaStreamWaitEvent(s2->stream_, event, 0));
+ }
+
+ private:
+ mshadow::Stream<gpu> *primary_stream_;
+ mshadow::Stream<gpu> *aux_stream_;
+ cudaEvent_t gpu_stream_sync_event_;
+};
+
+/*!
+ * \brief Provides automatic coordination of an auxilary stream with a primary one.
+ * This object, upon construction, prepares an aux stream for use by syncing it with enqueued
+ * primary-stream work. Object destruction will sync again so future primary-stream work
+ * will wait on enqueued aux-stream work. If MXNET_GPU_WORKER_NSTREAMS == 1, then this defaults
+ * simply: the primary stream will equal the aux stream and the syncs will be executed as nops.
+ * See ./src/operator/cudnn/cudnn_convolution-inl.h for a usage example.
+ */
+class SyncedGPUAuxStream {
+ public:
+ /*!
+ * \brief constructor.
+ * \param gpu_aux_stream auxilary gpu stream that is managed by this RAII object.
+ */
+ explicit SyncedGPUAuxStream(GPUAuxStream *gpu_aux_stream) : gpu_aux_stream_(gpu_aux_stream) {
+ gpu_aux_stream_->PreAuxStreamUseSync();
+ }
+ /*! \brief destructor */
+ ~SyncedGPUAuxStream() {
+ gpu_aux_stream_->PostAuxStreamUseSync();
+ }
+ /*! \brief copy constructor deleted to prevent unexpected synchronizations. */
+ SyncedGPUAuxStream(const SyncedGPUAuxStream&) = delete;
+ /*! \brief copy assignment operator deleted to prevent unexpected synchronizations. */
+ void operator=(const SyncedGPUAuxStream&) = delete;
+ /*! \brief move constructor permitted as alternative to copying. */
+ SyncedGPUAuxStream(SyncedGPUAuxStream&&) = default;
+ /*! \brief move assignment operator permitted as alternative to copy assignment. */
+ SyncedGPUAuxStream& operator=(SyncedGPUAuxStream&&) = default;
+ /*! \brief Getter for underlying mshadow::Stream<gpu>. */
+ inline mshadow::Stream<gpu>* GetStream() const {
+ return gpu_aux_stream_->GetStream();
+ }
+
+ private:
+ GPUAuxStream *gpu_aux_stream_;
+};
+#endif // MXNET_USE_CUDA
+
/*!
* \brief execution time context.
* The information needed in runtime for actual execution.
@@ -233,6 +344,10 @@ struct RunContext {
*/
void *stream;
/*!
+ * \brief the auxiliary stream of the device, can be NULL or Stream<gpu>* in GPU mode
+ */
+ void *aux_stream;
+ /*!
* \brief indicator of whether this execution is run in bulk mode
*/
bool is_bulk;
@@ -245,6 +360,15 @@ struct RunContext {
inline mshadow::Stream<xpu>* get_stream() const {
return static_cast<mshadow::Stream<xpu>*>(stream);
}
+#if MXNET_USE_CUDA
+ /*!
+ * \brief get an RAII object that transparently handles the syncing of the auxiliary stream.
+ * \return the aux stream auto-syncing object
+ */
+ inline SyncedGPUAuxStream get_gpu_aux_stream() const {
+ return SyncedGPUAuxStream(static_cast<GPUAuxStream*>(aux_stream));
+ }
+#endif
/*! \brief get the base Context from RunContext */
inline const Context& get_ctx() const {
return ctx;
@@ -309,6 +433,15 @@ inline int32_t Context::GetGPUCount() {
#endif
}
+inline int32_t Context::GetGPUStreamsPerWorker() {
+ // The default number of streams available if the user has not set MXNET_GPU_WORKER_NSTREAMS.
+ const int32_t default_num_streams = 1;
+ // The get_aux_stream() interface can supply one additional stream beyond the standard one.
+ static int32_t num_streams =
+ dmlc::GetEnv("MXNET_GPU_WORKER_NSTREAMS", default_num_streams) >= 2 ? 2 : 1;
+ return num_streams;
+}
+
inline void Context::GetGPUMemoryInformation(int dev, uint64_t *free_mem,
uint64_t *total_mem) {
#if MXNET_USE_CUDA
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index aba59ce..22bba30 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -83,6 +83,15 @@ struct OpContext {
inline mshadow::Stream<xpu>* get_stream() const {
return run_ctx.get_stream<xpu>();
}
+#if MXNET_USE_CUDA
+ /*!
+ * \brief get auxilary gpu stream auto-syncing object from Context
+ * \return the aux stream auto-syncing object
+ */
+ inline SyncedGPUAuxStream get_gpu_aux_stream() const {
+ return run_ctx.get_gpu_aux_stream();
+ }
+#endif
};
/*! \brief the execution type of the operator */
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 05b72d2..db44919 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -62,6 +62,8 @@ class NaiveEngine final : public Engine {
};
NaiveEngine() {
+ objpool_opr_ref_ = common::ObjectPool<NaiveOpr>::_GetSharedRef();
+ objpool_var_ref_ = common::ObjectPool<NaiveVar>::_GetSharedRef();
}
// virtual destructor
virtual ~NaiveEngine() {
@@ -74,6 +76,12 @@ class NaiveEngine final : public Engine {
streams_[i] = nullptr;
}
}
+ for (size_t i = 0; i < aux_streams_.size(); ++i) {
+ if (aux_streams_[i] != nullptr) {
+ delete aux_streams_[i];
+ aux_streams_[i] = nullptr;
+ }
+ }
#endif
}
@@ -169,16 +177,18 @@ class NaiveEngine final : public Engine {
MSHADOW_CATCH_ERROR(mshadow::SetDevice<gpu>(exec_ctx.dev_id));
if (streams_.size() <= dev_id) {
streams_.resize(dev_id + 1, nullptr);
+ aux_streams_.resize(dev_id + 1, nullptr);
}
if (streams_[dev_id] == nullptr) {
streams_[dev_id] = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0, dev_id);
+ aux_streams_[dev_id] = new GPUAuxStream(streams_[dev_id]);
}
- exec_fun(RunContext{exec_ctx, streams_[dev_id]}, callback);
+ exec_fun(RunContext{exec_ctx, streams_[dev_id], aux_streams_[dev_id], false}, callback);
#else
LOG(FATAL) << "GPU is not enabled";
#endif
} else {
- exec_fun(RunContext{exec_ctx, &cpu_stream_}, callback);
+ exec_fun(RunContext{exec_ctx, &cpu_stream_, nullptr, false}, callback);
}
CHECK(this->req_completed_)
<< "NaiveEngine only support synchronize Push so far";
@@ -220,6 +230,17 @@ class NaiveEngine final : public Engine {
mshadow::Stream<cpu> cpu_stream_;
// GPU streams
std::vector<mshadow::Stream<gpu>*> streams_;
+#if MXNET_USE_CUDA
+ // GPU auxiliary streams
+ std::vector<GPUAuxStream*> aux_streams_;
+#endif
+/*!
+ * \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early
+ * See also #309 (https://github.com/dmlc/mxnet/issues/309) and similar fix in threaded_engine.h.
+ * Without this, segfaults seen on CentOS7 in test_operator_gpu.py:test_convolution_multiple_streams
+ */
+ std::shared_ptr<common::ObjectPool<NaiveOpr> > objpool_opr_ref_;
+ std::shared_ptr<common::ObjectPool<NaiveVar> > objpool_var_ref_;
}; // class NaiveEngine
Engine *CreateNaiveEngine() {
diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h
index 8d44d9c..42d03e5 100644
--- a/src/engine/stream_manager.h
+++ b/src/engine/stream_manager.h
@@ -55,6 +55,8 @@ class StreamManager {
#if MXNET_USE_CUDA
std::array<std::array<mshadow::Stream<gpu>*, kStreams>, kNumGpus>
gpu_streams_;
+ std::array<std::array<GPUAuxStream*, kStreams>, kNumGpus>
+ gpu_aux_streams_;
std::array<mshadow::Stream<gpu>*, kNumGpus> gpu_io_streams_;
std::array<int, kNumGpus> gpu_cnt_;
#endif // MXNET_USE_CUDA
@@ -67,7 +69,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
RunContext ret;
switch (ctx.dev_mask()) {
case cpu::kDevMask:
- ret = RunContext{ctx, nullptr, false};
+ ret = RunContext{ctx, nullptr, nullptr, false};
break;
case gpu::kDevMask: {
#if MXNET_USE_CUDA
@@ -77,8 +79,13 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
auto&& counter = gpu_cnt_.at(ctx.dev_id);
if (counter == -1) {
mxnet::common::cuda::DeviceStore device_store(ctx.dev_id);
- for (auto&& i : gpu_streams_.at(ctx.dev_id)) {
- i = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0, ctx.dev_id);
+ for (auto&& primary_stream : gpu_streams_.at(ctx.dev_id)) {
+ primary_stream = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0, ctx.dev_id);
+ }
+ int idx = 0;
+ for (auto&& aux_stream : gpu_aux_streams_.at(ctx.dev_id)) {
+ auto primary_stream = gpu_streams_.at(ctx.dev_id).at(idx++);
+ aux_stream = new GPUAuxStream(primary_stream);
}
counter = 0;
}
@@ -87,6 +94,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
}
ret = RunContext{ctx,
gpu_streams_.at(ctx.dev_id).at(use_counter),
+ gpu_aux_streams_.at(ctx.dev_id).at(use_counter),
false};
break;
#else
@@ -105,7 +113,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
RunContext ret;
switch (ctx.dev_mask()) {
case cpu::kDevMask:
- ret = RunContext{ctx, nullptr, false};
+ ret = RunContext{ctx, nullptr, nullptr, false};
break;
case gpu::kDevMask: {
#if MXNET_USE_CUDA
@@ -116,7 +124,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream<gpu>(false, false, ctx.dev_id);
}
}
- ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), false};
+ ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), nullptr, false};
break;
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
@@ -145,9 +153,12 @@ void StreamManager<kNumGpus, kStreams>::Finalize() {
#if MXNET_USE_CUDA
for (std::size_t i = 0; i < kNumGpus; ++i) {
if (gpu_cnt_.at(i) != -1) {
- for (auto&& j : gpu_streams_.at(i)) {
+ for (auto&& primary_stream : gpu_streams_.at(i)) {
// Catch exception for CUDA driver shutdown
- MSHADOW_CATCH_ERROR(mshadow::DeleteStream<gpu>(j));
+ MSHADOW_CATCH_ERROR(mshadow::DeleteStream<gpu>(primary_stream));
+ }
+ for (auto&& aux_stream : gpu_aux_streams_.at(i)) {
+ delete aux_stream;
}
gpu_cnt_.at(i) = -1;
}
diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc
index b6537da..bcb101e 100644
--- a/src/engine/threaded_engine_perdevice.cc
+++ b/src/engine/threaded_engine_perdevice.cc
@@ -99,7 +99,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
MSHADOW_CATCH_ERROR(mshadow::SetDevice<gpu>(ctx.dev_id));
#endif
}
- this->ExecuteOprBlock(RunContext{ctx, nullptr}, opr_block);
+ this->ExecuteOprBlock(RunContext{ctx, nullptr, nullptr, false}, opr_block);
} else {
if (ctx.dev_mask() == Context::kCPU) {
// CPU execution.
@@ -244,7 +244,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
this->is_worker_ = true;
#if MXNET_USE_CUDA
CHECK(block != nullptr);
- mshadow::Stream<gpu> *stream;
+ mshadow::Stream<gpu> *stream = nullptr;
+ GPUAuxStream *aux_stream = nullptr;
do {
ThreadPool::SetReadyOnDestroy setReady(ready_event);
// allocate stream
@@ -253,11 +254,12 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
stream = mshadow::NewStream<gpu>(false, false, ctx.dev_id);
} else {
stream = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0, ctx.dev_id);
+ aux_stream = new GPUAuxStream(stream);
}
} while (false);
// execute task
OprBlock* opr_block;
- RunContext run_ctx{ctx, stream};
+ RunContext run_ctx{ctx, stream, aux_stream, false};
auto* task_queue = &(block->task_queue);
// Don't eat up omp threads for GPU jobs. They're probably best used elsewhere,
@@ -269,6 +271,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
}
// Catch exception for CUDA driver shutdown
MSHADOW_CATCH_ERROR(mshadow::DeleteStream<gpu>(stream));
+ if (aux_stream != nullptr)
+ delete aux_stream;
#else
ready_event->signal();
#endif
@@ -283,7 +287,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
const std::shared_ptr<dmlc::ManualEvent>& ready_event) {
this->is_worker_ = true;
auto* task_queue = &(block->task_queue);
- RunContext run_ctx{ctx, nullptr};
+ RunContext run_ctx{ctx, nullptr, nullptr, false};
// execute task
OprBlock* opr_block;
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 7113cb2..a8db481 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -448,7 +448,7 @@ inline void PushFComputeEx(const FComputeEx& fn,
};
if (exec_type == ExecType::kCrossDeviceCopy) {
- run(RunContext{ctx, nullptr});
+ run(RunContext{ctx, nullptr, nullptr, false});
} else {
CHECK(exec_type == ExecType::kSync);
Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
@@ -498,7 +498,7 @@ inline void PushOperator(const OpStatePtr& state,
// For operators with subgraphs, we need to invoke them in the main thread
// instead of the threaded engine.
if (exec_type == ExecType::kSubgraphExec) {
- RunContext rctx{ctx, nullptr};
+ RunContext rctx{ctx, nullptr, nullptr, false};
run(rctx, engine::CallbackOnComplete());
} else if (exec_type == ExecType::kSync) {
Engine::Get()->PushSync(
@@ -546,7 +546,7 @@ inline void PushOperator(const OpStatePtr& state,
};
if (exec_type == ExecType::kSubgraphExec) {
- RunContext rctx{ctx, nullptr};
+ RunContext rctx{ctx, nullptr, nullptr, false};
run(rctx, engine::CallbackOnComplete());
} else if (exec_type == ExecType::kSync) {
Engine::Get()->PushSync(
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 0f0fed2..648f958 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1826,7 +1826,7 @@ void NDArray::SyncCopyFromCPU(const void *data, size_t size) const {
if (this->ctx().dev_mask() == cpu::kDevMask) {
this->WaitToWrite();
- RunContext rctx{this->ctx(), nullptr};
+ RunContext rctx{this->ctx(), nullptr, nullptr, false};
TBlob dst = this->data();
ndarray::Copy<cpu, cpu>(src, &dst, Context::CPU(), Context::CPU(), rctx);
} else {
@@ -1957,7 +1957,7 @@ void NDArray::SyncCopyToCPU(void *data, size_t size) const {
if (this->ctx().dev_mask() == cpu::kDevMask) {
this->WaitToRead();
- RunContext rctx{this->ctx(), nullptr};
+ RunContext rctx{this->ctx(), nullptr, nullptr, false};
NDArray src = *this;
#if MXNET_USE_MKLDNN == 1
if (src.IsMKLDNNData())
diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h
index 3bd6c5a..f68d2e3 100644
--- a/src/operator/nn/cudnn/cudnn_convolution-inl.h
+++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h
@@ -53,6 +53,7 @@ class CuDNNConvolutionOp {
CUDNN_CALL(cudnnCreateConvolutionDescriptor(&forward_conv_desc_));
CUDNN_CALL(cudnnCreateConvolutionDescriptor(&back_conv_desc_));
CUDNN_CALL(cudnnCreateConvolutionDescriptor(&back_conv_desc_w_));
+ parallelize_backward_kernels_ = Context::GetGPUStreamsPerWorker() >= 2;
}
void Init(const ConvolutionParam& param,
@@ -110,6 +111,7 @@ class CuDNNConvolutionOp {
// future cuDNN releases.
SelectAlgo(rctx, in_shape, out_shape,
cudnn_forward_compute_type, cudnn_backward_compute_type);
+ GetTempSize(rctx);
}
~CuDNNConvolutionOp() {
@@ -131,7 +133,6 @@ class CuDNNConvolutionOp {
CHECK_EQ(in_data.size(), expected);
CHECK_EQ(out_data.size(), 1U);
Stream<gpu> *s = ctx.get_stream<gpu>();
- GetTempSize(ctx);
Tensor<gpu, 1, DType> workspace = AllocateTempWorkspace(ctx, forward_workspace_byte_);
size_t workspace_size = TensorSizeBytes(workspace);
@@ -224,6 +225,8 @@ class CuDNNConvolutionOp {
CHECK_EQ(in_data.size(), expected);
CHECK_EQ(in_grad.size(), expected);
Stream<gpu> *s = ctx.get_stream<gpu>();
+ // RAII object to handle syncing of the underlying auxiliary stream with the primary stream
+ SyncedGPUAuxStream s_dgrad = ctx.get_gpu_aux_stream();
// I/O's should have 2 more dims than the kernel dim
DType *grad_ptr = GetNdPtr(out_grad[conv::kOut], param_.kernel.ndim() + 2, s);
@@ -232,9 +235,27 @@ class CuDNNConvolutionOp {
DType *data_ptr = GetNdPtr(in_data[conv::kData], param_.kernel.ndim() + 2, s);
DType *gdata_ptr = GetNdPtr(in_grad[conv::kData], param_.kernel.ndim() + 2, s);
- GetTempSize(ctx);
- Tensor<gpu, 1, DType> workspace = AllocateTempWorkspace(ctx, backward_workspace_byte_);
+ size_t backward_workspace_byte =
+ parallelize_backward_kernels_ ? back_workspace_byte_dgrad_ + back_workspace_byte_wgrad_
+ : std::max(back_workspace_byte_dgrad_,
+ back_workspace_byte_wgrad_);
+ Tensor<gpu, 1, DType> workspace = AllocateTempWorkspace(ctx, backward_workspace_byte);
size_t workspace_size = TensorSizeBytes(workspace);
+ DType *workspace_dptr_wgrad = workspace.dptr_;
+ DType *workspace_dptr_dgrad = workspace.dptr_;
+ if (parallelize_backward_kernels_) {
+ CHECK_LE(back_workspace_byte_dgrad_ + back_workspace_byte_wgrad_, workspace_size);
+ // Large allocations at some point will be given their own page. Pass this alignment on to
+ // the larger of the two separate dgrad/wgrad workspaces. This probably doesn't matter, but
+ // corresponds more closely to the workspace alignments used during cudnnFind.
+ if (back_workspace_byte_dgrad_ > back_workspace_byte_wgrad_)
+ workspace_dptr_wgrad = workspace.dptr_ + back_workspace_byte_dgrad_ / sizeof(DType);
+ else
+ workspace_dptr_dgrad = workspace.dptr_ + back_workspace_byte_wgrad_ / sizeof(DType);
+ } else {
+ CHECK_LE(back_workspace_byte_dgrad_, workspace_size);
+ CHECK_LE(back_workspace_byte_wgrad_, workspace_size);
+ }
#if CUDNN_MAJOR >= 7
typename DataType<DType>::ScaleType alpha = 1.0f;
typename DataType<DType>::ScaleType beta = 0.0f;
@@ -259,14 +280,14 @@ class CuDNNConvolutionOp {
grad_ptr,
back_conv_desc_w_,
back_algo_w_.AlgoNumber(),
- workspace.dptr_,
- workspace_size,
+ workspace_dptr_wgrad,
+ back_workspace_byte_wgrad_,
req[conv::kWeight] == kAddTo? &beta_add : &beta,
filter_desc_,
gwmat_ptr));
}
if (req[conv::kData] != kNullOp) {
- CUDNN_CALL(cudnnConvolutionBackwardData(s->dnn_handle_,
+ CUDNN_CALL(cudnnConvolutionBackwardData(s_dgrad.GetStream()->dnn_handle_,
&alpha,
filter_desc_,
wmat_ptr,
@@ -274,8 +295,8 @@ class CuDNNConvolutionOp {
grad_ptr,
back_conv_desc_,
back_algo_.AlgoNumber(),
- workspace.dptr_,
- workspace_size,
+ workspace_dptr_dgrad,
+ back_workspace_byte_dgrad_,
req[conv::kData] == kAddTo? &beta_add : &beta,
in_desc_,
gdata_ptr));
@@ -912,24 +933,30 @@ class CuDNNConvolutionOp {
}
- void GetTempSize(const OpContext& ctx) {
- mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
- size_t back_size = 0, back_size_w = 0;
+ void GetTempSize(const RunContext& rctx) {
+ mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize(s->dnn_handle_,
filter_desc_,
out_desc_,
back_conv_desc_,
in_desc_,
back_algo_.AlgoNumber(),
- &back_size));
+ &back_workspace_byte_dgrad_));
CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize(s->dnn_handle_,
in_desc_,
out_desc_,
back_conv_desc_w_,
filter_desc_,
back_algo_w_.AlgoNumber(),
- &back_size_w));
- backward_workspace_byte_ = std::max(back_size, back_size_w);
+ &back_workspace_byte_wgrad_));
+ // cudaMalloc returns addresses that are aligned for large accesses (e.g. to 512 bytes).
+ // Since we only make one allocation and divide it into two parts when we parallelize
+ // the dgrad and wgrad kernels, we round the sizes up to this alignment size so the
+ // dptrs respect this alignment, even if the separate areas are stacked.
+ const size_t dptr_alignment = 512;
+ back_workspace_byte_dgrad_ = RoundToMultiple(back_workspace_byte_dgrad_, dptr_alignment);
+ back_workspace_byte_wgrad_ = RoundToMultiple(back_workspace_byte_wgrad_, dptr_alignment);
+
CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(s->dnn_handle_,
in_desc_,
filter_desc_,
@@ -983,11 +1010,18 @@ class CuDNNConvolutionOp {
CastTShapeToIntPtr(param_.pad, ¶m_pad_);
}
+ // Round a value 'x' up to the next multiple of 'multiple'
+ size_t RoundToMultiple(size_t x, size_t multiple) {
+ size_t retVal = ((x + multiple - 1) / multiple) * multiple;
+ return retVal;
+ }
+
// Allocates a 1D Tensor of words with size in bytes >= `size_bytes`.
// Always allocates at least one word.
mshadow::Tensor<gpu, 1, DType> AllocateTempWorkspace(const OpContext &ctx, size_t size_bytes) {
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
- size_t size_words = size_bytes / sizeof(DType) + 1;
+ size_t size_words =
+ std::max<size_t>(1, RoundToMultiple(size_bytes, sizeof(DType)) / sizeof(DType));
return ctx.requested[conv::kTempSpace].get_space_typed<gpu, 1, DType>(
mshadow::Shape1(size_words), s);
}
@@ -1035,8 +1069,10 @@ class CuDNNConvolutionOp {
// Temp workspace size in bytes needed for Forward() operation.
size_t forward_workspace_byte_;
- // Temp workspace size in bytes needed for Backward() operation.
- size_t backward_workspace_byte_;
+ // Temp workspace size in bytes needed for Backward() dgrad (data gradient) operation.
+ size_t back_workspace_byte_dgrad_;
+ // Temp workspace size in bytes needed for Backward() wgrad (weight gradient) operation.
+ size_t back_workspace_byte_wgrad_;
size_t data_offset_;
size_t out_offset_;
size_t weight_offset_;
@@ -1052,6 +1088,8 @@ class CuDNNConvolutionOp {
cudnnConvolutionDescriptor_t back_conv_desc_;
// Convolution descriptor for back-prop operations to the weights
cudnnConvolutionDescriptor_t back_conv_desc_w_;
+ // Should dgrad and wgrad be launched into separate streams
+ bool parallelize_backward_kernels_;
// Algorithm for the forward inference operation
CuDNNAlgo<cudnnConvolutionFwdAlgo_t> forward_algo_;
// Algorithm for the back-prop operation to the data
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 010cf50..4dbf82e 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -520,6 +520,56 @@ def test_convolution_options():
sym_no_cudnn = mx.sym.Convolution(num_filter=3, kernel=(1,1,1), pad=(0,0,0), cudnn_off=True, name='conv')
check_consistency_NxM([sym, sym_no_cudnn], ctx_list)
+
+# Helper function to run tests in a subprocess to avoid save/restore of os.environ.
+# Also avoids issues of cached environment variable lookups in the backend.
+def _test_in_separate_process(func, env, *args):
+ try:
+ mpctx = mp.get_context('spawn')
+ except:
+ print('SKIP: python%s.%s lacks the required process fork-exec support ... ' %
+ sys.version_info[0:2], file=sys.stderr, end='')
+ else:
+ seed = np.random.randint(0,1024*1024*1024)
+ for (key, value) in env.items():
+ os.environ[key] = str(value)
+ # Prepend seed as first arg
+ p = mpctx.Process(target=func, args=(seed,)+args)
+ p.start()
+ p.join()
+ assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (p.exitcode, func.__name__)
+
+def _conv_with_num_streams(seed):
+ with random_seed(seed):
+ # Try to expose timing-dependent improper workspace sharing by parallel dgrad and wgrad
+ num_trials = 20
+ for _ in range(num_trials):
+ size = np.random.randint(32, 128)
+ # The cudnn conv operator runs dgrad and wgrad in separate streams if enabled, with possible
+ # kernel overlap. The non-cudnn conv op doesn't do this so is used as the 'golden copy'.
+ ctx = {'ctx': mx.gpu(0), 'conv_data': (2, 2, size, size),
+ 'type_dict': {'conv_data': np.float32}}
+ # Adding 'flip' here isolates the model from the input node (which can't use inplace store)
+ flipped = mx.sym.flip(axis=0, name='conv')
+ sym = mx.sym.Convolution(data=flipped, num_filter=3, kernel=(3,3), pad=(1,1), name='conv')
+ flipped_no_cudnn = mx.sym.flip(axis=0, name='conv')
+ sym_no_cudnn = mx.sym.Convolution(data=flipped_no_cudnn, num_filter=3, kernel=(3,3), pad=(1,1),
+ cudnn_off=True, name='conv')
+ try:
+ # tol can be pretty high- we're looking for a large diff due to garbaged workspace
+ check_consistency([sym, sym_no_cudnn], [ctx, ctx], tol=1e-2)
+ except:
+ print('Failing conv size = {}'.format(size))
+ raise
+
+@with_seed()
+def test_convolution_multiple_streams():
+ for num_streams in [1, 2]:
+ for engine in ['NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice']:
+ _test_in_separate_process(_conv_with_num_streams,
+ {'MXNET_GPU_WORKER_NSTREAMS' : num_streams, 'MXNET_ENGINE_TYPE' : engine})
+
+
# This test is designed to expose an issue with cudnn v7.1.4 algo find() when invoked with large c.
# Algos returned by find() can fail to run with grad_req='add' (wgrad kernel beta parameter == 1.0f).
@with_seed()