You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ya...@apache.org on 2020/08/02 23:58:14 UTC

[spark] branch branch-2.4 updated: [SPARK-28818][SQL][2.4] Respect source column nullability in the arrays created by `freqItems()`

This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 91f2a25  [SPARK-28818][SQL][2.4] Respect source column nullability in the arrays created by `freqItems()`
91f2a25 is described below

commit 91f2a2548ad0f825fc4b5c67264e11abb76bbd9d
Author: Matt Hawes <mh...@palantir.com>
AuthorDate: Mon Aug 3 08:55:28 2020 +0900

    [SPARK-28818][SQL][2.4] Respect source column nullability in the arrays created by `freqItems()`
    
    ### What changes were proposed in this pull request?
    This PR replaces the hard-coded non-nullability of the array elements returned by `freqItems()` with a nullability that reflects the original schema. Essentially [the functional change](https://github.com/apache/spark/pull/25575/files#diff-bf59bb9f3dc351f5bf6624e5edd2dcf4R122) to the schema generation is:
    ```
    StructField(name + "_freqItems", ArrayType(dataType, false))
    ```
    Becomes:
    ```
    StructField(name + "_freqItems", ArrayType(dataType, originalField.nullable))
    ```
    
    Respecting the original nullability prevents issues when Spark depends on `ArrayType`'s `containsNull` being accurate. The example that uncovered this is calling `collect()` on the dataframe (see [ticket](https://issues.apache.org/jira/browse/SPARK-28818) for full repro). Though it's likely that there a several places where this could cause a problem.
    
    I've also refactored a small amount of the surrounding code to remove some unnecessary steps and group together related operations.
    
    Note: This is the backport PR of #25575 and the credit should be MGHawes.
    
    ### Why are the changes needed?
    I think it's pretty clear why this change is needed. It fixes a bug that currently prevents users from calling `df.freqItems.collect()` along with potentially causing other, as yet unknown, issues.
    
    ### Does this PR introduce any user-facing change?
    Nullability of columns when calling freqItems on them is now respected after the change.
    
    ### How was this patch tested?
    I added a test that specifically tests the carry-through of the nullability as well as explicitly calling `collect()` to catch the exact regression that was observed. I also ran the test against the old version of the code and it fails as expected.
    
    Closes #29327 from maropu/SPARK-28818-2.4.
    
    Lead-authored-by: Matt Hawes <mh...@palantir.com>
    Co-authored-by: Takeshi Yamamuro <ya...@apache.org>
    Signed-off-by: Takeshi Yamamuro <ya...@apache.org>
---
 .../spark/sql/execution/stat/FrequentItems.scala   | 19 ++++++++--------
 .../org/apache/spark/sql/DataFrameStatSuite.scala  | 26 +++++++++++++++++++++-
 2 files changed, 35 insertions(+), 10 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 86f6307..f21efd4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -89,11 +89,6 @@ object FrequentItems extends Logging {
     // number of max items to keep counts for
     val sizeOfMap = (1 / support).toInt
     val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
-    val originalSchema = df.schema
-    val colInfo: Array[(String, DataType)] = cols.map { name =>
-      val index = originalSchema.fieldIndex(name)
-      (name, originalSchema.fields(index).dataType)
-    }.toArray
 
     val freqItems = df.select(cols.map(Column(_)) : _*).rdd.treeAggregate(countMaps)(
       seqOp = (counts, row) => {
@@ -117,10 +112,16 @@ object FrequentItems extends Logging {
     )
     val justItems = freqItems.map(m => m.baseMap.keys.toArray)
     val resultRow = Row(justItems : _*)
-    // append frequent Items to the column name for easy debugging
-    val outputCols = colInfo.map { v =>
-      StructField(v._1 + "_freqItems", ArrayType(v._2, false))
-    }
+
+    val originalSchema = df.schema
+    val outputCols = cols.map { name =>
+      val index = originalSchema.fieldIndex(name)
+      val originalField = originalSchema.fields(index)
+
+      // append frequent Items to the column name for easy debugging
+      StructField(name + "_freqItems", ArrayType(originalField.dataType, originalField.nullable))
+    }.toArray
+
     val schema = StructType(outputCols).toAttributes
     Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 8eae353..23a1fc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
+import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType}
 
 class DataFrameStatSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
@@ -366,6 +366,30 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
     }
   }
 
+  test("SPARK-28818: Respect original column nullability in `freqItems`") {
+    val rows = spark.sparkContext.parallelize(
+      Seq(Row("1", "a"), Row("2", null), Row("3", "b"))
+    )
+    val schema = StructType(Seq(
+      StructField("non_null", StringType, false),
+      StructField("nullable", StringType, true)
+    ))
+    val df = spark.createDataFrame(rows, schema)
+
+    val result = df.stat.freqItems(df.columns)
+
+    val nonNullableDataType = result.schema("non_null_freqItems").dataType.asInstanceOf[ArrayType]
+    val nullableDataType = result.schema("nullable_freqItems").dataType.asInstanceOf[ArrayType]
+
+    assert(nonNullableDataType.containsNull == false)
+    assert(nullableDataType.containsNull == true)
+    // Original bug was a NullPointerException exception caused by calling collect(), test for this
+    val resultRow = result.collect()(0)
+
+    assert(resultRow.get(0).asInstanceOf[Seq[String]].toSet == Set("1", "2", "3"))
+    assert(resultRow.get(1).asInstanceOf[Seq[String]].toSet == Set("a", "b", null))
+  }
+
   test("sampleBy") {
     val df = spark.range(0, 100).select((col("id") % 3).as("key"))
     val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org