You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/10/12 23:10:01 UTC

[incubator-mxnet] branch numpy_pr_merge updated (753fbf1 -> 08a4db7)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a change to branch numpy_pr_merge
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git.


 discard 753fbf1  numpy eye op (#16132)
    omit c13806b  [Numpy] Numpy compatible dstack (#15871)
     add ca30ba8  Pseudo 2D transpose kernel (#16229)
     add d2d76dc  increase docker cache timeout (#16430)
     add 4dee4ee  Fix mkldnn reshape (#16455)
     new 0fcd57f  [Numpy] Numpy compatible dstack (#15871)
     new 08a4db7  numpy eye op (#16132)

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (753fbf1)
            \
             N -- N -- N   refs/heads/numpy_pr_merge (08a4db7)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 ci/docker_cache.py                                 |   2 +-
 src/operator/nn/mkldnn/mkldnn_base-inl.h           |   1 -
 src/operator/nn/mkldnn/mkldnn_base.cc              |   6 +-
 src/operator/nn/mkldnn/mkldnn_flatten-inl.h        |  48 ---
 src/operator/nn/mkldnn/mkldnn_flatten.cc           |  79 -----
 src/operator/nn/mkldnn/mkldnn_ops-inl.h            |   5 -
 src/operator/nn/mkldnn/mkldnn_reshape-inl.h        |  27 +-
 src/operator/nn/mkldnn/mkldnn_reshape.cc           | 118 +++----
 .../mkldnn/mkldnn_quantized_flatten.cc             |   4 +-
 src/operator/tensor/matrix_op-inl.h                |  16 +
 src/operator/tensor/matrix_op.cc                   |  42 +--
 src/operator/tensor/pseudo2DTranspose_op-inl.cuh   | 348 +++++++++++++++++++++
 tests/python/unittest/test_operator.py             |  39 +++
 13 files changed, 473 insertions(+), 262 deletions(-)
 delete mode 100644 src/operator/nn/mkldnn/mkldnn_flatten-inl.h
 delete mode 100644 src/operator/nn/mkldnn/mkldnn_flatten.cc
 create mode 100644 src/operator/tensor/pseudo2DTranspose_op-inl.cuh


[incubator-mxnet] 02/02: numpy eye op (#16132)

Posted by ha...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy_pr_merge
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 08a4db767b67b6b03965d76514fff7638338cc0e
Author: Jake Lee <gs...@gmail.com>
AuthorDate: Fri Oct 11 21:58:42 2019 -0700

    numpy eye op (#16132)
---
 python/mxnet/ndarray/numpy/_op.py      | 35 ++++++++++++++++-
 python/mxnet/numpy/multiarray.py       | 29 +++++++++++++-
 python/mxnet/symbol/numpy/_symbol.py   | 35 ++++++++++++++++-
 src/operator/numpy/np_init_op.cc       | 30 ++++++--------
 src/operator/numpy/np_init_op.cu       |  3 ++
 src/operator/numpy/np_init_op.h        | 72 ++++++++++++++++++++++++++++++++++
 src/operator/tensor/init_op.h          | 40 +++++++++++--------
 tests/python/unittest/test_numpy_op.py | 68 ++++++++++++++++++++++++++++++++
 8 files changed, 274 insertions(+), 38 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 2846d2b..0bf6232 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -22,7 +22,7 @@
 from __future__ import absolute_import
 import numpy as _np
 from ...base import numeric_types
-from ...util import set_module
+from ...util import _sanity_check_params, set_module
 from ...context import current_context
 from . import _internal as _npi
 from ..ndarray import NDArray
@@ -31,7 +31,7 @@ __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mo
            'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
            'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
            'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
-           'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
+           'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
            'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
            'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
            'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
@@ -789,6 +789,37 @@ def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
 
 
 @set_module('mxnet.ndarray.numpy')
+def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
+    """
+    Return a 2-D array with ones on the diagonal and zeros elsewhere.
+
+    Parameters
+    ----------
+    N : int
+        Number of rows in the output.
+    M : int, optional
+        Number of columns in the output. If None, defaults to N.
+    k : int, optional
+        Index of the diagonal: 0 (the default) refers to the main diagonal,
+        a positive value refers to an upper diagonal,
+        and a negative value to a lower diagonal.
+    dtype : data-type, optional
+        Data-type of the returned array.
+
+    Returns
+    -------
+    I : ndarray of shape (N,M)
+        An array where all elements are equal to zero,
+        except for the k-th diagonal, whose values are equal to one.
+    """
+    _sanity_check_params('eye', ['order'], kwargs)
+    ctx = kwargs.pop('ctx', current_context())
+    if ctx is None:
+        ctx = current_context()
+    return _npi.eye(N, M, k, ctx, dtype)
+
+
+@set_module('mxnet.ndarray.numpy')
 def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None):  # pylint: disable=too-many-arguments
     r"""
     Return evenly spaced numbers over a specified interval.
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 00a7709..76df87c 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -50,7 +50,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtrac
            'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
            'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
            'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
-           'tensordot', 'histogram', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
+           'tensordot', 'histogram', 'eye', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
            'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
            'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
            'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
@@ -3637,6 +3637,33 @@ def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
 
 
 @set_module('mxnet.numpy')
+def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
+    """
+    Return a 2-D array with ones on the diagonal and zeros elsewhere.
+
+    Parameters
+    ----------
+    N : int
+        Number of rows in the output.
+    M : int, optional
+        Number of columns in the output. If None, defaults to N.
+    k : int, optional
+        Index of the diagonal: 0 (the default) refers to the main diagonal,
+        a positive value refers to an upper diagonal,
+        and a negative value to a lower diagonal.
+    dtype : data-type, optional
+        Data-type of the returned array.
+
+    Returns
+    -------
+    I : ndarray of shape (N,M)
+        An array where all elements are equal to zero,
+        except for the k-th diagonal, whose values are equal to one.
+    """
+    return _mx_nd_np.eye(N, M, k, dtype, **kwargs)
+
+
+@set_module('mxnet.numpy')
 def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None):  # pylint: disable=too-many-arguments
     r"""
     Return evenly spaced numbers over a specified interval.
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index de11cfb..cbd46f3 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -23,7 +23,7 @@ import ctypes
 import numpy as _np
 from . import _op as _mx_np_op
 from ...base import _LIB, SymbolHandle, numeric_types, mx_uint
-from ...util import check_call, set_module
+from ...util import check_call, _sanity_check_params, set_module
 from ...context import current_context
 from ..symbol import Symbol
 from .._internal import _set_np_symbol_class
@@ -33,7 +33,7 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'rem
            'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
            'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
            'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
-           'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
+           'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
            'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
            'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
            'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
@@ -1278,6 +1278,37 @@ def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
 
 
 @set_module('mxnet.symbol.numpy')
+def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
+    """
+    Return a 2-D array with ones on the diagonal and zeros elsewhere.
+
+    Parameters
+    ----------
+    N : int
+        Number of rows in the output.
+    M : int, optional
+        Number of columns in the output. If None, defaults to N.
+    k : int, optional
+        Index of the diagonal: 0 (the default) refers to the main diagonal,
+        a positive value refers to an upper diagonal,
+        and a negative value to a lower diagonal.
+    dtype : data-type, optional
+        Data-type of the returned array.
+
+    Returns
+    -------
+    I : ndarray of shape (N,M)
+        An array where all elements are equal to zero,
+        except for the k-th diagonal, whose values are equal to one.
+    """
+    _sanity_check_params('eye', ['order'], kwargs)
+    ctx = kwargs.pop('ctx', current_context())
+    if ctx is None:
+        ctx = current_context()
+    return _npi.eye(N, M, k, ctx, dtype)
+
+
+@set_module('mxnet.symbol.numpy')
 def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
     r"""
     Return evenly spaced numbers over a specified interval.
diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc
index 2477573..7e65d6c 100644
--- a/src/operator/numpy/np_init_op.cc
+++ b/src/operator/numpy/np_init_op.cc
@@ -22,6 +22,7 @@
  * \file np_init_op.cc
  * \brief CPU Implementation of numpy init op
  */
+
 #include "../tensor/init_op.h"
 #include "../tensor/elemwise_unary_op.h"
 #include "./np_init_op.h"
@@ -29,6 +30,8 @@
 namespace mxnet {
 namespace op {
 
+
+DMLC_REGISTER_PARAMETER(NumpyEyeParam);
 DMLC_REGISTER_PARAMETER(IndicesOpParam);
 
 inline bool NumpyIndicesShape(const nnvm::NodeAttrs& attrs,
@@ -117,23 +120,6 @@ NNVM_REGISTER_OP(_np_ones_like)
 .add_argument("a", "NDArray-or-Symbol",
               "The shape and data-type of a define these same attributes of the returned array.");
 
-bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
-                     mxnet::ShapeVector* in_shapes,
-                     mxnet::ShapeVector* out_shapes) {
-  const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
-  CHECK_EQ(in_shapes->size(), 0U);
-  CHECK_EQ(out_shapes->size(), 1U);
-  CHECK_NE(param.step, 0) << "_npi_arange does not support step=0";
-  CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " << param.repeat;
-  CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value";
-  double out_size = std::ceil((param.stop.value() - param.start) / param.step);
-  if (out_size < 0) {
-    out_size = 0;
-  }
-  SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
-  return true;
-}
-
 NNVM_REGISTER_OP(_npi_arange)
 .set_num_inputs(0)
 .set_num_outputs(1)
@@ -143,6 +129,16 @@ NNVM_REGISTER_OP(_npi_arange)
 .set_attr<FCompute>("FCompute<cpu>", RangeCompute<cpu, RangeParam>)
 .add_arguments(RangeParam::__FIELDS__());
 
+NNVM_REGISTER_OP(_npi_eye)
+.describe("Return a 2-D array with ones on the diagonal and zeros elsewhere.")
+.set_num_inputs(0)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyEyeParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyEyeShape)
+.set_attr<nnvm::FInferType>("FInferType", InitType<NumpyEyeParam>)
+.set_attr<FCompute>("FCompute<cpu>", NumpyEyeFill<cpu>)
+.add_arguments(NumpyEyeParam::__FIELDS__());
+
 NNVM_REGISTER_OP(_npi_indices)
 .describe("Return an array representing the indices of a grid.")
 .set_num_inputs(0)
diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu
index e68dd9a..9a2b236 100644
--- a/src/operator/numpy/np_init_op.cu
+++ b/src/operator/numpy/np_init_op.cu
@@ -47,6 +47,9 @@ NNVM_REGISTER_OP(_np_ones_like)
 NNVM_REGISTER_OP(_npi_arange)
 .set_attr<FCompute>("FCompute<gpu>", RangeCompute<gpu, RangeParam>);
 
+NNVM_REGISTER_OP(_npi_eye)
+.set_attr<FCompute>("FCompute<gpu>", NumpyEyeFill<gpu>);
+
 NNVM_REGISTER_OP(_npi_indices)
 .set_attr<FCompute>("FCompute<gpu>", IndicesCompute<gpu>);
 
diff --git a/src/operator/numpy/np_init_op.h b/src/operator/numpy/np_init_op.h
index 3e1c345..9eb83e8 100644
--- a/src/operator/numpy/np_init_op.h
+++ b/src/operator/numpy/np_init_op.h
@@ -35,6 +35,34 @@
 namespace mxnet {
 namespace op {
 
+struct NumpyEyeParam : public dmlc::Parameter<NumpyEyeParam> {
+  nnvm::dim_t N;
+  dmlc::optional<nnvm::dim_t> M;
+  nnvm::dim_t k;
+  std::string ctx;
+  int dtype;
+  DMLC_DECLARE_PARAMETER(NumpyEyeParam) {
+    DMLC_DECLARE_FIELD(N)
+    .describe("Number of rows in the output.");
+    DMLC_DECLARE_FIELD(M)
+    .set_default(dmlc::optional<nnvm::dim_t>())
+    .describe("Number of columns in the output. If None, defaults to N.");
+    DMLC_DECLARE_FIELD(k)
+    .set_default(0)
+    .describe("Index of the diagonal. 0 (the default) refers to the main diagonal,"
+              "a positive value refers to an upper diagonal."
+              "and a negative value to a lower diagonal.");
+    DMLC_DECLARE_FIELD(ctx)
+    .set_default("")
+    .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
+              "Only used for imperative calls.");
+    DMLC_DECLARE_FIELD(dtype)
+    .set_default(mshadow::kFloat32)
+    MXNET_ADD_ALL_TYPES
+    .describe("Data-type of the returned array.");
+  }
+};
+
 struct IndicesOpParam : public dmlc::Parameter<IndicesOpParam> {
   mxnet::TShape dimensions;
   int dtype;
@@ -52,6 +80,50 @@ struct IndicesOpParam : public dmlc::Parameter<IndicesOpParam> {
   }
 };
 
+inline bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
+                            mxnet::ShapeVector* in_shapes,
+                            mxnet::ShapeVector* out_shapes) {
+  const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
+  CHECK_EQ(in_shapes->size(), 0U);
+  CHECK_EQ(out_shapes->size(), 1U);
+  CHECK_NE(param.step, 0) << "_npi_arange does not support step=0";
+  CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " << param.repeat;
+  CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value";
+  double out_size = std::ceil((param.stop.value() - param.start) / param.step);
+  if (out_size < 0) {
+    out_size = 0;
+  }
+  SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
+  return true;
+}
+
+inline bool NumpyEyeShape(const nnvm::NodeAttrs& attrs,
+                          mxnet::ShapeVector *in_attrs,
+                          mxnet::ShapeVector *out_attrs) {
+  const NumpyEyeParam& param = nnvm::get<NumpyEyeParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), 0U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  nnvm::dim_t M = param.M.has_value() ? param.M.value() : param.N;
+  CHECK(param.N >= 0) << "negative dimensions are not allowed. N is " << param.N;
+  CHECK(M >= 0) << "negative dimensions are not allowed. M is " << M;
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(param.N, M));
+
+  return out_attrs->at(0).ndim() != 0U;
+}
+template<typename xpu>
+void NumpyEyeFill(const nnvm::NodeAttrs& attrs,
+                  const OpContext& ctx,
+                  const std::vector<TBlob>& inputs,
+                  const std::vector<OpReqType>& req,
+                  const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 0U);
+  CHECK_EQ(outputs.size(), 1U);
+  if (outputs[0].shape_.Size() == 0) return;  // zero-size tensor
+  const NumpyEyeParam& param = nnvm::get<NumpyEyeParam>(attrs.parsed);
+  const nnvm::dim_t num_cols = param.M.has_value() ? param.M.value() : param.N;
+  EyeFillImpl<xpu>(outputs[0], ctx, req, num_cols, param.N, param.k);
+}
+
 template<int req>
 struct indices_fwd {
   template<typename DType>
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 8e8896e..d2107a1 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -483,6 +483,29 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
   }
 }
 
+template<typename xpu>
+inline void EyeFillImpl(const TBlob& out_data,
+                        const OpContext& ctx,
+                        const std::vector<OpReqType>& req,
+                        const nnvm::dim_t num_cols,
+                        const nnvm::dim_t N,
+                        const nnvm::dim_t k) {
+  using namespace mxnet_op;
+  const nnvm::dim_t cnnz = std::max(num_cols - std::abs(k), (nnvm::dim_t)0);
+  const nnvm::dim_t rnnz = std::max(N - std::abs(k), (nnvm::dim_t)0);
+  const nnvm::dim_t nnz = k > 0 ? std::min(cnnz, N) :
+                          std::min(rnnz, num_cols);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Fill(s, out_data, req[0], static_cast<DType>(0));
+        if (nnz > 0) {
+          Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, out_data.dptr<DType>(),
+            std::max(static_cast<nnvm::dim_t>(0), k), k, num_cols);
+        }
+      });
+  });
+}
 
 template<typename xpu>
 void EyeFill(const nnvm::NodeAttrs& attrs,
@@ -493,25 +516,10 @@ void EyeFill(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 0U);
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
-  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   const EyeParam& param = nnvm::get<EyeParam>(attrs.parsed);
   const TBlob& out_data = outputs[0];
   const nnvm::dim_t num_cols = param.M > 0 ? param.M : param.N;
-
-  const nnvm::dim_t cnnz = std::max(num_cols - std::abs(param.k), (nnvm::dim_t)0);
-  const nnvm::dim_t rnnz = std::max(param.N - std::abs(param.k), (nnvm::dim_t)0);
-  const nnvm::dim_t nnz = param.k > 0 ? std::min(cnnz, param.N) :
-                                        std::min(rnnz, num_cols);
-  using namespace mxnet_op;
-  MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-    MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-      Fill(s, out_data, req[0], static_cast<DType>(0));
-      if (nnz > 0) {
-        Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, out_data.dptr<DType>(),
-          std::max(static_cast<nnvm::dim_t>(0), param.k), param.k, num_cols);
-      }
-    });
-  });
+  EyeFillImpl<xpu>(out_data, ctx, req, num_cols, param.N, param.k);
 }
 
 
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index e9fa84e..0a0a1c8 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -2216,6 +2216,74 @@ def test_np_choice():
 
 @with_seed()
 @use_np
+def test_np_eye():
+    configs = [
+        4,
+        1000,
+        (4, 3),
+        (5, None),
+        (4, None, 1),
+        (2, 2, 1),
+        (4, 6, 1),
+        (7, 3, -3),
+        (3, 2, -2),
+        (4, 0),
+        (0, 0),
+        (0, 3),
+        (0, 0, -2)
+    ]
+    exception_configs = [
+        -1,
+        -1000,
+        (-2, None),
+        (1, -1)
+    ]
+    dtypes = ['int32', 'float16', 'float32', 'float64', None]
+    for config in configs:
+        for dtype in dtypes:
+            if isinstance(config, tuple):
+                mx_ret = np.eye(*config, dtype=dtype)
+                np_ret = _np.eye(*config, dtype=dtype)
+            else:
+                mx_ret = np.eye(config, dtype=dtype)
+                np_ret = _np.eye(config, dtype=dtype)
+            assert same(mx_ret.asnumpy(), np_ret)
+    # check for exception input
+    for config in exception_configs:
+        if isinstance(config, tuple):
+            assertRaises(MXNetError, np.eye, *config)
+        else:
+            assertRaises(MXNetError, np.eye, config)
+
+    class TestEye(HybridBlock):
+        def __init__(self, N, M=None, k=0, dtype=None):
+            super(TestEye, self).__init__()
+            self._N = N
+            self._M = M
+            self._k = k
+            self._dtype = dtype
+
+        def hybrid_forward(self, F, x):
+            return x + F.np.eye(self._N, self._M, self._k, dtype=self._dtype)
+
+    for dtype in dtypes:
+        x = np.zeros(shape=(), dtype=dtype)
+        for config in configs:
+            for hybridize in [False, True]:
+                if isinstance(config, tuple):
+                    net = TestEye(*config, dtype=dtype)
+                    np_out = _np.eye(*config, dtype=dtype)
+                else:
+                    net = TestEye(config, dtype=dtype)
+                    np_out = _np.eye(config, dtype=dtype)
+                if hybridize:
+                    net.hybridize()
+                mx_out = net(x)
+                assert same(mx_out.asnumpy(), np_out)
+
+
+@with_seed()
+@use_np
 def test_np_indices():
     dtypes = ['int32', 'int64', 'float16', 'float32', 'float64']
     shapes = [


[incubator-mxnet] 01/02: [Numpy] Numpy compatible dstack (#15871)

Posted by ha...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy_pr_merge
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 0fcd57fb6af708707073d4cba2a939fabeac9966
Author: Mike <ma...@connect.hku.hk>
AuthorDate: Fri Oct 11 17:51:00 2019 -0400

    [Numpy] Numpy compatible dstack (#15871)
    
    * Add dstack that pass CPU test
    
    Rgister dstack on GPU
    
    Minor comment fix
    
    Minor syntax fix
    
    Syntax fix according to comments
    
    header fix
    
    * Fix sanity
---
 python/mxnet/ndarray/numpy/_op.py      |  46 ++++++++++++++-
 python/mxnet/numpy/multiarray.py       |  48 +++++++++++++++-
 python/mxnet/symbol/numpy/_symbol.py   |  33 ++++++++++-
 src/operator/nn/concat-inl.h           |  62 ++++++++++++++++++++
 src/operator/numpy/np_matrix_op.cc     | 100 ++++++++++++++++++++++++++++++++-
 src/operator/numpy/np_matrix_op.cu     |   7 +++
 tests/python/unittest/test_numpy_op.py |  61 ++++++++++++++++++++
 7 files changed, 350 insertions(+), 7 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index aea8b19..2846d2b 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -32,8 +32,8 @@ __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mo
            'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
            'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
            'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
-           'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
-           'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
+           'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
+           'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
            'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
            'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
            'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
@@ -2468,6 +2468,48 @@ def vstack(arrays, out=None):
 
 
 @set_module('mxnet.ndarray.numpy')
+def dstack(arrays):
+    """
+    Stack arrays in sequence depth wise (along third axis).
+    This is equivalent to concatenation along the third axis after 2-D arrays
+    of shape `(M,N)` have been reshaped to `(M,N,1)` and 1-D arrays of shape
+    `(N,)` have been reshaped to `(1,N,1)`. Rebuilds arrays divided by
+    `dsplit`.
+    This function makes most sense for arrays with up to 3 dimensions. For
+    instance, for pixel-data with a height (first axis), width (second axis),
+    and r/g/b channels (third axis). The functions `concatenate`, `stack` and
+    `block` provide more general stacking and concatenation operations.
+
+    Parameters
+    ----------
+    tup : sequence of arrays
+        The arrays must have the same shape along all but the third axis.
+        1-D or 2-D arrays must have the same shape.
+
+    Returns
+    -------
+    stacked : ndarray
+        The array formed by stacking the given arrays, will be at least 3-D.
+
+    Examples
+    --------
+    >>> a = np.array((1,2,3))
+    >>> b = np.array((2,3,4))
+    >>> np.dstack((a,b))
+    array([[[1, 2],
+            [2, 3],
+            [3, 4]]])
+    >>> a = np.array([[1],[2],[3]])
+    >>> b = np.array([[2],[3],[4]])
+    >>> np.dstack((a,b))
+    array([[[1, 2]],
+           [[2, 3]],
+           [[3, 4]]])
+    """
+    return _npi.dstack(*arrays)
+
+
+@set_module('mxnet.ndarray.numpy')
 def maximum(x1, x2, out=None):
     """Returns element-wise maximum of the input arrays with broadcasting.
 
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index d3ae4d1..00a7709 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -51,8 +51,8 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtrac
            'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
            'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
            'tensordot', 'histogram', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
-           'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
-           'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
+           'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
+           'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
            'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
            'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
 
@@ -4011,6 +4011,50 @@ def vstack(arrays, out=None):
 
 
 @set_module('mxnet.numpy')
+def dstack(arrays):
+    """
+    Stack arrays in sequence depth wise (along third axis).
+
+    This is equivalent to concatenation along the third axis after 2-D arrays
+    of shape `(M,N)` have been reshaped to `(M,N,1)` and 1-D arrays of shape
+    `(N,)` have been reshaped to `(1,N,1)`. Rebuilds arrays divided by
+    `dsplit`.
+
+    This function makes most sense for arrays with up to 3 dimensions. For
+    instance, for pixel-data with a height (first axis), width (second axis),
+    and r/g/b channels (third axis). The functions `concatenate`, `stack` and
+    `block` provide more general stacking and concatenation operations.
+
+    Parameters
+    ----------
+    tup : sequence of arrays
+        The arrays must have the same shape along all but the third axis.
+        1-D or 2-D arrays must have the same shape.
+
+    Returns
+    -------
+    stacked : ndarray
+        The array formed by stacking the given arrays, will be at least 3-D.
+
+    Examples
+    --------
+    >>> a = np.array((1,2,3))
+    >>> b = np.array((2,3,4))
+    >>> np.dstack((a,b))
+    array([[[1, 2],
+            [2, 3],
+            [3, 4]]])
+    >>> a = np.array([[1],[2],[3]])
+    >>> b = np.array([[2],[3],[4]])
+    >>> np.dstack((a,b))
+    array([[[1, 2]],
+           [[2, 3]],
+           [[3, 4]]])
+    """
+    return _npi.dstack(*arrays)
+
+
+@set_module('mxnet.numpy')
 def maximum(x1, x2, out=None):
     """Returns element-wise maximum of the input arrays with broadcasting.
 
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 9a90942..de11cfb 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -34,8 +34,8 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'rem
            'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
            'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
            'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
-           'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
-           'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
+           'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
+           'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
            'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
            'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
            'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
@@ -2662,6 +2662,35 @@ def vstack(arrays, out=None):
 
 
 @set_module('mxnet.symbol.numpy')
+def dstack(arrays):
+    """
+    Stack arrays in sequence depth wise (along third axis).
+
+    This is equivalent to concatenation along the third axis after 2-D arrays
+    of shape `(M,N)` have been reshaped to `(M,N,1)` and 1-D arrays of shape
+    `(N,)` have been reshaped to `(1,N,1)`. Rebuilds arrays divided by
+    `dsplit`.
+
+    This function makes most sense for arrays with up to 3 dimensions. For
+    instance, for pixel-data with a height (first axis), width (second axis),
+    and r/g/b channels (third axis). The functions `concatenate`, `stack` and
+    `block` provide more general stacking and concatenation operations.
+
+    Parameters
+    ----------
+    tup : sequence of _Symbol
+        The arrays must have the same shape along all but the first axis.
+        1-D arrays must have the same length.
+
+    Returns
+    -------
+    stacked : _Symbol
+        The array formed by stacking the given arrays, will be at least 2-D.
+    """
+    return _npi.dstack(*arrays)
+
+
+@set_module('mxnet.symbol.numpy')
 def maximum(x1, x2, out=None):
     return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)
 
diff --git a/src/operator/nn/concat-inl.h b/src/operator/nn/concat-inl.h
index 7a58ae6..1fb20ac 100644
--- a/src/operator/nn/concat-inl.h
+++ b/src/operator/nn/concat-inl.h
@@ -142,6 +142,37 @@ void ConcatCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
 }
 
 template<typename xpu>
+void DStackCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
+                   const std::vector<TBlob>& inputs,
+                   const std::vector<OpReqType>& req,
+                   const std::vector<TBlob>& outputs) {
+  ConcatParam param = nnvm::get<ConcatParam>(attrs.parsed);
+  param.dim = 2;
+  std::vector<TBlob> modified_inputs(inputs.size());
+  for (int i = 0; i < param.num_args; ++i) {
+    if (inputs[i].shape_.ndim() == 0) {
+      modified_inputs[i] = inputs[i].reshape(TShape(3, 1));
+    } else if (inputs[i].shape_.ndim() == 1) {
+      TShape t = TShape(3, 1);
+      t[1] = inputs[i].shape_[0];
+      modified_inputs[i] = inputs[i].reshape(t);
+    } else if (inputs[i].shape_.ndim() == 2) {
+      TShape t = TShape(3, 1);
+      t[0] = inputs[i].shape_[0];
+      t[1] = inputs[i].shape_[1];
+      modified_inputs[i] = inputs[i].reshape(t);
+    } else {
+      modified_inputs[i] = inputs[i];
+    }
+  }
+  MSHADOW_TYPE_SWITCH(inputs[concat_enum::kData0].type_flag_, DType, {
+    ConcatOp<xpu, DType> op;
+    op.Init(param);
+    op.Forward(ctx, modified_inputs, req, outputs);
+  });
+}
+
+template<typename xpu>
 void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
                        const std::vector<TBlob>& inputs,
                        const std::vector<OpReqType>& req,
@@ -154,6 +185,37 @@ void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
   });
 }
 
+template<typename xpu>
+void DStackGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
+                       const std::vector<TBlob>& inputs,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<TBlob>& outputs) {
+  ConcatParam param = nnvm::get<ConcatParam>(attrs.parsed);
+  param.dim = 2;
+  std::vector<TBlob> modified_outputs(outputs.size());
+  for (int i = 0; i < param.num_args; ++i) {
+    if (outputs[i].shape_.ndim() == 0) {
+      modified_outputs[i] = outputs[i].reshape(TShape(3, 1));
+    } else if (outputs[i].shape_.ndim() == 1) {
+      TShape t  = TShape(3, 1);
+      t[1] = outputs[i].shape_[0];
+      modified_outputs[i] = outputs[i].reshape(t);
+    } else if (outputs[i].shape_.ndim() == 2) {
+      TShape t  = TShape(3, 1);
+      t[0] = outputs[i].shape_[0];
+      t[1] = outputs[i].shape_[1];
+      modified_outputs[i] = outputs[i].reshape(t);
+    } else {
+      modified_outputs[i] = outputs[i];
+    }
+  }
+  MSHADOW_TYPE_SWITCH(inputs[concat_enum::kOut].type_flag_, DType, {
+    ConcatOp<xpu, DType> op;
+    op.Init(param);
+    op.Backward(ctx, inputs[concat_enum::kOut], req, modified_outputs);
+  });
+}
+
 /*!
  * \brief concat CSRNDArray on the first dimension.
  */
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index 96a1056..3804468 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -255,6 +255,67 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,
                  mxnet::ShapeVector *in_shape,
                  mxnet::ShapeVector *out_shape);
 
+bool DStackShape(const nnvm::NodeAttrs& attrs,
+                 mxnet::ShapeVector *in_shape,
+                 mxnet::ShapeVector *out_shape) {
+  using namespace mshadow;
+  ConcatParam param_ = nnvm::get<ConcatParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
+  mxnet::TShape dshape;
+  dim_t size = 0;
+  bool has_unknown_dim_size = false;
+  int axis = 2;
+  param_.dim = axis;
+  for (int i = 0; i < param_.num_args; ++i) {
+    if ((*in_shape)[i].ndim() == 0) {
+      (*in_shape)[i] = mxnet::TShape(3, 1);
+    } else if ((*in_shape)[i].ndim() == 1) {
+      mxnet::TShape t = mxnet::TShape(3, 1);
+      t[1] = (*in_shape)[i][0];
+      (*in_shape)[i] = t;
+    } else if ((*in_shape)[i].ndim() == 2) {
+      mxnet::TShape t = mxnet::TShape(3, 1);
+      t[0] = (*in_shape)[i][0];
+      t[1] = (*in_shape)[i][1];
+      (*in_shape)[i] = t;
+    }
+    mxnet::TShape &tmp = (*in_shape)[i];
+    if (tmp.ndim() > 0) {
+      CheckAxis(axis, tmp.ndim());
+      if (!mxnet::dim_size_is_known(tmp, axis)) {
+        has_unknown_dim_size = true;
+      } else {
+        size += tmp[axis];
+      }
+      tmp[axis] = -1;
+      shape_assign(&dshape, tmp);
+    }
+  }
+
+  mxnet::TShape tmp = (*out_shape)[0];
+  if (tmp.ndim() > 0) {
+    axis = CheckAxis(param_.dim, tmp.ndim());
+    tmp[axis] = -1;
+    shape_assign(&dshape, tmp);
+  }
+
+  if (dshape.ndim() == -1) return false;
+  CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated";
+
+  for (int i = 0; i < param_.num_args; ++i) {
+    CHECK(shape_assign(&(*in_shape)[i], dshape))
+      << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
+  }
+
+  if (!has_unknown_dim_size) {
+    dshape[axis] = size;
+  }
+  CHECK(shape_assign(&(*out_shape)[0], dshape))
+    << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
+
+  return shape_is_known(dshape);
+}
+
 bool ConcatType(const nnvm::NodeAttrs& attrs,
                 std::vector<int> *in_type,
                 std::vector<int> *out_type);
@@ -269,7 +330,6 @@ struct NumpyConcatGrad {
   }
 };
 
-
 NNVM_REGISTER_OP(_npi_concatenate)
 .describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE)
 .set_num_inputs([](const NodeAttrs& attrs) {
@@ -490,6 +550,44 @@ NNVM_REGISTER_OP(_backward_np_vstack)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FCompute>("FCompute<cpu>", NumpyVstackBackward<cpu>);
 
+NNVM_REGISTER_OP(_npi_dstack)
+.describe(R"code(Stack tensors in sequence depthwise (in third dimension))code" ADD_FILELINE)
+.set_num_inputs([](const NodeAttrs& attrs) {
+  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+  return params.num_args;
+})
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<ConcatParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+    std::vector<std::string> ret;
+    for (int i = 0; i < params.num_args; ++i) {
+      ret.push_back(std::string("data") + std::to_string(i));
+    }
+    return ret;
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"out"};
+})
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<nnvm::FInferType>("FInferType", ConcatType)
+.set_attr<mxnet::FInferShape>("FInferShape", DStackShape)
+.set_attr<FCompute>("FCompute<cpu>", DStackCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", NumpyConcatGrad{"_backward_np_dstack"})
+.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
+.add_arguments(ConcatParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_np_dstack)
+.set_num_outputs([](const NodeAttrs& attrs) {
+  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+  return params.num_args;
+})
+.set_attr_parser(ParamParser<ConcatParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", DStackGradCompute<cpu>);
+
 inline bool NumpyRollShape(const nnvm::NodeAttrs& attrs,
                            mxnet::ShapeVector *in_attrs,
                            mxnet::ShapeVector *out_attrs) {
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
index caab410..125cd91 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -53,6 +53,12 @@ NNVM_REGISTER_OP(_npi_vstack)
 NNVM_REGISTER_OP(_backward_np_vstack)
 .set_attr<FCompute>("FCompute<gpu>", NumpyVstackBackward<gpu>);
 
+NNVM_REGISTER_OP(_npi_dstack)
+.set_attr<FCompute>("FCompute<gpu>", DStackCompute<gpu>);
+
+NNVM_REGISTER_OP(_backward_np_dstack)
+.set_attr<FCompute>("FCompute<gpu>", DStackGradCompute<gpu>);
+
 NNVM_REGISTER_OP(_np_roll)
 .set_attr<FCompute>("FCompute<gpu>", NumpyRollCompute<gpu>);
 
@@ -90,5 +96,6 @@ NNVM_REGISTER_OP(_npi_flip)
 
 NNVM_REGISTER_OP(_backward_npi_flip)
 .set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 1f90f30..e9fa84e 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1687,6 +1687,67 @@ def test_np_stack():
 
 @with_seed()
 @use_np
+def test_np_dstack():
+    class TestDStack(HybridBlock):
+        def __init__(self):
+            super(TestDStack, self).__init__()
+
+        def hybrid_forward(self, F, a, *args):
+            return F.np.dstack([a] + list(args))
+
+    def get_new_shape(shape):
+        if len(shape) < 3:
+            return shape
+        axis = 2
+        shape_lst = list(shape)
+        shape_lst[axis] = random.randint(0, 5)
+        return tuple(shape_lst)
+
+    shapes = [
+        (),
+        (1,),
+        (2,1),
+        (2,2,4),
+        (2,0,0),
+        (0,1,3),
+        (2,0,3),
+        (2,3,4,5)
+    ]
+    for hybridize in [True, False]:
+        for shape in shapes:
+            test_dstack = TestDStack()
+            if hybridize:
+                test_dstack.hybridize()
+            # test symbolic forward
+            a = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray()
+            a.attach_grad()
+            b = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray()
+            b.attach_grad()
+            c = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray()
+            c.attach_grad()
+            d = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray()
+            d.attach_grad()
+            with mx.autograd.record():
+                mx_out = test_dstack(a, b, c, d)
+            np_out = _np.dstack((a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()))
+            assert mx_out.shape == np_out.shape
+            assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+            # test symbolic backward
+            mx_out.backward()
+            assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
+            assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5)
+            assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5)
+            assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5)
+
+            # test imperative
+            mx_out = np.dstack((a, b, c, d))
+            np_out = _np.dstack((a.asnumpy(),b.asnumpy(), c.asnumpy(), d.asnumpy()))
+            assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+
+@with_seed()
+@use_np
 def test_np_ravel():
     class TestRavel(HybridBlock):
         def __init__(self):