You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/05/12 03:41:26 UTC

spark git commit: [SPARK-5893] [ML] Add bucketizer

Repository: spark
Updated Branches:
  refs/heads/master 87229c95c -> 35fb42a0b


[SPARK-5893] [ML] Add bucketizer

JIRA issue [here](https://issues.apache.org/jira/browse/SPARK-5893).

One thing to make clear, the `buckets` parameter, which is an array of `Double`, performs as split points. Say,

```scala
buckets = Array(-0.5, 0.0, 0.5)
```

splits the real number into 4 ranges, (-inf, -0.5], (-0.5, 0.0], (0.0, 0.5], (0.5, +inf), which is encoded as 0, 1, 2, 3.

Author: Xusen Yin <yi...@gmail.com>
Author: Joseph K. Bradley <jo...@databricks.com>

Closes #5980 from yinxusen/SPARK-5893 and squashes the following commits:

dc8c843 [Xusen Yin] Merge pull request #4 from jkbradley/yinxusen-SPARK-5893
1ca973a [Joseph K. Bradley] one more bucketizer test
34f124a [Joseph K. Bradley] Removed lowerInclusive, upperInclusive params from Bucketizer, and used splits instead.
eacfcfa [Xusen Yin] change ML attribute from splits into buckets
c3cc770 [Xusen Yin] add more unit test for binary search
3a16cc2 [Xusen Yin] refine comments and names
ac77859 [Xusen Yin] fix style error
fb30d79 [Xusen Yin] fix and test binary search
2466322 [Xusen Yin] refactor Bucketizer
11fb00a [Xusen Yin] change it into an Estimator
998bc87 [Xusen Yin] check buckets
4024cf1 [Xusen Yin] add test suite
5fe190e [Xusen Yin] add bucketizer


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/35fb42a0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/35fb42a0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/35fb42a0

Branch: refs/heads/master
Commit: 35fb42a0b01d3043b7d5e27256d1b45a08583aab
Parents: 87229c9
Author: Xusen Yin <yi...@gmail.com>
Authored: Mon May 11 18:41:22 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon May 11 18:41:22 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/Bucketizer.scala    | 131 ++++++++++++++++
 .../org/apache/spark/ml/util/SchemaUtils.scala  |  11 ++
 .../spark/ml/feature/BucketizerSuite.scala      | 148 +++++++++++++++++++
 3 files changed, 290 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/35fb42a0/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
new file mode 100644
index 0000000..7dba64b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
+
+/**
+ * :: AlphaComponent ::
+ * `Bucketizer` maps a column of continuous features to a column of feature buckets.
+ */
+@AlphaComponent
+final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
+  extends Model[Bucketizer] with HasInputCol with HasOutputCol {
+
+  def this() = this(null)
+
+  /**
+   * Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
+   * A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly
+   * increasing. Values at -inf, inf must be explicitly provided to cover all Double values;
+   * otherwise, values outside the splits specified will be treated as errors.
+   * @group param
+   */
+  val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
+    "Split points for mapping continuous features into buckets. With n splits, there are n+1 " +
+      "buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " +
+      "should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" +
+      " all Double values; otherwise, values outside the splits specified will be treated as" +
+      " errors.",
+    Bucketizer.checkSplits)
+
+  /** @group getParam */
+  def getSplits: Array[Double] = $(splits)
+
+  /** @group setParam */
+  def setSplits(value: Array[Double]): this.type = set(splits, value)
+
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    transformSchema(dataset.schema)
+    val bucketizer = udf { feature: Double =>
+      Bucketizer.binarySearchForBuckets($(splits), feature)
+    }
+    val newCol = bucketizer(dataset($(inputCol)))
+    val newField = prepOutputField(dataset.schema)
+    dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
+  }
+
+  private def prepOutputField(schema: StructType): StructField = {
+    val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
+    val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true),
+      values = Some(buckets))
+    attr.toStructField()
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
+    SchemaUtils.appendColumn(schema, prepOutputField(schema))
+  }
+}
+
+private[feature] object Bucketizer {
+  /** We require splits to be of length >= 3 and to be in strictly increasing order. */
+  def checkSplits(splits: Array[Double]): Boolean = {
+    if (splits.length < 3) {
+      false
+    } else {
+      var i = 0
+      while (i < splits.length - 1) {
+        if (splits(i) >= splits(i + 1)) return false
+        i += 1
+      }
+      true
+    }
+  }
+
+  /**
+   * Binary searching in several buckets to place each data point.
+   * @throws RuntimeException if a feature is < splits.head or >= splits.last
+   */
+  def binarySearchForBuckets(
+      splits: Array[Double],
+      feature: Double): Double = {
+    // Check bounds.  We make an exception for +inf so that it can exist in some bin.
+    if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) {
+      throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" +
+        s" [${splits.head}, ${splits.last}).  Check your features, or loosen " +
+        s"the lower/upper bound constraints.")
+    }
+    var left = 0
+    var right = splits.length - 2
+    while (left < right) {
+      val mid = (left + right) / 2
+      val split = splits(mid + 1)
+      if (feature < split) {
+        right = mid
+      } else {
+        left = mid + 1
+      }
+    }
+    left
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/35fb42a0/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index 0383bf0..11592b7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -58,4 +58,15 @@ object SchemaUtils {
     val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
     StructType(outputFields)
   }
+
+  /**
+   * Appends a new column to the input schema. This fails if the given output column already exists.
+   * @param schema input schema
+   * @param col New column schema
+   * @return new schema with the input column appended
+   */
+  def appendColumn(schema: StructType, col: StructField): StructType = {
+    require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.")
+    StructType(schema.fields :+ col)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35fb42a0/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
new file mode 100644
index 0000000..acb46c0
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkException
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
+
+  @transient private var sqlContext: SQLContext = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sqlContext = new SQLContext(sc)
+  }
+
+  test("Bucket continuous features, without -inf,inf") {
+    // Check a set of valid feature values.
+    val splits = Array(-0.5, 0.0, 0.5)
+    val validData = Array(-0.5, -0.3, 0.0, 0.2)
+    val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0)
+    val dataFrame: DataFrame =
+      sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
+
+    val bucketizer: Bucketizer = new Bucketizer()
+      .setInputCol("feature")
+      .setOutputCol("result")
+      .setSplits(splits)
+
+    bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
+      case Row(x: Double, y: Double) =>
+        assert(x === y,
+          s"The feature value is not correct after bucketing.  Expected $y but found $x")
+    }
+
+    // Check for exceptions when using a set of invalid feature values.
+    val invalidData1: Array[Double] = Array(-0.9) ++ validData
+    val invalidData2 = Array(0.5) ++ validData
+    val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
+    intercept[RuntimeException]{
+      bucketizer.transform(badDF1).collect()
+      println("Invalid feature value -0.9 was not caught as an invalid feature!")
+    }
+    val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
+    intercept[RuntimeException]{
+      bucketizer.transform(badDF2).collect()
+      println("Invalid feature value 0.5 was not caught as an invalid feature!")
+    }
+  }
+
+  test("Bucket continuous features, with -inf,inf") {
+    val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity)
+    val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
+    val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
+    val dataFrame: DataFrame =
+      sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
+
+    val bucketizer: Bucketizer = new Bucketizer()
+      .setInputCol("feature")
+      .setOutputCol("result")
+      .setSplits(splits)
+
+    bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
+      case Row(x: Double, y: Double) =>
+        assert(x === y,
+          s"The feature value is not correct after bucketing.  Expected $y but found $x")
+    }
+  }
+
+  test("Binary search correctness on hand-picked examples") {
+    import BucketizerSuite.checkBinarySearch
+    // length 3, with -inf
+    checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0))
+    // length 4
+    checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0))
+    // length 5
+    checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0, 1.5))
+    // length 3, with inf
+    checkBinarySearch(Array(0.0, 1.0, Double.PositiveInfinity))
+    // length 3, with -inf and inf
+    checkBinarySearch(Array(Double.NegativeInfinity, 1.0, Double.PositiveInfinity))
+    // length 4, with -inf and inf
+    checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0, Double.PositiveInfinity))
+  }
+
+  test("Binary search correctness in contrast with linear search, on random data") {
+    val data = Array.fill(100)(Random.nextDouble())
+    val splits: Array[Double] = Double.NegativeInfinity +:
+      Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity
+    val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x)))
+    val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
+    assert(bsResult ~== lsResult absTol 1e-5)
+  }
+}
+
+private object BucketizerSuite extends FunSuite {
+  /** Brute force search for buckets.  Bucket i is defined by the range [split(i), split(i+1)). */
+  def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
+    require(feature >= splits.head)
+    var i = 0
+    while (i < splits.length - 1) {
+      if (feature < splits(i + 1)) return i
+      i += 1
+    }
+    throw new RuntimeException(
+      s"linearSearchForBuckets failed to find bucket for feature value $feature")
+  }
+
+  /** Check all values in splits, plus values between all splits. */
+  def checkBinarySearch(splits: Array[Double]): Unit = {
+    def testFeature(feature: Double, expectedBucket: Double): Unit = {
+      assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket,
+        s"Expected feature value $feature to be in bucket $expectedBucket with splits:" +
+          s" ${splits.mkString(", ")}")
+    }
+    var i = 0
+    while (i < splits.length - 1) {
+      testFeature(splits(i), i) // Split i should fall in bucket i.
+      testFeature((splits(i) + splits(i + 1)) / 2, i) // Value between splits i,i+1 should be in i.
+      i += 1
+    }
+    if (splits.last === Double.PositiveInfinity) {
+      testFeature(Double.PositiveInfinity, splits.length - 2)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org