You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2019/02/25 23:17:19 UTC

[spark] branch master updated: [SPARK-24103][ML][MLLIB] ML Evaluators should use weight column - added weight column for binary classification evaluator

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b66be0e  [SPARK-24103][ML][MLLIB] ML Evaluators should use weight column - added weight column for binary classification evaluator
b66be0e is described below

commit b66be0e49098b904d52d0e6412744828890d5354
Author: Ilya Matiach <il...@microsoft.com>
AuthorDate: Mon Feb 25 17:16:51 2019 -0600

    [SPARK-24103][ML][MLLIB] ML Evaluators should use weight column - added weight column for binary classification evaluator
    
    ## What changes were proposed in this pull request?
    
    The evaluators BinaryClassificationEvaluator, RegressionEvaluator, and MulticlassClassificationEvaluator and the corresponding metrics classes BinaryClassificationMetrics, RegressionMetrics and MulticlassMetrics should use sample weight data.
    
    I've closed the PR: https://github.com/apache/spark/pull/16557
    as recommended in favor of creating three pull requests, one for each of the evaluators (binary/regression/multiclass) to make it easier to review/update.
    
    ## How was this patch tested?
    I added tests to the metrics and evaluators classes.
    
    Closes #17084 from imatiach-msft/ilmat/binary-evalute.
    
    Authored-by: Ilya Matiach <il...@microsoft.com>
    Signed-off-by: Sean Owen <se...@databricks.com>
---
 .../evaluation/BinaryClassificationEvaluator.scala | 26 ++++++++++---
 .../evaluation/BinaryClassificationMetrics.scala   | 41 ++++++++++++++------
 .../spark/mllib/evaluation/MulticlassMetrics.scala | 19 ++++++---
 .../BinaryClassificationMetricComputers.scala      | 14 +++----
 .../evaluation/binary/BinaryConfusionMatrix.scala  | 38 +++++++++---------
 .../evaluation/binary/BinaryLabelCounter.scala     | 26 ++++++++-----
 .../BinaryClassificationEvaluatorSuite.scala       | 45 +++++++++++++++++-----
 .../BinaryClassificationMetricsSuite.scala         | 28 ++++++++++++++
 python/pyspark/ml/evaluation.py                    | 20 +++++++---
 python/pyspark/mllib/evaluation.py                 | 30 ++++++++++-----
 10 files changed, 207 insertions(+), 80 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index bff72b2..c6b0433 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -36,7 +36,8 @@ import org.apache.spark.sql.types.DoubleType
 @Since("1.2.0")
 @Experimental
 class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
-  extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable {
+  extends Evaluator with HasRawPredictionCol with HasLabelCol
+    with HasWeightCol with DefaultParamsWritable {
 
   @Since("1.2.0")
   def this() = this(Identifiable.randomUID("binEval"))
@@ -68,6 +69,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
   @Since("1.2.0")
   def setLabelCol(value: String): this.type = set(labelCol, value)
 
+  /** @group setParam */
+  @Since("3.0.0")
+  def setWeightCol(value: String): this.type = set(weightCol, value)
+
   setDefault(metricName -> "areaUnderROC")
 
   @Since("2.0.0")
@@ -75,14 +80,23 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
     val schema = dataset.schema
     SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
     SchemaUtils.checkNumericType(schema, $(labelCol))
+    if (isDefined(weightCol)) {
+      SchemaUtils.checkNumericType(schema, $(weightCol))
+    }
 
     // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
-    val scoreAndLabels =
-      dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
-        case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label)
-        case Row(rawPrediction: Double, label: Double) => (rawPrediction, label)
+    val scoreAndLabelsWithWeights =
+      dataset.select(
+        col($(rawPredictionCol)),
+        col($(labelCol)).cast(DoubleType),
+        if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
+        else col($(weightCol)).cast(DoubleType)).rdd.map {
+        case Row(rawPrediction: Vector, label: Double, weight: Double) =>
+          (rawPrediction(1), label, weight)
+        case Row(rawPrediction: Double, label: Double, weight: Double) =>
+          (rawPrediction, label, weight)
       }
-    val metrics = new BinaryClassificationMetrics(scoreAndLabels)
+    val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights)
     val metric = $(metricName) match {
       case "areaUnderROC" => metrics.areaUnderROC()
       case "areaUnderPR" => metrics.areaUnderPR()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
index 2cfcf38..cc89edc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
@@ -21,12 +21,12 @@ import org.apache.spark.annotation.Since
 import org.apache.spark.internal.Logging
 import org.apache.spark.mllib.evaluation.binary._
 import org.apache.spark.rdd.{RDD, UnionRDD}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
 
 /**
  * Evaluator for binary classification.
  *
- * @param scoreAndLabels an RDD of (score, label) pairs.
+ * @param scoreAndLabels an RDD of (score, label) or (score, label, weight) tuples.
  * @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally
  *                will be down-sampled to this many "bins". If 0, no down-sampling will occur.
  *                This is useful because the curve contains a point for each distinct score
@@ -41,9 +41,19 @@ import org.apache.spark.sql.DataFrame
  *                partition boundaries.
  */
 @Since("1.0.0")
-class BinaryClassificationMetrics @Since("1.3.0") (
-    @Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)],
-    @Since("1.3.0") val numBins: Int) extends Logging {
+class BinaryClassificationMetrics @Since("3.0.0") (
+    @Since("1.3.0") val scoreAndLabels: RDD[_ <: Product],
+    @Since("1.3.0") val numBins: Int = 1000)
+  extends Logging {
+  val scoreLabelsWeight: RDD[(Double, (Double, Double))] = scoreAndLabels.map {
+    case (prediction: Double, label: Double, weight: Double) =>
+      require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
+      (prediction, (label, weight))
+    case (prediction: Double, label: Double) =>
+      (prediction, (label, 1.0))
+    case other =>
+      throw new IllegalArgumentException(s"Expected tuples, got $other")
+  }
 
   require(numBins >= 0, "numBins must be nonnegative")
 
@@ -58,7 +68,14 @@ class BinaryClassificationMetrics @Since("1.3.0") (
    * @param scoreAndLabels a DataFrame with two double columns: score and label
    */
   private[mllib] def this(scoreAndLabels: DataFrame) =
-    this(scoreAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
+    this(scoreAndLabels.rdd.map {
+      case Row(prediction: Double, label: Double, weight: Double) =>
+        (prediction, label, weight)
+      case Row(prediction: Double, label: Double) =>
+        (prediction, label, 1.0)
+      case other =>
+        throw new IllegalArgumentException(s"Expected Row of tuples, got $other")
+    })
 
   /**
    * Unpersist intermediate RDDs used in the computation.
@@ -146,11 +163,13 @@ class BinaryClassificationMetrics @Since("1.3.0") (
   private lazy val (
     cumulativeCounts: RDD[(Double, BinaryLabelCounter)],
     confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
-    // Create a bin for each distinct score value, count positives and negatives within each bin,
-    // and then sort by score values in descending order.
-    val counts = scoreAndLabels.combineByKey(
-      createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label,
-      mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,
+    // Create a bin for each distinct score value, count weighted positives and
+    // negatives within each bin, and then sort by score values in descending order.
+    val counts = scoreLabelsWeight.combineByKey(
+      createCombiner = (labelAndWeight: (Double, Double)) =>
+        new BinaryLabelCounter(0.0, 0.0) += (labelAndWeight._1, labelAndWeight._2),
+      mergeValue = (c: BinaryLabelCounter, labelAndWeight: (Double, Double)) =>
+        c += (labelAndWeight._1, labelAndWeight._2),
       mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2
     ).sortByKey(ascending = false)
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
index ad83c24..a10f26ba 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
@@ -22,17 +22,17 @@ import scala.collection.Map
 import org.apache.spark.annotation.Since
 import org.apache.spark.mllib.linalg.{Matrices, Matrix}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
 
 /**
  * Evaluator for multiclass classification.
  *
- * @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or
- *                         (prediction, label) pairs.
+ * @param predictionAndLabels an RDD of (prediction, label, weight) or
+ *                         (prediction, label) tuples.
  */
 @Since("1.1.0")
-class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Product]) {
-  val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map {
+class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[_ <: Product]) {
+  val predLabelsWeight: RDD[(Double, Double, Double)] = predictionAndLabels.map {
     case (prediction: Double, label: Double, weight: Double) =>
       (prediction, label, weight)
     case (prediction: Double, label: Double) =>
@@ -46,7 +46,14 @@ class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Pr
    * @param predictionAndLabels a DataFrame with two double columns: prediction and label
    */
   private[mllib] def this(predictionAndLabels: DataFrame) =
-    this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
+    this(predictionAndLabels.rdd.map {
+      case Row(prediction: Double, label: Double, weight: Double) =>
+        (prediction, label, weight)
+      case Row(prediction: Double, label: Double) =>
+        (prediction, label, 1.0)
+      case other =>
+        throw new IllegalArgumentException(s"Expected Row of tuples, got $other")
+    })
 
   private lazy val labelCountByClass: Map[Double, Double] =
     predLabelsWeight.map {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala
index 5a4c6ae..d98ca2b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala
@@ -27,11 +27,11 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl
 /** Precision. Defined as 1.0 when there are no positive examples. */
 private[evaluation] object Precision extends BinaryClassificationMetricComputer {
   override def apply(c: BinaryConfusionMatrix): Double = {
-    val totalPositives = c.numTruePositives + c.numFalsePositives
-    if (totalPositives == 0) {
+    val totalPositives = c.weightedTruePositives + c.weightedFalsePositives
+    if (totalPositives == 0.0) {
       1.0
     } else {
-      c.numTruePositives.toDouble / totalPositives
+      c.weightedTruePositives / totalPositives
     }
   }
 }
@@ -39,10 +39,10 @@ private[evaluation] object Precision extends BinaryClassificationMetricComputer
 /** False positive rate. Defined as 0.0 when there are no negative examples. */
 private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
   override def apply(c: BinaryConfusionMatrix): Double = {
-    if (c.numNegatives == 0) {
+    if (c.weightedNegatives == 0.0) {
       0.0
     } else {
-      c.numFalsePositives.toDouble / c.numNegatives
+      c.weightedFalsePositives / c.weightedNegatives
     }
   }
 }
@@ -50,10 +50,10 @@ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricC
 /** Recall. Defined as 0.0 when there are no positive examples. */
 private[evaluation] object Recall extends BinaryClassificationMetricComputer {
   override def apply(c: BinaryConfusionMatrix): Double = {
-    if (c.numPositives == 0) {
+    if (c.weightedPositives == 0.0) {
       0.0
     } else {
-      c.numTruePositives.toDouble / c.numPositives
+      c.weightedTruePositives / c.weightedPositives
     }
   }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala
index 559c6ef..192c9b1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala
@@ -21,23 +21,23 @@ package org.apache.spark.mllib.evaluation.binary
  * Trait for a binary confusion matrix.
  */
 private[evaluation] trait BinaryConfusionMatrix {
-  /** number of true positives */
-  def numTruePositives: Long
+  /** weighted number of true positives */
+  def weightedTruePositives: Double
 
-  /** number of false positives */
-  def numFalsePositives: Long
+  /** weighted number of false positives */
+  def weightedFalsePositives: Double
 
-  /** number of false negatives */
-  def numFalseNegatives: Long
+  /** weighted number of false negatives */
+  def weightedFalseNegatives: Double
 
-  /** number of true negatives */
-  def numTrueNegatives: Long
+  /** weighted number of true negatives */
+  def weightedTrueNegatives: Double
 
-  /** number of positives */
-  def numPositives: Long = numTruePositives + numFalseNegatives
+  /** weighted number of positives */
+  def weightedPositives: Double = weightedTruePositives + weightedFalseNegatives
 
-  /** number of negatives */
-  def numNegatives: Long = numFalsePositives + numTrueNegatives
+  /** weighted number of negatives */
+  def weightedNegatives: Double = weightedFalsePositives + weightedTrueNegatives
 }
 
 /**
@@ -51,20 +51,22 @@ private[evaluation] case class BinaryConfusionMatrixImpl(
     totalCount: BinaryLabelCounter) extends BinaryConfusionMatrix {
 
   /** number of true positives */
-  override def numTruePositives: Long = count.numPositives
+  override def weightedTruePositives: Double = count.weightedNumPositives
 
   /** number of false positives */
-  override def numFalsePositives: Long = count.numNegatives
+  override def weightedFalsePositives: Double = count.weightedNumNegatives
 
   /** number of false negatives */
-  override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives
+  override def weightedFalseNegatives: Double =
+    totalCount.weightedNumPositives - count.weightedNumPositives
 
   /** number of true negatives */
-  override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives
+  override def weightedTrueNegatives: Double =
+    totalCount.weightedNumNegatives - count.weightedNumNegatives
 
   /** number of positives */
-  override def numPositives: Long = totalCount.numPositives
+  override def weightedPositives: Double = totalCount.weightedNumPositives
 
   /** number of negatives */
-  override def numNegatives: Long = totalCount.numNegatives
+  override def weightedNegatives: Double = totalCount.weightedNumNegatives
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryLabelCounter.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryLabelCounter.scala
index 1e610c2..1ad9196 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryLabelCounter.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryLabelCounter.scala
@@ -20,31 +20,39 @@ package org.apache.spark.mllib.evaluation.binary
 /**
  * A counter for positives and negatives.
  *
- * @param numPositives number of positive labels
- * @param numNegatives number of negative labels
+ * @param weightedNumPositives weighted number of positive labels
+ * @param weightedNumNegatives weighted number of negative labels
  */
 private[evaluation] class BinaryLabelCounter(
-    var numPositives: Long = 0L,
-    var numNegatives: Long = 0L) extends Serializable {
+    var weightedNumPositives: Double = 0.0,
+    var weightedNumNegatives: Double = 0.0) extends Serializable {
 
   /** Processes a label. */
   def +=(label: Double): BinaryLabelCounter = {
     // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
     // -1.0 for negative as well.
-    if (label > 0.5) numPositives += 1L else numNegatives += 1L
+    if (label > 0.5) weightedNumPositives += 1.0 else weightedNumNegatives += 1.0
+    this
+  }
+
+  /** Processes a label with a weight. */
+  def +=(label: Double, weight: Double): BinaryLabelCounter = {
+    // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
+    // -1.0 for negative as well.
+    if (label > 0.5) weightedNumPositives += weight else weightedNumNegatives += weight
     this
   }
 
   /** Merges another counter. */
   def +=(other: BinaryLabelCounter): BinaryLabelCounter = {
-    numPositives += other.numPositives
-    numNegatives += other.numNegatives
+    weightedNumPositives += other.weightedNumPositives
+    weightedNumNegatives += other.weightedNumNegatives
     this
   }
 
   override def clone: BinaryLabelCounter = {
-    new BinaryLabelCounter(numPositives, numNegatives)
+    new BinaryLabelCounter(weightedNumPositives, weightedNumNegatives)
   }
 
-  override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"
+  override def toString: String = s"{numPos: $weightedNumPositives, numNeg: $weightedNumNegatives}"
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
index 2b0909a..83b213a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
@@ -45,23 +45,23 @@ class BinaryClassificationEvaluatorSuite
       .setMetricName("areaUnderPR")
 
     val vectorDF = Seq(
-      (0d, Vectors.dense(12, 2.5)),
-      (1d, Vectors.dense(1, 3)),
-      (0d, Vectors.dense(10, 2))
+      (0.0, Vectors.dense(12, 2.5)),
+      (1.0, Vectors.dense(1, 3)),
+      (0.0, Vectors.dense(10, 2))
     ).toDF("label", "rawPrediction")
     assert(evaluator.evaluate(vectorDF) === 1.0)
 
     val doubleDF = Seq(
-      (0d, 0d),
-      (1d, 1d),
-      (0d, 0d)
+      (0.0, 0.0),
+      (1.0, 1.0),
+      (0.0, 0.0)
     ).toDF("label", "rawPrediction")
     assert(evaluator.evaluate(doubleDF) === 1.0)
 
     val stringDF = Seq(
-      (0d, "0d"),
-      (1d, "1d"),
-      (0d, "0d")
+      (0.0, "0.0"),
+      (1.0, "1.0"),
+      (0.0, "0.0")
     ).toDF("label", "rawPrediction")
     val thrown = intercept[IllegalArgumentException] {
       evaluator.evaluate(stringDF)
@@ -71,6 +71,33 @@ class BinaryClassificationEvaluatorSuite
     assert(thrown.getMessage.replace("\n", "") contains "but was actually of type string.")
   }
 
+  test("should accept weight column") {
+    val weightCol = "weight"
+    // get metric with weight column
+    val evaluator = new BinaryClassificationEvaluator()
+      .setMetricName("areaUnderROC").setWeightCol(weightCol)
+    val vectorDF = Seq(
+      (0.0, Vectors.dense(2.5, 12), 1.0),
+      (1.0, Vectors.dense(1, 3), 1.0),
+      (0.0, Vectors.dense(10, 2), 1.0)
+    ).toDF("label", "rawPrediction", weightCol)
+    val result = evaluator.evaluate(vectorDF)
+    // without weight column
+    val evaluator2 = new BinaryClassificationEvaluator()
+      .setMetricName("areaUnderROC")
+    val result2 = evaluator2.evaluate(vectorDF)
+    assert(result === result2)
+    // use different weights, validate metrics change
+    val vectorDF2 = Seq(
+      (0.0, Vectors.dense(2.5, 12), 2.5),
+      (1.0, Vectors.dense(1, 3), 0.1),
+      (0.0, Vectors.dense(10, 2), 2.0)
+    ).toDF("label", "rawPrediction", weightCol)
+    val result3 = evaluator.evaluate(vectorDF2)
+    // Since wrong result weighted more heavily, expect the score to be lower
+    assert(result3 < result)
+  }
+
   test("should support all NumericType labels and not support other types") {
     val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction")
     MLTestingUtils.checkNumericTypes(evaluator, spark)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index a08917a..06a522f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -82,6 +82,34 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
     validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
   }
 
+  test("binary evaluation metrics with weights") {
+    val w1 = 1.5
+    val w2 = 0.7
+    val w3 = 0.4
+    val scoreAndLabelsWithWeights = sc.parallelize(
+      Seq((0.1, 0.0, w1), (0.1, 1.0, w2), (0.4, 0.0, w1), (0.6, 0.0, w3),
+        (0.6, 1.0, w2), (0.6, 1.0, w2), (0.8, 1.0, w1)), 2)
+    val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights, 0)
+    val thresholds = Seq(0.8, 0.6, 0.4, 0.1)
+    val numTruePositives =
+      Seq(1 * w1, 1 * w1 + 2 * w2, 1 * w1 + 2 * w2, 3 * w2 + 1 * w1)
+    val numFalsePositives = Seq(0.0, 1.0 * w3, 1.0 * w1 + 1.0 * w3, 1.0 * w3 + 2.0 * w1)
+    val numPositives = 3 * w2 + 1 * w1
+    val numNegatives = 2 * w1 + w3
+    val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
+      t.toDouble / (t + f)
+    }
+    val recalls = numTruePositives.map(_ / numPositives)
+    val fpr = numFalsePositives.map(_ / numNegatives)
+    val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
+    val pr = recalls.zip(precisions)
+    val prCurve = Seq((0.0, 1.0)) ++ pr
+    val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
+    val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
+
+    validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
+  }
+
   test("binary evaluation metrics for RDD where all examples have positive label") {
     val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2)
     val metrics = new BinaryClassificationMetrics(scoreAndLabels)
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index f563a2d..0f70860 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -106,7 +106,7 @@ class JavaEvaluator(JavaParams, Evaluator):
 
 
 @inherit_doc
-class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol,
+class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol, HasWeightCol,
                                     JavaMLReadable, JavaMLWritable):
     """
     .. note:: Experimental
@@ -130,6 +130,16 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
     >>> evaluator2 = BinaryClassificationEvaluator.load(bce_path)
     >>> str(evaluator2.getRawPredictionCol())
     'raw'
+    >>> scoreAndLabelsAndWeight = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1], x[2]),
+    ...    [(0.1, 0.0, 1.0), (0.1, 1.0, 0.9), (0.4, 0.0, 0.7), (0.6, 0.0, 0.9),
+    ...     (0.6, 1.0, 1.0), (0.6, 1.0, 0.3), (0.8, 1.0, 1.0)])
+    >>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["raw", "label", "weight"])
+    ...
+    >>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw", weightCol="weight")
+    >>> evaluator.evaluate(dataset)
+    0.70...
+    >>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})
+    0.82...
 
     .. versionadded:: 1.4.0
     """
@@ -140,10 +150,10 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
 
     @keyword_only
     def __init__(self, rawPredictionCol="rawPrediction", labelCol="label",
-                 metricName="areaUnderROC"):
+                 metricName="areaUnderROC", weightCol=None):
         """
         __init__(self, rawPredictionCol="rawPrediction", labelCol="label", \
-                 metricName="areaUnderROC")
+                 metricName="areaUnderROC", weightCol=None)
         """
         super(BinaryClassificationEvaluator, self).__init__()
         self._java_obj = self._new_java_obj(
@@ -169,10 +179,10 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
     @keyword_only
     @since("1.4.0")
     def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
-                  metricName="areaUnderROC"):
+                  metricName="areaUnderROC", weightCol=None):
         """
         setParams(self, rawPredictionCol="rawPrediction", labelCol="label", \
-                  metricName="areaUnderROC")
+                  metricName="areaUnderROC", weightCol=None)
         Sets params for binary classification evaluator.
         """
         kwargs = self._input_kwargs
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index b028394..5d8d20d 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -30,7 +30,7 @@ class BinaryClassificationMetrics(JavaModelWrapper):
     """
     Evaluator for binary classification.
 
-    :param scoreAndLabels: an RDD of (score, label) pairs
+    :param scoreAndLabels: an RDD of score, label and optional weight.
 
     >>> scoreAndLabels = sc.parallelize([
     ...     (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
@@ -40,6 +40,14 @@ class BinaryClassificationMetrics(JavaModelWrapper):
     >>> metrics.areaUnderPR
     0.83...
     >>> metrics.unpersist()
+    >>> scoreAndLabelsWithOptWeight = sc.parallelize([
+    ...     (0.1, 0.0, 1.0), (0.1, 1.0, 0.4), (0.4, 0.0, 0.2), (0.6, 0.0, 0.6), (0.6, 1.0, 0.9),
+    ...     (0.6, 1.0, 0.5), (0.8, 1.0, 0.7)], 2)
+    >>> metrics = BinaryClassificationMetrics(scoreAndLabelsWithOptWeight)
+    >>> metrics.areaUnderROC
+    0.79...
+    >>> metrics.areaUnderPR
+    0.88...
 
     .. versionadded:: 1.4.0
     """
@@ -47,9 +55,13 @@ class BinaryClassificationMetrics(JavaModelWrapper):
     def __init__(self, scoreAndLabels):
         sc = scoreAndLabels.ctx
         sql_ctx = SQLContext.getOrCreate(sc)
-        df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
+        numCol = len(scoreAndLabels.first())
+        schema = StructType([
             StructField("score", DoubleType(), nullable=False),
-            StructField("label", DoubleType(), nullable=False)]))
+            StructField("label", DoubleType(), nullable=False)])
+        if numCol == 3:
+            schema.add("weight", DoubleType(), False)
+        df = sql_ctx.createDataFrame(scoreAndLabels, schema=schema)
         java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
         java_model = java_class(df._jdf)
         super(BinaryClassificationMetrics, self).__init__(java_model)
@@ -162,7 +174,7 @@ class MulticlassMetrics(JavaModelWrapper):
     """
     Evaluator for multiclass classification.
 
-    :param predAndLabelsWithOptWeight: an RDD of prediction, label and optional weight.
+    :param predictionAndLabels: an RDD of prediction, label and optional weight.
 
     >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
     ...     (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
@@ -223,16 +235,16 @@ class MulticlassMetrics(JavaModelWrapper):
     .. versionadded:: 1.4.0
     """
 
-    def __init__(self, predAndLabelsWithOptWeight):
-        sc = predAndLabelsWithOptWeight.ctx
+    def __init__(self, predictionAndLabels):
+        sc = predictionAndLabels.ctx
         sql_ctx = SQLContext.getOrCreate(sc)
-        numCol = len(predAndLabelsWithOptWeight.first())
+        numCol = len(predictionAndLabels.first())
         schema = StructType([
             StructField("prediction", DoubleType(), nullable=False),
             StructField("label", DoubleType(), nullable=False)])
-        if (numCol == 3):
+        if numCol == 3:
             schema.add("weight", DoubleType(), False)
-        df = sql_ctx.createDataFrame(predAndLabelsWithOptWeight, schema)
+        df = sql_ctx.createDataFrame(predictionAndLabels, schema)
         java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
         java_model = java_class(df._jdf)
         super(MulticlassMetrics, self).__init__(java_model)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org