You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/04/02 02:19:42 UTC

spark git commit: [SPARK-6580] [MLLIB] Optimize LogisticRegressionModel.predictPoint

Repository: spark
Updated Branches:
  refs/heads/master 2fa3b47db -> 86b439935


[SPARK-6580] [MLLIB] Optimize LogisticRegressionModel.predictPoint

https://issues.apache.org/jira/browse/SPARK-6580

Author: Yanbo Liang <yb...@gmail.com>

Closes #5249 from yanboliang/spark-6580 and squashes the following commits:

6f47f21 [Yanbo Liang] address comments
4e0bd0f [Yanbo Liang] fix typos
04e2e2a [Yanbo Liang] trigger jenkins
cad5bcd [Yanbo Liang] Optimize LogisticRegressionModel.predictPoint


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/86b43993
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/86b43993
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/86b43993

Branch: refs/heads/master
Commit: 86b43993517104e6d5ad0785704ceec6db8acc20
Parents: 2fa3b47
Author: Yanbo Liang <yb...@gmail.com>
Authored: Wed Apr 1 17:19:36 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Apr 1 17:19:36 2015 -0700

----------------------------------------------------------------------
 .../classification/LogisticRegression.scala     | 55 +++++++++-----------
 1 file changed, 26 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/86b43993/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index e7c3599..057b628 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -62,6 +62,15 @@ class LogisticRegressionModel (
       s" but was given weights of length ${weights.size}")
   }
 
+  private val dataWithBiasSize: Int = weights.size / (numClasses - 1)
+
+  private val weightsArray: Array[Double] = weights match {
+    case dv: DenseVector => dv.values
+    case _ =>
+      throw new IllegalArgumentException(
+        s"weights only supports dense vector but got type ${weights.getClass}.")
+  }
+
   /**
    * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification.
    */
@@ -74,6 +83,7 @@ class LogisticRegressionModel (
    * Sets the threshold that separates positive predictions from negative predictions
    * in Binary Logistic Regression. An example with prediction score greater than or equal to
    * this threshold is identified as an positive, and negative otherwise. The default value is 0.5.
+   * It is only used for binary classification.
    */
   @Experimental
   def setThreshold(threshold: Double): this.type = {
@@ -84,6 +94,7 @@ class LogisticRegressionModel (
   /**
    * :: Experimental ::
    * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
+   * It is only used for binary classification.
    */
   @Experimental
   def getThreshold: Option[Double] = threshold
@@ -91,6 +102,7 @@ class LogisticRegressionModel (
   /**
    * :: Experimental ::
    * Clears the threshold so that `predict` will output raw prediction scores.
+   * It is only used for binary classification.
    */
   @Experimental
   def clearThreshold(): this.type = {
@@ -106,7 +118,6 @@ class LogisticRegressionModel (
 
     // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression.
     if (numClasses == 2) {
-      require(numFeatures == weightMatrix.size)
       val margin = dot(weightMatrix, dataMatrix) + intercept
       val score = 1.0 / (1.0 + math.exp(-margin))
       threshold match {
@@ -114,30 +125,9 @@ class LogisticRegressionModel (
         case None => score
       }
     } else {
-      val dataWithBiasSize = weightMatrix.size / (numClasses - 1)
-
-      val weightsArray = weightMatrix match {
-        case dv: DenseVector => dv.values
-        case _ =>
-          throw new IllegalArgumentException(
-            s"weights only supports dense vector but got type ${weightMatrix.getClass}.")
-      }
-
-      val margins = (0 until numClasses - 1).map { i =>
-        var margin = 0.0
-        dataMatrix.foreachActive { (index, value) =>
-          if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index)
-        }
-        // Intercept is required to be added into margin.
-        if (dataMatrix.size + 1 == dataWithBiasSize) {
-          margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size)
-        }
-        margin
-      }
-
       /**
-       * Find the one with maximum margins. If the maxMargin is negative, then the prediction
-       * result will be the first class.
+       * Compute and find the one with maximum margins. If the maxMargin is negative, then the
+       * prediction result will be the first class.
        *
        * PS, if you want to compute the probabilities for each outcome instead of the outcome
        * with maximum probability, remember to subtract the maxMargin from margins if maxMargin
@@ -145,13 +135,20 @@ class LogisticRegressionModel (
        */
       var bestClass = 0
       var maxMargin = 0.0
-      var i = 0
-      while(i < margins.size) {
-        if (margins(i) > maxMargin) {
-          maxMargin = margins(i)
+      val withBias = dataMatrix.size + 1 == dataWithBiasSize
+      (0 until numClasses - 1).foreach { i =>
+        var margin = 0.0
+        dataMatrix.foreachActive { (index, value) =>
+          if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index)
+        }
+        // Intercept is required to be added into margin.
+        if (withBias) {
+          margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size)
+        }
+        if (margin > maxMargin) {
+          maxMargin = margin
           bestClass = i + 1
         }
-        i += 1
       }
       bestClass.toDouble
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org