You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2015/12/11 08:36:00 UTC

spark git commit: [SPARK-10991][ML] logistic regression training summary handle empty prediction col

Repository: spark
Updated Branches:
  refs/heads/master b1b4ee7f3 -> 518ab5101


[SPARK-10991][ML] logistic regression training summary handle empty prediction col

LogisticRegression training summary should still function if the predictionCol is set to an empty string or otherwise unset (related too https://issues.apache.org/jira/browse/SPARK-9718 )

Author: Holden Karau <ho...@pigscanfly.ca>
Author: Holden Karau <ho...@us.ibm.com>

Closes #9037 from holdenk/SPARK-10991-LogisticRegressionTrainingSummary-handle-empty-prediction-col.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/518ab510
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/518ab510
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/518ab510

Branch: refs/heads/master
Commit: 518ab5101073ee35d62e33c8f7281a1e6342101e
Parents: b1b4ee7
Author: Holden Karau <ho...@pigscanfly.ca>
Authored: Fri Dec 11 02:35:53 2015 -0500
Committer: DB Tsai <db...@netflix.com>
Committed: Fri Dec 11 02:35:53 2015 -0500

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  | 20 ++++++++++++++++++--
 .../LogisticRegressionSuite.scala               | 11 +++++++++++
 2 files changed, 29 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/518ab510/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 19cc323..486043e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -389,9 +389,10 @@ class LogisticRegression @Since("1.2.0") (
     if (handlePersistence) instances.unpersist()
 
     val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept))
+    val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol()
     val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
-      model.transform(dataset),
-      $(probabilityCol),
+      summaryModel.transform(dataset),
+      probabilityColName,
       $(labelCol),
       $(featuresCol),
       objectiveHistory)
@@ -469,6 +470,21 @@ class LogisticRegressionModel private[ml] (
         new NullPointerException())
   }
 
+  /**
+   * If the probability column is set returns the current model and probability column,
+   * otherwise generates a new column and sets it as the probability column on a new copy
+   * of the current model.
+   */
+  private[classification] def findSummaryModelAndProbabilityCol():
+      (LogisticRegressionModel, String) = {
+    $(probabilityCol) match {
+      case "" =>
+        val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString()
+        (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
+      case p => (this, p)
+    }
+  }
+
   private[classification] def setSummary(
       summary: LogisticRegressionTrainingSummary): this.type = {
     this.trainingSummary = Some(summary)

http://git-wip-us.apache.org/repos/asf/spark/blob/518ab510/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index a9a6ff8..1087afb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -99,6 +99,17 @@ class LogisticRegressionSuite
     assert(model.hasParent)
   }
 
+  test("empty probabilityCol") {
+    val lr = new LogisticRegression().setProbabilityCol("")
+    val model = lr.fit(dataset)
+    assert(model.hasSummary)
+    // Validate that we re-insert a probability column for evaluation
+    val fieldNames = model.summary.predictions.schema.fieldNames
+    assert((dataset.schema.fieldNames.toSet).subsetOf(
+      fieldNames.toSet))
+    assert(fieldNames.exists(s => s.startsWith("probability_")))
+  }
+
   test("setThreshold, getThreshold") {
     val lr = new LogisticRegression
     // default


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org