You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/07/28 04:54:00 UTC

[spark] branch branch-3.2 updated: [SPARK-34952][SQL][FOLLOW-UP] DSv2 aggregate push down follow-up

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 33ef52e  [SPARK-34952][SQL][FOLLOW-UP] DSv2 aggregate push down follow-up
33ef52e is described below

commit 33ef52e2c0856c0188d868a6cfb5f38b3d922f2f
Author: Huaxin Gao <hu...@apple.com>
AuthorDate: Wed Jul 28 12:52:42 2021 +0800

    [SPARK-34952][SQL][FOLLOW-UP] DSv2 aggregate push down follow-up
    
    ### What changes were proposed in this pull request?
    update java doc, JDBC data source doc, address follow up comments
    
    ### Why are the changes needed?
    update doc and address follow up comments
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, add the new JDBC option `pushDownAggregate` in JDBC data source doc.
    
    ### How was this patch tested?
    manually checked
    
    Closes #33526 from huaxingao/aggPD_followup.
    
    Authored-by: Huaxin Gao <hu...@apple.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit c8dd97d4566e4cd6865437c2640467c9c16080d4)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 docs/sql-data-sources-jdbc.md                      |  9 +++++
 .../sql/connector/expressions/Aggregation.java     | 12 +++----
 .../spark/sql/connector/expressions/Count.java     | 28 +++++++--------
 .../spark/sql/connector/expressions/CountStar.java | 14 ++++----
 .../spark/sql/connector/expressions/Max.java       | 18 ++++------
 .../spark/sql/connector/expressions/Min.java       | 20 ++++-------
 .../spark/sql/connector/expressions/Sum.java       | 40 +++++++++-------------
 .../connector/read/SupportsPushDownAggregates.java |  8 ++---
 .../sql/execution/datasources/jdbc/JDBCRDD.scala   |  8 ++---
 .../execution/datasources/v2/PushDownUtils.scala   | 10 ++----
 .../datasources/v2/V2ScanRelationPushDown.scala    |  4 +--
 .../datasources/v2/jdbc/JDBCScanBuilder.scala      |  4 +--
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    |  2 +-
 13 files changed, 78 insertions(+), 99 deletions(-)

diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md
index c973e8a..315f476 100644
--- a/docs/sql-data-sources-jdbc.md
+++ b/docs/sql-data-sources-jdbc.md
@@ -238,6 +238,15 @@ logging into the data sources.
   </tr>
 
   <tr>
+    <td><code>pushDownAggregate</code></td>
+    <td><code>false</code></td>
+    <td>
+     The option to enable or disable aggregate push-down into the JDBC data source. The default value is false, in which case Spark will not push down aggregates to the JDBC data source. Otherwise, if sets to true, aggregates will be pushed down to the JDBC data source. Aggregate push-down is usually turned off when the aggregate is performed faster by Spark than by the JDBC data source. Please note that aggregates can be pushed down if and only if all the aggregate functions and the rel [...]
+    </td>
+    <td>read</td>
+  </tr>
+
+  <tr>
     <td><code>keytab</code></td>
     <td>(none)</td>
     <td>
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java
index fdf3031..8eb3491 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java
@@ -28,19 +28,15 @@ import java.io.Serializable;
  */
 @Evolving
 public final class Aggregation implements Serializable {
-  private AggregateFunc[] aggregateExpressions;
-  private FieldReference[] groupByColumns;
+  private final AggregateFunc[] aggregateExpressions;
+  private final FieldReference[] groupByColumns;
 
   public Aggregation(AggregateFunc[] aggregateExpressions, FieldReference[] groupByColumns) {
     this.aggregateExpressions = aggregateExpressions;
     this.groupByColumns = groupByColumns;
   }
 
-  public AggregateFunc[] aggregateExpressions() {
-    return aggregateExpressions;
-  }
+  public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; }
 
-  public FieldReference[] groupByColumns() {
-    return groupByColumns;
-  }
+  public FieldReference[] groupByColumns() { return groupByColumns; }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java
index 17562a1..0e28a93 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java
@@ -26,24 +26,20 @@ import org.apache.spark.annotation.Evolving;
  */
 @Evolving
 public final class Count implements AggregateFunc {
-    private FieldReference column;
-    private boolean isDistinct;
+  private final FieldReference column;
+  private final boolean isDistinct;
 
-    public Count(FieldReference column, boolean isDistinct) {
-        this.column = column;
-        this.isDistinct = isDistinct;
-    }
+  public Count(FieldReference column, boolean isDistinct) {
+    this.column = column;
+    this.isDistinct = isDistinct;
+  }
 
-    public FieldReference column() {
-        return column;
-    }
-    public boolean isDinstinct() {
-        return isDistinct;
-    }
+  public FieldReference column() { return column; }
+  public boolean isDistinct() { return isDistinct; }
 
-    @Override
-    public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
+  @Override
+  public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
 
-    @Override
-    public String describe() { return this.toString(); }
+  @Override
+  public String describe() { return this.toString(); }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java
index 777a99d..21a3564 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java
@@ -27,14 +27,12 @@ import org.apache.spark.annotation.Evolving;
 @Evolving
 public final class CountStar implements AggregateFunc {
 
-    public CountStar() {
-    }
+  public CountStar() {
+  }
 
-    @Override
-    public String toString() {
-        return "CountStar()";
-    }
+  @Override
+  public String toString() { return "CountStar()"; }
 
-    @Override
-    public String describe() { return this.toString(); }
+  @Override
+  public String describe() { return this.toString(); }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java
index fe7689c..d2ff6b2 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java
@@ -26,19 +26,15 @@ import org.apache.spark.annotation.Evolving;
  */
 @Evolving
 public final class Max implements AggregateFunc {
-    private FieldReference column;
+  private final FieldReference column;
 
-    public Max(FieldReference column) {
-        this.column = column;
-    }
+  public Max(FieldReference column) { this.column = column; }
 
-    public FieldReference column() { return column; }
+  public FieldReference column() { return column; }
 
-    @Override
-    public String toString() {
-        return "Max(" + column.describe() + ")";
-    }
+  @Override
+  public String toString() { return "Max(" + column.describe() + ")"; }
 
-    @Override
-    public String describe() { return this.toString(); }
+  @Override
+  public String describe() { return this.toString(); }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java
index f528b0b..efa8036 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java
@@ -26,21 +26,15 @@ import org.apache.spark.annotation.Evolving;
  */
 @Evolving
 public final class Min implements AggregateFunc {
-    private FieldReference column;
+  private final FieldReference column;
 
-    public Min(FieldReference column) {
-        this.column = column;
-    }
+  public Min(FieldReference column) { this.column = column; }
 
-    public FieldReference column() {
-        return column;
-    }
+  public FieldReference column() { return column; }
 
-    @Override
-    public String toString() {
-        return "Min(" + column.describe() + ")";
-    }
+  @Override
+  public String toString() { return "Min(" + column.describe() + ")"; }
 
-    @Override
-    public String describe() { return this.toString(); }
+  @Override
+  public String describe() { return this.toString(); }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java
index 4cb34be..e4e860e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java
@@ -27,31 +27,25 @@ import org.apache.spark.sql.types.DataType;
  */
 @Evolving
 public final class Sum implements AggregateFunc {
-    private FieldReference column;
-    private DataType dataType;
-    private boolean isDistinct;
+  private final FieldReference column;
+  private final DataType dataType;
+  private final boolean isDistinct;
 
-    public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
-        this.column = column;
-        this.dataType = dataType;
-        this.isDistinct = isDistinct;
-    }
+  public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
+    this.column = column;
+    this.dataType = dataType;
+    this.isDistinct = isDistinct;
+  }
 
-    public FieldReference column() {
-        return column;
-    }
-    public DataType dataType() {
-        return dataType;
-    }
-    public boolean isDinstinct() {
-        return isDistinct;
-    }
+  public FieldReference column() { return column; }
+  public DataType dataType() { return dataType; }
+  public boolean isDistinct() { return isDistinct; }
 
-    @Override
-    public String toString() {
-        return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
-    }
+  @Override
+  public String toString() {
+    return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
+  }
 
-    @Override
-    public String describe() { return this.toString(); }
+  @Override
+  public String describe() { return this.toString(); }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
index 7efa333..8ec9a25 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
@@ -27,12 +27,10 @@ import org.apache.spark.sql.connector.expressions.Aggregation;
  * "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate
  * to the data source, the data source can still output data with duplicated keys, which is OK
  * as Spark will do GROUP BY key again. The final query plan can be something like this:
- * {{{
+ * <pre>
  *   Aggregate [key#1], [min(min(value)#2) AS m#3]
  *     +- RelationV2[key#1, min(value)#2]
- * }}}
- *
- * <p>
+ * </pre>
  * Similarly, if there is no grouping expression, the data source can still output more than one
  * rows.
  *
@@ -51,6 +49,8 @@ public interface SupportsPushDownAggregates extends ScanBuilder {
    * Pushes down Aggregation to datasource. The order of the datasource scan output columns should
    * be: grouping columns, aggregate columns (in the same order as the aggregate functions in
    * the given Aggregation).
+   *
+   * @return true if the aggregation can be pushed down to datasource, false otherwise.
    */
   boolean pushAggregation(Aggregation aggregation);
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index c22ca15..af6c407 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -148,12 +148,12 @@ object JDBCRDD extends Logging {
         s"MAX(${quote(max.column.fieldNames.head)})"
       case count: Count =>
         assert(count.column.fieldNames.length == 1)
-        val distinct = if (count.isDinstinct) "DISTINCT" else ""
+        val distinct = if (count.isDistinct) "DISTINCT" else ""
         val column = quote(count.column.fieldNames.head)
         s"COUNT($distinct $column)"
       case sum: Sum =>
         assert(sum.column.fieldNames.length == 1)
-        val distinct = if (sum.isDinstinct) "DISTINCT" else ""
+        val distinct = if (sum.isDistinct) "DISTINCT" else ""
         val column = quote(sum.column.fieldNames.head)
         s"SUM($distinct $column)"
       case _: CountStar =>
@@ -172,8 +172,8 @@ object JDBCRDD extends Logging {
    * @param parts - An array of JDBCPartitions specifying partition ids and
    *    per-partition WHERE clauses.
    * @param options - JDBC options that contains url, table and other information.
-   * @param requiredSchema - The schema of the columns to SELECT.
-   * @param aggregation - The pushed down aggregation
+   * @param outputSchema - The schema of the columns to SELECT.
+   * @param groupByColumns - The pushed down group by columns.
    *
    * @return An RDD representing "SELECT requiredColumns FROM fqTable".
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index ab5c5da..34b6431 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -92,8 +92,8 @@ object PushDownUtils extends PredicateHelper {
 
     scanBuilder match {
       case r: SupportsPushDownAggregates =>
-        val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate).flatten
-        val translatedGroupBys = groupBy.map(columnAsString).flatten
+        val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
+        val translatedGroupBys = groupBy.flatMap(columnAsString)
 
         if (translatedAggregates.length != aggregates.length ||
           translatedGroupBys.length != groupBy.length) {
@@ -101,11 +101,7 @@ object PushDownUtils extends PredicateHelper {
         }
 
         val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
-        if (r.pushAggregation(agg)) {
-          Some(agg)
-        } else {
-          None
-        }
+        Some(agg).filter(r.pushAggregation)
       case _ => None
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 445ff03..a1fc981 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -33,7 +33,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
   import DataSourceV2Implicits._
 
   def apply(plan: LogicalPlan): LogicalPlan = {
-    applyColumnPruning(pushdownAggregate(pushDownFilters(createScanBuilder(plan))))
+    applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))
   }
 
   private def createScanBuilder(plan: LogicalPlan) = plan.transform {
@@ -68,7 +68,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
       filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder)
   }
 
-  def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform {
+  def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform {
     // update the scan builder with agg pushdown and return a new plan with agg pushed
     case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
       child match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index 7442eda..afdc822 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -85,14 +85,14 @@ case class JDBCScanBuilder(
         val structField = getStructFieldForCol(min.column)
         outputSchema = outputSchema.add(structField.copy("min(" + structField.name + ")"))
       case count: Count =>
-        val distinct = if (count.isDinstinct) "DISTINCT " else ""
+        val distinct = if (count.isDistinct) "DISTINCT " else ""
         val structField = getStructFieldForCol(count.column)
         outputSchema =
           outputSchema.add(StructField(s"count($distinct" + structField.name + ")", LongType))
       case _: CountStar =>
         outputSchema = outputSchema.add(StructField("count(*)", LongType))
       case sum: Sum =>
-        val distinct = if (sum.isDinstinct) "DISTINCT " else ""
+        val distinct = if (sum.isDistinct) "DISTINCT " else ""
         val structField = getStructFieldForCol(sum.column)
         outputSchema =
           outputSchema.add(StructField(s"sum($distinct" + structField.name + ")", sum.dataType))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index c1f8f5f..8dfb6de 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -453,7 +453,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     checkAnswer(query, Seq(Row(47100.0)))
   }
 
-  test("scan with aggregate push-down: aggregate over alias") {
+  test("scan with aggregate push-down: aggregate over alias NOT push down") {
     val cols = Seq("a", "b", "c", "d")
     val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
     val df2 = df1.groupBy().sum("c")

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org