You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/10/08 20:16:26 UTC

spark git commit: [SPARK-9718] [ML] linear regression training summary all columns

Repository: spark
Updated Branches:
  refs/heads/master dcbd58a92 -> 0903c6489


[SPARK-9718] [ML] linear regression training summary all columns

LinearRegression training summary: The transformed dataset should hold all columns, not just selected ones like prediction and label. There is no real need to remove some, and the user may find them useful.

Author: Holden Karau <ho...@pigscanfly.ca>

Closes #8564 from holdenk/SPARK-9718-LinearRegressionTrainingSummary-all-columns.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0903c648
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0903c648
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0903c648

Branch: refs/heads/master
Commit: 0903c6489e9fa39db9575dace22a64015b9cd4c5
Parents: dcbd58a
Author: Holden Karau <ho...@pigscanfly.ca>
Authored: Thu Oct 8 11:16:20 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Oct 8 11:16:20 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/regression/LinearRegression.scala  | 35 +++++++++++++++-----
 .../ml/regression/LinearRegressionSuite.scala   | 13 ++++++++
 2 files changed, 40 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0903c648/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 0dc084f..dd09667 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -170,9 +170,12 @@ class LinearRegression(override val uid: String)
       val intercept = yMean
 
       val model = new LinearRegressionModel(uid, coefficients, intercept)
+      // Handle possible missing or invalid prediction columns
+      val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
+
       val trainingSummary = new LinearRegressionTrainingSummary(
-        model.transform(dataset),
-        $(predictionCol),
+        summaryModel.transform(dataset),
+        predictionColName,
         $(labelCol),
         $(featuresCol),
         Array(0D))
@@ -262,9 +265,12 @@ class LinearRegression(override val uid: String)
     if (handlePersistence) instances.unpersist()
 
     val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
+    // Handle possible missing or invalid prediction columns
+    val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
+
     val trainingSummary = new LinearRegressionTrainingSummary(
-      model.transform(dataset),
-      $(predictionCol),
+      summaryModel.transform(dataset),
+      predictionColName,
       $(labelCol),
       $(featuresCol),
       objectiveHistory)
@@ -316,13 +322,26 @@ class LinearRegressionModel private[ml] (
    */
   // TODO: decide on a good name before exposing to public API
   private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
-    val t = udf { features: Vector => predict(features) }
-    val predictionAndObservations = dataset
-      .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol)))
+    // Handle possible missing or invalid prediction columns
+    val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
+    new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, $(labelCol))
+  }
 
-    new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol))
+  /**
+   * If the prediction column is set returns the current model and prediction column,
+   * otherwise generates a new column and sets it as the prediction column on a new copy
+   * of the current model.
+   */
+  private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = {
+    $(predictionCol) match {
+      case "" =>
+        val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+        (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
+      case p => (this, p)
+    }
   }
 
+
   override protected def predict(features: Vector): Double = {
     dot(features, weights) + intercept
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0903c648/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 3272947..73a0a5c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -462,9 +462,22 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
   test("linear regression model training summary") {
     val trainer = new LinearRegression
     val model = trainer.fit(dataset)
+    val trainerNoPredictionCol = trainer.setPredictionCol("")
+    val modelNoPredictionCol = trainerNoPredictionCol.fit(dataset)
+
 
     // Training results for the model should be available
     assert(model.hasSummary)
+    assert(modelNoPredictionCol.hasSummary)
+
+    // Schema should be a superset of the input dataset
+    assert((dataset.schema.fieldNames.toSet + "prediction").subsetOf(
+      model.summary.predictions.schema.fieldNames.toSet))
+    // Validate that we re-insert a prediction column for evaluation
+    val modelNoPredictionColFieldNames = modelNoPredictionCol.summary.predictions.schema.fieldNames
+    assert((dataset.schema.fieldNames.toSet).subsetOf(
+      modelNoPredictionColFieldNames.toSet))
+    assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_")))
 
     // Residuals in [[LinearRegressionResults]] should equal those manually computed
     val expectedResiduals = dataset.select("features", "label")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org