You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2019/03/13 18:08:08 UTC
[incubator-mxnet] branch master updated: Fix relative difference
scala (#14417)
This is an automated email from the ASF dual-hosted git repository.
lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 82504ad Fix relative difference scala (#14417)
82504ad is described below
commit 82504adc63b91f7b7be9075f8110f417944ba413
Author: Dang Trung Kien <ki...@pm.me>
AuthorDate: Thu Mar 14 02:07:41 2019 +0800
Fix relative difference scala (#14417)
* Fix relative difference scala
* Increase number of cases for scala arange test
* Add cases where arange produces NDArray of [0]
* Remote whitespace
---
.../core/src/test/scala/org/apache/mxnet/CheckUtils.scala | 4 ++--
.../core/src/test/scala/org/apache/mxnet/NDArraySuite.scala | 7 +++++++
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/CheckUtils.scala b/scala-package/core/src/test/scala/org/apache/mxnet/CheckUtils.scala
index 1ddb292..7602b53 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/CheckUtils.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/CheckUtils.scala
@@ -21,13 +21,13 @@ object CheckUtils {
def reldiff(a: NDArray, b: NDArray): Float = {
val diff = NDArray.sum(NDArray.abs(a - b)).toScalar
val norm = NDArray.sum(NDArray.abs(a)).toScalar
- diff / norm
+ if (diff < Float.MinPositiveValue) diff else diff / norm
}
def reldiff(a: Array[Float], b: Array[Float]): Float = {
val diff =
(a zip b).map { case (aElem, bElem) => Math.abs(aElem - bElem) }.sum
val norm: Float = a.reduce(Math.abs(_) + Math.abs(_))
- diff / norm
+ if (diff < Float.MinPositiveValue) diff else diff / norm
}
}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 72a5974..206094c 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -355,6 +355,13 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
val result3 = 0f to stop by 1f
val range3 = NDArray.arange(stop)
assert(CheckUtils.reldiff(result3.toArray, range3.toArray) <= 1e-4f)
+
+ val stop4 = Math.abs(stop)
+ val step4 = stop4 + Math.abs(scala.util.Random.nextFloat())
+ val result4 = (0.0 until stop4.toDouble by step4.toDouble)
+ .flatMap(x => Array.fill[Float](repeat)(x.toFloat))
+ val range4 = NDArray.arange(stop4, step = step4, repeat = repeat)
+ assert(CheckUtils.reldiff(result4.toArray, range4.toArray) <= 1e-4f)
}
}