You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/04 19:30:00 UTC
[incubator-mxnet] branch master updated: Fix the gradient of
gather_nd (#9200)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new d918868 Fix the gradient of gather_nd (#9200)
d918868 is described below
commit d918868006a02824ac536e65e0d03d677f9b5af7
Author: Xingjian Shi <xs...@ust.hk>
AuthorDate: Thu Jan 4 11:29:57 2018 -0800
Fix the gradient of gather_nd (#9200)
* try to implement scatter_nd_acc
fix
fix
fix
update
only support real_type
update
update
try to fix
update
fix
update
revise test
fix lint
* fix
* mark line as no lint
* fix test
* revise test
* fix test case
* revise
* remove openmp
* update
* update
* update
* update test
* Revert "update test"
This reverts commit 3eb3ac6b2757ba8facb9387cd8b0080e0d496f46.
* Revert "update"
This reverts commit a28fa53a61e13bcffd0dc4503804d8704ea200a0.
* Revert "update"
This reverts commit e99ffd075832881348ff6cf7d1524fca9e614a2d.
* Revert "update"
This reverts commit 399ba0216bc21f279d46c688282fbbd37b0126c8.
* add atomic and specialize the behavior of half_t
* use "!" instead of not
* add test
* fix test
* fix test
* fix test
* rename to backward_gather_nd
* fix
* fix
* fix doc
---
src/common/cuda_utils.h | 5 ++
src/operator/mxnet_op.h | 44 ++++++++++++
src/operator/tensor/indexing_op.cc | 118 +++++++++++++++++++++++++++++++--
src/operator/tensor/indexing_op.cu | 29 ++++++++
src/operator/tensor/indexing_op.h | 64 +++++++++++++++++-
tests/python/unittest/test_operator.py | 41 +++++++-----
6 files changed, 277 insertions(+), 24 deletions(-)
diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index a1c37a9..9d3388b 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -479,6 +479,11 @@ static inline __device__ void atomicAdd(mshadow::half::half_t *address,
} while (assumed != old);
}
+// Overload atomicAdd to work for signed int64 on all architectures
+static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
+ atomicAdd(reinterpret_cast<unsigned long long*>(address), static_cast<unsigned long long>(val)); // NOLINT
+}
+
template <typename DType>
__device__ inline DType ldg(const DType* address) {
#if __CUDA_ARCH__ >= 350
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 15ad59f..081e40a 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -132,6 +132,50 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "ndim=" << NDim << "too large "; \
}
+#define MXNET_NO_INT8_TYPE_SWITCH(type, DType, ...) \
+ switch (type) { \
+ case mshadow::kFloat32: \
+ { \
+ typedef float DType; \
+ {__VA_ARGS__} \
+ } \
+ break; \
+ case mshadow::kFloat64: \
+ { \
+ typedef double DType; \
+ {__VA_ARGS__} \
+ } \
+ break; \
+ case mshadow::kFloat16: \
+ { \
+ typedef mshadow::half::half_t DType; \
+ {__VA_ARGS__} \
+ } \
+ break; \
+ case mshadow::kUint8: \
+ LOG(FATAL) << "This operation does not " \
+ "support int8 or uint8"; \
+ break; \
+ case mshadow::kInt8: \
+ LOG(FATAL) << "This operation does not " \
+ "support int8 or uint8"; \
+ break; \
+ case mshadow::kInt32: \
+ { \
+ typedef int32_t DType; \
+ {__VA_ARGS__} \
+ } \
+ break; \
+ case mshadow::kInt64: \
+ { \
+ typedef int64_t DType; \
+ {__VA_ARGS__} \
+ } \
+ break; \
+ default: \
+ LOG(FATAL) << "Unknown type enum " << type; \
+ }
+
/*!
* \brief assign the val to out according
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index 735da31..10905b5 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -137,6 +137,46 @@ inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const OpContext& ctx,
}
+template<typename DType, typename IType>
+inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
+GatherNDBackwardImpl(int N, int M, int K,
+ const mshadow::Shape<10> strides,
+ DType* out,
+ const DType* data,
+ const IType* indices,
+ mshadow::Stream<cpu> *s) {
+#pragma omp parallel for
+ for (int i = 0; i < N; i++) {
+ int offset = 0;
+ for (int j = 0; j < M; ++j) {
+ offset += strides[j] * static_cast<int>(indices[j*N + i]);
+ }
+ for (int j = 0; j < K; ++j) {
+#pragma omp atomic
+ out[offset + j] += data[i * K + j];
+ }
+ }
+}
+
+template<typename DType, typename IType>
+inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type
+GatherNDBackwardImpl(int N, int M, int K,
+ const mshadow::Shape<10> strides,
+ DType* out,
+ const DType* data,
+ const IType* indices,
+ mshadow::Stream<cpu> *s) {
+ for (int i = 0; i < N; i++) {
+ int offset = 0;
+ for (int j = 0; j < M; ++j) {
+ offset += strides[j] * static_cast<int>(indices[j*N + i]);
+ }
+ for (int j = 0; j < K; ++j) {
+ out[offset + j] += data[i * K + j];
+ }
+ }
+}
+
DMLC_REGISTER_PARAMETER(EmbeddingParam);
DMLC_REGISTER_PARAMETER(TakeParam);
DMLC_REGISTER_PARAMETER(OneHotParam);
@@ -443,8 +483,7 @@ Examples::
NNVM_REGISTER_OP(gather_nd)
.describe(R"code(Gather elements or slices from `data` and store to a tensor whose
-shape is defined by `indices`. `gather_nd` and `scatter_nd` are inverse functions
-to each other.
+shape is defined by `indices`.
Given `data` with shape `(X_0, X_1, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})`,
@@ -476,13 +515,14 @@ Examples::
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
- p->attrs.op = nnvm::Op::Get("scatter_nd");
+ p->attrs.op = nnvm::Op::Get("_backward_gather_nd");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
{n->inputs[1]}, nullptr, &n);
+
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
@@ -492,10 +532,8 @@ Examples::
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices");
-
NNVM_REGISTER_OP(scatter_nd)
.describe(R"code(Scatters data into a new tensor according to indices.
-`gather_nd` and `scatter_nd` are inverse functions to each other.
Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
@@ -510,6 +548,12 @@ The elements in output is defined as follows::
all other entries in output are 0.
+.. warning::
+
+ If the indices have duplicates, the result will be non-deterministic and
+ the gradient of `scatter_nd` will not be correct!!
+
+
Examples::
data = [2, 3, 0]
@@ -548,11 +592,73 @@ Examples::
.add_argument("indices", "NDArray-or-Symbol", "indices")
.add_arguments(ScatterNDParam::__FIELDS__());
+NNVM_REGISTER_OP(_backward_gather_nd)
+.describe(R"code(Accumulates data according to indices and get the result. It's the backward of
+`gather_nd`.
+
+Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
+`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
+where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`.
+
+The elements in output is defined as follows::
+
+ output[indices[0, y_0, ..., y_{K-1}],
+ ...,
+ indices[M-1, y_0, ..., y_{K-1}],
+ x_M, ..., x_{N-1}] += data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
+
+all other entries in output are 0 or the original value if AddTo is triggered.
+
+Examples::
+
+ data = [2, 3, 0]
+ indices = [[1, 1, 0], [0, 1, 0]]
+ shape = (2, 2)
+ _backward_gather_nd(data, indices, shape) = [[0, 0], [2, 3]] # Same as scatter_nd
+
+ # The difference between scatter_nd and scatter_nd_acc is the latter will accumulate
+ # the values that point to the same index.
+
+ data = [2, 3, 0]
+ indices = [[1, 1, 0], [1, 1, 0]]
+ shape = (2, 2)
+ _backward_gather_nd(data, indices, shape) = [[0, 0], [0, 5]]
+
+)code")
+.set_num_outputs(1)
+.set_num_inputs(2)
+.set_attr_parser(ParamParser<ScatterNDParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"data", "indices"};
+ })
+.set_attr<nnvm::FInferShape>("FInferShape", ScatterNDShape)
+.set_attr<nnvm::FInferType>("FInferType", ScatterNDType)
+.set_attr<FCompute>("FCompute<cpu>", GatherNDBackward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient",
+ [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+ auto p = nnvm::Node::Create();
+ p->attrs.op = nnvm::Op::Get("gather_nd");
+ p->attrs.name = n->attrs.name + "_backward";
+ p->inputs.push_back(ograds[0]);
+ p->inputs.push_back(n->inputs[1]);
+ p->control_deps.emplace_back(n);
+ auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
+ {n->inputs[1]}, nullptr, &n);
+ std::vector<nnvm::NodeEntry> ret;
+ ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
+ ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
+ return ret;
+ })
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.add_argument("data", "NDArray-or-Symbol", "data")
+.add_argument("indices", "NDArray-or-Symbol", "indices")
+.add_arguments(ScatterNDParam::__FIELDS__());
+
NNVM_REGISTER_OP(_scatter_set_nd)
.describe(R"code(This operator has the same functionality as scatter_nd
except that it does not reset the elements not indexed by the input
index `NDArray` in the input data `NDArray`.
-
.. note:: This operator is for internal use only.
Examples::
diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu
index 4021f2b..762d8fd 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -179,6 +179,32 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
});
}
+struct backward_gather_nd_gpu {
+ template<typename DType, typename IType>
+ MSHADOW_XINLINE static void Map(int i, int N, int M, int K,
+ const mshadow::Shape<10> strides,
+ DType* out, const DType* data,
+ const IType* indices) {
+ int offset = 0;
+ for (int j = 0; j < M; ++j) {
+ offset += strides[j] * static_cast<int>(indices[j*N + i]);
+ }
+ for (int j = 0; j < K; ++j) {
+ atomicAdd(out + (offset + j), data[i * K + j]);
+ }
+ }
+};
+
+template<typename DType, typename IType>
+inline void GatherNDBackwardImpl(int N, int M, int K,
+ const mshadow::Shape<10> strides,
+ DType* out,
+ const DType* data,
+ const IType* indices,
+ mshadow::Stream<gpu> *s) {
+ mxnet_op::Kernel<backward_gather_nd_gpu, gpu>::Launch(s, N, N, M, K, strides, out, data, indices);
+}
+
NNVM_REGISTER_OP(Embedding)
.set_attr<FCompute>("FCompute<gpu>", EmbeddingOpForward<gpu>);
@@ -209,6 +235,9 @@ NNVM_REGISTER_OP(gather_nd)
NNVM_REGISTER_OP(scatter_nd)
.set_attr<FCompute>("FCompute<gpu>", ScatterNDForward<gpu>);
+NNVM_REGISTER_OP(_backward_gather_nd)
+.set_attr<FCompute>("FCompute<gpu>", GatherNDBackward<gpu>);
+
NNVM_REGISTER_OP(_scatter_set_nd)
.set_attr<FCompute>("FCompute<gpu>", ScatterSetNDForward<gpu>);
} // namespace op
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 4043e76..7323f81 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -1131,10 +1131,10 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
int K = oshape.ProdShape(M, oshape.ndim());
mshadow::Shape<10> strides;
for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
+ if (kWriteTo == req[0]) {
+ Fill<true>(s, outputs[0], req[0], 0);
+ }
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch
- if (kWriteTo == req[0]) {
- Fill<true>(s, outputs[0], req[0], 0);
- }
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
mxnet_op::Kernel<scatter_nd, xpu>::Launch(
s, N, req[0], N, M, K, strides, outputs[0].dptr<DType>(),
@@ -1143,6 +1143,64 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
});
}
+template<typename DType, typename IType>
+inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
+GatherNDBackwardImpl(int N, int M, int K,
+ const mshadow::Shape<10> strides,
+ DType* out,
+ const DType* data,
+ const IType* indices,
+ mshadow::Stream<cpu> *s);
+
+template<typename DType, typename IType>
+inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type
+GatherNDBackwardImpl(int N, int M, int K,
+ const mshadow::Shape<10> strides,
+ DType* out,
+ const DType* data,
+ const IType* indices,
+ mshadow::Stream<cpu> *s);
+
+template<typename DType, typename IType>
+inline void GatherNDBackwardImpl(int N, int M, int K,
+ const mshadow::Shape<10> strides,
+ DType* out,
+ const DType* data,
+ const IType* indices,
+ mshadow::Stream<gpu> *s);
+
+template<typename xpu>
+void GatherNDBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mshadow;
+ CHECK_EQ(inputs.size(), 2U);
+ CHECK_EQ(outputs.size(), 1U);
+ if (req[0] == kNullOp) return;
+ mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+ const TShape& oshape = outputs[0].shape_;
+ const TShape& ishape = inputs[1].shape_;
+ int M = ishape[0];
+ int N = ishape.Size() / M;
+ int K = oshape.ProdShape(M, oshape.ndim());
+ mshadow::Shape<10> strides;
+ for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
+ if (kWriteTo == req[0]) {
+ Fill<true>(s, outputs[0], req[0], 0);
+ }
+ MXNET_NO_INT8_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch
+ MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
+ GatherNDBackwardImpl(N, M, K, strides,
+ outputs[0].dptr<DType>(),
+ inputs[0].dptr<DType>(),
+ inputs[1].dptr<IType>(),
+ s);
+ });
+ });
+}
+
/*!
* This is for internal use only.
* DO NOT call this function unless you have to.
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 3fbf98b..56dc27c 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4391,21 +4391,32 @@ def test_scatter_gather_nd():
npdata = np.zeros_like(data.asnumpy())
npdata[npidx] = y.asnumpy()
assert (npdata == data.grad.asnumpy()).all()
- assert (mx.nd.scatter_nd(y, idx, shape=data.shape).asnumpy() == data.grad.asnumpy()).all()
-
- data = mx.nd.arange(360, dtype='int32').reshape((3,4,5,6))
- idx = mx.nd.array([[1,1,2], [3, 3, 0], [3,2,1]], dtype='int32')
-
- check(data, idx)
-
- idx = mx.nd.array([[1,1,2], [3,3,0], [3,2,1], [5,2,4]], dtype='int32')
-
- check(data, idx)
-
- data = mx.nd.array([2, 3, 0])
- idx = mx.nd.array([[1, 1, 0], [0, 1, 0]])
-
- assert (mx.nd.scatter_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [2, 3]]).all()
+ assert (mx.nd._internal._backward_gather_nd(y, idx, shape=data.shape).asnumpy() == data.grad.asnumpy()).all()
+ for dtype in ['int32', 'int64', 'float16', 'float32', 'float64']:
+ data = mx.nd.arange(360, dtype=dtype).reshape((3,4,5,6))
+ idx = mx.nd.array([[1,1,2], [3, 3, 0], [3,2,1]], dtype='int32')
+ check(data, idx)
+
+ idx = mx.nd.array([[1,1,2], [3,3,0], [3,2,1], [5,2,4]], dtype='int32')
+
+ check(data, idx)
+
+ data = mx.nd.array([2, 3, 0], dtype=dtype)
+ idx = mx.nd.array([[1, 1, 0], [0, 1, 0]], dtype='int32')
+ assert (mx.nd.scatter_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [2, 3]]).all()
+
+ data = mx.nd.array([2, 3, 0], dtype=dtype)
+ idx = mx.nd.array([[1, 1, 0], [1, 1, 0]], dtype='int32')
+ assert (mx.nd._internal._backward_gather_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [0, 5]]).all()
+ data_npy = np.random.randint(0, 10, (100,))
+ data = mx.nd.array(data_npy, dtype=dtype)
+ idx = mx.nd.zeros(shape=(1, 100), dtype='int32')
+ assert (mx.nd._internal._backward_gather_nd(data, idx, shape=(1,)).asscalar() == data_npy.sum())
+ if dtype == 'int64':
+ data = mx.nd.array([2123162361283621, -31231236374787,
+ -112372937128970, -1378278798172378], dtype=dtype)
+ idx = mx.nd.array([[0, 0, 0, 0]], dtype='int32')
+ assert (mx.nd._internal._backward_gather_nd(data, idx, shape=(1,)).asscalar() == data.asnumpy().sum())
def compare_forw_backw_unary_op(
name, forward_mxnet_call, forward_numpy_call,
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].