You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2015/06/23 21:42:21 UTC

spark git commit: [SPARK-7888] Be able to disable intercept in linear regression in ml package

Repository: spark
Updated Branches:
  refs/heads/master 6f4cadf5e -> 2b1111dd0


[SPARK-7888] Be able to disable intercept in linear regression in ml package

Author: Holden Karau <ho...@pigscanfly.ca>

Closes #6927 from holdenk/SPARK-7888-Be-able-to-disable-intercept-in-Linear-Regression-in-ML-package and squashes the following commits:

0ad384c [Holden Karau] Add MiMa excludes
4016fac [Holden Karau] Switch to wild card import, remove extra blank lines
ae5baa8 [Holden Karau] CR feedback, move the fitIntercept down rather than changing ymean and etc above
f34971c [Holden Karau] Fix some more long lines
319bd3f [Holden Karau] Fix long lines
3bb9ee1 [Holden Karau] Update the regression suite tests
7015b9f [Holden Karau] Our code performs the same with R, except we need more than one data point but that seems reasonable
0b0c8c0 [Holden Karau] fix the issue with the sample R code
e2140ba [Holden Karau] Add a test, it fails!
5e84a0b [Holden Karau] Write out thoughts and use the correct trait
91ffc0a [Holden Karau] more murh
006246c [Holden Karau] murp?


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2b1111dd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2b1111dd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2b1111dd

Branch: refs/heads/master
Commit: 2b1111dd0b8deb9ad8d43fec792e60e3d0c4de75
Parents: 6f4cadf
Author: Holden Karau <ho...@pigscanfly.ca>
Authored: Tue Jun 23 12:42:17 2015 -0700
Committer: DB Tsai <db...@netflix.com>
Committed: Tue Jun 23 12:42:17 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/regression/LinearRegression.scala  |  30 +++-
 .../ml/regression/LinearRegressionSuite.scala   | 149 ++++++++++++++++++-
 project/MimaExcludes.scala                      |   5 +
 3 files changed, 172 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2b1111dd/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 0130654..1b1d729 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -26,7 +26,7 @@ import org.apache.spark.Logging
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
+import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS._
@@ -41,7 +41,8 @@ import org.apache.spark.util.StatCounter
  * Params for linear regression.
  */
 private[regression] trait LinearRegressionParams extends PredictorParams
-  with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
+    with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
+    with HasFitIntercept
 
 /**
  * :: Experimental ::
@@ -73,6 +74,14 @@ class LinearRegression(override val uid: String)
   setDefault(regParam -> 0.0)
 
   /**
+   * Set if we should fit the intercept
+   * Default is true.
+   * @group setParam
+   */
+  def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+  setDefault(fitIntercept -> true)
+
+  /**
    * Set the ElasticNet mixing parameter.
    * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
    * For 0 < alpha < 1, the penalty is a combination of L1 and L2.
@@ -123,6 +132,7 @@ class LinearRegression(override val uid: String)
     val numFeatures = summarizer.mean.size
     val yMean = statCounter.mean
     val yStd = math.sqrt(statCounter.variance)
+    // look at glmnet5.m L761 maaaybe that has info
 
     // If the yStd is zero, then the intercept is yMean with zero weights;
     // as a result, training is not needed.
@@ -142,7 +152,7 @@ class LinearRegression(override val uid: String)
     val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam
     val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
 
-    val costFun = new LeastSquaresCostFun(instances, yStd, yMean,
+    val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
       featuresStd, featuresMean, effectiveL2RegParam)
 
     val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
@@ -180,7 +190,7 @@ class LinearRegression(override val uid: String)
     // The intercept in R's GLMNET is computed using closed form after the coefficients are
     // converged. See the following discussion for detail.
     // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
-    val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
+    val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
     if (handlePersistence) instances.unpersist()
 
     // TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
@@ -234,6 +244,7 @@ class LinearRegressionModel private[ml] (
  * See this discussion for detail.
  * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
  *
+ * When training with intercept enabled,
  * The objective function in the scaled space is given by
  * {{{
  * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2,
@@ -241,6 +252,10 @@ class LinearRegressionModel private[ml] (
  * where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i,
  * \bar{y} is the mean of label, and \hat{y} is the standard deviation of label.
  *
+ * If we fitting the intercept disabled (that is forced through 0.0),
+ * we can use the same equation except we set \bar{y} and \bar{x_i} to 0 instead
+ * of the respective means.
+ *
  * This can be rewritten as
  * {{{
  * L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y}
@@ -255,6 +270,7 @@ class LinearRegressionModel private[ml] (
  * \sum_i w_i^\prime x_i - y / \hat{y} + offset
  * }}}
  *
+ *
  * Note that the effective weights and offset don't depend on training dataset,
  * so they can be precomputed.
  *
@@ -301,6 +317,7 @@ private class LeastSquaresAggregator(
     weights: Vector,
     labelStd: Double,
     labelMean: Double,
+    fitIntercept: Boolean,
     featuresStd: Array[Double],
     featuresMean: Array[Double]) extends Serializable {
 
@@ -321,7 +338,7 @@ private class LeastSquaresAggregator(
       }
       i += 1
     }
-    (weightsArray, -sum + labelMean / labelStd, weightsArray.length)
+    (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length)
   }
 
   private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
@@ -404,6 +421,7 @@ private class LeastSquaresCostFun(
     data: RDD[(Double, Vector)],
     labelStd: Double,
     labelMean: Double,
+    fitIntercept: Boolean,
     featuresStd: Array[Double],
     featuresMean: Array[Double],
     effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
@@ -412,7 +430,7 @@ private class LeastSquaresCostFun(
     val w = Vectors.fromBreeze(weights)
 
     val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
-      labelMean, featuresStd, featuresMean))(
+      labelMean, fitIntercept, featuresStd, featuresMean))(
         seqOp = (c, v) => (c, v) match {
           case (aggregator, (label, features)) => aggregator.add(label, features)
         },

http://git-wip-us.apache.org/repos/asf/spark/blob/2b1111dd/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 732e2c4..ad1e9da 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, Row}
 class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
 
   @transient var dataset: DataFrame = _
+  @transient var datasetWithoutIntercept: DataFrame = _
 
   /**
    * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
@@ -34,14 +35,24 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
    *
    * import org.apache.spark.mllib.util.LinearDataGenerator
    * val data =
-   *   sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2)
-   * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path")
+   *   sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2),
+   *     Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)
+   * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1)
+   *   .saveAsTextFile("path")
    */
   override def beforeAll(): Unit = {
     super.beforeAll()
     dataset = sqlContext.createDataFrame(
       sc.parallelize(LinearDataGenerator.generateLinearInput(
         6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
+    /**
+     * datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
+     * training model without intercept
+     */
+    datasetWithoutIntercept = sqlContext.createDataFrame(
+      sc.parallelize(LinearDataGenerator.generateLinearInput(
+        0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
+
   }
 
   test("linear regression with intercept without regularization") {
@@ -78,6 +89,42 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     }
   }
 
+  test("linear regression without intercept without regularization") {
+    val trainer = (new LinearRegression).setFitIntercept(false)
+    val model = trainer.fit(dataset)
+    val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept)
+
+    /**
+     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
+     *   intercept = FALSE))
+     * > weights
+     *  3 x 1 sparse Matrix of class "dgCMatrix"
+     *                           s0
+     * (Intercept)         .
+     * as.numeric.data.V2. 6.995908
+     * as.numeric.data.V3. 5.275131
+     */
+    val weightsR = Array(6.995908, 5.275131)
+
+    assert(model.intercept ~== 0 relTol 1E-3)
+    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
+    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+    /**
+     * Then again with the data with no intercept:
+     * > weightsWithoutIntercept
+     * 3 x 1 sparse Matrix of class "dgCMatrix"
+     *                             s0
+     * (Intercept)           .
+     * as.numeric.data3.V2. 4.70011
+     * as.numeric.data3.V3. 7.19943
+     */
+    val weightsWithoutInterceptR = Array(4.70011, 7.19943)
+
+    assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3)
+    assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3)
+    assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3)
+  }
+
   test("linear regression with intercept with L1 regularization") {
     val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
     val model = trainer.fit(dataset)
@@ -87,11 +134,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
      * > weights
      *  3 x 1 sparse Matrix of class "dgCMatrix"
      *                           s0
-     * (Intercept)         6.311546
-     * as.numeric.data.V2. 2.123522
-     * as.numeric.data.V3. 4.605651
+     * (Intercept)         6.24300
+     * as.numeric.data.V2. 4.024821
+     * as.numeric.data.V3. 6.679841
      */
-    val interceptR = 6.243000
+    val interceptR = 6.24300
     val weightsR = Array(4.024821, 6.679841)
 
     assert(model.intercept ~== interceptR relTol 1E-3)
@@ -106,6 +153,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     }
   }
 
+  test("linear regression without intercept with L1 regularization") {
+    val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+      .setFitIntercept(false)
+    val model = trainer.fit(dataset)
+
+    /**
+     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
+     *   intercept=FALSE))
+     * > weights
+     *  3 x 1 sparse Matrix of class "dgCMatrix"
+     *                           s0
+     * (Intercept)          .
+     * as.numeric.data.V2. 6.299752
+     * as.numeric.data.V3. 4.772913
+     */
+    val interceptR = 0.0
+    val weightsR = Array(6.299752, 4.772913)
+
+    assert(model.intercept ~== interceptR relTol 1E-3)
+    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
+    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+
+    model.transform(dataset).select("features", "prediction").collect().foreach {
+      case Row(features: DenseVector, prediction1: Double) =>
+        val prediction2 =
+          features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+        assert(prediction1 ~== prediction2 relTol 1E-5)
+    }
+  }
+
   test("linear regression with intercept with L2 regularization") {
     val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
     val model = trainer.fit(dataset)
@@ -134,6 +211,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     }
   }
 
+  test("linear regression without intercept with L2 regularization") {
+    val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+      .setFitIntercept(false)
+    val model = trainer.fit(dataset)
+
+    /**
+     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
+     *   intercept = FALSE))
+     * > weights
+     *  3 x 1 sparse Matrix of class "dgCMatrix"
+     *                           s0
+     * (Intercept)         .
+     * as.numeric.data.V2. 5.522875
+     * as.numeric.data.V3. 4.214502
+     */
+    val interceptR = 0.0
+    val weightsR = Array(5.522875, 4.214502)
+
+    assert(model.intercept ~== interceptR relTol 1E-3)
+    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
+    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+
+    model.transform(dataset).select("features", "prediction").collect().foreach {
+      case Row(features: DenseVector, prediction1: Double) =>
+        val prediction2 =
+          features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+        assert(prediction1 ~== prediction2 relTol 1E-5)
+    }
+  }
+
   test("linear regression with intercept with ElasticNet regularization") {
     val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
     val model = trainer.fit(dataset)
@@ -161,4 +268,34 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
         assert(prediction1 ~== prediction2 relTol 1E-5)
     }
   }
+
+  test("linear regression without intercept with ElasticNet regularization") {
+    val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+      .setFitIntercept(false)
+    val model = trainer.fit(dataset)
+
+    /**
+     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
+     *   intercept=FALSE))
+     * > weights
+     * 3 x 1 sparse Matrix of class "dgCMatrix"
+     * s0
+     * (Intercept)         .
+     * as.numeric.dataM.V2. 5.673348
+     * as.numeric.dataM.V3. 4.322251
+     */
+    val interceptR = 0.0
+    val weightsR = Array(5.673348, 4.322251)
+
+    assert(model.intercept ~== interceptR relTol 1E-3)
+    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
+    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+
+    model.transform(dataset).select("features", "prediction").collect().foreach {
+      case Row(features: DenseVector, prediction1: Double) =>
+        val prediction2 =
+          features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+        assert(prediction1 ~== prediction2 relTol 1E-5)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2b1111dd/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 7a748fb..f678c69 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -53,6 +53,11 @@ object MimaExcludes {
             // Removing a testing method from a private class
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"),
+            // While private MiMa is still not happy about the changes,
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.ml.regression.LeastSquaresAggregator.this"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.ml.regression.LeastSquaresCostFun.this"),
             // SQL execution is considered private.
             excludePackage("org.apache.spark.sql.execution"),
             // NanoTime and CatalystTimestampConverter is only used inside catalyst,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org