You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/04/30 14:51:31 UTC

[incubator-mxnet] branch master updated: [BUGFIX] fix #18936, #18937 (#19878)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6f4ac54  [BUGFIX] fix #18936, #18937 (#19878)
6f4ac54 is described below

commit 6f4ac54b3d5727301ba6bd08bc585254c7f2d925
Author: Nikolay Ulmasov <ul...@hotmail.com>
AuthorDate: Fri Apr 30 15:47:58 2021 +0100

    [BUGFIX] fix #18936, #18937 (#19878)
    
    * fix #18938
    
    * fix #18939, #18940
    
    * fix #18936 and #18937
    
    Co-authored-by: r3stl355 <ul...@amazon.com>
---
 src/operator/random/pdf_op.h         |  6 +++++
 tests/python/unittest/test_random.py | 51 ++++++++++++++++++++++++++++++++++++
 2 files changed, 57 insertions(+)

diff --git a/src/operator/random/pdf_op.h b/src/operator/random/pdf_op.h
index 57bddfc..f6dc777 100644
--- a/src/operator/random/pdf_op.h
+++ b/src/operator/random/pdf_op.h
@@ -514,6 +514,12 @@ void PdfOpForward(const nnvm::NodeAttrs& attrs,
   CHECK_NE(req[0], kAddTo);
   CHECK_EQ(inputs.size(), pnum + 1);
   CHECK_EQ(outputs.size(), 1);
+
+  // Skip kernel launch for zero-size tensors
+  if (inputs[1].shape_.Size() == 0U || outputs[0].Size() == 0U) {
+    return;
+  }
+
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   const PdfParam& param = nnvm::get<PdfParam>(attrs.parsed);
   MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py
index a260f63..9cd935d 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -27,6 +27,8 @@ import scipy.stats as ss
 import unittest
 import pytest
 from mxnet.test_utils import *
+from mxnet.base import MXNetError
+from common import assertRaises
 
 def same(a, b):
     return np.sum(a != b) == 0
@@ -1029,3 +1031,52 @@ def test_sample_multinomial_num_outputs():
     assert isinstance(out, list)
     assert len(out) == 2
 
+
+@use_np
+def test_dirichlet_zero_size_dim():
+    """ Tests for no error when dealing with zero-size array in calculating PDF of Poisson distribution
+    Issue: https://github.com/apache/incubator-mxnet/issues/18936
+    """
+
+    def test_valid_zero_dim():
+        alpha = mx.nd.array(np.random.rand(0))
+        sample = mx.nd.array(np.random.rand(4, 0))
+        res = mx.nd.op.random_pdf_dirichlet(sample=sample, alpha=alpha)
+        assert res.shape == sample.shape[:-1]
+
+    def test_valid_zero_multi_dim():
+        alpha = mx.nd.array(np.random.rand(4, 0))
+        sample = mx.nd.array(np.random.rand(4, 3, 0))
+        res = mx.nd.op.random_pdf_dirichlet(sample=sample, alpha=alpha)
+        assert res.shape == sample.shape[:-1]
+
+    def test_invalid_zero_dim():
+        """The shape of *alpha* must match the left-most part of the *sample* shape"""
+        alpha = mx.nd.array(np.random.rand(1))
+        sample = mx.nd.array(np.random.rand(4, 0))
+        assertRaises(MXNetError, mx.nd.op.random_pdf_dirichlet, sample, alpha)
+        
+    test_valid_zero_dim()
+    test_valid_zero_multi_dim()
+    test_invalid_zero_dim()
+
+@use_np
+def test_poisson_zero_size_dim():
+    """ Tests for no error when dealing with zero-size array in calculating PDF of Poisson distribution
+    Issue: https://github.com/apache/incubator-mxnet/issues/18937
+    """
+
+    def test_valid_zero_dim():
+        lam = mx.nd.array(np.random.rand(0))
+        sample = mx.nd.array(np.random.rand(0, 2))
+        res = mx.nd.op.random_pdf_poisson(sample=sample, lam=lam)
+        assert res.shape == sample.shape
+
+    def test_invalid_zero_dim():
+        """The shape of *lam* must match the leftmost part of the *sample* shape"""
+        lam = mx.nd.array(np.random.rand(0))
+        sample = mx.nd.array(np.random.rand(1, 2))
+        assertRaises(MXNetError, mx.nd.op.random_pdf_poisson, sample, lam)
+
+    test_valid_zero_dim()
+    test_invalid_zero_dim()
\ No newline at end of file