You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/07/10 01:21:25 UTC

spark git commit: [SPARK-8538] [SPARK-8539] [ML] Linear Regression Training and Testing Results

Repository: spark
Updated Branches:
  refs/heads/master e29ce319f -> a0cc3e5aa


[SPARK-8538] [SPARK-8539] [ML] Linear Regression Training and Testing Results

Adds results (e.g. objective value at each iteration, residuals) on training and user-specified test sets for LinearRegressionModel.

Notes to Reviewers:
 * Are the `*TrainingResults` and `Results` classes too specialized for `LinearRegressionModel`? Where would be an appropriate level of abstraction?
 * Please check `transient` annotations are correct; the datasets should not be copied and kept during serialization.
 * Any thoughts on `RDD`s versus `DataFrame`s? If using `DataFrame`s, suggested schemas for each intermediate step? Also, how to create a "local DataFrame" without a `sqlContext`?

Author: Feynman Liang <fl...@databricks.com>

Closes #7099 from feynmanliang/SPARK-8538 and squashes the following commits:

d219fa4 [Feynman Liang] Update docs
4a42680 [Feynman Liang] Change Summary to hold values, move transient annotations down to metrics and predictions DF
6300031 [Feynman Liang] Code review changes
0a5e762 [Feynman Liang] Fix build error
e71102d [Feynman Liang] Merge branch 'master' into SPARK-8538
3367489 [Feynman Liang] Merge branch 'master' into SPARK-8538
70f267c [Feynman Liang] Make TrainingSummary transient and remove Serializable from *Summary and RegressionMetrics
1d9ea42 [Feynman Liang] Fix failing Java test
a65dfda [Feynman Liang] Make TrainingSummary and metrics serializable, prediction dataframe transient
0a605d8 [Feynman Liang] Replace Params from LinearRegression*Summary with private constructor vals
c2fe835 [Feynman Liang] Optimize imports
02d8a70 [Feynman Liang] Add Params to LinearModel*Summary, refactor tests and add test for evaluate()
8f999f4 [Feynman Liang] Refactor from jkbradley code review
072e948 [Feynman Liang] Style
509ae36 [Feynman Liang] Use DFs and localize serialization to LinearRegressionModel
9509c79 [Feynman Liang] Fix imports
b2bbaa3 [Feynman Liang] Refactored LinearRegressionResults API to be more private
ffceaec [Feynman Liang] Merge branch 'master' into SPARK-8538
1cedb2b [Feynman Liang] Add test for decreasing objective trace
dab0aff [Feynman Liang] Add LinearRegressionTrainingResults tests, make test suite code copy+pasteable
97b0a81 [Feynman Liang] Add LinearRegressionModel.evaluate() to get results on test sets
dc51bce [Feynman Liang] Style guide fixes
521f397 [Feynman Liang] Use RDD[(Double, Double)] instead of DF
2ff5710 [Feynman Liang] Add training results and model summary to ML LinearRegression


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a0cc3e5a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a0cc3e5a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a0cc3e5a

Branch: refs/heads/master
Commit: a0cc3e5aa3fcfd0fce6813c520152657d327aaf2
Parents: e29ce31
Author: Feynman Liang <fl...@databricks.com>
Authored: Thu Jul 9 16:21:21 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Jul 9 16:21:21 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/regression/LinearRegression.scala  | 139 ++++++++++++++++++-
 .../ml/regression/LinearRegressionSuite.scala   |  59 ++++++++
 2 files changed, 192 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a0cc3e5a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index f672c96..8fc9860 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -22,18 +22,20 @@ import scala.collection.mutable
 import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
 import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
 
-import org.apache.spark.{SparkException, Logging}
+import org.apache.spark.{Logging, SparkException}
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS._
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.StatCounter
 
@@ -139,7 +141,16 @@ class LinearRegression(override val uid: String)
       logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
         s"and the intercept will be the mean of the label; as a result, training is not needed.")
       if (handlePersistence) instances.unpersist()
-      return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean)
+      val weights = Vectors.sparse(numFeatures, Seq())
+      val intercept = yMean
+
+      val model = new LinearRegressionModel(uid, weights, intercept)
+      val trainingSummary = new LinearRegressionTrainingSummary(
+        model.transform(dataset).select($(predictionCol), $(labelCol)),
+        $(predictionCol),
+        $(labelCol),
+        Array(0D))
+      return copyValues(model.setSummary(trainingSummary))
     }
 
     val featuresMean = summarizer.mean.toArray
@@ -178,7 +189,6 @@ class LinearRegression(override val uid: String)
         state = states.next()
         arrayBuilder += state.adjustedValue
       }
-
       if (state == null) {
         val msg = s"${optimizer.getClass.getName} failed."
         logError(msg)
@@ -209,7 +219,13 @@ class LinearRegression(override val uid: String)
 
     if (handlePersistence) instances.unpersist()
 
-    copyValues(new LinearRegressionModel(uid, weights, intercept))
+    val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
+    val trainingSummary = new LinearRegressionTrainingSummary(
+      model.transform(dataset).select($(predictionCol), $(labelCol)),
+      $(predictionCol),
+      $(labelCol),
+      objectiveHistory)
+    model.setSummary(trainingSummary)
   }
 
   override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
@@ -227,13 +243,124 @@ class LinearRegressionModel private[ml] (
   extends RegressionModel[Vector, LinearRegressionModel]
   with LinearRegressionParams {
 
+  private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
+
+  /**
+   * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is
+   * thrown if `trainingSummary == None`.
+   */
+  def summary: LinearRegressionTrainingSummary = trainingSummary match {
+    case Some(summ) => summ
+    case None =>
+      throw new SparkException(
+        "No training summary available for this LinearRegressionModel",
+        new NullPointerException())
+  }
+
+  private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = {
+    this.trainingSummary = Some(summary)
+    this
+  }
+
+  /** Indicates whether a training summary exists for this model instance. */
+  def hasSummary: Boolean = trainingSummary.isDefined
+
+  /**
+   * Evaluates the model on a testset.
+   * @param dataset Test dataset to evaluate model on.
+   */
+  // TODO: decide on a good name before exposing to public API
+  private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
+    val t = udf { features: Vector => predict(features) }
+    val predictionAndObservations = dataset
+      .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol)))
+
+    new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol))
+  }
+
   override protected def predict(features: Vector): Double = {
     dot(features, weights) + intercept
   }
 
   override def copy(extra: ParamMap): LinearRegressionModel = {
-    copyValues(new LinearRegressionModel(uid, weights, intercept), extra)
+    val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept))
+    if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
+    newModel
+  }
+}
+
+/**
+ * :: Experimental ::
+ * Linear regression training results.
+ * @param predictions predictions outputted by the model's `transform` method.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+@Experimental
+class LinearRegressionTrainingSummary private[regression] (
+    predictions: DataFrame,
+    predictionCol: String,
+    labelCol: String,
+    val objectiveHistory: Array[Double])
+  extends LinearRegressionSummary(predictions, predictionCol, labelCol) {
+
+  /** Number of training iterations until termination */
+  val totalIterations = objectiveHistory.length
+
+}
+
+/**
+ * :: Experimental ::
+ * Linear regression results evaluated on a dataset.
+ * @param predictions predictions outputted by the model's `transform` method.
+ */
+@Experimental
+class LinearRegressionSummary private[regression] (
+    @transient val predictions: DataFrame,
+    val predictionCol: String,
+    val labelCol: String) extends Serializable {
+
+  @transient private val metrics = new RegressionMetrics(
+    predictions
+      .select(predictionCol, labelCol)
+      .map { case Row(pred: Double, label: Double) => (pred, label) } )
+
+  /**
+   * Returns the explained variance regression score.
+   * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
+   * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
+   */
+  val explainedVariance: Double = metrics.explainedVariance
+
+  /**
+   * Returns the mean absolute error, which is a risk function corresponding to the
+   * expected value of the absolute error loss or l1-norm loss.
+   */
+  val meanAbsoluteError: Double = metrics.meanAbsoluteError
+
+  /**
+   * Returns the mean squared error, which is a risk function corresponding to the
+   * expected value of the squared error loss or quadratic loss.
+   */
+  val meanSquaredError: Double = metrics.meanSquaredError
+
+  /**
+   * Returns the root mean squared error, which is defined as the square root of
+   * the mean squared error.
+   */
+  val rootMeanSquaredError: Double = metrics.rootMeanSquaredError
+
+  /**
+   * Returns R^2^, the coefficient of determination.
+   * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
+   */
+  val r2: Double = metrics.r2
+
+  /** Residuals (predicted value - label value) */
+  @transient lazy val residuals: DataFrame = {
+    val t = udf { (pred: Double, label: Double) => pred - label}
+    predictions.select(t(col(predictionCol), col(labelCol)).as("residuals"))
   }
+
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/a0cc3e5a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 4f6a577..cf120cf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -289,4 +289,63 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
         assert(prediction1 ~== prediction2 relTol 1E-5)
     }
   }
+
+  test("linear regression model training summary") {
+    val trainer = new LinearRegression
+    val model = trainer.fit(dataset)
+
+    // Training results for the model should be available
+    assert(model.hasSummary)
+
+    // Residuals in [[LinearRegressionResults]] should equal those manually computed
+    val expectedResiduals = dataset.select("features", "label")
+      .map { case Row(features: DenseVector, label: Double) =>
+      val prediction =
+        features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+      prediction - label
+    }
+      .zip(model.summary.residuals.map(_.getDouble(0)))
+      .collect()
+      .foreach { case (manualResidual: Double, resultResidual: Double) =>
+      assert(manualResidual ~== resultResidual relTol 1E-5)
+    }
+
+    /*
+       Use the following R code to generate model training results.
+
+       predictions <- predict(fit, newx=features)
+       residuals <- predictions - label
+       > mean(residuals^2) # MSE
+       [1] 0.009720325
+       > mean(abs(residuals)) # MAD
+       [1] 0.07863206
+       > cor(predictions, label)^2# r^2
+               [,1]
+       s0 0.9998749
+     */
+    assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5)
+    assert(model.summary.meanAbsoluteError ~== 0.07863206  relTol 1E-5)
+    assert(model.summary.r2 ~== 0.9998749 relTol 1E-5)
+
+    // Objective function should be monotonically decreasing for linear regression
+    assert(
+      model.summary
+        .objectiveHistory
+        .sliding(2)
+        .forall(x => x(0) >= x(1)))
+  }
+
+  test("linear regression model testset evaluation summary") {
+    val trainer = new LinearRegression
+    val model = trainer.fit(dataset)
+
+    // Evaluating on training dataset should yield results summary equal to training summary
+    val testSummary = model.evaluate(dataset)
+    assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5)
+    assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5)
+    model.summary.residuals.select("residuals").collect()
+      .zip(testSummary.residuals.select("residuals").collect())
+      .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org