You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/09/12 14:54:18 UTC
spark git commit: [SPARK-25352][SQL] Perform ordered global limit
when limit number is bigger than topKSortFallbackThreshold
Repository: spark
Updated Branches:
refs/heads/master 79cc59718 -> 2f422398b
[SPARK-25352][SQL] Perform ordered global limit when limit number is bigger than topKSortFallbackThreshold
## What changes were proposed in this pull request?
We have optimization on global limit to evenly distribute limit rows across all partitions. This optimization doesn't work for ordered results.
For a query ending with sort + limit, in most cases it is performed by `TakeOrderedAndProjectExec`.
But if limit number is bigger than `SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD`, global limit will be used. At this moment, we need to do ordered global limit.
## How was this patch tested?
Unit tests.
Closes #22344 from viirya/SPARK-25352.
Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2f422398
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2f422398
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2f422398
Branch: refs/heads/master
Commit: 2f422398b524eacc89ab58e423bb134ae3ca3941
Parents: 79cc597
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Wed Sep 12 22:54:05 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Sep 12 22:54:05 2018 +0800
----------------------------------------------------------------------
.../spark/sql/execution/SparkStrategies.scala | 44 ++++++---
.../org/apache/spark/sql/execution/limit.scala | 7 +-
.../org/apache/spark/sql/DataFrameSuite.scala | 22 ++++-
.../apache/spark/sql/execution/LimitSuite.scala | 81 +++++++++++++++++
.../execution/TakeOrderedAndProjectSuite.scala | 94 +++++++++++---------
5 files changed, 192 insertions(+), 56 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/2f422398/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index dbc6db6..7c8ce31 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -68,22 +68,42 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object SpecialLimits extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ReturnAnswer(rootPlan) => rootPlan match {
- case Limit(IntegerLiteral(limit), Sort(order, true, child))
- if limit < conf.topKSortFallbackThreshold =>
- TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
- case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child)))
- if limit < conf.topKSortFallbackThreshold =>
- TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
+ case Limit(IntegerLiteral(limit), s@Sort(order, true, child)) =>
+ if (limit < conf.topKSortFallbackThreshold) {
+ TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
+ } else {
+ GlobalLimitExec(limit,
+ LocalLimitExec(limit, planLater(s)),
+ orderedLimit = true) :: Nil
+ }
+ case Limit(IntegerLiteral(limit), p@Project(projectList, Sort(order, true, child))) =>
+ if (limit < conf.topKSortFallbackThreshold) {
+ TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
+ } else {
+ GlobalLimitExec(limit,
+ LocalLimitExec(limit, planLater(p)),
+ orderedLimit = true) :: Nil
+ }
case Limit(IntegerLiteral(limit), child) =>
CollectLimitExec(limit, planLater(child)) :: Nil
case other => planLater(other) :: Nil
}
- case Limit(IntegerLiteral(limit), Sort(order, true, child))
- if limit < conf.topKSortFallbackThreshold =>
- TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
- case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child)))
- if limit < conf.topKSortFallbackThreshold =>
- TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
+ case Limit(IntegerLiteral(limit), s@Sort(order, true, child)) =>
+ if (limit < conf.topKSortFallbackThreshold) {
+ TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
+ } else {
+ GlobalLimitExec(limit,
+ LocalLimitExec(limit, planLater(s)),
+ orderedLimit = true) :: Nil
+ }
+ case Limit(IntegerLiteral(limit), p@Project(projectList, Sort(order, true, child))) =>
+ if (limit < conf.topKSortFallbackThreshold) {
+ TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
+ } else {
+ GlobalLimitExec(limit,
+ LocalLimitExec(limit, planLater(p)),
+ orderedLimit = true) :: Nil
+ }
case _ => Nil
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2f422398/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index fb46970..1a09632 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -98,7 +98,8 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode wi
/**
* Take the `limit` elements of the child output.
*/
-case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode {
+case class GlobalLimitExec(limit: Int, child: SparkPlan,
+ orderedLimit: Boolean = false) extends UnaryExecNode {
override def output: Seq[Attribute] = child.output
@@ -126,7 +127,9 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode {
// When enabled, Spark goes to take rows at each partition repeatedly until reaching
// limit number. When disabled, Spark takes all rows at first partition, then rows
// at second partition ..., until reaching limit number.
- val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit
+ // The optimization is disabled when it is needed to keep the original order of rows
+ // before global sort, e.g., select * from table order by col limit 10.
+ val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && !orderedLimit
val shuffled = new ShuffledRowRDD(shuffleDependency)
http://git-wip-us.apache.org/repos/asf/spark/blob/2f422398/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 279b7b8..f001b13 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
-import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.{FilterExec, QueryExecution, TakeOrderedAndProjectExec, WholeStageCodegenExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.functions._
@@ -2552,6 +2552,26 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}
+ test("SPARK-25352: Ordered global limit when more than topKSortFallbackThreshold ") {
+ withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") {
+ val baseDf = spark.range(1000).toDF.repartition(3).sort("id")
+
+ withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") {
+ val expected = baseDf.limit(99)
+ val takeOrderedNode1 = expected.queryExecution.executedPlan
+ .find(_.isInstanceOf[TakeOrderedAndProjectExec])
+ assert(takeOrderedNode1.isDefined)
+
+ val result = baseDf.limit(100)
+ val takeOrderedNode2 = result.queryExecution.executedPlan
+ .find(_.isInstanceOf[TakeOrderedAndProjectExec])
+ assert(takeOrderedNode2.isEmpty)
+
+ checkAnswer(expected, result.collect().take(99))
+ }
+ }
+ }
+
test("SPARK-25368 Incorrect predicate pushdown returns wrong result") {
def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = {
val df1 = spark.createDataFrame(Seq(
http://git-wip-us.apache.org/repos/asf/spark/blob/2f422398/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala
new file mode 100644
index 0000000..a7840a5
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.util.Random
+
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+
+
+class LimitSuite extends SparkPlanTest with SharedSQLContext {
+
+ private var rand: Random = _
+ private var seed: Long = 0
+
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ seed = System.currentTimeMillis()
+ rand = new Random(seed)
+ }
+
+ test("Produce ordered global limit if more than topKSortFallbackThreshold") {
+ withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") {
+ val df = LimitTest.generateRandomInputData(spark, rand).sort("a")
+
+ val globalLimit = df.limit(99).queryExecution.executedPlan.collect {
+ case g: GlobalLimitExec => g
+ }
+ assert(globalLimit.size == 0)
+
+ val topKSort = df.limit(99).queryExecution.executedPlan.collect {
+ case t: TakeOrderedAndProjectExec => t
+ }
+ assert(topKSort.size == 1)
+
+ val orderedGlobalLimit = df.limit(100).queryExecution.executedPlan.collect {
+ case g: GlobalLimitExec => g
+ }
+ assert(orderedGlobalLimit.size == 1 && orderedGlobalLimit(0).orderedLimit == true)
+ }
+ }
+
+ test("Ordered global limit") {
+ val baseDf = LimitTest.generateRandomInputData(spark, rand)
+ .select("a").repartition(3).sort("a")
+
+ withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") {
+ val orderedGlobalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan,
+ orderedLimit = true)
+ val orderedGlobalLimitResult = SparkPlanTest.executePlan(orderedGlobalLimit, spark.sqlContext)
+ .map(_.getInt(0))
+
+ val globalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, orderedLimit = false)
+ val globalLimitResult = SparkPlanTest.executePlan(globalLimit, spark.sqlContext)
+ .map(_.getInt(0))
+
+ // Global limit without order takes values at each partition sequentially.
+ // After global sort, the values in second partition must be larger than the values
+ // in first partition.
+ assert(orderedGlobalLimitResult(0) == globalLimitResult(0))
+ assert(orderedGlobalLimitResult(1) < globalLimitResult(1))
+ assert(orderedGlobalLimitResult(2) < globalLimitResult(2))
+ }
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/spark/blob/2f422398/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
index f076959..9322204 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import scala.util.Random
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.internal.SQLConf
@@ -32,28 +32,10 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
private var rand: Random = _
private var seed: Long = 0
- private val originalLimitFlatGlobalLimit = SQLConf.get.getConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT)
-
protected override def beforeAll(): Unit = {
super.beforeAll()
seed = System.currentTimeMillis()
rand = new Random(seed)
-
- // Disable the optimization to make Sort-Limit match `TakeOrderedAndProject` semantics.
- SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false)
- }
-
- protected override def afterAll() = {
- SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit)
- super.afterAll()
- }
-
- private def generateRandomInputData(): DataFrame = {
- val schema = new StructType()
- .add("a", IntegerType, nullable = false)
- .add("b", IntegerType, nullable = false)
- val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
- spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema)
}
/**
@@ -66,32 +48,62 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
val sortOrder = 'a.desc :: 'b.desc :: Nil
test("TakeOrderedAndProject.doExecute without project") {
- withClue(s"seed = $seed") {
- checkThatPlansAgree(
- generateRandomInputData(),
- input =>
- noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
- input =>
- GlobalLimitExec(limit,
- LocalLimitExec(limit,
- SortExec(sortOrder, true, input))),
- sortAnswers = false)
+ withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") {
+ withClue(s"seed = $seed") {
+ checkThatPlansAgree(
+ LimitTest.generateRandomInputData(spark, rand),
+ input =>
+ noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
+ input =>
+ GlobalLimitExec(limit,
+ LocalLimitExec(limit,
+ SortExec(sortOrder, true, input))),
+ sortAnswers = false)
+ }
}
}
test("TakeOrderedAndProject.doExecute with project") {
- withClue(s"seed = $seed") {
- checkThatPlansAgree(
- generateRandomInputData(),
- input =>
- noOpFilter(
- TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
- input =>
- GlobalLimitExec(limit,
- LocalLimitExec(limit,
- ProjectExec(Seq(input.output.last),
- SortExec(sortOrder, true, input)))),
- sortAnswers = false)
+ withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") {
+ withClue(s"seed = $seed") {
+ checkThatPlansAgree(
+ LimitTest.generateRandomInputData(spark, rand),
+ input =>
+ noOpFilter(
+ TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
+ input =>
+ GlobalLimitExec(limit,
+ LocalLimitExec(limit,
+ ProjectExec(Seq(input.output.last),
+ SortExec(sortOrder, true, input)))),
+ sortAnswers = false)
+ }
}
}
+
+ test("TakeOrderedAndProject.doExecute equals to ordered global limit") {
+ withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") {
+ withClue(s"seed = $seed") {
+ checkThatPlansAgree(
+ LimitTest.generateRandomInputData(spark, rand),
+ input =>
+ noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
+ input =>
+ GlobalLimitExec(limit,
+ LocalLimitExec(limit,
+ SortExec(sortOrder, true, input)), orderedLimit = true),
+ sortAnswers = false)
+ }
+ }
+ }
+}
+
+object LimitTest {
+ def generateRandomInputData(spark: SparkSession, rand: Random): DataFrame = {
+ val schema = new StructType()
+ .add("a", IntegerType, nullable = false)
+ .add("b", IntegerType, nullable = false)
+ val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
+ spark.createDataFrame(spark.sparkContext.parallelize(Random.shuffle(inputData), 10), schema)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org