You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ya...@apache.org on 2020/08/02 23:58:14 UTC
[spark] branch branch-2.4 updated: [SPARK-28818][SQL][2.4] Respect
source column nullability in the arrays created by `freqItems()`
This is an automated email from the ASF dual-hosted git repository.
yamamuro pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new 91f2a25 [SPARK-28818][SQL][2.4] Respect source column nullability in the arrays created by `freqItems()`
91f2a25 is described below
commit 91f2a2548ad0f825fc4b5c67264e11abb76bbd9d
Author: Matt Hawes <mh...@palantir.com>
AuthorDate: Mon Aug 3 08:55:28 2020 +0900
[SPARK-28818][SQL][2.4] Respect source column nullability in the arrays created by `freqItems()`
### What changes were proposed in this pull request?
This PR replaces the hard-coded non-nullability of the array elements returned by `freqItems()` with a nullability that reflects the original schema. Essentially [the functional change](https://github.com/apache/spark/pull/25575/files#diff-bf59bb9f3dc351f5bf6624e5edd2dcf4R122) to the schema generation is:
```
StructField(name + "_freqItems", ArrayType(dataType, false))
```
Becomes:
```
StructField(name + "_freqItems", ArrayType(dataType, originalField.nullable))
```
Respecting the original nullability prevents issues when Spark depends on `ArrayType`'s `containsNull` being accurate. The example that uncovered this is calling `collect()` on the dataframe (see [ticket](https://issues.apache.org/jira/browse/SPARK-28818) for full repro). Though it's likely that there a several places where this could cause a problem.
I've also refactored a small amount of the surrounding code to remove some unnecessary steps and group together related operations.
Note: This is the backport PR of #25575 and the credit should be MGHawes.
### Why are the changes needed?
I think it's pretty clear why this change is needed. It fixes a bug that currently prevents users from calling `df.freqItems.collect()` along with potentially causing other, as yet unknown, issues.
### Does this PR introduce any user-facing change?
Nullability of columns when calling freqItems on them is now respected after the change.
### How was this patch tested?
I added a test that specifically tests the carry-through of the nullability as well as explicitly calling `collect()` to catch the exact regression that was observed. I also ran the test against the old version of the code and it fails as expected.
Closes #29327 from maropu/SPARK-28818-2.4.
Lead-authored-by: Matt Hawes <mh...@palantir.com>
Co-authored-by: Takeshi Yamamuro <ya...@apache.org>
Signed-off-by: Takeshi Yamamuro <ya...@apache.org>
---
.../spark/sql/execution/stat/FrequentItems.scala | 19 ++++++++--------
.../org/apache/spark/sql/DataFrameStatSuite.scala | 26 +++++++++++++++++++++-
2 files changed, 35 insertions(+), 10 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 86f6307..f21efd4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -89,11 +89,6 @@ object FrequentItems extends Logging {
// number of max items to keep counts for
val sizeOfMap = (1 / support).toInt
val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
- val originalSchema = df.schema
- val colInfo: Array[(String, DataType)] = cols.map { name =>
- val index = originalSchema.fieldIndex(name)
- (name, originalSchema.fields(index).dataType)
- }.toArray
val freqItems = df.select(cols.map(Column(_)) : _*).rdd.treeAggregate(countMaps)(
seqOp = (counts, row) => {
@@ -117,10 +112,16 @@ object FrequentItems extends Logging {
)
val justItems = freqItems.map(m => m.baseMap.keys.toArray)
val resultRow = Row(justItems : _*)
- // append frequent Items to the column name for easy debugging
- val outputCols = colInfo.map { v =>
- StructField(v._1 + "_freqItems", ArrayType(v._2, false))
- }
+
+ val originalSchema = df.schema
+ val outputCols = cols.map { name =>
+ val index = originalSchema.fieldIndex(name)
+ val originalField = originalSchema.fields(index)
+
+ // append frequent Items to the column name for easy debugging
+ StructField(name + "_freqItems", ArrayType(originalField.dataType, originalField.nullable))
+ }.toArray
+
val schema = StructType(outputCols).toAttributes
Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 8eae353..23a1fc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
+import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType}
class DataFrameStatSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -366,6 +366,30 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
}
}
+ test("SPARK-28818: Respect original column nullability in `freqItems`") {
+ val rows = spark.sparkContext.parallelize(
+ Seq(Row("1", "a"), Row("2", null), Row("3", "b"))
+ )
+ val schema = StructType(Seq(
+ StructField("non_null", StringType, false),
+ StructField("nullable", StringType, true)
+ ))
+ val df = spark.createDataFrame(rows, schema)
+
+ val result = df.stat.freqItems(df.columns)
+
+ val nonNullableDataType = result.schema("non_null_freqItems").dataType.asInstanceOf[ArrayType]
+ val nullableDataType = result.schema("nullable_freqItems").dataType.asInstanceOf[ArrayType]
+
+ assert(nonNullableDataType.containsNull == false)
+ assert(nullableDataType.containsNull == true)
+ // Original bug was a NullPointerException exception caused by calling collect(), test for this
+ val resultRow = result.collect()(0)
+
+ assert(resultRow.get(0).asInstanceOf[Seq[String]].toSet == Set("1", "2", "3"))
+ assert(resultRow.get(1).asInstanceOf[Seq[String]].toSet == Set("a", "b", null))
+ }
+
test("sampleBy") {
val df = spark.range(0, 100).select((col("id") % 3).as("key"))
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org