You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/01 17:31:02 UTC

spark git commit: [SPARK-5891] [ML] Add Binarizer ML Transformer

Repository: spark
Updated Branches:
  refs/heads/master 3b514af8a -> 7630213ca


[SPARK-5891] [ML] Add Binarizer ML Transformer

JIRA: https://issues.apache.org/jira/browse/SPARK-5891

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #5699 from viirya/add_binarizer and squashes the following commits:

1a0b9a4 [Liang-Chi Hsieh] For comments.
bc397f2 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into add_binarizer
cc4f03c [Liang-Chi Hsieh] Implement threshold param and use merged params map.
7564c63 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into add_binarizer
1682f8c [Liang-Chi Hsieh] Add Binarizer ML Transformer.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7630213c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7630213c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7630213c

Branch: refs/heads/master
Commit: 7630213cab1f653212828f045cf1d7d1870abea0
Parents: 3b514af
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Fri May 1 08:31:01 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri May 1 08:31:01 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Binarizer.scala | 85 ++++++++++++++++++++
 .../spark/ml/feature/BinarizerSuite.scala       | 69 ++++++++++++++++
 2 files changed, 154 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7630213c/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
new file mode 100644
index 0000000..f3ce6df
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute.BinaryAttribute
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DoubleType, StructType}
+
+/**
+ * :: AlphaComponent ::
+ * Binarize a column of continuous features given a threshold.
+ */
+@AlphaComponent
+final class Binarizer extends Transformer with HasInputCol with HasOutputCol {
+
+  /**
+   * Param for threshold used to binarize continuous features.
+   * The features greater than the threshold, will be binarized to 1.0.
+   * The features equal to or less than the threshold, will be binarized to 0.0.
+   * @group param
+   */
+  val threshold: DoubleParam =
+    new DoubleParam(this, "threshold", "threshold used to binarize continuous features")
+
+  /** @group getParam */
+  def getThreshold: Double = getOrDefault(threshold)
+
+  /** @group setParam */
+  def setThreshold(value: Double): this.type = set(threshold, value)
+
+  setDefault(threshold -> 0.0)
+
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+    transformSchema(dataset.schema, paramMap, logging = true)
+    val map = extractParamMap(paramMap)
+    val td = map(threshold)
+    val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 }
+    val outputColName = map(outputCol)
+    val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata()
+    dataset.select(col("*"),
+      binarizer(col(map(inputCol))).as(outputColName, metadata))
+  }
+
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+    val map = extractParamMap(paramMap)
+    SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType)
+
+    val inputFields = schema.fields
+    val outputColName = map(outputCol)
+
+    require(inputFields.forall(_.name != outputColName),
+      s"Output column $outputColName already exists.")
+
+    val attr = BinaryAttribute.defaultAttr.withName(outputColName)
+    val outputFields = inputFields :+ attr.toStructField()
+    StructType(outputFields)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7630213c/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
new file mode 100644
index 0000000..caf1b75
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+
+class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
+
+  @transient var data: Array[Double] = _
+  @transient var sqlContext: SQLContext = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sqlContext = new SQLContext(sc)
+    data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
+  }
+
+  test("Binarize continuous features with default parameter") {
+    val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
+    val dataFrame: DataFrame = sqlContext.createDataFrame(
+      data.zip(defaultBinarized)).toDF("feature", "expected")
+
+    val binarizer: Binarizer = new Binarizer()
+      .setInputCol("feature")
+      .setOutputCol("binarized_feature")
+
+    binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+      case Row(x: Double, y: Double) =>
+        assert(x === y, "The feature value is not correct after binarization.")
+    }
+  }
+
+  test("Binarize continuous features with setter") {
+    val threshold: Double = 0.2
+    val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) 
+    val dataFrame: DataFrame = sqlContext.createDataFrame(
+        data.zip(thresholdBinarized)).toDF("feature", "expected")
+
+    val binarizer: Binarizer = new Binarizer()
+      .setInputCol("feature")
+      .setOutputCol("binarized_feature")
+      .setThreshold(threshold)
+
+    binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+      case Row(x: Double, y: Double) =>
+        assert(x === y, "The feature value is not correct after binarization.")
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org