You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2022/06/10 22:49:39 UTC

[spark] branch branch-3.2 updated: [SPARK-38918][SQL][3.2] Nested column pruning should filter out attributes that do not belong to the current relation

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 7c0b9e6e6f6 [SPARK-38918][SQL][3.2] Nested column pruning should filter out attributes that do not belong to the current relation
7c0b9e6e6f6 is described below

commit 7c0b9e6e6f680db45c1e2602b85753d9b521bb58
Author: allisonwang-db <al...@databricks.com>
AuthorDate: Fri Jun 10 15:49:24 2022 -0700

    [SPARK-38918][SQL][3.2] Nested column pruning should filter out attributes that do not belong to the current relation
    
    ### What changes were proposed in this pull request?
    
    Backport #36216 to branch-3.2
    
    ### Why are the changes needed?
    
    To fix a bug in `SchemaPruning`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit test
    
    Closes #36386 from allisonwang-db/spark-38918-branch-3.2.
    
    Authored-by: allisonwang-db <al...@databricks.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../expressions/ProjectionOverSchema.scala         |  8 +++-
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  1 +
 .../spark/sql/catalyst/optimizer/objects.scala     |  2 +-
 .../sql/execution/datasources/SchemaPruning.scala  |  2 +-
 .../datasources/v2/V2ScanRelationPushDown.scala    |  5 ++-
 .../execution/datasources/SchemaPruningSuite.scala | 45 +++++++++++++++++++++-
 6 files changed, 56 insertions(+), 7 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
index a6be98c8a3a..69d30dd5048 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
@@ -24,15 +24,19 @@ import org.apache.spark.sql.types._
  * field indexes and field counts of complex type extractors and attributes
  * are adjusted to fit the schema. All other expressions are left as-is. This
  * class is motivated by columnar nested schema pruning.
+ *
+ * @param schema nested column schema
+ * @param output output attributes of the data source relation. They are used to filter out
+ *               attributes in the schema that do not belong to the current relation.
  */
-case class ProjectionOverSchema(schema: StructType) {
+case class ProjectionOverSchema(schema: StructType, output: AttributeSet) {
   private val fieldNames = schema.fieldNames.toSet
 
   def unapply(expr: Expression): Option[Expression] = getProjection(expr)
 
   private def getProjection(expr: Expression): Option[Expression] =
     expr match {
-      case a: AttributeReference if fieldNames.contains(a.name) =>
+      case a: AttributeReference if fieldNames.contains(a.name) && output.contains(a) =>
         Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier))
       case GetArrayItem(child, arrayItemOrdinal, failOnError) =>
         getProjection(child).map {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 4fe00ef0eed..80af6c03745 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -60,6 +60,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
   override protected val excludedOnceBatches: Set[String] =
     Set(
       "PartitionPruning",
+      "RewriteSubquery",
       "Extract Python UDFs")
 
   protected def fixedPoint =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
index 52544ff3e24..ec64895fc30 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
@@ -229,7 +229,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
         }
 
         // Builds new projection.
-        val projectionOverSchema = ProjectionOverSchema(prunedSchema)
+        val projectionOverSchema = ProjectionOverSchema(prunedSchema, AttributeSet(s.output))
         val newProjects = p.projectList.map(_.transformDown {
           case projectionOverSchema(expr) => expr
         }).map { case expr: NamedExpression => expr }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
index 4f331c7bf48..bf3b54a297c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
@@ -81,7 +81,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
       // in dataSchema.
       if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) {
         val prunedRelation = leafNodeBuilder(prunedDataSchema)
-        val projectionOverSchema = ProjectionOverSchema(prunedDataSchema)
+        val projectionOverSchema = ProjectionOverSchema(prunedDataSchema, AttributeSet(output))
 
         Some(buildNewProjection(projects, normalizedProjects, normalizedFilters,
           prunedRelation, projectionOverSchema))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 046155b55cc..9178c840c20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import scala.collection.mutable
 
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
@@ -199,7 +199,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
 
       val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
 
-      val projectionOverSchema = ProjectionOverSchema(output.toStructType)
+      val projectionOverSchema =
+        ProjectionOverSchema(output.toStructType, AttributeSet(output))
       val projectionFunc = (expr: Expression) => expr transformDown {
         case projectionOverSchema(newExpr) => newExpr
       }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 0b745f18768..3062c7e6480 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -57,11 +57,15 @@ abstract class SchemaPruningSuite
     contactId: Int,
     employer: Employer)
 
+  case class Employee(id: Int, name: FullName, employer: Company)
+
   val janeDoe = FullName("Jane", "X.", "Doe")
   val johnDoe = FullName("John", "Y.", "Doe")
   val susanSmith = FullName("Susan", "Z.", "Smith")
 
-  val employer = Employer(0, Company("abc", "123 Business Street"))
+  val company = Company("abc", "123 Business Street")
+
+  val employer = Employer(0, company)
   val employerWithNullCompany = Employer(1, null)
   val employerWithNullCompany2 = Employer(2, null)
 
@@ -77,6 +81,8 @@ abstract class SchemaPruningSuite
     Department(1, "Marketing", 1, employerWithNullCompany) ::
     Department(2, "Operation", 4, employerWithNullCompany2) :: Nil
 
+  val employees = Employee(0, janeDoe, company) :: Employee(1, johnDoe, company) :: Nil
+
   case class Name(first: String, last: String)
   case class BriefContact(id: Int, name: Name, address: String)
 
@@ -617,6 +623,26 @@ abstract class SchemaPruningSuite
     }
   }
 
+  testSchemaPruning("SPARK-38918: nested schema pruning with correlated subqueries") {
+    withContacts {
+      withEmployees {
+        val query = sql(
+          """
+            |select count(*)
+            |from contacts c
+            |where not exists (select null from employees e where e.name.first = c.name.first
+            |  and e.employer.name = c.employer.company.name)
+            |""".stripMargin)
+        checkScan(query,
+          "struct<name:struct<first:string,middle:string,last:string>," +
+            "employer:struct<id:int,company:struct<name:string,address:string>>>",
+          "struct<name:struct<first:string,middle:string,last:string>," +
+            "employer:struct<name:string,address:string>>")
+        checkAnswer(query, Row(3))
+      }
+    }
+  }
+
   protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = {
     test(s"Spark vectorized reader - without partition data column - $testName") {
       withSQLConf(vectorizedReaderEnabledKey -> "true") {
@@ -697,6 +723,23 @@ abstract class SchemaPruningSuite
     }
   }
 
+  private def withEmployees(testThunk: => Unit): Unit = {
+    withTempPath { dir =>
+      val path = dir.getCanonicalPath
+
+      makeDataSourceFile(employees, new File(path + "/employees"))
+
+      // Providing user specified schema. Inferred schema from different data sources might
+      // be different.
+      val schema = "`id` INT,`name` STRUCT<`first`: STRING, `middle`: STRING, `last`: STRING>, " +
+        "`employer` STRUCT<`name`: STRING, `address`: STRING>"
+      spark.read.format(dataSourceName).schema(schema).load(path + "/employees")
+        .createOrReplaceTempView("employees")
+
+      testThunk
+    }
+  }
+
   case class MixedCaseColumn(a: String, B: Int)
   case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org