You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2021/03/30 18:43:25 UTC

[incubator-mxnet] branch master updated: [FFI] npx.pick, npx.convolution, npx.deconvolution (#20101)

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cb4df36  [FFI] npx.pick, npx.convolution, npx.deconvolution (#20101)
cb4df36 is described below

commit cb4df368488d6afee86d6917d0fb70d60684fcba
Author: barry-jin <69...@users.noreply.github.com>
AuthorDate: Tue Mar 30 11:41:08 2021 -0700

    [FFI] npx.pick, npx.convolution, npx.deconvolution (#20101)
---
 python/mxnet/base.py                               |   3 +-
 python/mxnet/ndarray/numpy_extension/_op.py        | 292 ++++++++++++++++++++-
 python/mxnet/numpy_extension/_op.py                | 273 ++++++++++++++++++-
 .../operator/numpy_extension/npx_convolution_op.cc | 188 +++++++++++++
 .../numpy_extension/npx_deconvolution_op.cc        | 214 +++++++++++++++
 src/api/operator/numpy_extension/npx_pick_op.cc    |  79 ++++++
 src/operator/nn/convolution-inl.h                  |  67 +++++
 src/operator/nn/deconvolution-inl.h                |  73 ++++++
 src/operator/tensor/broadcast_reduce_op.h          |  21 ++
 9 files changed, 1207 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index fa13020..5e3912b 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -796,7 +796,8 @@ _NP_EXT_OP_PREFIX = '_npx_'
 _NP_EXT_OP_SUBMODULE_LIST = ['_image_', '_random_']
 _NP_EXT_OP_IMPLEMENTED_SET = {'_npx_softmax', '_npx_log_softmax', '_npx_masked_softmax',
                               '_npx_masked_log_softmax', '_npx_activation',
-                              '_npx_batch_norm', '_npx_fully_connected'}
+                              '_npx_batch_norm', '_npx_fully_connected', '_npx_pick',
+                              '_npx_convolution', '_npx_deconvolution'}
 
 _NP_INTERNAL_OP_PREFIX = '_npi_'
 
diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py
index 8ada24f..346b85d 100644
--- a/python/mxnet/ndarray/numpy_extension/_op.py
+++ b/python/mxnet/ndarray/numpy_extension/_op.py
@@ -25,7 +25,8 @@ from ...util import set_module
 
 
 __all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax',
-           'activation', 'batch_norm', 'fully_connected']
+           'activation', 'batch_norm', 'fully_connected', 'pick', 'convolution',
+           'deconvolution']
 
 
 # pylint: disable=too-many-arguments
@@ -418,3 +419,292 @@ def fully_connected(x, weight, bias=None, num_hidden=None,
         assert bias is not None, "Missing bias parameter"
         return _api_internal.fully_connected(x, weight, bias, num_hidden,
                                              no_bias, flatten)
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.ndarray.numpy_extension')
+def pick(data, index, axis=-1, mode='clip', keepdims=False):
+    r"""Picks elements from an input array according to the input indices along the given axis.
+
+    Given an input array of shape ``(d0, d1)`` and indices of shape ``(i0,)``, the result will be
+    an output array of shape ``(i0,)`` with::
+
+      output[i] = input[i, indices[i]]
+
+    By default, if any index mentioned is too large, it is replaced by the index that addresses
+    the last element along an axis (the `clip` mode).
+
+    This function supports n-dimensional input and (n-1)-dimensional indices arrays.
+
+    Parameters
+    ----------
+    data : NDArray
+        The input array
+    index : NDArray
+        The index array
+    axis : int or None, optional, default='-1'
+        int or None. The axis to picking the elements.
+        Negative values means indexing from right to left.
+        If is `None`, the elements in the index w.r.t the flattened input will be picked.
+    keepdims : boolean, optional, default=0
+        If true, the axis where we pick the elements is
+        left in the result as dimension with size one.
+    mode : {'clip', 'wrap'},optional, default='clip'
+        Specify how out-of-bound indices behave. Default is "clip".
+        "clip" means clip to the range. So, if all indices mentioned are too large,
+        they are replaced by the index that addresses the last element along an axis.
+        "wrap" means to wrap around.
+
+    out : NDArray, optional
+        The output NDArray to hold the result.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+
+    Example
+    -------
+    >>> x = np.array([[1., 2.],[3., 4.],[5., 6.]])
+
+    picks elements with specified indices along axis 0
+
+    >>> npx.pick(x, np.array([0, 1]), 0)
+    array([1., 4.])
+
+    picks elements with specified indices along axis 1
+
+    >>> npx.pick(x, np.array([0, 1, 0]), 1)
+    array([1., 4., 5.])
+
+    picks elements with specified indices along axis 1 using 'wrap' mode
+    to place indicies that would normally be out of bounds
+
+    >>> npx.pick(x, np.array([2, -1, -2]), 1, mode='wrap')
+    array([1., 4., 5.])
+
+    picks elements with specified indices along axis 1 and dims are maintained
+
+    >>> npx.pick(x, np.array([[1.], [0.], [2.]]), 1, keepdims=True)
+    array([[2.],
+           [3.],
+           [6.]])
+    """
+    return _api_internal.pick(data, index, axis, mode, keepdims)
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.ndarray.numpy_extension')
+def convolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None,
+                pad=None, num_filter=1, num_group=1, workspace=1024, no_bias=False,
+                cudnn_tune=None, cudnn_off=False, layout=None):
+    r"""Compute *N*-D convolution on *(N+2)*-D input.
+
+    In the 2-D convolution, given input data with shape *(batch_size,
+    channel, height, width)*, the output is computed by
+
+    .. math::
+
+       out[n,i,:,:] = bias[i] + \sum_{j=0}^{channel} data[n,j,:,:] \star
+       weight[i,j,:,:]
+
+    where :math:`\star` is the 2-D cross-correlation operator.
+
+    For general 2-D convolution, the shapes are
+
+    - **data**: *(batch_size, channel, height, width)*
+    - **weight**: *(num_filter, channel, kernel[0], kernel[1])*
+    - **bias**: *(num_filter,)*
+    - **out**: *(batch_size, num_filter, out_height, out_width)*.
+
+    Define::
+
+      f(x,k,p,s,d) = floor((x+2*p-d*(k-1)-1)/s)+1
+
+    then we have::
+
+      out_height=f(height, kernel[0], pad[0], stride[0], dilate[0])
+      out_width=f(width, kernel[1], pad[1], stride[1], dilate[1])
+
+    If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
+
+    The default data ``layout`` is *NCHW*, namely *(batch_size, channel, height,
+    width)*. We can choose other layouts such as *NWC*.
+
+    If ``num_group`` is larger than 1, denoted by *g*, then split the input ``data``
+    evenly into *g* parts along the channel axis, and also evenly split ``weight``
+    along the first dimension. Next compute the convolution on the *i*-th part of
+    the data with the *i*-th weight part. The output is obtained by concatenating all
+    the *g* results.
+
+    1-D convolution does not have *height* dimension but only *width* in space.
+
+    - **data**: *(batch_size, channel, width)*
+    - **weight**: *(num_filter, channel, kernel[0])*
+    - **bias**: *(num_filter,)*
+    - **out**: *(batch_size, num_filter, out_width)*.
+
+    3-D convolution adds an additional *depth* dimension besides *height* and
+    *width*. The shapes are
+
+    - **data**: *(batch_size, channel, depth, height, width)*
+    - **weight**: *(num_filter, channel, kernel[0], kernel[1], kernel[2])*
+    - **bias**: *(num_filter,)*
+    - **out**: *(batch_size, num_filter, out_depth, out_height, out_width)*.
+
+    Both ``weight`` and ``bias`` are learnable parameters.
+
+    There are other options to tune the performance.
+
+    - **cudnn_tune**: enable this option leads to higher startup time but may give
+      faster speed. Options are
+
+      - **off**: no tuning
+      - **limited_workspace**:run test and pick the fastest algorithm that doesn't
+        exceed workspace limit.
+      - **fastest**: pick the fastest algorithm and ignore workspace limit.
+      - **None** (default): the behavior is determined by environment variable
+        ``MXNET_CUDNN_AUTOTUNE_DEFAULT``. 0 for off, 1 for limited workspace
+        (default), 2 for fastest.
+
+    - **workspace**: A large number leads to more (GPU) memory usage but may improve
+      the performance.
+
+    Parameters
+    ----------
+    data : NDArray
+        Input data to the ConvolutionOp.
+    weight : NDArray
+        Weight matrix.
+    bias : NDArray
+        Bias parameter.
+    kernel : Shape(tuple), required
+        Convolution kernel size: (w,), (h, w) or (d, h, w)
+    stride : Shape(tuple), optional, default=[]
+        Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
+    dilate : Shape(tuple), optional, default=[]
+        Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
+    pad : Shape(tuple), optional, default=[]
+        Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding.
+    num_filter : int (non-negative), required
+        Convolution filter(channel) number
+    num_group : int (non-negative), optional, default=1
+        Number of group partitions.
+    workspace : long (non-negative), optional, default=1024
+        Maximum temporary workspace allowed (MB) in convolution.This parameter has two usages.
+        When CUDNN is not used, it determines the effective batch size of the convolution kernel.
+        When CUDNN is used, it controls the maximum temporary storage used for tuning the best
+        CUDNN kernel when `limited_workspace` strategy is used.
+    no_bias : boolean, optional, default=0
+        Whether to disable bias parameter.
+    cudnn_tune : {None, 'fastest', 'limited_workspace', 'off'},optional, default='None'
+        Whether to pick convolution algo by running performance test.
+    cudnn_off : boolean, optional, default=0
+        Turn off cudnn for this layer.
+    layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None'
+        Set layout for input, output and weight. Empty for
+        default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.
+        NHWC and NDHWC are only supported on GPU.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+    """
+    assert data is not None and weight is not None and kernel is not None, \
+           "Missing input data, weight or kernel"
+    assert num_filter > 1, "Number of output filters should be greater than 1"
+    assert workspace > 0, "Maximum temporary workspace should be greater than 0"
+    if no_bias:
+        assert bias is None, "Using no bias"
+        return _api_internal.convolution(data, weight, kernel, stride, dilate, pad,
+                                         num_filter, num_group, workspace, no_bias,
+                                         cudnn_tune, cudnn_off, layout)
+    else:
+        assert bias is not None, "Using bias"
+        return _api_internal.convolution(data, weight, bias, kernel, stride, dilate, pad,
+                                         num_filter, num_group, workspace, no_bias,
+                                         cudnn_tune, cudnn_off, layout)
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.ndarray.numpy_extension')
+def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None,
+                  pad=None, adj=None, target_shape=None, num_filter=1, num_group=1,
+                  workspace=512, no_bias=False, cudnn_tune=None,
+                  cudnn_off=False, layout=None):
+    r"""Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of
+    the input tensor. This operation can be seen as the gradient of Convolution operation
+    with respect to its input. Convolution usually reduces the size of the input.
+    Transposed convolution works the other way, going from a smaller input
+    to a larger output while preserving the connectivity pattern.
+
+    Parameters
+    ----------
+    data : NDArray
+        Input tensor to the deconvolution operation.
+    weight : NDArray
+        Weights representing the kernel.
+    bias : NDArray
+        Bias added to the result after the deconvolution operation.
+    kernel : Shape(tuple), required
+        Deconvolution kernel size: (w,), (h, w) or (d, h, w).
+        This is same as the kernel size used for the corresponding convolution
+    stride : Shape(tuple), optional, default=[]
+        The stride used for the corresponding convolution: (w,), (h, w) or (d, h, w).
+        Defaults to 1 for each dimension.
+    dilate : Shape(tuple), optional, default=[]
+        Dilation factor for each dimension of the input: (w,), (h, w) or (d, h, w).
+        Defaults to 1 for each dimension.
+    pad : Shape(tuple), optional, default=[]
+        The amount of implicit zero padding added during convolution for each dimension of
+        the input: (w,), (h, w) or (d, h, w). ``(kernel-1)/2`` is usually a good choice.
+        If `target_shape` is set, `pad` will be ignored and a padding that will generate
+        the target shape will be used. Defaults to no padding.
+    adj : Shape(tuple), optional, default=[]
+        Adjustment for output shape: (w,), (h, w) or (d, h, w).
+        If `target_shape` is set, `adj` will be ignored and computed accordingly.
+    target_shape : Shape(tuple), optional, default=[]
+        Shape of the output tensor: (w,), (h, w) or (d, h, w).
+    num_filter : int (non-negative), required
+        Number of output filters.
+    num_group : int (non-negative), optional, default=1
+        Number of groups partition.
+    workspace : long (non-negative), optional, default=512
+        Maximum temporary workspace allowed (MB) in deconvolution. This parameter has two usages.
+        When CUDNN is not used, it determines the effective batch size of the deconvolution kernel.
+        When CUDNN is used, it controls the maximum temporary storage used for tuning
+        the best CUDNN kernel when `limited_workspace` strategy is used.
+    no_bias : boolean, optional, default=1
+        Whether to disable bias parameter.
+    cudnn_tune : {None, 'fastest', 'limited_workspace', 'off'},optional, default='None'
+        Whether to pick convolution algorithm by running performance test.
+    cudnn_off : boolean, optional, default=0
+        Turn off cudnn for this layer.
+    layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None'
+        Set layout for input, output and weight. Empty for
+        default layout, NCW for 1d, NCHW for 2d and NCDHW for 3d.
+        NHWC and NDHWC are only supported on GPU.
+
+    out : NDArray, optional
+        The output NDArray to hold the result.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+    """
+    assert data is not None and weight is not None and kernel is not None, \
+           "Missing input data, weight or kernel"
+    assert num_filter > 1, "Number of output filters should be greater than 1"
+    assert workspace > 0, "Maximum temporary workspace should be greater than 0"
+    if no_bias:
+        assert bias is None, "Using no bias"
+        return _api_internal.deconvolution(data, weight, kernel, stride, dilate, pad,
+                                           adj, target_shape, num_filter, num_group,
+                                           workspace, no_bias, cudnn_tune, cudnn_off, layout)
+    else:
+        assert bias is not None, "Using bias"
+        return _api_internal.deconvolution(data, weight, bias, kernel, stride, dilate, pad,
+                                           adj, target_shape, num_filter, num_group,
+                                           workspace, no_bias, cudnn_tune, cudnn_off, layout)
diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py
index d168af6..6ca2248 100644
--- a/python/mxnet/numpy_extension/_op.py
+++ b/python/mxnet/numpy_extension/_op.py
@@ -22,7 +22,8 @@ from ..util import set_module
 
 
 __all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax',
-           'activation', 'batch_norm', 'fully_connected']
+           'activation', 'batch_norm', 'fully_connected', 'pick', 'convolution',
+           'deconvolution']
 
 
 # pylint: disable=too-many-arguments
@@ -385,3 +386,273 @@ def fully_connected(x, weight, bias=None, num_hidden=None,
     """
     return _mx_nd_npx.fully_connected(x, weight, bias, num_hidden=num_hidden,
                                       no_bias=no_bias, flatten=flatten)
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.numpy_extension')
+def pick(data, index, axis=-1, mode='clip', keepdims=False):
+    r"""Picks elements from an input array according to the input indices along the given axis.
+
+    Given an input array of shape ``(d0, d1)`` and indices of shape ``(i0,)``, the result will be
+    an output array of shape ``(i0,)`` with::
+
+      output[i] = input[i, indices[i]]
+
+    By default, if any index mentioned is too large, it is replaced by the index that addresses
+    the last element along an axis (the `clip` mode).
+
+    This function supports n-dimensional input and (n-1)-dimensional indices arrays.
+
+    Parameters
+    ----------
+    data : NDArray
+        The input array
+    index : NDArray
+        The index array
+    axis : int or None, optional, default='-1'
+        int or None. The axis to picking the elements.
+        Negative values means indexing from right to left.
+        If is `None`, the elements in the index w.r.t the flattened input will be picked.
+    keepdims : boolean, optional, default=0
+        If true, the axis where we pick the elements is
+        left in the result as dimension with size one.
+    mode : {'clip', 'wrap'},optional, default='clip'
+        Specify how out-of-bound indices behave. Default is "clip".
+        "clip" means clip to the range. So, if all indices mentioned are too large,
+        they are replaced by the index that addresses the last element along an axis.
+        "wrap" means to wrap around.
+
+    out : NDArray, optional
+        The output NDArray to hold the result.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+
+    Example
+    -------
+    >>> x = np.array([[1., 2.],[3., 4.],[5., 6.]])
+
+    picks elements with specified indices along axis 0
+
+    >>> npx.pick(x, np.array([0, 1]), 0)
+    array([1., 4.])
+
+    picks elements with specified indices along axis 1
+
+    >>> npx.pick(x, np.array([0, 1, 0]), 1)
+    array([1., 4., 5.])
+
+    picks elements with specified indices along axis 1 using 'wrap' mode
+    to place indicies that would normally be out of bounds
+
+    >>> npx.pick(x, np.array([2, -1, -2]), 1, mode='wrap')
+    array([1., 4., 5.])
+
+    picks elements with specified indices along axis 1 and dims are maintained
+
+    >>> npx.pick(x, np.array([[1.], [0.], [2.]]), 1, keepdims=True)
+    array([[2.],
+           [3.],
+           [6.]])
+    """
+    return _mx_nd_npx.pick(data, index, axis, mode, keepdims)
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.numpy_extension')
+def convolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None,
+                pad=None, num_filter=1, num_group=1, workspace=1024, no_bias=False,
+                cudnn_tune=None, cudnn_off=False, layout=None):
+    r"""Compute *N*-D convolution on *(N+2)*-D input.
+
+    In the 2-D convolution, given input data with shape *(batch_size,
+    channel, height, width)*, the output is computed by
+
+    .. math::
+
+       out[n,i,:,:] = bias[i] + \sum_{j=0}^{channel} data[n,j,:,:] \star
+       weight[i,j,:,:]
+
+    where :math:`\star` is the 2-D cross-correlation operator.
+
+    For general 2-D convolution, the shapes are
+
+    - **data**: *(batch_size, channel, height, width)*
+    - **weight**: *(num_filter, channel, kernel[0], kernel[1])*
+    - **bias**: *(num_filter,)*
+    - **out**: *(batch_size, num_filter, out_height, out_width)*.
+
+    Define::
+
+      f(x,k,p,s,d) = floor((x+2*p-d*(k-1)-1)/s)+1
+
+    then we have::
+
+      out_height=f(height, kernel[0], pad[0], stride[0], dilate[0])
+      out_width=f(width, kernel[1], pad[1], stride[1], dilate[1])
+
+    If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
+
+    The default data ``layout`` is *NCHW*, namely *(batch_size, channel, height,
+    width)*. We can choose other layouts such as *NWC*.
+
+    If ``num_group`` is larger than 1, denoted by *g*, then split the input ``data``
+    evenly into *g* parts along the channel axis, and also evenly split ``weight``
+    along the first dimension. Next compute the convolution on the *i*-th part of
+    the data with the *i*-th weight part. The output is obtained by concatenating all
+    the *g* results.
+
+    1-D convolution does not have *height* dimension but only *width* in space.
+
+    - **data**: *(batch_size, channel, width)*
+    - **weight**: *(num_filter, channel, kernel[0])*
+    - **bias**: *(num_filter,)*
+    - **out**: *(batch_size, num_filter, out_width)*.
+
+    3-D convolution adds an additional *depth* dimension besides *height* and
+    *width*. The shapes are
+
+    - **data**: *(batch_size, channel, depth, height, width)*
+    - **weight**: *(num_filter, channel, kernel[0], kernel[1], kernel[2])*
+    - **bias**: *(num_filter,)*
+    - **out**: *(batch_size, num_filter, out_depth, out_height, out_width)*.
+
+    Both ``weight`` and ``bias`` are learnable parameters.
+
+    There are other options to tune the performance.
+
+    - **cudnn_tune**: enable this option leads to higher startup time but may give
+      faster speed. Options are
+
+      - **off**: no tuning
+      - **limited_workspace**:run test and pick the fastest algorithm that doesn't
+        exceed workspace limit.
+      - **fastest**: pick the fastest algorithm and ignore workspace limit.
+      - **None** (default): the behavior is determined by environment variable
+        ``MXNET_CUDNN_AUTOTUNE_DEFAULT``. 0 for off, 1 for limited workspace
+        (default), 2 for fastest.
+
+    - **workspace**: A large number leads to more (GPU) memory usage but may improve
+      the performance.
+
+    Parameters
+    ----------
+    data : NDArray
+        Input data to the ConvolutionOp.
+    weight : NDArray
+        Weight matrix.
+    bias : NDArray
+        Bias parameter.
+    kernel : Shape(tuple), required
+        Convolution kernel size: (w,), (h, w) or (d, h, w)
+    stride : Shape(tuple), optional, default=[]
+        Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
+    dilate : Shape(tuple), optional, default=[]
+        Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
+    pad : Shape(tuple), optional, default=[]
+        Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding.
+    num_filter : int (non-negative), required
+        Convolution filter(channel) number
+    num_group : int (non-negative), optional, default=1
+        Number of group partitions.
+    workspace : long (non-negative), optional, default=1024
+        Maximum temporary workspace allowed (MB) in convolution.This parameter has two usages.
+        When CUDNN is not used, it determines the effective batch size of the convolution kernel.
+        When CUDNN is used, it controls the maximum temporary storage used for tuning the best
+        CUDNN kernel when `limited_workspace` strategy is used.
+    no_bias : boolean, optional, default=0
+        Whether to disable bias parameter.
+    cudnn_tune : {None, 'fastest', 'limited_workspace', 'off'},optional, default='None'
+        Whether to pick convolution algo by running performance test.
+    cudnn_off : boolean, optional, default=0
+        Turn off cudnn for this layer.
+    layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None'
+        Set layout for input, output and weight. Empty for
+        default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.
+        NHWC and NDHWC are only supported on GPU.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+    """
+    return _mx_nd_npx.convolution(data=data, weight=weight, bias=bias, kernel=kernel,
+                                  stride=stride, dilate=dilate, pad=pad, num_filter=num_filter,
+                                  num_group=num_group, workspace=workspace, no_bias=no_bias,
+                                  cudnn_tune=cudnn_tune, cudnn_off=cudnn_off, layout=layout)
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.numpy_extension')
+def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None,
+                  pad=None, adj=None, target_shape=None, num_filter=1, num_group=1,
+                  workspace=512, no_bias=False, cudnn_tune=None,
+                  cudnn_off=False, layout=None):
+    r"""Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of
+    the input tensor. This operation can be seen as the gradient of Convolution operation
+    with respect to its input. Convolution usually reduces the size of the input.
+    Transposed convolution works the other way, going from a smaller input
+    to a larger output while preserving the connectivity pattern.
+
+    Parameters
+    ----------
+    data : NDArray
+        Input tensor to the deconvolution operation.
+    weight : NDArray
+        Weights representing the kernel.
+    bias : NDArray
+        Bias added to the result after the deconvolution operation.
+    kernel : Shape(tuple), required
+        Deconvolution kernel size: (w,), (h, w) or (d, h, w).
+        This is same as the kernel size used for the corresponding convolution
+    stride : Shape(tuple), optional, default=[]
+        The stride used for the corresponding convolution: (w,), (h, w) or (d, h, w).
+        Defaults to 1 for each dimension.
+    dilate : Shape(tuple), optional, default=[]
+        Dilation factor for each dimension of the input: (w,), (h, w) or (d, h, w).
+        Defaults to 1 for each dimension.
+    pad : Shape(tuple), optional, default=[]
+        The amount of implicit zero padding added during convolution for each dimension of
+        the input: (w,), (h, w) or (d, h, w). ``(kernel-1)/2`` is usually a good choice.
+        If `target_shape` is set, `pad` will be ignored and a padding that will generate
+        the target shape will be used. Defaults to no padding.
+    adj : Shape(tuple), optional, default=[]
+        Adjustment for output shape: (w,), (h, w) or (d, h, w).
+        If `target_shape` is set, `adj` will be ignored and computed accordingly.
+    target_shape : Shape(tuple), optional, default=[]
+        Shape of the output tensor: (w,), (h, w) or (d, h, w).
+    num_filter : int (non-negative), required
+        Number of output filters.
+    num_group : int (non-negative), optional, default=1
+        Number of groups partition.
+    workspace : long (non-negative), optional, default=512
+        Maximum temporary workspace allowed (MB) in deconvolution. This parameter has two usages.
+        When CUDNN is not used, it determines the effective batch size of the deconvolution kernel.
+        When CUDNN is used, it controls the maximum temporary storage used for tuning
+        the best CUDNN kernel when `limited_workspace` strategy is used.
+    no_bias : boolean, optional, default=1
+        Whether to disable bias parameter.
+    cudnn_tune : {None, 'fastest', 'limited_workspace', 'off'},optional, default='None'
+        Whether to pick convolution algorithm by running performance test.
+    cudnn_off : boolean, optional, default=0
+        Turn off cudnn for this layer.
+    layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None'
+        Set layout for input, output and weight. Empty for
+        default layout, NCW for 1d, NCHW for 2d and NCDHW for 3d.
+        NHWC and NDHWC are only supported on GPU.
+
+    out : NDArray, optional
+        The output NDArray to hold the result.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+    """
+    return _mx_nd_npx.deconvolution(data=data, weight=weight, bias=bias, kernel=kernel,
+                                    stride=stride, dilate=dilate, pad=pad, adj=adj,
+                                    target_shape=target_shape, num_filter=num_filter,
+                                    num_group=num_group, workspace=workspace, no_bias=no_bias,
+                                    cudnn_tune=cudnn_tune, cudnn_off=cudnn_off, layout=layout)
diff --git a/src/api/operator/numpy_extension/npx_convolution_op.cc b/src/api/operator/numpy_extension/npx_convolution_op.cc
new file mode 100644
index 0000000..adb1ec3
--- /dev/null
+++ b/src/api/operator/numpy_extension/npx_convolution_op.cc
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file npx_convolution_op.cc
+ * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_convolution_op.cc
+ */
+#include <mxnet/api_registry.h>
+#include <mxnet/runtime/packed_func.h>
+#include "../utils.h"
+#include "../../../operator/nn/convolution-inl.h"
+
+namespace mxnet {
+
+inline int String2Layout(const std::string& s) {
+  using namespace op;
+  if (s == "NCW") {
+    return mshadow::kNCW;
+  } else if (s == "NCHW") {
+    return mshadow::kNCHW;
+  } else if (s == "NCDHW") {
+    return mshadow::kNCDHW;
+  } else if (s == "NHWC") {
+    return mshadow::kNHWC;
+  } else if (s == "NDHWC") {
+    return mshadow::kNDHWC;
+  } else {
+    LOG(FATAL) << "unknown layout type " << s;
+  }
+  LOG(FATAL) << "should not reach here ";
+  return 0;
+}
+
+inline int String2CudnnTune(const std::string& s) {
+  using namespace op;
+  if (s == "off") {
+    return conv::kOff;
+  } else if (s == "limited_workspace") {
+    return conv::kLimited;
+  } else if (s == "fastest") {
+    return conv::kFastest;
+  } else {
+    LOG(FATAL) << "unknown cudnn tune type " << s;
+  }
+  LOG(FATAL) << "should not reach here ";
+  return 0;
+}
+
+MXNET_REGISTER_API("_npx.convolution")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+  using namespace runtime;
+  nnvm::NodeAttrs attrs;
+  const nnvm::Op* op = Op::Get("_npx_convolution");
+  op::ConvolutionParam param;
+  int args_size = args.size();
+  // no_bias
+  if (args[args_size - 4].type_code() == kNull) {
+    param.no_bias = false;
+  } else {
+    param.no_bias = args[args_size - 4].operator bool();
+  }
+  // inputs
+  int num_inputs = param.no_bias ? 2 : 3;
+  std::vector<NDArray*> inputs;
+  inputs.reserve(num_inputs);
+  for (int i = 0; i < num_inputs; ++i) {
+    inputs.push_back(args[i].operator mxnet::NDArray*());
+  }
+  // kernel
+  if (args[num_inputs].type_code() == kDLInt) {
+    param.kernel = TShape(1, args[num_inputs].operator int64_t());
+  } else {
+    param.kernel = TShape(args[num_inputs].operator ObjectRef());
+  }
+  // layout
+  if (args[num_inputs + 10].type_code() == kNull) {
+    param.layout = dmlc::nullopt;
+  } else {
+    param.layout = String2Layout(args[num_inputs + 10]);
+  }
+  // Check
+  if (param.kernel.ndim() == 1) {
+    param.layout = param.layout? param.layout.value() : mshadow::kNCW;
+  } else if (param.kernel.ndim() == 2) {
+    param.layout = param.layout ? param.layout.value() : mshadow::kNCHW;
+  } else {
+    CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim() << "D convolution not supported";
+    param.layout = param.layout ? param.layout.value(): mshadow::kNCDHW;
+  }
+  // stride
+  if (args[num_inputs + 1].type_code() == kNull) {
+    if (param.kernel.ndim() == 1) {
+      param.stride = Shape1(1);
+    } else if (param.kernel.ndim() == 2) {
+      param.stride = Shape2(1, 1);
+    } else {
+      param.stride = Shape3(1, 1, 1);
+    }
+  } else if (args[num_inputs + 1].type_code() == kDLInt) {
+    param.stride = TShape(1, args[num_inputs + 1].operator int64_t());
+  } else {
+    param.stride = TShape(args[num_inputs + 1].operator ObjectRef());
+  }
+  // dilate
+  if (args[num_inputs + 2].type_code() == kNull) {
+    if (param.kernel.ndim() == 1) {
+      param.dilate = Shape1(1);
+    } else if (param.kernel.ndim() == 2) {
+      param.dilate = Shape2(1, 1);
+    } else {
+      param.dilate = Shape3(1, 1, 1);
+    }
+  } else if (args[num_inputs + 2].type_code() == kDLInt) {
+    param.dilate = TShape(1, args[num_inputs + 2].operator int64_t());
+  } else {
+    param.dilate = TShape(args[num_inputs + 2].operator ObjectRef());
+  }
+  // pad
+  if (args[num_inputs + 3].type_code() == kNull) {
+    if (param.kernel.ndim() == 1) {
+      param.pad = Shape1(0);
+    } else if (param.kernel.ndim() == 2) {
+      param.pad = Shape2(0, 0);
+    } else {
+      param.pad = Shape3(0, 0, 0);
+    }
+  } else if (args[num_inputs + 3].type_code() == kDLInt) {
+    param.pad = TShape(1, args[num_inputs + 3].operator int64_t());
+  } else {
+    param.pad = TShape(args[num_inputs + 3].operator ObjectRef());
+  }
+  // num_filter
+  param.num_filter = (uint32_t) (args[num_inputs + 4].operator int());
+  // num_group
+  param.num_group = (uint32_t) (args[num_inputs + 5].operator int());
+  // workspace
+  param.workspace = args[num_inputs + 6].operator uint64_t();
+  // cudnn_tune
+  if (args[num_inputs + 8].type_code() == kNull) {
+    param.cudnn_tune = dmlc::nullopt;
+  } else {
+    param.cudnn_tune = String2CudnnTune(args[num_inputs + 8]);
+  }
+  // cudnn_off
+  if (args[num_inputs + 9].type_code() == kNull) {
+    param.cudnn_off = false;
+  } else {
+    param.cudnn_off = args[num_inputs + 9].operator bool();
+  }
+
+  CHECK_EQ(param.kernel.ndim(), param.stride.ndim())
+    << "Stride must have the same number of dimensions with kernel_size,"
+    << "but kernel_size is set to " << param.kernel << " while stride is "
+    << param.stride;
+  CHECK_EQ(param.kernel.ndim(), param.dilate.ndim())
+    << "Dilate must have the same number of dimensions with kernel_size,"
+    << "but kernel_size is set to " << param.kernel << " while dilate is "
+    << param.dilate;
+  CHECK_EQ(param.kernel.ndim(), param.pad.ndim())
+    << "Padding must have the same number of dimensions with kernel_size,"
+    << "but kernel_size is set to " << param.kernel << " while padding is "
+    << param.pad;
+
+  attrs.parsed = param;
+  attrs.op = op;
+  SetAttrDict<op::ConvolutionParam>(&attrs);
+  int num_outputs = 0;
+  auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr);
+  *ret = ndoutputs[0];
+});
+
+}  // namespace mxnet
diff --git a/src/api/operator/numpy_extension/npx_deconvolution_op.cc b/src/api/operator/numpy_extension/npx_deconvolution_op.cc
new file mode 100644
index 0000000..838f440
--- /dev/null
+++ b/src/api/operator/numpy_extension/npx_deconvolution_op.cc
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file npx_deconvolution_op.cc
+ * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_deconvolution_op.cc
+ */
+#include <mxnet/api_registry.h>
+#include <mxnet/runtime/packed_func.h>
+#include "../utils.h"
+#include "../../../operator/nn/deconvolution-inl.h"
+
+namespace mxnet {
+
+inline int String2Layout(const std::string& s) {
+  using namespace op;
+  if (s == "NCW") {
+    return mshadow::kNCW;
+  } else if (s == "NCHW") {
+    return mshadow::kNCHW;
+  } else if (s == "NCDHW") {
+    return mshadow::kNCDHW;
+  } else if (s == "NHWC") {
+    return mshadow::kNHWC;
+  } else if (s == "NDHWC") {
+    return mshadow::kNDHWC;
+  } else {
+    LOG(FATAL) << "unknown layout type " << s;
+  }
+  LOG(FATAL) << "should not reach here ";
+  return 0;
+}
+
+inline int String2CudnnTune(const std::string& s) {
+  using namespace op;
+  if (s == "off") {
+    return deconv::kOff;
+  } else if (s == "limited_workspace") {
+    return deconv::kLimited;
+  } else if (s == "fastest") {
+    return deconv::kFastest;
+  } else {
+    LOG(FATAL) << "unknown cudnn tune type " << s;
+  }
+  LOG(FATAL) << "should not reach here ";
+  return 0;
+}
+
+MXNET_REGISTER_API("_npx.deconvolution")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+  using namespace runtime;
+  nnvm::NodeAttrs attrs;
+  const nnvm::Op* op = Op::Get("_npx_deconvolution");
+  op::DeconvolutionParam param;
+  int args_size = args.size();
+  // no_bias
+  if (args[args_size - 4].type_code() == kNull) {
+    param.no_bias = false;
+  } else {
+    param.no_bias = args[args_size - 4].operator bool();
+  }
+  // inputs
+  int num_inputs = param.no_bias ? 2 : 3;
+  std::vector<NDArray*> inputs;
+  inputs.reserve(num_inputs);
+  for (int i = 0; i < num_inputs; ++i) {
+    inputs.push_back(args[i].operator mxnet::NDArray*());
+  }
+  // kernel
+  if (args[num_inputs].type_code() == kDLInt) {
+    param.kernel = TShape(1, args[num_inputs].operator int64_t());
+  } else {
+    param.kernel = TShape(args[num_inputs].operator ObjectRef());
+  }
+  // layout
+  if (args[num_inputs + 12].type_code() == kNull) {
+    param.layout = dmlc::nullopt;
+  } else {
+    param.layout = String2Layout(args[num_inputs + 12]);
+  }
+  // Check
+  if (param.kernel.ndim() == 1) {
+    param.layout = param.layout? param.layout.value() : mshadow::kNCW;
+  } else if (param.kernel.ndim() == 2) {
+    param.layout = param.layout ? param.layout.value() : mshadow::kNCHW;
+  } else {
+    CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim() << "D convolution not supported";
+    param.layout = param.layout ? param.layout.value(): mshadow::kNCDHW;
+  }
+  // stride
+  if (args[num_inputs + 1].type_code() == kNull) {
+    if (param.kernel.ndim() == 1) {
+      param.stride = Shape1(1);
+    } else if (param.kernel.ndim() == 2) {
+      param.stride = Shape2(1, 1);
+    } else {
+      param.stride = Shape3(1, 1, 1);
+    }
+  } else if (args[num_inputs + 1].type_code() == kDLInt) {
+    param.stride = TShape(1, args[num_inputs + 1].operator int64_t());
+  } else {
+    param.stride = TShape(args[num_inputs + 1].operator ObjectRef());
+  }
+  // dilate
+  if (args[num_inputs + 2].type_code() == kNull) {
+    if (param.kernel.ndim() == 1) {
+      param.dilate = Shape1(1);
+    } else if (param.kernel.ndim() == 2) {
+      param.dilate = Shape2(1, 1);
+    } else {
+      param.dilate = Shape3(1, 1, 1);
+    }
+  } else if (args[num_inputs + 2].type_code() == kDLInt) {
+    param.dilate = TShape(1, args[num_inputs + 2].operator int64_t());
+  } else {
+    param.dilate = TShape(args[num_inputs + 2].operator ObjectRef());
+  }
+  // pad
+  if (args[num_inputs + 3].type_code() == kNull) {
+    if (param.kernel.ndim() == 1) {
+      param.pad = Shape1(0);
+    } else if (param.kernel.ndim() == 2) {
+      param.pad = Shape2(0, 0);
+    } else {
+      param.pad = Shape3(0, 0, 0);
+    }
+  } else if (args[num_inputs + 3].type_code() == kDLInt) {
+    param.pad = TShape(1, args[num_inputs + 3].operator int64_t());
+  } else {
+    param.pad = TShape(args[num_inputs + 3].operator ObjectRef());
+  }
+  // adj
+  if (args[num_inputs + 4].type_code() == kNull) {
+    if (param.kernel.ndim() == 1) {
+      param.adj = Shape1(0);
+    } else if (param.kernel.ndim() == 2) {
+      param.adj = Shape2(0, 0);
+    } else {
+      param.adj = Shape3(0, 0, 0);
+    }
+  } else if (args[num_inputs + 4].type_code() == kDLInt) {
+    param.adj = TShape(1, args[num_inputs + 4].operator int64_t());
+  } else {
+    param.adj = TShape(args[num_inputs + 4].operator ObjectRef());
+  }
+  // target_shape
+  if (args[num_inputs + 5].type_code() != kNull) {
+    if (args[num_inputs + 5].type_code() == kDLInt) {
+      param.target_shape = TShape(1, args[num_inputs + 5].operator int64_t());
+    } else {
+      param.target_shape = TShape(args[num_inputs + 5].operator ObjectRef());
+    }
+  }
+  // num_filter
+  param.num_filter = (uint32_t) (args[num_inputs + 6].operator int());
+  // num_group
+  param.num_group = (uint32_t) (args[num_inputs + 7].operator int());
+  // workspace
+  param.workspace = args[num_inputs + 8].operator uint64_t();
+  // cudnn_tune
+  if (args[num_inputs + 10].type_code() == kNull) {
+    param.cudnn_tune = dmlc::nullopt;
+  } else {
+    param.cudnn_tune = String2CudnnTune(args[num_inputs + 10]);
+  }
+  // cudnn_off
+  if (args[num_inputs + 11].type_code() == kNull) {
+    param.cudnn_off = false;
+  } else {
+    param.cudnn_off = args[num_inputs + 11].operator bool();
+  }
+
+  CHECK_EQ(param.kernel.ndim(), param.stride.ndim())
+    << "Stride must have the same number of dimensions with kernel_size,"
+    << "but kernel_size is set to " << param.kernel << " while stride is "
+    << param.stride;
+  CHECK_EQ(param.kernel.ndim(), param.dilate.ndim())
+    << "Dilate must have the same number of dimensions with kernel_size,"
+    << "but kernel_size is set to " << param.kernel << " while dilate is "
+    << param.dilate;
+  CHECK_EQ(param.kernel.ndim(), param.pad.ndim())
+    << "Padding must have the same number of dimensions with kernel_size,"
+    << "but kernel_size is set to " << param.kernel << " while padding is "
+    << param.pad;
+  CHECK_EQ(param.kernel.ndim(), param.adj.ndim())
+    << "Adjustment must have the same number of dimensions with kernel_size,"
+    << "but kernel_size is set to " << param.kernel << " while adjustment is "
+    << param.adj;
+
+  attrs.parsed = param;
+  attrs.op = op;
+  SetAttrDict<op::DeconvolutionParam>(&attrs);
+  int num_outputs = 0;
+  auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr);
+  *ret = ndoutputs[0];
+});
+
+}  // namespace mxnet
diff --git a/src/api/operator/numpy_extension/npx_pick_op.cc b/src/api/operator/numpy_extension/npx_pick_op.cc
new file mode 100644
index 0000000..423a91f
--- /dev/null
+++ b/src/api/operator/numpy_extension/npx_pick_op.cc
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file npx_pick_op.cc
+ * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_pick_op.cc
+ */
+#include <mxnet/api_registry.h>
+#include <mxnet/runtime/packed_func.h>
+#include "../utils.h"
+#include "../../../operator/tensor/broadcast_reduce_op.h"
+
+namespace mxnet {
+
+inline int String2PickMode(const std::string& s) {
+  using namespace op;
+  if (s == "wrap") {
+    return kWrap;
+  } else if (s == "clip") {
+    return kClip;
+  } else {
+    LOG(FATAL) << "unknown mode type " << s;
+  }
+  LOG(FATAL) << "should not reach here ";
+  return 0;
+}
+
+MXNET_REGISTER_API("_npx.pick")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+  using namespace runtime;
+  nnvm::NodeAttrs attrs;
+  const nnvm::Op* op = Op::Get("_npx_pick");
+  op::PickParam param;
+  // axis
+  if (args[2].type_code() == kNull) {
+    param.axis = dmlc::nullopt;
+  } else {
+    param.axis = args[2].operator int();
+  }
+  // mode
+  param.mode = String2PickMode(args[3].operator std::string());
+  // keepdims
+  if (args[4].type_code() == kNull) {
+    param.keepdims = false;
+  } else {
+    param.keepdims = args[4].operator bool();
+  }
+  attrs.parsed = param;
+  attrs.op = op;
+  SetAttrDict<op::PickParam>(&attrs);
+  // inputs
+  int num_inputs = 2;
+  std::vector<NDArray*> inputs;
+  inputs.reserve(num_inputs);
+  for (int i = 0; i < 2; ++i) {
+    inputs.push_back(args[i].operator mxnet::NDArray*());
+  }
+  int num_outputs = 0;
+  auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr);
+  *ret = ndoutputs[0];
+});
+
+}  // namespace mxnet
diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h
index 87c82c3..053bd5a 100644
--- a/src/operator/nn/convolution-inl.h
+++ b/src/operator/nn/convolution-inl.h
@@ -124,6 +124,73 @@ struct ConvolutionParam : public dmlc::Parameter<ConvolutionParam> {
            this->cudnn_off == other.cudnn_off &&
            this->layout == other.layout;
   }
+  std::string CudnnTune2String(int cudnn_tune) {
+    switch (cudnn_tune) {
+      case conv::kOff:
+        return "off";
+      case conv::kLimited:
+        return "limited_workspace";
+      case conv::kFastest:
+        return "fastest";
+      default:
+        LOG(FATAL) << "Unknown cudnn_tune enum " << cudnn_tune;
+    }
+    LOG(FATAL) << "should not reach here ";
+    return "";
+  }
+  std::string Layout2String(int layout) {
+    switch (layout) {
+      case mshadow::kNCW:
+        return "NCW";
+      case mshadow::kNCHW:
+        return "NCHW";
+      case mshadow::kNCDHW:
+        return "NCDHW";
+      case mshadow::kNHWC:
+        return "NHWC";
+      case mshadow::kNDHWC:
+        return "NDHWC";
+      default:
+        LOG(FATAL) << "Unknown layout enum " << layout;
+    }
+    LOG(FATAL) << "should not reach here ";
+    return "";
+  }
+  void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
+    std::ostringstream kernel_s, stride_s, dilate_s, pad_s,
+                       num_filter_s, num_group_s, workspace_s, no_bias_s,
+                       cudnn_tune_s, cudnn_off_s, layout_s;
+    kernel_s << kernel;
+    stride_s << stride;
+    dilate_s << dilate;
+    pad_s << pad;
+    num_filter_s << num_filter;
+    num_group_s << num_group;
+    workspace_s << workspace;
+    no_bias_s << no_bias;
+    cudnn_tune_s << cudnn_tune;
+    cudnn_off_s << cudnn_off;
+    layout_s << layout;
+    (*dict)["kernel"] = kernel_s.str();
+    (*dict)["stride"] = stride_s.str();
+    (*dict)["dilate"] = dilate_s.str();
+    (*dict)["pad"] = pad_s.str();
+    (*dict)["num_filter"] = num_filter_s.str();
+    (*dict)["num_group"] = num_group_s.str();
+    (*dict)["workspace"] = workspace_s.str();
+    (*dict)["no_bias"] = no_bias_s.str();
+    if (cudnn_tune.has_value()) {
+      (*dict)["cudnn_tune"] = CudnnTune2String(cudnn_tune.value());
+    } else {
+      (*dict)["cudnn_tune"] = cudnn_tune_s.str();
+    }
+    (*dict)["cudnn_off"] = cudnn_off_s.str();
+    if (layout.has_value()) {
+      (*dict)["layout"] = Layout2String(layout.value());
+    } else {
+      (*dict)["layout"] = layout_s.str();
+    }
+  }
 };
 
 void ConvolutionParamParser(nnvm::NodeAttrs* attrs);
diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h
index f1e684e..c5578f5 100644
--- a/src/operator/nn/deconvolution-inl.h
+++ b/src/operator/nn/deconvolution-inl.h
@@ -170,6 +170,79 @@ struct DeconvolutionParam : public dmlc::Parameter<DeconvolutionParam> {
            this->cudnn_off == other.cudnn_off &&
            this->layout == other.layout;
   }
+
+  std::string CudnnTune2String(int cudnn_tune) {
+    switch (cudnn_tune) {
+      case deconv::kOff:
+        return "off";
+      case deconv::kLimited:
+        return "limited_workspace";
+      case deconv::kFastest:
+        return "fastest";
+      default:
+        LOG(FATAL) << "Unknown cudnn_tune enum " << cudnn_tune;
+    }
+    LOG(FATAL) << "should not reach here ";
+    return "";
+  }
+  std::string Layout2String(int layout) {
+    switch (layout) {
+      case mshadow::kNCW:
+        return "NCW";
+      case mshadow::kNCHW:
+        return "NCHW";
+      case mshadow::kNCDHW:
+        return "NCDHW";
+      case mshadow::kNHWC:
+        return "NHWC";
+      case mshadow::kNDHWC:
+        return "NDHWC";
+      default:
+        LOG(FATAL) << "Unknown layout enum " << layout;
+    }
+    LOG(FATAL) << "should not reach here ";
+    return "";
+  }
+
+  void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
+    std::ostringstream kernel_s, stride_s, dilate_s, pad_s, adj_s,
+                       target_shape_s, num_filter_s, num_group_s, workspace_s,
+                       no_bias_s, cudnn_tune_s, cudnn_off_s, layout_s;
+    kernel_s << kernel;
+    stride_s << stride;
+    dilate_s << dilate;
+    pad_s << pad;
+    adj_s << adj;
+    target_shape_s << target_shape;
+    num_filter_s << num_filter;
+    num_group_s << num_group;
+    workspace_s << workspace;
+    no_bias_s << no_bias;
+    cudnn_tune_s << cudnn_tune;
+    cudnn_off_s << cudnn_off;
+    layout_s << layout;
+    (*dict)["kernel"] = kernel_s.str();
+    (*dict)["stride"] = stride_s.str();
+    (*dict)["dilate"] = dilate_s.str();
+    (*dict)["pad"] = pad_s.str();
+    (*dict)["adj"] = adj_s.str();
+    (*dict)["target_shape"] = target_shape_s.str();
+    (*dict)["num_filter"] = num_filter_s.str();
+    (*dict)["num_group"] = num_group_s.str();
+    (*dict)["workspace"] = workspace_s.str();
+    (*dict)["no_bias"] = no_bias_s.str();
+    if (cudnn_tune.has_value()) {
+      (*dict)["cudnn_tune"] = CudnnTune2String(cudnn_tune.value());
+    } else {
+      (*dict)["cudnn_tune"] = cudnn_tune_s.str();
+    }
+    (*dict)["cudnn_off"] = cudnn_off_s.str();
+    if (layout.has_value()) {
+      (*dict)["layout"] = Layout2String(layout.value());
+    } else {
+      (*dict)["layout"] = layout_s.str();
+    }
+  }
 };
 
 typedef ParamOpSign<DeconvolutionParam> DeconvSignature;
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index c4e3dae..9834a88 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -142,6 +142,27 @@ struct PickParam : public dmlc::Parameter<PickParam> {
               " they are replaced by the index that addresses the last element along an axis. "
               " \"wrap\" means to wrap around.");
   }
+  std::string PickMode2String(int mode) {
+    switch (mode) {
+      case kWrap:
+        return "wrap";
+      case kClip:
+        return "clip";
+      default:
+        LOG(FATAL) << "Unknown mode enum " << mode;
+    }
+    LOG(FATAL) << "should not reach here ";
+    return "";
+  }
+  void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
+    std::ostringstream axis_s, mode_s, keepdims_s;
+    axis_s << axis;
+    mode_s << mode;
+    keepdims_s << keepdims;
+    (*dict)["axis"] = axis_s.str();
+    (*dict)["mode"] = PickMode2String(mode);
+    (*dict)["keepdims"] = keepdims_s.str();
+  }
 };
 
 struct BroadcastAxesParam : public dmlc::Parameter<BroadcastAxesParam> {