You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/09/16 13:05:36 UTC
[spark] branch branch-3.3 updated: [SPARK-40470][SQL] Handle GetArrayStructFields and GetMapValue in "arrays_zip" function
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new b9a514ea051 [SPARK-40470][SQL] Handle GetArrayStructFields and GetMapValue in "arrays_zip" function
b9a514ea051 is described below
commit b9a514ea0519e2da21efe2201c7f888be2640458
Author: Ivan Sadikov <iv...@databricks.com>
AuthorDate: Fri Sep 16 22:05:03 2022 +0900
[SPARK-40470][SQL] Handle GetArrayStructFields and GetMapValue in "arrays_zip" function
### What changes were proposed in this pull request?
This is a follow-up for https://github.com/apache/spark/pull/37833.
The PR fixes column names in `arrays_zip` function for the cases when `GetArrayStructFields` and `GetMapValue` expressions are used (see unit tests for more details).
Before the patch, the column names would be indexes or an AnalysisException would be thrown in the case of `GetArrayStructFields` example.
### Why are the changes needed?
Fixes an inconsistency issue in Spark 3.2 and onwards where the fields would be labeled as indexes instead of column names.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
I added unit tests that reproduce the issue and confirmed that the patch fixes them.
Closes #37911 from sadikovi/SPARK-40470.
Authored-by: Ivan Sadikov <iv...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
(cherry picked from commit 9b0f979141ba2c4124d96bc5da69ea5cac51df0d)
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../expressions/collectionOperations.scala | 4 +-
.../apache/spark/sql/DataFrameFunctionsSuite.scala | 45 ++++++++++++++++++++++
2 files changed, 48 insertions(+), 1 deletion(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 05a273763b9..c4bf65bb8ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -267,7 +267,9 @@ case class ArraysZip(children: Seq[Expression], names: Seq[Expression])
case (u: UnresolvedAttribute, _) => Literal(u.nameParts.last)
case (e: NamedExpression, _) if e.resolved => Literal(e.name)
case (e: NamedExpression, _) => NamePlaceholder
- case (e: GetStructField, _) => Literal(e.extractFieldName)
+ case (g: GetStructField, _) => Literal(g.extractFieldName)
+ case (g: GetArrayStructFields, _) => Literal(g.field.name)
+ case (g: GetMapValue, _) => Literal(g.key)
case (_, idx) => Literal(idx.toString)
})
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index a9c17045812..697cce9b50d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -740,6 +740,51 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
assert(fieldNames.toSeq === Seq("arr_1", "arr_2", "arr_3"))
}
+ test("SPARK-40470: array_zip should return field names in GetArrayStructFields") {
+ val df = spark.read.json(Seq(
+ """
+ {
+ "arr": [
+ {
+ "obj": {
+ "nested": {
+ "field1": [1],
+ "field2": [2]
+ }
+ }
+ }
+ ]
+ }
+ """).toDS())
+
+ val res = df
+ .selectExpr("arrays_zip(arr.obj.nested.field1, arr.obj.nested.field2) as arr")
+ .select(col("arr.field1"), col("arr.field2"))
+
+ val fieldNames = res.schema.fieldNames
+ assert(fieldNames.toSeq === Seq("field1", "field2"))
+
+ checkAnswer(res, Row(Seq(Seq(1)), Seq(Seq(2))) :: Nil)
+ }
+
+ test("SPARK-40470: arrays_zip should return field names in GetMapValue") {
+ val df = spark.sql("""
+ select
+ map(
+ 'arr_1', array(1, 2),
+ 'arr_2', array(3, 4)
+ ) as map_obj
+ """)
+
+ val res = df.selectExpr("arrays_zip(map_obj.arr_1, map_obj.arr_2) as arr")
+
+ val fieldNames = res.schema.head.dataType.asInstanceOf[ArrayType]
+ .elementType.asInstanceOf[StructType].fieldNames
+ assert(fieldNames.toSeq === Seq("arr_1", "arr_2"))
+
+ checkAnswer(res, Row(Seq(Row(1, 3), Row(2, 4))))
+ }
+
def testSizeOfMap(sizeOfNull: Any): Unit = {
val df = Seq(
(Map[Int, Int](1 -> 1, 2 -> 2), "x"),
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org