You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ml...@apache.org on 2016/03/10 12:21:31 UTC

spark git commit: [SPARK-11108][ML] OneHotEncoder should support other numeric types

Repository: spark
Updated Branches:
  refs/heads/master 9525c563d -> 9fe38aba1


[SPARK-11108][ML] OneHotEncoder should support other numeric types

Adding support for other numeric types:

* Integer
* Short
* Long
* Float
* Decimal

Author: sethah <se...@gmail.com>

Closes #9777 from sethah/SPARK-11108.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9fe38aba
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9fe38aba
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9fe38aba

Branch: refs/heads/master
Commit: 9fe38aba1f70a4cb19ec1f9df4814fce0b267b54
Parents: 9525c56
Author: sethah <se...@gmail.com>
Authored: Thu Mar 10 13:17:41 2016 +0200
Committer: Nick Pentreath <ni...@gmail.com>
Committed: Thu Mar 10 13:17:41 2016 +0200

----------------------------------------------------------------------
 .../apache/spark/ml/feature/OneHotEncoder.scala |  9 ++++--
 .../spark/ml/feature/OneHotEncoderSuite.scala   | 29 ++++++++++++++++++++
 2 files changed, 35 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9fe38aba/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index e9df161..fa5013d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions.{col, udf}
-import org.apache.spark.sql.types.{DoubleType, StructType}
+import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
 
 /**
  * :: Experimental ::
@@ -70,7 +70,8 @@ class OneHotEncoder(override val uid: String) extends Transformer
     val inputColName = $(inputCol)
     val outputColName = $(outputCol)
 
-    SchemaUtils.checkColumnType(schema, inputColName, DoubleType)
+    require(schema(inputColName).dataType.isInstanceOf[NumericType],
+      s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
     val inputFields = schema.fields
     require(!inputFields.exists(_.name == outputColName),
       s"Output column $outputColName already exists.")
@@ -133,7 +134,9 @@ class OneHotEncoder(override val uid: String) extends Transformer
       val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0))
         .aggregate(0.0)(
           (m, x) => {
-            assert(x >=0.0 && x == x.toInt,
+            assert(x <= Int.MaxValue,
+              s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x")
+            assert(x >= 0.0 && x == x.toInt,
               s"Values from column $inputColName must be indices, but got $x.")
             math.max(m, x)
           },

http://git-wip-us.apache.org/repos/asf/spark/blob/9fe38aba/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index e238b33..49803ae 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types._
 
 class OneHotEncoderSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -111,4 +112,32 @@ class OneHotEncoderSuite
       .setDropLast(false)
     testDefaultReadWrite(t)
   }
+
+  test("OneHotEncoder with varying types") {
+    val df = stringIndexed()
+    val dfWithTypes = df
+      .withColumn("shortLabel", df("labelIndex").cast(ShortType))
+      .withColumn("longLabel", df("labelIndex").cast(LongType))
+      .withColumn("intLabel", df("labelIndex").cast(IntegerType))
+      .withColumn("floatLabel", df("labelIndex").cast(FloatType))
+      .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0)))
+    val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel",
+      "floatLabel", "decimalLabel")
+    for (col <- cols) {
+      val encoder = new OneHotEncoder()
+        .setInputCol(col)
+        .setOutputCol("labelVec")
+        .setDropLast(false)
+      val encoded = encoder.transform(dfWithTypes)
+
+      val output = encoded.select("id", "labelVec").rdd.map { r =>
+        val vec = r.getAs[Vector](1)
+        (r.getInt(0), vec(0), vec(1), vec(2))
+      }.collect().toSet
+      // a -> 0, b -> 2, c -> 1
+      val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
+        (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
+      assert(output === expected)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org