You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/08/15 13:38:43 UTC

spark git commit: [SPARK-16934][ML][MLLIB] Update LogisticCostAggregator serialization code to make it consistent with LinearRegression

Repository: spark
Updated Branches:
  refs/heads/master ddf0d1e3f -> 3d8bfe7a3


[SPARK-16934][ML][MLLIB] Update LogisticCostAggregator serialization code to make it consistent with LinearRegression

## What changes were proposed in this pull request?

Update LogisticCostAggregator serialization code to make it consistent with #14109

## How was this patch tested?
MLlib 2.0:
![image](https://cloud.githubusercontent.com/assets/19235986/17649601/5e2a79ac-61ee-11e6-833c-3bd8b5250470.png)

After this PR:
![image](https://cloud.githubusercontent.com/assets/19235986/17649599/52b002ae-61ee-11e6-9402-9feb3439880f.png)

Author: WeichenXu <We...@outlook.com>

Closes #14520 from WeichenXu123/improve_logistic_regression_costfun.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3d8bfe7a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3d8bfe7a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3d8bfe7a

Branch: refs/heads/master
Commit: 3d8bfe7a39015c84cf95561fe17eb2808ce44084
Parents: ddf0d1e
Author: WeichenXu <We...@outlook.com>
Authored: Mon Aug 15 06:38:30 2016 -0700
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Mon Aug 15 06:38:30 2016 -0700

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  | 36 +++++++++++---------
 1 file changed, 20 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3d8bfe7a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 88d1b45..fce3935 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg._
@@ -346,8 +347,9 @@ class LogisticRegression @Since("1.2.0") (
         val regParamL1 = $(elasticNetParam) * $(regParam)
         val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
 
+        val bcFeaturesStd = instances.context.broadcast(featuresStd)
         val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
-          $(standardization), featuresStd, featuresMean, regParamL2)
+          $(standardization), bcFeaturesStd, regParamL2)
 
         val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
           new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
@@ -442,6 +444,7 @@ class LogisticRegression @Since("1.2.0") (
           rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
           i += 1
         }
+        bcFeaturesStd.destroy(blocking = false)
 
         if ($(fitIntercept)) {
           (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last,
@@ -938,11 +941,15 @@ class BinaryLogisticRegressionSummary private[classification] (
  * Two LogisticAggregator can be merged together to have a summary of loss and gradient of
  * the corresponding joint dataset.
  *
+ * @param bcCoefficients The broadcast coefficients corresponding to the features.
+ * @param bcFeaturesStd The broadcast standard deviation values of the features.
  * @param numClasses the number of possible outcomes for k classes classification problem in
  *                   Multinomial Logistic Regression.
  * @param fitIntercept Whether to fit an intercept term.
  */
 private class LogisticAggregator(
+    val bcCoefficients: Broadcast[Vector],
+    val bcFeaturesStd: Broadcast[Array[Double]],
     private val numFeatures: Int,
     numClasses: Int,
     fitIntercept: Boolean) extends Serializable {
@@ -958,14 +965,9 @@ private class LogisticAggregator(
    * of the objective function.
    *
    * @param instance The instance of data point to be added.
-   * @param coefficients The coefficients corresponding to the features.
-   * @param featuresStd The standard deviation values of the features.
    * @return This LogisticAggregator object.
    */
-  def add(
-      instance: Instance,
-      coefficients: Vector,
-      featuresStd: Array[Double]): this.type = {
+  def add(instance: Instance): this.type = {
     instance match { case Instance(label, weight, features) =>
       require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
         s" Expecting $numFeatures but got ${features.size}.")
@@ -973,14 +975,16 @@ private class LogisticAggregator(
 
       if (weight == 0.0) return this
 
-      val coefficientsArray = coefficients match {
+      val coefficientsArray = bcCoefficients.value match {
         case dv: DenseVector => dv.values
         case _ =>
           throw new IllegalArgumentException(
-            s"coefficients only supports dense vector but got type ${coefficients.getClass}.")
+            "coefficients only supports dense vector" +
+              s"but got type ${bcCoefficients.value.getClass}.")
       }
       val localGradientSumArray = gradientSumArray
 
+      val featuresStd = bcFeaturesStd.value
       numClasses match {
         case 2 =>
           // For Binary Logistic Regression.
@@ -1077,24 +1081,23 @@ private class LogisticCostFun(
     numClasses: Int,
     fitIntercept: Boolean,
     standardization: Boolean,
-    featuresStd: Array[Double],
-    featuresMean: Array[Double],
+    bcFeaturesStd: Broadcast[Array[Double]],
     regParamL2: Double) extends DiffFunction[BDV[Double]] {
 
+  val featuresStd = bcFeaturesStd.value
+
   override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
     val numFeatures = featuresStd.length
     val coeffs = Vectors.fromBreeze(coefficients)
+    val bcCoeffs = instances.context.broadcast(coeffs)
     val n = coeffs.size
-    val localFeaturesStd = featuresStd
-
 
     val logisticAggregator = {
-      val seqOp = (c: LogisticAggregator, instance: Instance) =>
-        c.add(instance, coeffs, localFeaturesStd)
+      val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
       val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
 
       instances.treeAggregate(
-        new LogisticAggregator(numFeatures, numClasses, fitIntercept)
+        new LogisticAggregator(bcCoeffs, bcFeaturesStd, numFeatures, numClasses, fitIntercept)
       )(seqOp, combOp)
     }
 
@@ -1134,6 +1137,7 @@ private class LogisticCostFun(
       }
       0.5 * regParamL2 * sum
     }
+    bcCoeffs.destroy(blocking = false)
 
     (logisticAggregator.loss + regVal, new BDV(totalGradientArray))
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org