You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/01 19:35:24 UTC

[GitHub] piiswrong closed pull request #10970: [MXNET-424] dtype option for multinomial

piiswrong closed pull request #10970: [MXNET-424] dtype option for multinomial
URL: https://github.com/apache/incubator-mxnet/pull/10970
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py
index 7be854a3107..d0c83c10e6b 100644
--- a/python/mxnet/ndarray/random.py
+++ b/python/mxnet/ndarray/random.py
@@ -389,7 +389,7 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, ctx=N
                           [mu, alpha], shape, dtype, ctx, out, kwargs)
 
 
-def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
+def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kwargs):
     """Concurrent sampling from multiple multinomial distributions.
 
     .. note:: The input distribution must be normalized, i.e. `data` must sum to
@@ -412,6 +412,9 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
         reward as head gradient w.r.t. this array to estimate gradient.
     out : NDArray
         Store output to an existing NDArray.
+    dtype : str or numpy.dtype
+        Data type of the sample output array. The default is int32.
+        Note that the data type of the log likelihood array is the same with that of `data`.
 
     Examples
     --------
@@ -429,7 +432,7 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
     [-1.20397282 -1.60943794]
     <NDArray 2 @cpu(0)>
     """
-    return _internal._sample_multinomial(data, shape, get_prob, out=out, **kwargs)
+    return _internal._sample_multinomial(data, shape, get_prob, out=out, dtype=dtype, **kwargs)
 
 
 def shuffle(data, **kwargs):
diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py
index 24f1c5aa49e..e9abe9c4a18 100644
--- a/python/mxnet/symbol/random.py
+++ b/python/mxnet/symbol/random.py
@@ -224,7 +224,7 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, **kwa
                           [mu, alpha], shape, dtype, kwargs)
 
 
-def multinomial(data, shape=_Null, get_prob=True, **kwargs):
+def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
     """Concurrent sampling from multiple multinomial distributions.
 
     .. note:: The input distribution must be normalized, i.e. `data` must sum to
@@ -245,8 +245,11 @@ def multinomial(data, shape=_Null, get_prob=True, **kwargs):
         samples will also be returned.
         This is usually used for reinforcement learning, where you can provide
         reward as head gradient w.r.t. this array to estimate gradient.
+    dtype : str or numpy.dtype
+        Data type of the sample output array. The default is int32.
+        Note that the data type of the log likelihood array is the same with that of `data`.
     """
-    return _internal._sample_multinomial(data, shape, get_prob, **kwargs)
+    return _internal._sample_multinomial(data, shape, get_prob, dtype=dtype, **kwargs)
 
 
 def shuffle(data, **kwargs):
diff --git a/src/common/utils.h b/src/common/utils.h
index f0ef94097bb..be78bf42547 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -43,6 +43,7 @@
 #include <thread>
 #include <algorithm>
 #include <functional>
+#include <limits>
 
 #include "../operator/mxnet_op.h"
 
@@ -617,6 +618,21 @@ FCompType GetFCompute(const nnvm::Op* op, const std::string& name,
   }
 }
 
+/*!
+ * \brief Return the max integer value representable in the type `T` without loss of precision.
+ */
+template <typename T>
+constexpr size_t MaxIntegerValue() {
+  return std::is_integral<T>::value ?
+    std::numeric_limits<T>::max():
+    size_t(2) << (std::numeric_limits<T>::digits - 1);
+}
+
+template <>
+constexpr size_t MaxIntegerValue<mshadow::half::half_t>() {
+  return size_t(2) << 10;
+}
+
 }  // namespace common
 }  // namespace mxnet
 #endif  // MXNET_COMMON_UTILS_H_
diff --git a/src/operator/random/sample_multinomial_op.cc b/src/operator/random/sample_multinomial_op.cc
index a513f9866e5..1bacb023588 100644
--- a/src/operator/random/sample_multinomial_op.cc
+++ b/src/operator/random/sample_multinomial_op.cc
@@ -97,7 +97,8 @@ struct SampleMultinomialBackwardCPUKernel {
                                   DType* ograd, DType* dist, IType* out,
                                   DType* igrad) {
     for (index_t j = 0; j < M; ++j) {
-      igrad[i*K + out[i*M + j]] += ograd[i*M + j] / dist[i*K + out[i*M + j]];
+      igrad[i*K + static_cast<size_t>(out[i*M + j])] +=
+        ograd[i*M + j] / dist[i*K + static_cast<size_t>(out[i*M + j])];
     }
   }
 };
diff --git a/src/operator/random/sample_multinomial_op.cu b/src/operator/random/sample_multinomial_op.cu
index 27f288834a9..039b9d15597 100644
--- a/src/operator/random/sample_multinomial_op.cu
+++ b/src/operator/random/sample_multinomial_op.cu
@@ -37,7 +37,8 @@ struct SampleMultinomialBackwardGPUKernel {
                                   DType* ograd, DType* dist, IType* out,
                                   DType* igrad) {
     for (index_t j = 0; j < M; ++j) {
-      atomicAdd(&igrad[i*K + out[i*M + j]], ograd[i*M + j] / dist[i*K + out[i*M + j]]);
+      atomicAdd(&igrad[i*K + static_cast<size_t>(out[i*M + j])],
+        ograd[i*M + j] / dist[i*K + static_cast<size_t>(out[i*M + j])]);
     }
   }
 };
diff --git a/src/operator/random/sample_multinomial_op.h b/src/operator/random/sample_multinomial_op.h
index 898ca050891..e0f0d685c8c 100644
--- a/src/operator/random/sample_multinomial_op.h
+++ b/src/operator/random/sample_multinomial_op.h
@@ -49,10 +49,13 @@ struct SampleMultinomialParam : public dmlc::Parameter<SampleMultinomialParam> {
           "result. This is usually used for differentiating through "
           "stochastic variables, e.g. in reinforcement learning.");
     DMLC_DECLARE_FIELD(dtype)
+    .add_enum("uint8", mshadow::kUint8)
     .add_enum("int32", mshadow::kInt32)
+    .add_enum("float16", mshadow::kFloat16)
+    .add_enum("float32", mshadow::kFloat32)
+    .add_enum("float64", mshadow::kFloat64)
     .set_default(mshadow::kInt32)
-    .describe("DType of the output in case this can't be inferred. "
-              "Only support int32 for now.");
+    .describe("DType of the output in case this can't be inferred.");
   }
 };
 
@@ -67,6 +70,11 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
   const TShape& ishape = (*in_attrs)[0];
   if (!ishape.ndim()) return false;
 
+  MSHADOW_TYPE_SWITCH(param.dtype, DType, {
+    CHECK_LE(ishape[ishape.ndim() - 1], mxnet::common::MaxIntegerValue<DType>())
+    << "'dtype' does not have a sufficient precision to represent the indices of the input array.";
+  });
+
   if (ishape.ndim() == 1) {
     if (param.shape.ndim()) {
       SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
@@ -155,9 +163,11 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
     Tensor<xpu, 1, float> uniform =
       ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(N*M), s);
     prnd->SampleUniform(&uniform, 0, 1);
-    Kernel<SampleMultinomialKernel, xpu>::Launch(
-      s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<int>(),
-      param.get_prob ? outputs[1].dptr<DType>() : nullptr);
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
+      Kernel<SampleMultinomialKernel, xpu>::Launch(
+        s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<IType>(),
+        param.get_prob ? outputs[1].dptr<DType>() : nullptr);
+    });
   });
 }
 
@@ -182,9 +192,11 @@ void SampleMultinomialBackward(const nnvm::NodeAttrs& attrs,
       Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
       out = 0;
     }
-    Kernel<kernel, xpu>::Launch(
-      s, N, K, M, inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
-      inputs[2].dptr<int>(), outputs[0].dptr<DType>());
+    MSHADOW_TYPE_SWITCH(inputs[2].type_flag_, IType, {
+      Kernel<kernel, xpu>::Launch(
+        s, N, K, M, inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
+        inputs[2].dptr<IType>(), outputs[0].dptr<DType>());
+    });
   });
 }
 
diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py
index f7b13048c1d..40723b270a3 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -377,34 +377,45 @@ def test_parallel_random_seed_setting_for_context():
 
 @with_seed()
 def test_sample_multinomial():
-    for x in [mx.nd.array([[0,1,2,3,4],[4,3,2,1,0]])/10.0, mx.nd.array([0,1,2,3,4])/10.0]:
-        dx = mx.nd.ones_like(x)
-        mx.contrib.autograd.mark_variables([x], [dx])
-        # Adding rtol and increasing samples needed to pass with seed 2951820647
-        samples = 5000
-        with mx.autograd.record():
-            y, prob = mx.nd.random.multinomial(x, shape=samples, get_prob=True)
-            r = prob * 5
-            r.backward()
-
-        y = y.asnumpy()
-        x = x.asnumpy()
-        dx = dx.asnumpy()
-        if len(x.shape) is 1:
-            x = x.reshape((1, x.shape[0]))
-            dx = dx.reshape(1, dx.shape[0])
-            y = y.reshape((1, y.shape[0]))
-            prob = prob.reshape((1, prob.shape[0]))
-        for i in range(x.shape[0]):
-            freq = np.bincount(y[i,:], minlength=5)/np.float32(samples)*x[i,:].sum()
-            mx.test_utils.assert_almost_equal(freq, x[i], rtol=0.20)
-            rprob = x[i][y[i]]/x[i].sum()
-            mx.test_utils.assert_almost_equal(np.log(rprob), prob.asnumpy()[i], atol=1e-5)
-
-            real_dx = np.zeros((5,))
-            for j in range(samples):
-                real_dx[y[i][j]] += 5.0 / rprob[j]
-            mx.test_utils.assert_almost_equal(real_dx, dx[i, :], rtol=1e-4, atol=1e-5)
+    for dtype in ['uint8', 'int32', 'float16', 'float32', 'float64']: # output array types
+        for x in [mx.nd.array([[0,1,2,3,4],[4,3,2,1,0]])/10.0, mx.nd.array([0,1,2,3,4])/10.0]:
+            dx = mx.nd.ones_like(x)
+            mx.contrib.autograd.mark_variables([x], [dx])
+            # Adding rtol and increasing samples needed to pass with seed 2951820647
+            samples = 5000
+            with mx.autograd.record():
+                y, prob = mx.nd.random.multinomial(x, shape=samples, get_prob=True, dtype=dtype)
+                r = prob * 5
+                r.backward()
+
+            assert(np.dtype(dtype) == y.dtype)
+            y = y.asnumpy()
+            x = x.asnumpy()
+            dx = dx.asnumpy()
+            if len(x.shape) is 1:
+                x = x.reshape((1, x.shape[0]))
+                dx = dx.reshape(1, dx.shape[0])
+                y = y.reshape((1, y.shape[0]))
+                prob = prob.reshape((1, prob.shape[0]))
+            for i in range(x.shape[0]):
+                freq = np.bincount(y[i,:].astype('int32'), minlength=5)/np.float32(samples)*x[i,:].sum()
+                mx.test_utils.assert_almost_equal(freq, x[i], rtol=0.20)
+                rprob = x[i][y[i].astype('int32')]/x[i].sum()
+                mx.test_utils.assert_almost_equal(np.log(rprob), prob.asnumpy()[i], atol=1e-5)
+
+                real_dx = np.zeros((5,))
+                for j in range(samples):
+                    real_dx[int(y[i][j])] += 5.0 / rprob[j]
+                mx.test_utils.assert_almost_equal(real_dx, dx[i, :], rtol=1e-4, atol=1e-5)
+    for dtype in ['uint8', 'float16', 'float32']:
+        # Bound check for the output data types. 'int32' and 'float64' require large memory so are skipped.
+        x = mx.nd.zeros(2 ** 25)  # Larger than the max integer in float32 without precision loss.
+        bound_check = False
+        try:
+            y = mx.nd.random.multinomial(x, dtype=dtype)
+        except mx.MXNetError as e:
+            bound_check = True
+        assert bound_check
 
 # Test the generators with the chi-square testing
 @with_seed()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services