You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/10/26 16:28:35 UTC

spark git commit: [SPARK-17748][FOLLOW-UP][ML] Reorg variables of WeightedLeastSquares.

Repository: spark
Updated Branches:
  refs/heads/master 4bee95407 -> 312ea3f7f


[SPARK-17748][FOLLOW-UP][ML] Reorg variables of WeightedLeastSquares.

## What changes were proposed in this pull request?
This is follow-up work of #15394.
Reorg some variables of ```WeightedLeastSquares``` and fix one minor issue of ```WeightedLeastSquaresSuite```.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <yb...@gmail.com>

Closes #15621 from yanboliang/spark-17748.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/312ea3f7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/312ea3f7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/312ea3f7

Branch: refs/heads/master
Commit: 312ea3f7f65532818e11016d6d780ad47485175f
Parents: 4bee954
Author: Yanbo Liang <yb...@gmail.com>
Authored: Wed Oct 26 09:28:28 2016 -0700
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Wed Oct 26 09:28:28 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/optim/WeightedLeastSquares.scala   | 139 +++++++++++--------
 .../ml/optim/WeightedLeastSquaresSuite.scala    |  15 +-
 2 files changed, 86 insertions(+), 68 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/312ea3f7/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 2223f12..90c24e1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -101,23 +101,19 @@ private[ml] class WeightedLeastSquares(
     summary.validate()
     logInfo(s"Number of instances: ${summary.count}.")
     val k = if (fitIntercept) summary.k + 1 else summary.k
+    val numFeatures = summary.k
     val triK = summary.triK
     val wSum = summary.wSum
-    val bBar = summary.bBar
-    val bbBar = summary.bbBar
-    val aBar = summary.aBar
-    val aStd = summary.aStd
-    val abBar = summary.abBar
-    val aaBar = summary.aaBar
-    val numFeatures = abBar.size
+
     val rawBStd = summary.bStd
+    val rawBBar = summary.bBar
     // if b is constant (rawBStd is zero), then b cannot be scaled. In this case
-    // setting bStd=abs(bBar) ensures that b is not scaled anymore in l-bfgs algorithm.
-    val bStd = if (rawBStd == 0.0) math.abs(bBar) else rawBStd
+    // setting bStd=abs(rawBBar) ensures that b is not scaled anymore in l-bfgs algorithm.
+    val bStd = if (rawBStd == 0.0) math.abs(rawBBar) else rawBStd
 
     if (rawBStd == 0) {
-      if (fitIntercept || bBar == 0.0) {
-        if (bBar == 0.0) {
+      if (fitIntercept || rawBBar == 0.0) {
+        if (rawBBar == 0.0) {
           logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " +
             s"and the intercept will all be zero; as a result, training is not needed.")
         } else {
@@ -126,7 +122,7 @@ private[ml] class WeightedLeastSquares(
             s"training is not needed.")
         }
         val coefficients = new DenseVector(Array.ofDim(numFeatures))
-        val intercept = bBar
+        val intercept = rawBBar
         val diagInvAtWA = new DenseVector(Array(0D))
         return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA, Array(0D))
       } else {
@@ -137,53 +133,70 @@ private[ml] class WeightedLeastSquares(
       }
     }
 
-    // scale aBar to standardized space in-place
-    val aBarValues = aBar.values
-    var j = 0
-    while (j < numFeatures) {
-      if (aStd(j) == 0.0) {
-        aBarValues(j) = 0.0
-      } else {
-        aBarValues(j) /= aStd(j)
-      }
-      j += 1
-    }
+    val bBar = summary.bBar / bStd
+    val bbBar = summary.bbBar / (bStd * bStd)
 
-    // scale abBar to standardized space in-place
-    val abBarValues = abBar.values
+    val aStd = summary.aStd
     val aStdValues = aStd.values
-    j = 0
-    while (j < numFeatures) {
-      if (aStdValues(j) == 0.0) {
-        abBarValues(j) = 0.0
-      } else {
-        abBarValues(j) /= (aStdValues(j) * bStd)
+
+    val aBar = {
+      val _aBar = summary.aBar
+      val _aBarValues = _aBar.values
+      var i = 0
+      // scale aBar to standardized space in-place
+      while (i < numFeatures) {
+        if (aStdValues(i) == 0.0) {
+          _aBarValues(i) = 0.0
+        } else {
+          _aBarValues(i) /= aStdValues(i)
+        }
+        i += 1
       }
-      j += 1
+      _aBar
     }
+    val aBarValues = aBar.values
 
-    // scale aaBar to standardized space in-place
-    val aaBarValues = aaBar.values
-    j = 0
-    var p = 0
-    while (j < numFeatures) {
-      val aStdJ = aStdValues(j)
+    val abBar = {
+      val _abBar = summary.abBar
+      val _abBarValues = _abBar.values
       var i = 0
-      while (i <= j) {
-        val aStdI = aStdValues(i)
-        if (aStdJ == 0.0 || aStdI == 0.0) {
-          aaBarValues(p) = 0.0
+      // scale abBar to standardized space in-place
+      while (i < numFeatures) {
+        if (aStdValues(i) == 0.0) {
+          _abBarValues(i) = 0.0
         } else {
-          aaBarValues(p) /= (aStdI * aStdJ)
+          _abBarValues(i) /= (aStdValues(i) * bStd)
         }
-        p += 1
         i += 1
       }
-      j += 1
+      _abBar
     }
+    val abBarValues = abBar.values
 
-    val bBarStd = bBar / bStd
-    val bbBarStd = bbBar / (bStd * bStd)
+    val aaBar = {
+      val _aaBar = summary.aaBar
+      val _aaBarValues = _aaBar.values
+      var j = 0
+      var p = 0
+      // scale aaBar to standardized space in-place
+      while (j < numFeatures) {
+        val aStdJ = aStdValues(j)
+        var i = 0
+        while (i <= j) {
+          val aStdI = aStdValues(i)
+          if (aStdJ == 0.0 || aStdI == 0.0) {
+            _aaBarValues(p) = 0.0
+          } else {
+            _aaBarValues(p) /= (aStdI * aStdJ)
+          }
+          p += 1
+          i += 1
+        }
+        j += 1
+      }
+      _aaBar
+    }
+    val aaBarValues = aaBar.values
 
     val effectiveRegParam = regParam / bStd
     val effectiveL1RegParam = elasticNetParam * effectiveRegParam
@@ -191,11 +204,11 @@ private[ml] class WeightedLeastSquares(
 
     // add L2 regularization to diagonals
     var i = 0
-    j = 2
+    var j = 2
     while (i < triK) {
       var lambda = effectiveL2RegParam
       if (!standardizeFeatures) {
-        val std = aStd(j - 2)
+        val std = aStdValues(j - 2)
         if (std != 0.0) {
           lambda /= (std * std)
         } else {
@@ -209,8 +222,9 @@ private[ml] class WeightedLeastSquares(
       i += j
       j += 1
     }
-    val aa = getAtA(aaBar.values, aBar.values)
-    val ab = getAtB(abBar.values, bBarStd)
+
+    val aa = getAtA(aaBarValues, aBarValues)
+    val ab = getAtB(abBarValues, bBar)
 
     val solver = if ((solverType == WeightedLeastSquares.Auto && elasticNetParam != 0.0 &&
       regParam != 0.0) || (solverType == WeightedLeastSquares.QuasiNewton)) {
@@ -237,22 +251,23 @@ private[ml] class WeightedLeastSquares(
     val solution = solver match {
       case cholesky: CholeskySolver =>
         try {
-          cholesky.solve(bBarStd, bbBarStd, ab, aa, aBar)
+          cholesky.solve(bBar, bbBar, ab, aa, aBar)
         } catch {
           // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to
-          // quasi-newton solver
+          // Quasi-Newton solver.
           case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto =>
             logWarning("Cholesky solver failed due to singular covariance matrix. " +
               "Retrying with Quasi-Newton solver.")
             // ab and aa were modified in place, so reconstruct them
-            val _aa = getAtA(aaBar.values, aBar.values)
-            val _ab = getAtB(abBar.values, bBarStd)
+            val _aa = getAtA(aaBarValues, aBarValues)
+            val _ab = getAtB(abBarValues, bBar)
             val newSolver = new QuasiNewtonSolver(fitIntercept, maxIter, tol, None)
-            newSolver.solve(bBarStd, bbBarStd, _ab, _aa, aBar)
+            newSolver.solve(bBar, bbBar, _ab, _aa, aBar)
         }
       case qn: QuasiNewtonSolver =>
-        qn.solve(bBarStd, bbBarStd, ab, aa, aBar)
+        qn.solve(bBar, bbBar, ab, aa, aBar)
     }
+
     val (coefficientArray, intercept) = if (fitIntercept) {
       (solution.coefficients.slice(0, solution.coefficients.length - 1),
         solution.coefficients.last * bStd)
@@ -271,7 +286,11 @@ private[ml] class WeightedLeastSquares(
     // aaInv is a packed upper triangular matrix, here we get all elements on diagonal
     val diagInvAtWA = solution.aaInv.map { inv =>
       new DenseVector((1 to k).map { i =>
-        val multiplier = if (i == k && fitIntercept) 1.0 else aStdValues(i - 1) * aStdValues(i - 1)
+        val multiplier = if (i == k && fitIntercept) {
+          1.0
+        } else {
+          aStdValues(i - 1) * aStdValues(i - 1)
+        }
         inv(i + (i - 1) * i / 2 - 1) / (wSum * multiplier)
       }.toArray)
     }.getOrElse(new DenseVector(Array(0D)))
@@ -280,7 +299,7 @@ private[ml] class WeightedLeastSquares(
       solution.objectiveHistory.getOrElse(Array(0D)))
   }
 
-  /** Construct A^T^ A from summary statistics. */
+  /** Construct A^T^ A (append bias if necessary). */
   private def getAtA(aaBar: Array[Double], aBar: Array[Double]): DenseVector = {
     if (fitIntercept) {
       new DenseVector(Array.concat(aaBar, aBar, Array(1.0)))
@@ -289,7 +308,7 @@ private[ml] class WeightedLeastSquares(
     }
   }
 
-  /** Construct A^T^ b from summary statistics. */
+  /** Construct A^T^ b (append bias if necessary). */
   private def getAtB(abBar: Array[Double], bBar: Double): DenseVector = {
     if (fitIntercept) {
       new DenseVector(Array.concat(abBar, Array(bBar)))

http://git-wip-us.apache.org/repos/asf/spark/blob/312ea3f7/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
index 3cdab03..093d02e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
@@ -361,14 +361,13 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
     for (fitIntercept <- Seq(false, true);
          standardization <- Seq(false, true);
          (lambda, alpha) <- Seq((0.0, 0.0), (0.5, 0.0), (0.5, 0.5), (0.5, 1.0))) {
-      for (solver <- Seq(WeightedLeastSquares.Auto, WeightedLeastSquares.Cholesky)) {
-        val wls = new WeightedLeastSquares(fitIntercept, regParam = lambda, elasticNetParam = alpha,
-          standardizeFeatures = standardization, standardizeLabel = true,
-          solverType = WeightedLeastSquares.QuasiNewton)
-        val model = wls.fit(constantFeaturesInstances)
-        val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
-        assert(actual ~== expectedQuasiNewton(idx) absTol 1e-6)
-      }
+      val wls = new WeightedLeastSquares(fitIntercept, regParam = lambda, elasticNetParam = alpha,
+        standardizeFeatures = standardization, standardizeLabel = true,
+        solverType = WeightedLeastSquares.QuasiNewton)
+      val model = wls.fit(constantFeaturesInstances)
+      val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
+      assert(actual ~== expectedQuasiNewton(idx) absTol 1e-6)
+
       idx += 1
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org