You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2020/09/10 07:10:38 UTC

[spark] branch branch-3.0 updated: [SPARK-32819][SQL][3.0] ignoreNullability parameter should be effective recursively

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 5708045  [SPARK-32819][SQL][3.0] ignoreNullability parameter should be effective recursively
5708045 is described below

commit 5708045fbf3721e97cad65d9aef9d859fcc2da11
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Thu Sep 10 16:06:54 2020 +0900

    [SPARK-32819][SQL][3.0] ignoreNullability parameter should be effective recursively
    
    ### What changes were proposed in this pull request?
    
    This patch proposes to check `ignoreNullability` parameter recursively in `equalsStructurally` method. This backports #29698 to branch-3.0.
    
    ### Why are the changes needed?
    
    `equalsStructurally` is used to check type equality. We can optionally ask to ignore nullability check. But the parameter `ignoreNullability` is not passed recursively down to nested types. So it produces weird error like:
    
    ```
    data type mismatch: argument 3 requires array<array<string>> type, however ... is of array<array<string>> type.
    ```
    
    when running the query `select aggregate(split('abcdefgh',''), array(array('')), (acc, x) -> array(array( x ) ) )`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, fixed a bug when running user query.
    
    ### How was this patch tested?
    
    Unit tests.
    
    Closes #29705 from viirya/SPARK-32819-3.0.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: HyukjinKwon <gu...@apache.org>
---
 .../org/apache/spark/sql/types/DataType.scala      |   8 +-
 .../org/apache/spark/sql/types/DataTypeSuite.scala | 110 ++++++++++++++++++++-
 .../sql-tests/inputs/higher-order-functions.sql    |   3 +
 .../results/ansi/higher-order-functions.sql.out    |  10 +-
 .../results/higher-order-functions.sql.out         |  10 +-
 5 files changed, 132 insertions(+), 9 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index fe8d7ef..3f70b76 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -328,19 +328,19 @@ object DataType {
       ignoreNullability: Boolean = false): Boolean = {
     (from, to) match {
       case (left: ArrayType, right: ArrayType) =>
-        equalsStructurally(left.elementType, right.elementType) &&
+        equalsStructurally(left.elementType, right.elementType, ignoreNullability) &&
           (ignoreNullability || left.containsNull == right.containsNull)
 
       case (left: MapType, right: MapType) =>
-        equalsStructurally(left.keyType, right.keyType) &&
-          equalsStructurally(left.valueType, right.valueType) &&
+        equalsStructurally(left.keyType, right.keyType, ignoreNullability) &&
+          equalsStructurally(left.valueType, right.valueType, ignoreNullability) &&
           (ignoreNullability || left.valueContainsNull == right.valueContainsNull)
 
       case (StructType(fromFields), StructType(toFields)) =>
         fromFields.length == toFields.length &&
           fromFields.zip(toFields)
             .forall { case (l, r) =>
-              equalsStructurally(l.dataType, r.dataType) &&
+              equalsStructurally(l.dataType, r.dataType, ignoreNullability) &&
                 (ignoreNullability || l.nullable == r.nullable)
             }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index b71dc91..52f80b1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -423,10 +423,15 @@ class DataTypeSuite extends SparkFunSuite {
   checkCatalogString(MapType(IntegerType, StringType))
   checkCatalogString(MapType(IntegerType, createStruct(40)))
 
-  def checkEqualsStructurally(from: DataType, to: DataType, expected: Boolean): Unit = {
-    val testName = s"equalsStructurally: (from: $from, to: $to)"
+  def checkEqualsStructurally(
+      from: DataType,
+      to: DataType,
+      expected: Boolean,
+      ignoreNullability: Boolean = false): Unit = {
+    val testName = s"equalsStructurally: (from: $from, to: $to, " +
+      s"ignoreNullability: $ignoreNullability)"
     test(testName) {
-      assert(DataType.equalsStructurally(from, to) === expected)
+      assert(DataType.equalsStructurally(from, to, ignoreNullability) === expected)
     }
   }
 
@@ -453,6 +458,105 @@ class DataTypeSuite extends SparkFunSuite {
     new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)),
     new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)),
     false)
+  checkEqualsStructurally(
+    new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)),
+    new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)),
+    true,
+    ignoreNullability = true)
+  checkEqualsStructurally(
+    new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType)),
+    new StructType().add("f2", IntegerType, nullable = false)
+      .add("g", new StructType().add("f1", StringType)),
+    true,
+    ignoreNullability = true)
+
+  checkEqualsStructurally(
+    ArrayType(
+      ArrayType(IntegerType, true), true),
+    ArrayType(
+      ArrayType(IntegerType, true), true),
+    true,
+    ignoreNullability = false)
+
+  checkEqualsStructurally(
+    ArrayType(
+      ArrayType(IntegerType, true), false),
+    ArrayType(
+      ArrayType(IntegerType, true), true),
+    false,
+    ignoreNullability = false)
+
+  checkEqualsStructurally(
+    ArrayType(
+      ArrayType(IntegerType, true), true),
+    ArrayType(
+      ArrayType(IntegerType, false), true),
+    false,
+    ignoreNullability = false)
+
+  checkEqualsStructurally(
+    ArrayType(
+      ArrayType(IntegerType, true), false),
+    ArrayType(
+      ArrayType(IntegerType, true), true),
+    true,
+    ignoreNullability = true)
+
+  checkEqualsStructurally(
+    ArrayType(
+      ArrayType(IntegerType, true), false),
+    ArrayType(
+      ArrayType(IntegerType, false), true),
+    true,
+    ignoreNullability = true)
+
+  checkEqualsStructurally(
+    MapType(
+      ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+    MapType(
+      ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+    true,
+    ignoreNullability = false)
+
+  checkEqualsStructurally(
+    MapType(
+      ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+    MapType(
+      ArrayType(IntegerType, true), ArrayType(IntegerType, true), false),
+    false,
+    ignoreNullability = false)
+
+  checkEqualsStructurally(
+    MapType(
+      ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+    MapType(
+      ArrayType(IntegerType, true), ArrayType(IntegerType, false), true),
+    false,
+    ignoreNullability = false)
+
+  checkEqualsStructurally(
+    MapType(
+      ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+    MapType(
+      ArrayType(IntegerType, true), ArrayType(IntegerType, true), false),
+    true,
+    ignoreNullability = true)
+
+  checkEqualsStructurally(
+    MapType(
+      ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+    MapType(
+      ArrayType(IntegerType, true), ArrayType(IntegerType, false), true),
+    true,
+    ignoreNullability = true)
+
+  checkEqualsStructurally(
+    MapType(
+      ArrayType(IntegerType, false), ArrayType(IntegerType, true), true),
+    MapType(
+      ArrayType(IntegerType, true), ArrayType(IntegerType, true), true),
+    true,
+    ignoreNullability = true)
 
   test("SPARK-25031: MapType should produce current formatted string for complex types") {
     val keyType: DataType = StructType(Seq(
diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
index cfa06ae..73dfa91 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
@@ -92,3 +92,6 @@ select transform_values(ys, (k, v) -> k + v) as v from nested;
 -- use non reversed keywords: all is non reversed only if !ansi
 select transform(ys, all -> all * all) as v from values (array(32, 97)) as t(ys);
 select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys);
+
+-- SPARK-32819: Aggregate on nested string arrays
+select aggregate(split('abcdefgh',''), array(array('')), (acc, x) -> array(array(x)));
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
index 7bef1ba..6d26fae 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 29
+-- Number of queries: 30
 
 
 -- !query
@@ -282,3 +282,11 @@ no viable alternative at input 'all'(line 1, pos 22)
 == SQL ==
 select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys)
 ----------------------^^^
+
+
+-- !query
+select aggregate(split('abcdefgh',''), array(array('')), (acc, x) -> array(array(x)))
+-- !query schema
+struct<aggregate(split(abcdefgh, , -1), array(array()), lambdafunction(array(array(namedlambdavariable())), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable())):array<array<string>>>
+-- !query output
+[[""]]
diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
index d35d0d5..7b31b56 100644
--- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 29
+-- Number of queries: 30
 
 
 -- !query
@@ -270,3 +270,11 @@ select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(
 struct<v:array<int>>
 -- !query output
 [32,98]
+
+
+-- !query
+select aggregate(split('abcdefgh',''), array(array('')), (acc, x) -> array(array(x)))
+-- !query schema
+struct<aggregate(split(abcdefgh, , -1), array(array()), lambdafunction(array(array(namedlambdavariable())), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable())):array<array<string>>>
+-- !query output
+[[""]]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org