You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/01/20 04:14:17 UTC

[spark] branch master updated: [SPARK-37839][SQL] DS V2 supports partial aggregate push-down `AVG`

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fcc5c34  [SPARK-37839][SQL] DS V2 supports partial aggregate push-down `AVG`
fcc5c34 is described below

commit fcc5c34c546a45a32ababdd41932c620de9bc969
Author: Jiaan Geng <be...@163.com>
AuthorDate: Thu Jan 20 12:13:00 2022 +0800

    [SPARK-37839][SQL] DS V2 supports partial aggregate push-down `AVG`
    
    ### What changes were proposed in this pull request?
    `max`,`min`,`count`,`sum`,`avg` are the most commonly used aggregation functions.
    Currently, DS V2 supports complete aggregate push-down of `avg`. But, supports partial aggregate push-down of `avg` is very useful.
    
    The aggregate push-down algorithm is:
    
    1. Spark translates group expressions of `Aggregate` to DS V2 `Aggregation`.
    2. Spark calls `supportCompletePushDown` to check if it can completely push down aggregate.
    3. If `supportCompletePushDown` returns true, we preserves the aggregate expressions as final aggregate expressions. Otherwise, we split `AVG` into 2 functions: `SUM` and `COUNT`.
    4. Spark translates final aggregate expressions and group expressions of `Aggregate` to DS V2 `Aggregation` again, and pushes the `Aggregation` to JDBC source.
    5. Spark constructs the final aggregate.
    
    ### Why are the changes needed?
    DS V2 supports partial aggregate push-down `AVG`
    
    ### Does this PR introduce _any_ user-facing change?
    'Yes'. DS V2 could partial aggregate push-down `AVG`
    
    ### How was this patch tested?
    New tests.
    
    Closes #35130 from beliefer/SPARK-37839.
    
    Authored-by: Jiaan Geng <be...@163.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/connector/expressions/aggregate/Avg.java   |  49 ++++++++++
 .../aggregate/GeneralAggregateFunc.java            |   1 -
 .../catalyst/expressions/aggregate/Average.scala   |   2 +-
 .../execution/datasources/DataSourceStrategy.scala |  29 +++++-
 .../execution/datasources/v2/PushDownUtils.scala   |  36 +------
 .../datasources/v2/V2ScanRelationPushDown.scala    | 106 ++++++++++++++++-----
 .../org/apache/spark/sql/jdbc/JdbcDialects.scala   |  11 ++-
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    |  82 +++++++++++++++-
 8 files changed, 249 insertions(+), 67 deletions(-)

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java
new file mode 100644
index 0000000..5e10ec9
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions.aggregate;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+
+/**
+ * An aggregate function that returns the mean of all the values in a group.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public final class Avg implements AggregateFunc {
+  private final NamedReference column;
+  private final boolean isDistinct;
+
+  public Avg(NamedReference column, boolean isDistinct) {
+    this.column = column;
+    this.isDistinct = isDistinct;
+  }
+
+  public NamedReference column() { return column; }
+  public boolean isDistinct() { return isDistinct; }
+
+  @Override
+  public String toString() {
+    if (isDistinct) {
+      return "AVG(DISTINCT " + column.describe() + ")";
+    } else {
+      return "AVG(" + column.describe() + ")";
+    }
+  }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
index 32615e2..0ff26c8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
@@ -31,7 +31,6 @@ import org.apache.spark.sql.connector.expressions.NamedReference;
  * <p>
  * The currently supported SQL aggregate functions:
  * <ol>
- *  <li><pre>AVG(input1)</pre> Since 3.3.0</li>
  *  <li><pre>VAR_POP(input1)</pre> Since 3.3.0</li>
  *  <li><pre>VAR_SAMP(input1)</pre> Since 3.3.0</li>
  *  <li><pre>STDDEV_POP(input1)</pre> Since 3.3.0</li>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 9714a09..05f7eda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -69,7 +69,7 @@ case class Average(
     case _ => DoubleType
   }
 
-  private lazy val sumDataType = child.dataType match {
+  lazy val sumDataType = child.dataType match {
     case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
     case _: YearMonthIntervalType => YearMonthIntervalType()
     case _: DayTimeIntervalType => DayTimeIntervalType()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index e734de3..ecde8a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
 import org.apache.spark.sql.connector.catalog.SupportsRead
 import org.apache.spark.sql.connector.catalog.TableCapability._
 import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
-import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
 import org.apache.spark.sql.execution.command._
@@ -720,7 +720,7 @@ object DataSourceStrategy
         case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
           Some(new Sum(FieldReference.column(name), agg.isDistinct))
         case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
-          Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference.column(name))))
+          Some(new Avg(FieldReference.column(name), agg.isDistinct))
         case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
           Some(new GeneralAggregateFunc(
             "VAR_POP", agg.isDistinct, Array(FieldReference.column(name))))
@@ -752,6 +752,31 @@ object DataSourceStrategy
     }
   }
 
+  /**
+   * Translate aggregate expressions and group by expressions.
+   *
+   * @return translated aggregation.
+   */
+  protected[sql] def translateAggregation(
+      aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = {
+
+    def columnAsString(e: Expression): Option[FieldReference] = e match {
+      case PushableColumnWithoutNestedColumn(name) =>
+        Some(FieldReference.column(name).asInstanceOf[FieldReference])
+      case _ => None
+    }
+
+    val translatedAggregates = aggregates.flatMap(translateAggregate)
+    val translatedGroupBys = groupBy.flatMap(columnAsString)
+
+    if (translatedAggregates.length != aggregates.length ||
+      translatedGroupBys.length != groupBy.length) {
+      return None
+    }
+
+    Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray))
+  }
+
   protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = {
     def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match {
       case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index 29d86b6..9953658 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2
 import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
-import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+import org.apache.spark.sql.connector.expressions.SortOrder
 import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
-import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
+import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources
 import org.apache.spark.sql.types.StructType
@@ -107,34 +105,6 @@ object PushDownUtils extends PredicateHelper {
   }
 
   /**
-   * Pushes down aggregates to the data source reader
-   *
-   * @return pushed aggregation.
-   */
-  def pushAggregates(
-      scanBuilder: SupportsPushDownAggregates,
-      aggregates: Seq[AggregateExpression],
-      groupBy: Seq[Expression]): Option[Aggregation] = {
-
-    def columnAsString(e: Expression): Option[FieldReference] = e match {
-      case PushableColumnWithoutNestedColumn(name) =>
-        Some(FieldReference.column(name).asInstanceOf[FieldReference])
-      case _ => None
-    }
-
-    val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
-    val translatedGroupBys = groupBy.flatMap(columnAsString)
-
-    if (translatedAggregates.length != aggregates.length ||
-      translatedGroupBys.length != groupBy.length) {
-      return None
-    }
-
-    val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
-    Some(agg).filter(scanBuilder.pushAggregation)
-  }
-
-  /**
    * Pushes down TableSample to the data source Scan
    */
   def pushTableSample(scanBuilder: ScanBuilder, sample: TableSampleInfo): Boolean = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 3437dcb..3ff9176 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -19,18 +19,18 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import scala.collection.mutable
 
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.expressions.SortOrder
-import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc}
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, GeneralAggregateFunc}
 import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
-import org.apache.spark.sql.types.{DataType, LongType, StructType}
+import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType}
 import org.apache.spark.sql.util.SchemaUtils._
 
 object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
@@ -97,25 +97,66 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
           sHolder.builder match {
             case r: SupportsPushDownAggregates =>
               val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
-              var ordinal = 0
-              val aggregates = resultExpressions.flatMap { expr =>
-                expr.collect {
-                  // Do not push down duplicated aggregate expressions. For example,
-                  // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
-                  // `max(a)` to the data source.
-                  case agg: AggregateExpression
-                      if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
-                    aggExprToOutputOrdinal(agg.canonicalized) = ordinal
-                    ordinal += 1
-                    agg
-                }
-              }
+              val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal)
               val normalizedAggregates = DataSourceStrategy.normalizeExprs(
                 aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
               val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs(
                 groupingExpressions, sHolder.relation.output)
-              val pushedAggregates = PushDownUtils.pushAggregates(
-                r, normalizedAggregates, normalizedGroupingExpressions)
+              val translatedAggregates = DataSourceStrategy.translateAggregation(
+                normalizedAggregates, normalizedGroupingExpressions)
+              val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = {
+                if (translatedAggregates.isEmpty ||
+                  r.supportCompletePushDown(translatedAggregates.get) ||
+                  translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) {
+                  (resultExpressions, aggregates, translatedAggregates)
+                } else {
+                  // scalastyle:off
+                  // The data source doesn't support the complete push-down of this aggregation.
+                  // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be
+                  // pushed, completely or partially.
+                  // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+                  // SELECT avg(c1) FROM t GROUP BY c2;
+                  // The original logical plan is
+                  // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19]
+                  // +- ScanOperation[...]
+                  //
+                  // After convert avg(c1#9) to sum(c1#9)/count(c1#9)
+                  // we have the following
+                  // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19]
+                  // +- ScanOperation[...]
+                  // scalastyle:on
+                  val newResultExpressions = resultExpressions.map { expr =>
+                    expr.transform {
+                      case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) =>
+                        val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
+                        val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct)
+                        // Closely follow `Average.evaluateExpression`
+                        avg.dataType match {
+                          case _: YearMonthIntervalType =>
+                            If(EqualTo(count, Literal(0L)),
+                              Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count))
+                          case _: DayTimeIntervalType =>
+                            If(EqualTo(count, Literal(0L)),
+                              Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count))
+                          case _ =>
+                            // TODO deal with the overflow issue
+                            Divide(addCastIfNeeded(sum, avg.dataType),
+                              addCastIfNeeded(count, avg.dataType), false)
+                        }
+                    }
+                  }.asInstanceOf[Seq[NamedExpression]]
+                  // Because aggregate expressions changed, translate them again.
+                  aggExprToOutputOrdinal.clear()
+                  val newAggregates =
+                    collectAggregates(newResultExpressions, aggExprToOutputOrdinal)
+                  val newNormalizedAggregates = DataSourceStrategy.normalizeExprs(
+                    newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
+                  (newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation(
+                    newNormalizedAggregates, normalizedGroupingExpressions))
+                }
+              }
+
+              val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation)
               if (pushedAggregates.isEmpty) {
                 aggNode // return original plan node
               } else if (!supportPartialAggPushDown(pushedAggregates.get) &&
@@ -138,7 +179,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
                 // +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
                 // scalastyle:on
                 val newOutput = scan.readSchema().toAttributes
-                assert(newOutput.length == groupingExpressions.length + aggregates.length)
+                assert(newOutput.length == groupingExpressions.length + finalAggregates.length)
                 val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
                   case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
                   case (_, b) => b
@@ -173,7 +214,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
                   Project(projectExpressions, scanRelation)
                 } else {
                   val plan = Aggregate(
-                    output.take(groupingExpressions.length), resultExpressions, scanRelation)
+                    output.take(groupingExpressions.length), finalResultExpressions, scanRelation)
 
                   // scalastyle:off
                   // Change the optimized logical plan to reflect the pushed down aggregate
@@ -219,16 +260,33 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
       }
   }
 
+  private def collectAggregates(resultExpressions: Seq[NamedExpression],
+      aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = {
+    var ordinal = 0
+    resultExpressions.flatMap { expr =>
+      expr.collect {
+        // Do not push down duplicated aggregate expressions. For example,
+        // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
+        // `max(a)` to the data source.
+        case agg: AggregateExpression
+          if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
+          aggExprToOutputOrdinal(agg.canonicalized) = ordinal
+          ordinal += 1
+          agg
+      }
+    }
+  }
+
   private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
     // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down.
     agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc])
   }
 
-  private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) =
-    if (aggAttribute.dataType == aggDataType) {
-      aggAttribute
+  private def addCastIfNeeded(expression: Expression, expectedDataType: DataType) =
+    if (expression.dataType == expectedDataType) {
+      expression
     } else {
-      Cast(aggAttribute, aggDataType)
+      Cast(expression, expectedDataType)
     }
 
   def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 344842d..7b8b362 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.TableChange
 import org.apache.spark.sql.connector.catalog.TableChange._
 import org.apache.spark.sql.connector.catalog.index.TableIndex
 import org.apache.spark.sql.connector.expressions.NamedReference
-import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
 import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
@@ -220,10 +220,11 @@ abstract class JdbcDialect extends Serializable with Logging{
         Some(s"SUM($distinct$column)")
       case _: CountStar =>
         Some("COUNT(*)")
-      case f: GeneralAggregateFunc if f.name() == "AVG" =>
-        assert(f.inputs().length == 1)
-        val distinct = if (f.isDistinct) "DISTINCT " else ""
-        Some(s"AVG($distinct${f.inputs().head})")
+      case avg: Avg =>
+        if (avg.column.fieldNames.length != 1) return None
+        val distinct = if (avg.isDistinct) "DISTINCT " else ""
+        val column = quoteIdentifier(avg.column.fieldNames.head)
+        Some(s"AVG($distinct$column)")
       case _ => None
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index c5e1a6a..eadc2fb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
 import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue}
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
 import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
-import org.apache.spark.sql.functions.{lit, sum, udf}
+import org.apache.spark.sql.functions.{avg, count, lit, sum, udf}
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.util.Utils
 
@@ -874,4 +874,84 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     checkAnswer(df, Seq(Row(2)))
     // scalastyle:on
   }
+
+  test("scan with aggregate push-down: complete push-down SUM, AVG, COUNT") {
+    val df = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "1")
+      .table("h2.test.employee")
+      .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
+    checkAggregateRemoved(df)
+    df.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expected_plan_fragment =
+          "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]"
+        checkKeywordsExistsInExplain(df, expected_plan_fragment)
+    }
+    checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5)))
+
+    val df2 = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "1")
+      .table("h2.test.employee")
+      .groupBy($"name")
+      .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
+    checkAggregateRemoved(df)
+    df.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expected_plan_fragment =
+          "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]"
+        checkKeywordsExistsInExplain(df, expected_plan_fragment)
+    }
+    checkAnswer(df2, Seq(
+      Row("alex", 12000.00, 12000.000000, 1),
+      Row("amy", 10000.00, 10000.000000, 1),
+      Row("cathy", 9000.00, 9000.000000, 1),
+      Row("david", 10000.00, 10000.000000, 1),
+      Row("jen", 12000.00, 12000.000000, 1)))
+  }
+
+  test("scan with aggregate push-down: partial push-down SUM, AVG, COUNT") {
+    val df = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
+    checkAggregateRemoved(df, false)
+    df.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expected_plan_fragment =
+          "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]"
+        checkKeywordsExistsInExplain(df, expected_plan_fragment)
+    }
+    checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5)))
+
+    val df2 = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .groupBy($"name")
+      .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
+    checkAggregateRemoved(df, false)
+    df.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expected_plan_fragment =
+          "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]"
+        checkKeywordsExistsInExplain(df, expected_plan_fragment)
+    }
+    checkAnswer(df2, Seq(
+      Row("alex", 12000.00, 12000.000000, 1),
+      Row("amy", 10000.00, 10000.000000, 1),
+      Row("cathy", 9000.00, 9000.000000, 1),
+      Row("david", 10000.00, 10000.000000, 1),
+      Row("jen", 12000.00, 12000.000000, 1)))
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org