You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2020/09/10 07:10:38 UTC
[spark] branch branch-3.0 updated: [SPARK-32819][SQL][3.0]
ignoreNullability parameter should be effective recursively
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 5708045 [SPARK-32819][SQL][3.0] ignoreNullability parameter should be effective recursively
5708045 is described below
commit 5708045fbf3721e97cad65d9aef9d859fcc2da11
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Thu Sep 10 16:06:54 2020 +0900
[SPARK-32819][SQL][3.0] ignoreNullability parameter should be effective recursively
### What changes were proposed in this pull request?
This patch proposes to check `ignoreNullability` parameter recursively in `equalsStructurally` method. This backports #29698 to branch-3.0.
### Why are the changes needed?
`equalsStructurally` is used to check type equality. We can optionally ask to ignore nullability check. But the parameter `ignoreNullability` is not passed recursively down to nested types. So it produces weird error like:
```
data type mismatch: argument 3 requires array<array<string>> type, however ... is of array<array<string>> type.
```
when running the query `select aggregate(split('abcdefgh',''), array(array('')), (acc, x) -> array(array( x ) ) )`.
### Does this PR introduce _any_ user-facing change?
Yes, fixed a bug when running user query.
### How was this patch tested?
Unit tests.
Closes #29705 from viirya/SPARK-32819-3.0.
Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
Signed-off-by: HyukjinKwon <gu...@apache.org>
---
.../org/apache/spark/sql/types/DataType.scala | 8 +-
.../org/apache/spark/sql/types/DataTypeSuite.scala | 110 ++++++++++++++++++++-
.../sql-tests/inputs/higher-order-functions.sql | 3 +
.../results/ansi/higher-order-functions.sql.out | 10 +-
.../results/higher-order-functions.sql.out | 10 +-
5 files changed, 132 insertions(+), 9 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index fe8d7ef..3f70b76 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -328,19 +328,19 @@ object DataType {
ignoreNullability: Boolean = false): Boolean = {
(from, to) match {
case (left: ArrayType, right: ArrayType) =>
- equalsStructurally(left.elementType, right.elementType) &&
+ equalsStructurally(left.elementType, right.elementType, ignoreNullability) &&
(ignoreNullability || left.containsNull == right.containsNull)
case (left: MapType, right: MapType) =>
- equalsStructurally(left.keyType, right.keyType) &&
- equalsStructurally(left.valueType, right.valueType) &&
+ equalsStructurally(left.keyType, right.keyType, ignoreNullability) &&
+ equalsStructurally(left.valueType, right.valueType, ignoreNullability) &&
(ignoreNullability || left.valueContainsNull == right.valueContainsNull)
case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields)
.forall { case (l, r) =>
- equalsStructurally(l.dataType, r.dataType) &&
+ equalsStructurally(l.dataType, r.dataType, ignoreNullability) &&
(ignoreNullability || l.nullable == r.nullable)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index b71dc91..52f80b1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -423,10 +423,15 @@ class DataTypeSuite extends SparkFunSuite {
checkCatalogString(MapType(IntegerType, StringType))
checkCatalogString(MapType(IntegerType, createStruct(40)))
- def checkEqualsStructurally(from: DataType, to: DataType, expected: Boolean): Unit = {
- val testName = s"equalsStructurally: (from: $from, to: $to)"
+ def checkEqualsStructurally(
+ from: DataType,
+ to: DataType,
+ expected: Boolean,
+ ignoreNullability: Boolean = false): Unit = {
+ val testName = s"equalsStructurally: (from: $from, to: $to, " +
+ s"ignoreNullability: $ignoreNullability)"
test(testName) {
- assert(DataType.equalsStructurally(from, to) === expected)
+ assert(DataType.equalsStructurally(from, to, ignoreNullability) === expected)
}
}
@@ -453,6 +458,105 @@ class DataTypeSuite extends SparkFunSuite {
new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)),
new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)),
false)
+ checkEqualsStructurally(
+ new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)),
+ new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)),
+ true,
+ ignoreNullability = true)
+ checkEqualsStructurally(
+ new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType)),
+ new StructType().add("f2", IntegerType, nullable = false)
+ .add("g", new StructType().add("f1", StringType)),
+ true,
+ ignoreNullability = true)
+
+ checkEqualsStructurally(
+ ArrayType(
+ ArrayType(IntegerType, true), true),
+ ArrayType(
+ ArrayType(IntegerType, true), true),
+ true,
+ ignoreNullability = false)
+
+ checkEqualsStructurally(
+ ArrayType(
+ ArrayType(IntegerType, true), false),
+ ArrayType(
+ ArrayType(IntegerType, true), true),
+ false,
+ ignoreNullability = false)
+
+ checkEqualsStructurally(
+ ArrayType(
+ ArrayType(IntegerType, true), true),
+ ArrayType(
+ ArrayType(IntegerType, false), true),
+ false,
+ ignoreNullability = false)
+
+ checkEqualsStructurally(
+ ArrayType(
+ ArrayType(IntegerType, true), false),
+ ArrayType(
+ ArrayType(IntegerType, true), true),
+ true,
+ ignoreNullability = true)
+
+ checkEqualsStructurally(
+ ArrayType(
+ ArrayType(IntegerType, true), false),
+ ArrayType(
+ ArrayType(IntegerType, false), true),
+ true,
+ ignoreNullability = true)
+
+ checkEqualsStructurally(
+ MapType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+ MapType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+ true,
+ ignoreNullability = false)
+
+ checkEqualsStructurally(
+ MapType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+ MapType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, true), false),
+ false,
+ ignoreNullability = false)
+
+ checkEqualsStructurally(
+ MapType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+ MapType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, false), true),
+ false,
+ ignoreNullability = false)
+
+ checkEqualsStructurally(
+ MapType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+ MapType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, true), false),
+ true,
+ ignoreNullability = true)
+
+ checkEqualsStructurally(
+ MapType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+ MapType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, false), true),
+ true,
+ ignoreNullability = true)
+
+ checkEqualsStructurally(
+ MapType(
+ ArrayType(IntegerType, false), ArrayType(IntegerType, true), true),
+ MapType(
+ ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+ true,
+ ignoreNullability = true)
test("SPARK-25031: MapType should produce current formatted string for complex types") {
val keyType: DataType = StructType(Seq(
diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
index cfa06ae..73dfa91 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
@@ -92,3 +92,6 @@ select transform_values(ys, (k, v) -> k + v) as v from nested;
-- use non reversed keywords: all is non reversed only if !ansi
select transform(ys, all -> all * all) as v from values (array(32, 97)) as t(ys);
select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys);
+
+-- SPARK-32819: Aggregate on nested string arrays
+select aggregate(split('abcdefgh',''), array(array('')), (acc, x) -> array(array(x)));
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
index 7bef1ba..6d26fae 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 29
+-- Number of queries: 30
-- !query
@@ -282,3 +282,11 @@ no viable alternative at input 'all'(line 1, pos 22)
== SQL ==
select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys)
----------------------^^^
+
+
+-- !query
+select aggregate(split('abcdefgh',''), array(array('')), (acc, x) -> array(array(x)))
+-- !query schema
+struct<aggregate(split(abcdefgh, , -1), array(array()), lambdafunction(array(array(namedlambdavariable())), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable())):array<array<string>>>
+-- !query output
+[[""]]
diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
index d35d0d5..7b31b56 100644
--- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 29
+-- Number of queries: 30
-- !query
@@ -270,3 +270,11 @@ select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(
struct<v:array<int>>
-- !query output
[32,98]
+
+
+-- !query
+select aggregate(split('abcdefgh',''), array(array('')), (acc, x) -> array(array(x)))
+-- !query schema
+struct<aggregate(split(abcdefgh, , -1), array(array()), lambdafunction(array(array(namedlambdavariable())), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable())):array<array<string>>>
+-- !query output
+[[""]]
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org