You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2016/01/19 20:08:55 UTC

spark git commit: [SPARK-12804][ML] Fix LogisticRegression with FitIntercept on all same label training data

Repository: spark
Updated Branches:
  refs/heads/master b122c861c -> 2388de519


[SPARK-12804][ML] Fix LogisticRegression with FitIntercept on all same label training data

CC jkbradley mengxr dbtsai

Author: Feynman Liang <fe...@gmail.com>

Closes #10743 from feynmanliang/SPARK-12804.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2388de51
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2388de51
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2388de51

Branch: refs/heads/master
Commit: 2388de51912efccaceeb663ac56fc500a79d2ceb
Parents: b122c86
Author: Feynman Liang <fe...@gmail.com>
Authored: Tue Jan 19 11:08:52 2016 -0800
Committer: DB Tsai <db...@netflix.com>
Committed: Tue Jan 19 11:08:52 2016 -0800

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  | 200 ++++++++++---------
 .../LogisticRegressionSuite.scala               |  43 ++++
 2 files changed, 148 insertions(+), 95 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2388de51/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 486043e..dad8dfc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -276,113 +276,123 @@ class LogisticRegression @Since("1.2.0") (
     val numClasses = histogram.length
     val numFeatures = summarizer.mean.size
 
-    if (numInvalid != 0) {
-      val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
-        s"Found $numInvalid invalid labels."
-      logError(msg)
-      throw new SparkException(msg)
-    }
-
-    if (numClasses > 2) {
-      val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " +
-        s"binary classification. Found $numClasses in the input dataset."
-      logError(msg)
-      throw new SparkException(msg)
-    }
+    val (coefficients, intercept, objectiveHistory) = {
+      if (numInvalid != 0) {
+        val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
+          s"Found $numInvalid invalid labels."
+        logError(msg)
+        throw new SparkException(msg)
+      }
 
-    val featuresMean = summarizer.mean.toArray
-    val featuresStd = summarizer.variance.toArray.map(math.sqrt)
+      if (numClasses > 2) {
+        val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " +
+          s"binary classification. Found $numClasses in the input dataset."
+        logError(msg)
+        throw new SparkException(msg)
+      } else if ($(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) {
+        logWarning(s"All labels are one and fitIntercept=true, so the coefficients will be " +
+          s"zeros and the intercept will be positive infinity; as a result, " +
+          s"training is not needed.")
+        (Vectors.sparse(numFeatures, Seq()), Double.PositiveInfinity, Array.empty[Double])
+      } else if ($(fitIntercept) && numClasses == 1) {
+        logWarning(s"All labels are zero and fitIntercept=true, so the coefficients will be " +
+          s"zeros and the intercept will be negative infinity; as a result, " +
+          s"training is not needed.")
+        (Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double])
+      } else {
+        val featuresMean = summarizer.mean.toArray
+        val featuresStd = summarizer.variance.toArray.map(math.sqrt)
 
-    val regParamL1 = $(elasticNetParam) * $(regParam)
-    val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
+        val regParamL1 = $(elasticNetParam) * $(regParam)
+        val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
 
-    val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), $(standardization),
-      featuresStd, featuresMean, regParamL2)
+        val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
+          $(standardization), featuresStd, featuresMean, regParamL2)
 
-    val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
-      new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
-    } else {
-      def regParamL1Fun = (index: Int) => {
-        // Remove the L1 penalization on the intercept
-        if (index == numFeatures) {
-          0.0
+        val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
+          new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
         } else {
-          if ($(standardization)) {
-            regParamL1
-          } else {
-            // If `standardization` is false, we still standardize the data
-            // to improve the rate of convergence; as a result, we have to
-            // perform this reverse standardization by penalizing each component
-            // differently to get effectively the same objective function when
-            // the training dataset is not standardized.
-            if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0
+          def regParamL1Fun = (index: Int) => {
+            // Remove the L1 penalization on the intercept
+            if (index == numFeatures) {
+              0.0
+            } else {
+              if ($(standardization)) {
+                regParamL1
+              } else {
+                // If `standardization` is false, we still standardize the data
+                // to improve the rate of convergence; as a result, we have to
+                // perform this reverse standardization by penalizing each component
+                // differently to get effectively the same objective function when
+                // the training dataset is not standardized.
+                if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0
+              }
+            }
           }
+          new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
         }
-      }
-      new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
-    }
-
-    val initialCoefficientsWithIntercept =
-      Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
-
-    if ($(fitIntercept)) {
-      /*
-         For binary logistic regression, when we initialize the coefficients as zeros,
-         it will converge faster if we initialize the intercept such that
-         it follows the distribution of the labels.
-
-         {{{
-         P(0) = 1 / (1 + \exp(b)), and
-         P(1) = \exp(b) / (1 + \exp(b))
-         }}}, hence
-         {{{
-         b = \log{P(1) / P(0)} = \log{count_1 / count_0}
-         }}}
-       */
-      initialCoefficientsWithIntercept.toArray(numFeatures)
-        = math.log(histogram(1) / histogram(0))
-    }
 
-    val states = optimizer.iterations(new CachedDiffFunction(costFun),
-      initialCoefficientsWithIntercept.toBreeze.toDenseVector)
+        val initialCoefficientsWithIntercept =
+          Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
+
+        if ($(fitIntercept)) {
+          /*
+             For binary logistic regression, when we initialize the coefficients as zeros,
+             it will converge faster if we initialize the intercept such that
+             it follows the distribution of the labels.
+
+             {{{
+               P(0) = 1 / (1 + \exp(b)), and
+               P(1) = \exp(b) / (1 + \exp(b))
+             }}}, hence
+             {{{
+               b = \log{P(1) / P(0)} = \log{count_1 / count_0}
+             }}}
+           */
+          initialCoefficientsWithIntercept.toArray(numFeatures) = math.log(
+            histogram(1) / histogram(0))
+        }
 
-    val (coefficients, intercept, objectiveHistory) = {
-      /*
-         Note that in Logistic Regression, the objective history (loss + regularization)
-         is log-likelihood which is invariance under feature standardization. As a result,
-         the objective history from optimizer is the same as the one in the original space.
-       */
-      val arrayBuilder = mutable.ArrayBuilder.make[Double]
-      var state: optimizer.State = null
-      while (states.hasNext) {
-        state = states.next()
-        arrayBuilder += state.adjustedValue
-      }
+        val states = optimizer.iterations(new CachedDiffFunction(costFun),
+          initialCoefficientsWithIntercept.toBreeze.toDenseVector)
+
+        /*
+           Note that in Logistic Regression, the objective history (loss + regularization)
+           is log-likelihood which is invariance under feature standardization. As a result,
+           the objective history from optimizer is the same as the one in the original space.
+         */
+        val arrayBuilder = mutable.ArrayBuilder.make[Double]
+        var state: optimizer.State = null
+        while (states.hasNext) {
+          state = states.next()
+          arrayBuilder += state.adjustedValue
+        }
 
-      if (state == null) {
-        val msg = s"${optimizer.getClass.getName} failed."
-        logError(msg)
-        throw new SparkException(msg)
-      }
+        if (state == null) {
+          val msg = s"${optimizer.getClass.getName} failed."
+          logError(msg)
+          throw new SparkException(msg)
+        }
 
-      /*
-         The coefficients are trained in the scaled space; we're converting them back to
-         the original space.
-         Note that the intercept in scaled space and original space is the same;
-         as a result, no scaling is needed.
-       */
-      val rawCoefficients = state.x.toArray.clone()
-      var i = 0
-      while (i < numFeatures) {
-        rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
-        i += 1
-      }
+        /*
+           The coefficients are trained in the scaled space; we're converting them back to
+           the original space.
+           Note that the intercept in scaled space and original space is the same;
+           as a result, no scaling is needed.
+         */
+        val rawCoefficients = state.x.toArray.clone()
+        var i = 0
+        while (i < numFeatures) {
+          rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
+          i += 1
+        }
 
-      if ($(fitIntercept)) {
-        (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last,
-          arrayBuilder.result())
-      } else {
-        (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result())
+        if ($(fitIntercept)) {
+          (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last,
+            arrayBuilder.result())
+        } else {
+          (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result())
+        }
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2388de51/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index ff0d0ff..972c086 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.lit
 
 class LogisticRegressionSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -883,6 +884,48 @@ class LogisticRegressionSuite
     assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
   }
 
+  test("logistic regression with all labels the same") {
+    val sameLabels = dataset
+      .withColumn("zeroLabel", lit(0.0))
+      .withColumn("oneLabel", lit(1.0))
+
+    // fitIntercept=true
+    val lrIntercept = new LogisticRegression()
+      .setFitIntercept(true)
+      .setMaxIter(3)
+
+    val allZeroInterceptModel = lrIntercept
+      .setLabelCol("zeroLabel")
+      .fit(sameLabels)
+    assert(allZeroInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3)
+    assert(allZeroInterceptModel.intercept === Double.NegativeInfinity)
+    assert(allZeroInterceptModel.summary.totalIterations === 0)
+
+    val allOneInterceptModel = lrIntercept
+      .setLabelCol("oneLabel")
+      .fit(sameLabels)
+    assert(allOneInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3)
+    assert(allOneInterceptModel.intercept === Double.PositiveInfinity)
+    assert(allOneInterceptModel.summary.totalIterations === 0)
+
+    // fitIntercept=false
+    val lrNoIntercept = new LogisticRegression()
+      .setFitIntercept(false)
+      .setMaxIter(3)
+
+    val allZeroNoInterceptModel = lrNoIntercept
+      .setLabelCol("zeroLabel")
+      .fit(sameLabels)
+    assert(allZeroNoInterceptModel.intercept === 0.0)
+    assert(allZeroNoInterceptModel.summary.totalIterations > 0)
+
+    val allOneNoInterceptModel = lrNoIntercept
+      .setLabelCol("oneLabel")
+      .fit(sameLabels)
+    assert(allOneNoInterceptModel.intercept === 0.0)
+    assert(allOneNoInterceptModel.summary.totalIterations > 0)
+  }
+
   test("read/write") {
     def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = {
       assert(model.intercept === model2.intercept)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org