You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/06/20 06:03:18 UTC

spark git commit: [SPARK-20929][ML] LinearSVC should use its own threshold param

Repository: spark
Updated Branches:
  refs/heads/master 8965fe764 -> cc67bd573


[SPARK-20929][ML] LinearSVC should use its own threshold param

## What changes were proposed in this pull request?

LinearSVC should use its own threshold param, rather than the shared one, since it applies to rawPrediction instead of probability.  This PR changes the param in the Scala, Python and R APIs.

## How was this patch tested?

New unit test to make sure the threshold can be set to any Double value.

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #18151 from jkbradley/ml-2.2-linearsvc-cleanup.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cc67bd57
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cc67bd57
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cc67bd57

Branch: refs/heads/master
Commit: cc67bd573264c9046c4a034927ed8deb2a732110
Parents: 8965fe7
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Mon Jun 19 23:04:17 2017 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Jun 19 23:04:17 2017 -0700

----------------------------------------------------------------------
 R/pkg/R/mllib_classification.R                  |  4 ++-
 .../spark/ml/classification/LinearSVC.scala     | 25 ++++++++++++--
 .../ml/classification/LinearSVCSuite.scala      | 35 +++++++++++++++++++-
 python/pyspark/ml/classification.py             | 20 ++++++++++-
 4 files changed, 79 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cc67bd57/R/pkg/R/mllib_classification.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R
index 306a9b8..bdcc081 100644
--- a/R/pkg/R/mllib_classification.R
+++ b/R/pkg/R/mllib_classification.R
@@ -62,7 +62,9 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
 #'                        of models will be always returned on the original scale, so it will be transparent for
 #'                        users. Note that with/without standardization, the models should be always converged
 #'                        to the same solution when no regularization is applied.
-#' @param threshold The threshold in binary classification, in range [0, 1].
+#' @param threshold The threshold in binary classification applied to the linear model prediction.
+#'                  This threshold can be any real number, where Inf will make all predictions 0.0
+#'                  and -Inf will make all predictions 1.0.
 #' @param weightCol The weight column name.
 #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
 #'                         or the number of partitions are large, this param could be adjusted to a larger size.

http://git-wip-us.apache.org/repos/asf/spark/blob/cc67bd57/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index 9900fbc..d6ed6a4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -42,7 +42,23 @@ import org.apache.spark.sql.functions.{col, lit}
 /** Params for linear SVM Classifier. */
 private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
   with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
-  with HasThreshold with HasAggregationDepth
+  with HasAggregationDepth {
+
+  /**
+   * Param for threshold in binary classification prediction.
+   * For LinearSVC, this threshold is applied to the rawPrediction, rather than a probability.
+   * This threshold can be any real number, where Inf will make all predictions 0.0
+   * and -Inf will make all predictions 1.0.
+   * Default: 0.0
+   *
+   * @group param
+   */
+  final val threshold: DoubleParam = new DoubleParam(this, "threshold",
+    "threshold in binary classification prediction applied to rawPrediction")
+
+  /** @group getParam */
+  def getThreshold: Double = $(threshold)
+}
 
 /**
  * :: Experimental ::
@@ -126,7 +142,7 @@ class LinearSVC @Since("2.2.0") (
   def setWeightCol(value: String): this.type = set(weightCol, value)
 
   /**
-   * Set threshold in binary classification, in range [0, 1].
+   * Set threshold in binary classification.
    *
    * @group setParam
    */
@@ -284,6 +300,7 @@ class LinearSVCModel private[classification] (
 
   @Since("2.2.0")
   def setThreshold(value: Double): this.type = set(threshold, value)
+  setDefault(threshold, 0.0)
 
   @Since("2.2.0")
   def setWeightCol(value: Double): this.type = set(threshold, value)
@@ -301,6 +318,10 @@ class LinearSVCModel private[classification] (
     Vectors.dense(-m, m)
   }
 
+  override protected def raw2prediction(rawPrediction: Vector): Double = {
+    if (rawPrediction(1) > $(threshold)) 1.0 else 0.0
+  }
+
   @Since("2.2.0")
   override def copy(extra: ParamMap): LinearSVCModel = {
     copyValues(new LinearSVCModel(uid, coefficients, intercept), extra).setParent(parent)

http://git-wip-us.apache.org/repos/asf/spark/blob/cc67bd57/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index 2f87afc..f2b00d0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.classification.LinearSVCSuite._
 import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
-import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -127,6 +127,39 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
     MLTestingUtils.checkCopyAndUids(lsvc, model)
   }
 
+  test("LinearSVC threshold acts on rawPrediction") {
+    val lsvc =
+      new LinearSVCModel(uid = "myLSVCM", coefficients = Vectors.dense(1.0), intercept = 0.0)
+    val df = spark.createDataFrame(Seq(
+      (1, Vectors.dense(1e-7)),
+      (0, Vectors.dense(0.0)),
+      (-1, Vectors.dense(-1e-7)))).toDF("id", "features")
+
+    def checkOneResult(
+        model: LinearSVCModel,
+        threshold: Double,
+        expected: Set[(Int, Double)]): Unit = {
+      model.setThreshold(threshold)
+      val results = model.transform(df).select("id", "prediction").collect()
+        .map(r => (r.getInt(0), r.getDouble(1)))
+        .toSet
+      assert(results === expected, s"Failed for threshold = $threshold")
+    }
+
+    def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = {
+      // Check via code path using Classifier.raw2prediction
+      lsvc.setRawPredictionCol("rawPrediction")
+      checkOneResult(lsvc, threshold, expected)
+      // Check via code path using Classifier.predict
+      lsvc.setRawPredictionCol("")
+      checkOneResult(lsvc, threshold, expected)
+    }
+
+    checkResults(0.0, Set((1, 1.0), (0, 0.0), (-1, 0.0)))
+    checkResults(Double.PositiveInfinity, Set((1, 0.0), (0, 0.0), (-1, 0.0)))
+    checkResults(Double.NegativeInfinity, Set((1, 1.0), (0, 1.0), (-1, 1.0)))
+  }
+
   test("linear svc doesn't fit intercept when fitIntercept is off") {
     val lsvc = new LinearSVC().setFitIntercept(false).setMaxIter(5)
     val model = lsvc.fit(smallBinaryDataset)

http://git-wip-us.apache.org/repos/asf/spark/blob/cc67bd57/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 60bdeed..9b345ac 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -63,7 +63,7 @@ class JavaClassificationModel(JavaPredictionModel):
 @inherit_doc
 class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
                 HasRegParam, HasTol, HasRawPredictionCol, HasFitIntercept, HasStandardization,
-                HasThreshold, HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
+                HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
     """
     .. note:: Experimental
 
@@ -109,6 +109,12 @@ class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, Ha
     .. versionadded:: 2.2.0
     """
 
+    threshold = Param(Params._dummy(), "threshold",
+                      "The threshold in binary classification applied to the linear model"
+                      " prediction.  This threshold can be any real number, where Inf will make"
+                      " all predictions 0.0 and -Inf will make all predictions 1.0.",
+                      typeConverter=TypeConverters.toFloat)
+
     @keyword_only
     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                  maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
@@ -147,6 +153,18 @@ class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, Ha
     def _create_model(self, java_model):
         return LinearSVCModel(java_model)
 
+    def setThreshold(self, value):
+        """
+        Sets the value of :py:attr:`threshold`.
+        """
+        return self._set(threshold=value)
+
+    def getThreshold(self):
+        """
+        Gets the value of threshold or its default value.
+        """
+        return self.getOrDefault(self.threshold)
+
 
 class LinearSVCModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
     """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org