You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/11/15 00:58:23 UTC
spark git commit: [SPARK-12375][ML] VectorIndexerModel support handle
unseen categories via handleInvalid
Repository: spark
Updated Branches:
refs/heads/master 774398045 -> 1e6f76059
[SPARK-12375][ML] VectorIndexerModel support handle unseen categories via handleInvalid
## What changes were proposed in this pull request?
Support skip/error/keep strategy, similar to `StringIndexer`.
Implemented via `try...catch`, so that it can avoid possible performance impact.
## How was this patch tested?
Unit test added.
Author: WeichenXu <we...@databricks.com>
Closes #19588 from WeichenXu123/handle_invalid_for_vector_indexer.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1e6f7605
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1e6f7605
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1e6f7605
Branch: refs/heads/master
Commit: 1e6f760593d81def059c514d34173bf2777d71ec
Parents: 7743980
Author: WeichenXu <we...@databricks.com>
Authored: Tue Nov 14 16:58:18 2017 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Nov 14 16:58:18 2017 -0800
----------------------------------------------------------------------
.../apache/spark/ml/feature/VectorIndexer.scala | 92 +++++++++++++++++---
.../spark/ml/feature/VectorIndexerSuite.scala | 39 +++++++++
2 files changed, 121 insertions(+), 10 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/1e6f7605/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index d371da7..3403ec4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -18,12 +18,13 @@
package org.apache.spark.ml.feature
import java.lang.{Double => JDouble, Integer => JInt}
-import java.util.{Map => JMap}
+import java.util.{Map => JMap, NoSuchElementException}
import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute._
@@ -37,7 +38,27 @@ import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.collection.OpenHashSet
/** Private trait for params for VectorIndexer and VectorIndexerModel */
-private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol {
+private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol
+ with HasHandleInvalid {
+
+ /**
+ * Param for how to handle invalid data (unseen labels or NULL values).
+ * Note: this param only applies to categorical features, not continuous ones.
+ * Options are:
+ * 'skip': filter out rows with invalid data.
+ * 'error': throw an error.
+ * 'keep': put invalid data in a special additional bucket, at index numCategories.
+ * Default value: "error"
+ * @group param
+ */
+ @Since("2.3.0")
+ override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
+ "How to handle invalid data (unseen labels or NULL values). " +
+ "Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), " +
+ "or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
+ ParamValidators.inArray(VectorIndexer.supportedHandleInvalids))
+
+ setDefault(handleInvalid, VectorIndexer.ERROR_INVALID)
/**
* Threshold for the number of values a categorical feature can take.
@@ -113,6 +134,10 @@ class VectorIndexer @Since("1.4.0") (
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /** @group setParam */
+ @Since("2.3.0")
+ def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
@Since("2.0.0")
override def fit(dataset: Dataset[_]): VectorIndexerModel = {
transformSchema(dataset.schema, logging = true)
@@ -148,6 +173,11 @@ class VectorIndexer @Since("1.4.0") (
@Since("1.6.0")
object VectorIndexer extends DefaultParamsReadable[VectorIndexer] {
+ private[feature] val SKIP_INVALID: String = "skip"
+ private[feature] val ERROR_INVALID: String = "error"
+ private[feature] val KEEP_INVALID: String = "keep"
+ private[feature] val supportedHandleInvalids: Array[String] =
+ Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
@Since("1.6.0")
override def load(path: String): VectorIndexer = super.load(path)
@@ -287,9 +317,15 @@ class VectorIndexerModel private[ml] (
while (featureIndex < numFeatures) {
if (categoryMaps.contains(featureIndex)) {
// categorical feature
- val featureValues: Array[String] =
+ val rawFeatureValues: Array[String] =
categoryMaps(featureIndex).toArray.sortBy(_._1).map(_._1).map(_.toString)
- if (featureValues.length == 2) {
+
+ val featureValues = if (getHandleInvalid == VectorIndexer.KEEP_INVALID) {
+ (rawFeatureValues.toList :+ "__unknown").toArray
+ } else {
+ rawFeatureValues
+ }
+ if (featureValues.length == 2 && getHandleInvalid != VectorIndexer.KEEP_INVALID) {
attrs(featureIndex) = new BinaryAttribute(index = Some(featureIndex),
values = Some(featureValues))
} else {
@@ -311,22 +347,39 @@ class VectorIndexerModel private[ml] (
// TODO: Check more carefully about whether this whole class will be included in a closure.
/** Per-vector transform function */
- private val transformFunc: Vector => Vector = {
+ private lazy val transformFunc: Vector => Vector = {
val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted
val localVectorMap = categoryMaps
val localNumFeatures = numFeatures
+ val localHandleInvalid = getHandleInvalid
val f: Vector => Vector = { (v: Vector) =>
assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" +
s" $numFeatures but found length ${v.size}")
v match {
case dv: DenseVector =>
+ var hasInvalid = false
val tmpv = dv.copy
localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
- tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
+ try {
+ tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
+ } catch {
+ case _: NoSuchElementException =>
+ localHandleInvalid match {
+ case VectorIndexer.ERROR_INVALID =>
+ throw new SparkException(s"VectorIndexer encountered invalid value " +
+ s"${tmpv(featureIndex)} on feature index ${featureIndex}. To handle " +
+ s"or skip invalid value, try setting VectorIndexer.handleInvalid.")
+ case VectorIndexer.KEEP_INVALID =>
+ tmpv.values(featureIndex) = categoryMap.size
+ case VectorIndexer.SKIP_INVALID =>
+ hasInvalid = true
+ }
+ }
}
- tmpv
+ if (hasInvalid) null else tmpv
case sv: SparseVector =>
// We use the fact that categorical value 0 is always mapped to index 0.
+ var hasInvalid = false
val tmpv = sv.copy
var catFeatureIdx = 0 // index into sortedCatFeatureIndices
var k = 0 // index into non-zero elements of sparse vector
@@ -337,12 +390,26 @@ class VectorIndexerModel private[ml] (
} else if (featureIndex > tmpv.indices(k)) {
k += 1
} else {
- tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
+ try {
+ tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
+ } catch {
+ case _: NoSuchElementException =>
+ localHandleInvalid match {
+ case VectorIndexer.ERROR_INVALID =>
+ throw new SparkException(s"VectorIndexer encountered invalid value " +
+ s"${tmpv.values(k)} on feature index ${featureIndex}. To handle " +
+ s"or skip invalid value, try setting VectorIndexer.handleInvalid.")
+ case VectorIndexer.KEEP_INVALID =>
+ tmpv.values(k) = localVectorMap(featureIndex).size
+ case VectorIndexer.SKIP_INVALID =>
+ hasInvalid = true
+ }
+ }
catFeatureIdx += 1
k += 1
}
}
- tmpv
+ if (hasInvalid) null else tmpv
}
}
f
@@ -362,7 +429,12 @@ class VectorIndexerModel private[ml] (
val newField = prepOutputField(dataset.schema)
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
val newCol = transformUDF(dataset($(inputCol)))
- dataset.withColumn($(outputCol), newCol, newField.metadata)
+ val ds = dataset.withColumn($(outputCol), newCol, newField.metadata)
+ if (getHandleInvalid == VectorIndexer.SKIP_INVALID) {
+ ds.na.drop(Array($(outputCol)))
+ } else {
+ ds
+ }
}
@Since("1.4.0")
http://git-wip-us.apache.org/repos/asf/spark/blob/1e6f7605/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index f2cca8a..69a7b75 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -38,6 +38,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
// identical, of length 3
@transient var densePoints1: DataFrame = _
@transient var sparsePoints1: DataFrame = _
+ @transient var densePoints1TestInvalid: DataFrame = _
+ @transient var sparsePoints1TestInvalid: DataFrame = _
@transient var point1maxes: Array[Double] = _
// identical, of length 2
@@ -55,11 +57,19 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
Vectors.dense(0.0, 1.0, 2.0),
Vectors.dense(0.0, 0.0, -1.0),
Vectors.dense(1.0, 3.0, 2.0))
+ val densePoints1SeqTestInvalid = densePoints1Seq ++ Seq(
+ Vectors.dense(10.0, 2.0, 0.0),
+ Vectors.dense(0.0, 10.0, 2.0),
+ Vectors.dense(1.0, 3.0, 10.0))
val sparsePoints1Seq = Seq(
Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)),
Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)),
Vectors.sparse(3, Array(2), Array(-1.0)),
Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0)))
+ val sparsePoints1SeqTestInvalid = sparsePoints1Seq ++ Seq(
+ Vectors.sparse(3, Array(0, 1), Array(10.0, 2.0)),
+ Vectors.sparse(3, Array(1, 2), Array(10.0, 2.0)),
+ Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 10.0)))
point1maxes = Array(1.0, 3.0, 2.0)
val densePoints2Seq = Seq(
@@ -88,6 +98,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
densePoints1 = densePoints1Seq.map(FeatureData).toDF()
sparsePoints1 = sparsePoints1Seq.map(FeatureData).toDF()
+ densePoints1TestInvalid = densePoints1SeqTestInvalid.map(FeatureData).toDF()
+ sparsePoints1TestInvalid = sparsePoints1SeqTestInvalid.map(FeatureData).toDF()
densePoints2 = densePoints2Seq.map(FeatureData).toDF()
sparsePoints2 = sparsePoints2Seq.map(FeatureData).toDF()
badPoints = badPointsSeq.map(FeatureData).toDF()
@@ -219,6 +231,33 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3))
}
+ test("handle invalid") {
+ for ((points, pointsTestInvalid) <- Seq((densePoints1, densePoints1TestInvalid),
+ (sparsePoints1, sparsePoints1TestInvalid))) {
+ val vectorIndexer = getIndexer.setMaxCategories(4).setHandleInvalid("error")
+ val model = vectorIndexer.fit(points)
+ intercept[SparkException] {
+ model.transform(pointsTestInvalid).collect()
+ }
+ val vectorIndexer1 = getIndexer.setMaxCategories(4).setHandleInvalid("skip")
+ val model1 = vectorIndexer1.fit(points)
+ val invalidTransformed1 = model1.transform(pointsTestInvalid).select("indexed")
+ .collect().map(_(0))
+ val transformed1 = model1.transform(points).select("indexed").collect().map(_(0))
+ assert(transformed1 === invalidTransformed1)
+
+ val vectorIndexer2 = getIndexer.setMaxCategories(4).setHandleInvalid("keep")
+ val model2 = vectorIndexer2.fit(points)
+ val invalidTransformed2 = model2.transform(pointsTestInvalid).select("indexed")
+ .collect().map(_(0))
+ assert(invalidTransformed2 === transformed1 ++ Array(
+ Vectors.dense(2.0, 2.0, 0.0),
+ Vectors.dense(0.0, 4.0, 2.0),
+ Vectors.dense(1.0, 3.0, 3.0))
+ )
+ }
+ }
+
test("Maintain sparsity for sparse vectors") {
def checkSparsity(data: DataFrame, maxCategories: Int): Unit = {
val points = data.collect().map(_.getAs[Vector](0))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org