You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2015/12/11 08:36:00 UTC
spark git commit: [SPARK-10991][ML] logistic regression training
summary handle empty prediction col
Repository: spark
Updated Branches:
refs/heads/master b1b4ee7f3 -> 518ab5101
[SPARK-10991][ML] logistic regression training summary handle empty prediction col
LogisticRegression training summary should still function if the predictionCol is set to an empty string or otherwise unset (related too https://issues.apache.org/jira/browse/SPARK-9718 )
Author: Holden Karau <ho...@pigscanfly.ca>
Author: Holden Karau <ho...@us.ibm.com>
Closes #9037 from holdenk/SPARK-10991-LogisticRegressionTrainingSummary-handle-empty-prediction-col.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/518ab510
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/518ab510
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/518ab510
Branch: refs/heads/master
Commit: 518ab5101073ee35d62e33c8f7281a1e6342101e
Parents: b1b4ee7
Author: Holden Karau <ho...@pigscanfly.ca>
Authored: Fri Dec 11 02:35:53 2015 -0500
Committer: DB Tsai <db...@netflix.com>
Committed: Fri Dec 11 02:35:53 2015 -0500
----------------------------------------------------------------------
.../ml/classification/LogisticRegression.scala | 20 ++++++++++++++++++--
.../LogisticRegressionSuite.scala | 11 +++++++++++
2 files changed, 29 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/518ab510/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 19cc323..486043e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -389,9 +389,10 @@ class LogisticRegression @Since("1.2.0") (
if (handlePersistence) instances.unpersist()
val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept))
+ val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol()
val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
- model.transform(dataset),
- $(probabilityCol),
+ summaryModel.transform(dataset),
+ probabilityColName,
$(labelCol),
$(featuresCol),
objectiveHistory)
@@ -469,6 +470,21 @@ class LogisticRegressionModel private[ml] (
new NullPointerException())
}
+ /**
+ * If the probability column is set returns the current model and probability column,
+ * otherwise generates a new column and sets it as the probability column on a new copy
+ * of the current model.
+ */
+ private[classification] def findSummaryModelAndProbabilityCol():
+ (LogisticRegressionModel, String) = {
+ $(probabilityCol) match {
+ case "" =>
+ val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString()
+ (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
+ case p => (this, p)
+ }
+ }
+
private[classification] def setSummary(
summary: LogisticRegressionTrainingSummary): this.type = {
this.trainingSummary = Some(summary)
http://git-wip-us.apache.org/repos/asf/spark/blob/518ab510/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index a9a6ff8..1087afb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -99,6 +99,17 @@ class LogisticRegressionSuite
assert(model.hasParent)
}
+ test("empty probabilityCol") {
+ val lr = new LogisticRegression().setProbabilityCol("")
+ val model = lr.fit(dataset)
+ assert(model.hasSummary)
+ // Validate that we re-insert a probability column for evaluation
+ val fieldNames = model.summary.predictions.schema.fieldNames
+ assert((dataset.schema.fieldNames.toSet).subsetOf(
+ fieldNames.toSet))
+ assert(fieldNames.exists(s => s.startsWith("probability_")))
+ }
+
test("setThreshold, getThreshold") {
val lr = new LogisticRegression
// default
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org