You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/01/18 05:48:57 UTC

[spark] branch master updated: [SPARK-37768][SQL][FOLLOWUP] Schema pruning for the metadata struct

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 54f91d3  [SPARK-37768][SQL][FOLLOWUP] Schema pruning for the metadata struct
54f91d3 is described below

commit 54f91d391acd2995defc1b5666dc0bb95100a575
Author: yaohua <ya...@databricks.com>
AuthorDate: Tue Jan 18 13:48:06 2022 +0800

    [SPARK-37768][SQL][FOLLOWUP] Schema pruning for the metadata struct
    
    ### What changes were proposed in this pull request?
    Follow-up PR of #34575. Support the metadata struct schema pruning for all file formats.
    
    ### Why are the changes needed?
    Performance improvements.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing UTs and a new UT.
    
    Closes #35147 from Yaohua628/spark-37768.
    
    Authored-by: yaohua <ya...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/expressions/SchemaPruning.scala   |  8 +-
 .../catalyst/expressions/namedExpressions.scala    |  2 +-
 .../spark/sql/catalyst/optimizer/objects.scala     |  2 +-
 .../catalyst/expressions/SchemaPruningSuite.scala  |  4 +-
 .../execution/datasources/FileSourceStrategy.scala |  3 +-
 .../sql/execution/datasources/SchemaPruning.scala  | 93 ++++++++++++----------
 .../execution/datasources/v2/PushDownUtils.scala   |  2 +-
 .../datasources/FileMetadataStructSuite.scala      | 48 +++++++++++
 8 files changed, 107 insertions(+), 55 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
index 2a182b6..fd5b2db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
@@ -33,8 +33,8 @@ object SchemaPruning extends SQLConfHelper {
    *   1. The schema field ordering at original schema is still preserved in pruned schema.
    *   2. The top-level fields are not pruned here.
    */
-  def pruneDataSchema(
-      dataSchema: StructType,
+  def pruneSchema(
+      schema: StructType,
       requestedRootFields: Seq[RootField]): StructType = {
     val resolver = conf.resolver
     // Merge the requested root fields into a single schema. Note the ordering of the fields
@@ -44,10 +44,10 @@ object SchemaPruning extends SQLConfHelper {
       .map { root: RootField => StructType(Array(root.field)) }
       .reduceLeft(_ merge _)
     val mergedDataSchema =
-      StructType(dataSchema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d)))
+      StructType(schema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d)))
     // Sort the fields of mergedDataSchema according to their order in dataSchema,
     // recursively. This makes mergedDataSchema a pruned schema of dataSchema
-    sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
+    sortLeftFieldsByRight(mergedDataSchema, schema).asInstanceOf[StructType]
   }
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index c51030f..a099fad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -342,7 +342,7 @@ case class AttributeReference(
     AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
   }
 
-  override def withDataType(newType: DataType): Attribute = {
+  override def withDataType(newType: DataType): AttributeReference = {
     AttributeReference(name, newType, nullable, metadata)(exprId, qualifier)
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
index 52544ff..c347a2e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
@@ -222,7 +222,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
 
       if (conf.serializerNestedSchemaPruningEnabled && rootFields.nonEmpty) {
         // Prunes nested fields in serializers.
-        val prunedSchema = SchemaPruning.pruneDataSchema(
+        val prunedSchema = SchemaPruning.pruneSchema(
           StructType.fromAttributes(prunedSerializer.map(_.toAttribute)), rootFields)
         val nestedPrunedSerializer = prunedSerializer.zipWithIndex.map { case (serializer, idx) =>
           pruneSerializer(serializer, prunedSchema(idx).dataType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
index c67a962..b64bc49 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
@@ -31,7 +31,7 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
       // `derivedFromAtt` doesn't affect the result of pruned schema.
       SchemaPruning.RootField(field = f, derivedFromAtt = true)
     }
-    val prunedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields)
+    val prunedSchema = SchemaPruning.pruneSchema(schema, requestedRootFields)
     assert(prunedSchema === expectedSchema)
   }
 
@@ -140,7 +140,7 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
     assert(field.metadata.getString("foo") == "bar")
 
     val schema = StructType(Seq(field))
-    val prunedSchema = SchemaPruning.pruneDataSchema(schema, rootFields)
+    val prunedSchema = SchemaPruning.pruneSchema(schema, rootFields)
     assert(prunedSchema.head.metadata.getString("foo") == "bar")
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index c1282fa..5df8057 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -213,11 +213,10 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
       val outputSchema = readDataColumns.toStructType
       logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}")
 
-      val metadataStructOpt = requiredAttributes.collectFirst {
+      val metadataStructOpt = l.output.collectFirst {
         case MetadataAttribute(attr) => attr
       }
 
-      // TODO (yaohua): should be able to prune the metadata struct only containing what needed
       val metadataColumns = metadataStructOpt.map { metadataStruct =>
         metadataStruct.dataType.asInstanceOf[StructType].fields.map { field =>
           MetadataAttribute(field.name, field.dataType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
index 93bd1ac..9dd2f40 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
@@ -31,58 +31,68 @@ import org.apache.spark.sql.util.SchemaUtils._
  * By "physical column", we mean a column as defined in the data source format like Parquet format
  * or ORC format. For example, in Spark SQL, a root-level Parquet column corresponds to a SQL
  * column, and a nested Parquet column corresponds to a [[StructField]].
+ *
+ * Also prunes the unnecessary metadata columns if any for all file formats.
  */
 object SchemaPruning extends Rule[LogicalPlan] {
   import org.apache.spark.sql.catalyst.expressions.SchemaPruning._
 
   override def apply(plan: LogicalPlan): LogicalPlan =
-    if (conf.nestedSchemaPruningEnabled) {
-      apply0(plan)
-    } else {
-      plan
-    }
-
-  private def apply0(plan: LogicalPlan): LogicalPlan =
     plan transformDown {
       case op @ PhysicalOperation(projects, filters,
-          l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _))
-        if canPruneRelation(hadoopFsRelation) =>
-
-        prunePhysicalColumns(l.output, projects, filters, hadoopFsRelation.dataSchema,
-          prunedDataSchema => {
+      l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _)) =>
+        prunePhysicalColumns(l, projects, filters, hadoopFsRelation,
+          (prunedDataSchema, prunedMetadataSchema) => {
             val prunedHadoopRelation =
               hadoopFsRelation.copy(dataSchema = prunedDataSchema)(hadoopFsRelation.sparkSession)
-            buildPrunedRelation(l, prunedHadoopRelation)
+            buildPrunedRelation(l, prunedHadoopRelation, prunedMetadataSchema)
           }).getOrElse(op)
     }
 
   /**
    * This method returns optional logical plan. `None` is returned if no nested field is required or
    * all nested fields are required.
+   *
+   * This method will prune both the data schema and the metadata schema
    */
   private def prunePhysicalColumns(
-      output: Seq[AttributeReference],
+      relation: LogicalRelation,
       projects: Seq[NamedExpression],
       filters: Seq[Expression],
-      dataSchema: StructType,
-      leafNodeBuilder: StructType => LeafNode): Option[LogicalPlan] = {
+      hadoopFsRelation: HadoopFsRelation,
+      leafNodeBuilder: (StructType, StructType) => LeafNode): Option[LogicalPlan] = {
+
     val (normalizedProjects, normalizedFilters) =
-      normalizeAttributeRefNames(output, projects, filters)
+      normalizeAttributeRefNames(relation.output, projects, filters)
     val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters)
 
     // If requestedRootFields includes a nested field, continue. Otherwise,
     // return op
     if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) {
-      val prunedDataSchema = pruneDataSchema(dataSchema, requestedRootFields)
 
-      // If the data schema is different from the pruned data schema, continue. Otherwise,
-      // return op. We effect this comparison by counting the number of "leaf" fields in
-      // each schemata, assuming the fields in prunedDataSchema are a subset of the fields
-      // in dataSchema.
-      if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) {
-        val prunedRelation = leafNodeBuilder(prunedDataSchema)
-        val projectionOverSchema = ProjectionOverSchema(prunedDataSchema)
+      val prunedDataSchema = if (canPruneDataSchema(hadoopFsRelation)) {
+        pruneSchema(hadoopFsRelation.dataSchema, requestedRootFields)
+      } else {
+        hadoopFsRelation.dataSchema
+      }
+
+      val metadataSchema =
+        relation.output.collect { case MetadataAttribute(attr) => attr }.toStructType
+      val prunedMetadataSchema = if (metadataSchema.nonEmpty) {
+        pruneSchema(metadataSchema, requestedRootFields)
+      } else {
+        metadataSchema
+      }
 
+      // If the data schema is different from the pruned data schema
+      // OR
+      // the metadata schema is different from the pruned metadata schema, continue.
+      // Otherwise, return None.
+      if (countLeaves(hadoopFsRelation.dataSchema) > countLeaves(prunedDataSchema) ||
+        countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema)) {
+        val prunedRelation = leafNodeBuilder(prunedDataSchema, prunedMetadataSchema)
+        val projectionOverSchema =
+          ProjectionOverSchema(prunedDataSchema.merge(prunedMetadataSchema))
         Some(buildNewProjection(projects, normalizedProjects, normalizedFilters,
           prunedRelation, projectionOverSchema))
       } else {
@@ -96,9 +106,10 @@ object SchemaPruning extends Rule[LogicalPlan] {
   /**
    * Checks to see if the given relation can be pruned. Currently we support Parquet and ORC v1.
    */
-  private def canPruneRelation(fsRelation: HadoopFsRelation) =
-    fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] ||
-      fsRelation.fileFormat.isInstanceOf[OrcFileFormat]
+  private def canPruneDataSchema(fsRelation: HadoopFsRelation): Boolean =
+    conf.nestedSchemaPruningEnabled && (
+      fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] ||
+        fsRelation.fileFormat.isInstanceOf[OrcFileFormat])
 
   /**
    * Normalizes the names of the attribute references in the given projects and filters to reflect
@@ -162,29 +173,25 @@ object SchemaPruning extends Rule[LogicalPlan] {
    */
   private def buildPrunedRelation(
       outputRelation: LogicalRelation,
-      prunedBaseRelation: HadoopFsRelation) = {
-    val prunedOutput = getPrunedOutput(outputRelation.output, prunedBaseRelation.schema)
-    // also add the metadata output if any
-    // TODO: should be able to prune the metadata schema
-    val metaOutput = outputRelation.output.collect {
-      case MetadataAttribute(attr) => attr
-    }
-    outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput ++ metaOutput)
+      prunedBaseRelation: HadoopFsRelation,
+      prunedMetadataSchema: StructType) = {
+    val finalSchema = prunedBaseRelation.schema.merge(prunedMetadataSchema)
+    val prunedOutput = getPrunedOutput(outputRelation.output, finalSchema)
+    outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput)
   }
 
   // Prune the given output to make it consistent with `requiredSchema`.
   private def getPrunedOutput(
       output: Seq[AttributeReference],
       requiredSchema: StructType): Seq[AttributeReference] = {
-    // We need to replace the expression ids of the pruned relation output attributes
-    // with the expression ids of the original relation output attributes so that
-    // references to the original relation's output are not broken
-    val outputIdMap = output.map(att => (att.name, att.exprId)).toMap
+    // We need to update the data type of the output attributes to use the pruned ones.
+    // so that references to the original relation's output are not broken
+    val nameAttributeMap = output.map(att => (att.name, att)).toMap
     requiredSchema
       .toAttributes
       .map {
-        case att if outputIdMap.contains(att.name) =>
-          att.withExprId(outputIdMap(att.name))
+        case att if nameAttributeMap.contains(att.name) =>
+          nameAttributeMap(att.name).withDataType(att.dataType)
         case att => att
       }
   }
@@ -203,6 +210,4 @@ object SchemaPruning extends Rule[LogicalPlan] {
       case _ => 1
     }
   }
-
-
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index db7b3dc..29d86b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -187,7 +187,7 @@ object PushDownUtils extends PredicateHelper {
       case r: SupportsPushDownRequiredColumns if SQLConf.get.nestedSchemaPruningEnabled =>
         val rootFields = SchemaPruning.identifyRootFields(projects, filters)
         val prunedSchema = if (rootFields.nonEmpty) {
-          SchemaPruning.pruneDataSchema(relation.schema, rootFields)
+          SchemaPruning.pruneSchema(relation.schema, rootFields)
         } else {
           new StructType()
         }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala
index fffac88..8bf5d61 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala
@@ -22,6 +22,7 @@ import java.sql.Timestamp
 import java.text.SimpleDateFormat
 
 import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row}
+import org.apache.spark.sql.execution.FileSourceScanExec
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
@@ -384,4 +385,51 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession {
       }
     }
   }
+
+  metadataColumnsTest("prune metadata schema in projects", schema) { (df, f0, f1) =>
+    val prunedDF = df.select("name", "age", "info.id", METADATA_FILE_NAME)
+    val fileSourceScanMetaCols = prunedDF.queryExecution.sparkPlan.collectFirst {
+      case p: FileSourceScanExec => p.metadataColumns
+    }.get
+    assert(fileSourceScanMetaCols.size == 1)
+    assert(fileSourceScanMetaCols.head.name == "file_name")
+
+    checkAnswer(
+      prunedDF,
+      Seq(Row("jack", 24, 12345L, f0(METADATA_FILE_NAME)),
+        Row("lily", 31, 54321L, f1(METADATA_FILE_NAME)))
+    )
+  }
+
+  metadataColumnsTest("prune metadata schema in filters", schema) { (df, f0, f1) =>
+    val prunedDF = df.select("name", "age", "info.id")
+      .where(col(METADATA_FILE_PATH).contains("data/f0"))
+
+    val fileSourceScanMetaCols = prunedDF.queryExecution.sparkPlan.collectFirst {
+      case p: FileSourceScanExec => p.metadataColumns
+    }.get
+    assert(fileSourceScanMetaCols.size == 1)
+    assert(fileSourceScanMetaCols.head.name == "file_path")
+
+    checkAnswer(
+      prunedDF,
+      Seq(Row("jack", 24, 12345L))
+    )
+  }
+
+  metadataColumnsTest("prune metadata schema in projects and filters", schema) { (df, f0, f1) =>
+    val prunedDF = df.select("name", "age", "info.id", METADATA_FILE_SIZE)
+      .where(col(METADATA_FILE_PATH).contains("data/f0"))
+
+    val fileSourceScanMetaCols = prunedDF.queryExecution.sparkPlan.collectFirst {
+      case p: FileSourceScanExec => p.metadataColumns
+    }.get
+    assert(fileSourceScanMetaCols.size == 2)
+    assert(fileSourceScanMetaCols.map(_.name).toSet == Set("file_size", "file_path"))
+
+    checkAnswer(
+      prunedDF,
+      Seq(Row("jack", 24, 12345L, f0(METADATA_FILE_SIZE)))
+    )
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org