You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/09/24 07:16:00 UTC

spark git commit: [SPARK-17057][ML] ProbabilisticClassifierModels' thresholds should have at most one 0

Repository: spark
Updated Branches:
  refs/heads/master f3fe55439 -> 248916f55


[SPARK-17057][ML] ProbabilisticClassifierModels' thresholds should have at most one 0

## What changes were proposed in this pull request?

Match ProbabilisticClassifer.thresholds requirements to R randomForest cutoff, requiring all > 0

## How was this patch tested?

Jenkins tests plus new test cases

Author: Sean Owen <so...@cloudera.com>

Closes #15149 from srowen/SPARK-17057.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/248916f5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/248916f5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/248916f5

Branch: refs/heads/master
Commit: 248916f5589155c0c3e93c3874781f17b08d598d
Parents: f3fe554
Author: Sean Owen <so...@cloudera.com>
Authored: Sat Sep 24 08:15:55 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Sat Sep 24 08:15:55 2016 +0100

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  |  5 +--
 .../ProbabilisticClassifier.scala               | 20 +++++------
 .../ml/param/shared/SharedParamsCodeGen.scala   |  8 +++--
 .../spark/ml/param/shared/sharedParams.scala    |  4 +--
 .../ProbabilisticClassifierSuite.scala          | 35 ++++++++++++++++----
 .../pyspark/ml/param/_shared_params_code_gen.py |  5 +--
 python/pyspark/ml/param/shared.py               |  4 +--
 7 files changed, 52 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 343d50c..5ab63d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -123,9 +123,10 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
 
   /**
    * Set thresholds in multiclass (or binary) classification to adjust the probability of
-   * predicting each class. Array must have length equal to the number of classes, with values >= 0.
+   * predicting each class. Array must have length equal to the number of classes, with values > 0,
+   * excepting that at most one value may be 0.
    * The class with largest value p/t is predicted, where p is the original probability of that
-   * class and t is the class' threshold.
+   * class and t is the class's threshold.
    *
    * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared.
    *       If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be

http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 1b6e775..e89da6f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.ml.classification
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT}
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.sql.{DataFrame, Dataset}
@@ -200,22 +200,20 @@ abstract class ProbabilisticClassificationModel[
     if (!isDefined(thresholds)) {
       probability.argmax
     } else {
-      val thresholds: Array[Double] = getThresholds
-      val probabilities = probability.toArray
+      val thresholds = getThresholds
       var argMax = 0
       var max = Double.NegativeInfinity
       var i = 0
       val probabilitySize = probability.size
       while (i < probabilitySize) {
-        if (thresholds(i) == 0.0) {
-          max = Double.PositiveInfinity
+        // Thresholds are all > 0, excepting that at most one may be 0.
+        // The single class whose threshold is 0, if any, will always be predicted
+        // ('scaled' = +Infinity). However in the case that this class also has
+        // 0 probability, the class will not be selected ('scaled' is NaN).
+        val scaled = probability(i) / thresholds(i)
+        if (scaled > max) {
+          max = scaled
           argMax = i
-        } else {
-          val scaled = probabilities(i) / thresholds(i)
-          if (scaled > max) {
-            max = scaled
-            argMax = i
-          }
         }
         i += 1
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 480b03d..c94b8b4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -50,10 +50,12 @@ private[shared] object SharedParamsCodeGen {
         isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
       ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" +
         " to adjust the probability of predicting each class." +
-        " Array must have length equal to the number of classes, with values >= 0." +
+        " Array must have length equal to the number of classes, with values > 0" +
+        " excepting that at most one value may be 0." +
         " The class with largest value p/t is predicted, where p is the original probability" +
-        " of that class and t is the class' threshold",
-        isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false),
+        " of that class and t is the class's threshold",
+        isValid = "(t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1",
+        finalMethods = false),
       ParamDesc[String]("inputCol", "input column name"),
       ParamDesc[Array[String]]("inputCols", "input column names"),
       ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),

http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 9125d9e..fa45309 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -176,10 +176,10 @@ private[ml] trait HasThreshold extends Params {
 private[ml] trait HasThresholds extends Params {
 
   /**
-   * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.
+   * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.
    * @group param
    */
-  final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0))
+  final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold", (t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1)
 
   /** @group getParam */
   def getThresholds: Array[Double] = $(thresholds)

http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index b3bd2b3..172c64a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -36,8 +36,8 @@ final class TestProbabilisticClassificationModel(
     rawPrediction
   }
 
-  def friendlyPredict(input: Vector): Double = {
-    predict(input)
+  def friendlyPredict(values: Double*): Double = {
+    predict(Vectors.dense(values.toArray))
   }
 }
 
@@ -45,16 +45,37 @@ final class TestProbabilisticClassificationModel(
 class ProbabilisticClassifierSuite extends SparkFunSuite {
 
   test("test thresholding") {
-    val thresholds = Array(0.5, 0.2)
     val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
-      .setThresholds(thresholds)
-    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
-    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
+      .setThresholds(Array(0.5, 0.2))
+    assert(testModel.friendlyPredict(1.0, 1.0) === 1.0)
+    assert(testModel.friendlyPredict(1.0, 0.2) === 0.0)
   }
 
   test("test thresholding not required") {
     val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
-    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
+    assert(testModel.friendlyPredict(1.0, 2.0) === 1.0)
+  }
+
+  test("test tiebreak") {
+    val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
+      .setThresholds(Array(0.4, 0.4))
+    assert(testModel.friendlyPredict(0.6, 0.6) === 0.0)
+  }
+
+  test("test one zero threshold") {
+    val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
+      .setThresholds(Array(0.0, 0.1))
+    assert(testModel.friendlyPredict(1.0, 10.0) === 0.0)
+    assert(testModel.friendlyPredict(0.0, 10.0) === 1.0)
+  }
+
+  test("bad thresholds") {
+    intercept[IllegalArgumentException] {
+      new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(0.0, 0.0))
+    }
+    intercept[IllegalArgumentException] {
+      new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1))
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/python/pyspark/ml/param/_shared_params_code_gen.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 4f4328b..9295912 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -139,8 +139,9 @@ if __name__ == "__main__":
          "model.", "True", "TypeConverters.toBoolean"),
         ("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
          "predicting each class. Array must have length equal to the number of classes, with " +
-         "values >= 0. The class with largest value p/t is predicted, where p is the original " +
-         "probability of that class and t is the class' threshold.", None,
+         "values > 0, excepting that at most one value may be 0. " +
+         "The class with largest value p/t is predicted, where p is the original " +
+         "probability of that class and t is the class's threshold.", None,
          "TypeConverters.toListFloat"),
         ("weightCol", "weight column name. If this is not set or empty, we treat " +
          "all instance weights as 1.0.", None, "TypeConverters.toString"),

http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/python/pyspark/ml/param/shared.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 24af07a..cc59693 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -469,10 +469,10 @@ class HasStandardization(Params):
 
 class HasThresholds(Params):
     """
-    Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.
+    Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.
     """
 
-    thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", typeConverter=TypeConverters.toListFloat)
+    thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.", typeConverter=TypeConverters.toListFloat)
 
     def __init__(self):
         super(HasThresholds, self).__init__()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org