You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/01/20 04:14:17 UTC
[spark] branch master updated: [SPARK-37839][SQL] DS V2 supports partial aggregate push-down `AVG`
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fcc5c34 [SPARK-37839][SQL] DS V2 supports partial aggregate push-down `AVG`
fcc5c34 is described below
commit fcc5c34c546a45a32ababdd41932c620de9bc969
Author: Jiaan Geng <be...@163.com>
AuthorDate: Thu Jan 20 12:13:00 2022 +0800
[SPARK-37839][SQL] DS V2 supports partial aggregate push-down `AVG`
### What changes were proposed in this pull request?
`max`,`min`,`count`,`sum`,`avg` are the most commonly used aggregation functions.
Currently, DS V2 supports complete aggregate push-down of `avg`. But, supports partial aggregate push-down of `avg` is very useful.
The aggregate push-down algorithm is:
1. Spark translates group expressions of `Aggregate` to DS V2 `Aggregation`.
2. Spark calls `supportCompletePushDown` to check if it can completely push down aggregate.
3. If `supportCompletePushDown` returns true, we preserves the aggregate expressions as final aggregate expressions. Otherwise, we split `AVG` into 2 functions: `SUM` and `COUNT`.
4. Spark translates final aggregate expressions and group expressions of `Aggregate` to DS V2 `Aggregation` again, and pushes the `Aggregation` to JDBC source.
5. Spark constructs the final aggregate.
### Why are the changes needed?
DS V2 supports partial aggregate push-down `AVG`
### Does this PR introduce _any_ user-facing change?
'Yes'. DS V2 could partial aggregate push-down `AVG`
### How was this patch tested?
New tests.
Closes #35130 from beliefer/SPARK-37839.
Authored-by: Jiaan Geng <be...@163.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../sql/connector/expressions/aggregate/Avg.java | 49 ++++++++++
.../aggregate/GeneralAggregateFunc.java | 1 -
.../catalyst/expressions/aggregate/Average.scala | 2 +-
.../execution/datasources/DataSourceStrategy.scala | 29 +++++-
.../execution/datasources/v2/PushDownUtils.scala | 36 +------
.../datasources/v2/V2ScanRelationPushDown.scala | 106 ++++++++++++++++-----
.../org/apache/spark/sql/jdbc/JdbcDialects.scala | 11 ++-
.../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 82 +++++++++++++++-
8 files changed, 249 insertions(+), 67 deletions(-)
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java
new file mode 100644
index 0000000..5e10ec9
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions.aggregate;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+
+/**
+ * An aggregate function that returns the mean of all the values in a group.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public final class Avg implements AggregateFunc {
+ private final NamedReference column;
+ private final boolean isDistinct;
+
+ public Avg(NamedReference column, boolean isDistinct) {
+ this.column = column;
+ this.isDistinct = isDistinct;
+ }
+
+ public NamedReference column() { return column; }
+ public boolean isDistinct() { return isDistinct; }
+
+ @Override
+ public String toString() {
+ if (isDistinct) {
+ return "AVG(DISTINCT " + column.describe() + ")";
+ } else {
+ return "AVG(" + column.describe() + ")";
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
index 32615e2..0ff26c8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
@@ -31,7 +31,6 @@ import org.apache.spark.sql.connector.expressions.NamedReference;
* <p>
* The currently supported SQL aggregate functions:
* <ol>
- * <li><pre>AVG(input1)</pre> Since 3.3.0</li>
* <li><pre>VAR_POP(input1)</pre> Since 3.3.0</li>
* <li><pre>VAR_SAMP(input1)</pre> Since 3.3.0</li>
* <li><pre>STDDEV_POP(input1)</pre> Since 3.3.0</li>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 9714a09..05f7eda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -69,7 +69,7 @@ case class Average(
case _ => DoubleType
}
- private lazy val sumDataType = child.dataType match {
+ lazy val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case _: YearMonthIntervalType => YearMonthIntervalType()
case _: DayTimeIntervalType => DayTimeIntervalType()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index e734de3..ecde8a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
-import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
@@ -720,7 +720,7 @@ object DataSourceStrategy
case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
Some(new Sum(FieldReference.column(name), agg.isDistinct))
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
- Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference.column(name))))
+ Some(new Avg(FieldReference.column(name), agg.isDistinct))
case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc(
"VAR_POP", agg.isDistinct, Array(FieldReference.column(name))))
@@ -752,6 +752,31 @@ object DataSourceStrategy
}
}
+ /**
+ * Translate aggregate expressions and group by expressions.
+ *
+ * @return translated aggregation.
+ */
+ protected[sql] def translateAggregation(
+ aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = {
+
+ def columnAsString(e: Expression): Option[FieldReference] = e match {
+ case PushableColumnWithoutNestedColumn(name) =>
+ Some(FieldReference.column(name).asInstanceOf[FieldReference])
+ case _ => None
+ }
+
+ val translatedAggregates = aggregates.flatMap(translateAggregate)
+ val translatedGroupBys = groupBy.flatMap(columnAsString)
+
+ if (translatedAggregates.length != aggregates.length ||
+ translatedGroupBys.length != groupBy.length) {
+ return None
+ }
+
+ Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray))
+ }
+
protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = {
def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match {
case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index 29d86b6..9953658 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
-import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
-import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
+import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
@@ -107,34 +105,6 @@ object PushDownUtils extends PredicateHelper {
}
/**
- * Pushes down aggregates to the data source reader
- *
- * @return pushed aggregation.
- */
- def pushAggregates(
- scanBuilder: SupportsPushDownAggregates,
- aggregates: Seq[AggregateExpression],
- groupBy: Seq[Expression]): Option[Aggregation] = {
-
- def columnAsString(e: Expression): Option[FieldReference] = e match {
- case PushableColumnWithoutNestedColumn(name) =>
- Some(FieldReference.column(name).asInstanceOf[FieldReference])
- case _ => None
- }
-
- val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
- val translatedGroupBys = groupBy.flatMap(columnAsString)
-
- if (translatedAggregates.length != aggregates.length ||
- translatedGroupBys.length != groupBy.length) {
- return None
- }
-
- val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
- Some(agg).filter(scanBuilder.pushAggregation)
- }
-
- /**
* Pushes down TableSample to the data source Scan
*/
def pushTableSample(scanBuilder: ScanBuilder, sample: TableSampleInfo): Boolean = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 3437dcb..3ff9176 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -19,18 +19,18 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.SortOrder
-import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc}
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, GeneralAggregateFunc}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
-import org.apache.spark.sql.types.{DataType, LongType, StructType}
+import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType}
import org.apache.spark.sql.util.SchemaUtils._
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
@@ -97,25 +97,66 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
sHolder.builder match {
case r: SupportsPushDownAggregates =>
val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
- var ordinal = 0
- val aggregates = resultExpressions.flatMap { expr =>
- expr.collect {
- // Do not push down duplicated aggregate expressions. For example,
- // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
- // `max(a)` to the data source.
- case agg: AggregateExpression
- if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
- aggExprToOutputOrdinal(agg.canonicalized) = ordinal
- ordinal += 1
- agg
- }
- }
+ val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal)
val normalizedAggregates = DataSourceStrategy.normalizeExprs(
aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs(
groupingExpressions, sHolder.relation.output)
- val pushedAggregates = PushDownUtils.pushAggregates(
- r, normalizedAggregates, normalizedGroupingExpressions)
+ val translatedAggregates = DataSourceStrategy.translateAggregation(
+ normalizedAggregates, normalizedGroupingExpressions)
+ val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = {
+ if (translatedAggregates.isEmpty ||
+ r.supportCompletePushDown(translatedAggregates.get) ||
+ translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) {
+ (resultExpressions, aggregates, translatedAggregates)
+ } else {
+ // scalastyle:off
+ // The data source doesn't support the complete push-down of this aggregation.
+ // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be
+ // pushed, completely or partially.
+ // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+ // SELECT avg(c1) FROM t GROUP BY c2;
+ // The original logical plan is
+ // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19]
+ // +- ScanOperation[...]
+ //
+ // After convert avg(c1#9) to sum(c1#9)/count(c1#9)
+ // we have the following
+ // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19]
+ // +- ScanOperation[...]
+ // scalastyle:on
+ val newResultExpressions = resultExpressions.map { expr =>
+ expr.transform {
+ case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) =>
+ val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
+ val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct)
+ // Closely follow `Average.evaluateExpression`
+ avg.dataType match {
+ case _: YearMonthIntervalType =>
+ If(EqualTo(count, Literal(0L)),
+ Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count))
+ case _: DayTimeIntervalType =>
+ If(EqualTo(count, Literal(0L)),
+ Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count))
+ case _ =>
+ // TODO deal with the overflow issue
+ Divide(addCastIfNeeded(sum, avg.dataType),
+ addCastIfNeeded(count, avg.dataType), false)
+ }
+ }
+ }.asInstanceOf[Seq[NamedExpression]]
+ // Because aggregate expressions changed, translate them again.
+ aggExprToOutputOrdinal.clear()
+ val newAggregates =
+ collectAggregates(newResultExpressions, aggExprToOutputOrdinal)
+ val newNormalizedAggregates = DataSourceStrategy.normalizeExprs(
+ newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
+ (newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation(
+ newNormalizedAggregates, normalizedGroupingExpressions))
+ }
+ }
+
+ val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation)
if (pushedAggregates.isEmpty) {
aggNode // return original plan node
} else if (!supportPartialAggPushDown(pushedAggregates.get) &&
@@ -138,7 +179,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
// scalastyle:on
val newOutput = scan.readSchema().toAttributes
- assert(newOutput.length == groupingExpressions.length + aggregates.length)
+ assert(newOutput.length == groupingExpressions.length + finalAggregates.length)
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
case (_, b) => b
@@ -173,7 +214,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
Project(projectExpressions, scanRelation)
} else {
val plan = Aggregate(
- output.take(groupingExpressions.length), resultExpressions, scanRelation)
+ output.take(groupingExpressions.length), finalResultExpressions, scanRelation)
// scalastyle:off
// Change the optimized logical plan to reflect the pushed down aggregate
@@ -219,16 +260,33 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
}
+ private def collectAggregates(resultExpressions: Seq[NamedExpression],
+ aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = {
+ var ordinal = 0
+ resultExpressions.flatMap { expr =>
+ expr.collect {
+ // Do not push down duplicated aggregate expressions. For example,
+ // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
+ // `max(a)` to the data source.
+ case agg: AggregateExpression
+ if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
+ aggExprToOutputOrdinal(agg.canonicalized) = ordinal
+ ordinal += 1
+ agg
+ }
+ }
+ }
+
private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
// We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down.
agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc])
}
- private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) =
- if (aggAttribute.dataType == aggDataType) {
- aggAttribute
+ private def addCastIfNeeded(expression: Expression, expectedDataType: DataType) =
+ if (expression.dataType == expectedDataType) {
+ expression
} else {
- Cast(aggAttribute, aggDataType)
+ Cast(expression, expectedDataType)
}
def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 344842d..7b8b362 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.TableChange
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.NamedReference
-import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
@@ -220,10 +220,11 @@ abstract class JdbcDialect extends Serializable with Logging{
Some(s"SUM($distinct$column)")
case _: CountStar =>
Some("COUNT(*)")
- case f: GeneralAggregateFunc if f.name() == "AVG" =>
- assert(f.inputs().length == 1)
- val distinct = if (f.isDistinct) "DISTINCT " else ""
- Some(s"AVG($distinct${f.inputs().head})")
+ case avg: Avg =>
+ if (avg.column.fieldNames.length != 1) return None
+ val distinct = if (avg.isDistinct) "DISTINCT " else ""
+ val column = quoteIdentifier(avg.column.fieldNames.head)
+ Some(s"AVG($distinct$column)")
case _ => None
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index c5e1a6a..eadc2fb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
-import org.apache.spark.sql.functions.{lit, sum, udf}
+import org.apache.spark.sql.functions.{avg, count, lit, sum, udf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
@@ -874,4 +874,84 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df, Seq(Row(2)))
// scalastyle:on
}
+
+ test("scan with aggregate push-down: complete push-down SUM, AVG, COUNT") {
+ val df = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "1")
+ .table("h2.test.employee")
+ .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
+ checkAggregateRemoved(df)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5)))
+
+ val df2 = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "1")
+ .table("h2.test.employee")
+ .groupBy($"name")
+ .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
+ checkAggregateRemoved(df)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df2, Seq(
+ Row("alex", 12000.00, 12000.000000, 1),
+ Row("amy", 10000.00, 10000.000000, 1),
+ Row("cathy", 9000.00, 9000.000000, 1),
+ Row("david", 10000.00, 10000.000000, 1),
+ Row("jen", 12000.00, 12000.000000, 1)))
+ }
+
+ test("scan with aggregate push-down: partial push-down SUM, AVG, COUNT") {
+ val df = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
+ checkAggregateRemoved(df, false)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5)))
+
+ val df2 = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .groupBy($"name")
+ .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
+ checkAggregateRemoved(df, false)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df2, Seq(
+ Row("alex", 12000.00, 12000.000000, 1),
+ Row("amy", 10000.00, 10000.000000, 1),
+ Row("cathy", 9000.00, 9000.000000, 1),
+ Row("david", 10000.00, 10000.000000, 1),
+ Row("jen", 12000.00, 12000.000000, 1)))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org