You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by vi...@apache.org on 2021/07/30 07:27:41 UTC
[spark] branch branch-3.2 updated: [SPARK-34952][SQL][FOLLOWUP]
Simplify JDBC aggregate pushdown
This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new f6bb75b [SPARK-34952][SQL][FOLLOWUP] Simplify JDBC aggregate pushdown
f6bb75b is described below
commit f6bb75b0bcae8c0bccf361dfd3710ce5f17173d5
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Fri Jul 30 00:26:32 2021 -0700
[SPARK-34952][SQL][FOLLOWUP] Simplify JDBC aggregate pushdown
### What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/33352 , to simplify the JDBC aggregate pushdown:
1. We should get the schema of the aggregate query by asking the JDBC server, instead of calculating it by ourselves. This can simplify the code a lot, and is also more robust: the data type of SUM may vary in different databases, it's fragile to assume they are always the same as Spark.
2. because of 1, now we can remove the `dataType` property from the public `Sum` expression.
This PR also contains some small improvements:
1. Spark should deduplicate the aggregate expressions before pushing them down.
2. Improve the `toString` of public aggregate expressions to make them more SQL.
### Why are the changes needed?
code and API simplification
### Does this PR introduce _any_ user-facing change?
this API is not released yet.
### How was this patch tested?
existing tests
Closes #33579 from cloud-fan/dsv2.
Authored-by: Wenchen Fan <we...@databricks.com>
Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
(cherry picked from commit 387a251a682a596ba4156b7d12e6025762ebac85)
Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
---
.../spark/sql/connector/expressions/Count.java | 8 +-
.../spark/sql/connector/expressions/CountStar.java | 2 +-
.../spark/sql/connector/expressions/Max.java | 2 +-
.../spark/sql/connector/expressions/Min.java | 2 +-
.../spark/sql/connector/expressions/Sum.java | 12 +--
.../execution/datasources/DataSourceStrategy.scala | 3 +-
.../sql/execution/datasources/jdbc/JDBCRDD.scala | 45 +++++------
.../execution/datasources/jdbc/JDBCRelation.scala | 7 +-
.../execution/datasources/v2/PushDownUtils.scala | 2 +-
.../datasources/v2/V2ScanRelationPushDown.scala | 24 ++++--
.../execution/datasources/v2/jdbc/JDBCScan.scala | 12 ++-
.../datasources/v2/jdbc/JDBCScanBuilder.scala | 88 ++++++++++------------
.../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 30 ++++----
13 files changed, 121 insertions(+), 116 deletions(-)
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java
index 0e28a93..fecde71 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java
@@ -38,7 +38,13 @@ public final class Count implements AggregateFunc {
public boolean isDistinct() { return isDistinct; }
@Override
- public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
+ public String toString() {
+ if (isDistinct) {
+ return "COUNT(DISTINCT " + column.describe() + ")";
+ } else {
+ return "COUNT(" + column.describe() + ")";
+ }
+ }
@Override
public String describe() { return this.toString(); }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java
index 21a3564..8e799cd 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java
@@ -31,7 +31,7 @@ public final class CountStar implements AggregateFunc {
}
@Override
- public String toString() { return "CountStar()"; }
+ public String toString() { return "COUNT(*)"; }
@Override
public String describe() { return this.toString(); }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java
index d2ff6b2..3ce45ca 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java
@@ -33,7 +33,7 @@ public final class Max implements AggregateFunc {
public FieldReference column() { return column; }
@Override
- public String toString() { return "Max(" + column.describe() + ")"; }
+ public String toString() { return "MAX(" + column.describe() + ")"; }
@Override
public String describe() { return this.toString(); }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java
index efa8036..2449358 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java
@@ -33,7 +33,7 @@ public final class Min implements AggregateFunc {
public FieldReference column() { return column; }
@Override
- public String toString() { return "Min(" + column.describe() + ")"; }
+ public String toString() { return "MIN(" + column.describe() + ")"; }
@Override
public String describe() { return this.toString(); }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java
index e4e860e..345194f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java
@@ -18,7 +18,6 @@
package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.types.DataType;
/**
* An aggregate function that returns the summation of all the values in a group.
@@ -28,22 +27,23 @@ import org.apache.spark.sql.types.DataType;
@Evolving
public final class Sum implements AggregateFunc {
private final FieldReference column;
- private final DataType dataType;
private final boolean isDistinct;
- public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
+ public Sum(FieldReference column, boolean isDistinct) {
this.column = column;
- this.dataType = dataType;
this.isDistinct = isDistinct;
}
public FieldReference column() { return column; }
- public DataType dataType() { return dataType; }
public boolean isDistinct() { return isDistinct; }
@Override
public String toString() {
- return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
+ if (isDistinct) {
+ return "SUM(DISTINCT " + column.describe() + ")";
+ } else {
+ return "SUM(" + column.describe() + ")";
+ }
}
@Override
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2f334de..81ecb2c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -714,8 +714,7 @@ object DataSourceStrategy
case _ => None
}
case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
- Some(new Sum(FieldReference(name).asInstanceOf[FieldReference],
- sum.dataType, aggregates.isDistinct))
+ Some(new Sum(FieldReference(name).asInstanceOf[FieldReference], aggregates.isDistinct))
case _ => None
}
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index af6c407..c575e95 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -25,7 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, FieldReference, Max, Min, Sum}
+import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -54,9 +54,14 @@ object JDBCRDD extends Logging {
val url = options.url
val table = options.tableOrQuery
val dialect = JdbcDialects.get(url)
+ getQueryOutputSchema(dialect.getSchemaQuery(table), options, dialect)
+ }
+
+ def getQueryOutputSchema(
+ query: String, options: JDBCOptions, dialect: JdbcDialect): StructType = {
val conn: Connection = JdbcUtils.createConnectionFactory(options)()
try {
- val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
+ val statement = conn.prepareStatement(query)
try {
statement.setQueryTimeout(options.queryTimeout)
val rs = statement.executeQuery()
@@ -136,30 +141,30 @@ object JDBCRDD extends Logging {
def compileAggregates(
aggregates: Seq[AggregateFunc],
- dialect: JdbcDialect): Seq[String] = {
+ dialect: JdbcDialect): Option[Seq[String]] = {
def quote(colName: String): String = dialect.quoteIdentifier(colName)
- aggregates.map {
+ Some(aggregates.map {
case min: Min =>
- assert(min.column.fieldNames.length == 1)
+ if (min.column.fieldNames.length != 1) return None
s"MIN(${quote(min.column.fieldNames.head)})"
case max: Max =>
- assert(max.column.fieldNames.length == 1)
+ if (max.column.fieldNames.length != 1) return None
s"MAX(${quote(max.column.fieldNames.head)})"
case count: Count =>
- assert(count.column.fieldNames.length == 1)
- val distinct = if (count.isDistinct) "DISTINCT" else ""
+ if (count.column.fieldNames.length != 1) return None
+ val distinct = if (count.isDistinct) "DISTINCT " else ""
val column = quote(count.column.fieldNames.head)
- s"COUNT($distinct $column)"
+ s"COUNT($distinct$column)"
case sum: Sum =>
- assert(sum.column.fieldNames.length == 1)
- val distinct = if (sum.isDistinct) "DISTINCT" else ""
+ if (sum.column.fieldNames.length != 1) return None
+ val distinct = if (sum.isDistinct) "DISTINCT " else ""
val column = quote(sum.column.fieldNames.head)
- s"SUM($distinct $column)"
+ s"SUM($distinct$column)"
case _: CountStar =>
- s"COUNT(1)"
- case _ => ""
- }
+ s"COUNT(*)"
+ case _ => return None
+ })
}
/**
@@ -185,7 +190,7 @@ object JDBCRDD extends Logging {
parts: Array[Partition],
options: JDBCOptions,
outputSchema: Option[StructType] = None,
- groupByColumns: Option[Array[FieldReference]] = None): RDD[InternalRow] = {
+ groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = if (groupByColumns.isEmpty) {
@@ -221,7 +226,7 @@ private[jdbc] class JDBCRDD(
partitions: Array[Partition],
url: String,
options: JDBCOptions,
- groupByColumns: Option[Array[FieldReference]])
+ groupByColumns: Option[Array[String]])
extends RDD[InternalRow](sc, Nil) {
/**
@@ -266,10 +271,8 @@ private[jdbc] class JDBCRDD(
*/
private def getGroupByClause: String = {
if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) {
- assert(groupByColumns.get.forall(_.fieldNames.length == 1))
- val dialect = JdbcDialects.get(url)
- val quotedColumns = groupByColumns.get.map(c => dialect.quoteIdentifier(c.fieldNames.head))
- s"GROUP BY ${quotedColumns.mkString(", ")}"
+ // The GROUP BY columns should already be quoted by the caller side.
+ s"GROUP BY ${groupByColumns.get.mkString(", ")}"
} else {
""
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 5fb26d2..60d88b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
-import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.JdbcDialects
@@ -291,9 +290,9 @@ private[sql] case class JDBCRelation(
def buildScan(
requiredColumns: Array[String],
- requireSchema: Option[StructType],
+ finalSchema: StructType,
filters: Array[Filter],
- groupByColumns: Option[Array[FieldReference]]): RDD[Row] = {
+ groupByColumns: Option[Array[String]]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
@@ -302,7 +301,7 @@ private[sql] case class JDBCRelation(
filters,
parts,
jdbcOptions,
- requireSchema,
+ Some(finalSchema),
groupByColumns).asInstanceOf[RDD[Row]]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index 34b6431..6eedeba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -91,7 +91,7 @@ object PushDownUtils extends PredicateHelper {
}
scanBuilder match {
- case r: SupportsPushDownAggregates =>
+ case r: SupportsPushDownAggregates if aggregates.nonEmpty =>
val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
val translatedGroupBys = groupBy.flatMap(columnAsString)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index a1fc981..d05519b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.v2
+import scala.collection.mutable
+
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -76,9 +78,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
sHolder.builder match {
case _: SupportsPushDownAggregates =>
+ val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
+ var ordinal = 0
val aggregates = resultExpressions.flatMap { expr =>
expr.collect {
- case agg: AggregateExpression => agg
+ // Do not push down duplicated aggregate expressions. For example,
+ // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
+ // `max(a)` to the data source.
+ case agg: AggregateExpression
+ if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
+ aggExprToOutputOrdinal(agg.canonicalized) = ordinal
+ ordinal += 1
+ agg
}
}
val pushedAggregates = PushDownUtils
@@ -144,19 +155,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
// scalastyle:on
- var i = 0
val aggOutput = output.drop(groupAttrs.length)
plan.transformExpressions {
case agg: AggregateExpression =>
+ val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
val aggFunction: aggregate.AggregateFunction =
agg.aggregateFunction match {
- case max: aggregate.Max => max.copy(child = aggOutput(i))
- case min: aggregate.Min => min.copy(child = aggOutput(i))
- case sum: aggregate.Sum => sum.copy(child = aggOutput(i))
- case _: aggregate.Count => aggregate.Sum(aggOutput(i))
+ case max: aggregate.Max => max.copy(child = aggOutput(ordinal))
+ case min: aggregate.Min => min.copy(child = aggOutput(ordinal))
+ case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal))
+ case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal))
case other => other
}
- i += 1
agg.copy(aggregateFunction = aggFunction)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
index d6ae7c8..ef42691 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
@@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.read.V1Scan
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan}
@@ -29,7 +28,7 @@ case class JDBCScan(
prunedSchema: StructType,
pushedFilters: Array[Filter],
pushedAggregateColumn: Array[String] = Array(),
- groupByColumns: Option[Array[FieldReference]]) extends V1Scan {
+ groupByColumns: Option[Array[String]]) extends V1Scan {
override def readSchema(): StructType = prunedSchema
@@ -39,13 +38,12 @@ case class JDBCScan(
override def schema: StructType = prunedSchema
override def needConversion: Boolean = relation.needConversion
override def buildScan(): RDD[Row] = {
- if (groupByColumns.isEmpty) {
- relation.buildScan(
- prunedSchema.map(_.name).toArray, Some(prunedSchema), pushedFilters, groupByColumns)
+ val columnList = if (groupByColumns.isEmpty) {
+ prunedSchema.map(_.name).toArray
} else {
- relation.buildScan(
- pushedAggregateColumn, Some(prunedSchema), pushedFilters, groupByColumns)
+ pushedAggregateColumn
}
+ relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns)
}
}.asInstanceOf[T]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index afdc822..89fa621 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -16,27 +16,33 @@
*/
package org.apache.spark.sql.execution.datasources.v2.jdbc
+import scala.util.control.NonFatal
+
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.expressions.{Aggregation, Count, CountStar, FieldReference, Max, Min, Sum}
+import org.apache.spark.sql.connector.expressions.Aggregation
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.{LongType, StructField, StructType}
+import org.apache.spark.sql.types.StructType
case class JDBCScanBuilder(
session: SparkSession,
schema: StructType,
jdbcOptions: JDBCOptions)
- extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns
- with SupportsPushDownAggregates{
+ extends ScanBuilder
+ with SupportsPushDownFilters
+ with SupportsPushDownRequiredColumns
+ with SupportsPushDownAggregates
+ with Logging {
private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis
private var pushedFilter = Array.empty[Filter]
- private var prunedSchema = schema
+ private var finalSchema = schema
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
if (jdbcOptions.pushDownPredicate) {
@@ -51,56 +57,45 @@ case class JDBCScanBuilder(
override def pushedFilters(): Array[Filter] = pushedFilter
- private var pushedAggregations = Option.empty[Aggregation]
-
- private var pushedAggregateColumn: Array[String] = Array()
+ private var pushedAggregateList: Array[String] = Array()
- private def getStructFieldForCol(col: FieldReference): StructField =
- schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+ private var pushedGroupByCols: Option[Array[String]] = None
override def pushAggregation(aggregation: Aggregation): Boolean = {
if (!jdbcOptions.pushDownAggregate) return false
val dialect = JdbcDialects.get(jdbcOptions.url)
val compiledAgg = JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect)
+ if (compiledAgg.isEmpty) return false
- var outputSchema = new StructType()
- aggregation.groupByColumns.foreach { col =>
- val structField = getStructFieldForCol(col)
- outputSchema = outputSchema.add(structField)
- pushedAggregateColumn = pushedAggregateColumn :+ dialect.quoteIdentifier(structField.name)
+ val groupByCols = aggregation.groupByColumns.map { col =>
+ if (col.fieldNames.length != 1) return false
+ dialect.quoteIdentifier(col.fieldNames.head)
}
// The column names here are already quoted and can be used to build sql string directly.
// e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") =>
// SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee"
// GROUP BY "DEPT", "NAME"
- pushedAggregateColumn = pushedAggregateColumn ++ compiledAgg
-
- aggregation.aggregateExpressions.foreach {
- case max: Max =>
- val structField = getStructFieldForCol(max.column)
- outputSchema = outputSchema.add(structField.copy("max(" + structField.name + ")"))
- case min: Min =>
- val structField = getStructFieldForCol(min.column)
- outputSchema = outputSchema.add(structField.copy("min(" + structField.name + ")"))
- case count: Count =>
- val distinct = if (count.isDistinct) "DISTINCT " else ""
- val structField = getStructFieldForCol(count.column)
- outputSchema =
- outputSchema.add(StructField(s"count($distinct" + structField.name + ")", LongType))
- case _: CountStar =>
- outputSchema = outputSchema.add(StructField("count(*)", LongType))
- case sum: Sum =>
- val distinct = if (sum.isDistinct) "DISTINCT " else ""
- val structField = getStructFieldForCol(sum.column)
- outputSchema =
- outputSchema.add(StructField(s"sum($distinct" + structField.name + ")", sum.dataType))
- case _ => return false
+ val selectList = groupByCols ++ compiledAgg.get
+ val groupByClause = if (groupByCols.isEmpty) {
+ ""
+ } else {
+ "GROUP BY " + groupByCols.mkString(",")
+ }
+
+ val aggQuery = s"SELECT ${selectList.mkString(",")} FROM ${jdbcOptions.tableOrQuery} " +
+ s"WHERE 1=0 $groupByClause"
+ try {
+ finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, dialect)
+ pushedAggregateList = selectList
+ pushedGroupByCols = Some(groupByCols)
+ true
+ } catch {
+ case NonFatal(e) =>
+ logError("Failed to push down aggregation to JDBC", e)
+ false
}
- this.pushedAggregations = Some(aggregation)
- prunedSchema = outputSchema
- true
}
override def pruneColumns(requiredSchema: StructType): Unit = {
@@ -112,7 +107,7 @@ case class JDBCScanBuilder(
val colName = PartitioningUtils.getColName(field, isCaseSensitive)
requiredCols.contains(colName)
}
- prunedSchema = StructType(fields)
+ finalSchema = StructType(fields)
}
override def build(): Scan = {
@@ -120,19 +115,14 @@ case class JDBCScanBuilder(
val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions)
- // in prunedSchema, the schema is either pruned in pushAggregation (if aggregates are
+ // the `finalSchema` is either pruned in pushAggregation (if aggregates are
// pushed down), or pruned in pruneColumns (in regular column pruning). These
// two are mutual exclusive.
// For aggregate push down case, we want to pass down the quoted column lists such as
// "DEPT","NAME",MAX("SALARY"),MIN("BONUS"), instead of getting column names from
// prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't
// be used in sql string.
- val groupByColumns = if (pushedAggregations.nonEmpty) {
- Some(pushedAggregations.get.groupByColumns)
- } else {
- Option.empty[Array[FieldReference]]
- }
- JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema, pushedFilter,
- pushedAggregateColumn, groupByColumns)
+ JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter,
+ pushedAggregateList, pushedGroupByCols)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 8dfb6de..37bc352 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -248,7 +248,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Max(SALARY), Min(BONUS)], " +
+ "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -265,7 +265,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Max(ID), Min(ID)], " +
+ "PushedAggregates: [MAX(ID), MIN(ID)], " +
"PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " +
"PushedGroupby: []"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -278,7 +278,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Max(SALARY)]"
+ "PushedAggregates: [MAX(SALARY)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(12001)))
@@ -289,7 +289,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [CountStar()]"
+ "PushedAggregates: [COUNT(*)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(5)))
@@ -300,7 +300,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Count(DEPT,false)]"
+ "PushedAggregates: [COUNT(DEPT)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(5)))
@@ -311,7 +311,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Count(DEPT,true)]"
+ "PushedAggregates: [COUNT(DISTINCT DEPT)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(3)))
@@ -322,7 +322,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)]"
+ "PushedAggregates: [SUM(SALARY)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(53000)))
@@ -333,7 +333,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)]"
+ "PushedAggregates: [SUM(DISTINCT SALARY)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(31000)))
@@ -344,7 +344,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " +
+ "PushedAggregates: [SUM(SALARY)], " +
"PushedFilters: [], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -357,7 +357,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)], " +
+ "PushedAggregates: [SUM(DISTINCT SALARY)], " +
"PushedFilters: [], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -375,7 +375,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Max(SALARY), Min(BONUS)], " +
+ "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupby: [DEPT, NAME]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -394,7 +394,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Max(SALARY), Min(BONUS)], " +
+ "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -409,7 +409,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Min(SALARY)], " +
+ "PushedAggregates: [MIN(SALARY)], " +
"PushedFilters: [], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
@@ -432,7 +432,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
query.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " +
+ "PushedAggregates: [SUM(SALARY)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(query, expected_plan_fragment)
@@ -447,7 +447,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
query.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
- "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false), Sum(BONUS,DoubleType,false)"
+ "PushedAggregates: [SUM(SALARY), SUM(BONUS)"
checkKeywordsExistsInExplain(query, expected_plan_fragment)
}
checkAnswer(query, Seq(Row(47100.0)))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org