You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by vi...@apache.org on 2021/04/26 16:33:12 UTC

[spark] branch master updated: [SPARK-34638][SQL] Single field nested column prune on generator output

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c59988a  [SPARK-34638][SQL] Single field nested column prune on generator output
c59988a is described below

commit c59988aa799a2aab1cd231ae2b1cf90f1a2b1af8
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Apr 26 09:32:22 2021 -0700

    [SPARK-34638][SQL] Single field nested column prune on generator output
    
    ### What changes were proposed in this pull request?
    
    This patch proposes an improvement on nested column pruning if the pruning target is generator's output. Previously we disallow such case. This patch allows to prune on it if there is only one single nested column is accessed after `Generate`.
    
    E.g., `df.select(explode($"items").as('item)).select($"item.itemId")`. As we only need `itemId` from `item`, we can prune other fields out and only keep `itemId`.
    
    In this patch, we only address explode-like generators. We will address other generators in followups.
    
    ### Why are the changes needed?
    
    This helps to extend the availability of nested column pruning.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit test
    
    Closes #31966 from viirya/SPARK-34638.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
---
 .../expressions/ProjectionOverSchema.scala         |  3 +-
 .../catalyst/optimizer/NestedColumnAliasing.scala  | 98 +++++++++++++++++++++-
 .../optimizer/NestedColumnAliasingSuite.scala      | 30 ++++++-
 .../execution/datasources/SchemaPruningSuite.scala | 54 ++++++++++++
 4 files changed, 177 insertions(+), 8 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
index 03b5517..a6be98c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
@@ -42,7 +42,8 @@ case class ProjectionOverSchema(schema: StructType) {
         getProjection(a.child).map(p => (p, p.dataType)).map {
           case (projection, ArrayType(projSchema @ StructType(_), _)) =>
             // For case-sensitivity aware field resolution, we should take `ordinal` which
-            // points to correct struct field.
+            // points to correct struct field, because `ExtractValue` actually does column
+            // name resolving correctly.
             val selectedField = a.child.dataType.asInstanceOf[ArrayType]
               .elementType.asInstanceOf[StructType](a.ordinal)
             val prunedField = projSchema(selectedField.name)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
index 0be2792..5b12667 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
@@ -231,6 +231,27 @@ object NestedColumnAliasing {
  * of it.
  */
 object GeneratorNestedColumnAliasing {
+  // Partitions `attrToAliases` based on whether the attribute is in Generator's output.
+  private def aliasesOnGeneratorOutput(
+      attrToAliases: Map[ExprId, Seq[Alias]],
+      generatorOutput: Seq[Attribute]) = {
+    val generatorOutputExprId = generatorOutput.map(_.exprId)
+    attrToAliases.partition { k =>
+      generatorOutputExprId.contains(k._1)
+    }
+  }
+
+  // Partitions `nestedFieldToAlias` based on whether the attribute of nested field extractor
+  // is in Generator's output.
+  private def nestedFieldOnGeneratorOutput(
+      nestedFieldToAlias: Map[ExtractValue, Alias],
+      generatorOutput: Seq[Attribute]) = {
+    val generatorOutputSet = AttributeSet(generatorOutput)
+    nestedFieldToAlias.partition { pair =>
+      pair._1.references.subsetOf(generatorOutputSet)
+    }
+  }
+
   def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
     // Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is enabled, we
     // need to prune nested columns through Project and under Generate. The difference is
@@ -241,12 +262,81 @@ object GeneratorNestedColumnAliasing {
       // On top on `Generate`, a `Project` that might have nested column accessors.
       // We try to get alias maps for both project list and generator's children expressions.
       val exprsToPrune = projectList ++ g.generator.children
-      NestedColumnAliasing.getAliasSubMap(exprsToPrune, g.qualifiedGeneratorOutput).map {
+      NestedColumnAliasing.getAliasSubMap(exprsToPrune).map {
         case (nestedFieldToAlias, attrToAliases) =>
+          val (nestedFieldsOnGenerator, nestedFieldsNotOnGenerator) =
+            nestedFieldOnGeneratorOutput(nestedFieldToAlias, g.qualifiedGeneratorOutput)
+          val (attrToAliasesOnGenerator, attrToAliasesNotOnGenerator) =
+            aliasesOnGeneratorOutput(attrToAliases, g.qualifiedGeneratorOutput)
+
+          // Push nested column accessors through `Generator`.
           // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
-          val newChild =
-            NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases)
-          Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild)
+          val newChild = NestedColumnAliasing.replaceWithAliases(g,
+            nestedFieldsNotOnGenerator, attrToAliasesNotOnGenerator)
+          val pushedThrough = Project(NestedColumnAliasing
+            .getNewProjectList(projectList, nestedFieldsNotOnGenerator), newChild)
+
+          // If the generator output is `ArrayType`, we cannot push through the extractor.
+          // It is because we don't allow field extractor on two-level array,
+          // i.e., attr.field when attr is a ArrayType(ArrayType(...)).
+          // Similarily, we also cannot push through if the child of generator is `MapType`.
+          g.generator.children.head.dataType match {
+            case _: MapType => return Some(pushedThrough)
+            case ArrayType(_: ArrayType, _) => return Some(pushedThrough)
+            case _ =>
+          }
+
+          // Pruning on `Generator`'s output. We only process single field case.
+          // For multiple field case, we cannot directly move field extractor into
+          // the generator expression. A workaround is to re-construct array of struct
+          // from multiple fields. But it will be more complicated and may not worth.
+          // TODO(SPARK-34956): support multiple fields.
+          if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.isEmpty) {
+            pushedThrough
+          } else {
+            // Only one nested column accessor.
+            // E.g., df.select(explode($"items").as("item")).select($"item.a")
+            pushedThrough match {
+              case p @ Project(_, newG: Generate) =>
+                // Replace the child expression of `ExplodeBase` generator with
+                // nested column accessor.
+                // E.g., df.select(explode($"items").as("item")).select($"item.a") =>
+                //       df.select(explode($"items.a").as("item.a"))
+                val rewrittenG = newG.transformExpressions {
+                  case e: ExplodeBase =>
+                    val extractor = nestedFieldsOnGenerator.head._1.transformUp {
+                      case _: Attribute =>
+                        e.child
+                      case g: GetStructField =>
+                        ExtractValue(g.child, Literal(g.extractFieldName), SQLConf.get.resolver)
+                    }
+                    e.withNewChildren(Seq(extractor))
+                }
+
+                // As we change the child of the generator, its output data type must be updated.
+                val updatedGeneratorOutput = rewrittenG.generatorOutput
+                  .zip(rewrittenG.generator.elementSchema.toAttributes)
+                  .map { case (oldAttr, newAttr) =>
+                  newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
+                }
+                assert(updatedGeneratorOutput.length == rewrittenG.generatorOutput.length,
+                  "Updated generator output must have the same length " +
+                    "with original generator output.")
+                val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput)
+
+                // Replace nested column accessor with generator output.
+                p.withNewChildren(Seq(updatedGenerate)).transformExpressions {
+                  case f: ExtractValue if nestedFieldsOnGenerator.contains(f) =>
+                    updatedGenerate.output
+                      .find(a => attrToAliasesOnGenerator.contains(a.exprId))
+                      .getOrElse(f)
+                }
+
+              case other =>
+                // We should not reach here.
+                throw new IllegalStateException(s"Unreasonable plan after optimization: $other")
+            }
+          }
       }
 
     case g: Generate if SQLConf.get.nestedSchemaPruningEnabled &&
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
index 0ae4d3f..a856caa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
@@ -329,14 +329,14 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
     comparePlans(optimized, expected)
   }
 
-  test("Nested field pruning for Project and Generate: not prune on generator output") {
+  test("Nested field pruning for Project and Generate: multiple-field case is not supported") {
     val companies = LocalRelation(
       'id.int,
       'employers.array(employer))
 
     val query = companies
       .generate(Explode('employers.getField("company")), outputNames = Seq("company"))
-      .select('company.getField("name"))
+      .select('company.getField("name"), 'company.getField("address"))
       .analyze
     val optimized = Optimize.execute(query)
 
@@ -347,7 +347,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
       .generate(Explode($"${aliases(0)}"),
         unrequiredChildIndex = Seq(0),
         outputNames = Seq("company"))
-      .select('company.getField("name").as("company.name"))
+      .select('company.getField("name").as("company.name"),
+        'company.getField("address").as("company.address"))
       .analyze
     comparePlans(optimized, expected)
   }
@@ -684,6 +685,29 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
     ).analyze
     comparePlans(optimized2, expected2)
   }
+
+  test("SPARK-34638: nested column prune on generator output for one field") {
+    val companies = LocalRelation(
+      'id.int,
+      'employers.array(employer))
+
+    val query = companies
+      .generate(Explode('employers.getField("company")), outputNames = Seq("company"))
+      .select('company.getField("name"))
+      .analyze
+    val optimized = Optimize.execute(query)
+
+    val aliases = collectGeneratedAliases(optimized)
+
+    val expected = companies
+      .select('employers.getField("company").getField("name").as(aliases(0)))
+      .generate(Explode($"${aliases(0)}"),
+        unrequiredChildIndex = Seq(0),
+        outputNames = Seq("company"))
+      .select('company.as("company.name"))
+      .analyze
+    comparePlans(optimized, expected)
+  }
 }
 
 object NestedColumnAliasingSuite {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 765d2fc..ac5c289 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -351,6 +351,43 @@ abstract class SchemaPruningSuite
     }
   }
 
+  testSchemaPruning("SPARK-34638: nested column prune on generator output") {
+    val query1 = spark.table("contacts")
+      .select(explode(col("friends")).as("friend"))
+      .select("friend.first")
+    checkScan(query1, "struct<friends:array<struct<first:string>>>")
+    checkAnswer(query1, Row("Susan") :: Nil)
+
+    // Currently we don't prune multiple field case.
+    val query2 = spark.table("contacts")
+      .select(explode(col("friends")).as("friend"))
+      .select("friend.first", "friend.middle")
+    checkScan(query2, "struct<friends:array<struct<first:string,middle:string,last:string>>>")
+    checkAnswer(query2, Row("Susan", "Z.") :: Nil)
+
+    val query3 = spark.table("contacts")
+      .select(explode(col("friends")).as("friend"))
+      .select("friend.first", "friend.middle", "friend")
+    checkScan(query3, "struct<friends:array<struct<first:string,middle:string,last:string>>>")
+    checkAnswer(query3, Row("Susan", "Z.", Row("Susan", "Z.", "Smith")) :: Nil)
+  }
+
+  testSchemaPruning("SPARK-34638: nested column prune on generator output - case-sensitivity") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+      val query1 = spark.table("contacts")
+        .select(explode(col("friends")).as("friend"))
+        .select("friend.First")
+      checkScan(query1, "struct<friends:array<struct<first:string>>>")
+      checkAnswer(query1, Row("Susan") :: Nil)
+
+      val query2 = spark.table("contacts")
+        .select(explode(col("friends")).as("friend"))
+        .select("friend.MIDDLE")
+      checkScan(query2, "struct<friends:array<struct<middle:string>>>")
+      checkAnswer(query2, Row("Z.") :: Nil)
+    }
+  }
+
   testSchemaPruning("select one deep nested complex field after repartition") {
     val query = sql("select * from contacts")
       .repartition(100)
@@ -816,4 +853,21 @@ abstract class SchemaPruningSuite
           Row("John", "Y.") :: Nil)
     }
   }
+
+  test("SPARK-34638: queries should not fail on unsupported cases") {
+    withTable("nested_array") {
+      sql("select * from values array(array(named_struct('a', 1, 'b', 3), " +
+        "named_struct('a', 2, 'b', 4))) T(items)").write.saveAsTable("nested_array")
+      val query = sql("select d.a from (select explode(c) d from " +
+        "(select explode(items) c from nested_array))")
+      checkAnswer(query, Row(1) :: Row(2) :: Nil)
+    }
+
+    withTable("map") {
+      sql("select * from values map(1, named_struct('a', 1, 'b', 3), " +
+        "2, named_struct('a', 2, 'b', 4)) T(items)").write.saveAsTable("map")
+      val query = sql("select d.a from (select explode(items) (c, d) from map)")
+      checkAnswer(query, Row(1) :: Row(2) :: Nil)
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org