You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ml...@apache.org on 2018/01/30 07:02:22 UTC
spark git commit: [SPARK-23138][ML][DOC] Multiclass logistic
regression summary example and user guide
Repository: spark
Updated Branches:
refs/heads/master 8b983243e -> 5056877e8
[SPARK-23138][ML][DOC] Multiclass logistic regression summary example and user guide
## What changes were proposed in this pull request?
User guide and examples are updated to reflect multiclass logistic regression summary which was added in [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139).
I did not make a separate summary example, but added the summary code to the multiclass example that already existed. I don't see the need for a separate example for the summary.
## How was this patch tested?
Docs and examples only. Ran all examples locally using spark-submit.
Author: sethah <sh...@cloudera.com>
Closes #20332 from sethah/multiclass_summary_example.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5056877e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5056877e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5056877e
Branch: refs/heads/master
Commit: 5056877e8bea56dd0f4dc9e3385669e1e78b2925
Parents: 8b98324
Author: sethah <sh...@cloudera.com>
Authored: Tue Jan 30 09:02:16 2018 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Tue Jan 30 09:02:16 2018 +0200
----------------------------------------------------------------------
docs/ml-classification-regression.md | 22 +++----
.../JavaLogisticRegressionSummaryExample.java | 17 ++----
...LogisticRegressionWithElasticNetExample.java | 62 ++++++++++++++++++++
...lass_logistic_regression_with_elastic_net.py | 38 ++++++++++++
.../ml/LogisticRegressionSummaryExample.scala | 15 ++---
...ogisticRegressionWithElasticNetExample.scala | 43 ++++++++++++++
6 files changed, 164 insertions(+), 33 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5056877e/docs/ml-classification-regression.md
----------------------------------------------------------------------
diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md
index bf979f3..ddd2f4b 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -87,7 +87,7 @@ More details on parameters can be found in the [R API documentation](api/R/spark
The `spark.ml` implementation of logistic regression also supports
extracting a summary of the model over the training set. Note that the
predictions and metrics which are stored as `DataFrame` in
-`BinaryLogisticRegressionSummary` are annotated `@transient` and hence
+`LogisticRegressionSummary` are annotated `@transient` and hence
only available on the driver.
<div class="codetabs">
@@ -97,10 +97,9 @@ only available on the driver.
[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary)
provides a summary for a
[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel).
-Currently, only binary classification is supported and the
-summary must be explicitly cast to
-[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
-This will likely change when multiclass classification is supported.
+In the case of binary classification, certain additional metrics are
+available, e.g. ROC curve. The binary summary can be accessed via the
+`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
Continuing the earlier example:
@@ -111,10 +110,9 @@ Continuing the earlier example:
[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html)
provides a summary for a
[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html).
-Currently, only binary classification is supported and the
-summary must be explicitly cast to
-[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
-Support for multiclass model summaries will be added in the future.
+In the case of binary classification, certain additional metrics are
+available, e.g. ROC curve. The binary summary can be accessed via the
+`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
Continuing the earlier example:
@@ -125,7 +123,8 @@ Continuing the earlier example:
[`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary)
provides a summary for a
[`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel).
-Currently, only binary classification is supported. Support for multiclass model summaries will be added in the future.
+In the case of binary classification, certain additional metrics are
+available, e.g. ROC curve. See [`BinaryLogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary).
Continuing the earlier example:
@@ -162,7 +161,8 @@ For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multin
**Examples**
The following example shows how to train a multiclass logistic regression
-model with elastic net regularization.
+model with elastic net regularization, as well as extract the multiclass
+training summary for evaluating the model.
<div class="codetabs">
http://git-wip-us.apache.org/repos/asf/spark/blob/5056877e/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
index dee5679..1529da1 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
@@ -18,10 +18,9 @@
package org.apache.spark.examples.ml;
// $example on$
-import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
+import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
-import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
@@ -50,7 +49,7 @@ public class JavaLogisticRegressionSummaryExample {
// $example on$
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier
// example
- LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
+ BinaryLogisticRegressionTrainingSummary trainingSummary = lrModel.binarySummary();
// Obtain the loss per iteration.
double[] objectiveHistory = trainingSummary.objectiveHistory();
@@ -58,21 +57,15 @@ public class JavaLogisticRegressionSummaryExample {
System.out.println(lossPerIteration);
}
- // Obtain the metrics useful to judge performance on test data.
- // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary
- // classification problem.
- BinaryLogisticRegressionSummary binarySummary =
- (BinaryLogisticRegressionSummary) trainingSummary;
-
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
- Dataset<Row> roc = binarySummary.roc();
+ Dataset<Row> roc = trainingSummary.roc();
roc.show();
roc.select("FPR").show();
- System.out.println(binarySummary.areaUnderROC());
+ System.out.println(trainingSummary.areaUnderROC());
// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
- Dataset<Row> fMeasure = binarySummary.fMeasureByThreshold();
+ Dataset<Row> fMeasure = trainingSummary.fMeasureByThreshold();
double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0);
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
.select("threshold").head().getDouble(0);
http://git-wip-us.apache.org/repos/asf/spark/blob/5056877e/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java
index da410cb..801a82c 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java
@@ -20,6 +20,7 @@ package org.apache.spark.examples.ml;
// $example on$
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
+import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
@@ -48,6 +49,67 @@ public class JavaMulticlassLogisticRegressionWithElasticNetExample {
// Print the coefficients and intercept for multinomial logistic regression
System.out.println("Coefficients: \n"
+ lrModel.coefficientMatrix() + " \nIntercept: " + lrModel.interceptVector());
+ LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
+
+ // Obtain the loss per iteration.
+ double[] objectiveHistory = trainingSummary.objectiveHistory();
+ for (double lossPerIteration : objectiveHistory) {
+ System.out.println(lossPerIteration);
+ }
+
+ // for multiclass, we can inspect metrics on a per-label basis
+ System.out.println("False positive rate by label:");
+ int i = 0;
+ double[] fprLabel = trainingSummary.falsePositiveRateByLabel();
+ for (double fpr : fprLabel) {
+ System.out.println("label " + i + ": " + fpr);
+ i++;
+ }
+
+ System.out.println("True positive rate by label:");
+ i = 0;
+ double[] tprLabel = trainingSummary.truePositiveRateByLabel();
+ for (double tpr : tprLabel) {
+ System.out.println("label " + i + ": " + tpr);
+ i++;
+ }
+
+ System.out.println("Precision by label:");
+ i = 0;
+ double[] precLabel = trainingSummary.precisionByLabel();
+ for (double prec : precLabel) {
+ System.out.println("label " + i + ": " + prec);
+ i++;
+ }
+
+ System.out.println("Recall by label:");
+ i = 0;
+ double[] recLabel = trainingSummary.recallByLabel();
+ for (double rec : recLabel) {
+ System.out.println("label " + i + ": " + rec);
+ i++;
+ }
+
+ System.out.println("F-measure by label:");
+ i = 0;
+ double[] fLabel = trainingSummary.fMeasureByLabel();
+ for (double f : fLabel) {
+ System.out.println("label " + i + ": " + f);
+ i++;
+ }
+
+ double accuracy = trainingSummary.accuracy();
+ double falsePositiveRate = trainingSummary.weightedFalsePositiveRate();
+ double truePositiveRate = trainingSummary.weightedTruePositiveRate();
+ double fMeasure = trainingSummary.weightedFMeasure();
+ double precision = trainingSummary.weightedPrecision();
+ double recall = trainingSummary.weightedRecall();
+ System.out.println("Accuracy: " + accuracy);
+ System.out.println("FPR: " + falsePositiveRate);
+ System.out.println("TPR: " + truePositiveRate);
+ System.out.println("F-measure: " + fMeasure);
+ System.out.println("Precision: " + precision);
+ System.out.println("Recall: " + recall);
// $example off$
spark.stop();
http://git-wip-us.apache.org/repos/asf/spark/blob/5056877e/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py
----------------------------------------------------------------------
diff --git a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py
index bb9cd82..bec9860 100644
--- a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py
+++ b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py
@@ -43,6 +43,44 @@ if __name__ == "__main__":
# Print the coefficients and intercept for multinomial logistic regression
print("Coefficients: \n" + str(lrModel.coefficientMatrix))
print("Intercept: " + str(lrModel.interceptVector))
+
+ trainingSummary = lrModel.summary
+
+ # Obtain the objective per iteration
+ objectiveHistory = trainingSummary.objectiveHistory
+ print("objectiveHistory:")
+ for objective in objectiveHistory:
+ print(objective)
+
+ # for multiclass, we can inspect metrics on a per-label basis
+ print("False positive rate by label:")
+ for i, rate in enumerate(trainingSummary.falsePositiveRateByLabel):
+ print("label %d: %s" % (i, rate))
+
+ print("True positive rate by label:")
+ for i, rate in enumerate(trainingSummary.truePositiveRateByLabel):
+ print("label %d: %s" % (i, rate))
+
+ print("Precision by label:")
+ for i, prec in enumerate(trainingSummary.precisionByLabel):
+ print("label %d: %s" % (i, prec))
+
+ print("Recall by label:")
+ for i, rec in enumerate(trainingSummary.recallByLabel):
+ print("label %d: %s" % (i, rec))
+
+ print("F-measure by label:")
+ for i, f in enumerate(trainingSummary.fMeasureByLabel()):
+ print("label %d: %s" % (i, f))
+
+ accuracy = trainingSummary.accuracy
+ falsePositiveRate = trainingSummary.weightedFalsePositiveRate
+ truePositiveRate = trainingSummary.weightedTruePositiveRate
+ fMeasure = trainingSummary.weightedFMeasure()
+ precision = trainingSummary.weightedPrecision
+ recall = trainingSummary.weightedRecall
+ print("Accuracy: %s\nFPR: %s\nTPR: %s\nF-measure: %s\nPrecision: %s\nRecall: %s"
+ % (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall))
# $example off$
spark.stop()
http://git-wip-us.apache.org/repos/asf/spark/blob/5056877e/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala
index 1740a0d..0368dcb 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala
@@ -19,7 +19,7 @@
package org.apache.spark.examples.ml
// $example on$
-import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}
+import org.apache.spark.ml.classification.LogisticRegression
// $example off$
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.max
@@ -47,25 +47,20 @@ object LogisticRegressionSummaryExample {
// $example on$
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier
// example
- val trainingSummary = lrModel.summary
+ val trainingSummary = lrModel.binarySummary
// Obtain the objective per iteration.
val objectiveHistory = trainingSummary.objectiveHistory
println("objectiveHistory:")
objectiveHistory.foreach(loss => println(loss))
- // Obtain the metrics useful to judge performance on test data.
- // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
- // binary classification problem.
- val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
-
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
- val roc = binarySummary.roc
+ val roc = trainingSummary.roc
roc.show()
- println(s"areaUnderROC: ${binarySummary.areaUnderROC}")
+ println(s"areaUnderROC: ${trainingSummary.areaUnderROC}")
// Set the model threshold to maximize F-Measure
- val fMeasure = binarySummary.fMeasureByThreshold
+ val fMeasure = trainingSummary.fMeasureByThreshold
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure)
.select("threshold").head().getDouble(0)
http://git-wip-us.apache.org/repos/asf/spark/blob/5056877e/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala
index 3e61dbe..1f7dbdd 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala
@@ -49,6 +49,49 @@ object MulticlassLogisticRegressionWithElasticNetExample {
// Print the coefficients and intercept for multinomial logistic regression
println(s"Coefficients: \n${lrModel.coefficientMatrix}")
println(s"Intercepts: \n${lrModel.interceptVector}")
+
+ val trainingSummary = lrModel.summary
+
+ // Obtain the objective per iteration
+ val objectiveHistory = trainingSummary.objectiveHistory
+ println("objectiveHistory:")
+ objectiveHistory.foreach(println)
+
+ // for multiclass, we can inspect metrics on a per-label basis
+ println("False positive rate by label:")
+ trainingSummary.falsePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) =>
+ println(s"label $label: $rate")
+ }
+
+ println("True positive rate by label:")
+ trainingSummary.truePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) =>
+ println(s"label $label: $rate")
+ }
+
+ println("Precision by label:")
+ trainingSummary.precisionByLabel.zipWithIndex.foreach { case (prec, label) =>
+ println(s"label $label: $prec")
+ }
+
+ println("Recall by label:")
+ trainingSummary.recallByLabel.zipWithIndex.foreach { case (rec, label) =>
+ println(s"label $label: $rec")
+ }
+
+
+ println("F-measure by label:")
+ trainingSummary.fMeasureByLabel.zipWithIndex.foreach { case (f, label) =>
+ println(s"label $label: $f")
+ }
+
+ val accuracy = trainingSummary.accuracy
+ val falsePositiveRate = trainingSummary.weightedFalsePositiveRate
+ val truePositiveRate = trainingSummary.weightedTruePositiveRate
+ val fMeasure = trainingSummary.weightedFMeasure
+ val precision = trainingSummary.weightedPrecision
+ val recall = trainingSummary.weightedRecall
+ println(s"Accuracy: $accuracy\nFPR: $falsePositiveRate\nTPR: $truePositiveRate\n" +
+ s"F-measure: $fMeasure\nPrecision: $precision\nRecall: $recall")
// $example off$
spark.stop()
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org