You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/03/17 08:55:43 UTC

[spark] branch master updated: [SPARK-38560][SQL] If `Sum`, `Count`, `Any` accompany with distinct, cannot do partial agg push down

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f0b836b  [SPARK-38560][SQL] If `Sum`, `Count`, `Any` accompany with distinct, cannot do partial agg push down
f0b836b is described below

commit f0b836b1a6d5926dd09018b67a27461aca5ce739
Author: Jiaan Geng <be...@163.com>
AuthorDate: Thu Mar 17 16:53:40 2022 +0800

    [SPARK-38560][SQL] If `Sum`, `Count`, `Any` accompany with distinct, cannot do partial agg push down
    
    ### What changes were proposed in this pull request?
    Spark could partial push down sum(distinct col), count(distinct col) if data source have multiple partitions, and Spark will sum the value again.
    So the result may not correctly.
    
    ### Why are the changes needed?
    Fix the bug push down sum(distinct col), count(distinct col) to data source and return incorrect result.
    
    ### Does this PR introduce _any_ user-facing change?
    'Yes'.
    Users will see the correct behavior.
    
    ### How was this patch tested?
    New tests.
    
    Closes #35873 from beliefer/SPARK-38560.
    
    Authored-by: Jiaan Geng <be...@163.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../datasources/v2/V2ScanRelationPushDown.scala    | 184 +++++++++++----------
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    |  14 +-
 2 files changed, 111 insertions(+), 87 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 3ff9176..b4bd027 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.expressions.SortOrder
-import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, GeneralAggregateFunc}
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum}
 import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
@@ -156,101 +156,106 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
                 }
               }
 
-              val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation)
-              if (pushedAggregates.isEmpty) {
+              if (finalTranslatedAggregates.isEmpty) {
                 aggNode // return original plan node
-              } else if (!supportPartialAggPushDown(pushedAggregates.get) &&
-                !r.supportCompletePushDown(pushedAggregates.get)) {
+              } else if (!r.supportCompletePushDown(finalTranslatedAggregates.get) &&
+                !supportPartialAggPushDown(finalTranslatedAggregates.get)) {
                 aggNode // return original plan node
               } else {
-                // No need to do column pruning because only the aggregate columns are used as
-                // DataSourceV2ScanRelation output columns. All the other columns are not
-                // included in the output.
-                val scan = sHolder.builder.build()
-
-                // scalastyle:off
-                // use the group by columns and aggregate columns as the output columns
-                // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
-                // SELECT min(c1), max(c1) FROM t GROUP BY c2;
-                // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation
-                // We want to have the following logical plan:
-                // == Optimized Logical Plan ==
-                // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
-                // +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
-                // scalastyle:on
-                val newOutput = scan.readSchema().toAttributes
-                assert(newOutput.length == groupingExpressions.length + finalAggregates.length)
-                val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
-                  case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
-                  case (_, b) => b
-                }
-                val aggOutput = newOutput.drop(groupAttrs.length)
-                val output = groupAttrs ++ aggOutput
-
-                logInfo(
-                  s"""
-                     |Pushing operators to ${sHolder.relation.name}
-                     |Pushed Aggregate Functions:
-                     | ${pushedAggregates.get.aggregateExpressions.mkString(", ")}
-                     |Pushed Group by:
-                     | ${pushedAggregates.get.groupByColumns.mkString(", ")}
-                     |Output: ${output.mkString(", ")}
-                      """.stripMargin)
-
-                val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)
-                val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
-                if (r.supportCompletePushDown(pushedAggregates.get)) {
-                  val projectExpressions = resultExpressions.map { expr =>
-                    // TODO At present, only push down group by attribute is supported.
-                    // In future, more attribute conversion is extended here. e.g. GetStructField
-                    expr.transform {
-                      case agg: AggregateExpression =>
-                        val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
-                        val child =
-                          addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType)
-                        Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId)
-                    }
-                  }.asInstanceOf[Seq[NamedExpression]]
-                  Project(projectExpressions, scanRelation)
+                val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation)
+                if (pushedAggregates.isEmpty) {
+                  aggNode // return original plan node
                 } else {
-                  val plan = Aggregate(
-                    output.take(groupingExpressions.length), finalResultExpressions, scanRelation)
+                  // No need to do column pruning because only the aggregate columns are used as
+                  // DataSourceV2ScanRelation output columns. All the other columns are not
+                  // included in the output.
+                  val scan = sHolder.builder.build()
 
                   // scalastyle:off
-                  // Change the optimized logical plan to reflect the pushed down aggregate
+                  // use the group by columns and aggregate columns as the output columns
                   // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
                   // SELECT min(c1), max(c1) FROM t GROUP BY c2;
-                  // The original logical plan is
-                  // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
-                  // +- RelationV2[c1#9, c2#10] ...
-                  //
-                  // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
-                  // we have the following
-                  // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
-                  // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
-                  //
-                  // We want to change it to
+                  // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation
+                  // We want to have the following logical plan:
                   // == Optimized Logical Plan ==
                   // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
-                  // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+                  // +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
                   // scalastyle:on
-                  plan.transformExpressions {
-                    case agg: AggregateExpression =>
-                      val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
-                      val aggAttribute = aggOutput(ordinal)
-                      val aggFunction: aggregate.AggregateFunction =
-                        agg.aggregateFunction match {
-                          case max: aggregate.Max =>
-                            max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType))
-                          case min: aggregate.Min =>
-                            min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType))
-                          case sum: aggregate.Sum =>
-                            sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType))
-                          case _: aggregate.Count =>
-                            aggregate.Sum(addCastIfNeeded(aggAttribute, LongType))
-                          case other => other
-                        }
-                      agg.copy(aggregateFunction = aggFunction)
+                  val newOutput = scan.readSchema().toAttributes
+                  assert(newOutput.length == groupingExpressions.length + finalAggregates.length)
+                  val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
+                    case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
+                    case (_, b) => b
+                  }
+                  val aggOutput = newOutput.drop(groupAttrs.length)
+                  val output = groupAttrs ++ aggOutput
+
+                  logInfo(
+                    s"""
+                       |Pushing operators to ${sHolder.relation.name}
+                       |Pushed Aggregate Functions:
+                       | ${pushedAggregates.get.aggregateExpressions.mkString(", ")}
+                       |Pushed Group by:
+                       | ${pushedAggregates.get.groupByColumns.mkString(", ")}
+                       |Output: ${output.mkString(", ")}
+                      """.stripMargin)
+
+                  val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)
+                  val scanRelation =
+                    DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
+                  if (r.supportCompletePushDown(pushedAggregates.get)) {
+                    val projectExpressions = resultExpressions.map { expr =>
+                      // TODO At present, only push down group by attribute is supported.
+                      // In future, more attribute conversion is extended here. e.g. GetStructField
+                      expr.transform {
+                        case agg: AggregateExpression =>
+                          val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
+                          val child =
+                            addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType)
+                          Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId)
+                      }
+                    }.asInstanceOf[Seq[NamedExpression]]
+                    Project(projectExpressions, scanRelation)
+                  } else {
+                    val plan = Aggregate(output.take(groupingExpressions.length),
+                      finalResultExpressions, scanRelation)
+
+                    // scalastyle:off
+                    // Change the optimized logical plan to reflect the pushed down aggregate
+                    // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+                    // SELECT min(c1), max(c1) FROM t GROUP BY c2;
+                    // The original logical plan is
+                    // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
+                    // +- RelationV2[c1#9, c2#10] ...
+                    //
+                    // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
+                    // we have the following
+                    // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
+                    // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+                    //
+                    // We want to change it to
+                    // == Optimized Logical Plan ==
+                    // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
+                    // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+                    // scalastyle:on
+                    plan.transformExpressions {
+                      case agg: AggregateExpression =>
+                        val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
+                        val aggAttribute = aggOutput(ordinal)
+                        val aggFunction: aggregate.AggregateFunction =
+                          agg.aggregateFunction match {
+                            case max: aggregate.Max =>
+                              max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType))
+                            case min: aggregate.Min =>
+                              min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType))
+                            case sum: aggregate.Sum =>
+                              sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType))
+                            case _: aggregate.Count =>
+                              aggregate.Sum(addCastIfNeeded(aggAttribute, LongType))
+                            case other => other
+                          }
+                        agg.copy(aggregateFunction = aggFunction)
+                    }
                   }
                 }
               }
@@ -279,7 +284,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
 
   private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
     // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down.
-    agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc])
+    // If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down.
+    agg.aggregateExpressions().exists {
+      case sum: Sum => !sum.isDistinct
+      case count: Count => !count.isDistinct
+      case avg: Avg => !avg.isDistinct
+      case _: GeneralAggregateFunc => false
+      case _ => true
+    }
   }
 
   private def addCastIfNeeded(expression: Expression, expectedDataType: DataType) =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 85ccf82..17bd7f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
 import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue}
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
 import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
-import org.apache.spark.sql.functions.{avg, count, lit, sum, udf}
+import org.apache.spark.sql.functions.{avg, count, count_distinct, lit, sum, udf}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.util.Utils
@@ -506,6 +506,18 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     checkAnswer(df, Seq(Row(3)))
   }
 
+  test("scan with aggregate push-down: cannot partial push down COUNT(DISTINCT col)") {
+    val df = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .agg(count_distinct($"DEPT"))
+    checkAggregateRemoved(df, false)
+    checkAnswer(df, Seq(Row(3)))
+  }
+
   test("scan with aggregate push-down: SUM without filer and group by") {
     val df = sql("SELECT SUM(SALARY) FROM h2.test.employee")
     checkAggregateRemoved(df)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org