You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/04/02 23:41:32 UTC
spark git commit: [SPARK-23690][ML] Add handleinvalid to
VectorAssembler
Repository: spark
Updated Branches:
refs/heads/master 28ea4e314 -> a1351828d
[SPARK-23690][ML] Add handleinvalid to VectorAssembler
## What changes were proposed in this pull request?
Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found.
## How was this patch tested?
Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases.
Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>
Author: Bago Amirbekian <ba...@databricks.com>
Author: Yogesh Garg <10...@users.noreply.github.com>
Closes #20829 from yogeshg/rformula_handleinvalid.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a1351828
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a1351828
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a1351828
Branch: refs/heads/master
Commit: a1351828d376a01e5ee0959cf608f767d756dd86
Parents: 28ea4e3
Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>
Authored: Mon Apr 2 16:41:26 2018 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Apr 2 16:41:26 2018 -0700
----------------------------------------------------------------------
.../apache/spark/ml/feature/StringIndexer.scala | 2 +-
.../spark/ml/feature/VectorAssembler.scala | 198 +++++++++++++++----
.../spark/ml/feature/VectorAssemblerSuite.scala | 131 ++++++++++--
3 files changed, 284 insertions(+), 47 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a1351828/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 1cdcdfc..67cdb09 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -234,7 +234,7 @@ class StringIndexerModel (
val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
- val (filteredDataset, keepInvalid) = getHandleInvalid match {
+ val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case StringIndexer.SKIP_INVALID =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
http://git-wip-us.apache.org/repos/asf/spark/blob/a1351828/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index b373ae9..6bf4aa3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -17,14 +17,17 @@
package org.apache.spark.ml.feature
-import scala.collection.mutable.ArrayBuilder
+import java.util.NoSuchElementException
+
+import scala.collection.mutable
+import scala.language.existentials
import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
@@ -33,10 +36,14 @@ import org.apache.spark.sql.types._
/**
* A feature transformer that merges multiple columns into a vector column.
+ *
+ * This requires one pass over the entire dataset. In case we need to infer column lengths from the
+ * data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter.
*/
@Since("1.4.0")
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
- extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
+ extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid
+ with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("vecAssembler"))
@@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
+ /**
+ * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
+ * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
+ * output). Column lengths are taken from the size of ML Attribute Group, which can be set using
+ * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
+ * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
+ * Default: "error"
+ * @group param
+ */
+ @Since("2.4.0")
+ override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
+ """Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
+ |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
+ |output). Column lengths are taken from the size of ML Attribute Group, which can be set using
+ |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
+ |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
+ |""".stripMargin.replaceAll("\n", " "),
+ ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
+
+ setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Schema transformation.
val schema = dataset.schema
- lazy val first = dataset.toDF.first()
- val attrs = $(inputCols).flatMap { c =>
+
+ val vectorCols = $(inputCols).filter { c =>
+ schema(c).dataType match {
+ case _: VectorUDT => true
+ case _ => false
+ }
+ }
+ val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid))
+
+ val featureAttributesMap = $(inputCols).map { c =>
val field = schema(c)
- val index = schema.fieldIndex(c)
field.dataType match {
case DoubleType =>
- val attr = Attribute.fromStructField(field)
- // If the input column doesn't have ML attribute, assume numeric.
- if (attr == UnresolvedAttribute) {
- Some(NumericAttribute.defaultAttr.withName(c))
- } else {
- Some(attr.withName(c))
+ val attribute = Attribute.fromStructField(field)
+ attribute match {
+ case UnresolvedAttribute =>
+ Seq(NumericAttribute.defaultAttr.withName(c))
+ case _ =>
+ Seq(attribute.withName(c))
}
case _: NumericType | BooleanType =>
// If the input column type is a compatible scalar type, assume numeric.
- Some(NumericAttribute.defaultAttr.withName(c))
+ Seq(NumericAttribute.defaultAttr.withName(c))
case _: VectorUDT =>
- val group = AttributeGroup.fromStructField(field)
- if (group.attributes.isDefined) {
- // If attributes are defined, copy them with updated names.
- group.attributes.get.zipWithIndex.map { case (attr, i) =>
+ val attributeGroup = AttributeGroup.fromStructField(field)
+ if (attributeGroup.attributes.isDefined) {
+ attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) =>
if (attr.name.isDefined) {
// TODO: Define a rigorous naming scheme.
attr.withName(c + "_" + attr.name.get)
@@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
} else {
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
// from metadata, check the first row.
- val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
- Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
+ (0 until vectorColsLengths(c)).map { i =>
+ NumericAttribute.defaultAttr.withName(c + "_" + i)
+ }
}
case otherType =>
throw new SparkException(s"VectorAssembler does not support the $otherType type")
}
}
- val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
-
+ val featureAttributes = featureAttributesMap.flatten[Attribute].toArray
+ val lengths = featureAttributesMap.map(a => a.length).toArray
+ val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata()
+ val (filteredDataset, keepInvalid) = $(handleInvalid) match {
+ case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false)
+ case VectorAssembler.KEEP_INVALID => (dataset, true)
+ case VectorAssembler.ERROR_INVALID => (dataset, false)
+ }
// Data transformation.
val assembleFunc = udf { r: Row =>
- VectorAssembler.assemble(r.toSeq: _*)
+ VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
}.asNondeterministic()
val args = $(inputCols).map { c =>
schema(c).dataType match {
@@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
}
}
- dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
+ filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
}
@Since("1.4.0")
@@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.6.0")
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
+ private[feature] val SKIP_INVALID: String = "skip"
+ private[feature] val ERROR_INVALID: String = "error"
+ private[feature] val KEEP_INVALID: String = "keep"
+ private[feature] val supportedHandleInvalids: Array[String] =
+ Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
+
+ /**
+ * Infers lengths of vector columns from the first row of the dataset
+ * @param dataset the dataset
+ * @param columns name of vector columns whose lengths need to be inferred
+ * @return map of column names to lengths
+ */
+ private[feature] def getVectorLengthsFromFirstRow(
+ dataset: Dataset[_],
+ columns: Seq[String]): Map[String, Int] = {
+ try {
+ val first_row = dataset.toDF().select(columns.map(col): _*).first()
+ columns.zip(first_row.toSeq).map {
+ case (c, x) => c -> x.asInstanceOf[Vector].size
+ }.toMap
+ } catch {
+ case e: NullPointerException => throw new NullPointerException(
+ s"""Encountered null value while inferring lengths from the first row. Consider using
+ |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
+ .stripMargin.replaceAll("\n", " ") + e.toString)
+ case e: NoSuchElementException => throw new NoSuchElementException(
+ s"""Encountered empty dataframe while inferring lengths from the first row. Consider using
+ |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
+ .stripMargin.replaceAll("\n", " ") + e.toString)
+ }
+ }
+
+ private[feature] def getLengths(
+ dataset: Dataset[_],
+ columns: Seq[String],
+ handleInvalid: String): Map[String, Int] = {
+ val groupSizes = columns.map { c =>
+ c -> AttributeGroup.fromStructField(dataset.schema(c)).size
+ }.toMap
+ val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq
+ val firstSizes = (missingColumns.nonEmpty, handleInvalid) match {
+ case (true, VectorAssembler.ERROR_INVALID) =>
+ getVectorLengthsFromFirstRow(dataset, missingColumns)
+ case (true, VectorAssembler.SKIP_INVALID) =>
+ getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns)
+ case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException(
+ s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint
+ |to add metadata for columns: ${columns.mkString("[", ", ", "]")}."""
+ .stripMargin.replaceAll("\n", " "))
+ case (_, _) => Map.empty
+ }
+ groupSizes ++ firstSizes
+ }
+
+
@Since("1.6.0")
override def load(path: String): VectorAssembler = super.load(path)
- private[feature] def assemble(vv: Any*): Vector = {
- val indices = ArrayBuilder.make[Int]
- val values = ArrayBuilder.make[Double]
- var cur = 0
+ /**
+ * Returns a function that has the required information to assemble each row.
+ * @param lengths an array of lengths of input columns, whose size should be equal to the number
+ * of cells in the row (vv)
+ * @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows
+ * @return a udf that can be applied on each row
+ */
+ private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = {
+ val indices = mutable.ArrayBuilder.make[Int]
+ val values = mutable.ArrayBuilder.make[Double]
+ var featureIndex = 0
+
+ var inputColumnIndex = 0
vv.foreach {
case v: Double =>
- if (v != 0.0) {
- indices += cur
+ if (v.isNaN && !keepInvalid) {
+ throw new SparkException(
+ s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider
+ |removing NaNs from dataset or using handleInvalid = "keep" or "skip"."""
+ .stripMargin)
+ } else if (v != 0.0) {
+ indices += featureIndex
values += v
}
- cur += 1
+ inputColumnIndex += 1
+ featureIndex += 1
case vec: Vector =>
vec.foreachActive { case (i, v) =>
if (v != 0.0) {
- indices += cur + i
+ indices += featureIndex + i
values += v
}
}
- cur += vec.size
+ inputColumnIndex += 1
+ featureIndex += vec.size
case null =>
- // TODO: output Double.NaN?
- throw new SparkException("Values to assemble cannot be null.")
+ if (keepInvalid) {
+ val length: Int = lengths(inputColumnIndex)
+ Array.range(0, length).foreach { i =>
+ indices += featureIndex + i
+ values += Double.NaN
+ }
+ inputColumnIndex += 1
+ featureIndex += length
+ } else {
+ throw new SparkException(
+ s"""Encountered null while assembling a row with handleInvalid = "keep". Consider
+ |removing nulls from dataset or using handleInvalid = "keep" or "skip"."""
+ .stripMargin)
+ }
case o =>
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
}
- Vectors.sparse(cur, indices.result(), values.result()).compressed
+ Vectors.sparse(featureIndex, indices.result(), values.result()).compressed
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a1351828/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index eca065f..91fb24a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
class VectorAssemblerSuite
@@ -31,30 +31,49 @@ class VectorAssemblerSuite
import testImplicits._
+ @transient var dfWithNullsAndNaNs: Dataset[_] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ val sv = Vectors.sparse(2, Array(1), Array(3.0))
+ dfWithNullsAndNaNs = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long, String)](
+ (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", sv, 7L, null),
+ (2, 1, 0.0, null, "a", sv, 6L, null),
+ (3, 3, null, Vectors.dense(1.0, 2.0), "a", sv, 8L, null),
+ (4, 4, null, null, "a", sv, 9L, null),
+ (5, 5, java.lang.Double.NaN, Vectors.dense(1.0, 2.0), "a", sv, 7L, null),
+ (6, 6, java.lang.Double.NaN, null, "a", sv, 8L, null))
+ .toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls")
+ }
+
test("params") {
ParamsSuite.checkParams(new VectorAssembler)
}
test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
- assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
- assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0)))
+ assert(assemble(Array(1), keepInvalid = true)(0.0)
+ === Vectors.sparse(1, Array.empty, Array.empty))
+ assert(assemble(Array(1, 1), keepInvalid = true)(0.0, 1.0)
+ === Vectors.sparse(2, Array(1), Array(1.0)))
val dv = Vectors.dense(2.0, 0.0)
- assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
+ assert(assemble(Array(1, 2, 1), keepInvalid = true)(0.0, dv, 1.0) ===
+ Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
- assert(assemble(0.0, dv, 1.0, sv) ===
+ assert(assemble(Array(1, 2, 1, 2), keepInvalid = true)(0.0, dv, 1.0, sv) ===
Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
- for (v <- Seq(1, "a", null)) {
- intercept[SparkException](assemble(v))
- intercept[SparkException](assemble(1.0, v))
+ for (v <- Seq(1, "a")) {
+ intercept[SparkException](assemble(Array(1), keepInvalid = true)(v))
+ intercept[SparkException](assemble(Array(1, 1), keepInvalid = true)(1.0, v))
}
}
test("assemble should compress vectors") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
- val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
+ val v1 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(0.0, 0.0, 0.0, Vectors.dense(4.0))
assert(v1.isInstanceOf[SparseVector])
- val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0)))
+ val sv = Vectors.sparse(1, Array(0), Array(4.0))
+ val v2 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(1.0, 2.0, 3.0, sv)
assert(v2.isInstanceOf[DenseVector])
}
@@ -147,4 +166,94 @@ class VectorAssemblerSuite
.filter(vectorUDF($"features") > 1)
.count() == 1)
}
+
+ test("assemble should keep nulls when keepInvalid is true") {
+ import org.apache.spark.ml.feature.VectorAssembler.assemble
+ assert(assemble(Array(1, 1), keepInvalid = true)(1.0, null) === Vectors.dense(1.0, Double.NaN))
+ assert(assemble(Array(1, 2), keepInvalid = true)(1.0, null)
+ === Vectors.dense(1.0, Double.NaN, Double.NaN))
+ assert(assemble(Array(1), keepInvalid = true)(null) === Vectors.dense(Double.NaN))
+ assert(assemble(Array(2), keepInvalid = true)(null) === Vectors.dense(Double.NaN, Double.NaN))
+ }
+
+ test("assemble should throw errors when keepInvalid is false") {
+ import org.apache.spark.ml.feature.VectorAssembler.assemble
+ intercept[SparkException](assemble(Array(1, 1), keepInvalid = false)(1.0, null))
+ intercept[SparkException](assemble(Array(1, 2), keepInvalid = false)(1.0, null))
+ intercept[SparkException](assemble(Array(1), keepInvalid = false)(null))
+ intercept[SparkException](assemble(Array(2), keepInvalid = false)(null))
+ }
+
+ test("get lengths functions") {
+ import org.apache.spark.ml.feature.VectorAssembler._
+ val df = dfWithNullsAndNaNs
+ assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2))
+ assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y")))
+ .getMessage.contains("VectorSizeHint"))
+ assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 6"),
+ Seq("y"))).getMessage.contains("VectorSizeHint"))
+
+ assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2))
+ assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID))
+ .getMessage.contains("VectorSizeHint"))
+ assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID))
+ .getMessage.contains("VectorSizeHint"))
+ }
+
+ test("Handle Invalid should behave properly") {
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("x", "y", "z", "n"))
+ .setOutputCol("features")
+
+ def runWithMetadata(mode: String, additional_filter: String = "true"): Dataset[_] = {
+ val attributeY = new AttributeGroup("y", 2)
+ val attributeZ = new AttributeGroup(
+ "z",
+ Array[Attribute](
+ NumericAttribute.defaultAttr.withName("foo"),
+ NumericAttribute.defaultAttr.withName("bar")))
+ val dfWithMetadata = dfWithNullsAndNaNs.withColumn("y", col("y"), attributeY.toMetadata())
+ .withColumn("z", col("z"), attributeZ.toMetadata()).filter(additional_filter)
+ val output = assembler.setHandleInvalid(mode).transform(dfWithMetadata)
+ output.collect()
+ output
+ }
+
+ def runWithFirstRow(mode: String): Dataset[_] = {
+ val output = assembler.setHandleInvalid(mode).transform(dfWithNullsAndNaNs)
+ output.collect()
+ output
+ }
+
+ def runWithAllNullVectors(mode: String): Dataset[_] = {
+ val output = assembler.setHandleInvalid(mode)
+ .transform(dfWithNullsAndNaNs.filter("0 == id1 % 2"))
+ output.collect()
+ output
+ }
+
+ // behavior when vector size hint is given
+ assert(runWithMetadata("keep").count() == 6, "should keep all rows")
+ assert(runWithMetadata("skip").count() == 1, "should skip rows with nulls")
+ // should throw error with nulls
+ intercept[SparkException](runWithMetadata("error"))
+ // should throw error with NaNs
+ intercept[SparkException](runWithMetadata("error", additional_filter = "id1 > 4"))
+
+ // behavior when first row has information
+ assert(intercept[RuntimeException](runWithFirstRow("keep").count())
+ .getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
+ assert(runWithFirstRow("skip").count() == 1, "should infer size and skip rows with nulls")
+ intercept[SparkException](runWithFirstRow("error"))
+
+ // behavior when vector column is all null
+ assert(intercept[RuntimeException](runWithAllNullVectors("skip"))
+ .getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
+ assert(intercept[NullPointerException](runWithAllNullVectors("error"))
+ .getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
+
+ // behavior when scalar column is all null
+ assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4)
+ }
+
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org