You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/12 22:49:34 UTC
[incubator-mxnet] branch master updated: fix small memory leak of
sparse embedding (#9025)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 167871a fix small memory leak of sparse embedding (#9025)
167871a is described below
commit 167871a135308971c22cb8f6bdc2c8e7477fda6e
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Wed Dec 13 06:49:31 2017 +0800
fix small memory leak of sparse embedding (#9025)
* disable empty output of ndarray.slice & fix small mem leak of sparse embedding
* revert
* replace cudamalloc with resource request
---
src/operator/tensor/indexing_op.cc | 3 ++-
src/operator/tensor/indexing_op.cu | 8 +++++---
src/operator/tensor/indexing_op.h | 5 ++---
3 files changed, 9 insertions(+), 7 deletions(-)
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index 7d885ad..735da31 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -29,7 +29,7 @@ namespace mxnet {
namespace op {
template<>
-void SparseEmbeddingOpForwardRspImpl<cpu>(mshadow::Stream<cpu>* s,
+void SparseEmbeddingOpForwardRspImpl<cpu>(const OpContext& ctx,
const TBlob& data,
const NDArray& weight,
const OpReqType req,
@@ -37,6 +37,7 @@ void SparseEmbeddingOpForwardRspImpl<cpu>(mshadow::Stream<cpu>* s,
if (req == kNullOp) return;
using namespace rowsparse;
using namespace mxnet_op;
+ mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
// zeros weight
if (req == kWriteTo && !weight.storage_initialized()) {
size_t out_size = output.shape_.Size();
diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu
index f029f02..4021f2b 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -61,7 +61,7 @@ struct AddTakeGradRspGPUKernel {
};
template<>
-void SparseEmbeddingOpForwardRspImpl<gpu>(mshadow::Stream<gpu>* s,
+void SparseEmbeddingOpForwardRspImpl<gpu>(const OpContext& ctx,
const TBlob& data,
const NDArray& weight,
const OpReqType req,
@@ -69,6 +69,7 @@ void SparseEmbeddingOpForwardRspImpl<gpu>(mshadow::Stream<gpu>* s,
if (req == kNullOp) return;
using namespace rowsparse;
using namespace mxnet_op;
+ mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
// zeros weight
if (req == kWriteTo && !weight.storage_initialized()) {
size_t out_size = output.shape_.Size();
@@ -85,8 +86,9 @@ void SparseEmbeddingOpForwardRspImpl<gpu>(mshadow::Stream<gpu>* s,
DType max = static_cast<DType>(weight.shape()[0] - 1);
DType* data_ptr = data.dptr<DType>();
size_t data_size = data.shape_.Size();
- int32_t* is_valid_ptr = NULL;
- CUDA_CALL(cudaMalloc(&is_valid_ptr, sizeof(int32_t)));
+ Tensor<gpu, 1, char> workspace = ctx.requested[0]
+ .get_space_typed<gpu, 1, char>(Shape1(sizeof(int32_t)), s);
+ int32_t* is_valid_ptr = reinterpret_cast<int32_t*>(workspace.dptr_);
Kernel<set_zero, gpu>::Launch(s, 1, is_valid_ptr);
Kernel<is_valid_check, gpu>::Launch(s, data_size, is_valid_ptr, data_ptr, min, max);
CUDA_CALL(cudaMemcpy(&is_valid, is_valid_ptr, sizeof(int32_t),
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index b0f06de..4043e76 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -364,7 +364,7 @@ inline void EmbeddingOpForwardRspImpl(mshadow::Stream<xpu>* s,
// Embedding forward implementation with row_sparse weight
template<typename xpu>
-void SparseEmbeddingOpForwardRspImpl(mshadow::Stream<xpu>* s,
+void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx,
const TBlob& data,
const NDArray& weight,
const OpReqType req,
@@ -406,10 +406,9 @@ void SparseEmbeddingOpForwardEx(const nnvm::NodeAttrs& attrs,
const auto data_stype = data.storage_type();
const auto weight_stype = weight.storage_type();
const auto out_stype = out.storage_type();
- mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
if (data_stype == kDefaultStorage && weight_stype == kRowSparseStorage &&
out_stype == kDefaultStorage) {
- SparseEmbeddingOpForwardRspImpl<xpu>(s, data.data(), weight, req[0], out.data());
+ SparseEmbeddingOpForwardRspImpl<xpu>(ctx, data.data(), weight, req[0], out.data());
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].