You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2020/07/10 18:58:12 UTC

[incubator-mxnet] 01/01: Fix scipy dependency in probability module

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch leezu-patch-3
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 35acfdeaab23a5eff61929ef5825be236b8dc81f
Author: Leonard Lausen <la...@amazon.com>
AuthorDate: Fri Jul 10 11:57:06 2020 -0700

    Fix scipy dependency in probability module
---
 .../mxnet/gluon/probability/distributions/utils.py | 25 +++++++++++++++++-----
 1 file changed, 20 insertions(+), 5 deletions(-)

diff --git a/python/mxnet/gluon/probability/distributions/utils.py b/python/mxnet/gluon/probability/distributions/utils.py
index f8a03c4..75837c2 100644
--- a/python/mxnet/gluon/probability/distributions/utils.py
+++ b/python/mxnet/gluon/probability/distributions/utils.py
@@ -24,7 +24,10 @@ __all__ = ['getF', 'prob2logit', 'logit2prob', 'cached_property', 'sample_n_shap
 from functools import update_wrapper
 from numbers import Number
 import numpy as onp
-import scipy.special as sc
+try:
+    import scipy.special as sc
+except ImportError:
+    sc = None
 from .... import symbol as sym
 from .... import ndarray as nd
 
@@ -48,7 +51,10 @@ def digamma(F):
         """Return digamma(value)
         """
         if isinstance(value, Number):
-            return sc.digamma(value, dtype='float32')
+            if sc is not None:
+                return sc.digamma(value, dtype='float32')
+            else:
+                raise ValueError('Numbers are not supported as input if scipy is not installed')
         return F.npx.digamma(value)
     return compute
 
@@ -60,7 +66,10 @@ def gammaln(F):
         """Return log(gamma(value))
         """
         if isinstance(value, Number):
-            return sc.gammaln(value, dtype='float32')
+            if sc is not None:
+                return sc.gammaln(value, dtype='float32')
+            else:
+                raise ValueError('Numbers are not supported as input if scipy is not installed')
         return F.npx.gammaln(value)
     return compute
 
@@ -70,7 +79,10 @@ def erf(F):
     """
     def compute(value):
         if isinstance(value, Number):
-            return sc.erf(value)
+            if sc is not None:
+                return sc.erf(value, dtype='float32')
+            else:
+                raise ValueError('Numbers are not supported as input if scipy is not installed')
         return F.npx.erf(value)
     return compute
 
@@ -80,7 +92,10 @@ def erfinv(F):
     """
     def compute(value):
         if isinstance(value, Number):
-            return sc.erfinv(value)
+            if sc is not None:
+                return sc.erfinv(value, dtype='float32')
+            else:
+                raise ValueError('Numbers are not supported as input if scipy is not installed')
         return F.npx.erfinv(value)
     return compute