You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/01/05 19:51:29 UTC

spark git commit: [SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator

Repository: spark
Updated Branches:
  refs/heads/master c0b7424ec -> 930b90a84


[SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator

## What changes were proposed in this pull request?

Follow-up cleanups for the OneHotEncoderEstimator PR.  See some discussion in the original PR: https://github.com/apache/spark/pull/19527 or read below for what this PR includes:
* configedCategorySize: I reverted this to return an Array.  I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF.
* encoder: I reorganized the logic to show what I meant in the comment in the previous PR.  I think it's simpler but am open to suggestions.

I also made some small style cleanups based on IntelliJ warnings.

## How was this patch tested?

Existing unit tests

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #20132 from jkbradley/viirya-SPARK-13030.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/930b90a8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/930b90a8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/930b90a8

Branch: refs/heads/master
Commit: 930b90a84871e2504b57ed50efa7b8bb52d3ba44
Parents: c0b7424
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Fri Jan 5 11:51:25 2018 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Fri Jan 5 11:51:25 2018 -0800

----------------------------------------------------------------------
 .../ml/feature/OneHotEncoderEstimator.scala     | 92 +++++++++++---------
 1 file changed, 49 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/930b90a8/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
index 074622d..bd1e342 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
@@ -30,24 +30,27 @@ import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.expressions.UserDefinedFunction
 import org.apache.spark.sql.functions.{col, lit, udf}
-import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
+import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
 
 /** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */
 private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
     with HasInputCols with HasOutputCols {
 
   /**
-   * Param for how to handle invalid data.
+   * Param for how to handle invalid data during transform().
    * Options are 'keep' (invalid data presented as an extra categorical feature) or
    * 'error' (throw an error).
+   * Note that this Param is only used during transform; during fitting, invalid data
+   * will result in an error.
    * Default: "error"
    * @group param
    */
   @Since("2.3.0")
   override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
-    "How to handle invalid data " +
+    "How to handle invalid data during transform(). " +
     "Options are 'keep' (invalid data presented as an extra categorical feature) " +
-    "or error (throw an error).",
+    "or error (throw an error). Note that this Param is only used during transform; " +
+    "during fitting, invalid data will result in an error.",
     ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids))
 
   setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID)
@@ -66,10 +69,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
   def getDropLast: Boolean = $(dropLast)
 
   protected def validateAndTransformSchema(
-      schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = {
+      schema: StructType,
+      dropLast: Boolean,
+      keepInvalid: Boolean): StructType = {
     val inputColNames = $(inputCols)
     val outputColNames = $(outputCols)
-    val existingFields = schema.fields
 
     require(inputColNames.length == outputColNames.length,
       s"The number of input columns ${inputColNames.length} must be the same as the number of " +
@@ -197,6 +201,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat
   override def load(path: String): OneHotEncoderEstimator = super.load(path)
 }
 
+/**
+ * @param categorySizes  Original number of categories for each feature being encoded.
+ *                       The array contains one value for each input column, in order.
+ */
 @Since("2.3.0")
 class OneHotEncoderModel private[ml] (
     @Since("2.3.0") override val uid: String,
@@ -205,60 +213,58 @@ class OneHotEncoderModel private[ml] (
 
   import OneHotEncoderModel._
 
-  // Returns the category size for a given index with `dropLast` and `handleInvalid`
+  // Returns the category size for each index with `dropLast` and `handleInvalid`
   // taken into account.
-  private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = {
+  private def getConfigedCategorySizes: Array[Int] = {
     val dropLast = getDropLast
     val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID
 
     if (!dropLast && keepInvalid) {
       // When `handleInvalid` is "keep", an extra category is added as last category
       // for invalid data.
-      orgCategorySize + 1
+      categorySizes.map(_ + 1)
     } else if (dropLast && !keepInvalid) {
       // When `dropLast` is true, the last category is removed.
-      orgCategorySize - 1
+      categorySizes.map(_ - 1)
     } else {
       // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid
       // data is removed. Thus, it is the same as the plain number of categories.
-      orgCategorySize
+      categorySizes
     }
   }
 
   private def encoder: UserDefinedFunction = {
-    val oneValue = Array(1.0)
-    val emptyValues = Array.empty[Double]
-    val emptyIndices = Array.empty[Int]
-    val dropLast = getDropLast
-    val handleInvalid = getHandleInvalid
-    val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID
+    val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID
+    val configedSizes = getConfigedCategorySizes
+    val localCategorySizes = categorySizes
 
     // The udf performed on input data. The first parameter is the input value. The second
-    // parameter is the index of input.
-    udf { (label: Double, idx: Int) =>
-      val plainNumCategories = categorySizes(idx)
-      val size = configedCategorySize(plainNumCategories, idx)
-
-      if (label < 0) {
-        throw new SparkException(s"Negative value: $label. Input can't be negative.")
-      } else if (label == size && dropLast && !keepInvalid) {
-        // When `dropLast` is true and `handleInvalid` is not "keep",
-        // the last category is removed.
-        Vectors.sparse(size, emptyIndices, emptyValues)
-      } else if (label >= plainNumCategories && keepInvalid) {
-        // When `handleInvalid` is "keep", encodes invalid data to last category (and removed
-        // if `dropLast` is true)
-        if (dropLast) {
-          Vectors.sparse(size, emptyIndices, emptyValues)
+    // parameter is the index in inputCols of the column being encoded.
+    udf { (label: Double, colIdx: Int) =>
+      val origCategorySize = localCategorySizes(colIdx)
+      // idx: index in vector of the single 1-valued element
+      val idx = if (label >= 0 && label < origCategorySize) {
+        label
+      } else {
+        if (keepInvalid) {
+          origCategorySize
         } else {
-          Vectors.sparse(size, Array(size - 1), oneValue)
+          if (label < 0) {
+            throw new SparkException(s"Negative value: $label. Input can't be negative. " +
+              s"To handle invalid values, set Param handleInvalid to " +
+              s"${OneHotEncoderEstimator.KEEP_INVALID}")
+          } else {
+            throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
+              s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
+          }
         }
-      } else if (label < plainNumCategories) {
-        Vectors.sparse(size, Array(label.toInt), oneValue)
+      }
+
+      val size = configedSizes(colIdx)
+      if (idx < size) {
+        Vectors.sparse(size, Array(idx.toInt), Array(1.0))
       } else {
-        assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID)
-        throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
-          s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
+        Vectors.sparse(size, Array.empty[Int], Array.empty[Double])
       }
     }
   }
@@ -282,7 +288,6 @@ class OneHotEncoderModel private[ml] (
   @Since("2.3.0")
   override def transformSchema(schema: StructType): StructType = {
     val inputColNames = $(inputCols)
-    val outputColNames = $(outputCols)
 
     require(inputColNames.length == categorySizes.length,
       s"The number of input columns ${inputColNames.length} must be the same as the number of " +
@@ -300,6 +305,7 @@ class OneHotEncoderModel private[ml] (
    * account. Mismatched numbers will cause exception.
    */
   private def verifyNumOfValues(schema: StructType): StructType = {
+    val configedSizes = getConfigedCategorySizes
     $(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
       val inputColName = $(inputCols)(idx)
       val attrGroup = AttributeGroup.fromStructField(schema(outputColName))
@@ -308,9 +314,9 @@ class OneHotEncoderModel private[ml] (
       // comparing with expected category number with `handleInvalid` and
       // `dropLast` taken into account.
       if (attrGroup.attributes.nonEmpty) {
-        val numCategories = configedCategorySize(categorySizes(idx), idx)
+        val numCategories = configedSizes(idx)
         require(attrGroup.size == numCategories, "OneHotEncoderModel expected " +
-          s"$numCategories categorical values for input column ${inputColName}, " +
+          s"$numCategories categorical values for input column $inputColName, " +
             s"but the input column had metadata specifying ${attrGroup.size} values.")
       }
     }
@@ -322,7 +328,7 @@ class OneHotEncoderModel private[ml] (
     val transformedSchema = transformSchema(dataset.schema, logging = true)
     val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
 
-    val encodedColumns = (0 until $(inputCols).length).map { idx =>
+    val encodedColumns = $(inputCols).indices.map { idx =>
       val inputColName = $(inputCols)(idx)
       val outputColName = $(outputCols)(idx)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org