You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/03 10:49:48 UTC

[GitHub] mseeger closed pull request #9961: Unary ops norm_logcdf, norm_derivlogcdf for log CDF of standard normal

mseeger closed pull request #9961: Unary ops norm_logcdf, norm_derivlogcdf for log CDF of standard normal
URL: https://github.com/apache/incubator-mxnet/pull/9961
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/ndarray.md b/docs/api/python/ndarray/ndarray.md
index 59ca4a612e6..7d639a5e127 100644
--- a/docs/api/python/ndarray/ndarray.md
+++ b/docs/api/python/ndarray/ndarray.md
@@ -606,6 +606,8 @@ The `ndarray` package provides several classes:
     sign
     gamma
     gammaln
+	norm_logcdf
+	norm_derivlogcdf
 ```
 
 ## Neural network
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index e383597236d..92d8e4e7845 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -607,6 +607,8 @@ Composite multiple symbols into a new one by an operator.
     sign
     gamma
     gammaln
+	norm_logcdf
+	norm_derivlogcdf
 ```
 
 ## Neural network
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 1d4284e1ac2..f9776748095 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -621,6 +621,72 @@ MSHADOW_XINLINE double gammaln_grad::Map<double>(double a) {
   return special_functions::cephes::psi<double>(a);
 }
 
+/***** norm_logcdf ******/
+
+struct norm_logcdf : public mxnet_op::tunable {
+  template<typename DType>
+  MSHADOW_XINLINE static DType Map(DType a) {
+    // Default implementation using floating precision
+    float af(static_cast<float>(a));
+    return DType(special_functions::apbsint::logCdfNormal<float>(af));
+  }
+};
+
+template<>
+MSHADOW_XINLINE double norm_logcdf::Map<double>(double a) {
+  return special_functions::apbsint::logCdfNormal<double>(a);
+}
+
+struct norm_logcdf_grad : public mxnet_op::tunable {
+  template<typename DType>
+  MSHADOW_XINLINE static DType Map(DType a) {
+    // Default implementation using floating precision
+    float af(static_cast<float>(a));
+    return DType(special_functions::apbsint::derivLogCdfNormal<float>(af));
+  }
+};
+
+template<>
+MSHADOW_XINLINE double norm_logcdf_grad::Map<double>(double a) {
+  return special_functions::apbsint::derivLogCdfNormal<double>(a);
+}
+
+/***** norm_derivlogcdf ******/
+
+struct norm_derivlogcdf : public mxnet_op::tunable {
+  template<typename DType>
+  MSHADOW_XINLINE static DType Map(DType a) {
+    // Default implementation using floating precision
+    float af(static_cast<float>(a));
+    return DType(special_functions::apbsint::derivLogCdfNormal<float>(af));
+  }
+};
+
+template<>
+MSHADOW_XINLINE double norm_derivlogcdf::Map<double>(double a) {
+  return special_functions::apbsint::derivLogCdfNormal<double>(a);
+}
+
+// NOTE: This grad would best be computed as ElemwiseGradUseInOut, with a and da as
+// input. Here, we recompute da, because ElemwiseGradUseInOut is not properly supported
+// for basic unary functions.
+struct norm_derivlogcdf_grad : public mxnet_op::tunable {
+  template<typename DType>
+  MSHADOW_XINLINE static DType Map(DType a) {
+    // Default implementation using floating precision
+    float af(static_cast<float>(a));
+    float daf(special_functions::apbsint::derivLogCdfNormal<float>(af));
+    return DType(-daf * (af + daf));
+  }
+};
+
+template<>
+MSHADOW_XINLINE double norm_derivlogcdf_grad::Map<double>(double a) {
+  double da(special_functions::apbsint::derivLogCdfNormal<double>(a));
+  return -da * (a + da);
+}
+
+
 /* Smooth L1 Loss is a loss specific for R-CNN franchise training
  * Smooth L1 Loss function:
  * f(x) = 0.5 * (sigma * x) ^ 2,     |x| < 1 / sigma^2
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index c13f1ac2fae..fe67793a9bf 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -277,6 +277,10 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gamma_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gammaln);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gammaln_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::norm_logcdf);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::norm_logcdf_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::norm_derivlogcdf);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::norm_derivlogcdf_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ceil);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::degrees);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::degrees_grad);  // NOLINT()
diff --git a/src/operator/special_functions-inl.h b/src/operator/special_functions-inl.h
index 743391e0fce..aab234c7bcc 100644
--- a/src/operator/special_functions-inl.h
+++ b/src/operator/special_functions-inl.h
@@ -9,6 +9,8 @@
 #ifndef MXNET_OPERATOR_SPECIAL_FUNCTIONS_INL_H_
 #define MXNET_OPERATOR_SPECIAL_FUNCTIONS_INL_H_
 
+#include "math_functions-inl.h"
+
 namespace mxnet {
 namespace op {
 
@@ -86,7 +88,7 @@ struct cephes {
 
   /*
    *
-   *	Psi (digamma) function
+   *    Psi (digamma) function
    *
    *
    * SYNOPSIS:
@@ -241,6 +243,213 @@ MSHADOW_XINLINE float cephes::psi_helper<float>(float s) {
     return 0.0;
   }
 }
+
+
+// This is code extracted from ApBsInT, available at
+//    https://github.com/mseeger/apbsint
+// The author of ApBsInT, Matthias Seeger (mseeger@gmail.com) is the same one
+// who ported the code to MXNet.
+//
+// NOTE: Instantiate this with DType in {float, double}, nothing else will work!
+
+// Some constants used in apbsint struct below
+template<typename DType>
+struct apbsint_const {
+  MSHADOW_XINLINE static DType m_ln2pi() {
+    return DType(1.83787706640934533908193770913);
+  }
+  MSHADOW_XINLINE static DType m_ln2() {
+    return DType(0.69314718055994530941723212146);
+  }
+  MSHADOW_XINLINE static DType m_sqrtpi() {
+    return DType(1.77245385090551602729816748334);
+  }
+  MSHADOW_XINLINE static DType m_sqrt2() {
+    return DType(1.41421356237309504880168872421);
+  }
+  MSHADOW_XINLINE static DType erf_cody_limit1() {
+    return DType(0.6629);
+  }
+  MSHADOW_XINLINE static DType erf_cody_limit2() {
+    return DType(5.6569);
+  }
+};
+
+struct apbsint {
+  // Internal helpers
+
+  /**
+   * For x >= erf_cody_limit1(), define Q(x) by
+   *   1 - Phi(x) approx N(x) x^{-1} Q(x).
+   * We compute Q(x) according to
+   *   Cody
+   *   Rational Chebyshev approximation to the error function
+   * This is done differently for x >= erf_cody_limit2() and
+   * erf_cody_limit1() <= x < erf_cody_limit2().
+   * NOTE: Q(x) -> 1 for x->infty.
+   */
+  template<typename DType>
+  MSHADOW_XINLINE static DType erfRationalHelper(DType x) {
+    int i;
+    DType res, den, y;
+
+    // MYASS(x>0.0);
+    if (x >= apbsint_const<DType>::erf_cody_limit2()) {
+      // x/sqrt(2) >= 4
+      // Q(x)   = 1 + sqrt(pi) y R_1(y),
+      // R_1(y) = poly(p_j,y) / poly(q_j,y),   y = 2/x^2
+      // Ordering of arrays: 4,3,2,1,0,5 (only for numerator p_j; q_5=1)
+      // ATTENTION: The p_j are negative of the entries here
+      DType p[] = {3.05326634961232344e-1, 3.60344899949804439e-1,
+                   1.25781726111229246e-1, 1.60837851487422766e-2,
+                   6.58749161529837803e-4, 1.63153871373020978e-2};
+      DType q[] = {2.56852019228982242,    1.87295284992346047,
+                   5.27905102951428412e-1, 6.05183413124413191e-2,
+                   2.33520497626869185e-3};
+      y = 2.0/x/x;
+      res = y*p[5];
+      den = y;
+      for (i = 0; i < 4; i++) {
+        res = (res + p[i])*y;
+        den = (den + q[i])*y;
+      }
+      // Minus, because p[j] values have to be negated
+      res = 1.0 - apbsint_const<DType>::m_sqrtpi()*y*(res + p[4])/(den + q[4]);
+    } else {
+      // x/sqrt(2) < 4, x/sqrt(2) >= 0.469
+      // Q(x)   = sqrt(pi) y R_2(y),
+      // R_2(y) = poly(p_j,y) / poly(q_j,y),   y = x/sqrt(2)
+      // Ordering of arrays: 7,6,5,4,3,2,1,0,8 (only p_8; q_8=1)
+      DType p[] = {5.64188496988670089e-1, 8.88314979438837594,
+                   6.61191906371416295e+1, 2.98635138197400131e+2,
+                   8.81952221241769090e+2, 1.71204761263407058e+3,
+                   2.05107837782607147e+3, 1.23033935479799725e+3,
+                   2.15311535474403846e-8};
+      DType q[] = {1.57449261107098347e+1, 1.17693950891312499e+2,
+                   5.37181101862009858e+2, 1.62138957456669019e+3,
+                   3.29079923573345963e+3, 4.36261909014324716e+3,
+                   3.43936767414372164e+3, 1.23033935480374942e+3};
+      y = x/apbsint_const<DType>::m_sqrt2();
+      res = y*p[8];
+      den = y;
+      for (i = 0; i < 7; i++) {
+        res = (res + p[i])*y;
+        den = (den + q[i])*y;
+      }
+      res = apbsint_const<DType>::m_sqrtpi()*y*(res + p[7])/(den + q[7]);
+    }
+
+    return res;
+  }
+
+  /**
+   * Implements rational function R_3(y),  y = x^2/2,
+   * which is used if 0 <= x < erf_cody_limit1(). In this range:
+   *   Phi(x) approx (1 + (x/sqrt(2)) R_3(x^2/2))/2
+   * See
+   *   Cody
+   *   Rational Chebyshev approximation to the error function
+   */
+  template<typename DType>
+  MSHADOW_XINLINE static DType erfRationalHelperR3(DType y) {
+    int i;
+    DType nom, den;
+
+    // MYASS(y>=0.0);
+    // R_3(y) = poly(p_j,y) / poly(q_j,y)
+    // Ordering of arrays: 3,2,1,0,4 (only for p_5; q_5=1)
+    DType p[] = {3.16112374387056560,    1.13864154151050156e+2,
+                 3.77485237685302021e+2, 3.20937758913846947e+3,
+                 1.85777706184603153e-1};
+    DType q[] = {2.36012909523441209e+1, 2.44024637934444173e+2,
+                 1.28261652607737228e+3, 2.84423683343917062e+3};
+    nom = y*p[4];
+    den = y;
+    for (i = 0; i < 3; i++) {
+      nom = (nom + p[i])*y;
+      den = (den + q[i])*y;
+    }
+
+    return (nom + p[3])/(den + q[3]);
+  }
+
+  // Exported functions
+
+  /**
+   * @param z Argument
+   * @return  log N(z|0,1)
+   */
+  template<typename DType>
+  MSHADOW_XINLINE static DType logPdfNormal(DType z) {
+    return -0.5 * (apbsint_const<DType>::m_ln2pi() + z*z);
+  }
+
+  /**
+   * If Phi(z) denotes the c.d.f. of N(0,1), this method computes
+   * log Phi(z).
+   *
+   * @param z Argument
+   * @return  log Phi(z)
+   */
+  template<typename DType>
+  MSHADOW_XINLINE static DType logCdfNormal(DType z) {
+    DType res;
+
+    if (math::fabs(z) < apbsint_const<DType>::erf_cody_limit1()) {
+      // Part 3 approximation:
+      // Phi(z) approx (1 + y R_3(y^2))/2, y = z/sqrt(2)
+      res = math::log1p((z/apbsint_const<DType>::m_sqrt2())*erfRationalHelperR3(0.5*z*z))
+        - apbsint_const<DType>::m_ln2();
+    } else {
+      // Part 1 or 2 approximation:
+      // Phi(z) approx N(z) Q(-z)/(-z), z < 0
+      // NOTE: The case z >= erf_cody_limit1() is uncritical, we could even use
+      // a cheaper approximation then
+      if (z < 0.0)
+        res = logPdfNormal(z) - math::log(-z) + math::log(erfRationalHelper(-z));
+      else
+        res = math::log1p(-math::exp(logPdfNormal(z))*erfRationalHelper(z)/z);
+    }
+
+    return res;
+  }
+
+  /**
+   * If Phi(z) denotes the c.d.f. of N(0,1), this method computes
+   *   f(z) = (d/dz) log Phi(z) = N(z)/Phi(z).
+   * NOTE: The technical report defines the hazard function
+   *   h(x) = N(x)/(1 - Phi(x)).
+   * This method computes h(-z).
+   *
+   * @param z Argument
+   * @return  (d/dz) log Phi(z)
+   */
+  template<typename DType>
+  MSHADOW_XINLINE static DType derivLogCdfNormal(DType z) {
+    DType res;
+
+    if (math::fabs(z) < apbsint_const<DType>::erf_cody_limit1()) {
+      // Part 3 approximation:
+      // Phi(z) approx (1 + y R_3(y^2))/2, y = z/sqrt(2)
+      res = 2.0 * math::exp(logPdfNormal(z)) /
+        (1.0 + (z/apbsint_const<DType>::m_sqrt2())*erfRationalHelperR3(0.5*z*z));
+    } else {
+      // Part 1 or 2:
+      // Phi(z) approx N(z) Q(-z)/(-z), z<0
+      // NOTE: The case z >= erf_cody_limit1() is uncritical, we could even use
+      // a cheaper approximation then
+      if (z < 0.0) {
+        res = -z/erfRationalHelper(-z);
+      } else {
+        DType temp = math::exp(logPdfNormal(z));
+        res = temp / (1.0 - temp*erfRationalHelper(z)/z);
+      }
+    }
+
+    return res;
+  }
+};
+
 }  // namespace special_functions
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index acd8f7b23ff..eb45f5f6b74 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -826,5 +826,42 @@ The storage type of ``gammaln`` output is always dense
 MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_gammaln,
                                                   unary_bwd<mshadow_op::gammaln_grad>);
 
+// norm_logcdf
+MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(norm_logcdf, cpu, mshadow_op::norm_logcdf)
+MXNET_ADD_SPARSE_OP_ALIAS(norm_logcdf)
+.describe(R"code(Returns ``log`` of cumulative distribution function of standard normal, \
+computed element-wise on the input array.
+
+The standard normal distribution has mean 0, variance 1.
+
+The storage type of ``norm_logcdf`` output is always dense.
+
+)code")
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_norm_logcdf"});
+
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_norm_logcdf,
+                                                  unary_bwd<mshadow_op::norm_logcdf_grad>);
+
+// norm_derivlogcdf
+MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(norm_derivlogcdf, cpu,
+                                             mshadow_op::norm_derivlogcdf)
+MXNET_ADD_SPARSE_OP_ALIAS(norm_derivlogcdf)
+.describe(R"code(Returns derivative of ``log`` of cumulative distribution function of \
+standard normal, computed element-wise on the input array.
+
+.. math::
+   y = norm_pdf(x) / norm_cdf(x)
+
+The standard normal distribution has mean 0, variance 1. ``norm_pdf`` denotes its PDF, \
+``norm_cdf`` its CDF. The expression is also the derivative of ``norm_logcdf(x)``.
+
+The storage type of ``norm_logcdf`` output is always dense.
+
+)code")
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_norm_derivlogcdf"});
+
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_norm_derivlogcdf,
+                                                  unary_bwd<mshadow_op::norm_derivlogcdf_grad>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu
index 8dfa9af74ce..8431834a2cb 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -246,5 +246,21 @@ NNVM_REGISTER_OP(_backward_gammaln)
 .set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
   gpu, unary_bwd<mshadow_op::gammaln_grad> >);
 
+// norm_logcdf
+NNVM_REGISTER_OP(norm_logcdf)
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::norm_logcdf>);
+
+NNVM_REGISTER_OP(_backward_norm_logcdf)
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
+  gpu, unary_bwd<mshadow_op::norm_logcdf_grad> >);
+
+// norm_derivlogcdf
+NNVM_REGISTER_OP(norm_derivlogcdf)
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::norm_derivlogcdf>);
+
+NNVM_REGISTER_OP(_backward_norm_derivlogcdf)
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
+  gpu, unary_bwd<mshadow_op::norm_derivlogcdf_grad> >);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 7889e084f74..7667d45a3de 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4325,7 +4325,6 @@ def test_laop_2():
     # Tests for linalg.syrk
     mnalpha_lst = [(2, 3, 1.), (5, 3, -2.), (1, 6, 5.), (3, 3, 0.5), (4, 1, 10.), (1, 1, 1.)]
     for m, n, alpha in mnalpha_lst:
-        #print('syrk: m={}, n={}, alpha={}'.format(m, n, alpha))
         data_in1 = np.random.uniform(1, 10, (m, n))
         res_syrk1 = alpha * np.dot(data_in1, data_in1.T)
         test_syrk1 = mx.sym.linalg.syrk(data1, transpose=False, alpha=alpha)
@@ -4359,7 +4358,6 @@ def test_laop_2():
     test_gelqf_l = _gelqf_second_output(data1)  # Output L (Q is not dangling)
     mn_lst = [(4, 4), (1, 1), (5, 20), (1, 10), (15, 50)]
     for m, n in mn_lst:
-        #print('gelqf: m={}, n={}'.format(m, n))
         data_in1 = np.random.normal(0., 10., (m, n))
         res_eye = np.eye(m)
         res_a = data_in1
@@ -4463,7 +4461,6 @@ def test_laop_3():
     test_syevd_l_4 = _syevd_second_output(data1_s4)
     n_lst = [4, 1, 2, 10, 14]
     for n in n_lst:
-        #print('\n** syevd: n={}'.format(n))
         data_in1 = np.random.normal(0., 10., (n, n))
         data_in1 = 0.5 * (data_in1 + data_in1.T)
         res_eye = np.eye(n)
@@ -4514,10 +4511,8 @@ def test_laop_4():
     l_np = np.array([0., 5.])
     test_syevd = mx.sym.linalg.syevd(data1)
     # float64
-    #print('float64')
     check_fw(test_syevd, [a_np], [u_np, l_np], np.float64)
     # float32
-    #print('float32')
     check_fw(test_syevd, [a_np], [u_np, l_np], np.float32)
 
 
@@ -4676,6 +4671,7 @@ def check(data, idx):
             idx = mx.nd.array([[0, 0, 0, 0]], dtype='int32')
             assert (mx.nd._internal._backward_gather_nd(data, idx, shape=(1,)).asscalar() == data.asnumpy().sum())
 
+
 def compare_forw_backw_unary_op(
         name, forward_mxnet_call, forward_numpy_call,
         backward_numpy_call, shape, input_low, input_high, rtol, atol,
@@ -4716,6 +4712,54 @@ def finite_diff_unary_op(
         name=op_name)
     check_grad(op_ex, [data_np])
 
+# Here, we compare forward and backward results for unary ops called with
+# different dtype.
+def compare_forw_backw_unary_op_dtypes(
+        name, forward_mxnet_call, shape, input_low, input_high, rtols, atols,
+        dtype_ref=np.float64, dtypes_cmp=[np.float32, np.float16]):
+    ctx = default_context()
+    data_np = np.random.uniform(input_low, input_high, shape)
+    out_grad = np.random.uniform(-2.0, 2.0, shape)
+    # Compute results for dtype_ref
+    op_name = 'unary_op={}, dtype={}'.format(name, dtype_ref)
+    data_name = op_name + '_data'
+    data_ref = mx.symbol.Variable(data_name, dtype=dtype_ref)
+    sym_ref = mx.sym.broadcast_add(
+        forward_mxnet_call(data_ref), mx.sym.zeros_like(data_ref),
+        name=op_name)
+    args_ref = {
+        data_name: mx.nd.array(data_np.astype(dtype_ref), ctx=ctx, dtype=dtype_ref)
+    }
+    args_grad_ref = {
+        data_name: mx.nd.empty(shape, ctx=ctx, dtype=dtype_ref)
+    }
+    ex_ref = sym_ref.bind(
+        ctx=ctx, grad_req='write', args=args_ref, args_grad=args_grad_ref)
+    ex_ref.forward(is_train=True)
+    res_forw_ref = [x.asnumpy() for x in ex_ref.outputs]
+    ex_ref.backward(mx.nd.array(
+        out_grad.astype(dtype_ref), ctx=ctx, dtype=dtype_ref))
+    res_grad_ref = [x.asnumpy() for x in args_grad_ref.values()]
+    # Loop over dtypes_cmp
+    for ind in range(len(dtypes_cmp)):
+        dtype = dtypes_cmp[ind]
+        _data_np = data_np.astype(dtype)
+        _out_grad = out_grad.astype(dtype)
+        # Compare forward
+        _res_forw_ref = [x.astype(dtype) for x in res_forw_ref]
+        op_name = 'unary_op={}, dtype={}'.format(name, dtype)
+        data = mx.symbol.Variable(op_name + '_data', dtype=dtype)
+        sym = mx.sym.broadcast_add(
+            forward_mxnet_call(data), mx.sym.zeros_like(data), name=op_name)
+        check_symbolic_forward(
+            sym, [_data_np], _res_forw_ref, rtol=rtols[ind], atol=atols[ind],
+            dtype=dtype)
+        # Compare backward
+        _res_grad_ref = [x.astype(dtype) for x in res_grad_ref]
+        check_symbolic_backward(
+            sym, [_data_np], [_out_grad], _res_grad_ref, rtol=rtols[ind],
+            atol=atols[ind], dtype=dtype)
+
 def np_smooth_l1(x, sigma):
     issq = 1. / sigma / sigma
     absx = np.abs(x)
@@ -4726,25 +4770,40 @@ def np_smooth_l1_grad(x, sigma):
     ssq = sigma * sigma
     return np.where(np.abs(x) < 1. / ssq, x * ssq, np.sign(x))
 
+# This needs scipy
+def np_norm_derivlogcdf(x):
+    from scipy.stats import norm
+    temp = np.square(x) + np.log(2.0 * np.pi)
+    return np.exp(-0.5 * temp - norm.logcdf(x))
+
+# This needs scipy
+def np_norm_derivlogcdf_grad(x):
+    y = np_norm_derivlogcdf(x)
+    return -y * (x + y)
+
 # Tests for unary operators (basic mathematical functions):
-# - Forward: Comparison to NumPy (several dtype)
-# - Backward: Comparison to NumPy (several dtype)
-# - Finite difference tests (only dtype = float64)
+# - Forward, backward: Comparison to NumPy (dtype = float64)
+# - Forward, backward: Comparison float64 against {float32, float16}
+# - Finite difference tests (dtype = float64)
 # Seed set because the test is not robust enough to operate on random data
 @with_seed(192837465)
 def test_unary_math_operators():
     have_scipy = True
     try:
         from scipy import special as scipy_special
+        from scipy import stats as scipy_stats
     except:
         print("Could not import scipy. Skipping unit tests for special functions")
         have_scipy = False
     shape=(9, 10)
-    dtype_l = [np.float64, np.float32, np.float16]
-    rtol_l = [1e-7, 1e-6, 1e-2]
-    rtol_less_l = [1e-6, 1e-5, 1e-2]
-    atol_l = [1e-7, 1e-6, 1e-2]
-    atol_less_l = [1e-6, 1e-5, 1e-2]
+    dtype_ref = np.float64
+    dtypes_cmp = [np.float32, np.float16]
+    rtol_np = 1e-7
+    atol_np = 1e-7
+    rtol_np_less = 1e-6
+    atol_np_less = 1e-6
+    rtol_l = [1e-5, 1e-2]
+    atol_l = [1e-5, 1e-2]
     rtol_fd = 1e-5
     atol_fd = 1e-6
     num_eps = 1e-6
@@ -4871,20 +4930,31 @@ def test_unary_math_operators():
                                 lambda x: scipy_special.gammaln(x),
                                 lambda x: scipy_special.psi(x),
                                 0.01, 20.0]
+        unary_ops['norm_logcdf'] = [lambda x: mx.sym.norm_logcdf(x),
+                                    lambda x: scipy_stats.norm.logcdf(x),
+                                    lambda x: np_norm_derivlogcdf(x),
+                                    -10.0, 5.0]
+        unary_ops['norm_derivlogcdf'] = [lambda x: mx.sym.norm_derivlogcdf(x),
+                                         lambda x: np_norm_derivlogcdf(x),
+                                         lambda x: np_norm_derivlogcdf_grad(x),
+                                         -10.0, 5.0]
     # Loop over operators
     for name, op in unary_ops.items():
-        # Loop over dtype's
-        for ind in range(len(dtype_l)):
-            dtype = dtype_l[ind]
-            if name == 'gammaln' or name == 'gamma':
-                rtol = rtol_less_l[ind]
-                atol = atol_less_l[ind]
-            else:
-                rtol = rtol_l[ind]
-                atol = atol_l[ind]
-            compare_forw_backw_unary_op(
-                name, op[0], op[1], op[2], shape, op[3], op[4], rtol, atol,
-                dtype)
+        # Compare to NumPy (float64)
+        if name == 'gammaln' or name == 'gamma' or name == 'norm_logcdf' \
+                or name == 'norm_derivlogcdf':
+            rtol = rtol_np_less
+            atol = atol_np_less
+        else:
+            rtol = rtol_np
+            atol = atol_np
+        compare_forw_backw_unary_op(
+            name, op[0], op[1], op[2], shape, op[3], op[4], rtol, atol,
+            dtype=np.float64)
+        # Compare float64 (reference) to other dtypes
+        compare_forw_backw_unary_op_dtypes(
+            name, op[0], shape, op[3], op[4], rtol_l, atol_l, dtype_ref,
+            dtypes_cmp)
         # Finite difference testing
         finite_diff_unary_op(
             name, op[0], shape, op[3], op[4], rtol_fd, atol_fd, num_eps)
@@ -4934,13 +5004,110 @@ def finite_diff_binary_op(
         name=op_name)
     check_grad(op_ex, [data1_np, data2_np])
 
-# Tests for unary operators (basic mathematical functions):
-# - Forward: Comparison to NumPy (several dtype)
-# - Backward: Comparison to NumPy (several dtype)
-# - Finite difference tests (only dtype = float64)
+# Here, we compare forward and backward results for binary ops called with
+# different dtype.
+def compare_forw_backw_binary_op_dtypes(
+        name, forward_mxnet_call, shape, input1_low, input1_high, input2_low,
+        input2_high, rtols, atols, dtype_ref=np.float64,
+        dtypes_cmp=[np.float32, np.float16]):
+    ctx = default_context()
+    data1_np = np.random.uniform(input1_low, input1_high, shape)
+    data2_np = np.random.uniform(input2_low, input2_high, shape)
+    out_grad = np.random.uniform(-2.0, 2.0, shape)
+    # Compute results for dtype_ref
+    op_name = 'binary_op={}, dtype={}'.format(name, dtype_ref)
+    data1_name = op_name + '_data1'
+    data1_ref = mx.symbol.Variable(data1_name, dtype=dtype_ref)
+    data2_name = op_name + '_data2'
+    data2_ref = mx.symbol.Variable(data2_name, dtype=dtype_ref)
+    sym_ref = mx.sym.broadcast_add(
+        forward_mxnet_call(data1_ref, data2_ref), mx.sym.zeros_like(data1_ref),
+        name=op_name)
+    args_ref = {
+        data1_name: mx.nd.array(data1_np.astype(dtype_ref), ctx=ctx, dtype=dtype_ref),
+        data2_name: mx.nd.array(data2_np.astype(dtype_ref), ctx=ctx, dtype=dtype_ref)
+    }
+    args_grad_ref = {
+        data1_name: mx.nd.empty(shape, ctx=ctx, dtype=dtype_ref),
+        data2_name: mx.nd.empty(shape, ctx=ctx, dtype=dtype_ref)
+    }
+    ex_ref = sym_ref.bind(
+        ctx=ctx, grad_req='write', args=args_ref, args_grad=args_grad_ref)
+    ex_ref.forward(is_train=True)
+    res_forw_ref = [x.asnumpy() for x in ex_ref.outputs]
+    ex_ref.backward(mx.nd.array(
+        out_grad.astype(dtype_ref), ctx=ctx, dtype=dtype_ref))
+    res_grad_ref = [args_grad_ref[data1_name].asnumpy(),
+                    args_grad_ref[data2_name].asnumpy()]
+    # Loop over dtypes_cmp
+    for ind in range(len(dtypes_cmp)):
+        dtype = dtypes_cmp[ind]
+        _data1_np = data1_np.astype(dtype)
+        _data2_np = data2_np.astype(dtype)
+        _out_grad = out_grad.astype(dtype)
+        # Compare forward
+        _res_forw_ref = [x.astype(dtype) for x in res_forw_ref]
+        op_name = 'binary_op={}, dtype={}'.format(name, dtype)
+        data1 = mx.symbol.Variable(op_name + '_data1', dtype=dtype)
+        data2 = mx.symbol.Variable(op_name + '_data2', dtype=dtype)
+        sym = mx.sym.broadcast_add(
+            forward_mxnet_call(data1, data2), mx.sym.zeros_like(data1),
+            name=op_name)
+        check_symbolic_forward(
+            sym, [_data1_np, _data2_np], _res_forw_ref, rtol=rtols[ind],
+            atol=atols[ind], dtype=dtype)
+        # Compare backward
+        _res_grad_ref = [x.astype(dtype) for x in res_grad_ref]
+        check_symbolic_backward(
+            sym, [_data1_np, _data2_np], [_out_grad], _res_grad_ref,
+            rtol=rtols[ind], atol=atols[ind], dtype=dtype)
+
+# Tests for binary operators (basic mathematical functions):
+# - Forward, backward: Comparison to NumPy (dtype = float64)
+# - Forward, backward: Comparison float64 against {float32, float16}
+# - Finite difference tests (dtype = float64)
 # Seed set because the test is not robust enough to operate on random data
 @with_seed(192837465)
 def test_binary_math_operators():
+    shape=(9, 10)
+    dtype_ref = np.float64
+    dtypes_cmp = [np.float32, np.float16]
+    rtol_np = 1e-7
+    atol_np = 1e-7
+    rtol_l = [1e-5, 1e-2]
+    atol_l = [1e-5, 1e-2]
+    rtol_fd = 1e-5
+    atol_fd = 1e-6
+    num_eps = 1e-6
+    binary_ops = {
+        'hypot' : [lambda x, y: mx.sym.hypot(x, y),
+                   lambda x, y: np.hypot(x, y),
+                   lambda x, y: x / np.hypot(x, y),
+                   lambda x, y: y / np.hypot(x, y),
+                    -5.0, 5.0, -5.0, 5.0],
+        'pow': [lambda x, y: mx.sym.pow(x, y),
+                lambda x, y: np.power(x, y),
+                lambda x, y: np.power(x, y - 1.) * y,
+                lambda x, y: np.power(x, y) * np.log(x),
+                0.2, 5.0, -4.0, 4.0]
+    }
+    # Loop over operators
+    for name, op in binary_ops.items():
+        # Compare to NumPy (float64)
+        compare_forw_backw_binary_op(
+            name, op[0], op[1], op[2], op[3], shape, op[4], op[5], op[6],
+            op[7], rtol_np, atol_np, dtype=np.float64)
+        # Compare float64 (reference) to other dtypes
+        compare_forw_backw_binary_op_dtypes(
+            name, op[0], shape, op[4], op[5], op[6], op[7], rtol_l, atol_l,
+            dtype_ref, dtypes_cmp)
+        # Finite difference testing
+        finite_diff_binary_op(
+            name, op[0], shape, op[4], op[5], op[6], op[7], rtol_fd, atol_fd,
+            num_eps)
+
+@with_seed(192837465)
+def test_binary_math_operators_2():
     shape=(9, 10)
     dtype_l = [np.float64, np.float32, np.float16]
     rtol_l = [1e-7, 1e-6, 1e-2]
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 4e5d3546280..cccfb094be2 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -746,6 +746,17 @@ def check_binary_op_with_scalar(stype,
                                        force_overlap=force_overlap,
                                        verbose=False)
 
+    # This needs scipy
+    def np_norm_derivlogcdf(x):
+        from scipy.stats import norm
+        temp = np.square(x) + np.log(2.0 * np.pi)
+        return np.exp(-0.5 * temp - norm.logcdf(x))
+
+    # This needs scipy
+    def np_norm_derivlogcdf_grad(x):
+        y = np_norm_derivlogcdf(x)
+        return -y * (x + y)
+
     # Check many basic unary operators
     def check_mathematical_core(stype, output_grad_stype=None,
                                 input_grad_stype=None, force_overlap=False,
@@ -1035,6 +1046,7 @@ def check_mathematical_core(stype, output_grad_stype=None,
 
             try:
                 from scipy import special as scipy_special
+                from scipy import stats as scipy_stats
                 import_succeeded = True
                 # gamma
                 check_sparse_mathematical_core("gamma", stype,
@@ -1054,6 +1066,24 @@ def check_mathematical_core(stype, output_grad_stype=None,
                                                input_grad_stype=input_grad_stype,
                                                force_overlap=force_overlap,
                                                density=density, ograd_density=ograd_density)
+                # norm_logcdf
+                check_sparse_mathematical_core("norm_logcdf", stype,
+                                               lambda x: mx.sym.norm_logcdf(x),
+                                               lambda x: scipy_stats.norm.logcdf(x),
+                                               lambda x: np_norm_derivlogcdf(x),
+                                               output_grad_stype=output_grad_stype,
+                                               input_grad_stype=input_grad_stype,
+                                               force_overlap=force_overlap,
+                                               density=density, ograd_density=ograd_density)
+                # norm_derivlogcdf
+                check_sparse_mathematical_core("norm_derivlogcdf", stype,
+                                               lambda x: mx.sym.norm_derivlogcdf(x),
+                                               lambda x: np_norm_derivlogcdf(x),
+                                               lambda x: np_norm_derivlogcdf_grad(x),
+                                               output_grad_stype=output_grad_stype,
+                                               input_grad_stype=input_grad_stype,
+                                               force_overlap=force_overlap,
+                                               density=density, ograd_density=ograd_density)
 
             except:
                 if import_succeeded == False:


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services