You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2019/03/13 18:08:08 UTC

[incubator-mxnet] branch master updated: Fix relative difference scala (#14417)

This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 82504ad  Fix relative difference scala (#14417)
82504ad is described below

commit 82504adc63b91f7b7be9075f8110f417944ba413
Author: Dang Trung Kien <ki...@pm.me>
AuthorDate: Thu Mar 14 02:07:41 2019 +0800

    Fix relative difference scala (#14417)
    
    * Fix relative difference scala
    
    * Increase number of cases for scala arange test
    
    * Add cases where arange produces NDArray of [0]
    
    * Remote whitespace
---
 .../core/src/test/scala/org/apache/mxnet/CheckUtils.scala          | 4 ++--
 .../core/src/test/scala/org/apache/mxnet/NDArraySuite.scala        | 7 +++++++
 2 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/CheckUtils.scala b/scala-package/core/src/test/scala/org/apache/mxnet/CheckUtils.scala
index 1ddb292..7602b53 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/CheckUtils.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/CheckUtils.scala
@@ -21,13 +21,13 @@ object CheckUtils {
   def reldiff(a: NDArray, b: NDArray): Float = {
     val diff = NDArray.sum(NDArray.abs(a - b)).toScalar
     val norm = NDArray.sum(NDArray.abs(a)).toScalar
-    diff / norm
+    if (diff < Float.MinPositiveValue) diff else diff / norm
   }
 
   def reldiff(a: Array[Float], b: Array[Float]): Float = {
     val diff =
       (a zip b).map { case (aElem, bElem) => Math.abs(aElem - bElem) }.sum
     val norm: Float = a.reduce(Math.abs(_) + Math.abs(_))
-    diff / norm
+    if (diff < Float.MinPositiveValue) diff else diff / norm
   }
 }
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 72a5974..206094c 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -355,6 +355,13 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
       val result3 = 0f to stop by 1f
       val range3 = NDArray.arange(stop)
       assert(CheckUtils.reldiff(result3.toArray, range3.toArray) <= 1e-4f)
+
+      val stop4 = Math.abs(stop)
+      val step4 = stop4 + Math.abs(scala.util.Random.nextFloat())
+      val result4 = (0.0 until stop4.toDouble by step4.toDouble)
+        .flatMap(x => Array.fill[Float](repeat)(x.toFloat))
+      val range4 = NDArray.arange(stop4, step = step4, repeat = repeat)
+      assert(CheckUtils.reldiff(result4.toArray, range4.toArray) <= 1e-4f)
     }
   }