You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/01 17:31:02 UTC
spark git commit: [SPARK-5891] [ML] Add Binarizer ML Transformer
Repository: spark
Updated Branches:
refs/heads/master 3b514af8a -> 7630213ca
[SPARK-5891] [ML] Add Binarizer ML Transformer
JIRA: https://issues.apache.org/jira/browse/SPARK-5891
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #5699 from viirya/add_binarizer and squashes the following commits:
1a0b9a4 [Liang-Chi Hsieh] For comments.
bc397f2 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into add_binarizer
cc4f03c [Liang-Chi Hsieh] Implement threshold param and use merged params map.
7564c63 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into add_binarizer
1682f8c [Liang-Chi Hsieh] Add Binarizer ML Transformer.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7630213c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7630213c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7630213c
Branch: refs/heads/master
Commit: 7630213cab1f653212828f045cf1d7d1870abea0
Parents: 3b514af
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Fri May 1 08:31:01 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri May 1 08:31:01 2015 -0700
----------------------------------------------------------------------
.../org/apache/spark/ml/feature/Binarizer.scala | 85 ++++++++++++++++++++
.../spark/ml/feature/BinarizerSuite.scala | 69 ++++++++++++++++
2 files changed, 154 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/7630213c/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
new file mode 100644
index 0000000..f3ce6df
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute.BinaryAttribute
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DoubleType, StructType}
+
+/**
+ * :: AlphaComponent ::
+ * Binarize a column of continuous features given a threshold.
+ */
+@AlphaComponent
+final class Binarizer extends Transformer with HasInputCol with HasOutputCol {
+
+ /**
+ * Param for threshold used to binarize continuous features.
+ * The features greater than the threshold, will be binarized to 1.0.
+ * The features equal to or less than the threshold, will be binarized to 0.0.
+ * @group param
+ */
+ val threshold: DoubleParam =
+ new DoubleParam(this, "threshold", "threshold used to binarize continuous features")
+
+ /** @group getParam */
+ def getThreshold: Double = getOrDefault(threshold)
+
+ /** @group setParam */
+ def setThreshold(value: Double): this.type = set(threshold, value)
+
+ setDefault(threshold -> 0.0)
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = extractParamMap(paramMap)
+ val td = map(threshold)
+ val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 }
+ val outputColName = map(outputCol)
+ val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata()
+ dataset.select(col("*"),
+ binarizer(col(map(inputCol))).as(outputColName, metadata))
+ }
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = extractParamMap(paramMap)
+ SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType)
+
+ val inputFields = schema.fields
+ val outputColName = map(outputCol)
+
+ require(inputFields.forall(_.name != outputColName),
+ s"Output column $outputColName already exists.")
+
+ val attr = BinaryAttribute.defaultAttr.withName(outputColName)
+ val outputFields = inputFields :+ attr.toStructField()
+ StructType(outputFields)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/7630213c/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
new file mode 100644
index 0000000..caf1b75
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+
+class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
+
+ @transient var data: Array[Double] = _
+ @transient var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
+ }
+
+ test("Binarize continuous features with default parameter") {
+ val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
+ val dataFrame: DataFrame = sqlContext.createDataFrame(
+ data.zip(defaultBinarized)).toDF("feature", "expected")
+
+ val binarizer: Binarizer = new Binarizer()
+ .setInputCol("feature")
+ .setOutputCol("binarized_feature")
+
+ binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+ case Row(x: Double, y: Double) =>
+ assert(x === y, "The feature value is not correct after binarization.")
+ }
+ }
+
+ test("Binarize continuous features with setter") {
+ val threshold: Double = 0.2
+ val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
+ val dataFrame: DataFrame = sqlContext.createDataFrame(
+ data.zip(thresholdBinarized)).toDF("feature", "expected")
+
+ val binarizer: Binarizer = new Binarizer()
+ .setInputCol("feature")
+ .setOutputCol("binarized_feature")
+ .setThreshold(threshold)
+
+ binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+ case Row(x: Double, y: Double) =>
+ assert(x === y, "The feature value is not correct after binarization.")
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org