You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/05/14 02:31:33 UTC

git commit: SPARK-1791 - SVM implementation does not use threshold parameter

Repository: spark
Updated Branches:
  refs/heads/master 16ffadcc4 -> d1e487473


SPARK-1791 - SVM implementation does not use threshold parameter

Summary:
https://issues.apache.org/jira/browse/SPARK-1791

Simple fix, and backward compatible, since

- anyone who set the threshold was getting completely wrong answers.
- anyone who did not set the threshold had the default 0.0 value for the threshold anyway.

Test Plan:
Unit test added that is verified to fail under the old implementation,
and pass under the new implementation.

Reviewers:

CC:

Author: Andrew Tulloch <an...@tullo.ch>

Closes #725 from ajtulloch/SPARK-1791-SVM and squashes the following commits:

770f55d [Andrew Tulloch] SPARK-1791 - SVM implementation does not use threshold parameter


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d1e48747
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d1e48747
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d1e48747

Branch: refs/heads/master
Commit: d1e487473fd509f28daf28dcda856f3c2f1194ec
Parents: 16ffadc
Author: Andrew Tulloch <an...@tullo.ch>
Authored: Tue May 13 17:31:27 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Tue May 13 17:31:27 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/classification/SVM.scala |  2 +-
 .../spark/mllib/classification/SVMSuite.scala   | 37 ++++++++++++++++++++
 2 files changed, 38 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d1e48747/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index e052135..316ecd7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -65,7 +65,7 @@ class SVMModel private[mllib] (
       intercept: Double) = {
     val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
     threshold match {
-      case Some(t) => if (margin < 0) 0.0 else 1.0
+      case Some(t) => if (margin < t) 0.0 else 1.0
       case None => margin
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d1e48747/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 77d6f04..886c71d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -69,6 +69,43 @@ class SVMSuite extends FunSuite with LocalSparkContext {
     assert(numOffPredictions < input.length / 5)
   }
 
+  test("SVM with threshold") {
+    val nPoints = 10000
+
+    // NOTE: Intercept should be small for generating equal 0s and 1s
+    val A = 0.01
+    val B = -1.5
+    val C = 1.0
+
+    val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
+
+    val testRDD = sc.parallelize(testData, 2)
+    testRDD.cache()
+
+    val svm = new SVMWithSGD().setIntercept(true)
+    svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
+
+    val model = svm.run(testRDD)
+
+    val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
+    val validationRDD  = sc.parallelize(validationData, 2)
+
+    // Test prediction on RDD.
+
+    var predictions = model.predict(validationRDD.map(_.features)).collect()
+    assert(predictions.count(_ == 0.0) != predictions.length)
+
+    // High threshold makes all the predictions 0.0
+    model.setThreshold(10000.0)
+    predictions = model.predict(validationRDD.map(_.features)).collect()
+    assert(predictions.count(_ == 0.0) == predictions.length)
+
+    // Low threshold makes all the predictions 1.0
+    model.setThreshold(-10000.0)
+    predictions = model.predict(validationRDD.map(_.features)).collect()
+    assert(predictions.count(_ == 1.0) == predictions.length)
+  }
+
   test("SVM using local random SGD") {
     val nPoints = 10000