You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2021/03/26 18:31:22 UTC

[incubator-mxnet] branch master updated: [FFI] npx.softmax, npx.activation, npx.batch_norm, npx.fully_connected (#20087)

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9645e63  [FFI] npx.softmax, npx.activation, npx.batch_norm, npx.fully_connected (#20087)
9645e63 is described below

commit 9645e63de706911a14d09dae52f3de573c74d633
Author: barry-jin <69...@users.noreply.github.com>
AuthorDate: Fri Mar 26 11:27:32 2021 -0700

    [FFI] npx.softmax, npx.activation, npx.batch_norm, npx.fully_connected (#20087)
---
 python/mxnet/base.py                               |   5 +-
 .../numpy_extension/{_op.py => _api_internal.py}   |   7 +-
 python/mxnet/ndarray/numpy_extension/_op.py        | 401 ++++++++++++++++++++-
 python/mxnet/numpy_extension/_op.py                | 369 ++++++++++++++++++-
 .../operator/numpy_extension/npx_activation_op.cc  |  69 ++++
 .../operator/numpy_extension/npx_batch_norm_op.cc  |  87 +++++
 .../numpy_extension/npx_fully_connected_op.cc      |  66 ++++
 src/api/operator/numpy_extension/npx_softmax_op.cc | 136 +++++++
 src/operator/nn/activation-inl.h                   |  23 ++
 src/operator/nn/batch_norm-inl.h                   |  22 ++
 src/operator/nn/fully_connected-inl.h              |   9 +
 src/operator/nn/softmax-inl.h                      |  15 +
 12 files changed, 1204 insertions(+), 5 deletions(-)

diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 15db63e..fa13020 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -794,6 +794,9 @@ _NP_OP_IMPLEMENTED_SET = {'_np_reshape'}
 
 _NP_EXT_OP_PREFIX = '_npx_'
 _NP_EXT_OP_SUBMODULE_LIST = ['_image_', '_random_']
+_NP_EXT_OP_IMPLEMENTED_SET = {'_npx_softmax', '_npx_log_softmax', '_npx_masked_softmax',
+                              '_npx_masked_log_softmax', '_npx_activation',
+                              '_npx_batch_norm', '_npx_fully_connected'}
 
 _NP_INTERNAL_OP_PREFIX = '_npi_'
 
@@ -855,7 +858,7 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op
     elif np_module_name == 'numpy_extension':
         op_name_prefix = _NP_EXT_OP_PREFIX
         submodule_name_list = _NP_EXT_OP_SUBMODULE_LIST
-        op_implemented_set = set()
+        op_implemented_set = _NP_EXT_OP_IMPLEMENTED_SET
     elif np_module_name == 'numpy._internal':
         op_name_prefix = _NP_INTERNAL_OP_PREFIX
         submodule_name_list = []
diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_api_internal.py
similarity index 84%
copy from python/mxnet/ndarray/numpy_extension/_op.py
copy to python/mxnet/ndarray/numpy_extension/_api_internal.py
index 22738a0..b7b2216 100644
--- a/python/mxnet/ndarray/numpy_extension/_op.py
+++ b/python/mxnet/ndarray/numpy_extension/_api_internal.py
@@ -15,7 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""Namespace for the operators not belonging to the official numpy package
-used in Gluon dispatched by F=ndarray module."""
+"""Namespace for numpy_extension api."""
+
+from ..._ffi.function import _init_api
 
 __all__ = []
+
+_init_api("_npx", "mxnet.ndarray.numpy_extension._api_internal")
diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py
index 22738a0..8ada24f 100644
--- a/python/mxnet/ndarray/numpy_extension/_op.py
+++ b/python/mxnet/ndarray/numpy_extension/_op.py
@@ -18,4 +18,403 @@
 """Namespace for the operators not belonging to the official numpy package
 used in Gluon dispatched by F=ndarray module."""
 
-__all__ = []
+import numpy as _np
+from .. import numpy as np  # pylint: disable=reimported
+from . import _api_internal
+from ...util import set_module
+
+
+__all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax',
+           'activation', 'batch_norm', 'fully_connected']
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.ndarray.numpy_extension')
+def softmax(data, axis=-1, length=None, temperature=None, use_length=False, dtype=None):
+    r"""Applies the softmax function.
+
+    The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1.
+
+    .. math::
+       softmax(\mathbf{z/t})_j = \frac{e^{z_j/t}}{\sum_{k=1}^K e^{z_k/t}}
+
+    for :math:`j = 1, ..., K`
+
+    t is the temperature parameter in softmax function. By default, t equals 1.0
+
+    Parameters
+    ----------
+    data : NDArray
+        The input array.
+    axis : int, optional, default='-1'
+        The axis along which to compute softmax.
+    length : NDArray
+        The length array.
+    temperature : double or None, optional, default=None
+        Temperature parameter in softmax
+    dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
+        DType of the output in case this can't be inferred. Defaults to
+        the same as input's dtype if not defined (dtype=None).
+    use_length : boolean or None, optional, default=0
+        Whether to use the length input as a mask over the data input.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+
+    Example
+    -------
+    >>> data = np.ones((2, 3))
+    >>> npx.softmax(data, axis=0)
+    array([[0.5, 0.5, 0.5],
+        [0.5, 0.5, 0.5]])
+    >>> npx.softmax(data, axis=1)
+    array([[0.33333334, 0.33333334, 0.33333334],
+        [0.33333334, 0.33333334, 0.33333334]])
+    """
+    if dtype and not isinstance(dtype, str):
+        dtype = _np.dtype(dtype).name
+    if use_length:
+        assert length is not None, "Missing length input"
+        return _api_internal.softmax(data, length, axis, temperature, True, dtype)
+    else:
+        assert length is None, "Length input is not used"
+        return _api_internal.softmax(data, axis, temperature, False, dtype)
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.ndarray.numpy_extension')
+def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False, dtype=None):
+    r"""Computes the log softmax of the input.
+    This is equivalent to computing softmax followed by log.
+
+    Parameters
+    ----------
+    data : NDArray
+        The input array.
+    axis : int, optional, default='-1'
+        The axis along which to compute softmax.
+    length : NDArray
+        The length array.
+    temperature : double or None, optional, default=None
+        Temperature parameter in softmax
+    dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
+        DType of the output in case this can't be inferred. Defaults to
+        the same as input's dtype if not defined (dtype=None).
+    use_length : boolean or None, optional, default=0
+        Whether to use the length input as a mask over the data input.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+
+    Examples
+    --------
+    >>> data = np.array([1, 2, .1])
+    >>> npx.log_softmax(data)
+    array([-1.4170278, -0.4170278, -2.3170278])
+    >>> data = np.array([[1, 2, .1],[.1, 2, 1]])
+    >>> npx.log_softmax(data, axis=0)
+    array([[-0.34115386, -0.6931472 , -1.2411538 ],
+        [-1.2411538 , -0.6931472 , -0.34115386]])
+    """
+    if dtype and not isinstance(dtype, str):
+        dtype = _np.dtype(dtype).name
+    if use_length:
+        assert length is not None, "Missing length input"
+        return _api_internal.log_softmax(data, length, axis, temperature, True, dtype)
+    else:
+        assert length is None, "Length input is not used"
+        return _api_internal.log_softmax(data, axis, temperature, False, dtype)
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.ndarray.numpy_extension')
+def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
+    r"""Applies the softmax function masking elements according to the mask provided
+
+    Parameters
+    ----------
+    data : NDArray
+        The input array.
+    mask : NDArray
+        Mask to apply.
+    axis : int, optional, default='-1'
+        The axis along which to compute softmax.
+    temperature : double or None, optional, default=None
+        Temperature parameter in softmax
+    dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
+        DType of the output in case this can't be inferred. Defaults to
+        the same as input's dtype if not defined (dtype=None).
+    normalize : boolean or None, optional, default=1
+        Whether to normalize input data x: x = x - max(x)
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+
+    Examples
+    --------
+    >>> data = np.arange(5)
+    >>> mask = np.array([1, 0, 1, 0, 1])
+    >>> npx.masked_softmax(data, mask)
+    array([0.01587624, 0.        , 0.11731042, 0.        , 0.8668133 ])
+    >>> data = np.arange(10).reshape((2, 5))
+    >>> npx.masked_softmax(data, mask, axis=0)
+    array([[0.00669285, 0.        , 0.00669285, 0.        , 0.00669285],
+        [0.9933072 , 0.        , 0.9933072 , 0.        , 0.9933072 ]])
+    """
+    if mask is not None:
+        neg = -1e18
+        if _np.dtype(dtype) == _np.float16:
+            neg = -1e4
+        data = np.where(mask, data, neg)
+        logits = (softmax(data, axis=axis) / temperature) * mask
+    else:
+        logits = softmax(data, axis=axis) / temperature
+    return logits
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.ndarray.numpy_extension')
+def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
+    r"""Computes the masked log softmax of the input.
+    This is equivalent to computing masked softmax followed by log.
+
+    Parameters
+    ----------
+    data : NDArray
+        The input array.
+    mask : NDArray
+        Mask to apply.
+    axis : int, optional, default='-1'
+        The axis along which to compute softmax.
+    temperature : double or None, optional, default=None
+        Temperature parameter in softmax
+    dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
+        DType of the output in case this can't be inferred. Defaults to
+        the same as input's dtype if not defined (dtype=None).
+    normalize : boolean or None, optional, default=1
+        Whether to normalize input data x: x = x - max(x)
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+
+    Examples
+    --------
+    >>> data = np.arange(5)
+    >>> mask = np.array([1, 0, 1, 0, 1])
+    >>> npx.masked_log_softmax(data, mask)
+    array([-4.1429286 ,        -inf, -2.1429286 ,        -inf, -0.14292854])
+    >>> data = np.arange(10).reshape((2, 5))
+    >>> npx.masked_log_softmax(data, mask, axis=0)
+    array([[-5.0067153 ,        -inf, -5.0067153 ,        -inf, -5.0067153 ],
+       [-0.00671535,        -inf, -0.00671535,        -inf, -0.00671535]])
+    """
+    if mask is not None:
+        neg = -1e18
+        inf = -_np.inf
+        if _np.dtype(dtype) == _np.float16:
+            neg = -1e4
+        data = np.where(mask, data, neg)
+        logits = np.where(mask, log_softmax(data, axis=axis) / temperature, inf)
+    else:
+        logits = log_softmax(data, axis=axis) / temperature
+    return logits
+
+
+# pylint: disable=too-many-arguments, unused-argument
+@set_module('mxnet.ndarray.numpy_extension')
+def activation(data, act_type='relu', **kwargs):
+    r"""Applies an activation function element-wise to the input.
+
+    The following activation functions are supported:
+
+    - `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`
+    - `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}`
+    - `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}`
+    - `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`
+    - `softsign`: :math:`y = \frac{x}{1 + abs(x)}`
+
+    Parameters
+    ----------
+    data : NDArray
+        The input array.
+    act_type : {'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required
+        Activation function to be applied.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+    """
+    return _api_internal.activation(data, act_type)
+
+
+# pylint: disable=too-many-arguments, unused-argument
+@set_module('mxnet.ndarray.numpy_extension')
+def batch_norm(x, gamma, beta, running_mean, running_var, eps=1e-3, momentum=0.9,
+               fix_gamma=True, use_global_stats=False, output_mean_var=False, axis=1,
+               cudnn_off=False, min_calib_range=None, max_calib_range=None, **kwargs):
+    r"""Batch normalization.
+
+    Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as
+    well as offset ``beta``.
+
+    Assume the input has more than one dimension and we normalize along axis 1.
+    We first compute the mean and variance along this axis:
+
+    .. math::
+
+      data\_mean[i] = mean(data[:,i,:,...]) \\
+      data\_var[i] = var(data[:,i,:,...])
+
+    Then compute the normalized output, which has the same shape as input, as following:
+
+    .. math::
+
+      out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} * gamma[i] + beta[i]
+
+    Both *mean* and *var* returns a scalar by treating the input as a vector.
+
+    Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
+    have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and
+    the inverse of ``data_var``, which are needed for the backward pass. Note that gradient of these
+    two outputs are blocked.
+
+    Besides the inputs and the outputs, this operator accepts two auxiliary
+    states, ``moving_mean`` and ``moving_var``, which are *k*-length
+    vectors. They are global statistics for the whole dataset, which are updated
+    by::
+
+      moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
+      moving_var = moving_var * momentum + data_var * (1 - momentum)
+
+    If ``use_global_stats`` is set to be true, then ``moving_mean`` and
+    ``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute
+    the output. It is often used during inference.
+
+    The parameter ``axis`` specifies which axis of the input shape denotes
+    the 'channel' (separately normalized groups).  The default is 1.  Specifying -1 sets the channel
+    axis to be the last item in the input shape.
+
+    Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true,
+    then set ``gamma`` to 1 and its gradient to 0.
+
+    .. Note::
+      When ``fix_gamma`` is set to True, no sparse support is provided. If ``fix_gamma is`` set to False,
+      the sparse tensors will fallback.
+
+    Parameters
+    ----------
+    data : NDArray
+        Input data to batch normalization
+    gamma : NDArray
+        gamma array
+    beta : NDArray
+        beta array
+    moving_mean : NDArray
+        running mean of input
+    moving_var : NDArray
+        running variance of input
+    eps : double, optional, default=0.0010000000474974513
+        Epsilon to prevent div 0. Must be no less than CUDNN_BN_MIN_EPSILON
+        defined in cudnn.h when using cudnn (usually 1e-5)
+    momentum : float, optional, default=0.899999976
+        Momentum for moving average
+    fix_gamma : boolean, optional, default=1
+        Fix gamma while training
+    use_global_stats : boolean, optional, default=0
+        Whether use global moving statistics instead of local batch-norm.
+        This will force change batch-norm into a scale shift operator.
+    output_mean_var : boolean, optional, default=0
+        Output the mean and inverse std
+    axis : int, optional, default='1'
+        Specify which shape axis the channel is specified
+    cudnn_off : boolean, optional, default=0
+        Do not select CUDNN operator, if available
+    min_calib_range : float or None, optional, default=None
+        The minimum scalar value in the form of float32 obtained through calibration.
+        If present, it will be used to by quantized batch norm op to calculate primitive scale.
+        Note: this calib_range is to calib bn output.
+    max_calib_range : float or None, optional, default=None
+        The maximum scalar value in the form of float32 obtained through calibration.
+        If present, it will be used to by quantized batch norm op to calculate primitive scale.
+        Note: this calib_range is to calib bn output.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+    """
+    return _api_internal.batch_norm(x, gamma, beta, running_mean, running_var, eps, momentum,
+                                    fix_gamma, use_global_stats, output_mean_var, axis,
+                                    cudnn_off, min_calib_range, max_calib_range)
+
+
+# pylint: disable=too-many-arguments, unused-argument
+@set_module('mxnet.ndarray.numpy_extension')
+def fully_connected(x, weight, bias=None, num_hidden=None,
+                    no_bias=True, flatten=True, **kwargs):
+    r"""Applies a linear transformation: :math:`Y = XW^T + b`.
+
+    If ``flatten`` is set to be true, then the shapes are:
+
+    - **data**: `(batch_size, x1, x2, ..., xn)`
+    - **weight**: `(num_hidden, x1 * x2 * ... * xn)`
+    - **bias**: `(num_hidden,)`
+    - **out**: `(batch_size, num_hidden)`
+
+    If ``flatten`` is set to be false, then the shapes are:
+
+    - **data**: `(x1, x2, ..., xn, input_dim)`
+    - **weight**: `(num_hidden, input_dim)`
+    - **bias**: `(num_hidden,)`
+    - **out**: `(x1, x2, ..., xn, num_hidden)`
+
+    The learnable parameters include both ``weight`` and ``bias``.
+
+    If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
+
+    .. Note::
+
+        The sparse support for FullyConnected is limited to forward evaluation with `row_sparse`
+        weight and bias, where the length of `weight.indices` and `bias.indices` must be equal
+        to `num_hidden`. This could be useful for model inference with `row_sparse` weights
+        trained with importance sampling or noise contrastive estimation.
+
+        To compute linear transformation with 'csr' sparse data, sparse.dot is recommended instead
+        of sparse.FullyConnected.
+
+    Parameters
+    ----------
+    data : NDArray
+        Input data.
+    weight : NDArray
+        Weight matrix.
+    bias : NDArray
+        Bias parameter.
+    num_hidden : int, required
+        Number of hidden nodes of the output.
+    no_bias : boolean, optional, default=0
+        Whether to disable bias parameter.
+    flatten : boolean, optional, default=1
+        Whether to collapse all but the first axis of the input data tensor.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+    """
+    assert num_hidden is not None, "Please provide number of hidden nodes"
+    if no_bias:
+        return _api_internal.fully_connected(x, weight, num_hidden, no_bias, flatten)
+    else:
+        assert bias is not None, "Missing bias parameter"
+        return _api_internal.fully_connected(x, weight, bias, num_hidden,
+                                             no_bias, flatten)
diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py
index a995e48..d168af6 100644
--- a/python/mxnet/numpy_extension/_op.py
+++ b/python/mxnet/numpy_extension/_op.py
@@ -17,4 +17,371 @@
 
 """Namespace for registering numpy_extension ops for imperative programming."""
 
-__all__ = []
+from ..ndarray import numpy_extension as _mx_nd_npx
+from ..util import set_module
+
+
+__all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax',
+           'activation', 'batch_norm', 'fully_connected']
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.numpy_extension')
+def softmax(data, axis=-1, length=None, temperature=None, use_length=False, dtype=None):
+    r"""Applies the softmax function.
+
+    The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1.
+
+    .. math::
+       softmax(\mathbf{z/t})_j = \frac{e^{z_j/t}}{\sum_{k=1}^K e^{z_k/t}}
+
+    for :math:`j = 1, ..., K`
+
+    t is the temperature parameter in softmax function. By default, t equals 1.0
+
+    Parameters
+    ----------
+    data : NDArray
+        The input array.
+    axis : int, optional, default='-1'
+        The axis along which to compute softmax.
+    length : NDArray
+        The length array.
+    temperature : double or None, optional, default=None
+        Temperature parameter in softmax
+    dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
+        DType of the output in case this can't be inferred. Defaults to
+        the same as input's dtype if not defined (dtype=None).
+    use_length : boolean or None, optional, default=0
+        Whether to use the length input as a mask over the data input.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+
+    Example
+    -------
+    >>> data = np.ones((2, 3))
+    >>> npx.softmax(data, axis=0)
+    array([[0.5, 0.5, 0.5],
+        [0.5, 0.5, 0.5]])
+    >>> npx.softmax(data, axis=1)
+    array([[0.33333334, 0.33333334, 0.33333334],
+        [0.33333334, 0.33333334, 0.33333334]])
+    """
+    return _mx_nd_npx.softmax(data, axis=axis, length=length, temperature=temperature,
+                              use_length=use_length, dtype=dtype)
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.numpy_extension')
+def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False, dtype=None):
+    r"""Computes the log softmax of the input.
+    This is equivalent to computing softmax followed by log.
+
+    Parameters
+    ----------
+    data : NDArray
+        The input array.
+    axis : int, optional, default='-1'
+        The axis along which to compute softmax.
+    length : NDArray
+        The length array.
+    temperature : double or None, optional, default=None
+        Temperature parameter in softmax
+    dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
+        DType of the output in case this can't be inferred. Defaults to
+        the same as input's dtype if not defined (dtype=None).
+    use_length : boolean or None, optional, default=0
+        Whether to use the length input as a mask over the data input.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+
+    Examples
+    --------
+    >>> data = np.array([1, 2, .1])
+    >>> npx.log_softmax(data)
+    array([-1.4170278, -0.4170278, -2.3170278])
+    >>> data = np.array([[1, 2, .1],[.1, 2, 1]])
+    >>> npx.log_softmax(data, axis=0)
+    array([[-0.34115386, -0.6931472 , -1.2411538 ],
+        [-1.2411538 , -0.6931472 , -0.34115386]])
+    """
+    return _mx_nd_npx.log_softmax(data, axis=axis, length=length, temperature=temperature,
+                                  use_length=use_length, dtype=dtype)
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.numpy_extension')
+def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
+    r"""Applies the softmax function masking elements according to the mask provided
+
+    Parameters
+    ----------
+    data : NDArray
+        The input array.
+    mask : NDArray
+        Mask to apply.
+    axis : int, optional, default='-1'
+        The axis along which to compute softmax.
+    temperature : double or None, optional, default=None
+        Temperature parameter in softmax
+    dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
+        DType of the output in case this can't be inferred. Defaults to
+        the same as input's dtype if not defined (dtype=None).
+    normalize : boolean or None, optional, default=1
+        Whether to normalize input data x: x = x - max(x)
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+
+    Examples
+    --------
+    >>> data = np.arange(5)
+    >>> mask = np.array([1, 0, 1, 0, 1])
+    >>> npx.masked_softmax(data, mask)
+    array([0.01587624, 0.        , 0.11731042, 0.        , 0.8668133 ])
+    >>> data = np.arange(10).reshape((2, 5))
+    >>> npx.masked_softmax(data, mask, axis=0)
+    array([[0.00669285, 0.        , 0.00669285, 0.        , 0.00669285],
+        [0.9933072 , 0.        , 0.9933072 , 0.        , 0.9933072 ]])
+    """
+    return _mx_nd_npx.masked_softmax(data, mask, axis=axis, temperature=temperature,
+                                     dtype=dtype)
+
+
+# pylint: disable=too-many-arguments
+@set_module('mxnet.numpy_extension')
+def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
+    r"""Computes the masked log softmax of the input.
+    This is equivalent to computing masked softmax followed by log.
+
+    Parameters
+    ----------
+    data : NDArray
+        The input array.
+    mask : NDArray
+        Mask to apply.
+    axis : int, optional, default='-1'
+        The axis along which to compute softmax.
+    temperature : double or None, optional, default=None
+        Temperature parameter in softmax
+    dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
+        DType of the output in case this can't be inferred. Defaults to
+        the same as input's dtype if not defined (dtype=None).
+    normalize : boolean or None, optional, default=1
+        Whether to normalize input data x: x = x - max(x)
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+
+    Examples
+    --------
+    >>> data = np.arange(5)
+    >>> mask = np.array([1, 0, 1, 0, 1])
+    >>> npx.masked_log_softmax(data, mask)
+    array([-4.1429286 ,        -inf, -2.1429286 ,        -inf, -0.14292854])
+    >>> data = np.arange(10).reshape((2, 5))
+    >>> npx.masked_log_softmax(data, mask, axis=0)
+    array([[-5.0067153 ,        -inf, -5.0067153 ,        -inf, -5.0067153 ],
+       [-0.00671535,        -inf, -0.00671535,        -inf, -0.00671535]])
+    """
+    return _mx_nd_npx.masked_log_softmax(data, mask, axis=axis, temperature=temperature,
+                                         dtype=dtype)
+
+
+# pylint: disable=too-many-arguments, unused-argument
+@set_module('mxnet.numpy_extension')
+def activation(data, act_type='relu', **kwargs):
+    r"""Applies an activation function element-wise to the input.
+
+    The following activation functions are supported:
+
+    - `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`
+    - `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}`
+    - `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}`
+    - `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`
+    - `softsign`: :math:`y = \frac{x}{1 + abs(x)}`
+
+    Parameters
+    ----------
+    data : NDArray
+        The input array.
+    act_type : {'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required
+        Activation function to be applied.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+    """
+    return _mx_nd_npx.activation(data, act_type=act_type)
+
+
+# pylint: disable=too-many-arguments, unused-argument
+@set_module('mxnet.numpy_extension')
+def batch_norm(x, gamma, beta, running_mean, running_var, eps=1e-3, momentum=0.9,
+               fix_gamma=True, use_global_stats=False, output_mean_var=False, axis=1,
+               cudnn_off=False, min_calib_range=None, max_calib_range=None, **kwargs):
+    r"""Batch normalization.
+
+    Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as
+    well as offset ``beta``.
+
+    Assume the input has more than one dimension and we normalize along axis 1.
+    We first compute the mean and variance along this axis:
+
+    .. math::
+
+      data\_mean[i] = mean(data[:,i,:,...]) \\
+      data\_var[i] = var(data[:,i,:,...])
+
+    Then compute the normalized output, which has the same shape as input, as following:
+
+    .. math::
+
+      out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} * gamma[i] + beta[i]
+
+    Both *mean* and *var* returns a scalar by treating the input as a vector.
+
+    Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
+    have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and
+    the inverse of ``data_var``, which are needed for the backward pass. Note that gradient of these
+    two outputs are blocked.
+
+    Besides the inputs and the outputs, this operator accepts two auxiliary
+    states, ``moving_mean`` and ``moving_var``, which are *k*-length
+    vectors. They are global statistics for the whole dataset, which are updated
+    by::
+
+      moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
+      moving_var = moving_var * momentum + data_var * (1 - momentum)
+
+    If ``use_global_stats`` is set to be true, then ``moving_mean`` and
+    ``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute
+    the output. It is often used during inference.
+
+    The parameter ``axis`` specifies which axis of the input shape denotes
+    the 'channel' (separately normalized groups).  The default is 1.  Specifying -1 sets the channel
+    axis to be the last item in the input shape.
+
+    Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true,
+    then set ``gamma`` to 1 and its gradient to 0.
+
+    .. Note::
+      When ``fix_gamma`` is set to True, no sparse support is provided. If ``fix_gamma is`` set to False,
+      the sparse tensors will fallback.
+
+    Parameters
+    ----------
+    data : NDArray
+        Input data to batch normalization
+    gamma : NDArray
+        gamma array
+    beta : NDArray
+        beta array
+    moving_mean : NDArray
+        running mean of input
+    moving_var : NDArray
+        running variance of input
+    eps : double, optional, default=0.0010000000474974513
+        Epsilon to prevent div 0. Must be no less than CUDNN_BN_MIN_EPSILON
+        defined in cudnn.h when using cudnn (usually 1e-5)
+    momentum : float, optional, default=0.899999976
+        Momentum for moving average
+    fix_gamma : boolean, optional, default=1
+        Fix gamma while training
+    use_global_stats : boolean, optional, default=0
+        Whether use global moving statistics instead of local batch-norm.
+        This will force change batch-norm into a scale shift operator.
+    output_mean_var : boolean, optional, default=0
+        Output the mean and inverse std
+    axis : int, optional, default='1'
+        Specify which shape axis the channel is specified
+    cudnn_off : boolean, optional, default=0
+        Do not select CUDNN operator, if available
+    min_calib_range : float or None, optional, default=None
+        The minimum scalar value in the form of float32 obtained through calibration.
+        If present, it will be used to by quantized batch norm op to calculate primitive scale.
+        Note: this calib_range is to calib bn output.
+    max_calib_range : float or None, optional, default=None
+        The maximum scalar value in the form of float32 obtained through calibration.
+        If present, it will be used to by quantized batch norm op to calculate primitive scale.
+        Note: this calib_range is to calib bn output.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+    """
+    return _mx_nd_npx.batch_norm(x, gamma, beta, running_mean, running_var, eps=eps,
+                                 momentum=momentum, fix_gamma=fix_gamma,
+                                 use_global_stats=use_global_stats,
+                                 output_mean_var=output_mean_var, axis=axis, cudnn_off=cudnn_off,
+                                 min_calib_range=min_calib_range, max_calib_range=max_calib_range)
+
+
+# pylint: disable=too-many-arguments, unused-argument
+@set_module('mxnet.numpy_extension')
+def fully_connected(x, weight, bias=None, num_hidden=None,
+                    no_bias=True, flatten=True, **kwargs):
+    r"""Applies a linear transformation: :math:`Y = XW^T + b`.
+
+    If ``flatten`` is set to be true, then the shapes are:
+
+    - **data**: `(batch_size, x1, x2, ..., xn)`
+    - **weight**: `(num_hidden, x1 * x2 * ... * xn)`
+    - **bias**: `(num_hidden,)`
+    - **out**: `(batch_size, num_hidden)`
+
+    If ``flatten`` is set to be false, then the shapes are:
+
+    - **data**: `(x1, x2, ..., xn, input_dim)`
+    - **weight**: `(num_hidden, input_dim)`
+    - **bias**: `(num_hidden,)`
+    - **out**: `(x1, x2, ..., xn, num_hidden)`
+
+    The learnable parameters include both ``weight`` and ``bias``.
+
+    If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
+
+    .. Note::
+
+        The sparse support for FullyConnected is limited to forward evaluation with `row_sparse`
+        weight and bias, where the length of `weight.indices` and `bias.indices` must be equal
+        to `num_hidden`. This could be useful for model inference with `row_sparse` weights
+        trained with importance sampling or noise contrastive estimation.
+
+        To compute linear transformation with 'csr' sparse data, sparse.dot is recommended instead
+        of sparse.FullyConnected.
+
+    Parameters
+    ----------
+    data : NDArray
+        Input data.
+    weight : NDArray
+        Weight matrix.
+    bias : NDArray
+        Bias parameter.
+    num_hidden : int, required
+        Number of hidden nodes of the output.
+    no_bias : boolean, optional, default=0
+        Whether to disable bias parameter.
+    flatten : boolean, optional, default=1
+        Whether to collapse all but the first axis of the input data tensor.
+
+    Returns
+    -------
+    out : NDArray or list of NDArrays
+        The output of this function.
+    """
+    return _mx_nd_npx.fully_connected(x, weight, bias, num_hidden=num_hidden,
+                                      no_bias=no_bias, flatten=flatten)
diff --git a/src/api/operator/numpy_extension/npx_activation_op.cc b/src/api/operator/numpy_extension/npx_activation_op.cc
new file mode 100644
index 0000000..c072f6e
--- /dev/null
+++ b/src/api/operator/numpy_extension/npx_activation_op.cc
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file npx_activation_op.cc
+ * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_activation_op.cc
+ */
+#include <mxnet/api_registry.h>
+#include <mxnet/runtime/packed_func.h>
+#include "../utils.h"
+#include "../../../operator/nn/activation-inl.h"
+
+namespace mxnet {
+
+inline int String2MXNetActType(const std::string& s) {
+  using namespace op;
+  if (s == "relu") {
+    return activation::kReLU;
+  } else if (s == "sigmoid") {
+    return activation::kSigmoid;
+  } else if (s == "tanh") {
+    return activation::kTanh;
+  } else if (s == "softrelu") {
+    return activation::kSoftReLU;
+  } else if (s == "softsign") {
+    return activation::kSoftSign;
+  } else {
+    LOG(FATAL) << "unknown activation type " << s;
+  }
+  LOG(FATAL) << "should not reach here ";
+  return 0;
+}
+
+MXNET_REGISTER_API("_npx.activation")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+  using namespace runtime;
+  nnvm::NodeAttrs attrs;
+  const nnvm::Op* op = Op::Get("_npx_activation");
+  op::ActivationParam param;
+  // act_type
+  param.act_type = String2MXNetActType(args[1].operator std::string());
+  attrs.parsed = param;
+  attrs.op = op;
+  SetAttrDict<op::ActivationParam>(&attrs);
+  // inputs
+  NDArray* inputs[] = {args[0].operator NDArray*()};
+  int num_inputs = 1;
+  int num_outputs = 0;
+  auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
+  *ret = ndoutputs[0];
+});
+
+}  // namespace mxnet
diff --git a/src/api/operator/numpy_extension/npx_batch_norm_op.cc b/src/api/operator/numpy_extension/npx_batch_norm_op.cc
new file mode 100644
index 0000000..dcf3ac4
--- /dev/null
+++ b/src/api/operator/numpy_extension/npx_batch_norm_op.cc
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file npx_batch_norm_op.cc
+ * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_batch_norm_op.cc
+ */
+#include <mxnet/api_registry.h>
+#include <mxnet/runtime/packed_func.h>
+#include "../utils.h"
+#include "../../../operator/nn/batch_norm-inl.h"
+
+namespace mxnet {
+
+MXNET_REGISTER_API("_npx.batch_norm")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+  using namespace runtime;
+  nnvm::NodeAttrs attrs;
+  const nnvm::Op* op = Op::Get("_npx_batch_norm");
+  op::BatchNormParam param;
+  // eps
+  param.eps = args[5].operator double();
+  // momentum
+  param.momentum = args[6].operator double();
+  // fix_gamma
+  param.fix_gamma = args[7].operator bool();
+  // use_global_stats
+  param.use_global_stats = args[8].operator bool();
+  // output_mean_var
+  param.output_mean_var = args[9].operator bool();
+  // axis
+  param.axis = args[10].operator int();
+  // cudnn_off
+  param.cudnn_off = args[11].operator bool();
+  // min_calib_range
+  if (args[12].type_code() == kDLFloat || args[12].type_code() == kDLInt) {
+    param.min_calib_range = args[12].operator double();
+  } else {
+    param.min_calib_range = dmlc::nullopt;
+  }
+  // max_calib_range
+  if (args[13].type_code() == kDLFloat || args[13].type_code() == kDLInt) {
+    param.max_calib_range = args[13].operator double();
+  } else {
+    param.max_calib_range = dmlc::nullopt;
+  }
+  attrs.parsed = param;
+  attrs.op = op;
+  SetAttrDict<op::BatchNormParam>(&attrs);
+  // inputs
+  int num_inputs = 5;
+  std::vector<NDArray*> inputs;
+  inputs.reserve(num_inputs);
+  for (int i = 0; i < num_inputs; ++i) {
+    inputs.push_back(args[i].operator mxnet::NDArray*());
+  }
+  int num_outputs = 0;
+  auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr);
+  if (num_outputs == 1) {
+    *ret = ndoutputs[0];
+  } else {
+    std::vector<NDArrayHandle> ndarray_handles;
+    ndarray_handles.reserve(num_outputs);
+    for (int i = 0; i < num_outputs; ++i) {
+      ndarray_handles.emplace_back(ndoutputs[i]);
+    }
+    *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end());
+  }
+});
+
+}  // namespace mxnet
diff --git a/src/api/operator/numpy_extension/npx_fully_connected_op.cc b/src/api/operator/numpy_extension/npx_fully_connected_op.cc
new file mode 100644
index 0000000..d9ab3c0
--- /dev/null
+++ b/src/api/operator/numpy_extension/npx_fully_connected_op.cc
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file npx_fully_connected_op.cc
+ * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_fully_connected_op.cc
+ */
+#include <mxnet/api_registry.h>
+#include <mxnet/runtime/packed_func.h>
+#include "../utils.h"
+#include "../../../operator/nn/fully_connected-inl.h"
+
+namespace mxnet {
+
+MXNET_REGISTER_API("_npx.fully_connected")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+  using namespace runtime;
+  int args_size = args.size();
+  nnvm::NodeAttrs attrs;
+  const nnvm::Op* op = Op::Get("_npx_fully_connected");
+  op::FullyConnectedParam param;
+  // no_bias
+  param.no_bias = args[args_size - 2].operator bool();
+  // inputs
+  int num_inputs = 2;
+  if (param.no_bias) {
+    num_inputs = 2;
+  } else {
+    num_inputs = 3;
+  }
+  std::vector<NDArray*> inputs;
+  inputs.reserve(num_inputs);
+  for (int i = 0; i < num_inputs; ++i) {
+    inputs.push_back(args[i].operator mxnet::NDArray*());
+  }
+  // num_hidden
+  param.num_hidden = args[args_size - 3].operator int();
+  // flatten
+  param.flatten = args[args_size - 1].operator bool();
+
+  attrs.parsed = param;
+  attrs.op = op;
+  SetAttrDict<op::FullyConnectedParam>(&attrs);
+
+  int num_outputs = 0;
+  auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr);
+  *ret = ndoutputs[0];
+});
+
+}  // namespace mxnet
diff --git a/src/api/operator/numpy_extension/npx_softmax_op.cc b/src/api/operator/numpy_extension/npx_softmax_op.cc
new file mode 100644
index 0000000..641129e
--- /dev/null
+++ b/src/api/operator/numpy_extension/npx_softmax_op.cc
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file npx_softmax_op.cc
+ * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_softmax_op.cc
+ */
+#include <mxnet/api_registry.h>
+#include <mxnet/runtime/packed_func.h>
+#include "../utils.h"
+#include "../../../operator/nn/softmax-inl.h"
+
+namespace mxnet {
+
+MXNET_REGISTER_API("_npx.softmax")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+  using namespace runtime;
+  nnvm::NodeAttrs attrs;
+  static const nnvm::Op* op = Op::Get("_npx_softmax");
+  op::SoftmaxParam param;
+  int args_size = args.size();
+  // inputs
+  int num_inputs = args_size - 4;
+  std::vector<NDArray*> inputs;
+  inputs.reserve(num_inputs);
+  for (int i = 0; i < num_inputs; ++i) {
+    inputs.push_back(args[i].operator mxnet::NDArray*());
+  }
+
+  // parse use_length
+  if (args[args_size - 2].type_code() == kNull) {
+    param.use_length = false;
+  } else {
+    param.use_length = args[args_size - 2].operator bool();
+  }
+
+  // parse axis
+  if (args[args_size - 4].type_code() == kDLInt) {
+    param.axis = args[args_size - 4].operator int();
+  } else {
+    param.axis = static_cast<int>(args[args_size - 4].operator double());
+  }
+
+  // parse temperature
+  if (args[args_size - 3].type_code() == kNull) {
+    param.temperature = dmlc::nullopt;
+  } else {
+    param.temperature = args[args_size - 3].operator int64_t();
+  }
+
+  // parse dtype
+  if (args[args_size - 1].type_code() == kNull) {
+    param.dtype = dmlc::nullopt;
+  } else {
+    param.dtype = String2MXNetTypeWithBool(args[args_size - 1].operator std::string());
+  }
+
+  attrs.parsed = param;
+  attrs.op = op;
+  SetAttrDict<op::SoftmaxParam>(&attrs);
+
+  int num_outputs = 0;
+  auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr);
+  *ret = ndoutputs[0];
+});
+
+MXNET_REGISTER_API("_npx.log_softmax")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+  using namespace runtime;
+  nnvm::NodeAttrs attrs;
+  static const nnvm::Op* op = Op::Get("_npx_log_softmax");
+  op::SoftmaxParam param;
+
+  int args_size = args.size();
+  // inputs
+  int num_inputs = args_size - 4;
+  std::vector<NDArray*> inputs;
+  inputs.reserve(num_inputs);
+  for (int i = 0; i < num_inputs; ++i) {
+    inputs.push_back(args[i].operator mxnet::NDArray*());
+  }
+
+  // parse use_length
+  if (args[args_size - 2].type_code() == kNull) {
+    param.use_length = false;
+  } else {
+    param.use_length = args[args_size - 2].operator bool();
+  }
+
+  // parse axis
+  if (args[args_size - 4].type_code() == kDLInt) {
+    param.axis = args[args_size - 4].operator int();
+  } else {
+    param.axis = static_cast<int>(args[args_size - 4].operator double());
+  }
+
+  // parse temperature
+  if (args[args_size - 3].type_code() == kNull) {
+    param.temperature = dmlc::nullopt;
+  } else {
+    param.temperature = args[args_size - 3].operator int64_t();
+  }
+
+  // parse dtype
+  if (args[args_size - 1].type_code() == kNull) {
+    param.dtype = dmlc::nullopt;
+  } else {
+    param.dtype = String2MXNetTypeWithBool(args[args_size - 1].operator std::string());
+  }
+
+  attrs.parsed = param;
+  attrs.op = op;
+  SetAttrDict<op::SoftmaxParam>(&attrs);
+
+  int num_outputs = 0;
+  auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr);
+  *ret = ndoutputs[0];
+});
+
+}  // namespace mxnet
diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h
index 06ff1fe..1111464 100644
--- a/src/operator/nn/activation-inl.h
+++ b/src/operator/nn/activation-inl.h
@@ -69,6 +69,29 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> {
   bool operator==(const ActivationParam& other) const {
     return this->act_type == other.act_type;
   }
+  std::string MXNetActType2String(int act_type) {
+    switch (act_type) {
+      case activation::kReLU:
+        return "relu";
+      case activation::kSigmoid:
+        return "sigmoid";
+      case activation::kTanh:
+        return "tanh";
+      case activation::kSoftReLU:
+        return "softrelu";
+      case activation::kSoftSign:
+        return "softsign";
+      default:
+        LOG(FATAL) << "Unknown act_type enum " << act_type;
+    }
+    LOG(FATAL) << "should not reach here ";
+    return "";
+  }
+  void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
+    std::ostringstream act_type_s;
+    act_type_s << act_type;
+    (*dict)["act_type"] = MXNetActType2String(act_type);
+  }
 };
 
 }  // namespace op
diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h
index 485b3b3..bb8313d 100644
--- a/src/operator/nn/batch_norm-inl.h
+++ b/src/operator/nn/batch_norm-inl.h
@@ -125,6 +125,28 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
     }
     return flag;
   }
+  void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
+    std::ostringstream eps_s, momentum_s, fix_gamma_s, use_global_stats_s, output_mean_var_s,
+                       axis_s, cudnn_off_s, min_calib_range_s, max_calib_range_s;
+    eps_s << eps;
+    momentum_s << momentum;
+    fix_gamma_s << fix_gamma;
+    use_global_stats_s << use_global_stats;
+    output_mean_var_s << output_mean_var;
+    axis_s << axis;
+    cudnn_off_s << cudnn_off;
+    min_calib_range_s << min_calib_range;
+    max_calib_range_s << max_calib_range;
+    (*dict)["eps"] = eps_s.str();
+    (*dict)["momentum"] = momentum_s.str();
+    (*dict)["fix_gamma"] = fix_gamma_s.str();
+    (*dict)["use_global_stats"] = use_global_stats_s.str();
+    (*dict)["output_mean_var"] = output_mean_var_s.str();
+    (*dict)["axis"] = axis_s.str();
+    (*dict)["cudnn_off"] = cudnn_off_s.str();
+    (*dict)["min_calib_range"] = min_calib_range_s.str();
+    (*dict)["max_calib_range"] = max_calib_range_s.str();
+  }
 };
 
 }  // namespace op
diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h
index c90e8ce..51d6f5c 100644
--- a/src/operator/nn/fully_connected-inl.h
+++ b/src/operator/nn/fully_connected-inl.h
@@ -80,6 +80,15 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
            this->no_bias == other.no_bias &&
            this->flatten == other.flatten;
   }
+  void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
+    std::ostringstream num_hidden_s, no_bias_s, flatten_s;
+    num_hidden_s << num_hidden;
+    no_bias_s << no_bias;
+    flatten_s << flatten;
+    (*dict)["num_hidden"] = num_hidden_s.str();
+    (*dict)["no_bias"] = no_bias_s.str();
+    (*dict)["flatten"] = flatten_s.str();
+  }
 };
 
 /**
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index 512d8d2..7f64b74 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -1179,6 +1179,21 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
            this->dtype == other.dtype &&
            this->use_length == other.use_length;
   }
+  void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
+    std::ostringstream axis_s, temperature_s, dtype_s, use_length_s;
+    axis_s << axis;
+    temperature_s << temperature;
+    dtype_s << dtype;
+    use_length_s << use_length;
+    (*dict)["axis"] = axis_s.str();
+    (*dict)["temperature"] = temperature_s.str();
+    if (dtype.has_value()) {
+      (*dict)["dtype"] = MXNetTypeWithBool2String(dtype.value());
+    } else {
+      (*dict)["dtype"] = dtype_s.str();
+    }
+    (*dict)["use_length"] = use_length_s.str();
+  }
 };
 
 struct MaskedSoftmaxParam : public dmlc::Parameter<MaskedSoftmaxParam> {