You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/07/07 03:13:05 UTC

[spark] branch branch-3.0 updated: [SPARK-32167][SQL] Fix GetArrayStructFields to respect inner field's nullability together

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new a9ca4ab  [SPARK-32167][SQL] Fix GetArrayStructFields to respect inner field's nullability together
a9ca4ab is described below

commit a9ca4abb092bb5d293fc00842e98839fd2d632cf
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Mon Jul 6 20:07:33 2020 -0700

    [SPARK-32167][SQL] Fix GetArrayStructFields to respect inner field's nullability together
    
    ### What changes were proposed in this pull request?
    
    Fix nullability of `GetArrayStructFields`. It should consider both the original array's `containsNull` and the inner field's nullability.
    
    ### Why are the changes needed?
    
    Fix a correctness issue.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. See the added test.
    
    ### How was this patch tested?
    
    a new UT and end-to-end test
    
    Closes #28992 from cloud-fan/bug.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
    (cherry picked from commit 5d296ed39e3dd79ddb10c68657e773adba40a5e0)
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../expressions/complexTypeExtractors.scala        |  2 +-
 .../catalyst/expressions/ComplexTypeSuite.scala    | 26 ++++++++++++++++++++++
 .../catalyst/expressions/SelectedFieldSuite.scala  |  8 +++----
 .../org/apache/spark/sql/ComplexTypesSuite.scala   | 11 +++++++++
 4 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 9c600c9d3..89ff4fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -57,7 +57,7 @@ object ExtractValue {
         val fieldName = v.toString
         val ordinal = findField(fields, fieldName, resolver)
         GetArrayStructFields(child, fields(ordinal).copy(name = fieldName),
-          ordinal, fields.length, containsNull)
+          ordinal, fields.length, containsNull || fields(ordinal).nullable)
 
       case (_: ArrayType, _) => GetArrayItem(child, extraction)
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 3df7d02..dbe4370 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
@@ -159,6 +160,31 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null)
   }
 
+  test("SPARK-32167: nullability of GetArrayStructFields") {
+    val resolver = SQLConf.get.resolver
+
+    val array1 = ArrayType(
+      new StructType().add("a", "int", nullable = true),
+      containsNull = false)
+    val data1 = Literal.create(Seq(Row(null)), array1)
+    val get1 = ExtractValue(data1, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
+    assert(get1.containsNull)
+
+    val array2 = ArrayType(
+      new StructType().add("a", "int", nullable = false),
+      containsNull = true)
+    val data2 = Literal.create(Seq(null), array2)
+    val get2 = ExtractValue(data2, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
+    assert(get2.containsNull)
+
+    val array3 = ArrayType(
+      new StructType().add("a", "int", nullable = false),
+      containsNull = false)
+    val data3 = Literal.create(Seq(Row(1)), array3)
+    val get3 = ExtractValue(data3, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
+    assert(!get3.containsNull)
+  }
+
   test("CreateArray") {
     val intSeq = Seq(5, 10, 15, 20, 25)
     val longSeq = intSeq.map(_.toLong)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala
index 3c826e8..76d6890c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala
@@ -254,13 +254,13 @@ class SelectedFieldSuite extends AnalysisTest {
     StructField("col3", ArrayType(StructType(
       StructField("field1", StructType(
         StructField("subfield1", IntegerType, nullable = false) :: Nil))
-        :: Nil), containsNull = false), nullable = false)
+        :: Nil), containsNull = true), nullable = false)
   }
 
   testSelect(arrayWithStructAndMap, "col3.field2['foo'] as foo") {
     StructField("col3", ArrayType(StructType(
       StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false))
-        :: Nil), containsNull = false), nullable = false)
+        :: Nil), containsNull = true), nullable = false)
   }
 
   //  |-- col1: string (nullable = false)
@@ -471,7 +471,7 @@ class SelectedFieldSuite extends AnalysisTest {
   testSelect(mapWithArrayOfStructKey, "map_keys(col2)[0].field1 as foo") {
     StructField("col2", MapType(
       ArrayType(StructType(
-        StructField("field1", StringType) :: Nil), containsNull = false),
+        StructField("field1", StringType) :: Nil), containsNull = true),
       ArrayType(StructType(
         StructField("field3", StructType(
           StructField("subfield3", IntegerType) ::
@@ -482,7 +482,7 @@ class SelectedFieldSuite extends AnalysisTest {
     StructField("col2", MapType(
       ArrayType(StructType(
         StructField("field2", StructType(
-          StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false),
+          StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = true),
       ArrayType(StructType(
         StructField("field3", StructType(
           StructField("subfield3", IntegerType) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala
index 6b50333..bdcf723 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala
@@ -17,11 +17,15 @@
 
 package org.apache.spark.sql
 
+import scala.collection.JavaConverters._
+
 import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{ArrayType, StructType}
 
 class ComplexTypesSuite extends QueryTest with SharedSparkSession {
+  import testImplicits._
 
   override def beforeAll(): Unit = {
     super.beforeAll()
@@ -106,4 +110,11 @@ class ComplexTypesSuite extends QueryTest with SharedSparkSession {
     checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil)
     checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
   }
+
+  test("SPARK-32167: get field from an array of struct") {
+    val innerStruct = new StructType().add("i", "int", nullable = true)
+    val schema = new StructType().add("arr", ArrayType(innerStruct, containsNull = false))
+    val df = spark.createDataFrame(List(Row(Seq(Row(1), Row(null)))).asJava, schema)
+    checkAnswer(df.select($"arr".getField("i")), Row(Seq(1, null)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org