You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/04 19:30:00 UTC

[incubator-mxnet] branch master updated: Fix the gradient of gather_nd (#9200)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d918868  Fix the gradient of gather_nd (#9200)
d918868 is described below

commit d918868006a02824ac536e65e0d03d677f9b5af7
Author: Xingjian Shi <xs...@ust.hk>
AuthorDate: Thu Jan 4 11:29:57 2018 -0800

    Fix the gradient of gather_nd (#9200)
    
    * try to implement scatter_nd_acc
    
    fix
    
    fix
    
    fix
    
    update
    
    only support real_type
    
    update
    
    update
    
    try to fix
    
    update
    
    fix
    
    update
    
    revise test
    
    fix lint
    
    * fix
    
    * mark line as no lint
    
    * fix test
    
    * revise test
    
    * fix test case
    
    * revise
    
    * remove openmp
    
    * update
    
    * update
    
    * update
    
    * update test
    
    * Revert "update test"
    
    This reverts commit 3eb3ac6b2757ba8facb9387cd8b0080e0d496f46.
    
    * Revert "update"
    
    This reverts commit a28fa53a61e13bcffd0dc4503804d8704ea200a0.
    
    * Revert "update"
    
    This reverts commit e99ffd075832881348ff6cf7d1524fca9e614a2d.
    
    * Revert "update"
    
    This reverts commit 399ba0216bc21f279d46c688282fbbd37b0126c8.
    
    * add atomic and specialize the behavior of half_t
    
    * use "!" instead of not
    
    * add test
    
    * fix test
    
    * fix test
    
    * fix test
    
    * rename to backward_gather_nd
    
    * fix
    
    * fix
    
    * fix doc
---
 src/common/cuda_utils.h                |   5 ++
 src/operator/mxnet_op.h                |  44 ++++++++++++
 src/operator/tensor/indexing_op.cc     | 118 +++++++++++++++++++++++++++++++--
 src/operator/tensor/indexing_op.cu     |  29 ++++++++
 src/operator/tensor/indexing_op.h      |  64 +++++++++++++++++-
 tests/python/unittest/test_operator.py |  41 +++++++-----
 6 files changed, 277 insertions(+), 24 deletions(-)

diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index a1c37a9..9d3388b 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -479,6 +479,11 @@ static inline __device__ void atomicAdd(mshadow::half::half_t *address,
   } while (assumed != old);
 }
 
+// Overload atomicAdd to work for signed int64 on all architectures
+static inline  __device__  void atomicAdd(int64_t *address, int64_t val) {
+  atomicAdd(reinterpret_cast<unsigned long long*>(address), static_cast<unsigned long long>(val)); // NOLINT
+}
+
 template <typename DType>
 __device__ inline DType ldg(const DType* address) {
 #if __CUDA_ARCH__ >= 350
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 15ad59f..081e40a 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -132,6 +132,50 @@ inline int get_num_threads<cpu>(const int N) {
     LOG(FATAL) << "ndim=" << NDim << "too large "; \
   }
 
+#define MXNET_NO_INT8_TYPE_SWITCH(type, DType, ...)        \
+  switch (type) {                                          \
+  case mshadow::kFloat32:                                  \
+    {                                                      \
+      typedef float DType;                                 \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat64:                                  \
+    {                                                      \
+      typedef double DType;                                \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat16:                                  \
+    {                                                      \
+      typedef mshadow::half::half_t DType;                 \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kUint8:                                    \
+    LOG(FATAL) << "This operation does not "               \
+                  "support int8 or uint8";                 \
+    break;                                                 \
+  case mshadow::kInt8:                                     \
+    LOG(FATAL) << "This operation does not "               \
+                  "support int8 or uint8";                 \
+    break;                                                 \
+  case mshadow::kInt32:                                    \
+    {                                                      \
+      typedef int32_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt64:                                    \
+    {                                                      \
+      typedef int64_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  default:                                                 \
+    LOG(FATAL) << "Unknown type enum " << type;            \
+  }
+
 
 /*!
  * \brief assign the val to out according
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index 735da31..10905b5 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -137,6 +137,46 @@ inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const OpContext& ctx,
 }
 
 
+template<typename DType, typename IType>
+inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
+GatherNDBackwardImpl(int N, int M, int K,
+                     const mshadow::Shape<10> strides,
+                     DType* out,
+                     const DType* data,
+                     const IType* indices,
+                     mshadow::Stream<cpu> *s) {
+#pragma omp parallel for
+  for (int i = 0; i < N; i++) {
+    int offset = 0;
+    for (int j = 0; j < M; ++j) {
+      offset += strides[j] * static_cast<int>(indices[j*N + i]);
+    }
+    for (int j = 0; j < K; ++j) {
+#pragma omp atomic
+      out[offset + j] += data[i * K + j];
+    }
+  }
+}
+
+template<typename DType, typename IType>
+inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type
+GatherNDBackwardImpl(int N, int M, int K,
+                     const mshadow::Shape<10> strides,
+                     DType* out,
+                     const DType* data,
+                     const IType* indices,
+                     mshadow::Stream<cpu> *s) {
+  for (int i = 0; i < N; i++) {
+    int offset = 0;
+    for (int j = 0; j < M; ++j) {
+      offset += strides[j] * static_cast<int>(indices[j*N + i]);
+    }
+    for (int j = 0; j < K; ++j) {
+      out[offset + j] += data[i * K + j];
+    }
+  }
+}
+
 DMLC_REGISTER_PARAMETER(EmbeddingParam);
 DMLC_REGISTER_PARAMETER(TakeParam);
 DMLC_REGISTER_PARAMETER(OneHotParam);
@@ -443,8 +483,7 @@ Examples::
 
 NNVM_REGISTER_OP(gather_nd)
 .describe(R"code(Gather elements or slices from `data` and store to a tensor whose
-shape is defined by `indices`. `gather_nd` and `scatter_nd` are inverse functions
-to each other.
+shape is defined by `indices`.
 
 Given `data` with shape `(X_0, X_1, ..., X_{N-1})` and indices with shape
 `(M, Y_0, ..., Y_{K-1})`, the output will have shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})`,
@@ -476,13 +515,14 @@ Examples::
 .set_attr<nnvm::FGradient>("FGradient",
   [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
     auto p = nnvm::Node::Create();
-    p->attrs.op = nnvm::Op::Get("scatter_nd");
+    p->attrs.op = nnvm::Op::Get("_backward_gather_nd");
     p->attrs.name = n->attrs.name + "_backward";
     p->inputs.push_back(ograds[0]);
     p->inputs.push_back(n->inputs[1]);
     p->control_deps.emplace_back(n);
     auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
                          {n->inputs[1]}, nullptr, &n);
+
     std::vector<nnvm::NodeEntry> ret;
     ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
     ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
@@ -492,10 +532,8 @@ Examples::
 .add_argument("data", "NDArray-or-Symbol", "data")
 .add_argument("indices", "NDArray-or-Symbol", "indices");
 
-
 NNVM_REGISTER_OP(scatter_nd)
 .describe(R"code(Scatters data into a new tensor according to indices.
-`gather_nd` and `scatter_nd` are inverse functions to each other.
 
 Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
 `(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
@@ -510,6 +548,12 @@ The elements in output is defined as follows::
 
 all other entries in output are 0.
 
+.. warning::
+
+    If the indices have duplicates, the result will be non-deterministic and
+    the gradient of `scatter_nd` will not be correct!!
+
+
 Examples::
 
   data = [2, 3, 0]
@@ -548,11 +592,73 @@ Examples::
 .add_argument("indices", "NDArray-or-Symbol", "indices")
 .add_arguments(ScatterNDParam::__FIELDS__());
 
+NNVM_REGISTER_OP(_backward_gather_nd)
+.describe(R"code(Accumulates data according to indices and get the result. It's the backward of
+`gather_nd`.
+
+Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
+`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
+where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`.
+
+The elements in output is defined as follows::
+
+  output[indices[0, y_0, ..., y_{K-1}],
+         ...,
+         indices[M-1, y_0, ..., y_{K-1}],
+         x_M, ..., x_{N-1}] += data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
+
+all other entries in output are 0 or the original value if AddTo is triggered.
+
+Examples::
+
+  data = [2, 3, 0]
+  indices = [[1, 1, 0], [0, 1, 0]]
+  shape = (2, 2)
+  _backward_gather_nd(data, indices, shape) = [[0, 0], [2, 3]] # Same as scatter_nd
+
+  # The difference between scatter_nd and scatter_nd_acc is the latter will accumulate
+  #  the values that point to the same index.
+
+  data = [2, 3, 0]
+  indices = [[1, 1, 0], [1, 1, 0]]
+  shape = (2, 2)
+  _backward_gather_nd(data, indices, shape) = [[0, 0], [0, 5]]
+
+)code")
+.set_num_outputs(1)
+.set_num_inputs(2)
+.set_attr_parser(ParamParser<ScatterNDParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data", "indices"};
+  })
+.set_attr<nnvm::FInferShape>("FInferShape", ScatterNDShape)
+.set_attr<nnvm::FInferType>("FInferType", ScatterNDType)
+.set_attr<FCompute>("FCompute<cpu>", GatherNDBackward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient",
+  [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+    auto p = nnvm::Node::Create();
+    p->attrs.op = nnvm::Op::Get("gather_nd");
+    p->attrs.name = n->attrs.name + "_backward";
+    p->inputs.push_back(ograds[0]);
+    p->inputs.push_back(n->inputs[1]);
+    p->control_deps.emplace_back(n);
+    auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
+                         {n->inputs[1]}, nullptr, &n);
+    std::vector<nnvm::NodeEntry> ret;
+    ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
+    ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
+    return ret;
+  })
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.add_argument("data", "NDArray-or-Symbol", "data")
+.add_argument("indices", "NDArray-or-Symbol", "indices")
+.add_arguments(ScatterNDParam::__FIELDS__());
+
 NNVM_REGISTER_OP(_scatter_set_nd)
 .describe(R"code(This operator has the same functionality as scatter_nd
 except that it does not reset the elements not indexed by the input
 index `NDArray` in the input data `NDArray`.
-
 .. note:: This operator is for internal use only.
 
 Examples::
diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu
index 4021f2b..762d8fd 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -179,6 +179,32 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
   });
 }
 
+struct backward_gather_nd_gpu {
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, int N, int M, int K,
+                                  const mshadow::Shape<10> strides,
+                                  DType* out, const DType* data,
+                                  const IType* indices) {
+    int offset = 0;
+    for (int j = 0; j < M; ++j) {
+      offset += strides[j] * static_cast<int>(indices[j*N + i]);
+    }
+    for (int j = 0; j < K; ++j) {
+      atomicAdd(out + (offset + j), data[i * K + j]);
+    }
+  }
+};
+
+template<typename DType, typename IType>
+inline void GatherNDBackwardImpl(int N, int M, int K,
+                                 const mshadow::Shape<10> strides,
+                                 DType* out,
+                                 const DType* data,
+                                 const IType* indices,
+                                 mshadow::Stream<gpu> *s) {
+  mxnet_op::Kernel<backward_gather_nd_gpu, gpu>::Launch(s, N, N, M, K, strides, out, data, indices);
+}
+
 NNVM_REGISTER_OP(Embedding)
 .set_attr<FCompute>("FCompute<gpu>", EmbeddingOpForward<gpu>);
 
@@ -209,6 +235,9 @@ NNVM_REGISTER_OP(gather_nd)
 NNVM_REGISTER_OP(scatter_nd)
 .set_attr<FCompute>("FCompute<gpu>", ScatterNDForward<gpu>);
 
+NNVM_REGISTER_OP(_backward_gather_nd)
+.set_attr<FCompute>("FCompute<gpu>", GatherNDBackward<gpu>);
+
 NNVM_REGISTER_OP(_scatter_set_nd)
 .set_attr<FCompute>("FCompute<gpu>", ScatterSetNDForward<gpu>);
 }  // namespace op
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 4043e76..7323f81 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -1131,10 +1131,10 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
   int K = oshape.ProdShape(M, oshape.ndim());
   mshadow::Shape<10> strides;
   for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
+  if (kWriteTo == req[0]) {
+    Fill<true>(s, outputs[0], req[0], 0);
+  }
   MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {  // output data type switch
-    if (kWriteTo == req[0]) {
-      Fill<true>(s, outputs[0], req[0], 0);
-    }
     MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // indices data type switch
       mxnet_op::Kernel<scatter_nd, xpu>::Launch(
         s, N, req[0], N, M, K, strides, outputs[0].dptr<DType>(),
@@ -1143,6 +1143,64 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
   });
 }
 
+template<typename DType, typename IType>
+inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
+GatherNDBackwardImpl(int N, int M, int K,
+                     const mshadow::Shape<10> strides,
+                     DType* out,
+                     const DType* data,
+                     const IType* indices,
+                     mshadow::Stream<cpu> *s);
+
+template<typename DType, typename IType>
+inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type
+GatherNDBackwardImpl(int N, int M, int K,
+                     const mshadow::Shape<10> strides,
+                     DType* out,
+                     const DType* data,
+                     const IType* indices,
+                     mshadow::Stream<cpu> *s);
+
+template<typename DType, typename IType>
+inline void GatherNDBackwardImpl(int N, int M, int K,
+                                 const mshadow::Shape<10> strides,
+                                 DType* out,
+                                 const DType* data,
+                                 const IType* indices,
+                                 mshadow::Stream<gpu> *s);
+
+template<typename xpu>
+void GatherNDBackward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  if (req[0] == kNullOp) return;
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TShape& oshape = outputs[0].shape_;
+  const TShape& ishape = inputs[1].shape_;
+  int M = ishape[0];
+  int N = ishape.Size() / M;
+  int K = oshape.ProdShape(M, oshape.ndim());
+  mshadow::Shape<10> strides;
+  for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
+  if (kWriteTo == req[0]) {
+    Fill<true>(s, outputs[0], req[0], 0);
+  }
+  MXNET_NO_INT8_TYPE_SWITCH(inputs[0].type_flag_, DType, {  // output data type switch
+    MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // indices data type switch
+      GatherNDBackwardImpl(N, M, K, strides,
+                           outputs[0].dptr<DType>(),
+                           inputs[0].dptr<DType>(),
+                           inputs[1].dptr<IType>(),
+                           s);
+    });
+  });
+}
+
 /*!
  * This is for internal use only.
  * DO NOT call this function unless you have to.
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 3fbf98b..56dc27c 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4391,21 +4391,32 @@ def test_scatter_gather_nd():
         npdata = np.zeros_like(data.asnumpy())
         npdata[npidx] = y.asnumpy()
         assert (npdata == data.grad.asnumpy()).all()
-        assert (mx.nd.scatter_nd(y, idx, shape=data.shape).asnumpy() == data.grad.asnumpy()).all()
-
-    data = mx.nd.arange(360, dtype='int32').reshape((3,4,5,6))
-    idx = mx.nd.array([[1,1,2], [3, 3, 0], [3,2,1]], dtype='int32')
-
-    check(data, idx)
-
-    idx = mx.nd.array([[1,1,2], [3,3,0], [3,2,1], [5,2,4]], dtype='int32')
-
-    check(data, idx)
-
-    data = mx.nd.array([2, 3, 0])
-    idx = mx.nd.array([[1, 1, 0], [0, 1, 0]])
-
-    assert (mx.nd.scatter_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [2, 3]]).all()
+        assert (mx.nd._internal._backward_gather_nd(y, idx, shape=data.shape).asnumpy() == data.grad.asnumpy()).all()
+    for dtype in ['int32', 'int64', 'float16', 'float32', 'float64']:
+        data = mx.nd.arange(360, dtype=dtype).reshape((3,4,5,6))
+        idx = mx.nd.array([[1,1,2], [3, 3, 0], [3,2,1]], dtype='int32')
+        check(data, idx)
+
+        idx = mx.nd.array([[1,1,2], [3,3,0], [3,2,1], [5,2,4]], dtype='int32')
+
+        check(data, idx)
+
+        data = mx.nd.array([2, 3, 0], dtype=dtype)
+        idx = mx.nd.array([[1, 1, 0], [0, 1, 0]], dtype='int32')
+        assert (mx.nd.scatter_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [2, 3]]).all()
+
+        data = mx.nd.array([2, 3, 0], dtype=dtype)
+        idx = mx.nd.array([[1, 1, 0], [1, 1, 0]], dtype='int32')
+        assert (mx.nd._internal._backward_gather_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [0, 5]]).all()
+        data_npy = np.random.randint(0, 10, (100,))
+        data = mx.nd.array(data_npy, dtype=dtype)
+        idx = mx.nd.zeros(shape=(1, 100), dtype='int32')
+        assert (mx.nd._internal._backward_gather_nd(data, idx, shape=(1,)).asscalar() == data_npy.sum())
+        if dtype == 'int64':
+            data = mx.nd.array([2123162361283621, -31231236374787,
+                                -112372937128970, -1378278798172378], dtype=dtype)
+            idx = mx.nd.array([[0, 0, 0, 0]], dtype='int32')
+            assert (mx.nd._internal._backward_gather_nd(data, idx, shape=(1,)).asscalar() == data.asnumpy().sum())
 
 def compare_forw_backw_unary_op(
         name, forward_mxnet_call, forward_numpy_call,

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].