You are viewing a plain text version of this content. The canonical link for it is here.
Posted to by on 2015/05/29 09:51:20 UTC

spark git commit: [SPARK-7912] [SPARK-7921] [MLLIB] Update OneHotEncoder to handle ML attributes and change includeFirst to dropLast

Repository: spark
Updated Branches:
  refs/heads/master 97a60cf75 -> 23452be94

[SPARK-7912] [SPARK-7921] [MLLIB] Update OneHotEncoder to handle ML attributes and change includeFirst to dropLast

This PR contains two major changes to `OneHotEncoder`:

1. more robust handling of ML attributes. If the input attribute is unknown, we look at the values to get the max category index
2. change `includeFirst` to `dropLast` and leave the default to `true`. There are couple benefits:

    a. consistent with other tutorials of one-hot encoding (or dummy coding) (e.g.,
    b. keep the indices unmodified in the output vector. If we drop the first, all indices will be shifted by 1.
    c. If users use `StringIndex`, the last element is the least frequent one.

Sorry for including two changes in one PR! I'll update the user guide in another PR.

jkbradley sryza

Author: Xiangrui Meng <>

Closes #6466 from mengxr/SPARK-7912 and squashes the following commits:

a280dca [Xiangrui Meng] fix tests
d8f234d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7912
171b276 [Xiangrui Meng] mention the difference between our impl vs sklearn's
00dfd96 [Xiangrui Meng] update OneHotEncoder in Python
208ddad [Xiangrui Meng] update OneHotEncoder to handle ML attributes and change includeFirst to dropLast


Branch: refs/heads/master
Commit: 23452be944463dae72a35b58551040556dd3aeb5
Parents: 97a60cf
Author: Xiangrui Meng <>
Authored: Fri May 29 00:51:12 2015 -0700
Committer: Xiangrui Meng <>
Committed: Fri May 29 00:51:12 2015 -0700

 .../apache/spark/ml/feature/OneHotEncoder.scala | 160 +++++++++++++------
 .../spark/ml/feature/OneHotEncoderSuite.scala   |  42 ++++-
 python/pyspark/ml/                    |  58 ++++---
 3 files changed, 176 insertions(+), 84 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index eb6ec49..8f34878 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -17,94 +17,152 @@
-import org.apache.spark.SparkException
 import org.apache.spark.annotation.Experimental
-import{Attribute, BinaryAttribute, NominalAttribute}
 import{HasInputCol, HasOutputCol}
 import{Identifiable, SchemaUtils}
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
-import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.{DoubleType, StructType}
  * :: Experimental ::
- * A one-hot encoder that maps a column of label indices to a column of binary vectors, with
- * at most a single one-value. By default, the binary vector has an element for each category, so
- * with 5 categories, an input value of 2.0 would map to an output vector of
- * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the
- * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
- * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
- * linearly dependent because they sum up to one.
+ * A one-hot encoder that maps a column of category indices to a column of binary vectors, with
+ * at most a single one-value per row that indicates the input category index.
+ * For example with 5 categories, an input value of 2.0 would map to an output vector of
+ * `[0.0, 0.0, 1.0, 0.0]`.
+ * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]]
+ * because it makes the vector entries sum up to one, and hence linearly dependent.
+ * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
+ * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories.
+ * The output vectors are sparse.
+ *
+ * @see [[StringIndexer]] for converting categorical values into category indices
-class OneHotEncoder(override val uid: String)
-  extends UnaryTransformer[Double, Vector, OneHotEncoder] with HasInputCol with HasOutputCol {
+class OneHotEncoder(override val uid: String) extends Transformer
+  with HasInputCol with HasOutputCol {
   def this() = this(Identifiable.randomUID("oneHot"))
-   * Whether to include a component in the encoded vectors for the first category, defaults to true.
+   * Whether to drop the last category in the encoded vector (default: true)
    * @group param
-  final val includeFirst: BooleanParam =
-    new BooleanParam(this, "includeFirst", "include first category")
-  setDefault(includeFirst -> true)
-  private var categories: Array[String] = _
+  final val dropLast: BooleanParam =
+    new BooleanParam(this, "dropLast", "whether to drop the last category")
+  setDefault(dropLast -> true)
   /** @group setParam */
-  def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)
+  def setDropLast(value: Boolean): this.type = set(dropLast, value)
   /** @group setParam */
-  override def setInputCol(value: String): this.type = set(inputCol, value)
+  def setInputCol(value: String): this.type = set(inputCol, value)
   /** @group setParam */
-  override def setOutputCol(value: String): this.type = set(outputCol, value)
+  def setOutputCol(value: String): this.type = set(outputCol, value)
   override def transformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
-    val inputFields = schema.fields
+    val is = "_is_"
+    val inputColName = $(inputCol)
     val outputColName = $(outputCol)
-    require(inputFields.forall( != $(outputCol)),
-      s"Output column ${$(outputCol)} already exists.")
-    val inputColAttr = Attribute.fromStructField(schema($(inputCol)))
-    categories = inputColAttr match {
+    SchemaUtils.checkColumnType(schema, inputColName, DoubleType)
+    val inputFields = schema.fields
+    require(!inputFields.exists( == outputColName),
+      s"Output column $outputColName already exists.")
+    val inputAttr = Attribute.fromStructField(schema(inputColName))
+    val outputAttrNames: Option[Array[String]] = inputAttr match {
       case nominal: NominalAttribute =>
-        nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray)
-      case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1"))
+        if (nominal.values.isDefined) {
+ => inputColName + is + v))
+        } else if (nominal.numValues.isDefined) {
+ => Array.tabulate(n)(i => inputColName + is + i))
+        } else {
+          None
+        }
+      case binary: BinaryAttribute =>
+        if (binary.values.isDefined) {
+ => inputColName + is + v))
+        } else {
+          Some(Array.tabulate(2)(i => inputColName + is + i))
+        }
+      case _: NumericAttribute =>
+        throw new RuntimeException(
+          s"The input column $inputColName cannot be numeric.")
       case _ =>
-        throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal")
+        None // optimistic about unknown attributes
-    val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray
-    val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
-    val outputFields = inputFields :+ attr.toStructField()
+    val filteredOutputAttrNames = { names =>
+      if ($(dropLast)) {
+        require(names.length > 1,
+          s"The input column $inputColName should have at least two distinct values.")
+        names.dropRight(1)
+      } else {
+        names
+      }
+    }
+    val outputAttrGroup = if (filteredOutputAttrNames.isDefined) {
+      val attrs: Array[Attribute] = { name =>
+        BinaryAttribute.defaultAttr.withName(name)
+      }
+      new AttributeGroup($(outputCol), attrs)
+    } else {
+      new AttributeGroup($(outputCol))
+    }
+    val outputFields = inputFields :+ outputAttrGroup.toStructField()
-  protected override def createTransformFunc(): (Double) => Vector = {
-    val first = $(includeFirst)
-    val vecLen = if (first) categories.length else categories.length - 1
+  override def transform(dataset: DataFrame): DataFrame = {
+    // schema transformation
+    val is = "_is_"
+    val inputColName = $(inputCol)
+    val outputColName = $(outputCol)
+    val shouldDropLast = $(dropLast)
+    var outputAttrGroup = AttributeGroup.fromStructField(
+      transformSchema(dataset.schema)(outputColName))
+    if (outputAttrGroup.size < 0) {
+      // If the number of attributes is unknown, we check the values from the input column.
+      val numAttrs =
+        .aggregate(0.0)(
+          (m, x) => {
+            assert(x >=0.0 && x == x.toInt,
+              s"Values from column $inputColName must be indices, but got $x.")
+            math.max(m, x)
+          },
+          (m0, m1) => {
+            math.max(m0, m1)
+          }
+        ).toInt + 1
+      val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
+      val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
+      val outputAttrs: Array[Attribute] =
+ => BinaryAttribute.defaultAttr.withName(name))
+      outputAttrGroup = new AttributeGroup(outputColName, outputAttrs)
+    }
+    val metadata = outputAttrGroup.toMetadata()
+    // data transformation
+    val size = outputAttrGroup.size
     val oneValue = Array(1.0)
     val emptyValues = Array[Double]()
     val emptyIndices = Array[Int]()
-    label: Double => {
-      val values = if (first || label != 0.0) oneValue else emptyValues
-      val indices = if (first) {
-        Array(label.toInt)
-      } else if (label != 0.0) {
-        Array(label.toInt - 1)
+    val encode = udf { label: Double =>
+      if (label < size) {
+        Vectors.sparse(size, Array(label.toInt), oneValue)
       } else {
-        emptyIndices
+        Vectors.sparse(size, emptyIndices, emptyValues)
-      Vectors.sparse(vecLen, indices, values)
-  }
-  /**
-   * Returns the data type of the output column.
-   */
-  protected def outputDataType: DataType = new VectorUDT
+"*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata))
+  }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 056b9ed..9018d00 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -19,10 +19,11 @@ package
 import org.scalatest.FunSuite
+import{AttributeGroup, BinaryAttribute, NominalAttribute}
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.col
 class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
@@ -36,15 +37,16 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
-  test("OneHotEncoder includeFirst = true") {
+  test("OneHotEncoder dropLast = false") {
     val transformed = stringIndexed()
     val encoder = new OneHotEncoder()
+      .setDropLast(false)
     val encoded = encoder.transform(transformed)
     val output ="id", "labelVec").map { r =>
-      val vec = r.get(1).asInstanceOf[Vector]
+      val vec = r.getAs[Vector](1)
       (r.getInt(0), vec(0), vec(1), vec(2))
     // a -> 0, b -> 2, c -> 1
@@ -53,22 +55,46 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
     assert(output === expected)
-  test("OneHotEncoder includeFirst = false") {
+  test("OneHotEncoder dropLast = true") {
     val transformed = stringIndexed()
     val encoder = new OneHotEncoder()
-      .setIncludeFirst(false)
     val encoded = encoder.transform(transformed)
     val output ="id", "labelVec").map { r =>
-      val vec = r.get(1).asInstanceOf[Vector]
+      val vec = r.getAs[Vector](1)
       (r.getInt(0), vec(0), vec(1))
     // a -> 0, b -> 2, c -> 1
-    val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0),
-      (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0))
+    val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0),
+      (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0))
     assert(output === expected)
+  test("input column with ML attribute") {
+    val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
+    val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
+      .select(col("size").as("size", attr.toMetadata()))
+    val encoder = new OneHotEncoder()
+      .setInputCol("size")
+      .setOutputCol("encoded")
+    val output = encoder.transform(df)
+    val group = AttributeGroup.fromStructField(output.schema("encoded"))
+    assert(group.size === 2)
+    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
+    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
+  }
+  test("input column without ML attribute") {
+    val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
+    val encoder = new OneHotEncoder()
+      .setInputCol("index")
+      .setOutputCol("encoded")
+    val output = encoder.transform(df)
+    val group = AttributeGroup.fromStructField(output.schema("encoded"))
+    assert(group.size === 2)
+    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
+    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
+  }
diff --git a/python/pyspark/ml/ b/python/pyspark/ml/
index b0479d9..ddb33f4 100644
--- a/python/pyspark/ml/
+++ b/python/pyspark/ml/
@@ -324,65 +324,73 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol):
 class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol):
-    A one-hot encoder that maps a column of label indices to a column of binary vectors, with
-    at most a single one-value. By default, the binary vector has an element for each category, so
-    with 5 categories, an input value of 2.0 would map to an output vector of
-    (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so
-    the output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
-    of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
-    linearly dependent because they sum up to one.
-    TODO: This method requires the use of StringIndexer first. Decouple them.
+    A one-hot encoder that maps a column of category indices to a
+    column of binary vectors, with at most a single one-value per row
+    that indicates the input category index.
+    For example with 5 categories, an input value of 2.0 would map to
+    an output vector of `[0.0, 0.0, 1.0, 0.0]`.
+    The last category is not included by default (configurable via
+    :py:attr:`dropLast`) because it makes the vector entries sum up to
+    one, and hence linearly dependent.
+    So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
+    Note that this is different from scikit-learn's OneHotEncoder,
+    which keeps all categories.
+    The output vectors are sparse.
+    .. seealso::
+       :py:class:`StringIndexer` for converting categorical values into
+       category indices
     >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
     >>> model =
     >>> td = model.transform(stringIndDf)
-    >>> encoder = OneHotEncoder(includeFirst=False, inputCol="indexed", outputCol="features")
+    >>> encoder = OneHotEncoder(inputCol="indexed", outputCol="features")
     >>> encoder.transform(td).head().features
-    SparseVector(2, {})
+    SparseVector(2, {0: 1.0})
     >>> encoder.setParams(outputCol="freqs").transform(td).head().freqs
-    SparseVector(2, {})
-    >>> params = {encoder.includeFirst: True, encoder.outputCol: "test"}
+    SparseVector(2, {0: 1.0})
+    >>> params = {encoder.dropLast: False, encoder.outputCol: "test"}
     >>> encoder.transform(td, params).head().test
     SparseVector(3, {0: 1.0})
     # a placeholder to make it appear in the generated doc
-    includeFirst = Param(Params._dummy(), "includeFirst", "include first category")
+    dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category")
-    def __init__(self, includeFirst=True, inputCol=None, outputCol=None):
+    def __init__(self, dropLast=True, inputCol=None, outputCol=None):
         __init__(self, includeFirst=True, inputCol=None, outputCol=None)
         super(OneHotEncoder, self).__init__()
         self._java_obj = self._new_java_obj("", self.uid)
-        self.includeFirst = Param(self, "includeFirst", "include first category")
-        self._setDefault(includeFirst=True)
+        self.dropLast = Param(self, "dropLast", "whether to drop the last category")
+        self._setDefault(dropLast=True)
         kwargs = self.__init__._input_kwargs
-    def setParams(self, includeFirst=True, inputCol=None, outputCol=None):
+    def setParams(self, dropLast=True, inputCol=None, outputCol=None):
-        setParams(self, includeFirst=True, inputCol=None, outputCol=None)
+        setParams(self, dropLast=True, inputCol=None, outputCol=None)
         Sets params for this OneHotEncoder.
         kwargs = self.setParams._input_kwargs
         return self._set(**kwargs)
-    def setIncludeFirst(self, value):
+    def setDropLast(self, value):
-        Sets the value of :py:attr:`includeFirst`.
+        Sets the value of :py:attr:`dropLast`.
-        self._paramMap[self.includeFirst] = value
+        self._paramMap[self.dropLast] = value
         return self
-    def getIncludeFirst(self):
+    def getDropLast(self):
-        Gets the value of includeFirst or its default value.
+        Gets the value of dropLast or its default value.
-        return self.getOrDefault(self.includeFirst)
+        return self.getOrDefault(self.dropLast)

To unsubscribe, e-mail:
For additional commands, e-mail: