You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hu...@apache.org on 2021/10/29 00:00:04 UTC

[spark] branch master updated: [SPARK-37020][SQL] DS V2 LIMIT push down

This is an automated email from the ASF dual-hosted git repository.

huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9821a28  [SPARK-37020][SQL] DS V2 LIMIT push down
9821a28 is described below

commit 9821a286c7d5ee5e0668c49c893de158809ec38f
Author: Huaxin Gao <hu...@apple.com>
AuthorDate: Thu Oct 28 16:59:12 2021 -0700

    [SPARK-37020][SQL] DS V2 LIMIT push down
    
    ### What changes were proposed in this pull request?
    Push down limit to data source for better performance
    
    ### Why are the changes needed?
    For LIMIT, e.g. `SELECT * FROM table LIMIT 10`, Spark retrieves all the data from table and then returns 10 rows. If we can push LIMIT to data source side, the data transferred to Spark will be dramatically reduced.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. new interface `SupportsPushDownLimit`
    
    ### How was this patch tested?
    new test
    
    Closes #34291 from huaxingao/pushdownLimit.
    
    Authored-by: Huaxin Gao <hu...@apple.com>
    Signed-off-by: Huaxin Gao <hu...@apple.com>
---
 docs/sql-data-sources-jdbc.md                      |  9 +++
 .../spark/sql/connector/read/ScanBuilder.java      |  6 +-
 ...ScanBuilder.java => SupportsPushDownLimit.java} | 17 +++---
 .../spark/sql/execution/DataSourceScanExec.scala   |  4 +-
 .../execution/datasources/DataSourceStrategy.scala |  3 +
 .../execution/datasources/jdbc/JDBCOptions.scala   |  4 ++
 .../sql/execution/datasources/jdbc/JDBCRDD.scala   | 15 +++--
 .../execution/datasources/jdbc/JDBCRelation.scala  |  6 +-
 .../datasources/v2/DataSourceV2Strategy.scala      |  4 +-
 .../execution/datasources/v2/PushDownUtils.scala   | 13 +++-
 .../datasources/v2/V2ScanRelationPushDown.scala    | 30 ++++++++--
 .../execution/datasources/v2/jdbc/JDBCScan.scala   |  5 +-
 .../datasources/v2/jdbc/JDBCScanBuilder.scala      | 15 ++++-
 .../org/apache/spark/sql/jdbc/DerbyDialect.scala   |  4 ++
 .../org/apache/spark/sql/jdbc/JdbcDialects.scala   | 12 ++++
 .../apache/spark/sql/jdbc/MsSqlServerDialect.scala |  3 +
 .../apache/spark/sql/jdbc/TeradataDialect.scala    |  3 +
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 69 +++++++++++++++++++++-
 18 files changed, 191 insertions(+), 31 deletions(-)

diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md
index 16d525e..361b92be 100644
--- a/docs/sql-data-sources-jdbc.md
+++ b/docs/sql-data-sources-jdbc.md
@@ -247,6 +247,15 @@ logging into the data sources.
   </tr>
 
   <tr>
+    <td><code>pushDownLimit</code></td>
+    <td><code>false</code></td>
+    <td>
+     The option to enable or disable LIMIT push-down into the JDBC data source. The default value is false, in which case Spark does not push down LIMIT to the JDBC data source. Otherwise, if value sets to true, LIMIT is pushed down to the JDBC data source. SPARK still applies LIMIT on the result from data source even if LIMIT is pushed down.
+    </td>
+    <td>read</td>
+  </tr>
+
+  <tr>
     <td><code>keytab</code></td>
     <td>(none)</td>
     <td>
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
index b46f620..20c9d2e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
@@ -21,9 +21,9 @@ import org.apache.spark.annotation.Evolving;
 
 /**
  * An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ
- * interfaces to do operator pushdown, and keep the operator pushdown result in the returned
- * {@link Scan}. When pushing down operators, Spark pushes down filters first, then pushes down
- * aggregates or applies column pruning.
+ * interfaces to do operator push down, and keep the operator push down result in the returned
+ * {@link Scan}. When pushing down operators, the push down order is:
+ * filter -&gt; aggregate -&gt; limit -&gt; column pruning.
  *
  * @since 3.0.0
  */
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java
similarity index 68%
copy from sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
copy to sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java
index b46f620..7e50bf1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java
@@ -20,14 +20,17 @@ package org.apache.spark.sql.connector.read;
 import org.apache.spark.annotation.Evolving;
 
 /**
- * An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ
- * interfaces to do operator pushdown, and keep the operator pushdown result in the returned
- * {@link Scan}. When pushing down operators, Spark pushes down filters first, then pushes down
- * aggregates or applies column pruning.
+ * A mix-in interface for {@link Scan}. Data sources can implement this interface to
+ * push down LIMIT. Please note that the combination of LIMIT with other operations
+ * such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down.
  *
- * @since 3.0.0
+ * @since 3.3.0
  */
 @Evolving
-public interface ScanBuilder {
-  Scan build();
+public interface SupportsPushDownLimit extends ScanBuilder {
+
+  /**
+   * Pushes down LIMIT to the data source.
+   */
+  boolean pushLimit(int limit);
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 4f282ed..86b29d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -104,6 +104,7 @@ case class RowDataSourceScanExec(
     filters: Set[Filter],
     handledFilters: Set[Filter],
     aggregation: Option[Aggregation],
+    limit: Option[Int],
     rdd: RDD[InternalRow],
     @transient relation: BaseRelation,
     tableIdentifier: Option[TableIdentifier])
@@ -153,7 +154,8 @@ case class RowDataSourceScanExec(
       "ReadSchema" -> requiredSchema.catalogString,
       "PushedFilters" -> seqToString(markedFilters.toSeq),
       "PushedAggregates" -> aggString,
-      "PushedGroupby" -> groupByString)
+      "PushedGroupby" -> groupByString) ++
+      limit.map(value => "PushedLimit" -> s"LIMIT $value")
   }
 
   // Don't care about `rdd` and `tableIdentifier` when canonicalizing.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2c2dac1..81cd37f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -336,6 +336,7 @@ object DataSourceStrategy
         Set.empty,
         Set.empty,
         None,
+        None,
         toCatalystRDD(l, baseRelation.buildScan()),
         baseRelation,
         None) :: Nil
@@ -410,6 +411,7 @@ object DataSourceStrategy
         pushedFilters.toSet,
         handledFilters,
         None,
+        None,
         scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
         relation.relation,
         relation.catalogTable.map(_.identifier))
@@ -433,6 +435,7 @@ object DataSourceStrategy
         pushedFilters.toSet,
         handledFilters,
         None,
+        None,
         scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
         relation.relation,
         relation.catalogTable.map(_.identifier))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 510a22c..e0730f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -191,6 +191,9 @@ class JDBCOptions(
   // An option to allow/disallow pushing down aggregate into JDBC data source
   val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean
 
+  // An option to allow/disallow pushing down LIMIT into JDBC data source
+  val pushDownLimit = parameters.getOrElse(JDBC_PUSHDOWN_LIMIT, "false").toBoolean
+
   // The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either
   // by --files option of spark-submit or manually
   val keytab = {
@@ -266,6 +269,7 @@ object JDBCOptions {
   val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
   val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate")
   val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate")
+  val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit")
   val JDBC_KEYTAB = newOption("keytab")
   val JDBC_PRINCIPAL = newOption("principal")
   val JDBC_TABLE_COMMENT = newOption("tableComment")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index e024e4b..7973850 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -179,6 +179,8 @@ object JDBCRDD extends Logging {
    * @param options - JDBC options that contains url, table and other information.
    * @param outputSchema - The schema of the columns or aggregate columns to SELECT.
    * @param groupByColumns - The pushed down group by columns.
+   * @param limit - The pushed down limit. If the value is 0, it means no limit or limit
+   *                is not pushed down.
    *
    * @return An RDD representing "SELECT requiredColumns FROM fqTable".
    */
@@ -190,7 +192,8 @@ object JDBCRDD extends Logging {
       parts: Array[Partition],
       options: JDBCOptions,
       outputSchema: Option[StructType] = None,
-      groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = {
+      groupByColumns: Option[Array[String]] = None,
+      limit: Int = 0): RDD[InternalRow] = {
     val url = options.url
     val dialect = JdbcDialects.get(url)
     val quotedColumns = if (groupByColumns.isEmpty) {
@@ -208,7 +211,8 @@ object JDBCRDD extends Logging {
       parts,
       url,
       options,
-      groupByColumns)
+      groupByColumns,
+      limit)
   }
 }
 
@@ -226,7 +230,8 @@ private[jdbc] class JDBCRDD(
     partitions: Array[Partition],
     url: String,
     options: JDBCOptions,
-    groupByColumns: Option[Array[String]])
+    groupByColumns: Option[Array[String]],
+    limit: Int)
   extends RDD[InternalRow](sc, Nil) {
 
   /**
@@ -349,8 +354,10 @@ private[jdbc] class JDBCRDD(
 
     val myWhereClause = getWhereClause(part)
 
+    val myLimitClause: String = dialect.getLimitClause(limit)
+
     val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" +
-      s" $getGroupByClause"
+      s" $getGroupByClause $myLimitClause"
     stmt = conn.prepareStatement(sqlText,
         ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
     stmt.setFetchSize(options.fetchSize)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 8098fa0..ff9fcd4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -298,7 +298,8 @@ private[sql] case class JDBCRelation(
       requiredColumns: Array[String],
       finalSchema: StructType,
       filters: Array[Filter],
-      groupByColumns: Option[Array[String]]): RDD[Row] = {
+      groupByColumns: Option[Array[String]],
+      limit: Int): RDD[Row] = {
     // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
     JDBCRDD.scanTable(
       sparkSession.sparkContext,
@@ -308,7 +309,8 @@ private[sql] case class JDBCRelation(
       parts,
       jdbcOptions,
       Some(finalSchema),
-      groupByColumns).asInstanceOf[RDD[Row]]
+      groupByColumns,
+      limit).asInstanceOf[RDD[Row]]
   }
 
   override def insert(data: DataFrame, overwrite: Boolean): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 66ee431..b688c32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -94,7 +94,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
 
   override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
     case PhysicalOperation(project, filters,
-        DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate), output)) =>
+        DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate, limit), output)) =>
       val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext)
       if (v1Relation.schema != scan.readSchema()) {
         throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError(
@@ -102,12 +102,14 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
       }
       val rdd = v1Relation.buildScan()
       val unsafeRowRDD = DataSourceStrategy.toCatalystRDD(v1Relation, output, rdd)
+
       val dsScan = RowDataSourceScanExec(
         output,
         output.toStructType,
         Set.empty,
         pushed.toSet,
         aggregate,
+        limit,
         unsafeRowRDD,
         v1Relation,
         tableIdentifier = None)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index 335038a..a8c251a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
 import org.apache.spark.sql.connector.expressions.FieldReference
 import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters}
 import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources
@@ -139,6 +139,17 @@ object PushDownUtils extends PredicateHelper {
   }
 
   /**
+   * Pushes down LIMIT to the data source Scan
+   */
+  def pushLimit(scanBuilder: ScanBuilder, limit: Int): Boolean = {
+    scanBuilder match {
+      case s: SupportsPushDownLimit =>
+        s.pushLimit(limit)
+      case _ => false
+    }
+  }
+
+  /**
    * Applies column pruning to the data source, w.r.t. the references of the given expressions.
    *
    * @return the `Scan` instance (since column pruning is the last step of operator pushdown),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index ec45a5d..960a1ea6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import scala.collection.mutable
 
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
@@ -36,7 +36,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
   import DataSourceV2Implicits._
 
   def apply(plan: LogicalPlan): LogicalPlan = {
-    applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))
+    applyColumnPruning(applyLimit(pushDownAggregates(pushDownFilters(createScanBuilder(plan)))))
   }
 
   private def createScanBuilder(plan: LogicalPlan) = plan.transform {
@@ -225,6 +225,19 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
       withProjection
   }
 
+  def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform {
+    case globalLimit @ Limit(IntegerLiteral(limitValue), child) =>
+      child match {
+        case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 =>
+          val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue)
+          if (limitPushed) {
+            sHolder.setLimit(Some(limitValue))
+          }
+          globalLimit
+        case _ => globalLimit
+      }
+  }
+
   private def getWrappedScan(
       scan: Scan,
       sHolder: ScanBuilderHolder,
@@ -236,7 +249,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
             f.pushedFilters()
           case _ => Array.empty[sources.Filter]
         }
-        V1ScanWrapper(v1, pushedFilters, aggregation)
+        V1ScanWrapper(v1, pushedFilters, aggregation, sHolder.pushedLimit)
       case _ => scan
     }
   }
@@ -245,13 +258,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
 case class ScanBuilderHolder(
     output: Seq[AttributeReference],
     relation: DataSourceV2Relation,
-    builder: ScanBuilder) extends LeafNode
+    builder: ScanBuilder) extends LeafNode {
+  var pushedLimit: Option[Int] = None
+  private[sql] def setLimit(limit: Option[Int]): Unit = pushedLimit = limit
+}
+
 
 // A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by
 // the physical v1 scan node.
 case class V1ScanWrapper(
     v1Scan: V1Scan,
     handledFilters: Seq[sources.Filter],
-    pushedAggregate: Option[Aggregation]) extends Scan {
+    pushedAggregate: Option[Aggregation],
+    pushedLimit: Option[Int]) extends Scan {
   override def readSchema(): StructType = v1Scan.readSchema()
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
index ef42691..94d9d14 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
@@ -28,7 +28,8 @@ case class JDBCScan(
     prunedSchema: StructType,
     pushedFilters: Array[Filter],
     pushedAggregateColumn: Array[String] = Array(),
-    groupByColumns: Option[Array[String]]) extends V1Scan {
+    groupByColumns: Option[Array[String]],
+    pushedLimit: Int) extends V1Scan {
 
   override def readSchema(): StructType = prunedSchema
 
@@ -43,7 +44,7 @@ case class JDBCScan(
         } else {
           pushedAggregateColumn
         }
-        relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns)
+        relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, pushedLimit)
       }
     }.asInstanceOf[T]
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index b0de7c0..1482674 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -21,7 +21,7 @@ import scala.util.control.NonFatal
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns}
 import org.apache.spark.sql.execution.datasources.PartitioningUtils
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation}
 import org.apache.spark.sql.jdbc.JdbcDialects
@@ -36,6 +36,7 @@ case class JDBCScanBuilder(
     with SupportsPushDownFilters
     with SupportsPushDownRequiredColumns
     with SupportsPushDownAggregates
+    with SupportsPushDownLimit
     with Logging {
 
   private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis
@@ -44,6 +45,16 @@ case class JDBCScanBuilder(
 
   private var finalSchema = schema
 
+  private var pushedLimit = 0
+
+  override def pushLimit(limit: Int): Boolean = {
+    if (jdbcOptions.pushDownLimit && JdbcDialects.get(jdbcOptions.url).supportsLimit) {
+      pushedLimit = limit
+      return true
+    }
+    false
+  }
+
   override def pushFilters(filters: Array[Filter]): Array[Filter] = {
     if (jdbcOptions.pushDownPredicate) {
       val dialect = JdbcDialects.get(jdbcOptions.url)
@@ -123,6 +134,6 @@ case class JDBCScanBuilder(
     // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't
     // be used in sql string.
     JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter,
-      pushedAggregateList, pushedGroupByCols)
+      pushedAggregateList, pushedGroupByCols, pushedLimit)
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
index 020733a..ecb514a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
@@ -57,4 +57,8 @@ private object DerbyDialect extends JdbcDialect {
   override def getTableCommentQuery(table: String, comment: String): String = {
     throw QueryExecutionErrors.commentOnTableUnsupportedError()
   }
+
+  // ToDo: use fetch first n rows only for limit, e.g.
+  //  select * from employee fetch first 10 rows only;
+  override def supportsLimit(): Boolean = false
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 9e54ba7..ac6fd2f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -358,6 +358,18 @@ abstract class JdbcDialect extends Serializable with Logging{
   def classifyException(message: String, e: Throwable): AnalysisException = {
     new AnalysisException(message, cause = Some(e))
   }
+
+  /**
+   * returns the LIMIT clause for the SELECT statement
+   */
+  def getLimitClause(limit: Integer): String = {
+    if (limit > 0 ) s"LIMIT $limit" else ""
+  }
+
+  /**
+   * returns whether the dialect supports limit or not
+   */
+  def supportsLimit(): Boolean = true
 }
 
 /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
index ea98348..8dad5ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
@@ -118,4 +118,7 @@ private object MsSqlServerDialect extends JdbcDialect {
   override def getTableCommentQuery(table: String, comment: String): String = {
     throw QueryExecutionErrors.commentOnTableUnsupportedError()
   }
+
+  // ToDo: use top n to get limit, e.g. select top 100 * from employee;
+  override def supportsLimit(): Boolean = false
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
index 58fe62c..2a776bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
@@ -55,4 +55,7 @@ private case object TeradataDialect extends JdbcDialect {
   override def renameTable(oldTable: String, newTable: String): String = {
     s"RENAME TABLE $oldTable TO $newTable"
   }
+
+  // ToDo: use top n to get limit, e.g. select top 100 * from employee;
+  override def supportsLimit(): Boolean = false
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 02f10aa..d5b8ea9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -21,10 +21,10 @@ import java.sql.{Connection, DriverManager}
 import java.util.Properties
 
 import org.apache.spark.SparkConf
-import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row}
+import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row}
 import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
 import org.apache.spark.sql.catalyst.plans.logical.Filter
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
 import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
 import org.apache.spark.sql.functions.{lit, sum, udf}
 import org.apache.spark.sql.test.SharedSparkSession
@@ -42,6 +42,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     .set("spark.sql.catalog.h2.url", url)
     .set("spark.sql.catalog.h2.driver", "org.h2.Driver")
     .set("spark.sql.catalog.h2.pushDownAggregate", "true")
+    .set("spark.sql.catalog.h2.pushDownLimit", "true")
 
   private def withConnection[T](f: Connection => T): T = {
     val conn = DriverManager.getConnection(url, new Properties())
@@ -92,6 +93,70 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     checkAnswer(sql("SELECT name, id FROM h2.test.people"), Seq(Row("fred", 1), Row("mary", 2)))
   }
 
+  test("simple scan with LIMIT") {
+    val df1 = spark.read.table("h2.test.employee")
+      .where($"dept" === 1).limit(1)
+    checkPushedLimit(df1, true, 1)
+    checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0)))
+
+    val df2 = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .filter($"dept" > 1)
+      .limit(1)
+    checkPushedLimit(df2, true, 1)
+    checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0)))
+
+    val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1")
+    val scan = df3.queryExecution.optimizedPlan.collectFirst {
+      case s: DataSourceV2ScanRelation => s
+    }.get
+    assert(scan.schema.names.sameElements(Seq("NAME")))
+    checkPushedLimit(df3, true, 1)
+    checkAnswer(df3, Seq(Row("alex")))
+
+    val df4 = spark.read
+      .table("h2.test.employee")
+      .groupBy("DEPT").sum("SALARY")
+      .limit(1)
+    checkPushedLimit(df4, false, 0)
+    checkAnswer(df4, Seq(Row(1, 19000.00)))
+
+    val df5 = spark.read
+      .table("h2.test.employee")
+      .sort("SALARY")
+      .limit(1)
+    checkPushedLimit(df5, false, 0)
+    checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0)))
+
+    val name = udf { (x: String) => x.matches("cat|dav|amy") }
+    val sub = udf { (x: String) => x.substring(0, 3) }
+    val df6 = spark.read
+      .table("h2.test.employee")
+      .select($"SALARY", $"BONUS", sub($"NAME").as("shortName"))
+      .filter(name($"shortName"))
+      .limit(1)
+    // LIMIT is pushed down only if all the filters are pushed down
+    checkPushedLimit(df6, false, 0)
+    checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy")))
+  }
+
+  private def checkPushedLimit(df: DataFrame, pushed: Boolean, limit: Int): Unit = {
+    df.queryExecution.optimizedPlan.collect {
+      case DataSourceV2ScanRelation(_, scan, _) => scan match {
+        case v1: V1ScanWrapper =>
+          if (pushed) {
+            assert(v1.pushedLimit.nonEmpty && v1.pushedLimit.get === limit)
+          } else {
+            assert(v1.pushedLimit.isEmpty)
+          }
+      }
+    }
+  }
+
   test("scan with filter push-down") {
     val df = spark.table("h2.test.people").filter($"id" > 1)
     val filters = df.queryExecution.optimizedPlan.collect {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org