You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/03/17 19:55:44 UTC

[GitHub] [spark] viirya commented on a change in pull request #29695: [SPARK-22390][SPARK-32833][SQL] [WIP]JDBC V2 Datasource aggregate push down

viirya commented on a change in pull request #29695:
URL: https://github.com/apache/spark/pull/29695#discussion_r596239451



##########
File path: sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.sources.Aggregation;
+
+/**
+ * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
+ * push down aggregates to the data source.
+ *
+ * @since 3.1.0

Review comment:
       Maybe 3.2.0 now.

##########
File path: sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.sources.Aggregation;
+
+/**
+ * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
+ * push down aggregates to the data source.
+ *
+ * @since 3.1.0
+ */
+@Evolving
+public interface SupportsPushDownAggregates extends ScanBuilder {
+
+  /**
+   * Pushes down Aggregation to datasource.
+   * The Aggregation can be pushed down only if all the Aggregate Functions can
+   * be pushed down.
+   */
+  void pushAggregation(Aggregation aggregation);
+
+  /**
+   * Returns the aggregates that are pushed to the data source via

Review comment:
       I think this returns `Aggregation`?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,38 +17,122 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
 import org.apache.spark.sql.types.StructType
 
-object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper {
+
   import DataSourceV2Implicits._
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
-    case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
-      val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
+    case Aggregate(groupingExpressions, resultExpressions, child) =>
+      child match {
+        case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+          val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
+          val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, filters, relation)
+          if (postScanFilters.nonEmpty) {
+            Aggregate(groupingExpressions, resultExpressions, child)
+          } else { // only push down aggregate of all the filers can be push down
+            val aliasMap = getAliasMap(project)
+            var aggregates = resultExpressions.flatMap { expr =>
+              expr.collect {
+                case agg: AggregateExpression =>
+                  replaceAlias(agg, aliasMap).asInstanceOf[AggregateExpression]
+              }
+            }
+            aggregates = DataSourceStrategy.normalizeExprs(aggregates, relation.output)
+              .asInstanceOf[Seq[AggregateExpression]]
 
-      val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output)
-      val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
-        normalizedFilters.partition(SubqueryExpression.hasSubquery)
+            val groupingExpressionsWithoutAlias = groupingExpressions.flatMap{ expr =>
+              expr.collect {
+                case a: AttributeReference => replaceAlias(a, aliasMap)
+              }
+            }
+            val normalizedgroupingExpressions =

Review comment:
       normalizedGroupingExpressions

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
##########
@@ -700,6 +704,41 @@ object DataSourceStrategy
     (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters)
   }
 
+  private def columnAsString(e: Expression): String = e match {

Review comment:
       For predicate pushdown, seems we simplify the cases to handle by only looking at column name.
   
   This covers a lot of cases but also makes it easy to break. We can begin with simplest case and add more supports later.

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,38 +17,122 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
 import org.apache.spark.sql.types.StructType
 
-object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper {
+
   import DataSourceV2Implicits._
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
-    case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
-      val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
+    case Aggregate(groupingExpressions, resultExpressions, child) =>
+      child match {
+        case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+          val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
+          val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, filters, relation)

Review comment:
       I can see there is a dependency between filter pushdown and aggregate pushdown. As we need to check if all filters are pushed down.
   
   
   I think an alternative approach is to not touch filter pushdown, but to check if filter pushdown is happened and there is still `Filter` on top of the scan relation.
   
   I feel that can simplify the code here. And we don't need to call `pushDownFilter` twice for aggregate and filter.

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
##########
@@ -181,21 +243,49 @@ private[jdbc] class JDBCRDD(
     filters: Array[Filter],
     partitions: Array[Partition],
     url: String,
-    options: JDBCOptions)
+    options: JDBCOptions,
+    aggregation: Aggregation = Aggregation.empty)
   extends RDD[InternalRow](sc, Nil) {
 
   /**
    * Retrieve the list of partitions corresponding to this RDD.
    */
   override def getPartitions: Array[Partition] = partitions
 
+  private var updatedSchema: StructType = new StructType()
+
   /**
    * `columns`, but as a String suitable for injection into a SQL query.
    */
   private val columnList: String = {
+    val (compiledAgg, aggDataType) =
+      JDBCRDD.compileAggregates(aggregation.aggregateExpressions, JdbcDialects.get(url))
     val sb = new StringBuilder()
-    columns.foreach(x => sb.append(",").append(x))
-    if (sb.isEmpty) "1" else sb.substring(1)
+    if (compiledAgg.length == 0) {
+      updatedSchema = schema
+      columns.foreach(x => sb.append(",").append(x))
+    } else {
+      getAggregateColumnsList(sb, compiledAgg, aggDataType)
+    }
+    if (sb.length == 0) "1" else sb.substring(1)
+  }
+
+  private def getAggregateColumnsList(

Review comment:
       Shall we add a comment here to explain what `getAggregateColumnsList` does and why it is needed?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
##########
@@ -700,6 +704,41 @@ object DataSourceStrategy
     (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters)
   }
 
+  private def columnAsString(e: Expression): String = e match {

Review comment:
       Let's wait for others. See if there is any other voices.

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
##########
@@ -296,13 +398,15 @@ private[jdbc] class JDBCRDD(
 
     val myWhereClause = getWhereClause(part)
 
-    val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause"
+    val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" +
+      s" $getGroupByClause"
     stmt = conn.prepareStatement(sqlText,
         ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
     stmt.setFetchSize(options.fetchSize)
     stmt.setQueryTimeout(options.queryTimeout)
     rs = stmt.executeQuery()
-    val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
+

Review comment:
       unnecessary change?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
##########
@@ -181,21 +243,49 @@ private[jdbc] class JDBCRDD(
     filters: Array[Filter],
     partitions: Array[Partition],
     url: String,
-    options: JDBCOptions)
+    options: JDBCOptions,
+    aggregation: Aggregation = Aggregation.empty)
   extends RDD[InternalRow](sc, Nil) {
 
   /**
    * Retrieve the list of partitions corresponding to this RDD.
    */
   override def getPartitions: Array[Partition] = partitions
 
+  private var updatedSchema: StructType = new StructType()
+
   /**
    * `columns`, but as a String suitable for injection into a SQL query.
    */
   private val columnList: String = {
+    val (compiledAgg, aggDataType) =
+      JDBCRDD.compileAggregates(aggregation.aggregateExpressions, JdbcDialects.get(url))
     val sb = new StringBuilder()
-    columns.foreach(x => sb.append(",").append(x))
-    if (sb.isEmpty) "1" else sb.substring(1)
+    if (compiledAgg.length == 0) {
+      updatedSchema = schema
+      columns.foreach(x => sb.append(",").append(x))
+    } else {
+      getAggregateColumnsList(sb, compiledAgg, aggDataType)

Review comment:
       `columns` is empty for this case?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,38 +17,122 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
 import org.apache.spark.sql.types.StructType
 
-object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper {
+
   import DataSourceV2Implicits._
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
-    case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
-      val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
+    case Aggregate(groupingExpressions, resultExpressions, child) =>
+      child match {
+        case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+          val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
+          val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, filters, relation)
+          if (postScanFilters.nonEmpty) {
+            Aggregate(groupingExpressions, resultExpressions, child)
+          } else { // only push down aggregate of all the filers can be push down

Review comment:
       of -> if

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
##########
@@ -17,38 +17,122 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
 import org.apache.spark.sql.types.StructType
 
-object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper {
+
   import DataSourceV2Implicits._
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
-    case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
-      val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
+    case Aggregate(groupingExpressions, resultExpressions, child) =>
+      child match {
+        case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+          val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
+          val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, filters, relation)
+          if (postScanFilters.nonEmpty) {
+            Aggregate(groupingExpressions, resultExpressions, child)
+          } else { // only push down aggregate of all the filers can be push down
+            val aliasMap = getAliasMap(project)
+            var aggregates = resultExpressions.flatMap { expr =>
+              expr.collect {
+                case agg: AggregateExpression =>
+                  replaceAlias(agg, aliasMap).asInstanceOf[AggregateExpression]
+              }
+            }
+            aggregates = DataSourceStrategy.normalizeExprs(aggregates, relation.output)
+              .asInstanceOf[Seq[AggregateExpression]]
 
-      val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output)
-      val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
-        normalizedFilters.partition(SubqueryExpression.hasSubquery)
+            val groupingExpressionsWithoutAlias = groupingExpressions.flatMap{ expr =>
+              expr.collect {
+                case a: AttributeReference => replaceAlias(a, aliasMap)
+              }
+            }
+            val normalizedgroupingExpressions =
+              DataSourceStrategy.normalizeExprs(groupingExpressionsWithoutAlias, relation.output)
 
-      // `pushedFilters` will be pushed down and evaluated in the underlying data sources.
-      // `postScanFilters` need to be evaluated after the scan.
-      // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
-      val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters(
-        scanBuilder, normalizedFiltersWithoutSubquery)
-      val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery
+            val aggregation = PushDownUtils.pushAggregates(scanBuilder, aggregates,
+              normalizedgroupingExpressions)
+
+            val (scan, output, normalizedProjects) =
+              processFilerAndColumn(scanBuilder, project, postScanFilters, relation)
+
+            logInfo(
+              s"""
+                 |Pushing operators to ${relation.name}
+                 |Pushed Filters: ${pushedFilters.mkString(", ")}
+                 |Post-Scan Filters: ${postScanFilters.mkString(",")}
+                 |Pushed Aggregate Functions: ${aggregation.aggregateExpressions.mkString(", ")}
+                 |Pushed Groupby: ${aggregation.groupByExpressions.mkString(", ")}
+                 |Output: ${output.mkString(", ")}
+             """.stripMargin)
+
+            val wrappedScan = scan match {
+              case v1: V1Scan =>
+                val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true))
+                V1ScanWrapper(v1, translated, pushedFilters, aggregation)
+              case _ => scan
+            }
+            if (aggregation.aggregateExpressions.isEmpty) {
+              val plan = buildLogicalPlan(project, relation, wrappedScan, output,
+                normalizedProjects, postScanFilters)
+              Aggregate(groupingExpressions, resultExpressions, plan)
+            } else {
+              val aggOutputBuilder = ArrayBuilder.make[AttributeReference]
+              for (i <- 0 until aggregates.length) {
+                aggOutputBuilder += AttributeReference(
+                  aggregation.aggregateExpressions(i).toString, aggregates(i).dataType)()
+              }
+              for (groupBy <- groupingExpressions) {
+                aggOutputBuilder += groupBy.asInstanceOf[AttributeReference]
+              }

Review comment:
       `groupingExpressions` is `Seq[Expression]`, are we sure they all `AttributeReference`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org