You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/04/01 06:59:56 UTC
[spark] branch master updated: [SPARK-34919][SQL] Change
partitioning to SinglePartition if partition number is 1
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 89ae83d [SPARK-34919][SQL] Change partitioning to SinglePartition if partition number is 1
89ae83d is described below
commit 89ae83d19b9652348a685550c2c49920511160d5
Author: ulysses-you <ul...@gmail.com>
AuthorDate: Thu Apr 1 06:59:31 2021 +0000
[SPARK-34919][SQL] Change partitioning to SinglePartition if partition number is 1
### What changes were proposed in this pull request?
Change partitioning to `SinglePartition`.
### Why are the changes needed?
For node `Repartition` and `RepartitionByExpression`, if partition number is 1 we can use `SinglePartition` instead of other `Partitioning`.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Add test
Closes #32012 from ulysses-you/SPARK-34919.
Authored-by: ulysses-you <ul...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../catalyst/plans/logical/basicLogicalOperators.scala | 17 ++++++++++++++---
.../apache/spark/sql/execution/SparkStrategies.scala | 6 ++----
.../org/apache/spark/sql/execution/PlannerSuite.scala | 17 ++++++++++++++++-
.../spark/sql/execution/metric/SQLMetricsSuite.scala | 2 +-
4 files changed, 33 insertions(+), 9 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 24ccc61..9461dbf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -1001,6 +1001,7 @@ abstract class RepartitionOperation extends UnaryNode {
def numPartitions: Int
override final def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] = child.output
+ def partitioning: Partitioning
}
/**
@@ -1012,6 +1013,14 @@ abstract class RepartitionOperation extends UnaryNode {
case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
extends RepartitionOperation {
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
+
+ override def partitioning: Partitioning = {
+ require(shuffle, "Partitioning can only be used in shuffle.")
+ numPartitions match {
+ case 1 => SinglePartition
+ case _ => RoundRobinPartitioning(numPartitions)
+ }
+ }
}
/**
@@ -1029,7 +1038,7 @@ case class RepartitionByExpression(
val numPartitions = optNumPartitions.getOrElse(conf.numShufflePartitions)
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
- val partitioning: Partitioning = {
+ override val partitioning: Partitioning = {
val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder])
require(sortOrder.isEmpty || nonSortOrder.isEmpty,
@@ -1041,7 +1050,9 @@ case class RepartitionByExpression(
|NonSortOrder: $nonSortOrder
""".stripMargin)
- if (sortOrder.nonEmpty) {
+ if (numPartitions == 1) {
+ SinglePartition
+ } else if (sortOrder.nonEmpty) {
RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions)
} else if (nonSortOrder.nonEmpty) {
HashPartitioning(nonSortOrder, numPartitions)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 0e72050..ea83d59 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelec
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2}
import org.apache.spark.sql.execution.aggregate.AggUtils
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
@@ -685,10 +684,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
planLater(left), planLater(right)) :: Nil
- case logical.Repartition(numPartitions, shuffle, child) =>
+ case r @ logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
- ShuffleExchangeExec(RoundRobinPartitioning(numPartitions),
- planLater(child), REPARTITION_WITH_NUM) :: Nil
+ ShuffleExchangeExec(r.partitioning, planLater(child), REPARTITION_WITH_NUM) :: Nil
} else {
execution.CoalesceExec(numPartitions, planLater(child)) :: Nil
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index e851722..50da33a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range, Repartition, Union}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range, Repartition, RepartitionOperation, Union}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
@@ -1239,6 +1239,21 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
}
}
+
+ test("SPARK-34919: Change partitioning to SinglePartition if partition number is 1") {
+ def checkSinglePartitioning(df: DataFrame): Unit = {
+ assert(
+ df.queryExecution.analyzed.collect {
+ case r: RepartitionOperation => r
+ }.size == 1)
+ assert(
+ collect(df.queryExecution.executedPlan) {
+ case s: ShuffleExchangeExec if s.outputPartitioning == SinglePartition => s
+ }.size == 1)
+ }
+ checkSinglePartitioning(sql("SELECT /*+ REPARTITION(1) */ * FROM VALUES(1),(2),(3) AS t(c)"))
+ checkSinglePartitioning(sql("SELECT /*+ REPARTITION(1, c) */ * FROM VALUES(1),(2),(3) AS t(c)"))
+ }
}
// Used for unit-testing EnsureRequirements
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 50f9806..bbf58c6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -168,7 +168,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
// Exchange(nodeId = 5)
// LocalTableScan(nodeId = 6)
Seq(true, false).foreach { enableWholeStage =>
- val df = generateRandomBytesDF().repartition(1).groupBy('a).count()
+ val df = generateRandomBytesDF().repartition(2).groupBy('a).count()
val nodeIds = if (enableWholeStage) {
Set(4L, 1L)
} else {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org