You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/04/28 20:22:18 UTC

spark git commit: [SPARK-14852][ML] refactored GLM summary into training, non-training summaries

Repository: spark
Updated Branches:
  refs/heads/master 12c360c05 -> 5ee72454d


[SPARK-14852][ML] refactored GLM summary into training, non-training summaries

## What changes were proposed in this pull request?

This splits GeneralizedLinearRegressionSummary into 2 summary types:
* GeneralizedLinearRegressionSummary, which does not store info from fitting (diagInvAtWA)
* GeneralizedLinearRegressionTrainingSummary, which is a subclass of GeneralizedLinearRegressionSummary and stores info from fitting

This also add a method evaluate() which can produce a GeneralizedLinearRegressionSummary on a new dataset.

The summary no longer provides the model itself as a public val.

Also:
* Fixes bug where GeneralizedLinearRegressionTrainingSummary was created with model, not summaryModel.
* Adds hasSummary method.
* Renames findSummaryModelAndPredictionCol -> getSummaryModel and simplifies that method.
* In summary, extract values from model immediately in case user later changes those (e.g., predictionCol).
* Pardon the style fixes; that is IntelliJ being obnoxious.

## How was this patch tested?

Existing unit tests + updated test for evaluate and hasSummary

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #12624 from jkbradley/model-summary-api.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5ee72454
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5ee72454
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5ee72454

Branch: refs/heads/master
Commit: 5ee72454df21ef4668c855134627d0cdf5d35132
Parents: 12c360c
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Thu Apr 28 11:22:13 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Apr 28 11:22:13 2016 -0700

----------------------------------------------------------------------
 .../GeneralizedLinearRegression.scala           | 156 ++++++++++++-------
 .../GeneralizedLinearRegressionSuite.scala      |  14 ++
 2 files changed, 115 insertions(+), 55 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5ee72454/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index dcf69af..bf9d3ff 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 
+
 /**
  * Params for Generalized Linear Regression.
  */
@@ -81,6 +82,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
   /**
    * Param for link prediction (linear predictor) column name.
    * Default is empty, which means we do not output link prediction.
+   *
    * @group param
    */
   @Since("2.0.0")
@@ -144,6 +146,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
   /**
    * Sets the value of param [[family]].
    * Default is "gaussian".
+   *
    * @group setParam
    */
   @Since("2.0.0")
@@ -152,6 +155,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
 
   /**
    * Sets the value of param [[link]].
+   *
    * @group setParam
    */
   @Since("2.0.0")
@@ -160,6 +164,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
   /**
    * Sets if we should fit the intercept.
    * Default is true.
+   *
    * @group setParam
    */
   @Since("2.0.0")
@@ -168,6 +173,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
   /**
    * Sets the maximum number of iterations.
    * Default is 25 if the solver algorithm is "irls".
+   *
    * @group setParam
    */
   @Since("2.0.0")
@@ -177,6 +183,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
    * Sets the convergence tolerance of iterations.
    * Smaller value will lead to higher accuracy with the cost of more iterations.
    * Default is 1E-6.
+   *
    * @group setParam
    */
   @Since("2.0.0")
@@ -190,6 +197,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
    *   0.5 * regParam * L2norm(coefficients)^2
    * }}}
    * Default is 0.0.
+   *
    * @group setParam
    */
   @Since("2.0.0")
@@ -200,6 +208,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
    * Sets the value of param [[weightCol]].
    * If this is not set or empty, we treat all instance weights as 1.0.
    * Default is empty, so all instances have weight one.
+   *
    * @group setParam
    */
   @Since("2.0.0")
@@ -209,6 +218,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
   /**
    * Sets the solver algorithm used for optimization.
    * Currently only support "irls" which is also the default solver.
+   *
    * @group setParam
    */
   @Since("2.0.0")
@@ -217,6 +227,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
 
   /**
    * Sets the link prediction (linear predictor) column name.
+   *
    * @group setParam
    */
   @Since("2.0.0")
@@ -256,15 +267,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
       val model = copyValues(
         new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept)
           .setParent(this))
-      // Handle possible missing or invalid prediction columns
-      val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
-      val trainingSummary = new GeneralizedLinearRegressionSummary(
-        summaryModel.transform(dataset),
-        predictionColName,
-        model,
-        wlsModel.diagInvAtWA.toArray,
-        1,
-        getSolver)
+      val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
+        wlsModel.diagInvAtWA.toArray, 1, getSolver)
       return model.setSummary(trainingSummary)
     }
 
@@ -277,16 +281,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
     val model = copyValues(
       new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept)
         .setParent(this))
-    // Handle possible missing or invalid prediction columns
-    val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
-    val trainingSummary = new GeneralizedLinearRegressionSummary(
-      summaryModel.transform(dataset),
-      predictionColName,
-      model,
-      irlsModel.diagInvAtWA.toArray,
-      irlsModel.numIterations,
-      getSolver)
-
+    val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
+      irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver)
     model.setSummary(trainingSummary)
   }
 
@@ -363,6 +359,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
 
   /**
    * A description of the error distribution to be used in the model.
+   *
    * @param name the name of the family.
    */
   private[ml] abstract class Family(val name: String) extends Serializable {
@@ -381,6 +378,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
 
     /**
      * Akaike's 'An Information Criterion'(AIC) value of the family for a given dataset.
+     *
      * @param predictions an RDD of (y, mu, weight) of instances in evaluation dataset
      * @param deviance the deviance for the fitted model in evaluation dataset
      * @param numInstances number of instances in evaluation dataset
@@ -400,6 +398,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
 
     /**
      * Gets the [[Family]] object from its name.
+     *
      * @param name family name: "gaussian", "binomial", "poisson" or "gamma".
      */
     def fromName(name: String): Family = {
@@ -579,6 +578,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
    * A description of the link function to be used in the model.
    * The link function provides the relationship between the linear predictor
    * and the mean of the distribution function.
+   *
    * @param name the name of link function.
    */
   private[ml] abstract class Link(val name: String) extends Serializable {
@@ -597,6 +597,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
 
     /**
      * Gets the [[Link]] object from its name.
+     *
      * @param name link name: "identity", "logit", "log",
      *             "inverse", "probit", "cloglog" or "sqrt".
      */
@@ -694,6 +695,7 @@ class GeneralizedLinearRegressionModel private[ml] (
 
   /**
    * Sets the link prediction (linear predictor) column name.
+   *
    * @group setParam
    */
   @Since("2.0.0")
@@ -736,39 +738,39 @@ class GeneralizedLinearRegressionModel private[ml] (
     if ($(linkPredictionCol).nonEmpty) {
       output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol))))
     }
-    output.toDF
+    output.toDF()
   }
 
-  private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None
+  private var trainingSummary: Option[GeneralizedLinearRegressionTrainingSummary] = None
 
   /**
    * Gets R-like summary of model on training set. An exception is
-   * thrown if `trainingSummary == None`.
+   * thrown if there is no summary available.
    */
   @Since("2.0.0")
-  def summary: GeneralizedLinearRegressionSummary = trainingSummary.getOrElse {
+  def summary: GeneralizedLinearRegressionTrainingSummary = trainingSummary.getOrElse {
     throw new SparkException(
       "No training summary available for this GeneralizedLinearRegressionModel")
   }
 
-  private[regression] def setSummary(summary: GeneralizedLinearRegressionSummary): this.type = {
+  /**
+   * Indicates if [[summary]] is available.
+   */
+  @Since("2.0.0")
+  def hasSummary: Boolean = trainingSummary.nonEmpty
+
+  private[regression]
+  def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = {
     this.trainingSummary = Some(summary)
     this
   }
 
   /**
-   * If the prediction column is set returns the current model and prediction column,
-   * otherwise generates a new column and sets it as the prediction column on a new copy
-   * of the current model.
+   * Evaluate the model on the given dataset, returning a summary of the results.
    */
-  private[regression] def findSummaryModelAndPredictionCol()
-    : (GeneralizedLinearRegressionModel, String) = {
-    $(predictionCol) match {
-      case "" =>
-        val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
-        (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
-      case p => (this, p)
-    }
+  @Since("2.0.0")
+  def evaluate(dataset: Dataset[_]): GeneralizedLinearRegressionSummary = {
+    new GeneralizedLinearRegressionSummary(dataset, this)
   }
 
   @Since("2.0.0")
@@ -834,36 +836,55 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
 
 /**
  * :: Experimental ::
- * Summarizing Generalized Linear regression Fits.
+ * Summary of [[GeneralizedLinearRegression]] model and predictions.
  *
- * @param predictions predictions output by the model's `transform` method
- * @param predictionCol field in "predictions" which gives the prediction value of each instance
- * @param model the model that should be summarized
- * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
- * @param numIterations number of iterations
- * @param solver the solver algorithm used for model training
+ * @param dataset Dataset to be summarized.
+ * @param origModel Model to be summarized.  This is copied to create an internal
+ *                  model which cannot be modified from outside.
  */
 @Since("2.0.0")
 @Experimental
 class GeneralizedLinearRegressionSummary private[regression] (
-    @Since("2.0.0") @transient val predictions: DataFrame,
-    @Since("2.0.0") val predictionCol: String,
-    @Since("2.0.0") val model: GeneralizedLinearRegressionModel,
-    private val diagInvAtWA: Array[Double],
-    @Since("2.0.0") val numIterations: Int,
-    @Since("2.0.0") val solver: String) extends Serializable {
+    dataset: Dataset[_],
+    origModel: GeneralizedLinearRegressionModel) extends Serializable {
 
   import GeneralizedLinearRegression._
 
-  private lazy val family = Family.fromName(model.getFamily)
-  private lazy val link = if (model.isDefined(model.getParam("link"))) {
+  /**
+   * Field in "predictions" which gives the prediction value of each instance.
+   * This is set to a new column name if the original model's `predictionCol` is not set.
+   */
+  @Since("2.0.0")
+  val predictionCol: String = {
+    if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol != "") {
+      origModel.getPredictionCol
+    } else {
+      "prediction_" + java.util.UUID.randomUUID.toString
+    }
+  }
+
+  /**
+   * Private copy of model to ensure Params are not modified outside this class.
+   * Coefficients is not a deep copy, but that is acceptable.
+   *
+   * NOTE: [[predictionCol]] must be set correctly before the value of [[model]] is set,
+   *       and [[model]] must be set before [[predictions]] is set!
+   */
+  protected val model: GeneralizedLinearRegressionModel =
+    origModel.copy(ParamMap.empty).setPredictionCol(predictionCol)
+
+  /** predictions output by the model's `transform` method */
+  @Since("2.0.0") @transient val predictions: DataFrame = model.transform(dataset)
+
+  private[regression] lazy val family: Family = Family.fromName(model.getFamily)
+  private[regression] lazy val link: Link = if (model.isDefined(model.link)) {
     Link.fromName(model.getLink)
   } else {
     family.defaultLink
   }
 
   /** Number of instances in DataFrame predictions */
-  private lazy val numInstances: Long = predictions.count()
+  private[regression] lazy val numInstances: Long = predictions.count()
 
   /** The numeric rank of the fitted linear model */
   @Since("2.0.0")
@@ -891,7 +912,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
     numInstances
   }
 
-  private lazy val devianceResiduals: DataFrame = {
+  private[regression] lazy val devianceResiduals: DataFrame = {
     val drUDF = udf { (y: Double, mu: Double, weight: Double) =>
       val r = math.sqrt(math.max(family.deviance(y, mu, weight), 0.0))
       if (y > mu) r else -1.0 * r
@@ -901,19 +922,19 @@ class GeneralizedLinearRegressionSummary private[regression] (
       drUDF(col(model.getLabelCol), col(predictionCol), w).as("devianceResiduals"))
   }
 
-  private lazy val pearsonResiduals: DataFrame = {
+  private[regression] lazy val pearsonResiduals: DataFrame = {
     val prUDF = udf { mu: Double => family.variance(mu) }
     val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol)
     predictions.select(col(model.getLabelCol).minus(col(predictionCol))
       .multiply(sqrt(w)).divide(sqrt(prUDF(col(predictionCol)))).as("pearsonResiduals"))
   }
 
-  private lazy val workingResiduals: DataFrame = {
+  private[regression] lazy val workingResiduals: DataFrame = {
     val wrUDF = udf { (y: Double, mu: Double) => (y - mu) * link.deriv(mu) }
     predictions.select(wrUDF(col(model.getLabelCol), col(predictionCol)).as("workingResiduals"))
   }
 
-  private lazy val responseResiduals: DataFrame = {
+  private[regression] lazy val responseResiduals: DataFrame = {
     predictions.select(col(model.getLabelCol).minus(col(predictionCol)).as("responseResiduals"))
   }
 
@@ -925,6 +946,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
 
   /**
    * Get the residuals of the fitted model by type.
+   *
    * @param residualsType The type of residuals which should be returned.
    *                      Supported options: deviance, pearson, working and response.
    */
@@ -996,6 +1018,30 @@ class GeneralizedLinearRegressionSummary private[regression] (
     }
     family.aic(t, deviance, numInstances, weightSum) + 2 * rank
   }
+}
+
+/**
+ * :: Experimental ::
+ * Summary of [[GeneralizedLinearRegression]] fitting and model.
+ *
+ * @param dataset Dataset to be summarized.
+ * @param origModel Model to be summarized.  This is copied to create an internal
+ *                  model which cannot be modified from outside.
+ * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
+ * @param numIterations number of iterations
+ * @param solver the solver algorithm used for model training
+ */
+@Since("2.0.0")
+@Experimental
+class GeneralizedLinearRegressionTrainingSummary private[regression] (
+    dataset: Dataset[_],
+    origModel: GeneralizedLinearRegressionModel,
+    private val diagInvAtWA: Array[Double],
+    @Since("2.0.0") val numIterations: Int,
+    @Since("2.0.0") val solver: String)
+  extends GeneralizedLinearRegressionSummary(dataset, origModel) with Serializable {
+
+  import GeneralizedLinearRegression._
 
   /**
    * Standard error of estimated coefficients and intercept.

http://git-wip-us.apache.org/repos/asf/spark/blob/5ee72454/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 0b5e77a..e4c9a3b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -603,7 +603,9 @@ class GeneralizedLinearRegressionSuite
     val residualDegreeOfFreedomR = 1
     val aicR = 18.783
 
+    assert(model.hasSummary)
     val summary = model.summary
+    assert(summary.isInstanceOf[GeneralizedLinearRegressionTrainingSummary])
 
     val devianceResiduals = summary.residuals()
       .select(col("devianceResiduals"))
@@ -643,6 +645,18 @@ class GeneralizedLinearRegressionSuite
     assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
     assert(summary.aic ~== aicR absTol 1E-3)
     assert(summary.solver === "irls")
+
+    val summary2: GeneralizedLinearRegressionSummary = model.evaluate(datasetWithWeight)
+    assert(summary.predictions.columns.toSet === summary2.predictions.columns.toSet)
+    assert(summary.predictionCol === summary2.predictionCol)
+    assert(summary.rank === summary2.rank)
+    assert(summary.degreesOfFreedom === summary2.degreesOfFreedom)
+    assert(summary.residualDegreeOfFreedom === summary2.residualDegreeOfFreedom)
+    assert(summary.residualDegreeOfFreedomNull === summary2.residualDegreeOfFreedomNull)
+    assert(summary.nullDeviance === summary2.nullDeviance)
+    assert(summary.deviance === summary2.deviance)
+    assert(summary.dispersion === summary2.dispersion)
+    assert(summary.aic === summary2.aic)
   }
 
   test("glm summary: binomial family with weight") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org