You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:01:17 UTC

[incubator-mxnet] 41/42: Numpy Tensordot Operator (#15349)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 64d49527ae5a1431e4eaf3ff7422ccb54d0dc6d9
Author: ckt624 <ck...@gmail.com>
AuthorDate: Wed Jul 17 10:49:23 2019 +0800

    Numpy Tensordot Operator  (#15349)
    
    * implements numpy tensordot
    
    * Fixed bugs and optimized backward operator
    
    * Rewrited tests
    
    * Debuging
    
    * Debuging 0-size input
    
    * Moved axis-reordering from frontend to backend
    
    * Added comments
    
    * integrated forward part
    
    * Add more tests
    
    * Fixed GPU bugs
    
    * Add comments to tensordot
    
    * Add empty lines.
    
    * Change tests
    
    * Remove redundant code
    
    * Change file names
    
    * Add numerical backward test
    
    * Change np.dot for case 5.
    
    * Remove spaces.
    
    * Remove more spaces.
    
    * Add head files.
    
    * Remove spaces in python interface
    
    * Refactored.
    
    * Changed intereface.
    
    * changed GPU test.
    
    * Clean codes.
    
    * Change styles.
    
    * Remove blank lines.
    
    * Add blank lines
    
    * Recover lines.
    
    * Support python 2
    
    * Test Python 2
    
    * Add more tests
    
    * Add error msg
    
    * Change comments.
---
 python/mxnet/ndarray/numpy/_op.py        |  82 ++++-
 python/mxnet/numpy/multiarray.py         |  82 ++++-
 python/mxnet/symbol/numpy/_symbol.py     |  66 +++-
 src/operator/numpy/np_dot-inl.h          | 131 ++------
 src/operator/numpy/np_tensordot_op-inl.h | 556 +++++++++++++++++++++++++++++++
 src/operator/numpy/np_tensordot_op.cc    | 226 +++++++++++++
 src/operator/numpy/np_tensordot_op.cu    |  42 +++
 tests/python/unittest/test_numpy_op.py   | 161 ++++++++-
 8 files changed, 1227 insertions(+), 119 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 76ed88c..1049bb1 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -33,7 +33,87 @@ __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
            'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 'eye',
            'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan', 'sign', 'log',
            'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 'square', 'arcsin',
-           'argsort', 'hstack']
+           'argsort', 'hstack', 'tensordot']
+
+
+@set_module('mxnet.ndarray.numpy')
+def tensordot(a, b, axes=2):
+    r"""
+    tensordot(a, b, axes=2)
+
+    Compute tensor dot product along specified axes for arrays >= 1-D.
+
+    Given two tensors (arrays of dimension greater than or equal to one),
+    `a` and `b`, and an ndarray object containing two ndarray
+    objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
+    elements (components) over the axes specified by ``a_axes`` and
+    ``b_axes``. The third argument can be a single non-negative
+    integer_like scalar, ``N``; if it is such, then the last ``N``
+    dimensions of `a` and the first ``N`` dimensions of `b` are summed
+    over.
+
+    Parameters
+    ----------
+    a, b : ndarray, len(shape) >= 1
+        Tensors to "dot".
+
+    axes : int or (2,) ndarray
+        * integer_like
+        If an int N, sum over the last N axes of `a` and the first N axes
+        of `b` in order. The sizes of the corresponding axes must match.
+        * (2,) ndarray
+        Or, a list of axes to be summed over, first sequence applying to `a`,
+        second to `b`. Both elements ndarray must be of the same length.
+
+    See Also
+    --------
+    dot, einsum
+
+    Notes
+    -----
+    Three common use cases are:
+        * ``axes = 0`` : tensor product :math:`a\otimes b`
+        * ``axes = 1`` : tensor dot product :math:`a\cdot b`
+        * ``axes = 2`` : (default) tensor double contraction :math:`a:b`
+
+    When `axes` is integer_like, the sequence for evaluation will be: first
+    the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
+    Nth axis in `b` last.
+
+    When there is more than one axis to sum over - and they are not the last
+    (first) axes of `a` (`b`) - the argument `axes` should consist of
+    two sequences of the same length, with the first axis to sum over given
+    first in both sequences, the second axis second, and so forth.
+
+    Examples
+    --------
+    >>> a = np.arange(60.).reshape(3,4,5)
+    >>> b = np.arange(24.).reshape(4,3,2)
+    >>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
+    >>> c.shape
+    (5, 2)
+    >>> c
+    array([[ 4400.,  4730.],
+           [ 4532.,  4874.],
+           [ 4664.,  5018.],
+           [ 4796.,  5162.],
+           [ 4928.,  5306.]])
+    """
+    if _np.isscalar(axes):
+        return _npi.tensordot_int_axes(a, b, axes)
+
+    if len(axes) != 2:
+        raise ValueError('Axes must consist of two arrays.')
+    a_axes_summed, b_axes_summed = axes
+    if _np.isscalar(a_axes_summed):
+        a_axes_summed = (a_axes_summed,)
+    if _np.isscalar(b_axes_summed):
+        b_axes_summed = (b_axes_summed,)
+
+    if len(a_axes_summed) != len(b_axes_summed):
+        raise ValueError('Axes length mismatch')
+
+    return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)
 
 
 @set_module('mxnet.ndarray.numpy')
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index d20db96..dd51431 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -48,7 +48,87 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', '
            'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 'eye', 'sin', 'cos',
            'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan', 'sign', 'log',
            'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 'square', 'arcsin',
-           'argsort', 'hstack']
+           'argsort', 'hstack', 'tensordot']
+
+
+@set_module('mxnet.numpy')
+def tensordot(a, b, axes=2):
+    r"""
+    tensordot(a, b, axes=2)
+
+    Compute tensor dot product along specified axes for arrays >= 1-D.
+
+    Given two tensors (arrays of dimension greater than or equal to one),
+    `a` and `b`, and an ndarray object containing two ndarray
+    objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
+    elements (components) over the axes specified by ``a_axes`` and
+    ``b_axes``. The third argument can be a single non-negative
+    integer_like scalar, ``N``; if it is such, then the last ``N``
+    dimensions of `a` and the first ``N`` dimensions of `b` are summed
+    over.
+
+    Parameters
+    ----------
+    a, b : ndarray, len(shape) >= 1
+        Tensors to "dot".
+
+    axes : int or (2,) ndarray
+        * integer_like
+        If an int N, sum over the last N axes of `a` and the first N axes
+        of `b` in order. The sizes of the corresponding axes must match.
+        * (2,) ndarray
+        Or, a list of axes to be summed over, first sequence applying to `a`,
+        second to `b`. Both elements ndarray must be of the same length.
+
+    See Also
+    --------
+    dot, einsum
+
+    Notes
+    -----
+    Three common use cases are:
+        * ``axes = 0`` : tensor product :math:`a\otimes b`
+        * ``axes = 1`` : tensor dot product :math:`a\cdot b`
+        * ``axes = 2`` : (default) tensor double contraction :math:`a:b`
+
+    When `axes` is integer_like, the sequence for evaluation will be: first
+    the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
+    Nth axis in `b` last.
+
+    When there is more than one axis to sum over - and they are not the last
+    (first) axes of `a` (`b`) - the argument `axes` should consist of
+    two sequences of the same length, with the first axis to sum over given
+    first in both sequences, the second axis second, and so forth.
+
+    Examples
+    --------
+    >>> a = np.arange(60.).reshape(3,4,5)
+    >>> b = np.arange(24.).reshape(4,3,2)
+    >>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
+    >>> c.shape
+    (5, 2)
+    >>> c
+    array([[ 4400.,  4730.],
+           [ 4532.,  4874.],
+           [ 4664.,  5018.],
+           [ 4796.,  5162.],
+           [ 4928.,  5306.]])
+    """
+    if _np.isscalar(axes):
+        return _npi.tensordot_int_axes(a, b, axes)
+
+    if len(axes) != 2:
+        raise ValueError('Axes must consist of two arrays.')
+    a_axes_summed, b_axes_summed = axes
+    if _np.isscalar(a_axes_summed):
+        a_axes_summed = (a_axes_summed,)
+    if _np.isscalar(b_axes_summed):
+        b_axes_summed = (b_axes_summed,)
+
+    if len(a_axes_summed) != len(b_axes_summed):
+        raise ValueError('Axes length mismatch')
+
+    return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)
 
 
 # This function is copied from ndarray.py since pylint
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 987ed61..742a10d 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -33,7 +33,7 @@ __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arang
            'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes',
            'expand_dims', 'tile', 'linspace', 'eye', 'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt',
            'abs', 'exp', 'arctan', 'sign', 'log', 'degrees', 'log2', 'rint', 'radians', 'mean',
-           'reciprocal', 'square', 'arcsin', 'argsort', 'hstack']
+           'reciprocal', 'square', 'arcsin', 'argsort', 'hstack', 'tensordot']
 
 
 def _num_outputs(sym):
@@ -2294,4 +2294,68 @@ def arcsin(x, out=None, **kwargs):
     return _unary_func_helper(x, _npi.arcsin, _np.arcsin, out=out, **kwargs)
 
 
+@set_module('mxnet.symbol.numpy')
+def tensordot(a, b, axes=2):
+    r"""
+    tensordot(a, b, axes=2)
+
+    Compute tensor dot product along specified axes for arrays >= 1-D.
+
+    Given two tensors (arrays of dimension greater than or equal to one),
+    `a` and `b`, and an ndarray object containing two ndarray
+    objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
+    elements (components) over the axes specified by ``a_axes`` and
+    ``b_axes``. The third argument can be a single non-negative
+    integer_like scalar, ``N``; if it is such, then the last ``N``
+    dimensions of `a` and the first ``N`` dimensions of `b` are summed
+    over.
+
+    Parameters
+    ----------
+    a, b : _Symbol
+        Tensors to "dot".
+
+    axes : int or (2,) ndarray
+        * integer_like
+        If an int N, sum over the last N axes of `a` and the first N axes
+        of `b` in order. The sizes of the corresponding axes must match.
+        * (2,) array_like
+        Or, a list of axes to be summed over, first sequence applying to `a`,
+        second to `b`. Both elements array_like must be of the same length.
+
+
+    Notes
+    -----
+    Three common use cases are:
+        * ``axes = 0`` : tensor product :math:`a\otimes b`
+        * ``axes = 1`` : tensor dot product :math:`a\cdot b`
+        * ``axes = 2`` : (default) tensor double contraction :math:`a:b`
+
+    When `axes` is integer_like, the sequence for evaluation will be: first
+    the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
+    Nth axis in `b` last.
+
+    When there is more than one axis to sum over - and they are not the last
+    (first) axes of `a` (`b`) - the argument `axes` should consist of
+    two sequences of the same length, with the first axis to sum over given
+    first in both sequences, the second axis second, and so forth.
+
+    """
+    if _np.isscalar(axes):
+        return _npi.tensordot_int_axes(a, b, axes)
+
+    if len(axes) != 2:
+        raise ValueError('Axes must consist of two arrays.')
+    a_axes_summed, b_axes_summed = axes
+    if _np.isscalar(a_axes_summed):
+        a_axes_summed = (a_axes_summed,)
+    if _np.isscalar(b_axes_summed):
+        b_axes_summed = (b_axes_summed,)
+
+    if len(a_axes_summed) != len(b_axes_summed):
+        raise ValueError('Axes length mismatch')
+
+    return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)
+
+
 _set_np_symbol_class(_Symbol)
diff --git a/src/operator/numpy/np_dot-inl.h b/src/operator/numpy/np_dot-inl.h
index fa67c07..8f60bae 100644
--- a/src/operator/numpy/np_dot-inl.h
+++ b/src/operator/numpy/np_dot-inl.h
@@ -30,47 +30,11 @@
 #include "../tensor/dot-inl.h"
 #include "../tensor/elemwise_binary_op.h"
 #include "../tensor/broadcast_reduce_op.h"
+#include "np_tensordot_op-inl.h"
 
 namespace mxnet {
 namespace op {
 
-template<typename xpu>
-inline void MMImpl(const OpContext& ctx,
-                   const TBlob& a,
-                   const TBlob& b,
-                   const TBlob& out,
-                   const OpReqType req,
-                   const bool trans_a = false,
-                   const bool trans_b = false) {
-  using namespace mshadow;
-  using namespace mshadow_op;
-
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  index_t ma, na, mb, nb;
-  na = a.size(a.ndim() - 1);
-  ma = a.Size() / na;
-  mb = b.size(0);
-  nb = b.Size() / mb;
-  MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, {
-    Tensor<xpu, 2, DType> input0 = a.get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
-    Tensor<xpu, 2, DType> input1 = b.get_with_shape<xpu, 2, DType>(Shape2(mb, nb), s);
-    Tensor<xpu, 2, DType> output0;
-    if (trans_a && trans_b) {
-      output0 = out.get_with_shape<xpu, 2, DType>(Shape2(na, mb), s);
-      ASSIGN_DISPATCH(output0, req, dot(input0.T(), input1.T()));
-    } else if (!trans_a && trans_b) {
-      output0 = out.get_with_shape<xpu, 2, DType>(Shape2(ma, mb), s);
-      ASSIGN_DISPATCH(output0, req, dot(input0, input1.T()));
-    } else if (trans_a && !trans_b) {
-      output0 = out.get_with_shape<xpu, 2, DType>(Shape2(na, nb), s);
-      ASSIGN_DISPATCH(output0, req, dot(input0.T(), input1));
-    } else {
-      output0 = out.get_with_shape<xpu, 2, DType>(Shape2(ma, nb), s);
-      ASSIGN_DISPATCH(output0, req, dot(input0, input1));
-    }
-  });
-}
-
 template<int req>
 struct scalar_mul_kernel {
   template<typename DType>
@@ -114,18 +78,6 @@ inline void NumpyDotForward(const nnvm::NodeAttrs& attrs,
             Shape1(out.shape_.Size()), s);
         out_data = static_cast<DType>(0);
       }
-    } else if (a_shape.ndim() == 1 && b_shape.ndim() == 1) {
-      // Case 1: both 1-D arrays, inner product of vectors
-      if (out.type_flag_ == kFloat16) {
-        MMImpl<xpu>(ctx, a, b, out, req[0]);
-      } else {
-        CHECK_NE(req[0], kAddTo) << "AddTo not yet supported";
-        Tensor<xpu, 1, DType> mock_1d = out.get_with_shape<xpu, 1, DType>(Shape1(1), s);
-        VectorDot(mock_1d, a.get<xpu, 1, DType>(s), b.get<xpu, 1, DType>(s));
-      }
-    } else if (a_shape.ndim() == 2 && b_shape.ndim() == 2) {
-      // Case 2: both 2-D arrays, matrix multiplication
-      MMImpl<xpu>(ctx, a, b, out, req[0]);
     } else if (a_shape.ndim() == 0 && b_shape.ndim() == 0) {
       // Case 3: both 0-D scalars, equivalent to multiply
       Tensor<xpu, 1, DType> a_data = a.get_with_shape<xpu, 1, DType>(Shape1(1), s);
@@ -140,17 +92,16 @@ inline void NumpyDotForward(const nnvm::NodeAttrs& attrs,
         Kernel<scalar_mul_kernel<Req>, xpu>::Launch(
           s, out.Size(), out.dptr<DType>(), tensor, scalar);
       });
-    } else if (a_shape.ndim() == 1 || b_shape.ndim() == 1) {
-      // Case 4: a is N-D array and b is 1-D array, sum product over the last axis
-      MMImpl<xpu>(ctx, a, b, out, req[0]);
+    } else if (b_shape.ndim() < 3) {
+      // Case 1, 2, 4, 5: a is N-D array (N >= 1) and b is vector or matrix, sum product
+      //        over the last axis of a and the first axis of b
+      TensordotIntAxesImpl<xpu>(1, ctx, a, b, out, req[0]);
     } else {
-      // Case 5: a is N-D array and b is M-D array, sum product over the last axis
+      // Case 5.5: a is N-D array and b is M-D array (M > 2), sum product over the last axis
       //         of a and the 2nd-to-last axis of b
-      // TODO(haojin2): To be implemented...
-      if (b_shape.ndim() != 2) {
-        LOG(FATAL) << "Only support case 5 when b.ndim = 2";
-      }
-      MMImpl<xpu>(ctx, a, b, out, req[0]);
+      const Tuple<int> a_axes_summed({a_shape.ndim() - 1});
+      const Tuple<int> b_axes_summed({b_shape.ndim() - 2});
+      TensordotImpl<xpu>(a_axes_summed, b_axes_summed, ctx, a, b, out, req);
     }
   });
 }
@@ -179,22 +130,7 @@ inline void NumpyDotBackward(const nnvm::NodeAttrs& attrs,
 
   Stream<xpu> *s = ctx.get_stream<xpu>();
   MSHADOW_REAL_TYPE_SWITCH(ograd.type_flag_, DType, {
-    if (a_shape.ndim() == 1 && b_shape.ndim() == 1) {
-      // Case 1: both 1-D arrays, inner product of vectors
-      Tensor<xpu, 1, DType> out_grad = ograd.get_with_shape<xpu, 1, DType>(Shape1(1), s);
-      Tensor<xpu, 1, DType> a_data = a.get<xpu, 1, DType>(s);
-      Tensor<xpu, 1, DType> b_data = b.get<xpu, 1, DType>(s);
-      Tensor<xpu, 1, DType> a_grad = grad_a.get<xpu, 1, DType>(s);
-      Tensor<xpu, 1, DType> b_grad = grad_b.get<xpu, 1, DType>(s);
-      ASSIGN_DISPATCH(b_grad, req[1],
-                      broadcast_scalar(out_grad, a_data.shape_) * a_data);
-      ASSIGN_DISPATCH(a_grad, req[0],
-                      broadcast_scalar(out_grad, a_data.shape_) * b_data);
-    } else if (a_shape.ndim() == 2 && b_shape.ndim() == 2) {
-      // Case 2: both 2-D arrays, matrix multiplication
-      MMImpl<xpu>(ctx, a, ograd, grad_b, req[1], true, false);
-      MMImpl<xpu>(ctx, ograd, b, grad_a, req[0], false, true);
-    } else if (a_shape.ndim() == 0 && b_shape.ndim() == 0) {
+    if (a_shape.ndim() == 0 && b_shape.ndim() == 0) {
       // Case 3: both 0-D scalars, equivalent to multiply
       Tensor<xpu, 1, DType> out_grad = ograd.get_with_shape<xpu, 1, DType>(Shape1(1), s);
       Tensor<xpu, 1, DType> a_data = a.get_with_shape<xpu, 1, DType>(Shape1(1), s);
@@ -225,46 +161,17 @@ inline void NumpyDotBackward(const nnvm::NodeAttrs& attrs,
 
       ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
         ctx, {TBlob(temp_space)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_);
-    } else if (b_shape.ndim() == 1) {
-      size_t na = a_shape[a_shape.ndim() - 1];
-      size_t ma = a_shape.Size() / na;
-      Tensor<xpu, 2, DType> a_ =
-        a.get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
-      Tensor<xpu, 2, DType> b_ =
-        b.get_with_shape<xpu, 2, DType>(Shape2(b_shape.Size(), 1), s);
-      Tensor<xpu, 2, DType> grad_a_ =
-        grad_a.get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
-      Tensor<xpu, 2, DType> grad_b_ =
-        grad_b.get_with_shape<xpu, 2, DType>(Shape2(b_shape.Size(), 1), s);
-      Tensor<xpu, 2, DType> ograd_ =
-        ograd.get_with_shape<xpu, 2, DType>(Shape2(ograd.shape_.Size(), 1), s);
-      // Case 4: a is N-D array and b is 1-D array, sum product over the last axis
-      MMImpl<xpu>(ctx, TBlob(a_), TBlob(ograd_), TBlob(grad_b_), req[1], true, false);
-      MMImpl<xpu>(ctx, TBlob(ograd_), TBlob(b_), TBlob(grad_a_), req[0], false, true);
+    } else if (b_shape.ndim() < 3) {
+      // Case 1, 2, 4, 5: a is N-D array (N >= 1) and b is vector or matrix, sum product
+      //        over the last axis of a and the first axis of b
+      TensordotIntAxesBackwardImpl<xpu>(1, ctx, ograd, a, b, grad_a, grad_b, req);
     } else {
-      // Case 5: a is N-D array and b is M-D array, sum product over the last axis
+      // Case 5.5: a is N-D array and b is M-D array (M > 2), sum product over the last axis
       //         of a and the 2nd-to-last axis of b
-      // TODO(haojin2): To be implemented...
-      if (b_shape.ndim() != 2) {
-        LOG(FATAL) << "Only support case 5 when b.ndim = 2";
-      } else {  // a is N-D, b is 2D
-        index_t na = a_shape[a_shape.ndim() - 1];
-        index_t ma = a_shape.Size() / na;
-        index_t nograd = ograd.shape_[ograd.shape_.ndim() - 1];
-        index_t mograd = ograd.shape_.Size() / nograd;
-
-        Tensor<xpu, 2, DType> a_2d =
-            a.get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
-        Tensor<xpu, 2, DType> grad_a_2d =
-            grad_a.get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
-        Tensor<xpu, 2, DType> b_2d = b.FlatTo2D<xpu, DType>(s);
-        Tensor<xpu, 2, DType> grad_b_2d = grad_b.FlatTo2D<xpu, DType>(s);
-        Tensor<xpu, 2, DType> ograd_2d =
-            ograd.get_with_shape<xpu, 2, DType>(Shape2(mograd, nograd), s);
-
-        MMImpl<xpu>(ctx, TBlob(a_2d), TBlob(ograd_2d), TBlob(grad_b_2d), req[1], true, false);
-        MMImpl<xpu>(ctx, TBlob(ograd_2d), TBlob(b_2d), TBlob(grad_a_2d), req[0], false, true);
-      }
+      const Tuple<int> a_axes_summed({a_shape.ndim() - 1});
+      const Tuple<int> b_axes_summed({b_shape.ndim() - 2});
+      TensordotBackwardImpl<xpu>(a_axes_summed, b_axes_summed, ctx, ograd, a, b, grad_a,
+          grad_b, req);
     }
   });
 }
diff --git a/src/operator/numpy/np_tensordot_op-inl.h b/src/operator/numpy/np_tensordot_op-inl.h
new file mode 100644
index 0000000..9b88b81
--- /dev/null
+++ b/src/operator/numpy/np_tensordot_op-inl.h
@@ -0,0 +1,556 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_tensordot_op-inl.h
+ * \brief CPU Implementation of numpy-compatible tensordot
+ */
+#ifndef MXNET_OPERATOR_NUMPY_NP_TENSORDOT_OP_INL_H_
+#define MXNET_OPERATOR_NUMPY_NP_TENSORDOT_OP_INL_H_
+
+#include <vector>
+#include "np_matrix_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+using namespace mshadow;
+
+struct TensordotParam : public dmlc::Parameter<TensordotParam> {
+  mxnet::Tuple<int> a_axes_summed, b_axes_summed;
+  DMLC_DECLARE_PARAMETER(TensordotParam) {
+    DMLC_DECLARE_FIELD(a_axes_summed);
+    DMLC_DECLARE_FIELD(b_axes_summed);
+  }
+};
+
+/**
+ * Gets matrix dimensions of a and b after transpose and reshape.
+ */
+inline void GetMatrixDimensions(int* ad1,
+                                int* ad2,
+                                int* bd1,
+                                int* bd2,
+                                const mxnet::Tuple<int>& a_axes_remained,
+                                const mxnet::Tuple<int>& a_axes_summed,
+                                const mxnet::Tuple<int>& b_axes_remained,
+                                const mxnet::Tuple<int>& b_axes_summed,
+                                const mxnet::TShape& a_shape,
+                                const mxnet::TShape& b_shape) {
+  *ad1 = 1;
+  *ad2 = 1;
+  *bd1 = 1;
+  *bd2 = 1;
+
+  for (int i = 0; i < a_axes_remained.ndim(); i++) {
+    *ad1 *= a_shape[a_axes_remained[i]];
+  }
+  for (int i = 0; i < a_axes_summed.ndim(); i++) {
+    *ad2 *= a_shape[a_axes_summed[i]];
+  }
+  for (int i = 0; i < b_axes_summed.ndim(); i++) {
+    *bd1 *= b_shape[b_axes_summed[i]];
+  }
+  for (int i = 0; i < b_axes_remained.ndim(); i++) {
+    *bd2 *= b_shape[b_axes_remained[i]];
+  }
+}
+
+/**
+ * gets new axes of a and b after transpose and reshape.
+ */
+inline void GetReorderedAxes(const mxnet::Tuple<int>& a_axes_summed,
+                             mxnet::Tuple<int>* a_axes_remained,
+                             mxnet::Tuple<int>* a_axes,
+                             const mxnet::Tuple<int>& b_axes_summed,
+                             mxnet::Tuple<int>* b_axes_remained,
+                             mxnet::Tuple<int>* b_axes,
+                             const mxnet::TShape& a_shape,
+                             const mxnet::TShape& b_shape) {
+  std::vector<int> a_axes_remained_vector;
+  for (int i = 0; i < a_shape.ndim(); i++) {
+    a_axes_remained_vector.push_back(i);
+  }
+  for (auto& i : a_axes_summed) {
+    a_axes_remained_vector.erase(std::find(a_axes_remained_vector.begin(),
+      a_axes_remained_vector.end(), i));
+  }
+  *a_axes_remained = mxnet::Tuple<int>(a_axes_remained_vector);
+
+  std::vector<int> a_axes_vector(a_axes_remained_vector);
+  for (auto& i : a_axes_summed) {
+    a_axes_vector.push_back(i);
+  }
+  *a_axes = mxnet::Tuple<int>(a_axes_vector);
+
+  std::vector<int> b_axes_remained_vector;
+  for (int i = 0; i < b_shape.ndim(); i++) {
+    b_axes_remained_vector.push_back(i);
+  }
+  for (auto& i : b_axes_summed) {
+    b_axes_remained_vector.erase(std::find(b_axes_remained_vector.begin(),
+                                           b_axes_remained_vector.end(), i));
+  }
+  *b_axes_remained = mxnet::Tuple<int>(b_axes_remained_vector);
+
+  std::vector<int> b_axes_vector;
+  for (auto& i : b_axes_summed) {
+    b_axes_vector.push_back(i);
+  }
+  for (auto& i : b_axes_remained_vector) {
+    b_axes_vector.push_back(i);
+  }
+  *b_axes = mxnet::Tuple<int>(b_axes_vector);
+}
+
+/**
+ * gets shapes of a and b after transpose and reshape.
+ */
+inline mxnet::TShape GetReorderedShape(const mxnet::TShape& shape, const mxnet::Tuple<int>& axes) {
+  mxnet::TShape new_shape(shape);
+  for (int i = 0; i < axes.ndim(); i++) {
+    new_shape[i] = shape[axes[i]];
+  }
+  return new_shape;
+}
+
+/**
+ * gets matrix dot. Reshapes tensor a as ad1-by-ad2 matrix, tensor b as bd1-by-bd2 matrix, then 
+ * calculates matrix dot a * b and stores in tensor out.
+ */
+template<typename xpu>
+void MatrixDot(const OpContext& ctx,
+               const TBlob& a,
+               const TBlob& b,
+               const TBlob& out,
+               const OpReqType req,
+               const int ad1,
+               const int ad2,
+               const int bd1,
+               const int bd2,
+               const bool aT = false,
+               const bool bT = false) {
+  using namespace mshadow;
+  using namespace mshadow_op;
+
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+
+  MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, {
+    Tensor<xpu, 2, DType> a_tensor = a.get_with_shape<xpu, 2, DType>(Shape2(ad1, ad2), s);
+    Tensor<xpu, 2, DType> b_tensor = b.get_with_shape<xpu, 2, DType>(Shape2(bd1, bd2), s);
+
+    if (aT && bT) {
+      CHECK_EQ(ad1, bd2);
+      Tensor<xpu, 2, DType> out_tensor = out.get_with_shape<xpu, 2, DType>(Shape2(ad2, bd1), s);
+      ASSIGN_DISPATCH(out_tensor, req, dot(a_tensor.T(), b_tensor.T()));
+    } else if (aT && !bT) {
+      CHECK_EQ(ad1, bd1);
+      Tensor<xpu, 2, DType> out_tensor = out.get_with_shape<xpu, 2, DType>(Shape2(ad2, bd2), s);
+      ASSIGN_DISPATCH(out_tensor, req, dot(a_tensor.T(), b_tensor));
+    } else if (!aT && bT) {
+      CHECK_EQ(ad2, bd2);
+      Tensor<xpu, 2, DType> out_tensor = out.get_with_shape<xpu, 2, DType>(Shape2(ad1, bd1), s);
+      ASSIGN_DISPATCH(out_tensor, req, dot(a_tensor, b_tensor.T()));
+    } else {
+      CHECK_EQ(ad2, bd1);
+      Tensor<xpu, 2, DType> out_tensor = out.get_with_shape<xpu, 2, DType>(Shape2(ad1, bd2), s);
+      ASSIGN_DISPATCH(out_tensor, req, dot(a_tensor, b_tensor));
+    }
+  });
+}
+
+/**
+ * Calculates tensordot.
+ */
+template<typename xpu>
+void TensordotImpl(const Tuple<int>& a_axes_summed,
+                   const Tuple<int>& b_axes_summed,
+                   const OpContext& ctx,
+                   const TBlob& a,
+                   const TBlob& b,
+                   const TBlob& out,
+                   const std::vector<OpReqType>& req) {
+  if (req[0] == kNullOp) {
+    return;
+  }
+
+  if (out.shape_.Size() == 0U) {
+    return;  // zero-size output, no need to launch kernel
+  }
+
+  const mxnet::TShape& a_shape = a.shape_;
+  const mxnet::TShape& b_shape = b.shape_;
+
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  CHECK_EQ(out.type_flag_, a.type_flag_)
+      << "Binary function only support input/output with the same type";
+  CHECK_EQ(out.type_flag_, b.type_flag_)
+      << "Binary function only support input/output with the same type";
+  CHECK(out.type_flag_ == kFloat32 || out.type_flag_ == kFloat64 ||
+       (out.type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
+      << "Tensordot only supports float32/float64 for CPU, and float16/float32/float64 for GPU";
+
+  Tuple<int> a_axes_remained;
+  Tuple<int> b_axes_remained;
+  Tuple<int> a_axes;
+  Tuple<int> b_axes;
+  GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
+                   &b_axes, a_shape, b_shape);
+
+  int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
+  GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed,
+                      b_axes_remained, b_axes_summed, a_shape, b_shape);
+
+  mxnet::TShape a_temp_shape = GetReorderedShape(a_shape, a_axes);
+  mxnet::TShape b_temp_shape = GetReorderedShape(b_shape, b_axes);
+
+  MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, {
+    if (a_shape.Size() == 0U || b_shape.Size() == 0U) {  // 0-size input
+      if (req[0] != kAddTo) {
+        Tensor<xpu, 1, DType> out_data = out.get_with_shape<xpu, 1, DType>(
+            Shape1(out.shape_.Size()), s);
+        out_data = static_cast<DType>(0);
+      }
+      return;
+    }
+
+    Tensor<xpu, 1, DType> workspace = ctx.requested[0].get_space_typed<xpu, 1, DType>
+      (Shape1(a.Size() + b.Size()), s);
+    DType* a_ptr = reinterpret_cast<DType*>(workspace.dptr_);
+    DType* b_ptr = reinterpret_cast<DType*>(workspace.dptr_ + a.Size());
+    TBlob a_res = TBlob(a_ptr, a_temp_shape, xpu::kDevMask);
+    TBlob b_res = TBlob(b_ptr, b_temp_shape, xpu::kDevMask);
+
+    mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, a, a_res,
+                                  mxnet::TShape(a_axes.begin(), a_axes.end()));
+    mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, b, b_res,
+                                  mxnet::TShape(b_axes.begin(), b_axes.end()));
+
+    MatrixDot<xpu>(ctx, a_res, b_res, out, req[0], ad1, ad2, bd1, bd2);
+  });
+}
+
+/**
+ * forward function
+ */
+template<typename xpu>
+void TensordotOpForward(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const std::vector<TBlob>& inputs,
+                        const std::vector<OpReqType>& req,
+                        const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+
+  const TBlob& a = inputs[0];
+  const TBlob& b = inputs[1];
+  const TBlob& out = outputs[0];
+
+  const TensordotParam& param = nnvm::get<TensordotParam>(attrs.parsed);
+  const Tuple<int>& a_axes_summed = param.a_axes_summed;
+  const Tuple<int>& b_axes_summed = param.b_axes_summed;
+
+  TensordotImpl<xpu>(a_axes_summed, b_axes_summed, ctx, a, b, out, req);
+}
+
+/**
+ * gets shapes for inverse transpose.
+ */
+inline mxnet::TShape GetReverseShape(const mxnet::Tuple<int>& shape) {
+  mxnet::TShape shape2(shape.begin(), shape.end());
+  for (int i = 0; i < shape.ndim(); i++) {
+    shape2[shape[i]] = i;
+  }
+  return shape2;
+}
+
+/**
+ * calculates tensordot derivative.
+ */
+template<typename xpu>
+void TensordotBackwardImpl(const Tuple<int>& a_axes_summed,
+                           const Tuple<int>& b_axes_summed,
+                           const OpContext& ctx,
+                           const TBlob& out_grad,
+                           const TBlob& a,
+                           const TBlob& b,
+                           const TBlob& grad_a,
+                           const TBlob& grad_b,
+                           const std::vector<OpReqType>& req) {
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+
+  const mxnet::TShape& a_shape = a.shape_;
+  const mxnet::TShape& b_shape = b.shape_;
+
+  Tuple<int> a_axes_remained;
+  Tuple<int> b_axes_remained;
+  Tuple<int> a_axes;
+  Tuple<int> b_axes;
+  GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
+                   &b_axes, a_shape, b_shape);
+
+  int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
+  GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed,
+                      b_axes_remained, b_axes_summed, a_shape, b_shape);
+
+  std::vector<int> a_T_axes;
+  for (int i = 0; i < a_axes_summed.ndim(); i++) {
+    a_T_axes.push_back(a_axes_summed[i]);
+  }
+  for (int i = 0; i < a_axes_remained.ndim(); i++) {
+    a_T_axes.push_back(a_axes_remained[i]);
+  }
+  mxnet::TShape a_temp_shape(GetReorderedShape(a_shape, a_axes));
+  mxnet::TShape a_T_temp_shape(GetReorderedShape(a_shape, a_T_axes));
+
+  std::vector<int> b_T_axes;
+  for (int i = 0; i < b_axes_remained.ndim(); i++) {
+    b_T_axes.push_back(b_axes_remained[i]);
+  }
+  for (int i = 0; i < b_axes_summed.ndim(); i++) {
+    b_T_axes.push_back(b_axes_summed[i]);
+  }
+  mxnet::TShape b_temp_shape(GetReorderedShape(b_shape, b_axes));
+  mxnet::TShape b_T_temp_shape(GetReorderedShape(b_shape, b_T_axes));
+
+  MSHADOW_REAL_TYPE_SWITCH(out_grad.type_flag_, DType, {
+    Tensor<xpu, 1, DType> workspace = ctx.requested[0].get_space_typed<xpu, 1, DType>
+      (Shape1((a.Size() + b.Size()) * 2), s);
+    DType* a_ptr = reinterpret_cast<DType*>(workspace.dptr_);
+    DType* a_ptr2 = reinterpret_cast<DType*>(workspace.dptr_ + a.Size());
+    DType* b_ptr = reinterpret_cast<DType*>(workspace.dptr_ + 2 * a.Size());
+    DType* b_ptr2 = reinterpret_cast<DType*>(workspace.dptr_ + 2 * a.Size() + b.Size());
+
+    TBlob a_res = TBlob(a_ptr, a_temp_shape, xpu::kDevMask);
+    TBlob b_res = TBlob(b_ptr, b_temp_shape, xpu::kDevMask);
+    TBlob a_res2 = TBlob(a_ptr2, a_T_temp_shape, xpu::kDevMask);
+    TBlob b_res2 = TBlob(b_ptr2, b_T_temp_shape, xpu::kDevMask);
+
+    mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, a, a_res2,
+                                  mxnet::TShape(a_T_axes.begin(), a_T_axes.end()));
+    mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, b, b_res2,
+                                  mxnet::TShape(b_T_axes.begin(), b_T_axes.end()));
+
+    MatrixDot<xpu>(ctx, a_res2, out_grad, b_res, req[1], ad2, ad1, ad1, bd2);
+    MatrixDot<xpu>(ctx, out_grad, b_res2, a_res, req[0], ad1, bd2, bd2, bd1);
+
+    mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, a_res, grad_a, GetReverseShape(a_axes));
+    mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, b_res, grad_b, GetReverseShape(b_axes));
+  });
+}
+
+/**
+ * backward function.
+ */
+template<typename xpu>
+void TensordotOpBackward(const nnvm::NodeAttrs& attrs,
+                         const OpContext& ctx,
+                         const std::vector<TBlob>& inputs,
+                         const std::vector<OpReqType>& req,
+                         const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 2U);
+  CHECK_EQ(req.size(), 2U);
+
+  const TBlob& out_grad = inputs[0];
+  const TBlob& a = inputs[1];
+  const TBlob& b = inputs[2];
+  const TBlob& grad_a = outputs[0];
+  const TBlob& grad_b = outputs[1];
+
+  const TensordotParam& param = nnvm::get<TensordotParam>(attrs.parsed);
+  const Tuple<int>& a_axes_summed = param.a_axes_summed;
+  const Tuple<int>& b_axes_summed = param.b_axes_summed;
+
+  TensordotBackwardImpl<xpu>(a_axes_summed, b_axes_summed, ctx, out_grad, a, b, grad_a,
+                             grad_b, req);
+}
+
+struct TensordotIntAxesParam : public dmlc::Parameter<TensordotIntAxesParam> {
+  int axes;
+  DMLC_DECLARE_PARAMETER(TensordotIntAxesParam) {
+    DMLC_DECLARE_FIELD(axes);
+  }
+};
+
+/**
+ * gets summed axes of a and b from parameter axes.
+ */
+inline void GetSummedAxes(mxnet::Tuple<int>* a_axes_summed_ptr,
+                          mxnet::Tuple<int>* b_axes_summed_ptr,
+                          const int axes,
+                          const mxnet::TShape& a_shape) {
+  std::vector<int> a_axes_summed_vector;
+  for (int i = 0; i < axes; i++) {
+    a_axes_summed_vector.push_back(a_shape.ndim() - axes + i);
+  }
+  *a_axes_summed_ptr = mxnet::Tuple<int>(a_axes_summed_vector);
+
+  std::vector<int> b_axes_summed_vector;
+  for (int i = 0; i < axes; i++) {
+    b_axes_summed_vector.push_back(i);
+  }
+  *b_axes_summed_ptr = mxnet::Tuple<int>(b_axes_summed_vector);
+}
+
+/**
+ * Calculates tensordot.
+ */
+template<typename xpu>
+void TensordotIntAxesImpl(const int axes,
+                          const OpContext& ctx,
+                          const TBlob& a,
+                          const TBlob& b,
+                          const TBlob& out,
+                          const OpReqType req) {
+  if (req == kNullOp) {
+    return;
+  }
+
+  if (out.shape_.Size() == 0U) {
+    return;  // zero-size output, no need to launch kernel
+  }
+
+  const mxnet::TShape& a_shape = a.shape_;
+  const mxnet::TShape& b_shape = b.shape_;
+
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  CHECK_EQ(out.type_flag_, a.type_flag_)
+      << "Binary function only support input/output with the same type";
+  CHECK_EQ(out.type_flag_, b.type_flag_)
+      << "Binary function only support input/output with the same type";
+  CHECK(out.type_flag_ == kFloat32 || out.type_flag_ == kFloat64 ||
+       (out.type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
+      << "Tensordot only supports float32/float64 for CPU, and float16/float32/float64 for GPU";
+
+  Tuple<int> a_axes_summed;
+  Tuple<int> b_axes_summed;
+  GetSummedAxes(&a_axes_summed, &b_axes_summed, axes, a_shape);
+
+  Tuple<int> a_axes_remained;
+  Tuple<int> b_axes_remained;
+  Tuple<int> a_axes;
+  Tuple<int> b_axes;
+  GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
+                   &b_axes, a_shape, b_shape);
+
+  int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
+  GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed,
+                      b_axes_remained, b_axes_summed, a_shape, b_shape);
+
+  MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, {
+    if (a_shape.Size() == 0U || b_shape.Size() == 0U) {  // 0-size input
+      if (req != kAddTo) {
+        Tensor<xpu, 1, DType> out_data = out.get_with_shape<xpu, 1, DType>(
+            Shape1(out.shape_.Size()), s);
+        out_data = static_cast<DType>(0);
+      }
+      return;
+    }
+
+    MatrixDot<xpu>(ctx, a, b, out, req, ad1, ad2, bd1, bd2);
+  });
+}
+
+/**
+ * forward function
+ */
+template<typename xpu>
+void TensordotIntAxesOpForward(const nnvm::NodeAttrs& attrs,
+                               const OpContext& ctx,
+                               const std::vector<TBlob>& inputs,
+                               const std::vector<OpReqType>& req,
+                               const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+
+  const TBlob& a = inputs[0];
+  const TBlob& b = inputs[1];
+  const TBlob& out = outputs[0];
+
+  const TensordotIntAxesParam& param = nnvm::get<TensordotIntAxesParam>(attrs.parsed);
+  const int axes = param.axes;
+
+  TensordotIntAxesImpl<xpu>(axes, ctx, a, b, out, req[0]);
+}
+
+template<typename xpu>
+void TensordotIntAxesBackwardImpl(const int axes,
+                                  const OpContext& ctx,
+                                  const TBlob& out_grad,
+                                  const TBlob& a,
+                                  const TBlob& b,
+                                  const TBlob& grad_a,
+                                  const TBlob& grad_b,
+                                  const std::vector<OpReqType>& req) {
+  const mxnet::TShape& a_shape = a.shape_;
+  const mxnet::TShape& b_shape = b.shape_;
+
+  Tuple<int> a_axes_summed;
+  Tuple<int> b_axes_summed;
+  GetSummedAxes(&a_axes_summed, &b_axes_summed, axes, a_shape);
+
+  Tuple<int> a_axes_remained;
+  Tuple<int> b_axes_remained;
+  Tuple<int> a_axes;
+  Tuple<int> b_axes;
+  GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
+                   &b_axes, a_shape, b_shape);
+
+  int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
+  GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed,
+                      b_axes_remained, b_axes_summed, a_shape, b_shape);
+
+  MSHADOW_REAL_TYPE_SWITCH(out_grad.type_flag_, DType, {
+    MatrixDot<xpu>(ctx, a, out_grad, grad_b, req[1], ad1, ad2, ad1, bd2, true, false);
+    MatrixDot<xpu>(ctx, out_grad, b, grad_a, req[0], ad1, bd2, bd1, bd2, false, true);
+  });
+}
+
+/**
+ * backward function.
+ */
+template<typename xpu>
+void TensordotIntAxesOpBackward(const nnvm::NodeAttrs& attrs,
+                                const OpContext& ctx,
+                                const std::vector<TBlob>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 2U);
+  CHECK_EQ(req.size(), 2U);
+
+  const TBlob& out_grad = inputs[0];
+  const TBlob& a = inputs[1];
+  const TBlob& b = inputs[2];
+  const TBlob& grad_a = outputs[0];
+  const TBlob& grad_b = outputs[1];
+
+  const TensordotIntAxesParam& param = nnvm::get<TensordotIntAxesParam>(attrs.parsed);
+  const int axes = param.axes;
+
+  TensordotIntAxesBackwardImpl<xpu>(axes, ctx, out_grad, a, b, grad_a, grad_b, req);
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NUMPY_NP_TENSORDOT_OP_INL_H_
diff --git a/src/operator/numpy/np_tensordot_op.cc b/src/operator/numpy/np_tensordot_op.cc
new file mode 100644
index 0000000..6d6756e
--- /dev/null
+++ b/src/operator/numpy/np_tensordot_op.cc
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_tensordot_op.cc
+ * \brief CPU Implementation of numpy-compatible tensordot
+ */
+
+#include <string>
+#include "np_tensordot_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+bool TensordotOpShape(const nnvm::NodeAttrs& attrs,
+                      mxnet::ShapeVector *in_attrs,
+                      mxnet::ShapeVector *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+
+  const mxnet::TShape& a_shape = in_attrs->at(0);
+  const mxnet::TShape& b_shape = in_attrs->at(1);
+
+  if (!ndim_is_known(a_shape) || !ndim_is_known(b_shape)) {
+    return false;
+  }
+
+  CHECK_GE(a_shape.ndim(), 1)
+      << "First input tensor should be at least 1 dimension";
+
+  CHECK_GE(b_shape.ndim(), 1)
+      << "Second input tensor should be at least 1 dimension";
+
+  const TensordotParam& param = nnvm::get<TensordotParam>(attrs.parsed);
+  const Tuple<int>& a_axes_summed = param.a_axes_summed;
+  const Tuple<int>& b_axes_summed = param.b_axes_summed;
+
+  Tuple<int> a_axes_remained;
+  Tuple<int> b_axes_remained;
+  Tuple<int> a_axes;
+  Tuple<int> b_axes;
+  GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
+                   &b_axes, a_shape, b_shape);
+
+  CHECK_EQ(a_axes_summed.ndim(), b_axes_summed.ndim());
+
+  mxnet::TShape out_shape(a_axes_remained.ndim() + b_axes_remained.ndim(), -1);
+  for (int i = 0; i < a_axes_remained.ndim(); i++) {
+    out_shape[i] = a_shape[a_axes_remained[i]];
+  }
+  for (int i = 0; i < b_axes_remained.ndim(); i++) {
+    out_shape[a_axes_remained.ndim() + i] = b_shape[b_axes_remained[i]];
+  }
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape);
+
+  mxnet::TShape tem_shape1(a_axes.ndim(), -1);
+  for (int i = 0; i < a_axes_remained.ndim(); i++) {
+    tem_shape1[a_axes_remained[i]] = out_shape[i];
+  }
+  for (int i = 0; i < a_axes_summed.ndim(); i++) {
+    tem_shape1[a_axes_summed[i]] = b_shape[b_axes_summed[i]];
+  }
+  SHAPE_ASSIGN_CHECK(*in_attrs, 0, tem_shape1);
+
+  mxnet::TShape tem_shape2(b_axes.ndim(), -1);
+  for (int i = 0; i < b_axes_remained.ndim(); i++) {
+    tem_shape2[b_axes_remained[i]] = out_shape[a_axes_remained.ndim() + i];
+  }
+  for (int i = 0; i < b_axes_summed.ndim(); i++) {
+    tem_shape2[b_axes_summed[i]] = a_shape[a_axes_summed[i]];
+  }
+  SHAPE_ASSIGN_CHECK(*in_attrs, 1, tem_shape2);
+
+  return shape_is_known(*in_attrs) && shape_is_known(*out_attrs);
+}
+
+DMLC_REGISTER_PARAMETER(TensordotParam);
+
+NNVM_REGISTER_OP(_npi_tensordot)
+.set_attr_parser(mxnet::op::ParamParser<TensordotParam>)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"a", "b"};
+  })
+.set_attr<mxnet::FInferShape>("FInferShape", TensordotOpShape)
+.set_attr<nnvm::FInferType>("FInferType", mxnet::op::ElemwiseType<2, 1>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", TensordotOpForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", mxnet::op::ElemwiseGradUseIn{"_backward_npi_tensordot"})
+.add_argument("a", "NDArray-or-Symbol", "First input")
+.add_argument("b", "NDArray-or-Symbol", "Second input")
+.add_arguments(TensordotParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_npi_tensordot)
+.set_attr_parser(mxnet::op::ParamParser<TensordotParam>)
+.set_num_inputs(3)
+.set_num_outputs(2)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", TensordotOpBackward<cpu>);
+
+bool TensordotIntAxesOpShape(const nnvm::NodeAttrs& attrs,
+                             mxnet::ShapeVector *in_attrs,
+                             mxnet::ShapeVector *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+
+  const mxnet::TShape& a_shape = in_attrs->at(0);
+  const mxnet::TShape& b_shape = in_attrs->at(1);
+
+  if (!ndim_is_known(a_shape) || !ndim_is_known(b_shape)) {
+    return false;
+  }
+
+  CHECK_GE(a_shape.ndim(), 1)
+      << "First input tensor should be at least 1 dimension";
+
+  CHECK_GE(b_shape.ndim(), 1)
+      << "Second input tensor should be at least 1 dimension";
+
+  const TensordotIntAxesParam& param = nnvm::get<TensordotIntAxesParam>(attrs.parsed);
+  const int& axes = param.axes;
+
+  Tuple<int> a_axes_summed;
+  Tuple<int> b_axes_summed;
+  GetSummedAxes(&a_axes_summed, &b_axes_summed, axes, a_shape);
+
+  Tuple<int> a_axes_remained;
+  Tuple<int> b_axes_remained;
+  Tuple<int> a_axes;
+  Tuple<int> b_axes;
+  GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
+                   &b_axes, a_shape, b_shape);
+
+  CHECK_EQ(a_axes_summed.ndim(), b_axes_summed.ndim());
+
+  mxnet::TShape out_shape(a_axes_remained.ndim() + b_axes_remained.ndim(), -1);
+  for (int i = 0; i < a_axes_remained.ndim(); i++) {
+    out_shape[i] = a_shape[a_axes_remained[i]];
+  }
+  for (int i = 0; i < b_axes_remained.ndim(); i++) {
+    out_shape[a_axes_remained.ndim() + i] = b_shape[b_axes_remained[i]];
+  }
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape);
+
+  mxnet::TShape tem_shape1(a_axes.ndim(), -1);
+  for (int i = 0; i < a_axes_remained.ndim(); i++) {
+    tem_shape1[a_axes_remained[i]] = out_shape[i];
+  }
+  for (int i = 0; i < a_axes_summed.ndim(); i++) {
+    tem_shape1[a_axes_summed[i]] = b_shape[b_axes_summed[i]];
+  }
+  SHAPE_ASSIGN_CHECK(*in_attrs, 0, tem_shape1);
+
+  mxnet::TShape tem_shape2(b_axes.ndim(), -1);
+  for (int i = 0; i < b_axes_remained.ndim(); i++) {
+    tem_shape2[b_axes_remained[i]] = out_shape[a_axes_remained.ndim() + i];
+  }
+  for (int i = 0; i < b_axes_summed.ndim(); i++) {
+    tem_shape2[b_axes_summed[i]] = a_shape[a_axes_summed[i]];
+  }
+  SHAPE_ASSIGN_CHECK(*in_attrs, 1, tem_shape2);
+
+  return shape_is_known(*in_attrs) && shape_is_known(*out_attrs);
+}
+
+DMLC_REGISTER_PARAMETER(TensordotIntAxesParam);
+
+NNVM_REGISTER_OP(_npi_tensordot_int_axes)
+.set_attr_parser(mxnet::op::ParamParser<TensordotIntAxesParam>)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"a", "b"};
+  })
+.set_attr<mxnet::FInferShape>("FInferShape", TensordotIntAxesOpShape)
+.set_attr<nnvm::FInferType>("FInferType", mxnet::op::ElemwiseType<2, 1>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", TensordotIntAxesOpForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient",
+  mxnet::op::ElemwiseGradUseIn{"_backward_npi_tensordot_int_axes"})
+.add_argument("a", "NDArray-or-Symbol", "First input")
+.add_argument("b", "NDArray-or-Symbol", "Second input")
+.add_arguments(TensordotIntAxesParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_npi_tensordot_int_axes)
+.set_attr_parser(mxnet::op::ParamParser<TensordotIntAxesParam>)
+.set_num_inputs(3)
+.set_num_outputs(2)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", TensordotIntAxesOpBackward<cpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/np_tensordot_op.cu b/src/operator/numpy/np_tensordot_op.cu
new file mode 100644
index 0000000..e1d8a0b
--- /dev/null
+++ b/src/operator/numpy/np_tensordot_op.cu
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.ΓΈ
+ */
+
+/*!
+ * \file np_tensordot_inplace.cu
+ * \brief GPU Implementation of numpy-compatible tensordot
+ */
+
+#include "np_tensordot_op-inl.h"
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_npi_tensordot)
+.set_attr<FCompute>("FCompute<gpu>", TensordotOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_npi_tensordot)
+.set_attr<FCompute>("FCompute<gpu>", TensordotOpBackward<gpu>);
+
+NNVM_REGISTER_OP(_npi_tensordot_int_axes)
+.set_attr<FCompute>("FCompute<gpu>", TensordotIntAxesOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_npi_tensordot_int_axes)
+.set_attr<FCompute>("FCompute<gpu>", TensordotIntAxesOpBackward<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 6e3ca16..cd323e2 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -27,6 +27,156 @@ from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndar
 from mxnet.test_utils import check_numeric_gradient
 from common import assertRaises, with_seed
 import random
+import collections
+
+
+@with_seed()
+@npx.use_np_shape
+def test_np_tensordot():
+    class TestTensordot(HybridBlock):
+        def __init__(self, axes):
+            super(TestTensordot, self).__init__()
+            self._axes = axes
+            
+        def hybrid_forward(self, F, a, b):
+            return F.np.tensordot(a, b, self._axes)
+
+    def tensordot_backward(a, b, axes=2):
+        if (a.ndim < 1) or (b.ndim < 1):
+            raise ValueError('An input is zero-dim')
+
+        if _np.isscalar(axes):
+            a_axes_summed = [i + a.ndim - axes for i in range(axes)]
+            b_axes_summed = [i for i in range(axes)]
+        else:
+            if len(axes) != 2:
+                raise ValueError('Axes must consist of two arrays.')
+            a_axes_summed, b_axes_summed = axes
+            if _np.isscalar(a_axes_summed):
+                a_axes_summed = a_axes_summed,
+            if _np.isscalar(b_axes_summed):
+                b_axes_summed = b_axes_summed,
+
+        if len(a_axes_summed) != len(b_axes_summed):
+            raise ValueError('Axes length mismatch') 
+
+        a_axes_remained = []
+        for i in range(a.ndim):
+            if not (i in a_axes_summed):
+                a_axes_remained.append(i)
+        a_axes = a_axes_remained[:] + a_axes_summed[:]
+
+        b_axes_remained = []
+        for i in range(b.ndim):
+            if not (i in b_axes_summed):
+                b_axes_remained.append(i)
+        b_axes = b_axes_summed[:] + b_axes_remained[:]
+
+        ad1 = _np.prod([a.shape[i] for i in a_axes_remained]) if len(a_axes_remained) > 0 else 1
+        ad2 = _np.prod([a.shape[i] for i in a_axes_summed]) if len(a_axes_summed) > 0 else 1
+        bd1 = _np.prod([b.shape[i] for i in b_axes_summed]) if len(b_axes_summed) > 0 else 1
+        bd2 = _np.prod([b.shape[i] for i in b_axes_remained]) if len(b_axes_remained) > 0 else 1
+
+        out_grad = _np.ones((ad1, bd2))
+
+        new_a = _np.transpose(a, a_axes)
+        new_a_shape = new_a.shape[:]
+        new_a = new_a.reshape((ad1, ad2))
+        new_b = _np.transpose(b, b_axes)
+        new_b_shape = new_b.shape[:]
+        new_b = new_b.reshape((bd1, bd2))
+
+        reverse_a_axes = [0 for i in a_axes]
+        for i in range(len(a_axes)):
+            reverse_a_axes[a_axes[i]] = i
+
+        reverse_b_axes = [0 for i in b_axes]
+        for i in range(len(b_axes)):
+            reverse_b_axes[b_axes[i]] = i
+
+        grad_b = _np.dot(new_a.T, out_grad).reshape(new_b_shape)
+        grad_b = _np.transpose(grad_b, reverse_b_axes)
+        grad_a = _np.dot(out_grad, new_b.T).reshape(new_a_shape)
+        grad_a = _np.transpose(grad_a, reverse_a_axes)
+
+        return [grad_a, grad_b]
+
+    # test non zero size input
+    tensor_shapes = [
+        ((3, 5), (5, 4), 1),  # (a_shape, b_shape, axes)
+        ((3,), (3,), 1),
+        ((3, 4, 5, 6, 7), (5, 6, 7, 1, 2), 3),
+        ((3, 5, 4, 6, 7), (7, 6, 5, 1, 2), [[1, 3, 4], [2, 1, 0]]),
+        ((3, 5, 4), (5, 4, 3), [[1, 0, 2], [0, 2, 1]]),
+        ((3, 5, 4), (5, 3, 4), [[2, 0], [2, 1]]),
+        ((2, 2), (2, 2), 2),
+        ((3, 5, 4), (5, ), [[1], [0]]),
+        ((2,), (2, 3), 1),
+        ((3,), (3,), 0),
+        ((2,), (2, 3), 0),
+        ((3, 5, 4), (5, ), 0),
+        ((2, 3, 4), (4, 3, 2), [[], []])
+    ]
+
+    for hybridize in [True, False]:
+        for a_shape, b_shape, axes in tensor_shapes:
+            for dtype in [_np.float32, _np.float64]:
+                test_tensordot = TestTensordot(axes)
+                if hybridize:
+                    test_tensordot.hybridize()
+                a = rand_ndarray(shape = a_shape, dtype = dtype).as_np_ndarray()
+                b = rand_ndarray(shape = b_shape, dtype = dtype).as_np_ndarray()
+                a.attach_grad()
+                b.attach_grad()
+
+                np_out = _np.tensordot(a.asnumpy(), b.asnumpy(), axes)
+                with mx.autograd.record():
+                    mx_out = test_tensordot(a, b)
+                assert mx_out.shape == np_out.shape
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol = 1e-3, atol = 1e-5)
+                mx_out.backward()
+                np_backward = tensordot_backward(a.asnumpy(), b.asnumpy(), axes)
+                assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol = 1e-3, atol=1e-5)
+                assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol = 1e-3, atol=1e-5)
+
+                # Test imperative once again
+                mx_out = np.tensordot(a, b, axes)
+                np_out = _np.tensordot(a.asnumpy(), b.asnumpy(), axes)
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+                # test numeric gradient
+                a_sym = mx.sym.Variable("a").as_np_ndarray()
+                b_sym = mx.sym.Variable("b").as_np_ndarray()
+                mx_sym = mx.sym.np.tensordot(a_sym, b_sym, axes).as_nd_ndarray()
+                check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), b.as_nd_ndarray()],
+                  rtol=1e-1, atol=1e-1, dtype = dtype)
+
+    # test zero size input
+    zero_shapes = [
+        ((3, 0), (0, 5), 1),
+        ((3, 0), (0, 4), [1, 0]),
+        ((0, 3), (3, 5), 1)
+    ]
+
+    for hybridize in [True, False]:
+        for a_shape, b_shape, axes in zero_shapes:
+            for dtype in [_np.float32, _np.float64]:
+                test_tensordot = TestTensordot(axes)
+                if hybridize:
+                    test_tensordot.hybridize()
+                a = rand_ndarray(shape = a_shape, dtype = dtype).as_np_ndarray()
+                b = rand_ndarray(shape = b_shape, dtype = dtype).as_np_ndarray()
+
+                np_out = _np.tensordot(a.asnumpy(), b.asnumpy(), axes)
+                with mx.autograd.record():
+                    mx_out = test_tensordot(a, b)
+                assert mx_out.shape == np_out.shape
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol = 1e-3, atol = 1e-5)
+
+                # Test imperative once again
+                mx_out = np.tensordot(a, b, axes)
+                np_out = _np.tensordot(a.asnumpy(), b.asnumpy(), axes)
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
 
 
 @with_seed()
@@ -102,17 +252,20 @@ def test_np_dot():
         ((3, 4, 5), ()),     # Case 3.5.1
         ((), (3, 4, 5)),     # Case 3.5.2
         ((3, 4, 5), (5, )),  # Case 4
-        ((3, 4, 5), (5, 2)),
-        ((5,), (5, 2))
+        ((3, 4, 5), (5, 2)), # Case 5
+        ((5,), (5, 2)),
+        ((3, 5, 4), (5, 4, 3)),  
+        ((3, 4), (5, 4, 3)),
+        ((4,), (5, 4, 3))
     ]
 
     eps = 1e-3
 
     for shape_a, shape_b in shapes:
         np_a = _np.random.uniform(-1.0, 1.0, shape_a)
-        np_a[abs(np_a) < eps] = 2 * eps;
+        np_a[abs(np_a) < eps] = 2 * eps
         np_b = _np.random.uniform(-1.0, 1.0, shape_b)
-        np_b[abs(np_b) < eps] = 2 * eps;
+        np_b[abs(np_b) < eps] = 2 * eps
         a = mx.nd.array(np_a)
         b = mx.nd.array(np_b)
         np_res = _np.dot(np_a, np_b)