You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2018/12/12 16:07:03 UTC

[spark] branch master updated: [SPARK-24102][ML][MLLIB] ML Evaluators should use weight column - added weight column for regression evaluator

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 570b8f3  [SPARK-24102][ML][MLLIB] ML Evaluators should use weight column - added weight column for regression evaluator
570b8f3 is described below

commit 570b8f3d45ad8d6649ed633251a8194d910f1ab5
Author: Ilya Matiach <il...@microsoft.com>
AuthorDate: Wed Dec 12 10:06:41 2018 -0600

    [SPARK-24102][ML][MLLIB] ML Evaluators should use weight column - added weight column for regression evaluator
    
    ## What changes were proposed in this pull request?
    
    The evaluators BinaryClassificationEvaluator, RegressionEvaluator, and MulticlassClassificationEvaluator and the corresponding metrics classes BinaryClassificationMetrics, RegressionMetrics and MulticlassMetrics should use sample weight data.
    
    I've closed the PR: https://github.com/apache/spark/pull/16557
     as recommended in favor of creating three pull requests, one for each of the evaluators (binary/regression/multiclass) to make it easier to review/update.
    
    The updates to the regression metrics were based on (and updated with new changes based on comments):
    https://issues.apache.org/jira/browse/SPARK-11520
     ("RegressionMetrics should support instance weights")
     but the pull request was closed as the changes were never checked in.
    
    ## How was this patch tested?
    
    I added tests to the metrics class.
    
    Closes #17085 from imatiach-msft/ilmat/regression-evaluate.
    
    Authored-by: Ilya Matiach <il...@microsoft.com>
    Signed-off-by: Sean Owen <se...@databricks.com>
---
 .../spark/ml/evaluation/RegressionEvaluator.scala  | 19 +++++---
 .../spark/mllib/evaluation/RegressionMetrics.scala | 30 +++++++------
 .../mllib/stat/MultivariateOnlineSummarizer.scala  | 25 ++++++-----
 .../stat/MultivariateStatisticalSummary.scala      |  6 +++
 .../mllib/evaluation/RegressionMetricsSuite.scala  | 50 ++++++++++++++++++++++
 project/MimaExcludes.scala                         |  5 ++-
 6 files changed, 106 insertions(+), 29 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 031cd0d..616569b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
 
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
-import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
+import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
 import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
 import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.sql.{Dataset, Row}
@@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
 @Since("1.4.0")
 @Experimental
 final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
-  extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
+  extends Evaluator with HasPredictionCol with HasLabelCol
+    with HasWeightCol with DefaultParamsWritable {
 
   @Since("1.4.0")
   def this() = this(Identifiable.randomUID("regEval"))
@@ -69,6 +70,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
   @Since("1.4.0")
   def setLabelCol(value: String): this.type = set(labelCol, value)
 
+  /** @group setParam */
+  @Since("3.0.0")
+  def setWeightCol(value: String): this.type = set(weightCol, value)
+
   setDefault(metricName -> "rmse")
 
   @Since("2.0.0")
@@ -77,11 +82,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
     SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
     SchemaUtils.checkNumericType(schema, $(labelCol))
 
-    val predictionAndLabels = dataset
-      .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
+    val predictionAndLabelsWithWeights = dataset
+      .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
+        if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
       .rdd
-      .map { case Row(prediction: Double, label: Double) => (prediction, label) }
-    val metrics = new RegressionMetrics(predictionAndLabels)
+      .map { case Row(prediction: Double, label: Double, weight: Double) =>
+        (prediction, label, weight) }
+    val metrics = new RegressionMetrics(predictionAndLabelsWithWeights)
     val metric = $(metricName) match {
       case "rmse" => metrics.rootMeanSquaredError
       case "mse" => metrics.meanSquaredError
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
index 020676c..5250479 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
@@ -27,17 +27,18 @@ import org.apache.spark.sql.DataFrame
 /**
  * Evaluator for regression.
  *
- * @param predictionAndObservations an RDD of (prediction, observation) pairs
+ * @param predAndObsWithOptWeight an RDD of either (prediction, observation, weight)
+ *                                                    or (prediction, observation) pairs
  * @param throughOrigin True if the regression is through the origin. For example, in linear
  *                      regression, it will be true without fitting intercept.
  */
 @Since("1.2.0")
 class RegressionMetrics @Since("2.0.0") (
-    predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean)
+    predAndObsWithOptWeight: RDD[_ <: Product], throughOrigin: Boolean)
     extends Logging {
 
   @Since("1.2.0")
-  def this(predictionAndObservations: RDD[(Double, Double)]) =
+  def this(predictionAndObservations: RDD[_ <: Product]) =
     this(predictionAndObservations, false)
 
   /**
@@ -52,10 +53,13 @@ class RegressionMetrics @Since("2.0.0") (
    * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
    */
   private lazy val summary: MultivariateStatisticalSummary = {
-    val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
-      case (prediction, observation) => Vectors.dense(observation, observation - prediction)
+    val summary: MultivariateStatisticalSummary = predAndObsWithOptWeight.map {
+      case (prediction: Double, observation: Double, weight: Double) =>
+        (Vectors.dense(observation, observation - prediction), weight)
+      case (prediction: Double, observation: Double) =>
+        (Vectors.dense(observation, observation - prediction), 1.0)
     }.treeAggregate(new MultivariateOnlineSummarizer())(
-        (summary, v) => summary.add(v),
+        (summary, sample) => summary.add(sample._1, sample._2),
         (sum1, sum2) => sum1.merge(sum2)
       )
     summary
@@ -63,11 +67,13 @@ class RegressionMetrics @Since("2.0.0") (
 
   private lazy val SSy = math.pow(summary.normL2(0), 2)
   private lazy val SSerr = math.pow(summary.normL2(1), 2)
-  private lazy val SStot = summary.variance(0) * (summary.count - 1)
+  private lazy val SStot = summary.variance(0) * (summary.weightSum - 1)
   private lazy val SSreg = {
     val yMean = summary.mean(0)
-    predictionAndObservations.map {
-      case (prediction, _) => math.pow(prediction - yMean, 2)
+    predAndObsWithOptWeight.map {
+      case (prediction: Double, _: Double, weight: Double) =>
+        math.pow(prediction - yMean, 2) * weight
+      case (prediction: Double, _: Double) => math.pow(prediction - yMean, 2)
     }.sum()
   }
 
@@ -79,7 +85,7 @@ class RegressionMetrics @Since("2.0.0") (
    */
   @Since("1.2.0")
   def explainedVariance: Double = {
-    SSreg / summary.count
+    SSreg / summary.weightSum
   }
 
   /**
@@ -88,7 +94,7 @@ class RegressionMetrics @Since("2.0.0") (
    */
   @Since("1.2.0")
   def meanAbsoluteError: Double = {
-    summary.normL1(1) / summary.count
+    summary.normL1(1) / summary.weightSum
   }
 
   /**
@@ -97,7 +103,7 @@ class RegressionMetrics @Since("2.0.0") (
    */
   @Since("1.2.0")
   def meanSquaredError: Double = {
-    SSerr / summary.count
+    SSerr / summary.weightSum
   }
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 0554b6d..6d510e1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
   private var totalCnt: Long = 0
   private var totalWeightSum: Double = 0.0
   private var weightSquareSum: Double = 0.0
-  private var weightSum: Array[Double] = _
+  private var currWeightSum: Array[Double] = _
   private var nnz: Array[Long] = _
   private var currMax: Array[Double] = _
   private var currMin: Array[Double] = _
@@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
       currM2n = Array.ofDim[Double](n)
       currM2 = Array.ofDim[Double](n)
       currL1 = Array.ofDim[Double](n)
-      weightSum = Array.ofDim[Double](n)
+      currWeightSum = Array.ofDim[Double](n)
       nnz = Array.ofDim[Long](n)
       currMax = Array.fill[Double](n)(Double.MinValue)
       currMin = Array.fill[Double](n)(Double.MaxValue)
@@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
     val localCurrM2n = currM2n
     val localCurrM2 = currM2
     val localCurrL1 = currL1
-    val localWeightSum = weightSum
+    val localWeightSum = currWeightSum
     val localNumNonzeros = nnz
     val localCurrMax = currMax
     val localCurrMin = currMin
@@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
       weightSquareSum += other.weightSquareSum
       var i = 0
       while (i < n) {
-        val thisNnz = weightSum(i)
-        val otherNnz = other.weightSum(i)
+        val thisNnz = currWeightSum(i)
+        val otherNnz = other.currWeightSum(i)
         val totalNnz = thisNnz + otherNnz
         val totalCnnz = nnz(i) + other.nnz(i)
         if (totalNnz != 0.0) {
@@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
           currMax(i) = math.max(currMax(i), other.currMax(i))
           currMin(i) = math.min(currMin(i), other.currMin(i))
         }
-        weightSum(i) = totalNnz
+        currWeightSum(i) = totalNnz
         nnz(i) = totalCnnz
         i += 1
       }
@@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
       this.totalCnt = other.totalCnt
       this.totalWeightSum = other.totalWeightSum
       this.weightSquareSum = other.weightSquareSum
-      this.weightSum = other.weightSum.clone()
+      this.currWeightSum = other.currWeightSum.clone()
       this.nnz = other.nnz.clone()
       this.currMax = other.currMax.clone()
       this.currMin = other.currMin.clone()
@@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
     val realMean = Array.ofDim[Double](n)
     var i = 0
     while (i < n) {
-      realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
+      realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum)
       i += 1
     }
     Vectors.dense(realMean)
@@ -214,8 +214,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
       val len = currM2n.length
       while (i < len) {
         // We prevent variance from negative value caused by numerical error.
-        realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
-          (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
+        realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) *
+          (totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0)
         i += 1
       }
     }
@@ -230,6 +230,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
   override def count: Long = totalCnt
 
   /**
+   * Sum of weights.
+   */
+  override def weightSum: Double = totalWeightSum
+
+  /**
    * Number of nonzero elements in each dimension.
    *
    */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
index 39a16fb7..a438103 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
@@ -45,6 +45,12 @@ trait MultivariateStatisticalSummary {
   def count: Long
 
   /**
+   * Sum of weights.
+   */
+  @Since("3.0.0")
+  def weightSum: Double
+
+  /**
    * Number of nonzero elements (including explicitly presented zero values) in each column.
    */
   @Since("1.0.0")
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
index f1d5173..2380977 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -133,4 +133,54 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
       "root mean squared error mismatch")
     assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch")
   }
+
+  test("regression metrics with same (1.0) weight samples") {
+    val predictionAndObservationWithWeight = sc.parallelize(
+      Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2)
+    val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
+    assert(metrics.explainedVariance ~== 8.79687 absTol eps,
+      "explained variance regression score mismatch")
+    assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error mismatch")
+    assert(metrics.meanSquaredError ~== 0.3125 absTol eps, "mean squared error mismatch")
+    assert(metrics.rootMeanSquaredError ~== 0.55901 absTol eps,
+      "root mean squared error mismatch")
+    assert(metrics.r2 ~== 0.95717 absTol eps, "r2 score mismatch")
+  }
+
+  /**
+   * The following values are hand calculated using the formula:
+   * [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
+   * preds = c(2.25, -0.25, 1.75, 7.75)
+   * obs = c(3.0, -0.5, 2.0, 7.0)
+   * weights = c(0.1, 0.2, 0.15, 0.05)
+   * count = 4
+   *
+   * Weighted metrics can be calculated with MultivariateStatisticalSummary.
+   *             (observations, observations - predictions)
+   * mean        (1.7, 0.05)
+   * variance    (7.3, 0.3)
+   * numNonZeros (0.5, 0.5)
+   * max         (7.0, 0.75)
+   * min         (-0.5, -0.75)
+   * normL2      (2.0, 0.32596)
+   * normL1      (1.05, 0.2)
+   *
+   * explainedVariance: sum(pow((preds - 1.7),2)*weight) / weightedCount = 5.2425
+   * meanAbsoluteError: normL1(1) / weightedCount = 0.4
+   * meanSquaredError: pow(normL2(1),2) / weightedCount = 0.2125
+   * rootMeanSquaredError: sqrt(meanSquaredError) = 0.46098
+   * r2: 1 - pow(normL2(1),2) / (variance(0) * (weightedCount - 1)) = 1.02910
+   */
+  test("regression metrics with weighted samples") {
+    val predictionAndObservationWithWeight = sc.parallelize(
+      Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2)
+    val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
+    assert(metrics.explainedVariance ~== 5.2425 absTol eps,
+      "explained variance regression score mismatch")
+    assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error mismatch")
+    assert(metrics.meanSquaredError ~== 0.2125 absTol eps, "mean squared error mismatch")
+    assert(metrics.rootMeanSquaredError ~== 0.46098 absTol eps,
+      "root mean squared error mismatch")
+    assert(metrics.r2 ~== 1.02910 absTol eps, "r2 score mismatch")
+  }
 }
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index b3252d7..8839133 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -531,7 +531,10 @@ object MimaExcludes {
     ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"),
     ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"),
     ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"),
-    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes")
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes"),
+
+    // [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary
+    ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum")
   ) ++ Seq(
       // [SPARK-17019] Expose on-heap and off-heap memory usage in various places
       ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org