You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/05/12 23:24:30 UTC

spark git commit: [SPARK-7559] [MLLIB] Bucketizer should include the right most boundary in the last bucket.

Repository: spark
Updated Branches:
  refs/heads/master 2a41c0d71 -> 23b9863e2


[SPARK-7559] [MLLIB] Bucketizer should include the right most boundary in the last bucket.

We make special treatment for +inf in `Bucketizer`. This could be simplified by always including the largest split value in the last bucket. E.g., (x1, x2, x3) defines buckets [x1, x2) and [x2, x3]. This shouldn't affect user code much, and there are applications that need to include the right-most value. For example, we can bucketize ratings from 0 to 10 to bad, neutral, and good with splits 0, 4, 6, 10. It may reads weird if the users need to put 0, 4, 6, 10.1 (or 11).

This also update the impl to use `Arrays.binarySearch` and `withClue` in test.

yinxusen jkbradley

Author: Xiangrui Meng <me...@databricks.com>

Closes #6075 from mengxr/SPARK-7559 and squashes the following commits:

e28f910 [Xiangrui Meng] update bucketizer impl


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/23b9863e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/23b9863e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/23b9863e

Branch: refs/heads/master
Commit: 23b9863e2aa7ecd0c4fa3aa8a59fdae09b4fe1d7
Parents: 2a41c0d
Author: Xiangrui Meng <me...@databricks.com>
Authored: Tue May 12 14:24:26 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue May 12 14:24:26 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/Bucketizer.scala    | 55 ++++++++++----------
 .../spark/ml/feature/BucketizerSuite.scala      | 25 ++++-----
 2 files changed, 41 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/23b9863e/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 7dba64b..b28c88a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.ml.feature
 
+import java.{util => ju}
+
+import org.apache.spark.SparkException
 import org.apache.spark.annotation.AlphaComponent
 import org.apache.spark.ml.attribute.NominalAttribute
 import org.apache.spark.ml.param._
@@ -38,18 +41,19 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
   def this() = this(null)
 
   /**
-   * Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
-   * A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly
-   * increasing. Values at -inf, inf must be explicitly provided to cover all Double values;
+   * Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets.
+   * A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which
+   * also includes y. Splits should be strictly increasing.
+   * Values at -inf, inf must be explicitly provided to cover all Double values;
    * otherwise, values outside the splits specified will be treated as errors.
    * @group param
    */
   val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
-    "Split points for mapping continuous features into buckets. With n splits, there are n+1 " +
-      "buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " +
-      "should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" +
-      " all Double values; otherwise, values outside the splits specified will be treated as" +
-      " errors.",
+    "Split points for mapping continuous features into buckets. With n+1 splits, there are n " +
+      "buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " +
+      "bucket, which also includes y. The splits should be strictly increasing. " +
+      "Values at -inf, inf must be explicitly provided to cover all Double values; " +
+      "otherwise, values outside the splits specified will be treated as errors.",
     Bucketizer.checkSplits)
 
   /** @group getParam */
@@ -104,28 +108,25 @@ private[feature] object Bucketizer {
 
   /**
    * Binary searching in several buckets to place each data point.
-   * @throws RuntimeException if a feature is < splits.head or >= splits.last
+   * @throws SparkException if a feature is < splits.head or > splits.last
    */
-  def binarySearchForBuckets(
-      splits: Array[Double],
-      feature: Double): Double = {
-    // Check bounds.  We make an exception for +inf so that it can exist in some bin.
-    if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) {
-      throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" +
-        s" [${splits.head}, ${splits.last}).  Check your features, or loosen " +
-        s"the lower/upper bound constraints.")
-    }
-    var left = 0
-    var right = splits.length - 2
-    while (left < right) {
-      val mid = (left + right) / 2
-      val split = splits(mid + 1)
-      if (feature < split) {
-        right = mid
+  def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
+    if (feature == splits.last) {
+      splits.length - 2
+    } else {
+      val idx = ju.Arrays.binarySearch(splits, feature)
+      if (idx >= 0) {
+        idx
       } else {
-        left = mid + 1
+        val insertPos = -idx - 1
+        if (insertPos == 0 || insertPos == splits.length) {
+          throw new SparkException(s"Feature value $feature out of Bucketizer bounds" +
+            s" [${splits.head}, ${splits.last}].  Check your features, or loosen " +
+            s"the lower/upper bound constraints.")
+        } else {
+          insertPos - 1
+        }
       }
     }
-    left
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/23b9863e/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index acb46c0..1900820 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -57,16 +57,18 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
 
     // Check for exceptions when using a set of invalid feature values.
     val invalidData1: Array[Double] = Array(-0.9) ++ validData
-    val invalidData2 = Array(0.5) ++ validData
+    val invalidData2 = Array(0.51) ++ validData
     val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
-    intercept[RuntimeException]{
-      bucketizer.transform(badDF1).collect()
-      println("Invalid feature value -0.9 was not caught as an invalid feature!")
+    withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
+      intercept[SparkException] {
+        bucketizer.transform(badDF1).collect()
+      }
     }
     val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
-    intercept[RuntimeException]{
-      bucketizer.transform(badDF2).collect()
-      println("Invalid feature value 0.5 was not caught as an invalid feature!")
+    withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
+      intercept[SparkException] {
+        bucketizer.transform(badDF2).collect()
+      }
     }
   }
 
@@ -137,12 +139,11 @@ private object BucketizerSuite extends FunSuite {
     }
     var i = 0
     while (i < splits.length - 1) {
-      testFeature(splits(i), i) // Split i should fall in bucket i.
-      testFeature((splits(i) + splits(i + 1)) / 2, i) // Value between splits i,i+1 should be in i.
+      // Split i should fall in bucket i.
+      testFeature(splits(i), i)
+      // Value between splits i,i+1 should be in i, which is also true if the (i+1)-th split is inf.
+      testFeature((splits(i) + splits(i + 1)) / 2, i)
       i += 1
     }
-    if (splits.last === Double.PositiveInfinity) {
-      testFeature(Double.PositiveInfinity, splits.length - 2)
-    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org