You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ml...@apache.org on 2018/01/30 07:02:22 UTC

spark git commit: [SPARK-23138][ML][DOC] Multiclass logistic regression summary example and user guide

Repository: spark
Updated Branches:
  refs/heads/master 8b983243e -> 5056877e8


[SPARK-23138][ML][DOC] Multiclass logistic regression summary example and user guide

## What changes were proposed in this pull request?

User guide and examples are updated to reflect multiclass logistic regression summary which was added in [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139).

I did not make a separate summary example, but added the summary code to the multiclass example that already existed. I don't see the need for a separate example for the summary.

## How was this patch tested?

Docs and examples only. Ran all examples locally using spark-submit.

Author: sethah <sh...@cloudera.com>

Closes #20332 from sethah/multiclass_summary_example.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5056877e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5056877e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5056877e

Branch: refs/heads/master
Commit: 5056877e8bea56dd0f4dc9e3385669e1e78b2925
Parents: 8b98324
Author: sethah <sh...@cloudera.com>
Authored: Tue Jan 30 09:02:16 2018 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Tue Jan 30 09:02:16 2018 +0200

----------------------------------------------------------------------
 docs/ml-classification-regression.md            | 22 +++----
 .../JavaLogisticRegressionSummaryExample.java   | 17 ++----
 ...LogisticRegressionWithElasticNetExample.java | 62 ++++++++++++++++++++
 ...lass_logistic_regression_with_elastic_net.py | 38 ++++++++++++
 .../ml/LogisticRegressionSummaryExample.scala   | 15 ++---
 ...ogisticRegressionWithElasticNetExample.scala | 43 ++++++++++++++
 6 files changed, 164 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5056877e/docs/ml-classification-regression.md
----------------------------------------------------------------------
diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md
index bf979f3..ddd2f4b 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -87,7 +87,7 @@ More details on parameters can be found in the [R API documentation](api/R/spark
 The `spark.ml` implementation of logistic regression also supports
 extracting a summary of the model over the training set. Note that the
 predictions and metrics which are stored as `DataFrame` in
-`BinaryLogisticRegressionSummary` are annotated `@transient` and hence
+`LogisticRegressionSummary` are annotated `@transient` and hence
 only available on the driver.
 
 <div class="codetabs">
@@ -97,10 +97,9 @@ only available on the driver.
 [`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary)
 provides a summary for a
 [`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel).
-Currently, only binary classification is supported and the
-summary must be explicitly cast to
-[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
-This will likely change when multiclass classification is supported.
+In the case of binary classification, certain additional metrics are
+available, e.g. ROC curve. The binary summary can be accessed via the
+`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
 
 Continuing the earlier example:
 
@@ -111,10 +110,9 @@ Continuing the earlier example:
 [`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html)
 provides a summary for a
 [`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html).
-Currently, only binary classification is supported and the
-summary must be explicitly cast to
-[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). 
-Support for multiclass model summaries will be added in the future.
+In the case of binary classification, certain additional metrics are
+available, e.g. ROC curve. The binary summary can be accessed via the
+`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
 
 Continuing the earlier example:
 
@@ -125,7 +123,8 @@ Continuing the earlier example:
 [`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary)
 provides a summary for a
 [`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel).
-Currently, only binary classification is supported. Support for multiclass model summaries will be added in the future.
+In the case of binary classification, certain additional metrics are
+available, e.g. ROC curve. See [`BinaryLogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary).
 
 Continuing the earlier example:
 
@@ -162,7 +161,8 @@ For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multin
 **Examples**
 
 The following example shows how to train a multiclass logistic regression 
-model with elastic net regularization.
+model with elastic net regularization, as well as extract the multiclass
+training summary for evaluating the model.
 
 <div class="codetabs">
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5056877e/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
index dee5679..1529da1 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
@@ -18,10 +18,9 @@
 package org.apache.spark.examples.ml;
 
 // $example on$
-import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
+import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary;
 import org.apache.spark.ml.classification.LogisticRegression;
 import org.apache.spark.ml.classification.LogisticRegressionModel;
-import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SparkSession;
@@ -50,7 +49,7 @@ public class JavaLogisticRegressionSummaryExample {
     // $example on$
     // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier
     // example
-    LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
+    BinaryLogisticRegressionTrainingSummary trainingSummary = lrModel.binarySummary();
 
     // Obtain the loss per iteration.
     double[] objectiveHistory = trainingSummary.objectiveHistory();
@@ -58,21 +57,15 @@ public class JavaLogisticRegressionSummaryExample {
       System.out.println(lossPerIteration);
     }
 
-    // Obtain the metrics useful to judge performance on test data.
-    // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary
-    // classification problem.
-    BinaryLogisticRegressionSummary binarySummary =
-      (BinaryLogisticRegressionSummary) trainingSummary;
-
     // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
-    Dataset<Row> roc = binarySummary.roc();
+    Dataset<Row> roc = trainingSummary.roc();
     roc.show();
     roc.select("FPR").show();
-    System.out.println(binarySummary.areaUnderROC());
+    System.out.println(trainingSummary.areaUnderROC());
 
     // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
     // this selected threshold.
-    Dataset<Row> fMeasure = binarySummary.fMeasureByThreshold();
+    Dataset<Row> fMeasure = trainingSummary.fMeasureByThreshold();
     double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0);
     double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
       .select("threshold").head().getDouble(0);

http://git-wip-us.apache.org/repos/asf/spark/blob/5056877e/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java
index da410cb..801a82c 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java
@@ -20,6 +20,7 @@ package org.apache.spark.examples.ml;
 // $example on$
 import org.apache.spark.ml.classification.LogisticRegression;
 import org.apache.spark.ml.classification.LogisticRegressionModel;
+import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SparkSession;
@@ -48,6 +49,67 @@ public class JavaMulticlassLogisticRegressionWithElasticNetExample {
         // Print the coefficients and intercept for multinomial logistic regression
         System.out.println("Coefficients: \n"
                 + lrModel.coefficientMatrix() + " \nIntercept: " + lrModel.interceptVector());
+        LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
+
+        // Obtain the loss per iteration.
+        double[] objectiveHistory = trainingSummary.objectiveHistory();
+        for (double lossPerIteration : objectiveHistory) {
+            System.out.println(lossPerIteration);
+        }
+
+        // for multiclass, we can inspect metrics on a per-label basis
+        System.out.println("False positive rate by label:");
+        int i = 0;
+        double[] fprLabel = trainingSummary.falsePositiveRateByLabel();
+        for (double fpr : fprLabel) {
+            System.out.println("label " + i + ": " + fpr);
+            i++;
+        }
+
+        System.out.println("True positive rate by label:");
+        i = 0;
+        double[] tprLabel = trainingSummary.truePositiveRateByLabel();
+        for (double tpr : tprLabel) {
+            System.out.println("label " + i + ": " + tpr);
+            i++;
+        }
+
+        System.out.println("Precision by label:");
+        i = 0;
+        double[] precLabel = trainingSummary.precisionByLabel();
+        for (double prec : precLabel) {
+            System.out.println("label " + i + ": " + prec);
+            i++;
+        }
+
+        System.out.println("Recall by label:");
+        i = 0;
+        double[] recLabel = trainingSummary.recallByLabel();
+        for (double rec : recLabel) {
+            System.out.println("label " + i + ": " + rec);
+            i++;
+        }
+
+        System.out.println("F-measure by label:");
+        i = 0;
+        double[] fLabel = trainingSummary.fMeasureByLabel();
+        for (double f : fLabel) {
+            System.out.println("label " + i + ": " + f);
+            i++;
+        }
+
+        double accuracy = trainingSummary.accuracy();
+        double falsePositiveRate = trainingSummary.weightedFalsePositiveRate();
+        double truePositiveRate = trainingSummary.weightedTruePositiveRate();
+        double fMeasure = trainingSummary.weightedFMeasure();
+        double precision = trainingSummary.weightedPrecision();
+        double recall = trainingSummary.weightedRecall();
+        System.out.println("Accuracy: " + accuracy);
+        System.out.println("FPR: " + falsePositiveRate);
+        System.out.println("TPR: " + truePositiveRate);
+        System.out.println("F-measure: " + fMeasure);
+        System.out.println("Precision: " + precision);
+        System.out.println("Recall: " + recall);
         // $example off$
 
         spark.stop();

http://git-wip-us.apache.org/repos/asf/spark/blob/5056877e/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py
----------------------------------------------------------------------
diff --git a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py
index bb9cd82..bec9860 100644
--- a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py
+++ b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py
@@ -43,6 +43,44 @@ if __name__ == "__main__":
     # Print the coefficients and intercept for multinomial logistic regression
     print("Coefficients: \n" + str(lrModel.coefficientMatrix))
     print("Intercept: " + str(lrModel.interceptVector))
+
+    trainingSummary = lrModel.summary
+
+    # Obtain the objective per iteration
+    objectiveHistory = trainingSummary.objectiveHistory
+    print("objectiveHistory:")
+    for objective in objectiveHistory:
+        print(objective)
+
+    # for multiclass, we can inspect metrics on a per-label basis
+    print("False positive rate by label:")
+    for i, rate in enumerate(trainingSummary.falsePositiveRateByLabel):
+        print("label %d: %s" % (i, rate))
+
+    print("True positive rate by label:")
+    for i, rate in enumerate(trainingSummary.truePositiveRateByLabel):
+        print("label %d: %s" % (i, rate))
+
+    print("Precision by label:")
+    for i, prec in enumerate(trainingSummary.precisionByLabel):
+        print("label %d: %s" % (i, prec))
+
+    print("Recall by label:")
+    for i, rec in enumerate(trainingSummary.recallByLabel):
+        print("label %d: %s" % (i, rec))
+
+    print("F-measure by label:")
+    for i, f in enumerate(trainingSummary.fMeasureByLabel()):
+        print("label %d: %s" % (i, f))
+
+    accuracy = trainingSummary.accuracy
+    falsePositiveRate = trainingSummary.weightedFalsePositiveRate
+    truePositiveRate = trainingSummary.weightedTruePositiveRate
+    fMeasure = trainingSummary.weightedFMeasure()
+    precision = trainingSummary.weightedPrecision
+    recall = trainingSummary.weightedRecall
+    print("Accuracy: %s\nFPR: %s\nTPR: %s\nF-measure: %s\nPrecision: %s\nRecall: %s"
+          % (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall))
     # $example off$
 
     spark.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/5056877e/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala
index 1740a0d..0368dcb 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala
@@ -19,7 +19,7 @@
 package org.apache.spark.examples.ml
 
 // $example on$
-import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}
+import org.apache.spark.ml.classification.LogisticRegression
 // $example off$
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.functions.max
@@ -47,25 +47,20 @@ object LogisticRegressionSummaryExample {
     // $example on$
     // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier
     // example
-    val trainingSummary = lrModel.summary
+    val trainingSummary = lrModel.binarySummary
 
     // Obtain the objective per iteration.
     val objectiveHistory = trainingSummary.objectiveHistory
     println("objectiveHistory:")
     objectiveHistory.foreach(loss => println(loss))
 
-    // Obtain the metrics useful to judge performance on test data.
-    // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
-    // binary classification problem.
-    val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
-
     // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
-    val roc = binarySummary.roc
+    val roc = trainingSummary.roc
     roc.show()
-    println(s"areaUnderROC: ${binarySummary.areaUnderROC}")
+    println(s"areaUnderROC: ${trainingSummary.areaUnderROC}")
 
     // Set the model threshold to maximize F-Measure
-    val fMeasure = binarySummary.fMeasureByThreshold
+    val fMeasure = trainingSummary.fMeasureByThreshold
     val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
     val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure)
       .select("threshold").head().getDouble(0)

http://git-wip-us.apache.org/repos/asf/spark/blob/5056877e/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala
index 3e61dbe..1f7dbdd 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala
@@ -49,6 +49,49 @@ object MulticlassLogisticRegressionWithElasticNetExample {
     // Print the coefficients and intercept for multinomial logistic regression
     println(s"Coefficients: \n${lrModel.coefficientMatrix}")
     println(s"Intercepts: \n${lrModel.interceptVector}")
+
+    val trainingSummary = lrModel.summary
+
+    // Obtain the objective per iteration
+    val objectiveHistory = trainingSummary.objectiveHistory
+    println("objectiveHistory:")
+    objectiveHistory.foreach(println)
+
+    // for multiclass, we can inspect metrics on a per-label basis
+    println("False positive rate by label:")
+    trainingSummary.falsePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) =>
+      println(s"label $label: $rate")
+    }
+
+    println("True positive rate by label:")
+    trainingSummary.truePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) =>
+      println(s"label $label: $rate")
+    }
+
+    println("Precision by label:")
+    trainingSummary.precisionByLabel.zipWithIndex.foreach { case (prec, label) =>
+      println(s"label $label: $prec")
+    }
+
+    println("Recall by label:")
+    trainingSummary.recallByLabel.zipWithIndex.foreach { case (rec, label) =>
+      println(s"label $label: $rec")
+    }
+
+
+    println("F-measure by label:")
+    trainingSummary.fMeasureByLabel.zipWithIndex.foreach { case (f, label) =>
+      println(s"label $label: $f")
+    }
+
+    val accuracy = trainingSummary.accuracy
+    val falsePositiveRate = trainingSummary.weightedFalsePositiveRate
+    val truePositiveRate = trainingSummary.weightedTruePositiveRate
+    val fMeasure = trainingSummary.weightedFMeasure
+    val precision = trainingSummary.weightedPrecision
+    val recall = trainingSummary.weightedRecall
+    println(s"Accuracy: $accuracy\nFPR: $falsePositiveRate\nTPR: $truePositiveRate\n" +
+      s"F-measure: $fMeasure\nPrecision: $precision\nRecall: $recall")
     // $example off$
 
     spark.stop()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org