You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2019/04/01 20:54:26 UTC

[spark] branch master updated: [SPARK-27329][SQL] Pruning nested field in map of map key and value from object serializers

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new eaf008a  [SPARK-27329][SQL] Pruning nested field in map of map key and value from object serializers
eaf008a is described below

commit eaf008ad0e6e246665127c283b139a16424f3139
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Apr 1 13:53:55 2019 -0700

    [SPARK-27329][SQL] Pruning nested field in map of map key and value from object serializers
    
    ## What changes were proposed in this pull request?
    
    If object serializer has map of map key/value, pruning nested field should work.
    
    Previously object serializer pruner don't recursively prunes nested fields if it is deeply located in map key or value. This patch proposed to address it by slightly factoring the pruning logic.
    
    ## How was this patch tested?
    
    Added tests.
    
    Closes #24260 from viirya/SPARK-27329.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../spark/sql/catalyst/optimizer/objects.scala     | 26 +++++--------
 .../spark/sql/DatasetOptimizationSuite.scala       | 44 +++++++++++++++++-----
 2 files changed, 43 insertions(+), 27 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
index c48bd8f..216c125 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
@@ -175,30 +175,22 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
       serializer: NamedExpression,
       prunedDataType: DataType): NamedExpression = {
     val prunedStructTypes = collectStructType(prunedDataType, ArrayBuffer.empty[StructType])
-    var structTypeIndex = 0
+      .toIterator
 
-    val transformedSerializer = serializer.transformDown {
+    def transformer: PartialFunction[Expression, Expression] = {
       case m: ExternalMapToCatalyst =>
-        val prunedKeyConverter = m.keyConverter.transformDown {
-          case s: CreateNamedStruct if structTypeIndex < prunedStructTypes.size =>
-            val prunedType = prunedStructTypes(structTypeIndex)
-            structTypeIndex += 1
-            pruneNamedStruct(s, prunedType)
-        }
-        val prunedValueConverter = m.valueConverter.transformDown {
-          case s: CreateNamedStruct if structTypeIndex < prunedStructTypes.size =>
-            val prunedType = prunedStructTypes(structTypeIndex)
-            structTypeIndex += 1
-            pruneNamedStruct(s, prunedType)
-        }
+        val prunedKeyConverter = m.keyConverter.transformDown(transformer)
+        val prunedValueConverter = m.valueConverter.transformDown(transformer)
+
         m.copy(keyConverter = alignNullTypeInIf(prunedKeyConverter),
           valueConverter = alignNullTypeInIf(prunedValueConverter))
-      case s: CreateNamedStruct if structTypeIndex < prunedStructTypes.size =>
-        val prunedType = prunedStructTypes(structTypeIndex)
-        structTypeIndex += 1
+
+      case s: CreateNamedStruct if prunedStructTypes.hasNext =>
+        val prunedType = prunedStructTypes.next()
         pruneNamedStruct(s, prunedType)
     }
 
+    val transformedSerializer = serializer.transformDown(transformer)
     val prunedSerializer = alignNullTypeInIf(transformedSerializer).asInstanceOf[NamedExpression]
 
     if (prunedSerializer.dataType.sameType(prunedDataType)) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala
index cfbb343..b924254 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
+import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression}
 import org.apache.spark.sql.catalyst.expressions.objects.ExternalMapToCatalyst
 import org.apache.spark.sql.catalyst.plans.logical.SerializeFromObject
 import org.apache.spark.sql.functions.expr
@@ -47,16 +47,15 @@ class DatasetOptimizationSuite extends QueryTest with SharedSQLContext {
       case s: SerializeFromObject => s
     }.head
 
+    def collectNamedStruct: PartialFunction[Expression, Seq[CreateNamedStruct]] = {
+      case c: CreateNamedStruct => Seq(c)
+      case m: ExternalMapToCatalyst =>
+        m.keyConverter.collect(collectNamedStruct).flatten ++
+          m.valueConverter.collect(collectNamedStruct).flatten
+    }
+
     serializer.serializer.zip(structFields).foreach { case (serializer, fields) =>
-      val structs = serializer.collect {
-        case c: CreateNamedStruct => Seq(c)
-        case m: ExternalMapToCatalyst =>
-          m.keyConverter.collect {
-            case c: CreateNamedStruct => c
-          } ++ m.valueConverter.collect {
-            case c: CreateNamedStruct => c
-          }
-      }.flatten
+      val structs: Seq[CreateNamedStruct] = serializer.collect(collectNamedStruct).flatten
       assert(structs.size == fields.size)
       structs.zip(fields).foreach { case (struct, fieldNames) =>
         assert(struct.names.map(_.toString) == fieldNames)
@@ -142,4 +141,29 @@ class DatasetOptimizationSuite extends QueryTest with SharedSQLContext {
       checkAnswer(df1, Seq(Row("1"), Row("2"), Row("3")))
     }
   }
+
+  test("Pruned nested serializers: map of map value") {
+    withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") {
+      val mapData = Seq(
+        (Map(("k", Map(("k2", ("a_1", 11))))), 1),
+        (Map(("k", Map(("k2", ("b_1", 22))))), 2),
+        (Map(("k", Map(("k2", ("c_1", 33))))), 3))
+      val mapDs = mapData.toDS().map(t => (t._1, t._2 + 1))
+      val df = mapDs.select("_1.k.k2._1")
+      testSerializer(df, Seq(Seq("_1")))
+    }
+  }
+
+  test("Pruned nested serializers: map of map key") {
+    withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") {
+      val mapData = Seq(
+        (Map((Map((("1", 1), "val1")), "a_1")), 1),
+        (Map((Map((("2", 2), "val2")), "b_1")), 2),
+        (Map((Map((("3", 3), "val3")), "c_1")), 3))
+      val mapDs = mapData.toDS().map(t => (t._1, t._2 + 1))
+      val df = mapDs.select(expr("map_keys(map_keys(_1)[0])._1[0]"))
+      testSerializer(df, Seq(Seq("_1")))
+      checkAnswer(df, Seq(Row("1"), Row("2"), Row("3")))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org