You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:00:37 UTC

[incubator-mxnet] 01/42: [Do not review] [Do not merge] New numpy-compatible sum (#14739)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 6dc91b03f5546ccaef152080d48d753de763f439
Author: Hao Jin <hj...@gmail.com>
AuthorDate: Sun Apr 21 13:23:18 2019 -0700

    [Do not review] [Do not merge] New numpy-compatible sum (#14739)
    
    * Add numpy namespace and initial impl of np.sum (not complete)
    
    * Clean up
    
    * Fix import error
    
    * numpy sum
    
    * add test and backward data type support
    
    * add license to test_numpy_op.py
    
    * improve test to reduce flakiness
    
    * fix sanity build
    
    * extra numeric test and imperative test
    
    * add error message for initial argument
---
 python/mxnet/__init__.py                           |   1 +
 python/mxnet/base.py                               |  21 +-
 python/mxnet/ndarray/__init__.py                   |   2 +-
 .../mxnet/{symbol/__init__.py => ndarray/numpy.py} |  15 +-
 python/mxnet/{symbol => numpy}/__init__.py         |  17 +-
 python/mxnet/symbol/__init__.py                    |   2 +-
 python/mxnet/symbol/{__init__.py => numpy.py}      |  15 +-
 src/operator/numpy/np_broadcast_reduce_op.h        | 218 +++++++++++++++++++++
 src/operator/numpy/np_broadcast_reduce_op_value.cc |  78 ++++++++
 src/operator/numpy/np_broadcast_reduce_op_value.cu |  36 ++++
 src/operator/tensor/broadcast_reduce_op.h          |  74 +++++--
 tests/python/gpu/test_operator_gpu.py              |   1 +
 tests/python/unittest/test_numpy_op.py             |  92 +++++++++
 13 files changed, 512 insertions(+), 60 deletions(-)

diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index ab4bffd..a850b38 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -26,6 +26,7 @@ from . import engine
 from .base import MXNetError
 from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
 from . import base
+from . import numpy
 from . import contrib
 from . import ndarray
 from . import ndarray as nd
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 73fae48..c435317 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -561,7 +561,7 @@ def _as_list(obj):
         return [obj]
 
 
-_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_']
+_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_', '_numpy_']
 
 
 def _get_op_name_prefix(op_name):
@@ -607,6 +607,15 @@ def _init_op_module(root_namespace, module_name, make_op_func):
     # use mx.nd.contrib or mx.sym.contrib from now on
     contrib_module_name_old = "%s.contrib.%s" % (root_namespace, module_name)
     contrib_module_old = sys.modules[contrib_module_name_old]
+    # special handling of registering numpy ops
+    # only expose mxnet.numpy.op_name to users for imperative mode.
+    # Symbolic mode should be used in Gluon.
+    if module_name == 'ndarray':
+        numpy_module_name = "%s.numpy" % root_namespace
+        numpy_module = sys.modules[numpy_module_name]
+    else:
+        numpy_module_name = None
+        numpy_module = None
     submodule_dict = {}
     for op_name_prefix in _OP_NAME_PREFIX_LIST:
         submodule_dict[op_name_prefix] =\
@@ -645,6 +654,16 @@ def _init_op_module(root_namespace, module_name, make_op_func):
             function.__module__ = contrib_module_name_old
             setattr(contrib_module_old, function.__name__, function)
             contrib_module_old.__all__.append(function.__name__)
+        elif op_name_prefix == '_numpy_' and numpy_module_name is not None:
+            # only register numpy ops under mxnet.numpy in imperative mode
+            hdl = OpHandle()
+            check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
+            # TODO(reminisce): Didn't consider third level module here, e.g. mxnet.numpy.random.
+            func_name = name[len(op_name_prefix):]
+            function = make_op_func(hdl, name, func_name)
+            function.__module__ = numpy_module_name
+            setattr(numpy_module, function.__name__, function)
+            numpy_module.__all__.append(function.__name__)
 
 
 def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func):
diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py
index f09908e..a102399 100644
--- a/python/mxnet/ndarray/__init__.py
+++ b/python/mxnet/ndarray/__init__.py
@@ -17,7 +17,7 @@
 
 """NDArray API of MXNet."""
 
-from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray
+from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray, numpy
 # pylint: disable=wildcard-import, redefined-builtin
 try:
     from .gen_op import * # pylint: disable=unused-wildcard-import
diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/ndarray/numpy.py
similarity index 63%
copy from python/mxnet/symbol/__init__.py
copy to python/mxnet/ndarray/numpy.py
index f438e49..0826ac8 100644
--- a/python/mxnet/symbol/__init__.py
+++ b/python/mxnet/ndarray/numpy.py
@@ -15,17 +15,4 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""Symbol API of MXNet."""
-
-from . import _internal, contrib, linalg, op, random, sparse, image, symbol
-# pylint: disable=wildcard-import, redefined-builtin
-try:
-    from .gen_op import * # pylint: disable=unused-wildcard-import
-except ImportError:
-    pass
-from . import register
-from .op import *
-from .symbol import *
-# pylint: enable=wildcard-import
-
-__all__ = op.__all__ + symbol.__all__ + ['contrib', 'linalg', 'random', 'sparse', 'image']
+__all__ = []
diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/numpy/__init__.py
similarity index 63%
copy from python/mxnet/symbol/__init__.py
copy to python/mxnet/numpy/__init__.py
index f438e49..b1139a0 100644
--- a/python/mxnet/symbol/__init__.py
+++ b/python/mxnet/numpy/__init__.py
@@ -1,3 +1,5 @@
+#!/usr/bin/env python
+
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -15,17 +17,4 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""Symbol API of MXNet."""
-
-from . import _internal, contrib, linalg, op, random, sparse, image, symbol
-# pylint: disable=wildcard-import, redefined-builtin
-try:
-    from .gen_op import * # pylint: disable=unused-wildcard-import
-except ImportError:
-    pass
-from . import register
-from .op import *
-from .symbol import *
-# pylint: enable=wildcard-import
-
-__all__ = op.__all__ + symbol.__all__ + ['contrib', 'linalg', 'random', 'sparse', 'image']
+__all__ = []
diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py
index f438e49..326e4f5 100644
--- a/python/mxnet/symbol/__init__.py
+++ b/python/mxnet/symbol/__init__.py
@@ -17,7 +17,7 @@
 
 """Symbol API of MXNet."""
 
-from . import _internal, contrib, linalg, op, random, sparse, image, symbol
+from . import _internal, contrib, linalg, op, random, sparse, image, symbol, numpy
 # pylint: disable=wildcard-import, redefined-builtin
 try:
     from .gen_op import * # pylint: disable=unused-wildcard-import
diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/numpy.py
similarity index 63%
copy from python/mxnet/symbol/__init__.py
copy to python/mxnet/symbol/numpy.py
index f438e49..0826ac8 100644
--- a/python/mxnet/symbol/__init__.py
+++ b/python/mxnet/symbol/numpy.py
@@ -15,17 +15,4 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""Symbol API of MXNet."""
-
-from . import _internal, contrib, linalg, op, random, sparse, image, symbol
-# pylint: disable=wildcard-import, redefined-builtin
-try:
-    from .gen_op import * # pylint: disable=unused-wildcard-import
-except ImportError:
-    pass
-from . import register
-from .op import *
-from .symbol import *
-# pylint: enable=wildcard-import
-
-__all__ = op.__all__ + symbol.__all__ + ['contrib', 'linalg', 'random', 'sparse', 'image']
+__all__ = []
diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h
new file mode 100644
index 0000000..c516e6b
--- /dev/null
+++ b/src/operator/numpy/np_broadcast_reduce_op.h
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2015 by Contributors
+ * \file broadcast_reduce_op.h
+ * \brief Function definition of broadcast and reduce operators
+ */
+#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_
+#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_
+
+#include <algorithm>
+#include <vector>
+#include "../tensor/broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+
+struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> {
+  dmlc::optional<mxnet::Tuple<int>> axis;
+  dmlc::optional<int> dtype;
+  bool keepdims;
+  dmlc::optional<double> initial;
+  DMLC_DECLARE_PARAMETER(NumpyReduceAxesParam) {
+    DMLC_DECLARE_FIELD(axis)
+      .set_default(dmlc::optional<mxnet::Tuple<int>>())
+      .describe("Axis or axes along which a sum is performed. The default, axis=None, will sum "
+                "all of the elements of the input array. If axis is negative it counts from the "
+                "last to the first axis.");
+    DMLC_DECLARE_FIELD(dtype)
+      .add_enum("float16", mshadow::kFloat16)
+      .add_enum("float32", mshadow::kFloat32)
+      .add_enum("float64", mshadow::kFloat64)
+      .add_enum("int8", mshadow::kInt8)
+      .add_enum("int32", mshadow::kInt32)
+      .add_enum("int64", mshadow::kInt64)
+      .set_default(dmlc::optional<int>())
+      .describe("The type of the returned array and of the accumulator in which the elements are "
+                "summed. The dtype of a is used by default unless a has an integer dtype of less "
+                "precision than the default platform integer. In that case, if a is signed then "
+                "the platform integer is used while if a is unsigned then an unsigned integer of "
+                "the same precision as the platform integer is used.");
+    DMLC_DECLARE_FIELD(keepdims).set_default(false)
+      .describe("If this is set to `True`, the reduced axes are left "
+                "in the result as dimension with size one.");
+    DMLC_DECLARE_FIELD(initial).set_default(dmlc::optional<double>())
+      .describe("Starting value for the sum.");
+  }
+};
+
+inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape,
+                                       const dmlc::optional<mxnet::Tuple<int>>& axis,
+                                       bool keepdims) {
+  // TODO(junwu): improve the logic
+  // If input is a scalar, output should be a scalar too
+  if (ishape.ndim() == 0) {
+    if (axis.has_value()) {
+      const mxnet::Tuple<int>& axes = axis.value();
+      if (axes.ndim() > 0) {
+        CHECK_EQ(axes.ndim(), 1);
+        CHECK(axes[0] == 0 || axes[0] == -1);
+      }
+    }
+    return TShape(0, -1);
+  }
+
+  // axis=None, do global reduction
+  if (!axis.has_value()) {
+    if (keepdims) {
+      return TShape(ishape.ndim(), 1);
+    } else {
+      return TShape(0, -1);
+    }
+  }
+
+  // axis = (), will return identity(input)
+  if (axis.value().ndim() == 0) {
+    return ishape;
+  }
+
+  // axis has value
+  mxnet::Tuple<int> axes(axis.value());
+  for (index_t i = 0; i < axes.ndim(); i++) {
+    if (axes[i] < 0) {
+      axes[i] += ishape.ndim();
+    }
+  }
+  std::sort(axes.begin(), axes.end());
+
+  for (index_t i = 1; i < axes.ndim(); i++) {
+    CHECK_LT(axes[i-1], axes[i])
+        << "Reduction axes have duplicates "
+        << axes;
+  }
+  CHECK_LT(axes[axes.ndim()-1], ishape.ndim())
+      << "Reduction axis " << axes[axes.ndim()-1]
+      << " Exceeds input dimensions " << ishape;
+  CHECK_GE(axes[0], 0)
+      << "Reduction axis " << axis.value()
+      << " Exceeds input dimensions " << ishape;
+
+  TShape oshape;
+  if (keepdims) {
+    oshape = TShape(ishape);
+  } else {
+    oshape = TShape(ishape.ndim() - axes.ndim(), -1);
+  }
+
+  if (keepdims) {
+    for (index_t i = 0; i < axes.ndim(); ++i) {
+      oshape[axes[i]] = 1;
+    }
+  } else {
+    for (index_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) {
+      if (j < axes.ndim() && i == axes[j]) {
+        ++j;
+        continue;
+      }
+      oshape[k++] = ishape[i];
+    }
+  }
+  return oshape;
+}
+
+inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs,
+                                 std::vector<TShape> *in_attrs,
+                                 std::vector<TShape> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  if (!shape_is_known(in_attrs->at(0))) {
+    return false;
+  }
+  const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0,
+                     NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims));
+  return shape_is_known(out_attrs->at(0));
+}
+
+template<bool safe_acc_hint = false>
+inline bool NeedSafeAcc(int itype, int otype) {
+  bool rule = (itype != otype) || (itype != mshadow::kFloat32 && itype != mshadow::kFloat64);
+  return safe_acc_hint && rule;
+}
+
+template<typename xpu, typename reducer, bool safe_acc_hint = false, bool normalize = false,
+         typename OP = op::mshadow_op::identity>
+void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
+                            const OpContext& ctx,
+                            const std::vector<TBlob>& inputs,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<TBlob>& outputs) {
+  const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
+  if (param.initial.has_value()) {
+    LOG(FATAL) << "initial is not supported yet";
+  }
+  if (param.axis.has_value() && param.axis.value().ndim() == 0) {
+    UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs);
+  }
+  TShape small;
+  if (param.keepdims) {
+    small = outputs[0].shape_;
+  } else {
+    small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
+  }
+
+  if (NeedSafeAcc<safe_acc_hint>(inputs[0].type_flag_, outputs[0].type_flag_)) {
+    ReduceAxesComputeImpl<xpu, reducer, true, normalize, OP>(ctx, inputs, req, outputs, small);
+  } else {
+    ReduceAxesComputeImpl<xpu, reducer, false, normalize, OP>(ctx, inputs, req, outputs, small);
+  }
+}
+
+template<typename xpu, bool normalize = false>
+inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
+                                           const OpContext& ctx,
+                                           const std::vector<TBlob>& inputs,
+                                           const std::vector<OpReqType>& req,
+                                           const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
+  TShape small;
+  if (param.keepdims) {
+    small = inputs[0].shape_;
+  } else {
+    small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true);
+  }
+
+  BroadcastComputeImpl<xpu>(attrs, ctx, inputs, req, outputs, small);
+  if (normalize) {
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
+      Tensor<xpu, 1, IType> igrad = outputs[0].FlatTo1D<xpu, IType>(s);
+      printf("output size: %lu input_size: %lu\n", outputs[0].Size(), inputs[0].Size());
+      igrad /= scalar<IType>(outputs[0].Size()/inputs[0].Size());
+    });
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc
new file mode 100644
index 0000000..6c81bf6
--- /dev/null
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file np_reduce_op_value.cc
+ * \brief CPU Implementation of broadcast and reduce functions based on value.
+ */
+
+#include "np_broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(NumpyReduceAxesParam);
+
+inline bool NumpySumType(const nnvm::NodeAttrs& attrs,
+                         std::vector<int> *in_attrs,
+                         std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const NumpyReduceAxesParam &param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
+
+  if (param.dtype.has_value()) {
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
+  } else {
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+    TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+  }
+
+  return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
+}
+
+NNVM_REGISTER_OP(_numpy_sum)
+.describe(R"code()code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyReduceAxesParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxesShape)
+.set_attr<nnvm::FInferType>("FInferType", NumpySumType)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"a"};
+  })
+.add_argument("a", "NDArray-or-Symbol", "The input")
+.add_arguments(NumpyReduceAxesParam::__FIELDS__())
+.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesCompute<cpu, mshadow_op::sum, true>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"});
+
+NNVM_REGISTER_OP(_backward_numpy_sum)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyReduceAxesParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_num_inputs(1)
+.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesBackwardUseNone<cpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu
new file mode 100644
index 0000000..aa6bed4
--- /dev/null
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file np_reduce_op_value.cu
+ * \brief GPU Implementation of reduce functions based on value.
+ */
+#include "np_broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+NNVM_REGISTER_OP(_numpy_sum)
+.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::sum, true>);
+
+NNVM_REGISTER_OP(_backward_numpy_sum)
+.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index c7c4993..a6ee242 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -968,6 +968,34 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs,
   ReduceAxesBackwardUseInOutImpl<xpu, OP, normalize>(ctx, small, inputs, req, outputs);
 }
 
+template<typename OP>
+struct broadcast_kernel {
+  template<typename IType, typename OType>
+  MSHADOW_XINLINE static void Map(index_t i,
+                                  IType *input,
+                                  OType *output,
+                                  mshadow::Shape<5> in_shape,
+                                  mshadow::Shape<5> out_shape,
+                                  const OpReqType req,
+                                  const uint32_t ndim) {
+    size_t in_stride = 1;
+    size_t out_stride = 1;
+    index_t idx = i;
+    index_t in_idx = i;
+    for (int iter = ndim - 1; iter >= 0; --iter) {
+      size_t dim_idx = idx % out_shape[iter];
+      in_idx -= dim_idx * out_stride;
+      if (in_shape[iter] != 1) {
+        in_idx += dim_idx * in_stride;
+      }
+      idx /= out_shape[iter];
+      in_stride *= in_shape[iter];
+      out_stride *= out_shape[iter];
+    }
+    KERNEL_ASSIGN(output[i], req, OP::Map(input[in_idx]));
+  }
+};
+
 template<typename xpu>
 inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
                                  const OpContext& ctx,
@@ -977,24 +1005,40 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
                                  const mxnet::TShape& small) {
   using namespace mshadow;
   using namespace mshadow::expr;
+  using namespace mxnet_op;
   mxnet::TShape src_shape, dst_shape;
   BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape);
   Stream<xpu> *s = ctx.get_stream<xpu>();
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    if (dst_shape.ndim() == 2) {
-      Tensor<xpu, 2, DType> out =
-        outputs[0].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
-      Tensor<xpu, 2, DType> data =
-        inputs[0].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
-      ASSIGN_DISPATCH(out, req[0], broadcast_to(data, dst_shape));
-    } else {
-      const int ndim = MXNET_SPECIAL_MAX_NDIM;
-      Tensor<xpu, ndim, DType> out =
-        outputs[0].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
-      Tensor<xpu, ndim, DType> data =
-        inputs[0].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
-      ASSIGN_DISPATCH(out, req[0], broadcast_to(data, dst_shape));
-    }
+  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      mshadow::Shape<5> in_shape;
+      mshadow::Shape<5> out_shape;
+      for (int i = 0; i < 5; ++i) {
+        if (i < dst_shape.ndim()) {
+          in_shape[i] = src_shape[i];
+          out_shape[i] = dst_shape[i];
+        } else {
+          in_shape[i] = 1;
+          out_shape[i] = 1;
+        }
+      }
+      if (dst_shape.ndim() == 2) {
+        Tensor<xpu, 2, OType> out =
+          outputs[0].get_with_shape<xpu, 2, OType>(dst_shape.get<2>(), s);
+        Tensor<xpu, 2, IType> data =
+          inputs[0].get_with_shape<xpu, 2, IType>(src_shape.get<2>(), s);
+        Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
+          s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], 2);
+      } else {
+        const int ndim = MXNET_SPECIAL_MAX_NDIM;
+        Tensor<xpu, ndim, OType> out =
+          outputs[0].get_with_shape<xpu, ndim, OType>(dst_shape.get<ndim>(), s);
+        Tensor<xpu, ndim, IType> data =
+          inputs[0].get_with_shape<xpu, ndim, IType>(src_shape.get<ndim>(), s);
+        Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
+          s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], ndim);
+      }
+    });
   });
 }
 
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 5b4f81d..105b5aa 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -35,6 +35,7 @@ sys.path.insert(0, os.path.join(curr_path, '../unittest'))
 from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied
 from common import run_in_spawned_process
 from test_operator import *
+from test_numpy_op import *
 from test_optimizer import *
 from test_random import *
 from test_exc_handling import *
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
new file mode 100644
index 0000000..75e3428
--- /dev/null
+++ b/tests/python/unittest/test_numpy_op.py
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+from __future__ import absolute_import
+import numpy as _np
+import mxnet as mx
+from mxnet import numpy as np
+from mxnet.gluon import HybridBlock
+from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray
+from mxnet.test_utils import check_numeric_gradient
+from common import with_seed
+import random
+
+
+@mx.use_np_compat
+@with_seed()
+def test_np_sum():
+    class TestSum(HybridBlock):
+        def __init__(self, axis=None, dtype=None, keepdims=False):# , initial=None):
+            super(TestSum, self).__init__()
+            self._axis = axis
+            self._dtype = dtype
+            self._keepdims = keepdims
+
+        def hybrid_forward(self, F, a, *args, **kwargs):
+            return F.numpy.sum(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
+
+    def is_int(dtype):
+        return 'int' in dtype
+
+    in_data_dim = random.choice([2, 3, 4])
+    shape = rand_shape_nd(in_data_dim, dim=3)
+    acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64',
+                'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
+    for hybridize in [False, True]:
+        for keepdims in [True, False]:
+            for axis in ([i for i in range(in_data_dim)] + [(), None]):
+                for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']:
+                    for dtype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']:
+                        if is_int(dtype) and not is_int(itype):
+                            continue
+                        # test gluon
+                        test_sum = TestSum(axis=axis, dtype=dtype, keepdims=keepdims)
+                        if hybridize:
+                            test_sum.hybridize()
+                        if is_int(itype):
+                            x = _np.random.randint(-128, 128, shape, dtype=itype)
+                            x = mx.nd.array(x)
+                        else:
+                            x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
+                        x.attach_grad()
+                        expected_ret = _np.sum(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
+                        expected_ret = expected_ret.astype(dtype)
+                        with mx.autograd.record():
+                            y = test_sum(x)
+                        assert y.shape == expected_ret.shape
+                        assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
+                                            atol=1e-5 if dtype == 'float16' else 1e-5)
+
+                        y.backward()
+                        assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype))
+
+                        # test numeric
+                        if itype == 'float32' and dtype == 'float32':
+                            x_sym = mx.sym.Variable("x")
+                            mx_sym = mx.sym.numpy.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims)
+                            check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
+
+                        # test imperative
+                        mx_out = np.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)
+                        np_out = _np.sum(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
+                        assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()