You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/08/28 20:31:08 UTC
spark git commit: [SPARK-17139][ML] Add model summary for
MultinomialLogisticRegression
Repository: spark
Updated Branches:
refs/heads/master 73e64f7d5 -> c7270a46f
[SPARK-17139][ML] Add model summary for MultinomialLogisticRegression
## What changes were proposed in this pull request?
Add 4 traits, using the following hierarchy:
LogisticRegressionSummary
LogisticRegressionTrainingSummary: LogisticRegressionSummary
BinaryLogisticRegressionSummary: LogisticRegressionSummary
BinaryLogisticRegressionTrainingSummary: LogisticRegressionTrainingSummary, BinaryLogisticRegressionSummary
and the public method such as `def summary` only return trait type listed above.
and then implement 4 concrete classes:
LogisticRegressionSummaryImpl (multiclass case)
LogisticRegressionTrainingSummaryImpl (multiclass case)
BinaryLogisticRegressionSummaryImpl (binary case).
BinaryLogisticRegressionTrainingSummaryImpl (binary case).
## How was this patch tested?
Existing tests & added tests.
Author: WeichenXu <We...@outlook.com>
Closes #15435 from WeichenXu123/mlor_summary.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c7270a46
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c7270a46
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c7270a46
Branch: refs/heads/master
Commit: c7270a46fc340db62c87ddfc6568603d0b832845
Parents: 73e64f7
Author: Weichen Xu <we...@databricks.com>
Authored: Mon Aug 28 13:31:01 2017 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Aug 28 13:31:01 2017 -0700
----------------------------------------------------------------------
.../ml/classification/LogisticRegression.scala | 340 +++++++++++++++----
.../LogisticRegressionSuite.scala | 160 +++++++--
.../ml/regression/LinearRegressionSuite.scala | 2 +-
project/MimaExcludes.scala | 21 +-
4 files changed, 412 insertions(+), 111 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/c7270a46/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 21957d9..ffe4b52 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -22,7 +22,7 @@ import java.util.Locale
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
-import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
+import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
@@ -35,7 +35,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
-import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
@@ -882,21 +882,28 @@ class LogisticRegression @Since("1.2.0") (
val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
numClasses, isMultinomial))
- // TODO: implement summary model for multinomial case
- val m = if (!isMultinomial) {
- val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol()
- val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
+
+ val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
+ val logRegSummary = if (numClasses <= 2) {
+ new BinaryLogisticRegressionTrainingSummaryImpl(
summaryModel.transform(dataset),
probabilityColName,
+ predictionColName,
$(labelCol),
$(featuresCol),
objectiveHistory)
- model.setSummary(Some(logRegSummary))
} else {
- model
+ new LogisticRegressionTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ probabilityColName,
+ predictionColName,
+ $(labelCol),
+ $(featuresCol),
+ objectiveHistory)
}
- instr.logSuccess(m)
- m
+ model.setSummary(Some(logRegSummary))
+ instr.logSuccess(model)
+ model
}
@Since("1.4.0")
@@ -1010,8 +1017,8 @@ class LogisticRegressionModel private[spark] (
private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
/**
- * Gets summary of model on training set. An exception is
- * thrown if `trainingSummary == None`.
+ * Gets summary of model on training set. An exception is thrown
+ * if `trainingSummary == None`.
*/
@Since("1.5.0")
def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse {
@@ -1019,18 +1026,36 @@ class LogisticRegressionModel private[spark] (
}
/**
- * If the probability column is set returns the current model and probability column,
- * otherwise generates a new column and sets it as the probability column on a new copy
- * of the current model.
+ * Gets summary of model on training set. An exception is thrown
+ * if `trainingSummary == None` or it is a multiclass model.
*/
- private[classification] def findSummaryModelAndProbabilityCol():
- (LogisticRegressionModel, String) = {
- $(probabilityCol) match {
- case "" =>
- val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString
- (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
- case p => (this, p)
+ @Since("2.3.0")
+ def binarySummary: BinaryLogisticRegressionTrainingSummary = summary match {
+ case b: BinaryLogisticRegressionTrainingSummary => b
+ case _ =>
+ throw new RuntimeException("Cannot create a binary summary for a non-binary model" +
+ s"(numClasses=${numClasses}), use summary instead.")
+ }
+
+ /**
+ * If the probability and prediction columns are set, this method returns the current model,
+ * otherwise it generates new columns for them and sets them as columns on a new copy of
+ * the current model
+ */
+ private[classification] def findSummaryModel():
+ (LogisticRegressionModel, String, String) = {
+ val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
+ copy(ParamMap.empty)
+ .setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
+ .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
+ } else if ($(probabilityCol).isEmpty) {
+ copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
+ } else if ($(predictionCol).isEmpty) {
+ copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
+ } else {
+ this
}
+ (model, model.getProbabilityCol, model.getPredictionCol)
}
private[classification]
@@ -1051,9 +1076,14 @@ class LogisticRegressionModel private[spark] (
@Since("2.0.0")
def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = {
// Handle possible missing or invalid prediction columns
- val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol()
- new BinaryLogisticRegressionSummary(summaryModel.transform(dataset),
- probabilityColName, $(labelCol), $(featuresCol))
+ val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
+ if (numClasses > 2) {
+ new LogisticRegressionSummaryImpl(summaryModel.transform(dataset),
+ probabilityColName, predictionColName, $(labelCol), $(featuresCol))
+ } else {
+ new BinaryLogisticRegressionSummaryImpl(summaryModel.transform(dataset),
+ probabilityColName, predictionColName, $(labelCol), $(featuresCol))
+ }
}
/**
@@ -1324,90 +1354,154 @@ private[ml] class MultiClassSummarizer extends Serializable {
}
/**
- * Abstraction for multinomial Logistic Regression Training results.
- * Currently, the training summary ignores the training weights except
- * for the objective trace.
- */
-sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
-
- /** objective function (scaled loss + regularization) at each iteration. */
- def objectiveHistory: Array[Double]
-
- /** Number of training iterations until termination */
- def totalIterations: Int = objectiveHistory.length
-
-}
-
-/**
- * Abstraction for Logistic Regression Results for a given model.
+ * :: Experimental ::
+ * Abstraction for logistic regression results for a given model.
*/
+@Experimental
sealed trait LogisticRegressionSummary extends Serializable {
/**
* Dataframe output by the model's `transform` method.
*/
+ @Since("1.5.0")
def predictions: DataFrame
/** Field in "predictions" which gives the probability of each class as a vector. */
+ @Since("1.5.0")
def probabilityCol: String
+ /** Field in "predictions" which gives the prediction of each class. */
+ @Since("2.3.0")
+ def predictionCol: String
+
/** Field in "predictions" which gives the true label of each instance (if available). */
+ @Since("1.5.0")
def labelCol: String
/** Field in "predictions" which gives the features of each instance as a vector. */
+ @Since("1.6.0")
def featuresCol: String
+ @transient private val multiclassMetrics = {
+ new MulticlassMetrics(
+ predictions.select(
+ col(predictionCol),
+ col(labelCol).cast(DoubleType))
+ .rdd.map { case Row(prediction: Double, label: Double) => (prediction, label) })
+ }
+
+ /**
+ * Returns the sequence of labels in ascending order. This order matches the order used
+ * in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.
+ *
+ * Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
+ * training set is missing a label, then all of the arrays over labels
+ * (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
+ * expected numClasses.
+ */
+ @Since("2.3.0")
+ def labels: Array[Double] = multiclassMetrics.labels
+
+ /** Returns true positive rate for each label (category). */
+ @Since("2.3.0")
+ def truePositiveRateByLabel: Array[Double] = recallByLabel
+
+ /** Returns false positive rate for each label (category). */
+ @Since("2.3.0")
+ def falsePositiveRateByLabel: Array[Double] = {
+ multiclassMetrics.labels.map(label => multiclassMetrics.falsePositiveRate(label))
+ }
+
+ /** Returns precision for each label (category). */
+ @Since("2.3.0")
+ def precisionByLabel: Array[Double] = {
+ multiclassMetrics.labels.map(label => multiclassMetrics.precision(label))
+ }
+
+ /** Returns recall for each label (category). */
+ @Since("2.3.0")
+ def recallByLabel: Array[Double] = {
+ multiclassMetrics.labels.map(label => multiclassMetrics.recall(label))
+ }
+
+ /** Returns f-measure for each label (category). */
+ @Since("2.3.0")
+ def fMeasureByLabel(beta: Double): Array[Double] = {
+ multiclassMetrics.labels.map(label => multiclassMetrics.fMeasure(label, beta))
+ }
+
+ /** Returns f1-measure for each label (category). */
+ @Since("2.3.0")
+ def fMeasureByLabel: Array[Double] = fMeasureByLabel(1.0)
+
+ /**
+ * Returns accuracy.
+ * (equals to the total number of correctly classified instances
+ * out of the total number of instances.)
+ */
+ @Since("2.3.0")
+ def accuracy: Double = multiclassMetrics.accuracy
+
+ /**
+ * Returns weighted true positive rate.
+ * (equals to precision, recall and f-measure)
+ */
+ @Since("2.3.0")
+ def weightedTruePositiveRate: Double = weightedRecall
+
+ /** Returns weighted false positive rate. */
+ @Since("2.3.0")
+ def weightedFalsePositiveRate: Double = multiclassMetrics.weightedFalsePositiveRate
+
+ /**
+ * Returns weighted averaged recall.
+ * (equals to precision, recall and f-measure)
+ */
+ @Since("2.3.0")
+ def weightedRecall: Double = multiclassMetrics.weightedRecall
+
+ /** Returns weighted averaged precision. */
+ @Since("2.3.0")
+ def weightedPrecision: Double = multiclassMetrics.weightedPrecision
+
+ /** Returns weighted averaged f-measure. */
+ @Since("2.3.0")
+ def weightedFMeasure(beta: Double): Double = multiclassMetrics.weightedFMeasure(beta)
+
+ /** Returns weighted averaged f1-measure. */
+ @Since("2.3.0")
+ def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0)
}
/**
* :: Experimental ::
- * Logistic regression training results.
- *
- * @param predictions dataframe output by the model's `transform` method.
- * @param probabilityCol field in "predictions" which gives the probability of
- * each class as a vector.
- * @param labelCol field in "predictions" which gives the true label of each instance.
- * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
- * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ * Abstraction for multiclass logistic regression training results.
+ * Currently, the training summary ignores the training weights except
+ * for the objective trace.
*/
@Experimental
-@Since("1.5.0")
-class BinaryLogisticRegressionTrainingSummary private[classification] (
- predictions: DataFrame,
- probabilityCol: String,
- labelCol: String,
- featuresCol: String,
- @Since("1.5.0") val objectiveHistory: Array[Double])
- extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol)
- with LogisticRegressionTrainingSummary {
+sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
+
+ /** objective function (scaled loss + regularization) at each iteration. */
+ @Since("1.5.0")
+ def objectiveHistory: Array[Double]
+
+ /** Number of training iterations. */
+ @Since("1.5.0")
+ def totalIterations: Int = objectiveHistory.length
}
/**
* :: Experimental ::
- * Binary Logistic regression results for a given model.
- *
- * @param predictions dataframe output by the model's `transform` method.
- * @param probabilityCol field in "predictions" which gives the probability of
- * each class as a vector.
- * @param labelCol field in "predictions" which gives the true label of each instance.
- * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ * Abstraction for binary logistic regression results for a given model.
*/
@Experimental
-@Since("1.5.0")
-class BinaryLogisticRegressionSummary private[classification] (
- @Since("1.5.0") @transient override val predictions: DataFrame,
- @Since("1.5.0") override val probabilityCol: String,
- @Since("1.5.0") override val labelCol: String,
- @Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary {
-
+sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary {
private val sparkSession = predictions.sparkSession
import sparkSession.implicits._
- /**
- * Returns a BinaryClassificationMetrics object.
- */
// TODO: Allow the user to vary the number of bins using a setBins method in
// BinaryClassificationMetrics. For now the default is set to 100.
@transient private val binaryMetrics = new BinaryClassificationMetrics(
@@ -1484,3 +1578,99 @@ class BinaryLogisticRegressionSummary private[classification] (
binaryMetrics.recallByThreshold().toDF("threshold", "recall")
}
}
+
+/**
+ * :: Experimental ::
+ * Abstraction for binary logistic regression training results.
+ * Currently, the training summary ignores the training weights except
+ * for the objective trace.
+ */
+@Experimental
+sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegressionSummary
+ with LogisticRegressionTrainingSummary
+
+/**
+ * Multiclass logistic regression training results.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ * each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+private class LogisticRegressionTrainingSummaryImpl(
+ predictions: DataFrame,
+ probabilityCol: String,
+ predictionCol: String,
+ labelCol: String,
+ featuresCol: String,
+ override val objectiveHistory: Array[Double])
+ extends LogisticRegressionSummaryImpl(
+ predictions, probabilityCol, predictionCol, labelCol, featuresCol)
+ with LogisticRegressionTrainingSummary
+
+/**
+ * Multiclass logistic regression results for a given model.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ * each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ */
+private class LogisticRegressionSummaryImpl(
+ @transient override val predictions: DataFrame,
+ override val probabilityCol: String,
+ override val predictionCol: String,
+ override val labelCol: String,
+ override val featuresCol: String)
+ extends LogisticRegressionSummary
+
+/**
+ * Binary logistic regression training results.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ * each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+private class BinaryLogisticRegressionTrainingSummaryImpl(
+ predictions: DataFrame,
+ probabilityCol: String,
+ predictionCol: String,
+ labelCol: String,
+ featuresCol: String,
+ override val objectiveHistory: Array[Double])
+ extends BinaryLogisticRegressionSummaryImpl(
+ predictions, probabilityCol, predictionCol, labelCol, featuresCol)
+ with BinaryLogisticRegressionTrainingSummary
+
+/**
+ * Binary logistic regression results for a given model.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ * each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction of
+ * each class as a double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ */
+private class BinaryLogisticRegressionSummaryImpl(
+ predictions: DataFrame,
+ probabilityCol: String,
+ predictionCol: String,
+ labelCol: String,
+ featuresCol: String)
+ extends LogisticRegressionSummaryImpl(
+ predictions, probabilityCol, predictionCol, labelCol, featuresCol)
+ with BinaryLogisticRegressionSummary
http://git-wip-us.apache.org/repos/asf/spark/blob/c7270a46/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 542977a..6649fa4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -222,15 +222,58 @@ class LogisticRegressionSuite
}
}
- test("empty probabilityCol") {
- val lr = new LogisticRegression().setProbabilityCol("")
- val model = lr.fit(smallBinaryDataset)
- assert(model.hasSummary)
- // Validate that we re-insert a probability column for evaluation
- val fieldNames = model.summary.predictions.schema.fieldNames
- assert(smallBinaryDataset.schema.fieldNames.toSet.subsetOf(
- fieldNames.toSet))
- assert(fieldNames.exists(s => s.startsWith("probability_")))
+ test("empty probabilityCol or predictionCol") {
+ val lr = new LogisticRegression().setMaxIter(1)
+ val datasetFieldNames = smallBinaryDataset.schema.fieldNames.toSet
+ def checkSummarySchema(model: LogisticRegressionModel, columns: Seq[String]): Unit = {
+ val fieldNames = model.summary.predictions.schema.fieldNames
+ assert(model.hasSummary)
+ assert(datasetFieldNames.subsetOf(fieldNames.toSet))
+ columns.foreach { c => assert(fieldNames.exists(_.startsWith(c))) }
+ }
+ // check that the summary model adds the appropriate columns
+ Seq(("binomial", smallBinaryDataset), ("multinomial", smallMultinomialDataset)).foreach {
+ case (family, dataset) =>
+ lr.setFamily(family)
+ lr.setProbabilityCol("").setPredictionCol("prediction")
+ val modelNoProb = lr.fit(dataset)
+ checkSummarySchema(modelNoProb, Seq("probability_"))
+
+ lr.setProbabilityCol("probability").setPredictionCol("")
+ val modelNoPred = lr.fit(dataset)
+ checkSummarySchema(modelNoPred, Seq("prediction_"))
+
+ lr.setProbabilityCol("").setPredictionCol("")
+ val modelNoPredNoProb = lr.fit(dataset)
+ checkSummarySchema(modelNoPredNoProb, Seq("prediction_", "probability_"))
+ }
+ }
+
+ test("check summary types for binary and multiclass") {
+ val lr = new LogisticRegression()
+ .setFamily("binomial")
+ .setMaxIter(1)
+
+ val blorModel = lr.fit(smallBinaryDataset)
+ assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+ assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+
+ val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset)
+ assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummary])
+ withClue("cannot get binary summary for multiclass model") {
+ intercept[RuntimeException] {
+ mlorModel.binarySummary
+ }
+ }
+
+ val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset)
+ assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+ assert(mlorBinaryModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+
+ val blorSummary = blorModel.evaluate(smallBinaryDataset)
+ val mlorSummary = mlorModel.evaluate(smallMultinomialDataset)
+ assert(blorSummary.isInstanceOf[BinaryLogisticRegressionSummary])
+ assert(mlorSummary.isInstanceOf[LogisticRegressionSummary])
}
test("setThreshold, getThreshold") {
@@ -2341,51 +2384,98 @@ class LogisticRegressionSuite
}
test("evaluate on test set") {
- // TODO: add for multiclass when model summary becomes available
// Evaluate on test set should be same as that of the transformed training data.
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
.setThreshold(0.6)
- val model = lr.fit(smallBinaryDataset)
- val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary]
-
- val sameSummary =
- model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
- assert(summary.areaUnderROC === sameSummary.areaUnderROC)
- assert(summary.roc.collect() === sameSummary.roc.collect())
- assert(summary.pr.collect === sameSummary.pr.collect())
+ .setFamily("binomial")
+ val blorModel = lr.fit(smallBinaryDataset)
+ val blorSummary = blorModel.binarySummary
+
+ val sameBlorSummary =
+ blorModel.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
+ assert(blorSummary.areaUnderROC === sameBlorSummary.areaUnderROC)
+ assert(blorSummary.roc.collect() === sameBlorSummary.roc.collect())
+ assert(blorSummary.pr.collect === sameBlorSummary.pr.collect())
+ assert(
+ blorSummary.fMeasureByThreshold.collect() === sameBlorSummary.fMeasureByThreshold.collect())
assert(
- summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect())
- assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect())
+ blorSummary.recallByThreshold.collect() === sameBlorSummary.recallByThreshold.collect())
assert(
- summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
+ blorSummary.precisionByThreshold.collect() === sameBlorSummary.precisionByThreshold.collect())
+
+ lr.setFamily("multinomial")
+ val mlorModel = lr.fit(smallMultinomialDataset)
+ val mlorSummary = mlorModel.summary
+
+ val mlorSameSummary = mlorModel.evaluate(smallMultinomialDataset)
+
+ assert(mlorSummary.truePositiveRateByLabel === mlorSameSummary.truePositiveRateByLabel)
+ assert(mlorSummary.falsePositiveRateByLabel === mlorSameSummary.falsePositiveRateByLabel)
+ assert(mlorSummary.precisionByLabel === mlorSameSummary.precisionByLabel)
+ assert(mlorSummary.recallByLabel === mlorSameSummary.recallByLabel)
+ assert(mlorSummary.fMeasureByLabel === mlorSameSummary.fMeasureByLabel)
+ assert(mlorSummary.accuracy === mlorSameSummary.accuracy)
+ assert(mlorSummary.weightedTruePositiveRate === mlorSameSummary.weightedTruePositiveRate)
+ assert(mlorSummary.weightedFalsePositiveRate === mlorSameSummary.weightedFalsePositiveRate)
+ assert(mlorSummary.weightedPrecision === mlorSameSummary.weightedPrecision)
+ assert(mlorSummary.weightedRecall === mlorSameSummary.weightedRecall)
+ assert(mlorSummary.weightedFMeasure === mlorSameSummary.weightedFMeasure)
}
test("evaluate with labels that are not doubles") {
// Evaluate a test set with Label that is a numeric type other than Double
- val lr = new LogisticRegression()
+ val blor = new LogisticRegression()
.setMaxIter(1)
.setRegParam(1.0)
- val model = lr.fit(smallBinaryDataset)
- val summary = model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
+ .setFamily("binomial")
+ val blorModel = blor.fit(smallBinaryDataset)
+ val blorSummary = blorModel.evaluate(smallBinaryDataset)
+ .asInstanceOf[BinaryLogisticRegressionSummary]
+
+ val blorLongLabelData = smallBinaryDataset.select(col(blorModel.getLabelCol).cast(LongType),
+ col(blorModel.getFeaturesCol))
+ val blorLongSummary = blorModel.evaluate(blorLongLabelData)
+ .asInstanceOf[BinaryLogisticRegressionSummary]
+
+ assert(blorSummary.areaUnderROC === blorLongSummary.areaUnderROC)
- val longLabelData = smallBinaryDataset.select(col(model.getLabelCol).cast(LongType),
- col(model.getFeaturesCol))
- val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary]
+ val mlor = new LogisticRegression()
+ .setMaxIter(1)
+ .setRegParam(1.0)
+ .setFamily("multinomial")
+ val mlorModel = mlor.fit(smallMultinomialDataset)
+ val mlorSummary = mlorModel.evaluate(smallMultinomialDataset)
+
+ val mlorLongLabelData = smallMultinomialDataset.select(
+ col(mlorModel.getLabelCol).cast(LongType),
+ col(mlorModel.getFeaturesCol))
+ val mlorLongSummary = mlorModel.evaluate(mlorLongLabelData)
- assert(summary.areaUnderROC === longSummary.areaUnderROC)
+ assert(mlorSummary.accuracy === mlorLongSummary.accuracy)
}
test("statistics on training data") {
// Test that loss is monotonically decreasing.
- val lr = new LogisticRegression()
+ val blor = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
- .setThreshold(0.6)
- val model = lr.fit(smallBinaryDataset)
+ .setFamily("binomial")
+ val blorModel = blor.fit(smallBinaryDataset)
+ assert(
+ blorModel.summary
+ .objectiveHistory
+ .sliding(2)
+ .forall(x => x(0) >= x(1)))
+
+ val mlor = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ .setFamily("multinomial")
+ val mlorModel = mlor.fit(smallMultinomialDataset)
assert(
- model.summary
+ mlorModel.summary
.objectiveHistory
.sliding(2)
.forall(x => x(0) >= x(1)))
@@ -2470,7 +2560,7 @@ class LogisticRegressionSuite
predictions3.zip(predictions4).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(p1 === p2)
}
- // TODO: check that it converges in a single iteration when model summary is available
+ assert(model4.summary.totalIterations === 1)
}
test("binary logistic regression with all labels the same") {
@@ -2531,6 +2621,7 @@ class LogisticRegressionSuite
assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0)))
assert(pred === 4.0)
}
+ assert(model.summary.totalIterations === 0)
// force the model to be trained with only one class
val constantZeroData = Seq(
@@ -2544,6 +2635,7 @@ class LogisticRegressionSuite
assert(prob === Vectors.dense(Array(1.0)))
assert(pred === 0.0)
}
+ assert(modelZeroLabel.summary.totalIterations > 0)
// ensure that the correct value is predicted when numClasses passed through metadata
val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(6).toMetadata()
@@ -2557,7 +2649,7 @@ class LogisticRegressionSuite
assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0)))
assert(pred === 4.0)
}
- // TODO: check num iters is zero when it become available in the model
+ require(modelWithMetadata.summary.totalIterations === 0)
}
test("compressed storage for constant label") {
http://git-wip-us.apache.org/repos/asf/spark/blob/c7270a46/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index e7bd4eb..f470dca 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -715,7 +715,7 @@ class LinearRegressionSuite
assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_")))
// Residuals in [[LinearRegressionResults]] should equal those manually computed
- val expectedResiduals = datasetWithDenseFeature.select("features", "label")
+ datasetWithDenseFeature.select("features", "label")
.rdd
.map { case Row(features: DenseVector, label: Double) =>
val prediction =
http://git-wip-us.apache.org/repos/asf/spark/blob/c7270a46/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 9bda917..eecda26 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -44,7 +44,26 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetricDistributions.this"),
// [SPARK-21276] Update lz4-java to the latest (v1.4.0)
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.io.LZ4BlockInputStream")
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.io.LZ4BlockInputStream"),
+
+ // [SPARK-17139] Add model summary for MultinomialLogisticRegression
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary"),
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictionCol"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.labels"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.truePositiveRateByLabel"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.falsePositiveRateByLabel"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.precisionByLabel"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.recallByLabel"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.fMeasureByLabel"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.accuracy"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedTruePositiveRate"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFalsePositiveRate"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedRecall"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedPrecision"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$_setter_$org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics_=")
)
// Exclude rules for 2.2.x
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org