You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2019/03/25 22:37:16 UTC

[spark] branch master updated: [SPARK-26847][SQL] Pruning nested serializers from object serializers: MapType support

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8433ff6  [SPARK-26847][SQL] Pruning nested serializers from object serializers: MapType support
8433ff6 is described below

commit 8433ff6607a25f5e9c4e685f3b9521d232375265
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Mar 25 15:36:58 2019 -0700

    [SPARK-26847][SQL] Pruning nested serializers from object serializers: MapType support
    
    ## What changes were proposed in this pull request?
    
    In SPARK-26837, we prune nested fields from object serializers if they are unnecessary in the query execution. SPARK-26837 leaves the support of MapType as a TODO item. This proposes to support map type.
    
    ## How was this patch tested?
    
    Added tests.
    
    Closes #24158 from viirya/SPARK-26847.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../spark/sql/catalyst/optimizer/objects.scala     | 76 +++++++++++++---------
 .../optimizer/ObjectSerializerPruningSuite.scala   | 10 ++-
 .../spark/sql/DatasetOptimizationSuite.scala       | 25 ++++++-
 3 files changed, 78 insertions(+), 33 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
index 96a172c..8e92421 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{ArrayType, DataType, StructType}
+import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType, UserDefinedType}
 
 /*
  * This file defines optimization rules related to object manipulation (for the Dataset API).
@@ -121,9 +121,8 @@ object EliminateMapObjects extends Rule[LogicalPlan] {
 object ObjectSerializerPruning extends Rule[LogicalPlan] {
 
   /**
-   * Collects all struct types from given data type object, recursively. Supports struct and array
-   * types for now.
-   * TODO(SPARK-26847): support map type.
+   * Visible for testing.
+   * Collects all struct types from given data type object, recursively.
    */
   def collectStructType(dt: DataType, structs: ArrayBuffer[StructType]): ArrayBuffer[StructType] = {
     dt match {
@@ -132,17 +131,45 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
         fields.map(f => collectStructType(f.dataType, structs))
       case ArrayType(elementType, _) =>
         collectStructType(elementType, structs)
+      case MapType(_, valueType, _) =>
+        // Because we can't select a field from struct in key, so we skip key type.
+        collectStructType(valueType, structs)
+      // We don't use UserDefinedType in those serializers.
+      case _: UserDefinedType[_] =>
       case _ =>
     }
     structs
   }
 
   /**
+   * This method returns pruned `CreateNamedStruct` expression given an original `CreateNamedStruct`
+   * and a pruned `StructType`.
+   */
+  private def pruneNamedStruct(struct: CreateNamedStruct, prunedType: StructType) = {
+    // Filters out the pruned fields.
+    val resolver = SQLConf.get.resolver
+    val prunedFields = struct.nameExprs.zip(struct.valExprs).filter { case (nameExpr, _) =>
+      val name = nameExpr.eval(EmptyRow).toString
+      prunedType.fieldNames.exists(resolver(_, name))
+    }.flatMap(pair => Seq(pair._1, pair._2))
+
+    CreateNamedStruct(prunedFields)
+  }
+
+  /**
+   * When we change nested serializer data type, `If` expression will be unresolved because
+   * literal null's data type doesn't match now. We need to align it with new data type.
+   * Note: we should do `transformUp` explicitly to change data types.
+   */
+  private def alignNullTypeInIf(expr: Expression) = expr.transformUp {
+    case i @ If(_: IsNull, Literal(null, dt), ser) if !dt.sameType(ser.dataType) =>
+      i.copy(trueValue = Literal(null, ser.dataType))
+  }
+
+  /**
    * This method prunes given serializer expression by given pruned data type. For example,
    * given a serializer creating struct(a int, b int) and pruned data type struct(a int),
-   * this method returns pruned serializer creating struct(a int). For now it supports to
-   * prune nested fields in struct and array of struct.
-   * TODO(SPARK-26847): support to prune nested fields in key and value of map type.
+   * this method returns pruned serializer creating struct(a int).
    */
   def pruneSerializer(
       serializer: NamedExpression,
@@ -150,31 +177,22 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
     val prunedStructTypes = collectStructType(prunedDataType, ArrayBuffer.empty[StructType])
     var structTypeIndex = 0
 
-    val prunedSerializer = serializer.transformDown {
+    val transformedSerializer = serializer.transformDown {
+      case m: ExternalMapToCatalyst =>
+        val prunedValueConverter = m.valueConverter.transformDown {
+          case s: CreateNamedStruct if structTypeIndex < prunedStructTypes.size =>
+            val prunedType = prunedStructTypes(structTypeIndex)
+            structTypeIndex += 1
+            pruneNamedStruct(s, prunedType)
+        }
+        m.copy(valueConverter = alignNullTypeInIf(prunedValueConverter))
       case s: CreateNamedStruct if structTypeIndex < prunedStructTypes.size =>
         val prunedType = prunedStructTypes(structTypeIndex)
-
-        // Filters out the pruned fields.
-        val prunedFields = s.nameExprs.zip(s.valExprs).filter { case (nameExpr, _) =>
-          val name = nameExpr.eval(EmptyRow).toString
-          prunedType.fieldNames.exists { fieldName =>
-            if (SQLConf.get.caseSensitiveAnalysis) {
-              fieldName.equals(name)
-            } else {
-              fieldName.equalsIgnoreCase(name)
-            }
-          }
-        }.flatMap(pair => Seq(pair._1, pair._2))
-
         structTypeIndex += 1
-        CreateNamedStruct(prunedFields)
-    }.transformUp {
-      // When we change nested serializer data type, `If` expression will be unresolved because
-      // literal null's data type doesn't match now. We need to align it with new data type.
-      // Note: we should do `transformUp` explicitly to change data types.
-      case i @ If(_: IsNull, Literal(null, dt), ser) if !dt.sameType(ser.dataType) =>
-        i.copy(trueValue = Literal(null, ser.dataType))
-    }.asInstanceOf[NamedExpression]
+        pruneNamedStruct(s, prunedType)
+    }
+
+    val prunedSerializer = alignNullTypeInIf(transformedSerializer).asInstanceOf[NamedExpression]
 
     if (prunedSerializer.dataType.sameType(prunedDataType)) {
       prunedSerializer
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
index dee685a..fb0f3a3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
@@ -46,7 +46,10 @@ class ObjectSerializerPruningSuite extends PlanTest {
       ArrayType(IntegerType),
       StructType.fromDDL("a int, b int"),
       ArrayType(StructType.fromDDL("a int, b int, c string")),
-      StructType.fromDDL("a struct<a:int, b:int>, b int")
+      StructType.fromDDL("a struct<a:int, b:int>, b int"),
+      MapType(IntegerType, StructType.fromDDL("a int, b int, c string")),
+      MapType(StructType.fromDDL("a struct<a:int, b:int>, b int"), IntegerType),
+      MapType(StructType.fromDDL("a int, b int"), StructType.fromDDL("c long, d string"))
     )
 
     val expectedTypes = Seq(
@@ -55,7 +58,10 @@ class ObjectSerializerPruningSuite extends PlanTest {
       Seq(StructType.fromDDL("a int, b int")),
       Seq(StructType.fromDDL("a int, b int, c string")),
       Seq(StructType.fromDDL("a struct<a:int, b:int>, b int"),
-        StructType.fromDDL("a int, b int"))
+        StructType.fromDDL("a int, b int")),
+      Seq(StructType.fromDDL("a int, b int, c string")),
+      Seq.empty[StructType],
+      Seq(StructType.fromDDL("c long, d string"))
     )
 
     dataTypes.zipWithIndex.foreach { case (dt, idx) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala
index 2b1dbf0..69634f8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala
@@ -18,7 +18,9 @@
 package org.apache.spark.sql
 
 import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
+import org.apache.spark.sql.catalyst.expressions.objects.ExternalMapToCatalyst
 import org.apache.spark.sql.catalyst.plans.logical.SerializeFromObject
+import org.apache.spark.sql.functions.expr
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 
@@ -47,8 +49,12 @@ class DatasetOptimizationSuite extends QueryTest with SharedSQLContext {
 
     serializer.serializer.zip(structFields).foreach { case (serializer, fields) =>
       val structs = serializer.collect {
-        case c: CreateNamedStruct => c
-      }
+        case c: CreateNamedStruct => Seq(c)
+        case m: ExternalMapToCatalyst =>
+          m.valueConverter.collect {
+            case c: CreateNamedStruct => c
+          }
+      }.flatten
       assert(structs.size == fields.size)
       structs.zip(fields).foreach { case (struct, fieldNames) =>
         assert(struct.names.map(_.toString) == fieldNames)
@@ -104,4 +110,19 @@ class DatasetOptimizationSuite extends QueryTest with SharedSQLContext {
         Row(Seq("c", "d"), Seq(33, 44), "bb")))
     }
   }
+
+  test("Prune nested serializers: map of struct") {
+    withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") {
+      val mapData = Seq((Map(("k", ("a_1", 11))), 1), (Map(("k", ("b_1", 22))), 2),
+        (Map(("k", ("c_1", 33))), 3))
+      val mapDs = mapData.toDS().map(t => (t._1, t._2 + 1))
+      val df1 = mapDs.select("_1.k._1")
+      testSerializer(df1, Seq(Seq("_1")))
+      checkAnswer(df1, Seq(Row("a_1"), Row("b_1"), Row("c_1")))
+
+      val df2 = mapDs.select("_1.k._2")
+      testSerializer(df2, Seq(Seq("_2")))
+      checkAnswer(df2, Seq(Row(11), Row(22), Row(33)))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org