You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by vi...@apache.org on 2021/07/30 07:27:41 UTC

[spark] branch branch-3.2 updated: [SPARK-34952][SQL][FOLLOWUP] Simplify JDBC aggregate pushdown

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new f6bb75b  [SPARK-34952][SQL][FOLLOWUP] Simplify JDBC aggregate pushdown
f6bb75b is described below

commit f6bb75b0bcae8c0bccf361dfd3710ce5f17173d5
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Fri Jul 30 00:26:32 2021 -0700

    [SPARK-34952][SQL][FOLLOWUP] Simplify JDBC aggregate pushdown
    
    ### What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/33352 , to simplify the JDBC aggregate pushdown:
    1. We should get the schema of the aggregate query by asking the JDBC server, instead of calculating it by ourselves. This can simplify the code a lot, and is also more robust: the data type of SUM may vary in different databases, it's fragile to assume they are always the same as Spark.
    2. because of 1, now we can remove the `dataType` property from the public `Sum` expression.
    
    This PR also contains some small improvements:
    1. Spark should deduplicate the aggregate expressions before pushing them down.
    2. Improve the `toString` of public aggregate expressions to make them more SQL.
    
    ### Why are the changes needed?
    
    code and API simplification
    
    ### Does this PR introduce _any_ user-facing change?
    
    this API is not released yet.
    
    ### How was this patch tested?
    
    existing tests
    
    Closes #33579 from cloud-fan/dsv2.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
    (cherry picked from commit 387a251a682a596ba4156b7d12e6025762ebac85)
    Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
---
 .../spark/sql/connector/expressions/Count.java     |  8 +-
 .../spark/sql/connector/expressions/CountStar.java |  2 +-
 .../spark/sql/connector/expressions/Max.java       |  2 +-
 .../spark/sql/connector/expressions/Min.java       |  2 +-
 .../spark/sql/connector/expressions/Sum.java       | 12 +--
 .../execution/datasources/DataSourceStrategy.scala |  3 +-
 .../sql/execution/datasources/jdbc/JDBCRDD.scala   | 45 +++++------
 .../execution/datasources/jdbc/JDBCRelation.scala  |  7 +-
 .../execution/datasources/v2/PushDownUtils.scala   |  2 +-
 .../datasources/v2/V2ScanRelationPushDown.scala    | 24 ++++--
 .../execution/datasources/v2/jdbc/JDBCScan.scala   | 12 ++-
 .../datasources/v2/jdbc/JDBCScanBuilder.scala      | 88 ++++++++++------------
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 30 ++++----
 13 files changed, 121 insertions(+), 116 deletions(-)

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java
index 0e28a93..fecde71 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java
@@ -38,7 +38,13 @@ public final class Count implements AggregateFunc {
   public boolean isDistinct() { return isDistinct; }
 
   @Override
-  public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
+  public String toString() {
+    if (isDistinct) {
+      return "COUNT(DISTINCT " + column.describe() + ")";
+    } else {
+      return "COUNT(" + column.describe() + ")";
+    }
+  }
 
   @Override
   public String describe() { return this.toString(); }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java
index 21a3564..8e799cd 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java
@@ -31,7 +31,7 @@ public final class CountStar implements AggregateFunc {
   }
 
   @Override
-  public String toString() { return "CountStar()"; }
+  public String toString() { return "COUNT(*)"; }
 
   @Override
   public String describe() { return this.toString(); }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java
index d2ff6b2..3ce45ca 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java
@@ -33,7 +33,7 @@ public final class Max implements AggregateFunc {
   public FieldReference column() { return column; }
 
   @Override
-  public String toString() { return "Max(" + column.describe() + ")"; }
+  public String toString() { return "MAX(" + column.describe() + ")"; }
 
   @Override
   public String describe() { return this.toString(); }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java
index efa8036..2449358 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java
@@ -33,7 +33,7 @@ public final class Min implements AggregateFunc {
   public FieldReference column() { return column; }
 
   @Override
-  public String toString() { return "Min(" + column.describe() + ")"; }
+  public String toString() { return "MIN(" + column.describe() + ")"; }
 
   @Override
   public String describe() { return this.toString(); }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java
index e4e860e..345194f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.connector.expressions;
 
 import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.types.DataType;
 
 /**
  * An aggregate function that returns the summation of all the values in a group.
@@ -28,22 +27,23 @@ import org.apache.spark.sql.types.DataType;
 @Evolving
 public final class Sum implements AggregateFunc {
   private final FieldReference column;
-  private final DataType dataType;
   private final boolean isDistinct;
 
-  public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
+  public Sum(FieldReference column, boolean isDistinct) {
     this.column = column;
-    this.dataType = dataType;
     this.isDistinct = isDistinct;
   }
 
   public FieldReference column() { return column; }
-  public DataType dataType() { return dataType; }
   public boolean isDistinct() { return isDistinct; }
 
   @Override
   public String toString() {
-    return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
+    if (isDistinct) {
+      return "SUM(DISTINCT " + column.describe() + ")";
+    } else {
+      return "SUM(" + column.describe() + ")";
+    }
   }
 
   @Override
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2f334de..81ecb2c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -714,8 +714,7 @@ object DataSourceStrategy
             case _ => None
           }
         case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
-          Some(new Sum(FieldReference(name).asInstanceOf[FieldReference],
-            sum.dataType, aggregates.isDistinct))
+          Some(new Sum(FieldReference(name).asInstanceOf[FieldReference], aggregates.isDistinct))
         case _ => None
       }
     } else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index af6c407..c575e95 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -25,7 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, FieldReference, Max, Min, Sum}
+import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, Max, Min, Sum}
 import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
@@ -54,9 +54,14 @@ object JDBCRDD extends Logging {
     val url = options.url
     val table = options.tableOrQuery
     val dialect = JdbcDialects.get(url)
+    getQueryOutputSchema(dialect.getSchemaQuery(table), options, dialect)
+  }
+
+  def getQueryOutputSchema(
+      query: String, options: JDBCOptions, dialect: JdbcDialect): StructType = {
     val conn: Connection = JdbcUtils.createConnectionFactory(options)()
     try {
-      val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
+      val statement = conn.prepareStatement(query)
       try {
         statement.setQueryTimeout(options.queryTimeout)
         val rs = statement.executeQuery()
@@ -136,30 +141,30 @@ object JDBCRDD extends Logging {
 
   def compileAggregates(
       aggregates: Seq[AggregateFunc],
-      dialect: JdbcDialect): Seq[String] = {
+      dialect: JdbcDialect): Option[Seq[String]] = {
     def quote(colName: String): String = dialect.quoteIdentifier(colName)
 
-    aggregates.map {
+    Some(aggregates.map {
       case min: Min =>
-        assert(min.column.fieldNames.length == 1)
+        if (min.column.fieldNames.length != 1) return None
         s"MIN(${quote(min.column.fieldNames.head)})"
       case max: Max =>
-        assert(max.column.fieldNames.length == 1)
+        if (max.column.fieldNames.length != 1) return None
         s"MAX(${quote(max.column.fieldNames.head)})"
       case count: Count =>
-        assert(count.column.fieldNames.length == 1)
-        val distinct = if (count.isDistinct) "DISTINCT" else ""
+        if (count.column.fieldNames.length != 1) return None
+        val distinct = if (count.isDistinct) "DISTINCT " else ""
         val column = quote(count.column.fieldNames.head)
-        s"COUNT($distinct $column)"
+        s"COUNT($distinct$column)"
       case sum: Sum =>
-        assert(sum.column.fieldNames.length == 1)
-        val distinct = if (sum.isDistinct) "DISTINCT" else ""
+        if (sum.column.fieldNames.length != 1) return None
+        val distinct = if (sum.isDistinct) "DISTINCT " else ""
         val column = quote(sum.column.fieldNames.head)
-        s"SUM($distinct $column)"
+        s"SUM($distinct$column)"
       case _: CountStar =>
-        s"COUNT(1)"
-      case _ => ""
-    }
+        s"COUNT(*)"
+      case _ => return None
+    })
   }
 
   /**
@@ -185,7 +190,7 @@ object JDBCRDD extends Logging {
       parts: Array[Partition],
       options: JDBCOptions,
       outputSchema: Option[StructType] = None,
-      groupByColumns: Option[Array[FieldReference]] = None): RDD[InternalRow] = {
+      groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = {
     val url = options.url
     val dialect = JdbcDialects.get(url)
     val quotedColumns = if (groupByColumns.isEmpty) {
@@ -221,7 +226,7 @@ private[jdbc] class JDBCRDD(
     partitions: Array[Partition],
     url: String,
     options: JDBCOptions,
-    groupByColumns: Option[Array[FieldReference]])
+    groupByColumns: Option[Array[String]])
   extends RDD[InternalRow](sc, Nil) {
 
   /**
@@ -266,10 +271,8 @@ private[jdbc] class JDBCRDD(
    */
   private def getGroupByClause: String = {
     if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) {
-      assert(groupByColumns.get.forall(_.fieldNames.length == 1))
-      val dialect = JdbcDialects.get(url)
-      val quotedColumns = groupByColumns.get.map(c => dialect.quoteIdentifier(c.fieldNames.head))
-      s"GROUP BY ${quotedColumns.mkString(", ")}"
+      // The GROUP BY columns should already be quoted by the caller side.
+      s"GROUP BY ${groupByColumns.get.mkString(", ")}"
     } else {
       ""
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 5fb26d2..60d88b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
-import org.apache.spark.sql.connector.expressions.FieldReference
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.jdbc.JdbcDialects
@@ -291,9 +290,9 @@ private[sql] case class JDBCRelation(
 
   def buildScan(
       requiredColumns: Array[String],
-      requireSchema: Option[StructType],
+      finalSchema: StructType,
       filters: Array[Filter],
-      groupByColumns: Option[Array[FieldReference]]): RDD[Row] = {
+      groupByColumns: Option[Array[String]]): RDD[Row] = {
     // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
     JDBCRDD.scanTable(
       sparkSession.sparkContext,
@@ -302,7 +301,7 @@ private[sql] case class JDBCRelation(
       filters,
       parts,
       jdbcOptions,
-      requireSchema,
+      Some(finalSchema),
       groupByColumns).asInstanceOf[RDD[Row]]
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index 34b6431..6eedeba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -91,7 +91,7 @@ object PushDownUtils extends PredicateHelper {
     }
 
     scanBuilder match {
-      case r: SupportsPushDownAggregates =>
+      case r: SupportsPushDownAggregates if aggregates.nonEmpty =>
         val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
         val translatedGroupBys = groupBy.flatMap(columnAsString)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index a1fc981..d05519b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
+import scala.collection.mutable
+
 import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -76,9 +78,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
           if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
           sHolder.builder match {
             case _: SupportsPushDownAggregates =>
+              val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
+              var ordinal = 0
               val aggregates = resultExpressions.flatMap { expr =>
                 expr.collect {
-                  case agg: AggregateExpression => agg
+                  // Do not push down duplicated aggregate expressions. For example,
+                  // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
+                  // `max(a)` to the data source.
+                  case agg: AggregateExpression
+                      if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
+                    aggExprToOutputOrdinal(agg.canonicalized) = ordinal
+                    ordinal += 1
+                    agg
                 }
               }
               val pushedAggregates = PushDownUtils
@@ -144,19 +155,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
                 // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
                 // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
                 // scalastyle:on
-                var i = 0
                 val aggOutput = output.drop(groupAttrs.length)
                 plan.transformExpressions {
                   case agg: AggregateExpression =>
+                    val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
                     val aggFunction: aggregate.AggregateFunction =
                       agg.aggregateFunction match {
-                        case max: aggregate.Max => max.copy(child = aggOutput(i))
-                        case min: aggregate.Min => min.copy(child = aggOutput(i))
-                        case sum: aggregate.Sum => sum.copy(child = aggOutput(i))
-                        case _: aggregate.Count => aggregate.Sum(aggOutput(i))
+                        case max: aggregate.Max => max.copy(child = aggOutput(ordinal))
+                        case min: aggregate.Min => min.copy(child = aggOutput(ordinal))
+                        case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal))
+                        case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal))
                         case other => other
                       }
-                    i += 1
                     agg.copy(aggregateFunction = aggFunction)
                 }
               }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
index d6ae7c8..ef42691 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
@@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.sql.connector.expressions.FieldReference
 import org.apache.spark.sql.connector.read.V1Scan
 import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
 import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan}
@@ -29,7 +28,7 @@ case class JDBCScan(
     prunedSchema: StructType,
     pushedFilters: Array[Filter],
     pushedAggregateColumn: Array[String] = Array(),
-    groupByColumns: Option[Array[FieldReference]]) extends V1Scan {
+    groupByColumns: Option[Array[String]]) extends V1Scan {
 
   override def readSchema(): StructType = prunedSchema
 
@@ -39,13 +38,12 @@ case class JDBCScan(
       override def schema: StructType = prunedSchema
       override def needConversion: Boolean = relation.needConversion
       override def buildScan(): RDD[Row] = {
-        if (groupByColumns.isEmpty) {
-          relation.buildScan(
-            prunedSchema.map(_.name).toArray, Some(prunedSchema), pushedFilters, groupByColumns)
+        val columnList = if (groupByColumns.isEmpty) {
+          prunedSchema.map(_.name).toArray
         } else {
-          relation.buildScan(
-            pushedAggregateColumn, Some(prunedSchema), pushedFilters, groupByColumns)
+          pushedAggregateColumn
         }
+        relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns)
       }
     }.asInstanceOf[T]
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index afdc822..89fa621 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -16,27 +16,33 @@
  */
 package org.apache.spark.sql.execution.datasources.v2.jdbc
 
+import scala.util.control.NonFatal
+
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.expressions.{Aggregation, Count, CountStar, FieldReference, Max, Min, Sum}
+import org.apache.spark.sql.connector.expressions.Aggregation
 import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
 import org.apache.spark.sql.execution.datasources.PartitioningUtils
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation}
 import org.apache.spark.sql.jdbc.JdbcDialects
 import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.{LongType, StructField, StructType}
+import org.apache.spark.sql.types.StructType
 
 case class JDBCScanBuilder(
     session: SparkSession,
     schema: StructType,
     jdbcOptions: JDBCOptions)
-  extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns
-    with SupportsPushDownAggregates{
+  extends ScanBuilder
+    with SupportsPushDownFilters
+    with SupportsPushDownRequiredColumns
+    with SupportsPushDownAggregates
+    with Logging {
 
   private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis
 
   private var pushedFilter = Array.empty[Filter]
 
-  private var prunedSchema = schema
+  private var finalSchema = schema
 
   override def pushFilters(filters: Array[Filter]): Array[Filter] = {
     if (jdbcOptions.pushDownPredicate) {
@@ -51,56 +57,45 @@ case class JDBCScanBuilder(
 
   override def pushedFilters(): Array[Filter] = pushedFilter
 
-  private var pushedAggregations = Option.empty[Aggregation]
-
-  private var pushedAggregateColumn: Array[String] = Array()
+  private var pushedAggregateList: Array[String] = Array()
 
-  private def getStructFieldForCol(col: FieldReference): StructField =
-    schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+  private var pushedGroupByCols: Option[Array[String]] = None
 
   override def pushAggregation(aggregation: Aggregation): Boolean = {
     if (!jdbcOptions.pushDownAggregate) return false
 
     val dialect = JdbcDialects.get(jdbcOptions.url)
     val compiledAgg = JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect)
+    if (compiledAgg.isEmpty) return false
 
-    var outputSchema = new StructType()
-    aggregation.groupByColumns.foreach { col =>
-      val structField = getStructFieldForCol(col)
-      outputSchema = outputSchema.add(structField)
-      pushedAggregateColumn = pushedAggregateColumn :+ dialect.quoteIdentifier(structField.name)
+    val groupByCols = aggregation.groupByColumns.map { col =>
+      if (col.fieldNames.length != 1) return false
+      dialect.quoteIdentifier(col.fieldNames.head)
     }
 
     // The column names here are already quoted and can be used to build sql string directly.
     // e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") =>
     // SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee"
     //   GROUP BY "DEPT", "NAME"
-    pushedAggregateColumn = pushedAggregateColumn ++ compiledAgg
-
-    aggregation.aggregateExpressions.foreach {
-      case max: Max =>
-        val structField = getStructFieldForCol(max.column)
-        outputSchema = outputSchema.add(structField.copy("max(" + structField.name + ")"))
-      case min: Min =>
-        val structField = getStructFieldForCol(min.column)
-        outputSchema = outputSchema.add(structField.copy("min(" + structField.name + ")"))
-      case count: Count =>
-        val distinct = if (count.isDistinct) "DISTINCT " else ""
-        val structField = getStructFieldForCol(count.column)
-        outputSchema =
-          outputSchema.add(StructField(s"count($distinct" + structField.name + ")", LongType))
-      case _: CountStar =>
-        outputSchema = outputSchema.add(StructField("count(*)", LongType))
-      case sum: Sum =>
-        val distinct = if (sum.isDistinct) "DISTINCT " else ""
-        val structField = getStructFieldForCol(sum.column)
-        outputSchema =
-          outputSchema.add(StructField(s"sum($distinct" + structField.name + ")", sum.dataType))
-      case _ => return false
+    val selectList = groupByCols ++ compiledAgg.get
+    val groupByClause = if (groupByCols.isEmpty) {
+      ""
+    } else {
+      "GROUP BY " + groupByCols.mkString(",")
+    }
+
+    val aggQuery = s"SELECT ${selectList.mkString(",")} FROM ${jdbcOptions.tableOrQuery} " +
+      s"WHERE 1=0 $groupByClause"
+    try {
+      finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, dialect)
+      pushedAggregateList = selectList
+      pushedGroupByCols = Some(groupByCols)
+      true
+    } catch {
+      case NonFatal(e) =>
+        logError("Failed to push down aggregation to JDBC", e)
+        false
     }
-    this.pushedAggregations = Some(aggregation)
-    prunedSchema = outputSchema
-    true
   }
 
   override def pruneColumns(requiredSchema: StructType): Unit = {
@@ -112,7 +107,7 @@ case class JDBCScanBuilder(
       val colName = PartitioningUtils.getColName(field, isCaseSensitive)
       requiredCols.contains(colName)
     }
-    prunedSchema = StructType(fields)
+    finalSchema = StructType(fields)
   }
 
   override def build(): Scan = {
@@ -120,19 +115,14 @@ case class JDBCScanBuilder(
     val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
     val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions)
 
-    // in prunedSchema, the schema is either pruned in pushAggregation (if aggregates are
+    // the `finalSchema` is either pruned in pushAggregation (if aggregates are
     // pushed down), or pruned in pruneColumns (in regular column pruning). These
     // two are mutual exclusive.
     // For aggregate push down case, we want to pass down the quoted column lists such as
     // "DEPT","NAME",MAX("SALARY"),MIN("BONUS"), instead of getting column names from
     // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't
     // be used in sql string.
-    val groupByColumns = if (pushedAggregations.nonEmpty) {
-      Some(pushedAggregations.get.groupByColumns)
-    } else {
-      Option.empty[Array[FieldReference]]
-    }
-    JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema, pushedFilter,
-      pushedAggregateColumn, groupByColumns)
+    JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter,
+      pushedAggregateList, pushedGroupByCols)
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 8dfb6de..37bc352 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -248,7 +248,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Max(SALARY), Min(BONUS)], " +
+          "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
             "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
             "PushedGroupby: [DEPT]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -265,7 +265,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Max(ID), Min(ID)], " +
+          "PushedAggregates: [MAX(ID), MIN(ID)], " +
             "PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " +
             "PushedGroupby: []"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -278,7 +278,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Max(SALARY)]"
+          "PushedAggregates: [MAX(SALARY)]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
     }
     checkAnswer(df, Seq(Row(12001)))
@@ -289,7 +289,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [CountStar()]"
+          "PushedAggregates: [COUNT(*)]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
     }
     checkAnswer(df, Seq(Row(5)))
@@ -300,7 +300,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Count(DEPT,false)]"
+          "PushedAggregates: [COUNT(DEPT)]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
     }
     checkAnswer(df, Seq(Row(5)))
@@ -311,7 +311,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Count(DEPT,true)]"
+          "PushedAggregates: [COUNT(DISTINCT DEPT)]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
     }
     checkAnswer(df, Seq(Row(3)))
@@ -322,7 +322,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)]"
+          "PushedAggregates: [SUM(SALARY)]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
     }
     checkAnswer(df, Seq(Row(53000)))
@@ -333,7 +333,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)]"
+          "PushedAggregates: [SUM(DISTINCT SALARY)]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
     }
     checkAnswer(df, Seq(Row(31000)))
@@ -344,7 +344,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " +
+          "PushedAggregates: [SUM(SALARY)], " +
             "PushedFilters: [], " +
             "PushedGroupby: [DEPT]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -357,7 +357,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)], " +
+          "PushedAggregates: [SUM(DISTINCT SALARY)], " +
             "PushedFilters: [], " +
             "PushedGroupby: [DEPT]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -375,7 +375,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Max(SALARY), Min(BONUS)], " +
+          "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
             "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
             "PushedGroupby: [DEPT, NAME]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -394,7 +394,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Max(SALARY), Min(BONUS)], " +
+          "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
           "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
           "PushedGroupby: [DEPT]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -409,7 +409,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Min(SALARY)], " +
+          "PushedAggregates: [MIN(SALARY)], " +
             "PushedFilters: [], " +
             "PushedGroupby: [DEPT]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -432,7 +432,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     query.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " +
+          "PushedAggregates: [SUM(SALARY)], " +
             "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
             "PushedGroupby: [DEPT]"
         checkKeywordsExistsInExplain(query, expected_plan_fragment)
@@ -447,7 +447,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     query.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false), Sum(BONUS,DoubleType,false)"
+          "PushedAggregates: [SUM(SALARY), SUM(BONUS)"
         checkKeywordsExistsInExplain(query, expected_plan_fragment)
     }
     checkAnswer(query, Seq(Row(47100.0)))

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org