You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/08/28 20:31:08 UTC

spark git commit: [SPARK-17139][ML] Add model summary for MultinomialLogisticRegression

Repository: spark
Updated Branches:
  refs/heads/master 73e64f7d5 -> c7270a46f


[SPARK-17139][ML] Add model summary for MultinomialLogisticRegression

## What changes were proposed in this pull request?

Add 4 traits, using the following hierarchy:
LogisticRegressionSummary
LogisticRegressionTrainingSummary: LogisticRegressionSummary
BinaryLogisticRegressionSummary: LogisticRegressionSummary
BinaryLogisticRegressionTrainingSummary: LogisticRegressionTrainingSummary, BinaryLogisticRegressionSummary

and the public method such as `def summary` only return trait type listed above.

and then implement 4 concrete classes:
LogisticRegressionSummaryImpl (multiclass case)
LogisticRegressionTrainingSummaryImpl (multiclass case)
BinaryLogisticRegressionSummaryImpl (binary case).
BinaryLogisticRegressionTrainingSummaryImpl (binary case).

## How was this patch tested?

Existing tests & added tests.

Author: WeichenXu <We...@outlook.com>

Closes #15435 from WeichenXu123/mlor_summary.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c7270a46
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c7270a46
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c7270a46

Branch: refs/heads/master
Commit: c7270a46fc340db62c87ddfc6568603d0b832845
Parents: 73e64f7
Author: Weichen Xu <we...@databricks.com>
Authored: Mon Aug 28 13:31:01 2017 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Aug 28 13:31:01 2017 -0700

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  | 340 +++++++++++++++----
 .../LogisticRegressionSuite.scala               | 160 +++++++--
 .../ml/regression/LinearRegressionSuite.scala   |   2 +-
 project/MimaExcludes.scala                      |  21 +-
 4 files changed, 412 insertions(+), 111 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c7270a46/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 21957d9..ffe4b52 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -22,7 +22,7 @@ import java.util.Locale
 import scala.collection.mutable
 
 import breeze.linalg.{DenseVector => BDV}
-import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
+import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
 import org.apache.hadoop.fs.Path
 
 import org.apache.spark.SparkException
@@ -35,7 +35,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
-import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
 import org.apache.spark.mllib.linalg.VectorImplicits._
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.mllib.util.MLUtils
@@ -882,21 +882,28 @@ class LogisticRegression @Since("1.2.0") (
 
     val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
       numClasses, isMultinomial))
-    // TODO: implement summary model for multinomial case
-    val m = if (!isMultinomial) {
-      val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol()
-      val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
+
+    val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
+    val logRegSummary = if (numClasses <= 2) {
+      new BinaryLogisticRegressionTrainingSummaryImpl(
         summaryModel.transform(dataset),
         probabilityColName,
+        predictionColName,
         $(labelCol),
         $(featuresCol),
         objectiveHistory)
-      model.setSummary(Some(logRegSummary))
     } else {
-      model
+      new LogisticRegressionTrainingSummaryImpl(
+        summaryModel.transform(dataset),
+        probabilityColName,
+        predictionColName,
+        $(labelCol),
+        $(featuresCol),
+        objectiveHistory)
     }
-    instr.logSuccess(m)
-    m
+    model.setSummary(Some(logRegSummary))
+    instr.logSuccess(model)
+    model
   }
 
   @Since("1.4.0")
@@ -1010,8 +1017,8 @@ class LogisticRegressionModel private[spark] (
   private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
 
   /**
-   * Gets summary of model on training set. An exception is
-   * thrown if `trainingSummary == None`.
+   * Gets summary of model on training set. An exception is thrown
+   * if `trainingSummary == None`.
    */
   @Since("1.5.0")
   def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse {
@@ -1019,18 +1026,36 @@ class LogisticRegressionModel private[spark] (
   }
 
   /**
-   * If the probability column is set returns the current model and probability column,
-   * otherwise generates a new column and sets it as the probability column on a new copy
-   * of the current model.
+   * Gets summary of model on training set. An exception is thrown
+   * if `trainingSummary == None` or it is a multiclass model.
    */
-  private[classification] def findSummaryModelAndProbabilityCol():
-      (LogisticRegressionModel, String) = {
-    $(probabilityCol) match {
-      case "" =>
-        val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString
-        (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
-      case p => (this, p)
+  @Since("2.3.0")
+  def binarySummary: BinaryLogisticRegressionTrainingSummary = summary match {
+    case b: BinaryLogisticRegressionTrainingSummary => b
+    case _ =>
+      throw new RuntimeException("Cannot create a binary summary for a non-binary model" +
+        s"(numClasses=${numClasses}), use summary instead.")
+  }
+
+  /**
+   * If the probability and prediction columns are set, this method returns the current model,
+   * otherwise it generates new columns for them and sets them as columns on a new copy of
+   * the current model
+   */
+  private[classification] def findSummaryModel():
+      (LogisticRegressionModel, String, String) = {
+    val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
+      copy(ParamMap.empty)
+        .setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
+        .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
+    } else if ($(probabilityCol).isEmpty) {
+      copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
+    } else if ($(predictionCol).isEmpty) {
+      copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
+    } else {
+      this
     }
+    (model, model.getProbabilityCol, model.getPredictionCol)
   }
 
   private[classification]
@@ -1051,9 +1076,14 @@ class LogisticRegressionModel private[spark] (
   @Since("2.0.0")
   def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = {
     // Handle possible missing or invalid prediction columns
-    val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol()
-    new BinaryLogisticRegressionSummary(summaryModel.transform(dataset),
-      probabilityColName, $(labelCol), $(featuresCol))
+    val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
+    if (numClasses > 2) {
+      new LogisticRegressionSummaryImpl(summaryModel.transform(dataset),
+        probabilityColName, predictionColName, $(labelCol), $(featuresCol))
+    } else {
+      new BinaryLogisticRegressionSummaryImpl(summaryModel.transform(dataset),
+        probabilityColName, predictionColName, $(labelCol), $(featuresCol))
+    }
   }
 
   /**
@@ -1324,90 +1354,154 @@ private[ml] class MultiClassSummarizer extends Serializable {
 }
 
 /**
- * Abstraction for multinomial Logistic Regression Training results.
- * Currently, the training summary ignores the training weights except
- * for the objective trace.
- */
-sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
-
-  /** objective function (scaled loss + regularization) at each iteration. */
-  def objectiveHistory: Array[Double]
-
-  /** Number of training iterations until termination */
-  def totalIterations: Int = objectiveHistory.length
-
-}
-
-/**
- * Abstraction for Logistic Regression Results for a given model.
+ * :: Experimental ::
+ * Abstraction for logistic regression results for a given model.
  */
+@Experimental
 sealed trait LogisticRegressionSummary extends Serializable {
 
   /**
    * Dataframe output by the model's `transform` method.
    */
+  @Since("1.5.0")
   def predictions: DataFrame
 
   /** Field in "predictions" which gives the probability of each class as a vector. */
+  @Since("1.5.0")
   def probabilityCol: String
 
+  /** Field in "predictions" which gives the prediction of each class. */
+  @Since("2.3.0")
+  def predictionCol: String
+
   /** Field in "predictions" which gives the true label of each instance (if available). */
+  @Since("1.5.0")
   def labelCol: String
 
   /** Field in "predictions" which gives the features of each instance as a vector. */
+  @Since("1.6.0")
   def featuresCol: String
 
+  @transient private val multiclassMetrics = {
+    new MulticlassMetrics(
+      predictions.select(
+        col(predictionCol),
+        col(labelCol).cast(DoubleType))
+        .rdd.map { case Row(prediction: Double, label: Double) => (prediction, label) })
+  }
+
+  /**
+   * Returns the sequence of labels in ascending order. This order matches the order used
+   * in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.
+   *
+   * Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
+   * training set is missing a label, then all of the arrays over labels
+   * (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
+   * expected numClasses.
+   */
+  @Since("2.3.0")
+  def labels: Array[Double] = multiclassMetrics.labels
+
+  /** Returns true positive rate for each label (category). */
+  @Since("2.3.0")
+  def truePositiveRateByLabel: Array[Double] = recallByLabel
+
+  /** Returns false positive rate for each label (category). */
+  @Since("2.3.0")
+  def falsePositiveRateByLabel: Array[Double] = {
+    multiclassMetrics.labels.map(label => multiclassMetrics.falsePositiveRate(label))
+  }
+
+  /** Returns precision for each label (category). */
+  @Since("2.3.0")
+  def precisionByLabel: Array[Double] = {
+    multiclassMetrics.labels.map(label => multiclassMetrics.precision(label))
+  }
+
+  /** Returns recall for each label (category). */
+  @Since("2.3.0")
+  def recallByLabel: Array[Double] = {
+    multiclassMetrics.labels.map(label => multiclassMetrics.recall(label))
+  }
+
+  /** Returns f-measure for each label (category). */
+  @Since("2.3.0")
+  def fMeasureByLabel(beta: Double): Array[Double] = {
+    multiclassMetrics.labels.map(label => multiclassMetrics.fMeasure(label, beta))
+  }
+
+  /** Returns f1-measure for each label (category). */
+  @Since("2.3.0")
+  def fMeasureByLabel: Array[Double] = fMeasureByLabel(1.0)
+
+  /**
+   * Returns accuracy.
+   * (equals to the total number of correctly classified instances
+   * out of the total number of instances.)
+   */
+  @Since("2.3.0")
+  def accuracy: Double = multiclassMetrics.accuracy
+
+  /**
+   * Returns weighted true positive rate.
+   * (equals to precision, recall and f-measure)
+   */
+  @Since("2.3.0")
+  def weightedTruePositiveRate: Double = weightedRecall
+
+  /** Returns weighted false positive rate. */
+  @Since("2.3.0")
+  def weightedFalsePositiveRate: Double = multiclassMetrics.weightedFalsePositiveRate
+
+  /**
+   * Returns weighted averaged recall.
+   * (equals to precision, recall and f-measure)
+   */
+  @Since("2.3.0")
+  def weightedRecall: Double = multiclassMetrics.weightedRecall
+
+  /** Returns weighted averaged precision. */
+  @Since("2.3.0")
+  def weightedPrecision: Double = multiclassMetrics.weightedPrecision
+
+  /** Returns weighted averaged f-measure. */
+  @Since("2.3.0")
+  def weightedFMeasure(beta: Double): Double = multiclassMetrics.weightedFMeasure(beta)
+
+  /** Returns weighted averaged f1-measure. */
+  @Since("2.3.0")
+  def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0)
 }
 
 /**
  * :: Experimental ::
- * Logistic regression training results.
- *
- * @param predictions dataframe output by the model's `transform` method.
- * @param probabilityCol field in "predictions" which gives the probability of
- *                       each class as a vector.
- * @param labelCol field in "predictions" which gives the true label of each instance.
- * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
- * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ * Abstraction for multiclass logistic regression training results.
+ * Currently, the training summary ignores the training weights except
+ * for the objective trace.
  */
 @Experimental
-@Since("1.5.0")
-class BinaryLogisticRegressionTrainingSummary private[classification] (
-    predictions: DataFrame,
-    probabilityCol: String,
-    labelCol: String,
-    featuresCol: String,
-    @Since("1.5.0") val objectiveHistory: Array[Double])
-  extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol)
-  with LogisticRegressionTrainingSummary {
+sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
+
+  /** objective function (scaled loss + regularization) at each iteration. */
+  @Since("1.5.0")
+  def objectiveHistory: Array[Double]
+
+  /** Number of training iterations. */
+  @Since("1.5.0")
+  def totalIterations: Int = objectiveHistory.length
 
 }
 
 /**
  * :: Experimental ::
- * Binary Logistic regression results for a given model.
- *
- * @param predictions dataframe output by the model's `transform` method.
- * @param probabilityCol field in "predictions" which gives the probability of
- *                       each class as a vector.
- * @param labelCol field in "predictions" which gives the true label of each instance.
- * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ * Abstraction for binary logistic regression results for a given model.
  */
 @Experimental
-@Since("1.5.0")
-class BinaryLogisticRegressionSummary private[classification] (
-    @Since("1.5.0") @transient override val predictions: DataFrame,
-    @Since("1.5.0") override val probabilityCol: String,
-    @Since("1.5.0") override val labelCol: String,
-    @Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary {
-
+sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary {
 
   private val sparkSession = predictions.sparkSession
   import sparkSession.implicits._
 
-  /**
-   * Returns a BinaryClassificationMetrics object.
-   */
   // TODO: Allow the user to vary the number of bins using a setBins method in
   // BinaryClassificationMetrics. For now the default is set to 100.
   @transient private val binaryMetrics = new BinaryClassificationMetrics(
@@ -1484,3 +1578,99 @@ class BinaryLogisticRegressionSummary private[classification] (
     binaryMetrics.recallByThreshold().toDF("threshold", "recall")
   }
 }
+
+/**
+ * :: Experimental ::
+ * Abstraction for binary logistic regression training results.
+ * Currently, the training summary ignores the training weights except
+ * for the objective trace.
+ */
+@Experimental
+sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegressionSummary
+  with LogisticRegressionTrainingSummary
+
+/**
+ * Multiclass logistic regression training results.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ *                       each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ *                      double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+private class LogisticRegressionTrainingSummaryImpl(
+    predictions: DataFrame,
+    probabilityCol: String,
+    predictionCol: String,
+    labelCol: String,
+    featuresCol: String,
+    override val objectiveHistory: Array[Double])
+  extends LogisticRegressionSummaryImpl(
+    predictions, probabilityCol, predictionCol, labelCol, featuresCol)
+  with LogisticRegressionTrainingSummary
+
+/**
+ * Multiclass logistic regression results for a given model.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ *                       each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ *                      double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ */
+private class LogisticRegressionSummaryImpl(
+    @transient override val predictions: DataFrame,
+    override val probabilityCol: String,
+    override val predictionCol: String,
+    override val labelCol: String,
+    override val featuresCol: String)
+  extends LogisticRegressionSummary
+
+/**
+ * Binary logistic regression training results.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ *                       each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ *                      double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+private class BinaryLogisticRegressionTrainingSummaryImpl(
+    predictions: DataFrame,
+    probabilityCol: String,
+    predictionCol: String,
+    labelCol: String,
+    featuresCol: String,
+    override val objectiveHistory: Array[Double])
+  extends BinaryLogisticRegressionSummaryImpl(
+    predictions, probabilityCol, predictionCol, labelCol, featuresCol)
+  with BinaryLogisticRegressionTrainingSummary
+
+/**
+ * Binary logistic regression results for a given model.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ *                       each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction of
+ *                      each class as a double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ */
+private class BinaryLogisticRegressionSummaryImpl(
+    predictions: DataFrame,
+    probabilityCol: String,
+    predictionCol: String,
+    labelCol: String,
+    featuresCol: String)
+  extends LogisticRegressionSummaryImpl(
+    predictions, probabilityCol, predictionCol, labelCol, featuresCol)
+  with BinaryLogisticRegressionSummary

http://git-wip-us.apache.org/repos/asf/spark/blob/c7270a46/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 542977a..6649fa4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -222,15 +222,58 @@ class LogisticRegressionSuite
     }
   }
 
-  test("empty probabilityCol") {
-    val lr = new LogisticRegression().setProbabilityCol("")
-    val model = lr.fit(smallBinaryDataset)
-    assert(model.hasSummary)
-    // Validate that we re-insert a probability column for evaluation
-    val fieldNames = model.summary.predictions.schema.fieldNames
-    assert(smallBinaryDataset.schema.fieldNames.toSet.subsetOf(
-      fieldNames.toSet))
-    assert(fieldNames.exists(s => s.startsWith("probability_")))
+  test("empty probabilityCol or predictionCol") {
+    val lr = new LogisticRegression().setMaxIter(1)
+    val datasetFieldNames = smallBinaryDataset.schema.fieldNames.toSet
+    def checkSummarySchema(model: LogisticRegressionModel, columns: Seq[String]): Unit = {
+      val fieldNames = model.summary.predictions.schema.fieldNames
+      assert(model.hasSummary)
+      assert(datasetFieldNames.subsetOf(fieldNames.toSet))
+      columns.foreach { c => assert(fieldNames.exists(_.startsWith(c))) }
+    }
+    // check that the summary model adds the appropriate columns
+    Seq(("binomial", smallBinaryDataset), ("multinomial", smallMultinomialDataset)).foreach {
+      case (family, dataset) =>
+        lr.setFamily(family)
+        lr.setProbabilityCol("").setPredictionCol("prediction")
+        val modelNoProb = lr.fit(dataset)
+        checkSummarySchema(modelNoProb, Seq("probability_"))
+
+        lr.setProbabilityCol("probability").setPredictionCol("")
+        val modelNoPred = lr.fit(dataset)
+        checkSummarySchema(modelNoPred, Seq("prediction_"))
+
+        lr.setProbabilityCol("").setPredictionCol("")
+        val modelNoPredNoProb = lr.fit(dataset)
+        checkSummarySchema(modelNoPredNoProb, Seq("prediction_", "probability_"))
+    }
+  }
+
+  test("check summary types for binary and multiclass") {
+    val lr = new LogisticRegression()
+      .setFamily("binomial")
+      .setMaxIter(1)
+
+    val blorModel = lr.fit(smallBinaryDataset)
+    assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+    assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+
+    val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset)
+    assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummary])
+    withClue("cannot get binary summary for multiclass model") {
+      intercept[RuntimeException] {
+        mlorModel.binarySummary
+      }
+    }
+
+    val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset)
+    assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+    assert(mlorBinaryModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+
+    val blorSummary = blorModel.evaluate(smallBinaryDataset)
+    val mlorSummary = mlorModel.evaluate(smallMultinomialDataset)
+    assert(blorSummary.isInstanceOf[BinaryLogisticRegressionSummary])
+    assert(mlorSummary.isInstanceOf[LogisticRegressionSummary])
   }
 
   test("setThreshold, getThreshold") {
@@ -2341,51 +2384,98 @@ class LogisticRegressionSuite
   }
 
   test("evaluate on test set") {
-    // TODO: add for multiclass when model summary becomes available
     // Evaluate on test set should be same as that of the transformed training data.
     val lr = new LogisticRegression()
       .setMaxIter(10)
       .setRegParam(1.0)
       .setThreshold(0.6)
-    val model = lr.fit(smallBinaryDataset)
-    val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary]
-
-    val sameSummary =
-      model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
-    assert(summary.areaUnderROC === sameSummary.areaUnderROC)
-    assert(summary.roc.collect() === sameSummary.roc.collect())
-    assert(summary.pr.collect === sameSummary.pr.collect())
+      .setFamily("binomial")
+    val blorModel = lr.fit(smallBinaryDataset)
+    val blorSummary = blorModel.binarySummary
+
+    val sameBlorSummary =
+      blorModel.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
+    assert(blorSummary.areaUnderROC === sameBlorSummary.areaUnderROC)
+    assert(blorSummary.roc.collect() === sameBlorSummary.roc.collect())
+    assert(blorSummary.pr.collect === sameBlorSummary.pr.collect())
+    assert(
+      blorSummary.fMeasureByThreshold.collect() === sameBlorSummary.fMeasureByThreshold.collect())
     assert(
-      summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect())
-    assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect())
+      blorSummary.recallByThreshold.collect() === sameBlorSummary.recallByThreshold.collect())
     assert(
-      summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
+      blorSummary.precisionByThreshold.collect() === sameBlorSummary.precisionByThreshold.collect())
+
+    lr.setFamily("multinomial")
+    val mlorModel = lr.fit(smallMultinomialDataset)
+    val mlorSummary = mlorModel.summary
+
+    val mlorSameSummary = mlorModel.evaluate(smallMultinomialDataset)
+
+    assert(mlorSummary.truePositiveRateByLabel === mlorSameSummary.truePositiveRateByLabel)
+    assert(mlorSummary.falsePositiveRateByLabel === mlorSameSummary.falsePositiveRateByLabel)
+    assert(mlorSummary.precisionByLabel === mlorSameSummary.precisionByLabel)
+    assert(mlorSummary.recallByLabel === mlorSameSummary.recallByLabel)
+    assert(mlorSummary.fMeasureByLabel === mlorSameSummary.fMeasureByLabel)
+    assert(mlorSummary.accuracy === mlorSameSummary.accuracy)
+    assert(mlorSummary.weightedTruePositiveRate === mlorSameSummary.weightedTruePositiveRate)
+    assert(mlorSummary.weightedFalsePositiveRate === mlorSameSummary.weightedFalsePositiveRate)
+    assert(mlorSummary.weightedPrecision === mlorSameSummary.weightedPrecision)
+    assert(mlorSummary.weightedRecall === mlorSameSummary.weightedRecall)
+    assert(mlorSummary.weightedFMeasure === mlorSameSummary.weightedFMeasure)
   }
 
   test("evaluate with labels that are not doubles") {
     // Evaluate a test set with Label that is a numeric type other than Double
-    val lr = new LogisticRegression()
+    val blor = new LogisticRegression()
       .setMaxIter(1)
       .setRegParam(1.0)
-    val model = lr.fit(smallBinaryDataset)
-    val summary = model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
+      .setFamily("binomial")
+    val blorModel = blor.fit(smallBinaryDataset)
+    val blorSummary = blorModel.evaluate(smallBinaryDataset)
+      .asInstanceOf[BinaryLogisticRegressionSummary]
+
+    val blorLongLabelData = smallBinaryDataset.select(col(blorModel.getLabelCol).cast(LongType),
+      col(blorModel.getFeaturesCol))
+    val blorLongSummary = blorModel.evaluate(blorLongLabelData)
+      .asInstanceOf[BinaryLogisticRegressionSummary]
+
+    assert(blorSummary.areaUnderROC === blorLongSummary.areaUnderROC)
 
-    val longLabelData = smallBinaryDataset.select(col(model.getLabelCol).cast(LongType),
-      col(model.getFeaturesCol))
-    val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary]
+    val mlor = new LogisticRegression()
+      .setMaxIter(1)
+      .setRegParam(1.0)
+      .setFamily("multinomial")
+    val mlorModel = mlor.fit(smallMultinomialDataset)
+    val mlorSummary = mlorModel.evaluate(smallMultinomialDataset)
+
+    val mlorLongLabelData = smallMultinomialDataset.select(
+      col(mlorModel.getLabelCol).cast(LongType),
+      col(mlorModel.getFeaturesCol))
+    val mlorLongSummary = mlorModel.evaluate(mlorLongLabelData)
 
-    assert(summary.areaUnderROC === longSummary.areaUnderROC)
+    assert(mlorSummary.accuracy === mlorLongSummary.accuracy)
   }
 
   test("statistics on training data") {
     // Test that loss is monotonically decreasing.
-    val lr = new LogisticRegression()
+    val blor = new LogisticRegression()
       .setMaxIter(10)
       .setRegParam(1.0)
-      .setThreshold(0.6)
-    val model = lr.fit(smallBinaryDataset)
+      .setFamily("binomial")
+    val blorModel = blor.fit(smallBinaryDataset)
+    assert(
+      blorModel.summary
+        .objectiveHistory
+        .sliding(2)
+        .forall(x => x(0) >= x(1)))
+
+    val mlor = new LogisticRegression()
+      .setMaxIter(10)
+      .setRegParam(1.0)
+      .setFamily("multinomial")
+    val mlorModel = mlor.fit(smallMultinomialDataset)
     assert(
-      model.summary
+      mlorModel.summary
         .objectiveHistory
         .sliding(2)
         .forall(x => x(0) >= x(1)))
@@ -2470,7 +2560,7 @@ class LogisticRegressionSuite
     predictions3.zip(predictions4).foreach { case (Row(p1: Double), Row(p2: Double)) =>
       assert(p1 === p2)
     }
-    // TODO: check that it converges in a single iteration when model summary is available
+    assert(model4.summary.totalIterations === 1)
   }
 
   test("binary logistic regression with all labels the same") {
@@ -2531,6 +2621,7 @@ class LogisticRegressionSuite
         assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0)))
         assert(pred === 4.0)
     }
+    assert(model.summary.totalIterations === 0)
 
     // force the model to be trained with only one class
     val constantZeroData = Seq(
@@ -2544,6 +2635,7 @@ class LogisticRegressionSuite
         assert(prob === Vectors.dense(Array(1.0)))
         assert(pred === 0.0)
     }
+    assert(modelZeroLabel.summary.totalIterations > 0)
 
     // ensure that the correct value is predicted when numClasses passed through metadata
     val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(6).toMetadata()
@@ -2557,7 +2649,7 @@ class LogisticRegressionSuite
         assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0)))
         assert(pred === 4.0)
     }
-    // TODO: check num iters is zero when it become available in the model
+    require(modelWithMetadata.summary.totalIterations === 0)
   }
 
   test("compressed storage for constant label") {

http://git-wip-us.apache.org/repos/asf/spark/blob/c7270a46/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index e7bd4eb..f470dca 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -715,7 +715,7 @@ class LinearRegressionSuite
       assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_")))
 
       // Residuals in [[LinearRegressionResults]] should equal those manually computed
-      val expectedResiduals = datasetWithDenseFeature.select("features", "label")
+      datasetWithDenseFeature.select("features", "label")
         .rdd
         .map { case Row(features: DenseVector, label: Double) =>
           val prediction =

http://git-wip-us.apache.org/repos/asf/spark/blob/c7270a46/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 9bda917..eecda26 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -44,7 +44,26 @@ object MimaExcludes {
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetricDistributions.this"),
 
     // [SPARK-21276] Update lz4-java to the latest (v1.4.0)
-    ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.io.LZ4BlockInputStream")
+    ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.io.LZ4BlockInputStream"),
+
+    // [SPARK-17139] Add model summary for MultinomialLogisticRegression
+    ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary"),
+    ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictionCol"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.labels"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.truePositiveRateByLabel"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.falsePositiveRateByLabel"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.precisionByLabel"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.recallByLabel"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.fMeasureByLabel"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.accuracy"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedTruePositiveRate"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFalsePositiveRate"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedRecall"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedPrecision"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$_setter_$org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics_=")
   )
 
   // Exclude rules for 2.2.x


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org