You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/10/28 03:43:47 UTC

git commit: [MLlib] SPARK-3987: add test case on objective value for NNLS

Repository: spark
Updated Branches:
  refs/heads/master bfa614b12 -> 7e3a1ada8


[MLlib] SPARK-3987: add test case on objective value for NNLS

Also update step parameter to pass the proposed test

Author: coderxiang <sh...@gmail.com>

Closes #2965 from coderxiang/nnls-test and squashes the following commits:

24b06f9 [coderxiang] add test case on objective value for NNLS; update step parameter to pass the test


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7e3a1ada
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7e3a1ada
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7e3a1ada

Branch: refs/heads/master
Commit: 7e3a1ada86e6adf1ddd4d8a321824daf5f3b2c75
Parents: bfa614b
Author: coderxiang <sh...@gmail.com>
Authored: Mon Oct 27 19:43:39 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Oct 27 19:43:39 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/optimization/NNLS.scala  |  2 +-
 .../spark/mllib/optimization/NNLSSuite.scala    | 30 ++++++++++++++++++++
 2 files changed, 31 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7e3a1ada/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
index e4b436b..fef062e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
@@ -79,7 +79,7 @@ private[mllib] object NNLS {
     // stopping condition
     def stop(step: Double, ndir: Double, nx: Double): Boolean = {
         ((step.isNaN) // NaN
-      || (step < 1e-6) // too small or negative
+      || (step < 1e-7) // too small or negative
       || (step > 1e40) // too small; almost certainly numerical problems
       || (ndir < 1e-12 * nx) // gradient relatively too small
       || (ndir < 1e-32) // gradient absolutely too small; numerical issues may lurk

http://git-wip-us.apache.org/repos/asf/spark/blob/7e3a1ada/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
index b781a6a..82c327b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
@@ -37,6 +37,12 @@ class NNLSSuite extends FunSuite {
     (ata, atb)
   }
 
+  /** Compute the objective value */
+  def computeObjectiveValue(ata: DoubleMatrix, atb: DoubleMatrix, x: DoubleMatrix): Double = {
+    val res = (x.transpose().mmul(ata).mmul(x)).mul(0.5).sub(atb.dot(x))
+    res.get(0)
+  }
+
   test("NNLS: exact solution cases") {
     val n = 20
     val rand = new Random(12346)
@@ -79,4 +85,28 @@ class NNLSSuite extends FunSuite {
       assert(x(i) >= 0)
     }
   }
+
+  test("NNLS: objective value test") {
+    val n = 5
+    val ata = new DoubleMatrix(5, 5
+      , 517399.13534, 242529.67289, -153644.98976, 130802.84503, -798452.29283
+      , 242529.67289, 126017.69765, -75944.21743, 81785.36128, -405290.60884
+      , -153644.98976, -75944.21743, 46986.44577, -45401.12659, 247059.51049
+      , 130802.84503, 81785.36128, -45401.12659, 67457.31310, -253747.03819
+      , -798452.29283, -405290.60884, 247059.51049, -253747.03819, 1310939.40814
+    )
+    val atb = new DoubleMatrix(5, 1,
+      -31755.05710, 13047.14813, -20191.24443, 25993.77580, 11963.55017)
+
+    /** reference solution obtained from matlab function quadprog */
+    val refx = new DoubleMatrix(Array(34.90751, 103.96254, 0.00000, 27.82094, 58.79627))
+    val refObj = computeObjectiveValue(ata, atb, refx)
+
+
+    val ws = NNLS.createWorkspace(n)
+    val x = new DoubleMatrix(NNLS.solve(ata, atb, ws))
+    val obj = computeObjectiveValue(ata, atb, x)
+
+    assert(obj < refObj + 1E-5)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org