You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/06/25 23:45:38 UTC

[incubator-mxnet] branch master updated: [MXNET-349] Histogram Operator (#10931)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ed7e360  [MXNET-349] Histogram Operator (#10931)
ed7e360 is described below

commit ed7e3602a8046646582c0c681b70d9556f5fa0a4
Author: Hao Jin <ha...@users.noreply.github.com>
AuthorDate: Mon Jun 25 16:45:32 2018 -0700

    [MXNET-349] Histogram Operator (#10931)
    
    * implementation of histogram operator
    
    * address code reviews and code re-design
    
    * add exception for invalid inputs
    
    * address code reviews
    
    * add symbol and symbolic forward check for histogram
---
 python/mxnet/ndarray/ndarray.py              |  35 +++++-
 python/mxnet/symbol/symbol.py                |  30 ++++-
 src/common/cuda_utils.h                      |  30 +++++
 src/operator/tensor/histogram-inl.h          | 172 +++++++++++++++++++++++++++
 src/operator/tensor/histogram.cc             | 159 +++++++++++++++++++++++++
 src/operator/tensor/histogram.cu             | 111 +++++++++++++++++
 src/operator/tensor/util/tensor_util-inl.cuh |   4 +-
 tests/python/unittest/test_operator.py       |  34 ++++++
 8 files changed, 571 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index f017d7e..002ce3e 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -46,7 +46,7 @@ __all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRA
            "ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal",
            "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
            "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
-           "power", "subtract", "true_divide", "waitall", "_new_empty_handle"]
+           "power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram"]
 
 _STORAGE_TYPE_UNDEFINED = -1
 _STORAGE_TYPE_DEFAULT = 0
@@ -3740,3 +3740,36 @@ def empty(shape, ctx=None, dtype=None):
     if dtype is None:
         dtype = mx_real_t
     return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype))
+
+
+# pylint: disable= redefined-builtin
+def histogram(a, bins=10, range=None):
+    """Compute the histogram of the input data.
+
+    Parameters
+    ----------
+    a : NDArray
+        Input data. The histogram is computed over the flattened array.
+    bins : int or sequence of scalars
+        If bins is an int, it defines the number of equal-width bins in the
+        given range (10, by default). If bins is a sequence, it defines the bin edges,
+        including the rightmost edge, allowing for non-uniform bin widths.
+    range : (float, float), optional
+        The lower and upper range of the bins. If not provided, range is simply (a.min(), a.max()).
+        Values outside the range are ignored. The first element of the range must be less than or
+        equal to the second. range affects the automatic bin computation as well, the range will
+        be equally divided by the number of bins.
+    """
+
+    # pylint: disable= no-member, protected-access
+    if isinstance(bins, NDArray):
+        return _internal._histogram(data=a, bins=bins)
+    elif isinstance(bins, integer_types):
+        if range is None:
+            warnings.warn("range is not specified, using numpy's result "
+                          "to ensure consistency with numpy")
+            res, bin_bounds = np.histogram(a.asnumpy(), bins=bins)
+            return array(res), array(bin_bounds)
+        return _internal._histogram(data=a, bin_cnt=bins, range=range)
+    raise ValueError("bins argument should be either an integer or an NDArray")
+    # pylint: enable= no-member, protected-access, redefined-builtin
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 7e5b527..c5e2f5c 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -34,7 +34,7 @@ import numpy as _numpy
 
 from ..attribute import AttrScope
 from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, c_str_array, c_handle_array
-from ..base import mx_uint, py_str, string_types
+from ..base import mx_uint, py_str, string_types, integer_types
 from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle
 from ..base import check_call, MXNetError, NotImplementedForSymbol
 from ..context import Context, current_context
@@ -47,7 +47,8 @@ from . import op
 from ._internal import SymbolBase, _set_symbol_class
 
 __all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
-           "pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange"]
+           "pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange",
+           "histogram"]
 
 
 class Symbol(SymbolBase):
@@ -2864,4 +2865,29 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
     return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
                              name=name, dtype=dtype)
 
+def histogram(a, bins=10, range=None, **kwargs):
+    """Compute the histogram of the input data.
+
+    Parameters
+    ----------
+    a : NDArray
+        Input data. The histogram is computed over the flattened array.
+    bins : int or sequence of scalars
+        If bins is an int, it defines the number of equal-width bins in the
+        given range (10, by default). If bins is a sequence, it defines the bin edges,
+        including the rightmost edge, allowing for non-uniform bin widths.
+    range : (float, float), required if bins is an integer
+        The lower and upper range of the bins. If not provided, range is simply (a.min(), a.max()).
+        Values outside the range are ignored. The first element of the range must be less than or
+        equal to the second. range affects the automatic bin computation as well, the range will
+        be equally divided by the number of bins.
+    """
+    if isinstance(bins, Symbol):
+        return _internal._histogram(data=a, bins=bins, **kwargs)
+    elif isinstance(bins, integer_types):
+        if range is None:
+            raise ValueError("null range is not supported in symbol mode")
+        return _internal._histogram(data=a, bin_cnt=bins, range=range, **kwargs)
+    raise ValueError("bins argument should be either an integer or an NDArray")
+
 _set_symbol_class(Symbol)
diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index 962fe5a..b4b10c2 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -494,6 +494,36 @@ static inline __device__ void atomicAdd(mshadow::half::half_t *address,
   } while (assumed != old);
 }
 
+static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
+  unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 0x3));
+  unsigned int old = *address_as_ui;
+  unsigned int shift = (((size_t)address & 0x3) << 3);
+  unsigned int sum;
+  unsigned int assumed;
+
+  do {
+    assumed = old;
+    sum = val + static_cast<uint8_t>((old >> shift) & 0xff);
+    old = (old & ~(0x000000ff << shift)) | (sum << shift);
+    old = atomicCAS(address_as_ui, assumed, old);
+  } while (assumed != old);
+}
+
+static inline __device__ void atomicAdd(int8_t *address, int8_t val) {
+  unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 0x3));
+  unsigned int old = *address_as_ui;
+  unsigned int shift = (((size_t)address & 0x3) << 3);
+  unsigned int sum;
+  unsigned int assumed;
+
+  do {
+    assumed = old;
+    sum = val + static_cast<int8_t>((old >> shift) & 0xff);
+    old = (old & ~(0x000000ff << shift)) | (sum << shift);
+    old = atomicCAS(address_as_ui, assumed, old);
+  } while (assumed != old);
+}
+
 // Overload atomicAdd to work for signed int64 on all architectures
 static inline  __device__  void atomicAdd(int64_t *address, int64_t val) {
   atomicAdd(reinterpret_cast<unsigned long long*>(address), static_cast<unsigned long long>(val)); // NOLINT
diff --git a/src/operator/tensor/histogram-inl.h b/src/operator/tensor/histogram-inl.h
new file mode 100644
index 0000000..08620e8
--- /dev/null
+++ b/src/operator/tensor/histogram-inl.h
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_TENSOR_HISTOGRAM_INL_H_
+#define MXNET_OPERATOR_TENSOR_HISTOGRAM_INL_H_
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <mxnet/operator_util.h>
+#include <dmlc/optional.h>
+#include <mshadow/tensor.h>
+#include <nnvm/op.h>
+#include <nnvm/node.h>
+#include <nnvm/op_attr_types.h>
+#include <vector>
+#include <type_traits>
+#include "./util/tensor_util-inl.h"
+#include "../elemwise_op_common.h"
+#include "../mshadow_op.h"
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct HistogramParam : public dmlc::Parameter<HistogramParam> {
+    dmlc::optional<int> bin_cnt;
+    dmlc::optional<nnvm::Tuple<double>> range;
+    DMLC_DECLARE_PARAMETER(HistogramParam) {
+      DMLC_DECLARE_FIELD(bin_cnt)
+        .set_default(dmlc::optional<int>())
+        .describe("Number of bins for uniform case");
+      DMLC_DECLARE_FIELD(range)
+        .set_default(dmlc::optional<nnvm::Tuple<double>>())
+        .describe("The lower and upper range of the bins. if not provided, "
+                  "range is simply (a.min(), a.max()). values outside the "
+                  "range are ignored. the first element of the range must be "
+                  "less than or equal to the second. range affects the automatic "
+                  "bin computation as well. while bin width is computed to be "
+                  "optimal based on the actual data within range, the bin count "
+                  "will fill the entire range including portions containing no data.");
+    }
+};
+
+struct FillBinBoundsKernel {
+  template<typename DType>
+  static MSHADOW_XINLINE void Map(int i, DType* bin_bounds, int bin_cnt, double min, double max) {
+    if (i <= bin_cnt) {
+      bin_bounds[i] = DType((max * i + (bin_cnt - i) * min) / bin_cnt);
+    }
+  }
+};
+
+inline bool HistogramOpShape(const nnvm::NodeAttrs& attrs,
+                             std::vector<TShape>* in_attrs,
+                             std::vector<TShape>* out_attrs) {
+  HistogramParam param = nnvm::get<HistogramParam>(attrs.parsed);
+  const bool has_cnt = param.bin_cnt.has_value();
+  const bool has_range = param.range.has_value();
+  const bool legal_param = (has_cnt && has_range) || (!has_cnt && !has_range);
+  CHECK_EQ(in_attrs->size(), has_cnt ? 1U : 2U);
+  CHECK_EQ(out_attrs->size(), 2U);
+  CHECK(legal_param) << "cnt and range should both or neither specified";
+
+  if (has_cnt) {
+    // if cnt is specified, the output histogram has shape (cnt,)
+    // while output bins has shape (cnt+1,)
+    const int bin_cnt = param.bin_cnt.value();
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({bin_cnt}));
+    SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape({bin_cnt + 1}));
+  } else {
+    // if cnt is not specified, the output histogram has shape (bins.Size() - 1)
+    // while output bins has same shape as input bins
+    TShape oshape = (*in_attrs)[1];
+
+    CHECK_EQ(oshape.ndim(), 1U) << "bins argument should be an 1D vector";
+    CHECK_GE(oshape.Size(), 2U) << "number of bounds should be >= 2";
+
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({(oshape[0] - 1)}));
+    SHAPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(1));
+  }
+
+  return !shape_is_none(out_attrs->at(0)) && !shape_is_none(out_attrs->at(1)) &&
+         out_attrs->at(0).Size() == out_attrs->at(1).Size() - 1;
+}
+
+inline bool HistogramOpType(const nnvm::NodeAttrs& attrs,
+                            std::vector<int>* in_attrs,
+                            std::vector<int>* out_attrs) {
+  CHECK_EQ(out_attrs->size(), 2U);
+
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
+  TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0));
+  return !type_is_none(out_attrs->at(0)) && !type_is_none(out_attrs->at(1));
+}
+
+template<typename xpu>
+void HistogramForwardImpl(const OpContext& ctx,
+                          const TBlob& in_data,
+                          const TBlob& bin_bounds,
+                          const TBlob& out_data,
+                          const TBlob& out_bins);
+
+template<typename xpu>
+void HistogramForwardImpl(const OpContext& ctx,
+                          const TBlob& in_data,
+                          const TBlob& out_data,
+                          const TBlob& out_bins,
+                          const int bin_cnt,
+                          const double min,
+                          const double max);
+
+template<typename xpu>
+void HistogramOpForward(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const std::vector<TBlob>& inputs,
+                        const std::vector<OpReqType>& req,
+                        const std::vector<TBlob>& outputs) {
+  CHECK_EQ(req.size(), 2U);
+  CHECK_EQ(req[0], kWriteTo);
+  CHECK_EQ(req[1], kWriteTo);
+  const HistogramParam& param = nnvm::get<HistogramParam>(attrs.parsed);
+  const bool has_cnt = param.bin_cnt.has_value();
+  const bool has_range = param.range.has_value();
+  const bool legal_params = (has_cnt && has_range) || (!has_cnt && !has_range);
+  CHECK(legal_params) << "width and range should both or neither be specified";
+
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  const TBlob& out_bins = outputs[1];
+
+  if (has_cnt) {
+    CHECK((param.range.value().ndim() == 2U)) << "range should be a tuple with only 2 elements";
+    CHECK(param.range.value()[0] <= param.range.value()[1])
+      << "left hand side of range(" << param.range.value()[0]
+      << ")should be less than or equal to right hand side(" << param.range.value()[1] << ")";
+    double max = param.range.value()[1];
+    double min = param.range.value()[0];
+    const int bin_cnt = param.bin_cnt.value();
+    if (min == max) {
+      min -= 0.5f;
+      max += 0.5f;
+      LOG(INFO) << min << " " << max;
+    }
+    HistogramForwardImpl<xpu>(ctx, in_data, out_data, out_bins, bin_cnt, min, max);
+  } else {
+    const TBlob& bin_bounds = inputs[1];
+    HistogramForwardImpl<xpu>(ctx, in_data, bin_bounds, out_data, out_bins);
+  }
+}
+
+}   // namespace op
+}   // namespace mxnet
+
+#endif  // MXNET_OPERATOR_TENSOR_HISTOGRAM_INL_H_
diff --git a/src/operator/tensor/histogram.cc b/src/operator/tensor/histogram.cc
new file mode 100644
index 0000000..ac28606
--- /dev/null
+++ b/src/operator/tensor/histogram.cc
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./histogram-inl.h"
+
+namespace mxnet {
+namespace op {
+
+struct ComputeBinKernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, const DType* in_data, const DType* bin_bounds,
+                                  int* bin_indices, int bin_cnt, double min, double max) {
+    DType data = in_data[i];
+    int target = -1;
+    if (data >= min && data <= max) {
+      target = (data - min) * bin_cnt / (max - min);
+      target = mshadow_op::minimum::Map(bin_cnt - 1, target);
+      target -= (data < bin_bounds[target]) ? 1 : 0;
+      target += ((data >= bin_bounds[target + 1]) && (target != bin_cnt - 1)) ? 1 : 0;
+    }
+    bin_indices[i] = target;
+  }
+
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, const DType* in_data, int* bin_indices,
+                                   const DType* bin_bounds, int num_bins) {
+    DType data = in_data[i];
+    int target = -1;
+    if (data >= bin_bounds[0] && data <= bin_bounds[num_bins]) {
+      target = 0;
+      while ((data - bin_bounds[target]) >= 0) {
+        target += 1;
+      }
+      target = mshadow_op::minimum::Map(target - 1, num_bins - 1);
+    }
+    bin_indices[i] = target;
+  }
+};
+
+template<typename CType>
+void ComputeHistogram(const int* bin_indices, CType* out_data, size_t input_size) {
+  for (size_t i = 0; i < input_size; ++i) {
+    int target = bin_indices[i];
+    if (target >= 0) {
+      out_data[target] += 1;
+    }
+  }
+}
+
+template<>
+void HistogramForwardImpl<cpu>(const OpContext& ctx,
+                               const TBlob& in_data,
+                               const TBlob& bin_bounds,
+                               const TBlob& out_data,
+                               const TBlob& out_bins) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+  Tensor<cpu, 1, int> bin_indices =
+    ctx.requested[0].get_space_typed<cpu, 1, int>(Shape1(in_data.Size()), s);
+  const int bin_cnt = out_data.Size();
+
+  MSHADOW_TYPE_SWITCH(in_data.type_flag_, DType, {
+    Kernel<ComputeBinKernel, cpu>::Launch(
+      s, in_data.Size(), in_data.dptr<DType>(), bin_indices.dptr_, bin_bounds.dptr<DType>(),
+      bin_cnt);
+    Kernel<op_with_req<mshadow_op::identity, kWriteTo>, cpu>::Launch(
+      s, bin_bounds.Size(), out_bins.dptr<DType>(), bin_bounds.dptr<DType>());
+  });
+  MSHADOW_TYPE_SWITCH(out_data.type_flag_, CType, {
+    Kernel<set_zero, cpu>::Launch(s, bin_cnt, out_data.dptr<CType>());
+    ComputeHistogram(bin_indices.dptr_, out_data.dptr<CType>(), in_data.Size());
+  });
+}
+
+template<>
+void HistogramForwardImpl<cpu>(const OpContext& ctx,
+                               const TBlob& in_data,
+                               const TBlob& out_data,
+                               const TBlob& out_bins,
+                               const int bin_cnt,
+                               const double min,
+                               const double max) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+  Tensor<cpu, 1, int> bin_indices =
+    ctx.requested[0].get_space_typed<cpu, 1, int>(Shape1(in_data.Size()), s);
+
+  MSHADOW_TYPE_SWITCH(in_data.type_flag_, DType, {
+    Kernel<FillBinBoundsKernel, cpu>::Launch(
+      s, bin_cnt+1, out_bins.dptr<DType>(), bin_cnt, min, max);
+    Kernel<ComputeBinKernel, cpu>::Launch(
+      s, in_data.Size(), in_data.dptr<DType>(), out_bins.dptr<DType>(), bin_indices.dptr_,
+      bin_cnt, min, max);
+  });
+  MSHADOW_TYPE_SWITCH(out_data.type_flag_, CType, {
+    Kernel<set_zero, cpu>::Launch(s, bin_cnt, out_data.dptr<CType>());
+    ComputeHistogram(bin_indices.dptr_, out_data.dptr<CType>(), in_data.Size());
+  });
+}
+
+DMLC_REGISTER_PARAMETER(HistogramParam);
+
+NNVM_REGISTER_OP(_histogram)
+.describe(R"code(This operators implements the histogram function.
+
+Example::
+  x = [[0, 1], [2, 2], [3, 4]]
+  histo, bin_edges = histogram(data=x, bin_bounds=[], bin_cnt=5, range=(0,5))
+  histo = [1, 1, 2, 1, 1]
+  bin_edges = [0., 1., 2., 3., 4.]
+  histo, bin_edges = histogram(data=x, bin_bounds=[0., 2.1, 3.])
+  histo = [4, 1]
+
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<HistogramParam>)
+.set_num_inputs([](const NodeAttrs& attrs) {
+    const HistogramParam& params = nnvm::get<HistogramParam>(attrs.parsed);
+    return params.bin_cnt.has_value() ? 1 : 2;
+})
+.set_num_outputs(2)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    const HistogramParam& params = nnvm::get<HistogramParam>(attrs.parsed);
+    return params.bin_cnt.has_value() ?
+           std::vector<std::string>{"data"} :
+           std::vector<std::string>{"data", "bins"};
+  })
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+.set_attr<nnvm::FInferShape>("FInferShape", HistogramOpShape)
+.set_attr<nnvm::FInferType>("FInferType", HistogramOpType)
+.set_attr<FCompute>("FCompute<cpu>", HistogramOpForward<cpu>)
+.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
+.add_argument("bins", "NDArray-or-Symbol", "Input ndarray")
+.add_arguments(HistogramParam::__FIELDS__());
+
+}  // namespace op
+}  // namespace mxnet
+
diff --git a/src/operator/tensor/histogram.cu b/src/operator/tensor/histogram.cu
new file mode 100644
index 0000000..c3c836a
--- /dev/null
+++ b/src/operator/tensor/histogram.cu
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./histogram-inl.h"
+#include "./util/tensor_util-inl.cuh"
+
+namespace mxnet {
+namespace op {
+
+struct HistogramFusedKernel {
+  template<typename DType, typename CType>
+  static MSHADOW_XINLINE void Map(int i, const DType* in_data, const DType* bin_bounds, CType* bins,
+                                  const int bin_cnt, const double min, const double max) {
+    DType data = in_data[i];
+    int target = -1;
+    if (data >= min && data <= max) {
+      target = mshadow_op::floor::Map((data - min) * bin_cnt / (max - min));
+      target = mshadow_op::minimum::Map(bin_cnt - 1, target);
+      target -= (data < bin_bounds[target]) ? 1 : 0;
+      target += ((data >= bin_bounds[target + 1]) && (target != bin_cnt - 1)) ? 1 : 0;
+    }
+    if (target >= 0) {
+      atomicAdd(&bins[target], CType(1));
+    }
+  }
+
+  template<typename DType, typename CType>
+  static MSHADOW_XINLINE void Map(int i, const DType* in_data, const DType* bin_bounds, CType* bins,
+                                  const int bin_cnt) {
+    DType data = in_data[i];
+    int target = -1;
+    if (data >= bin_bounds[0] && data <= bin_bounds[bin_cnt]) {
+      target = 0;
+      while (data >= bin_bounds[target]) {
+        target += 1;
+      }
+      target = min(target - 1, bin_cnt - 1);
+    }
+    if (target >= 0) {
+      atomicAdd(&bins[target], CType(1));
+    }
+  }
+};
+
+template<>
+void HistogramForwardImpl<gpu>(const OpContext& ctx,
+                               const TBlob& in_data,
+                               const TBlob& bin_bounds,
+                               const TBlob& out_data,
+                               const TBlob& out_bins) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  MSHADOW_TYPE_SWITCH(in_data.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(out_data.type_flag_, CType, {
+      int bin_cnt = out_bins.Size() - 1;
+      Kernel<set_zero, gpu>::Launch(s, bin_cnt, out_data.dptr<CType>());
+      Kernel<HistogramFusedKernel, gpu>::Launch(
+        s, in_data.Size(), in_data.dptr<DType>(), bin_bounds.dptr<DType>(),
+        out_data.dptr<CType>(), bin_cnt);
+      Kernel<op_with_req<mshadow_op::identity, kWriteTo>, gpu>::Launch(
+        s, bin_bounds.Size(), out_bins.dptr<DType>(), bin_bounds.dptr<DType>());
+    });
+  });
+}
+
+template<>
+void HistogramForwardImpl<gpu>(const OpContext& ctx,
+                               const TBlob& in_data,
+                               const TBlob& out_data,
+                               const TBlob& out_bins,
+                               const int bin_cnt,
+                               const double min,
+                               const double max) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  MSHADOW_TYPE_SWITCH(in_data.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(out_data.type_flag_, CType, {
+      Kernel<set_zero, gpu>::Launch(s, bin_cnt, out_data.dptr<CType>());
+      Kernel<FillBinBoundsKernel, gpu>::Launch(
+        s, bin_cnt+1, out_bins.dptr<DType>(), bin_cnt, min, max);
+      Kernel<HistogramFusedKernel, gpu>::Launch(
+        s, in_data.Size(), in_data.dptr<DType>(), out_bins.dptr<DType>(), out_data.dptr<CType>(),
+        bin_cnt, min, max);
+    });
+  });
+}
+
+NNVM_REGISTER_OP(_histogram)
+.set_attr<FCompute>("FCompute<gpu>", HistogramOpForward<gpu>);
+
+}   // namespace op
+}   // namespace mxnet
+
diff --git a/src/operator/tensor/util/tensor_util-inl.cuh b/src/operator/tensor/util/tensor_util-inl.cuh
index c9ee625..a58f7db 100644
--- a/src/operator/tensor/util/tensor_util-inl.cuh
+++ b/src/operator/tensor/util/tensor_util-inl.cuh
@@ -242,7 +242,9 @@ struct HistogramKernel {
                                              const CType* source,
                                              const nnvm::dim_t num_elems) {
     if (tid < num_elems) {
-      atomicAdd(&target[source[tid]], 1);
+      if (source[tid] >= 0) {
+        atomicAdd(&target[source[tid]], IType(1));
+      }
     }
   }
 };
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 287d830..3de30f2 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -6025,6 +6025,40 @@ def test_quadratic_function():
         check_numeric_gradient(quad_sym, [data_np], atol=0.001)
 
 
+@with_seed()
+def test_histogram():
+    def f(x, bins=10, range=None):
+        return np.histogram(x, bins, range=range)
+
+    for ndim in range(1, 6):
+        shape = rand_shape_nd(ndim)
+        x = rand_ndarray(shape, stype='default', dtype=np.float64)
+        mx_bins = mx.nd.array([-1.0, 0.5, 2.0, 4.5, 50.0], dtype=np.float64)
+        np_bins = mx_bins.asnumpy()
+        bin_cnt = random.randint(2, 10)
+        bin_range = (-2.5, 2.5)
+        mx_histo1, mx_bins1 = mx.nd.histogram(x, bins=bin_cnt, range=bin_range)
+        np_histo1, np_bins1 = f(x.asnumpy(), bins=bin_cnt, range=bin_range)
+        assert_almost_equal(mx_bins1.asnumpy(), np_bins1)
+        assert_almost_equal(mx_histo1.asnumpy(), np_histo1, rtol=1e-3, atol=1e-5)
+        mx_histo2, mx_bins2 = mx.nd.histogram(x, bins=mx_bins)
+        np_histo2, np_bins2 = f(x.asnumpy(), bins=np_bins)
+        assert_almost_equal(mx_histo2.asnumpy(), np_histo2, rtol=1e-3, atol=1e-5)
+        assert_almost_equal(mx_bins2.asnumpy(), np_bins2, rtol=1e-3, atol=1e-5)
+
+        data = mx.sym.Variable("data")
+
+        bins = mx.sym.Variable("bins")
+        histo1 = mx.sym.histogram(a=data, bins=bin_cnt, range=bin_range)
+        histo2 = mx.sym.histogram(a=data, bins=bins)
+        executor1 = histo1.bind(ctx=default_context(), args={"data" : x})
+        executor1.forward(is_train=False)
+        assert_almost_equal(np_histo1, executor1.outputs[0].asnumpy(), 0, 0, ("EXPECTED_histo1", "FORWARD_histo1"), equal_nan=False)
+        executor2 = histo2.bind(ctx=default_context(), args={"data" : x, "bins" : mx_bins})
+        executor2.forward(is_train=False)
+        assert_almost_equal(np_histo2, executor2.outputs[0].asnumpy(), 0, 0, ("EXPECTED_histo2", "FORWARD_histo2"), equal_nan=False)
+
+
 def test_op_output_names_monitor():
     def check_name(op_sym, expected_names):
         output_names = []