You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/03/23 07:24:22 UTC

[spark] branch branch-3.3 updated: [SPARK-38533][SQL] DS V2 aggregate push-down supports project with alias

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 6e5f181  [SPARK-38533][SQL] DS V2 aggregate push-down supports project with alias
6e5f181 is described below

commit 6e5f1811b180868303ea0ee2f44309c3a5ef914c
Author: Jiaan Geng <be...@163.com>
AuthorDate: Wed Mar 23 15:22:48 2022 +0800

    [SPARK-38533][SQL] DS V2 aggregate push-down supports project with alias
    
    ### What changes were proposed in this pull request?
    Currently, Spark DS V2 aggregate push-down doesn't supports project with alias.
    
    Refer https://github.com/apache/spark/blob/c91c2e9afec0d5d5bbbd2e155057fe409c5bb928/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala#L96
    
    This PR let it works good with alias.
    
    **The first example:**
    the origin plan show below:
    ```
    Aggregate [DEPT#0], [DEPT#0, sum(mySalary#8) AS total#14]
    +- Project [DEPT#0, SALARY#2 AS mySalary#8]
       +- ScanBuilderHolder [DEPT#0, NAME#1, SALARY#2, BONUS#3], RelationV2[DEPT#0, NAME#1, SALARY#2, BONUS#3] test.employee, JDBCScanBuilder(org.apache.spark.sql.test.TestSparkSession77978658,StructType(StructField(DEPT,IntegerType,true),StructField(NAME,StringType,true),StructField(SALARY,DecimalType(20,2),true),StructField(BONUS,DoubleType,true)),org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions5f8da82)
    ```
    If we can complete push down the aggregate, then the plan will be:
    ```
    Project [DEPT#0, SUM(SALARY)#18 AS sum(SALARY#2)#13 AS total#14]
    +- RelationV2[DEPT#0, SUM(SALARY)#18] test.employee
    ```
    If we can partial push down the aggregate, then the plan will be:
    ```
    Aggregate [DEPT#0], [DEPT#0, sum(cast(SUM(SALARY)#18 as decimal(20,2))) AS total#14]
    +- RelationV2[DEPT#0, SUM(SALARY)#18] test.employee
    ```
    
    **The second example:**
    the origin plan show below:
    ```
    Aggregate [myDept#33], [myDept#33, sum(mySalary#34) AS total#40]
    +- Project [DEPT#25 AS myDept#33, SALARY#27 AS mySalary#34]
       +- ScanBuilderHolder [DEPT#25, NAME#26, SALARY#27, BONUS#28], RelationV2[DEPT#25, NAME#26, SALARY#27, BONUS#28] test.employee, JDBCScanBuilder(org.apache.spark.sql.test.TestSparkSession25c4f621,StructType(StructField(DEPT,IntegerType,true),StructField(NAME,StringType,true),StructField(SALARY,DecimalType(20,2),true),StructField(BONUS,DoubleType,true)),org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions345d641e)
    ```
    If we can complete push down the aggregate, then the plan will be:
    ```
    Project [DEPT#25 AS myDept#33, SUM(SALARY)#44 AS sum(SALARY#27)#39 AS total#40]
    +- RelationV2[DEPT#25, SUM(SALARY)#44] test.employee
    ```
    If we can partial push down the aggregate, then the plan will be:
    ```
    Aggregate [myDept#33], [DEPT#25 AS myDept#33, sum(cast(SUM(SALARY)#56 as decimal(20,2))) AS total#52]
    +- RelationV2[DEPT#25, SUM(SALARY)#56] test.employee
    ```
    
    ### Why are the changes needed?
    Alias is more useful.
    
    ### Does this PR introduce _any_ user-facing change?
    'Yes'.
    Users could see DS V2 aggregate push-down supports project with alias.
    
    ### How was this patch tested?
    New tests.
    
    Closes #35932 from beliefer/SPARK-38533_new.
    
    Authored-by: Jiaan Geng <be...@163.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit f327dade9cdb466574b4698c2b9da4bdaac300e0)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../datasources/v2/V2ScanRelationPushDown.scala    | 22 ++++--
 .../FileSourceAggregatePushDownSuite.scala         |  4 +-
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 86 ++++++++++++++++++++--
 3 files changed, 97 insertions(+), 15 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 44cdff1..c699e92 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import scala.collection.mutable
 
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.optimizer.CollapseProject
 import org.apache.spark.sql.catalyst.planning.ScanOperation
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -34,7 +35,7 @@ import org.apache.spark.sql.sources
 import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType}
 import org.apache.spark.sql.util.SchemaUtils._
 
-object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper {
   import DataSourceV2Implicits._
 
   def apply(plan: LogicalPlan): LogicalPlan = {
@@ -95,22 +96,27 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
     case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
       child match {
         case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
-          if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
+          if filters.isEmpty && CollapseProject.canCollapseExpressions(
+            resultExpressions, project, alwaysInline = true) =>
           sHolder.builder match {
             case r: SupportsPushDownAggregates =>
+              val aliasMap = getAliasMap(project)
+              val actualResultExprs = resultExpressions.map(replaceAliasButKeepName(_, aliasMap))
+              val actualGroupExprs = groupingExpressions.map(replaceAlias(_, aliasMap))
+
               val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
-              val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal)
+              val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal)
               val normalizedAggregates = DataSourceStrategy.normalizeExprs(
                 aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
               val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs(
-                groupingExpressions, sHolder.relation.output)
+                actualGroupExprs, sHolder.relation.output)
               val translatedAggregates = DataSourceStrategy.translateAggregation(
                 normalizedAggregates, normalizedGroupingExpressions)
               val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = {
                 if (translatedAggregates.isEmpty ||
                   r.supportCompletePushDown(translatedAggregates.get) ||
                   translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) {
-                  (resultExpressions, aggregates, translatedAggregates)
+                  (actualResultExprs, aggregates, translatedAggregates)
                 } else {
                   // scalastyle:off
                   // The data source doesn't support the complete push-down of this aggregation.
@@ -127,7 +133,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
                   // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19]
                   // +- ScanOperation[...]
                   // scalastyle:on
-                  val newResultExpressions = resultExpressions.map { expr =>
+                  val newResultExpressions = actualResultExprs.map { expr =>
                     expr.transform {
                       case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) =>
                         val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
@@ -206,7 +212,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
                   val scanRelation =
                     DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
                   if (r.supportCompletePushDown(pushedAggregates.get)) {
-                    val projectExpressions = resultExpressions.map { expr =>
+                    val projectExpressions = finalResultExpressions.map { expr =>
                       // TODO At present, only push down group by attribute is supported.
                       // In future, more attribute conversion is extended here. e.g. GetStructField
                       expr.transform {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
index 47740c5..26dfe1a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
@@ -184,7 +184,7 @@ trait FileSourceAggregatePushDownSuite
     }
   }
 
-  test("aggregate over alias not push down") {
+  test("aggregate over alias push down") {
     val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
       (9, "mno", 7), (2, null, 6))
     withDataSourceTable(data, "t") {
@@ -194,7 +194,7 @@ trait FileSourceAggregatePushDownSuite
         query.queryExecution.optimizedPlan.collect {
           case _: DataSourceV2ScanRelation =>
             val expected_plan_fragment =
-              "PushedAggregation: []"  // aggregate alias not pushed down
+              "PushedAggregation: [MIN(_1)]"
             checkKeywordsExistsInExplain(query, expected_plan_fragment)
         }
         checkAnswer(query, Seq(Row(-2)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index d6f098f..31fdb02 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -974,15 +974,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     checkAnswer(df, Seq(Row(1d), Row(1d), Row(null)))
   }
 
-  test("scan with aggregate push-down: aggregate over alias NOT push down") {
+  test("scan with aggregate push-down: aggregate over alias push down") {
     val cols = Seq("a", "b", "c", "d", "e")
     val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
     val df2 = df1.groupBy().sum("c")
-    checkAggregateRemoved(df2, false)
+    checkAggregateRemoved(df2)
     df2.queryExecution.optimizedPlan.collect {
-      case relation: DataSourceV2ScanRelation => relation.scan match {
-        case v1: V1ScanWrapper =>
-          assert(v1.pushedDownOperators.aggregation.isEmpty)
+      case relation: DataSourceV2ScanRelation =>
+        val expectedPlanFragment =
+          "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: []"
+        checkKeywordsExistsInExplain(df2, expectedPlanFragment)
+        relation.scan match {
+          case v1: V1ScanWrapper =>
+            assert(v1.pushedDownOperators.aggregation.nonEmpty)
       }
     }
     checkAnswer(df2, Seq(Row(53000.00)))
@@ -1228,4 +1232,76 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
         |ON h2.test.view1.`|col1` = h2.test.view2.`|col1`""".stripMargin)
     checkAnswer(df, Seq.empty[Row])
   }
+
+  test("scan with aggregate push-down: complete push-down aggregate with alias") {
+    val df = spark.table("h2.test.employee")
+      .select($"DEPT", $"SALARY".as("mySalary"))
+      .groupBy($"DEPT")
+      .agg(sum($"mySalary").as("total"))
+      .filter($"total" > 1000)
+    checkAggregateRemoved(df)
+    df.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expectedPlanFragment =
+          "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]"
+        checkKeywordsExistsInExplain(df, expectedPlanFragment)
+    }
+    checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00)))
+
+    val df2 = spark.table("h2.test.employee")
+      .select($"DEPT".as("myDept"), $"SALARY".as("mySalary"))
+      .groupBy($"myDept")
+      .agg(sum($"mySalary").as("total"))
+      .filter($"total" > 1000)
+    checkAggregateRemoved(df2)
+    df2.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expectedPlanFragment =
+          "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]"
+        checkKeywordsExistsInExplain(df2, expectedPlanFragment)
+    }
+    checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00)))
+  }
+
+  test("scan with aggregate push-down: partial push-down aggregate with alias") {
+    val df = spark.read
+      .option("partitionColumn", "DEPT")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .select($"NAME", $"SALARY".as("mySalary"))
+      .groupBy($"NAME")
+      .agg(sum($"mySalary").as("total"))
+      .filter($"total" > 1000)
+    checkAggregateRemoved(df, false)
+    df.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expectedPlanFragment =
+          "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]"
+        checkKeywordsExistsInExplain(df, expectedPlanFragment)
+    }
+    checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00),
+      Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00)))
+
+    val df2 = spark.read
+      .option("partitionColumn", "DEPT")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .select($"NAME".as("myName"), $"SALARY".as("mySalary"))
+      .groupBy($"myName")
+      .agg(sum($"mySalary").as("total"))
+      .filter($"total" > 1000)
+    checkAggregateRemoved(df2, false)
+    df2.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expectedPlanFragment =
+          "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]"
+        checkKeywordsExistsInExplain(df2, expectedPlanFragment)
+    }
+    checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00),
+      Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00)))
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org