You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/04/29 23:53:41 UTC

spark git commit: [SPARK-7222] [ML] Added mathematical derivation in comment and compressed the model, removed the correction terms in LinearRegression with ElasticNet

Repository: spark
Updated Branches:
  refs/heads/master 3a180c19a -> 15995c883


[SPARK-7222] [ML] Added mathematical derivation in comment and compressed the model, removed the correction terms in LinearRegression with ElasticNet

Added detailed mathematical derivation of how scaling and LeastSquaresAggregator work. Refactored the code so the model is compressed based on the storage. We may try compression based on the prediction time.

Also, I found that diffSum will be always zero mathematically, so no corrections are required.

Author: DB Tsai <db...@netflix.com>

Closes #5767 from dbtsai/lir-doc and squashes the following commits:

5e346c9 [DB Tsai] refactoring
fc9f582 [DB Tsai] doc
58456d8 [DB Tsai] address feedback
69757b8 [DB Tsai] actually diffSum is mathematically zero! No correction is needed.
5929e49 [DB Tsai] typo
63f7d1e [DB Tsai] Added compression to the model based on storage
203a295 [DB Tsai] Add more documentation to LinearRegression in new ML framework.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/15995c88
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/15995c88
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/15995c88

Branch: refs/heads/master
Commit: 15995c883aa248235fdebf0cbeeaa3ef12c97e9c
Parents: 3a180c1
Author: DB Tsai <db...@netflix.com>
Authored: Wed Apr 29 14:53:37 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Apr 29 14:53:37 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/regression/LinearRegression.scala  | 115 +++++++++++++++----
 .../regression/GeneralizedLinearAlgorithm.scala |   2 +-
 .../spark/mllib/util/LinearDataGenerator.scala  |  21 ++--
 3 files changed, 101 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/15995c88/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index f92c681..cc9ad22 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.StatCounter
+import org.apache.spark.Logging
 
 /**
  * Params for linear regression.
@@ -48,7 +49,7 @@ private[regression] trait LinearRegressionParams extends RegressorParams
  */
 @AlphaComponent
 class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
-  with LinearRegressionParams {
+  with LinearRegressionParams with Logging {
 
   /**
    * Set the regularization parameter.
@@ -110,6 +111,15 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
     val yMean = statCounter.mean
     val yStd = math.sqrt(statCounter.variance)
 
+    // If the yStd is zero, then the intercept is yMean with zero weights;
+    // as a result, training is not needed.
+    if (yStd == 0.0) {
+      logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
+        s"and the intercept will be the mean of the label; as a result, training is not needed.")
+      if (handlePersistence) instances.unpersist()
+      return new LinearRegressionModel(this, paramMap, Vectors.sparse(numFeatures, Seq()), yMean)
+    }
+
     val featuresMean = summarizer.mean.toArray
     val featuresStd = summarizer.variance.toArray.map(math.sqrt)
 
@@ -141,7 +151,6 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
     }
     lossHistory += state.value
 
-    // TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector.
     // The weights are trained in the scaled space; we're converting them back to
     // the original space.
     val weights = {
@@ -158,11 +167,10 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
     // converged. See the following discussion for detail.
     // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
     val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
+    if (handlePersistence) instances.unpersist()
 
-    if (handlePersistence) {
-      instances.unpersist()
-    }
-    new LinearRegressionModel(this, paramMap, weights, intercept)
+    // TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
+    new LinearRegressionModel(this, paramMap, weights.compressed, intercept)
   }
 }
 
@@ -198,15 +206,84 @@ class LinearRegressionModel private[ml] (
  * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of
  * the corresponding joint dataset.
  *
-
- *  * Compute gradient and loss for a Least-squared loss function, as used in linear regression.
- * This is correct for the averaged least squares loss function (mean squared error)
- *              L = 1/2n ||A weights-y||^2
- * See also the documentation for the precise formulation.
+ * For improving the convergence rate during the optimization process, and also preventing against
+ * features with very large variances exerting an overly large influence during model training,
+ * package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce
+ * the condition number, and then trains the model in scaled space but returns the weights in
+ * the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
+ *
+ * However, we don't want to apply the `StandardScaler` on the training dataset, and then cache
+ * the standardized dataset since it will create a lot of overhead. As a result, we perform the
+ * scaling implicitly when we compute the objective function. The following is the mathematical
+ * derivation.
+ *
+ * Note that we don't deal with intercept by adding bias here, because the intercept
+ * can be computed using closed form after the coefficients are converged.
+ * See this discussion for detail.
+ * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
+ *
+ * The objective function in the scaled space is given by
+ * {{{
+ * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2,
+ * }}}
+ * where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i,
+ * \bar{y} is the mean of label, and \hat{y} is the standard deviation of label.
+ *
+ * This can be rewritten as
+ * {{{
+ * L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y}
+ *     + \bar{y} / \hat{y}||^2
+ *   = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2
+ * }}}
+ * where w_i^\prime is the effective weights defined by w_i/\hat{x_i}, offset is
+ * {{{
+ * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}.
+ * }}}, and diff is
+ * {{{
+ * \sum_i w_i^\prime x_i - y / \hat{y} + offset
+ * }}}
  *
- * @param weights weights/coefficients corresponding to features
+ * Note that the effective weights and offset don't depend on training dataset,
+ * so they can be precomputed.
  *
- * @param updater Updater to be used to update weights after every iteration.
+ * Now, the first derivative of the objective function in scaled space is
+ * {{{
+ * \frac{\partial L}{\partial\w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i}
+ * }}}
+ * However, ($x_i - \bar{x_i}$) will densify the computation, so it's not
+ * an ideal formula when the training dataset is sparse format.
+ *
+ * This can be addressed by adding the dense \bar{x_i} / \har{x_i} terms
+ * in the end by keeping the sum of diff. The first derivative of total
+ * objective function from all the samples is
+ * {{{
+ * \frac{\partial L}{\partial\w_i} =
+ *     1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i}
+ *   = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i}) / \hat{x_i})
+ *   = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i)
+ * }}},
+ * where correction_i = - diffSum \bar{x_i}) / \hat{x_i}
+ *
+ * A simple math can show that diffSum is actually zero, so we don't even
+ * need to add the correction terms in the end. From the definition of diff,
+ * {{{
+ * diffSum = \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) / \hat{x_i} - (y_j - \bar{y}) / \hat{y})
+ *         = N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y_j} - \bar{y}) / \hat{y})
+ *         = 0
+ * }}}
+ *
+ * As a result, the first derivative of the total objective function only depends on
+ * the training dataset, which can be easily computed in distributed fashion, and is
+ * sparse format friendly.
+ * {{{
+ * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i})
+ * }}},
+ *
+ * @param weights The weights/coefficients corresponding to the features.
+ * @param labelStd The standard deviation value of the label.
+ * @param labelMean The mean value of the label.
+ * @param featuresStd The standard deviation values of the features.
+ * @param featuresMean The mean values of the features.
  */
 private class LeastSquaresAggregator(
     weights: Vector,
@@ -302,18 +379,6 @@ private class LeastSquaresAggregator(
 
   def gradient: Vector = {
     val result = Vectors.dense(gradientSumArray.clone())
-
-    val correction = {
-      val temp = effectiveWeightsArray.clone()
-      var i = 0
-      while (i < temp.length) {
-        temp(i) *= featuresMean(i)
-        i += 1
-      }
-      Vectors.dense(temp)
-    }
-
-    axpy(-diffSum, correction, result)
     scal(1.0 / totalCnt, result)
     result
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/15995c88/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 9fd60ff..26be30f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -225,7 +225,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
       throw new SparkException("Input validation failed.")
     }
 
-    /*
+    /**
      * Scaling columns to unit variance as a heuristic to reduce the condition number:
      *
      * During the optimization process, the convergence (rate) depends on the condition number of

http://git-wip-us.apache.org/repos/asf/spark/blob/15995c88/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
index d7bb943..91c2f4c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
@@ -103,17 +103,16 @@ object LinearDataGenerator {
 
     val rnd = new Random(seed)
     val x = Array.fill[Array[Double]](nPoints)(
-      Array.fill[Double](weights.length)(rnd.nextDouble))
-
-    x.map(vector => {
-      // This doesn't work if `vector` is a sparse vector.
-      val vectorArray = vector.toArray
-      var i = 0
-      while (i < vectorArray.size) {
-        vectorArray(i) = (vectorArray(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
-        i += 1
-      }
-    })
+      Array.fill[Double](weights.length)(rnd.nextDouble()))
+
+    x.foreach {
+      case v =>
+        var i = 0
+        while (i < v.length) {
+          v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
+          i += 1
+        }
+    }
 
     val y = x.map { xi =>
       blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org