You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/01/05 19:51:29 UTC
spark git commit: [SPARK-13030][ML] Follow-up cleanups for
OneHotEncoderEstimator
Repository: spark
Updated Branches:
refs/heads/master c0b7424ec -> 930b90a84
[SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator
## What changes were proposed in this pull request?
Follow-up cleanups for the OneHotEncoderEstimator PR. See some discussion in the original PR: https://github.com/apache/spark/pull/19527 or read below for what this PR includes:
* configedCategorySize: I reverted this to return an Array. I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF.
* encoder: I reorganized the logic to show what I meant in the comment in the previous PR. I think it's simpler but am open to suggestions.
I also made some small style cleanups based on IntelliJ warnings.
## How was this patch tested?
Existing unit tests
Author: Joseph K. Bradley <jo...@databricks.com>
Closes #20132 from jkbradley/viirya-SPARK-13030.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/930b90a8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/930b90a8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/930b90a8
Branch: refs/heads/master
Commit: 930b90a84871e2504b57ed50efa7b8bb52d3ba44
Parents: c0b7424
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Fri Jan 5 11:51:25 2018 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Fri Jan 5 11:51:25 2018 -0800
----------------------------------------------------------------------
.../ml/feature/OneHotEncoderEstimator.scala | 92 +++++++++++---------
1 file changed, 49 insertions(+), 43 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/930b90a8/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
index 074622d..bd1e342 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
@@ -30,24 +30,27 @@ import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, lit, udf}
-import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
+import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */
private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
with HasInputCols with HasOutputCols {
/**
- * Param for how to handle invalid data.
+ * Param for how to handle invalid data during transform().
* Options are 'keep' (invalid data presented as an extra categorical feature) or
* 'error' (throw an error).
+ * Note that this Param is only used during transform; during fitting, invalid data
+ * will result in an error.
* Default: "error"
* @group param
*/
@Since("2.3.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
- "How to handle invalid data " +
+ "How to handle invalid data during transform(). " +
"Options are 'keep' (invalid data presented as an extra categorical feature) " +
- "or error (throw an error).",
+ "or error (throw an error). Note that this Param is only used during transform; " +
+ "during fitting, invalid data will result in an error.",
ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids))
setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID)
@@ -66,10 +69,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
def getDropLast: Boolean = $(dropLast)
protected def validateAndTransformSchema(
- schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = {
+ schema: StructType,
+ dropLast: Boolean,
+ keepInvalid: Boolean): StructType = {
val inputColNames = $(inputCols)
val outputColNames = $(outputCols)
- val existingFields = schema.fields
require(inputColNames.length == outputColNames.length,
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
@@ -197,6 +201,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat
override def load(path: String): OneHotEncoderEstimator = super.load(path)
}
+/**
+ * @param categorySizes Original number of categories for each feature being encoded.
+ * The array contains one value for each input column, in order.
+ */
@Since("2.3.0")
class OneHotEncoderModel private[ml] (
@Since("2.3.0") override val uid: String,
@@ -205,60 +213,58 @@ class OneHotEncoderModel private[ml] (
import OneHotEncoderModel._
- // Returns the category size for a given index with `dropLast` and `handleInvalid`
+ // Returns the category size for each index with `dropLast` and `handleInvalid`
// taken into account.
- private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = {
+ private def getConfigedCategorySizes: Array[Int] = {
val dropLast = getDropLast
val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID
if (!dropLast && keepInvalid) {
// When `handleInvalid` is "keep", an extra category is added as last category
// for invalid data.
- orgCategorySize + 1
+ categorySizes.map(_ + 1)
} else if (dropLast && !keepInvalid) {
// When `dropLast` is true, the last category is removed.
- orgCategorySize - 1
+ categorySizes.map(_ - 1)
} else {
// When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid
// data is removed. Thus, it is the same as the plain number of categories.
- orgCategorySize
+ categorySizes
}
}
private def encoder: UserDefinedFunction = {
- val oneValue = Array(1.0)
- val emptyValues = Array.empty[Double]
- val emptyIndices = Array.empty[Int]
- val dropLast = getDropLast
- val handleInvalid = getHandleInvalid
- val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID
+ val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID
+ val configedSizes = getConfigedCategorySizes
+ val localCategorySizes = categorySizes
// The udf performed on input data. The first parameter is the input value. The second
- // parameter is the index of input.
- udf { (label: Double, idx: Int) =>
- val plainNumCategories = categorySizes(idx)
- val size = configedCategorySize(plainNumCategories, idx)
-
- if (label < 0) {
- throw new SparkException(s"Negative value: $label. Input can't be negative.")
- } else if (label == size && dropLast && !keepInvalid) {
- // When `dropLast` is true and `handleInvalid` is not "keep",
- // the last category is removed.
- Vectors.sparse(size, emptyIndices, emptyValues)
- } else if (label >= plainNumCategories && keepInvalid) {
- // When `handleInvalid` is "keep", encodes invalid data to last category (and removed
- // if `dropLast` is true)
- if (dropLast) {
- Vectors.sparse(size, emptyIndices, emptyValues)
+ // parameter is the index in inputCols of the column being encoded.
+ udf { (label: Double, colIdx: Int) =>
+ val origCategorySize = localCategorySizes(colIdx)
+ // idx: index in vector of the single 1-valued element
+ val idx = if (label >= 0 && label < origCategorySize) {
+ label
+ } else {
+ if (keepInvalid) {
+ origCategorySize
} else {
- Vectors.sparse(size, Array(size - 1), oneValue)
+ if (label < 0) {
+ throw new SparkException(s"Negative value: $label. Input can't be negative. " +
+ s"To handle invalid values, set Param handleInvalid to " +
+ s"${OneHotEncoderEstimator.KEEP_INVALID}")
+ } else {
+ throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
+ s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
+ }
}
- } else if (label < plainNumCategories) {
- Vectors.sparse(size, Array(label.toInt), oneValue)
+ }
+
+ val size = configedSizes(colIdx)
+ if (idx < size) {
+ Vectors.sparse(size, Array(idx.toInt), Array(1.0))
} else {
- assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID)
- throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
- s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
+ Vectors.sparse(size, Array.empty[Int], Array.empty[Double])
}
}
}
@@ -282,7 +288,6 @@ class OneHotEncoderModel private[ml] (
@Since("2.3.0")
override def transformSchema(schema: StructType): StructType = {
val inputColNames = $(inputCols)
- val outputColNames = $(outputCols)
require(inputColNames.length == categorySizes.length,
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
@@ -300,6 +305,7 @@ class OneHotEncoderModel private[ml] (
* account. Mismatched numbers will cause exception.
*/
private def verifyNumOfValues(schema: StructType): StructType = {
+ val configedSizes = getConfigedCategorySizes
$(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
val inputColName = $(inputCols)(idx)
val attrGroup = AttributeGroup.fromStructField(schema(outputColName))
@@ -308,9 +314,9 @@ class OneHotEncoderModel private[ml] (
// comparing with expected category number with `handleInvalid` and
// `dropLast` taken into account.
if (attrGroup.attributes.nonEmpty) {
- val numCategories = configedCategorySize(categorySizes(idx), idx)
+ val numCategories = configedSizes(idx)
require(attrGroup.size == numCategories, "OneHotEncoderModel expected " +
- s"$numCategories categorical values for input column ${inputColName}, " +
+ s"$numCategories categorical values for input column $inputColName, " +
s"but the input column had metadata specifying ${attrGroup.size} values.")
}
}
@@ -322,7 +328,7 @@ class OneHotEncoderModel private[ml] (
val transformedSchema = transformSchema(dataset.schema, logging = true)
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
- val encodedColumns = (0 until $(inputCols).length).map { idx =>
+ val encodedColumns = $(inputCols).indices.map { idx =>
val inputColName = $(inputCols)(idx)
val outputColName = $(outputCols)(idx)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org