You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2016/01/22 02:24:54 UTC

spark git commit: [SPARK-12908][ML] Add warning message for LogisticRegression for potential converge issue

Repository: spark
Updated Branches:
  refs/heads/master 85200c09a -> b4574e387


[SPARK-12908][ML] Add warning message for LogisticRegression for potential converge issue

When all labels are the same, it's a dangerous ground for LogisticRegression without intercept to converge. GLMNET doesn't support this case, and will just exit. GLM can train, but will have a warning message saying the algorithm doesn't converge.

Author: DB Tsai <db...@netflix.com>

Closes #10862 from dbtsai/add-tests.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b4574e38
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b4574e38
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b4574e38

Branch: refs/heads/master
Commit: b4574e387d0124667bdbb35f8c7c3e2065b14ba9
Parents: 85200c0
Author: DB Tsai <db...@netflix.com>
Authored: Thu Jan 21 17:24:48 2016 -0800
Committer: DB Tsai <db...@netflix.com>
Committed: Thu Jan 21 17:24:48 2016 -0800

----------------------------------------------------------------------
 .../apache/spark/ml/classification/LogisticRegression.scala  | 8 ++++++++
 1 file changed, 8 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b4574e38/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index dad8dfc..c98a78a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -300,6 +300,14 @@ class LogisticRegression @Since("1.2.0") (
           s"training is not needed.")
         (Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double])
       } else {
+        if (!$(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) {
+          logWarning(s"All labels are one and fitIntercept=false. It's a dangerous ground, " +
+            s"so the algorithm may not converge.")
+        } else if (!$(fitIntercept) && numClasses == 1) {
+          logWarning(s"All labels are zero and fitIntercept=false. It's a dangerous ground, " +
+            s"so the algorithm may not converge.")
+        }
+
         val featuresMean = summarizer.mean.toArray
         val featuresStd = summarizer.variance.toArray.map(math.sqrt)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org