You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by vi...@apache.org on 2021/04/26 06:40:46 UTC

[spark] branch branch-3.1 updated: [SPARK-35213][SQL] Keep the correct ordering of nested structs in chained withField operations

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new eaceb40  [SPARK-35213][SQL] Keep the correct ordering of nested structs in chained withField operations
eaceb40 is described below

commit eaceb4052ee359e5eaa2dd308ed2c9f9f71b0170
Author: Adam Binford <ad...@gmail.com>
AuthorDate: Sun Apr 25 23:39:56 2021 -0700

    [SPARK-35213][SQL] Keep the correct ordering of nested structs in chained withField operations
    
    ### What changes were proposed in this pull request?
    
    Modifies the UpdateFields optimizer to fix correctness issues with certain nested and chained withField operations. Examples for recreating the issue are in the new unit tests as well as the JIRA issue.
    
    ### Why are the changes needed?
    
    Certain withField patterns can cause Exceptions or even incorrect results. It appears to be a result of the additional UpdateFields optimization added in https://github.com/apache/spark/pull/29812. It traverses fieldOps in reverse order to take the last one per field, but this can cause nested structs to change order which leads to mismatches between the schema and the actual data. This updates the optimization to maintain the initial ordering of nested structs to match the generated schema.
    
    ### Does this PR introduce _any_ user-facing change?
    
    It fixes exceptions and incorrect results for valid uses in the latest Spark release.
    
    ### How was this patch tested?
    
    Added new unit tests for these edge cases.
    
    Closes #32338 from Kimahriman/bug/optimize-with-fields.
    
    Authored-by: Adam Binford <ad...@gmail.com>
    Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
    (cherry picked from commit 74afc68e2172cb0dc3567e12a8a2c304bb7ea138)
    Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
---
 .../sql/catalyst/optimizer/UpdateFields.scala      | 30 +++++-------
 .../optimizer/OptimizeWithFieldsSuite.scala        | 21 +++++++++
 .../apache/spark/sql/ColumnExpressionSuite.scala   | 55 ++++++++++++++++++++++
 3 files changed, 88 insertions(+), 18 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala
index 465d2ef..d3127b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala
@@ -50,28 +50,22 @@ object OptimizeUpdateFields extends Rule[LogicalPlan] {
       val values = withFields.map(_.valExpr)
 
       val newNames = mutable.ArrayBuffer.empty[String]
-      val newValues = mutable.ArrayBuffer.empty[Expression]
+      val newValues = mutable.HashMap.empty[String, Expression]
+      // Used to remember the casing of the last instance
+      val nameMap = mutable.HashMap.empty[String, String]
 
-      if (caseSensitive) {
-        names.zip(values).reverse.foreach { case (name, value) =>
-          if (!newNames.contains(name)) {
-            newNames += name
-            newValues += value
-          }
-        }
-      } else {
-        val nameSet = mutable.HashSet.empty[String]
-        names.zip(values).reverse.foreach { case (name, value) =>
-          val lowercaseName = name.toLowerCase(Locale.ROOT)
-          if (!nameSet.contains(lowercaseName)) {
-            newNames += name
-            newValues += value
-            nameSet += lowercaseName
-          }
+      names.zip(values).foreach { case (name, value) =>
+        val normalizedName = if (caseSensitive) name else name.toLowerCase(Locale.ROOT)
+        if (nameMap.contains(normalizedName)) {
+          newValues += normalizedName -> value
+        } else {
+          newNames += normalizedName
+          newValues += normalizedName -> value
         }
+        nameMap += normalizedName -> name
       }
 
-      val newWithFields = newNames.reverse.zip(newValues.reverse).map(p => WithField(p._1, p._2))
+      val newWithFields = newNames.map(n => WithField(nameMap(n), newValues(n)))
       UpdateFields(structExpr, newWithFields.toSeq)
 
     case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala
index b093b39..e63742a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala
@@ -126,4 +126,25 @@ class OptimizeWithFieldsSuite extends PlanTest {
       comparePlans(optimized, correctAnswer)
     }
   }
+
+  test("SPARK-35213: ensure optimize WithFields maintains correct WithField ordering") {
+    val originalQuery = testRelation
+      .select(
+        Alias(UpdateFields('a,
+          WithField("a1", Literal(3)) ::
+          WithField("b1", Literal(4)) ::
+          WithField("a1", Literal(5)) ::
+          Nil), "out")())
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .select(
+        Alias(UpdateFields('a,
+          WithField("a1", Literal(5)) ::
+          WithField("b1", Literal(4)) ::
+          Nil), "out")())
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 01b1508..186091d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -1660,6 +1660,61 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
       StructType(Seq(StructField("a", IntegerType, nullable = true))))
   }
 
+  test("SPARK-35213: chained withField operations should have correct schema for new columns") {
+    val df = spark.createDataFrame(
+      sparkContext.parallelize(Row(null) :: Nil),
+      StructType(Seq(StructField("data", NullType))))
+
+    checkAnswer(
+      df.withColumn("data", struct()
+        .withField("a", struct())
+        .withField("b", struct())
+        .withField("a.aa", lit("aa1"))
+        .withField("b.ba", lit("ba1"))
+        .withField("a.ab", lit("ab1"))),
+        Row(Row(Row("aa1", "ab1"), Row("ba1"))) :: Nil,
+        StructType(Seq(
+          StructField("data", StructType(Seq(
+            StructField("a", StructType(Seq(
+              StructField("aa", StringType, nullable = false),
+              StructField("ab", StringType, nullable = false)
+            )), nullable = false),
+            StructField("b", StructType(Seq(
+              StructField("ba", StringType, nullable = false)
+            )), nullable = false)
+          )), nullable = false)
+        ))
+    )
+  }
+
+  test("SPARK-35213: optimized withField operations should maintain correct nested struct " +
+    "ordering") {
+    val df = spark.createDataFrame(
+      sparkContext.parallelize(Row(null) :: Nil),
+      StructType(Seq(StructField("data", NullType))))
+
+    checkAnswer(
+      df.withColumn("data", struct()
+          .withField("a", struct().withField("aa", lit("aa1")))
+          .withField("b", struct().withField("ba", lit("ba1")))
+        )
+        .withColumn("data", col("data").withField("b.bb", lit("bb1")))
+        .withColumn("data", col("data").withField("a.ab", lit("ab1"))),
+        Row(Row(Row("aa1", "ab1"), Row("ba1", "bb1"))) :: Nil,
+        StructType(Seq(
+          StructField("data", StructType(Seq(
+            StructField("a", StructType(Seq(
+              StructField("aa", StringType, nullable = false),
+              StructField("ab", StringType, nullable = false)
+            )), nullable = false),
+            StructField("b", StructType(Seq(
+              StructField("ba", StringType, nullable = false),
+              StructField("bb", StringType, nullable = false)
+            )), nullable = false)
+          )), nullable = false)
+        ))
+    )
+  }
 
   test("dropFields should throw an exception if called on a non-StructType column") {
     intercept[AnalysisException] {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org