You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/02/16 05:15:30 UTC

spark git commit: [SPARK-13097][ML] Binarizer allowing Double AND Vector input types

Repository: spark
Updated Branches:
  refs/heads/master adb548365 -> cbeb006f2


[SPARK-13097][ML] Binarizer allowing Double AND Vector input types

This enhancement extends the existing SparkML Binarizer [SPARK-5891] to allow Vector in addition to the existing Double input column type.

A use case for this enhancement is for when a user wants to Binarize many similar feature columns at once using the same threshold value (for example a binary threshold applied to many pixels in an image).

This contribution is my original work and I license the work to the project under the project's open source license.

viirya mengxr

Author: seddonm1 <se...@gmail.com>

Closes #10976 from seddonm1/master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cbeb006f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cbeb006f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cbeb006f

Branch: refs/heads/master
Commit: cbeb006f23838b2f19e700e20b25003aeb3dfb01
Parents: adb5483
Author: seddonm1 <se...@gmail.com>
Authored: Mon Feb 15 20:15:27 2016 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Feb 15 20:15:27 2016 -0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Binarizer.scala | 62 ++++++++++++++------
 .../spark/ml/feature/BinarizerSuite.scala       | 36 ++++++++++++
 2 files changed, 81 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cbeb006f/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 544cf05..2f8e3a0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -17,15 +17,18 @@
 
 package org.apache.spark.ml.feature
 
+import scala.collection.mutable.ArrayBuilder
+
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.attribute.BinaryAttribute
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
 import org.apache.spark.ml.util._
+import org.apache.spark.mllib.linalg._
 import org.apache.spark.sql._
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, StructType}
+import org.apache.spark.sql.types._
 
 /**
  * :: Experimental ::
@@ -62,28 +65,53 @@ final class Binarizer(override val uid: String)
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   override def transform(dataset: DataFrame): DataFrame = {
-    transformSchema(dataset.schema, logging = true)
+    val outputSchema = transformSchema(dataset.schema, logging = true)
+    val schema = dataset.schema
+    val inputType = schema($(inputCol)).dataType
     val td = $(threshold)
-    val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 }
-    val outputColName = $(outputCol)
-    val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata()
-    dataset.select(col("*"),
-      binarizer(col($(inputCol))).as(outputColName, metadata))
+
+    val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 }
+    val binarizerVector = udf { (data: Vector) =>
+      val indices = ArrayBuilder.make[Int]
+      val values = ArrayBuilder.make[Double]
+
+      data.foreachActive { (index, value) =>
+        if (value > td) {
+          indices += index
+          values +=  1.0
+        }
+      }
+
+      Vectors.sparse(data.size, indices.result(), values.result()).compressed
+    }
+
+    val metadata = outputSchema($(outputCol)).metadata
+
+    inputType match {
+      case DoubleType =>
+        dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata))
+      case _: VectorUDT =>
+        dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata))
+    }
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
-    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
-
-    val inputFields = schema.fields
+    val inputType = schema($(inputCol)).dataType
     val outputColName = $(outputCol)
 
-    require(inputFields.forall(_.name != outputColName),
-      s"Output column $outputColName already exists.")
-
-    val attr = BinaryAttribute.defaultAttr.withName(outputColName)
-    val outputFields = inputFields :+ attr.toStructField()
-    StructType(outputFields)
+    val outCol: StructField = inputType match {
+      case DoubleType =>
+        BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
+      case _: VectorUDT =>
+        new StructField(outputColName, new VectorUDT, true)
+      case other =>
+        throw new IllegalArgumentException(s"Data type $other is not supported.")
+    }
+
+    if (schema.fieldNames.contains(outputColName)) {
+      throw new IllegalArgumentException(s"Output column $outputColName already exists.")
+    }
+    StructType(schema.fields :+ outCol)
   }
 
   override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)

http://git-wip-us.apache.org/repos/asf/spark/blob/cbeb006f/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 6d2d8fe..714b9db 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
 
@@ -68,6 +69,41 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
     }
   }
 
+  test("Binarize vector of continuous features with default parameter") {
+    val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
+    val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
+      (Vectors.dense(data), Vectors.dense(defaultBinarized))
+    )).toDF("feature", "expected")
+
+    val binarizer: Binarizer = new Binarizer()
+      .setInputCol("feature")
+      .setOutputCol("binarized_feature")
+
+    binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+      case Row(x: Vector, y: Vector) =>
+        assert(x == y, "The feature value is not correct after binarization.")
+    }
+  }
+
+  test("Binarize vector of continuous features with setter") {
+    val threshold: Double = 0.2
+    val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
+    val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
+      (Vectors.dense(data), Vectors.dense(defaultBinarized))
+    )).toDF("feature", "expected")
+
+    val binarizer: Binarizer = new Binarizer()
+      .setInputCol("feature")
+      .setOutputCol("binarized_feature")
+      .setThreshold(threshold)
+
+    binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+      case Row(x: Vector, y: Vector) =>
+        assert(x == y, "The feature value is not correct after binarization.")
+    }
+  }
+
+
   test("read/write") {
     val t = new Binarizer()
       .setInputCol("myInputCol")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org