You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/04/02 23:41:32 UTC

spark git commit: [SPARK-23690][ML] Add handleinvalid to VectorAssembler

Repository: spark
Updated Branches:
  refs/heads/master 28ea4e314 -> a1351828d


[SPARK-23690][ML] Add handleinvalid to VectorAssembler

## What changes were proposed in this pull request?

Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found.

## How was this patch tested?

Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases.

Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>
Author: Bago Amirbekian <ba...@databricks.com>
Author: Yogesh Garg <10...@users.noreply.github.com>

Closes #20829 from yogeshg/rformula_handleinvalid.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a1351828
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a1351828
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a1351828

Branch: refs/heads/master
Commit: a1351828d376a01e5ee0959cf608f767d756dd86
Parents: 28ea4e3
Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>
Authored: Mon Apr 2 16:41:26 2018 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Apr 2 16:41:26 2018 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/StringIndexer.scala |   2 +-
 .../spark/ml/feature/VectorAssembler.scala      | 198 +++++++++++++++----
 .../spark/ml/feature/VectorAssemblerSuite.scala | 131 ++++++++++--
 3 files changed, 284 insertions(+), 47 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a1351828/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 1cdcdfc..67cdb09 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -234,7 +234,7 @@ class StringIndexerModel (
     val metadata = NominalAttribute.defaultAttr
       .withName($(outputCol)).withValues(filteredLabels).toMetadata()
     // If we are skipping invalid records, filter them out.
-    val (filteredDataset, keepInvalid) = getHandleInvalid match {
+    val (filteredDataset, keepInvalid) = $(handleInvalid) match {
       case StringIndexer.SKIP_INVALID =>
         val filterer = udf { label: String =>
           labelToIndex.contains(label)

http://git-wip-us.apache.org/repos/asf/spark/blob/a1351828/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index b373ae9..6bf4aa3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -17,14 +17,17 @@
 
 package org.apache.spark.ml.feature
 
-import scala.collection.mutable.ArrayBuilder
+import java.util.NoSuchElementException
+
+import scala.collection.mutable
+import scala.language.existentials
 
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
 import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
@@ -33,10 +36,14 @@ import org.apache.spark.sql.types._
 
 /**
  * A feature transformer that merges multiple columns into a vector column.
+ *
+ * This requires one pass over the entire dataset. In case we need to infer column lengths from the
+ * data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter.
  */
 @Since("1.4.0")
 class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
-  extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
+  extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid
+    with DefaultParamsWritable {
 
   @Since("1.4.0")
   def this() = this(Identifiable.randomUID("vecAssembler"))
@@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
   @Since("1.4.0")
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
+  /** @group setParam */
+  @Since("2.4.0")
+  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
+  /**
+   * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
+   * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
+   * output). Column lengths are taken from the size of ML Attribute Group, which can be set using
+   * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
+   * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
+   * Default: "error"
+   * @group param
+   */
+  @Since("2.4.0")
+  override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
+    """Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
+      |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
+      |output). Column lengths are taken from the size of ML Attribute Group, which can be set using
+      |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
+      |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
+      |""".stripMargin.replaceAll("\n", " "),
+    ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
+
+  setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
+
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     // Schema transformation.
     val schema = dataset.schema
-    lazy val first = dataset.toDF.first()
-    val attrs = $(inputCols).flatMap { c =>
+
+    val vectorCols = $(inputCols).filter { c =>
+      schema(c).dataType match {
+        case _: VectorUDT => true
+        case _ => false
+      }
+    }
+    val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid))
+
+    val featureAttributesMap = $(inputCols).map { c =>
       val field = schema(c)
-      val index = schema.fieldIndex(c)
       field.dataType match {
         case DoubleType =>
-          val attr = Attribute.fromStructField(field)
-          // If the input column doesn't have ML attribute, assume numeric.
-          if (attr == UnresolvedAttribute) {
-            Some(NumericAttribute.defaultAttr.withName(c))
-          } else {
-            Some(attr.withName(c))
+          val attribute = Attribute.fromStructField(field)
+          attribute match {
+            case UnresolvedAttribute =>
+              Seq(NumericAttribute.defaultAttr.withName(c))
+            case _ =>
+              Seq(attribute.withName(c))
           }
         case _: NumericType | BooleanType =>
           // If the input column type is a compatible scalar type, assume numeric.
-          Some(NumericAttribute.defaultAttr.withName(c))
+          Seq(NumericAttribute.defaultAttr.withName(c))
         case _: VectorUDT =>
-          val group = AttributeGroup.fromStructField(field)
-          if (group.attributes.isDefined) {
-            // If attributes are defined, copy them with updated names.
-            group.attributes.get.zipWithIndex.map { case (attr, i) =>
+          val attributeGroup = AttributeGroup.fromStructField(field)
+          if (attributeGroup.attributes.isDefined) {
+            attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) =>
               if (attr.name.isDefined) {
                 // TODO: Define a rigorous naming scheme.
                 attr.withName(c + "_" + attr.name.get)
@@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
           } else {
             // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
             // from metadata, check the first row.
-            val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
-            Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
+            (0 until vectorColsLengths(c)).map { i =>
+              NumericAttribute.defaultAttr.withName(c + "_" + i)
+            }
           }
         case otherType =>
           throw new SparkException(s"VectorAssembler does not support the $otherType type")
       }
     }
-    val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
-
+    val featureAttributes = featureAttributesMap.flatten[Attribute].toArray
+    val lengths = featureAttributesMap.map(a => a.length).toArray
+    val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata()
+    val (filteredDataset, keepInvalid) = $(handleInvalid) match {
+      case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false)
+      case VectorAssembler.KEEP_INVALID => (dataset, true)
+      case VectorAssembler.ERROR_INVALID => (dataset, false)
+    }
     // Data transformation.
     val assembleFunc = udf { r: Row =>
-      VectorAssembler.assemble(r.toSeq: _*)
+      VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
     }.asNondeterministic()
     val args = $(inputCols).map { c =>
       schema(c).dataType match {
@@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
       }
     }
 
-    dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
+    filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
   }
 
   @Since("1.4.0")
@@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
 @Since("1.6.0")
 object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
 
+  private[feature] val SKIP_INVALID: String = "skip"
+  private[feature] val ERROR_INVALID: String = "error"
+  private[feature] val KEEP_INVALID: String = "keep"
+  private[feature] val supportedHandleInvalids: Array[String] =
+    Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
+
+  /**
+   * Infers lengths of vector columns from the first row of the dataset
+   * @param dataset the dataset
+   * @param columns name of vector columns whose lengths need to be inferred
+   * @return map of column names to lengths
+   */
+  private[feature] def getVectorLengthsFromFirstRow(
+      dataset: Dataset[_],
+      columns: Seq[String]): Map[String, Int] = {
+    try {
+      val first_row = dataset.toDF().select(columns.map(col): _*).first()
+      columns.zip(first_row.toSeq).map {
+        case (c, x) => c -> x.asInstanceOf[Vector].size
+      }.toMap
+    } catch {
+      case e: NullPointerException => throw new NullPointerException(
+        s"""Encountered null value while inferring lengths from the first row. Consider using
+           |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
+          .stripMargin.replaceAll("\n", " ") + e.toString)
+      case e: NoSuchElementException => throw new NoSuchElementException(
+        s"""Encountered empty dataframe while inferring lengths from the first row. Consider using
+           |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
+          .stripMargin.replaceAll("\n", " ") + e.toString)
+    }
+  }
+
+  private[feature] def getLengths(
+      dataset: Dataset[_],
+      columns: Seq[String],
+      handleInvalid: String): Map[String, Int] = {
+    val groupSizes = columns.map { c =>
+      c -> AttributeGroup.fromStructField(dataset.schema(c)).size
+    }.toMap
+    val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq
+    val firstSizes = (missingColumns.nonEmpty, handleInvalid) match {
+      case (true, VectorAssembler.ERROR_INVALID) =>
+        getVectorLengthsFromFirstRow(dataset, missingColumns)
+      case (true, VectorAssembler.SKIP_INVALID) =>
+        getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns)
+      case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException(
+        s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint
+           |to add metadata for columns: ${columns.mkString("[", ", ", "]")}."""
+          .stripMargin.replaceAll("\n", " "))
+      case (_, _) => Map.empty
+    }
+    groupSizes ++ firstSizes
+  }
+
+
   @Since("1.6.0")
   override def load(path: String): VectorAssembler = super.load(path)
 
-  private[feature] def assemble(vv: Any*): Vector = {
-    val indices = ArrayBuilder.make[Int]
-    val values = ArrayBuilder.make[Double]
-    var cur = 0
+  /**
+   * Returns a function that has the required information to assemble each row.
+   * @param lengths an array of lengths of input columns, whose size should be equal to the number
+   *                of cells in the row (vv)
+   * @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows
+   * @return  a udf that can be applied on each row
+   */
+  private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = {
+    val indices = mutable.ArrayBuilder.make[Int]
+    val values = mutable.ArrayBuilder.make[Double]
+    var featureIndex = 0
+
+    var inputColumnIndex = 0
     vv.foreach {
       case v: Double =>
-        if (v != 0.0) {
-          indices += cur
+        if (v.isNaN && !keepInvalid) {
+          throw new SparkException(
+            s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider
+               |removing NaNs from dataset or using handleInvalid = "keep" or "skip"."""
+              .stripMargin)
+        } else if (v != 0.0) {
+          indices += featureIndex
           values += v
         }
-        cur += 1
+        inputColumnIndex += 1
+        featureIndex += 1
       case vec: Vector =>
         vec.foreachActive { case (i, v) =>
           if (v != 0.0) {
-            indices += cur + i
+            indices += featureIndex + i
             values += v
           }
         }
-        cur += vec.size
+        inputColumnIndex += 1
+        featureIndex += vec.size
       case null =>
-        // TODO: output Double.NaN?
-        throw new SparkException("Values to assemble cannot be null.")
+        if (keepInvalid) {
+          val length: Int = lengths(inputColumnIndex)
+          Array.range(0, length).foreach { i =>
+            indices += featureIndex + i
+            values += Double.NaN
+          }
+          inputColumnIndex += 1
+          featureIndex += length
+        } else {
+          throw new SparkException(
+            s"""Encountered null while assembling a row with handleInvalid = "keep". Consider
+               |removing nulls from dataset or using handleInvalid = "keep" or "skip"."""
+              .stripMargin)
+        }
       case o =>
         throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
     }
-    Vectors.sparse(cur, indices.result(), values.result()).compressed
+    Vectors.sparse(featureIndex, indices.result(), values.result()).compressed
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a1351828/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index eca065f..91fb24a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -18,12 +18,12 @@
 package org.apache.spark.ml.feature
 
 import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute}
 import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.functions.{col, udf}
 
 class VectorAssemblerSuite
@@ -31,30 +31,49 @@ class VectorAssemblerSuite
 
   import testImplicits._
 
+  @transient var dfWithNullsAndNaNs: Dataset[_] = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    val sv = Vectors.sparse(2, Array(1), Array(3.0))
+    dfWithNullsAndNaNs = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long, String)](
+      (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", sv, 7L, null),
+      (2, 1, 0.0, null, "a", sv, 6L, null),
+      (3, 3, null, Vectors.dense(1.0, 2.0), "a", sv, 8L, null),
+      (4, 4, null, null, "a", sv, 9L, null),
+      (5, 5, java.lang.Double.NaN, Vectors.dense(1.0, 2.0), "a", sv, 7L, null),
+      (6, 6, java.lang.Double.NaN, null, "a", sv, 8L, null))
+      .toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls")
+  }
+
   test("params") {
     ParamsSuite.checkParams(new VectorAssembler)
   }
 
   test("assemble") {
     import org.apache.spark.ml.feature.VectorAssembler.assemble
-    assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
-    assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0)))
+    assert(assemble(Array(1), keepInvalid = true)(0.0)
+      === Vectors.sparse(1, Array.empty, Array.empty))
+    assert(assemble(Array(1, 1), keepInvalid = true)(0.0, 1.0)
+      === Vectors.sparse(2, Array(1), Array(1.0)))
     val dv = Vectors.dense(2.0, 0.0)
-    assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
+    assert(assemble(Array(1, 2, 1), keepInvalid = true)(0.0, dv, 1.0) ===
+      Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
     val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
-    assert(assemble(0.0, dv, 1.0, sv) ===
+    assert(assemble(Array(1, 2, 1, 2), keepInvalid = true)(0.0, dv, 1.0, sv) ===
       Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
-    for (v <- Seq(1, "a", null)) {
-      intercept[SparkException](assemble(v))
-      intercept[SparkException](assemble(1.0, v))
+    for (v <- Seq(1, "a")) {
+      intercept[SparkException](assemble(Array(1), keepInvalid = true)(v))
+      intercept[SparkException](assemble(Array(1, 1), keepInvalid = true)(1.0, v))
     }
   }
 
   test("assemble should compress vectors") {
     import org.apache.spark.ml.feature.VectorAssembler.assemble
-    val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
+    val v1 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(0.0, 0.0, 0.0, Vectors.dense(4.0))
     assert(v1.isInstanceOf[SparseVector])
-    val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0)))
+    val sv = Vectors.sparse(1, Array(0), Array(4.0))
+    val v2 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(1.0, 2.0, 3.0, sv)
     assert(v2.isInstanceOf[DenseVector])
   }
 
@@ -147,4 +166,94 @@ class VectorAssemblerSuite
       .filter(vectorUDF($"features") > 1)
       .count() == 1)
   }
+
+  test("assemble should keep nulls when keepInvalid is true") {
+    import org.apache.spark.ml.feature.VectorAssembler.assemble
+    assert(assemble(Array(1, 1), keepInvalid = true)(1.0, null) === Vectors.dense(1.0, Double.NaN))
+    assert(assemble(Array(1, 2), keepInvalid = true)(1.0, null)
+      === Vectors.dense(1.0, Double.NaN, Double.NaN))
+    assert(assemble(Array(1), keepInvalid = true)(null) === Vectors.dense(Double.NaN))
+    assert(assemble(Array(2), keepInvalid = true)(null) === Vectors.dense(Double.NaN, Double.NaN))
+  }
+
+  test("assemble should throw errors when keepInvalid is false") {
+    import org.apache.spark.ml.feature.VectorAssembler.assemble
+    intercept[SparkException](assemble(Array(1, 1), keepInvalid = false)(1.0, null))
+    intercept[SparkException](assemble(Array(1, 2), keepInvalid = false)(1.0, null))
+    intercept[SparkException](assemble(Array(1), keepInvalid = false)(null))
+    intercept[SparkException](assemble(Array(2), keepInvalid = false)(null))
+  }
+
+  test("get lengths functions") {
+    import org.apache.spark.ml.feature.VectorAssembler._
+    val df = dfWithNullsAndNaNs
+    assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2))
+    assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y")))
+      .getMessage.contains("VectorSizeHint"))
+    assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 6"),
+      Seq("y"))).getMessage.contains("VectorSizeHint"))
+
+    assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2))
+    assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID))
+      .getMessage.contains("VectorSizeHint"))
+    assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID))
+      .getMessage.contains("VectorSizeHint"))
+  }
+
+  test("Handle Invalid should behave properly") {
+    val assembler = new VectorAssembler()
+      .setInputCols(Array("x", "y", "z", "n"))
+      .setOutputCol("features")
+
+    def runWithMetadata(mode: String, additional_filter: String = "true"): Dataset[_] = {
+      val attributeY = new AttributeGroup("y", 2)
+      val attributeZ = new AttributeGroup(
+        "z",
+        Array[Attribute](
+          NumericAttribute.defaultAttr.withName("foo"),
+          NumericAttribute.defaultAttr.withName("bar")))
+      val dfWithMetadata = dfWithNullsAndNaNs.withColumn("y", col("y"), attributeY.toMetadata())
+        .withColumn("z", col("z"), attributeZ.toMetadata()).filter(additional_filter)
+      val output = assembler.setHandleInvalid(mode).transform(dfWithMetadata)
+      output.collect()
+      output
+    }
+
+    def runWithFirstRow(mode: String): Dataset[_] = {
+      val output = assembler.setHandleInvalid(mode).transform(dfWithNullsAndNaNs)
+      output.collect()
+      output
+    }
+
+    def runWithAllNullVectors(mode: String): Dataset[_] = {
+      val output = assembler.setHandleInvalid(mode)
+        .transform(dfWithNullsAndNaNs.filter("0 == id1 % 2"))
+      output.collect()
+      output
+    }
+
+    // behavior when vector size hint is given
+    assert(runWithMetadata("keep").count() == 6, "should keep all rows")
+    assert(runWithMetadata("skip").count() == 1, "should skip rows with nulls")
+    // should throw error with nulls
+    intercept[SparkException](runWithMetadata("error"))
+    // should throw error with NaNs
+    intercept[SparkException](runWithMetadata("error", additional_filter = "id1 > 4"))
+
+    // behavior when first row has information
+    assert(intercept[RuntimeException](runWithFirstRow("keep").count())
+      .getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
+    assert(runWithFirstRow("skip").count() == 1, "should infer size and skip rows with nulls")
+    intercept[SparkException](runWithFirstRow("error"))
+
+    // behavior when vector column is all null
+    assert(intercept[RuntimeException](runWithAllNullVectors("skip"))
+      .getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
+    assert(intercept[NullPointerException](runWithAllNullVectors("error"))
+      .getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
+
+    // behavior when scalar column is all null
+    assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4)
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org