You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/09/16 13:05:44 UTC

[spark] branch branch-3.2 updated: [SPARK-40470][SQL] Handle GetArrayStructFields and GetMapValue in "arrays_zip" function

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 1b84e44fce7 [SPARK-40470][SQL] Handle GetArrayStructFields and GetMapValue in "arrays_zip" function
1b84e44fce7 is described below

commit 1b84e44fce7da84382c4874fe3875d55c6647ddf
Author: Ivan Sadikov <iv...@databricks.com>
AuthorDate: Fri Sep 16 22:05:03 2022 +0900

    [SPARK-40470][SQL] Handle GetArrayStructFields and GetMapValue in "arrays_zip" function
    
    ### What changes were proposed in this pull request?
    
    This is a follow-up for https://github.com/apache/spark/pull/37833.
    
    The PR fixes column names in `arrays_zip` function for the cases when `GetArrayStructFields` and `GetMapValue` expressions are used (see unit tests for more details).
    
    Before the patch, the column names would be indexes or an AnalysisException would be thrown in the case of `GetArrayStructFields` example.
    
    ### Why are the changes needed?
    
    Fixes an inconsistency issue in Spark 3.2 and onwards where the fields would be labeled as indexes instead of column names.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    I added unit tests that reproduce the issue and confirmed that the patch fixes them.
    
    Closes #37911 from sadikovi/SPARK-40470.
    
    Authored-by: Ivan Sadikov <iv...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
    (cherry picked from commit 9b0f979141ba2c4124d96bc5da69ea5cac51df0d)
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../expressions/collectionOperations.scala         |  4 +-
 .../apache/spark/sql/DataFrameFunctionsSuite.scala | 45 ++++++++++++++++++++++
 2 files changed, 48 insertions(+), 1 deletion(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 6f66450844f..9919b31ca12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -191,7 +191,9 @@ case class ArraysZip(children: Seq[Expression], names: Seq[Expression])
         case (u: UnresolvedAttribute, _) => Literal(u.nameParts.last)
         case (e: NamedExpression, _) if e.resolved => Literal(e.name)
         case (e: NamedExpression, _) => NamePlaceholder
-        case (e: GetStructField, _) => Literal(e.extractFieldName)
+        case (g: GetStructField, _) => Literal(g.extractFieldName)
+        case (g: GetArrayStructFields, _) => Literal(g.field.name)
+        case (g: GetMapValue, _) => Literal(g.key)
         case (_, idx) => Literal(idx.toString)
       })
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 820d760bd72..8e43de53c34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -641,6 +641,51 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
     assert(fieldNames.toSeq === Seq("arr_1", "arr_2", "arr_3"))
   }
 
+  test("SPARK-40470: array_zip should return field names in GetArrayStructFields") {
+    val df = spark.read.json(Seq(
+      """
+      {
+        "arr": [
+          {
+            "obj": {
+              "nested": {
+                "field1": [1],
+                "field2": [2]
+              }
+            }
+          }
+        ]
+      }
+      """).toDS())
+
+    val res = df
+      .selectExpr("arrays_zip(arr.obj.nested.field1, arr.obj.nested.field2) as arr")
+      .select(col("arr.field1"), col("arr.field2"))
+
+    val fieldNames = res.schema.fieldNames
+    assert(fieldNames.toSeq === Seq("field1", "field2"))
+
+    checkAnswer(res, Row(Seq(Seq(1)), Seq(Seq(2))) :: Nil)
+  }
+
+  test("SPARK-40470: arrays_zip should return field names in GetMapValue") {
+    val df = spark.sql("""
+      select
+        map(
+          'arr_1', array(1, 2),
+          'arr_2', array(3, 4)
+        ) as map_obj
+      """)
+
+    val res = df.selectExpr("arrays_zip(map_obj.arr_1, map_obj.arr_2) as arr")
+
+    val fieldNames = res.schema.head.dataType.asInstanceOf[ArrayType]
+      .elementType.asInstanceOf[StructType].fieldNames
+    assert(fieldNames.toSeq === Seq("arr_1", "arr_2"))
+
+    checkAnswer(res, Row(Seq(Row(1, 3), Row(2, 4))))
+  }
+
   def testSizeOfMap(sizeOfNull: Any): Unit = {
     val df = Seq(
       (Map[Int, Int](1 -> 1, 2 -> 2), "x"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org