You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2024/01/22 07:31:12 UTC
(spark) branch master updated: Revert "[SPARK-46219][SQL] Unwrap cast in join predicates"
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 498519ee6bb4 Revert "[SPARK-46219][SQL] Unwrap cast in join predicates"
498519ee6bb4 is described below
commit 498519ee6bb4b0295d1df005175e4cbcbcb051e3
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Mon Jan 22 15:30:31 2024 +0800
Revert "[SPARK-46219][SQL] Unwrap cast in join predicates"
This reverts commit 8235f1d56bf232bb713fe24ff6f2ffdaf49d2fcc.
---
.../org/apache/spark/sql/internal/SQLConf.scala | 10 --
.../bucketing/CoalesceBucketsInJoin.scala | 22 +++-
.../execution/exchange/EnsureRequirements.scala | 25 +----
...ractJoinWithUnwrappedCastInJoinPredicates.scala | 114 ---------------------
.../spark/sql/execution/joins/ShuffledJoin.scala | 14 +--
.../apache/spark/sql/execution/PlannerSuite.scala | 74 -------------
.../spark/sql/sources/BucketedReadSuite.scala | 65 ------------
7 files changed, 23 insertions(+), 301 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 3dd7cf884cbe..bc4734775c77 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -564,14 +564,6 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
- val UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED =
- buildConf("spark.sql.unwrapCastInJoinCondition.enabled")
- .doc("When true, unwrap the cast in the join condition to reduce shuffle if they are " +
- "integral types.")
- .version("4.0.0")
- .booleanConf
- .createWithDefault(true)
-
val MAX_SINGLE_PARTITION_BYTES = buildConf("spark.sql.maxSinglePartitionBytes")
.doc("The maximum number of bytes allowed for a single partition. Otherwise, The planner " +
"will introduce shuffle to improve parallelism.")
@@ -5126,8 +5118,6 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN)
- def unwrapCastInJoinConditionEnabled: Boolean = getConf(UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED)
-
def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED)
def isParquetSchemaMergingEnabled: Boolean = getConf(PARQUET_SCHEMA_MERGING_ENABLED)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala
index ab0eaa044dea..d1464b4ac4ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala
@@ -20,7 +20,9 @@ package org.apache.spark.sql.execution.bucketing
import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{FileSourceScanExec, FilterExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec}
@@ -129,11 +131,27 @@ object ExtractJoinWithBuckets {
}
}
+ /**
+ * The join keys should match with expressions for output partitioning. Note that
+ * the ordering does not matter because it will be handled in `EnsureRequirements`.
+ */
+ private def satisfiesOutputPartitioning(
+ keys: Seq[Expression],
+ partitioning: Partitioning): Boolean = {
+ partitioning match {
+ case HashPartitioning(exprs, _) if exprs.length == keys.length =>
+ exprs.forall(e => keys.exists(_.semanticEquals(e)))
+ case PartitioningCollection(partitionings) =>
+ partitionings.exists(satisfiesOutputPartitioning(keys, _))
+ case _ => false
+ }
+ }
+
private def isApplicable(j: ShuffledJoin): Boolean = {
hasScanOperation(j.left) &&
hasScanOperation(j.right) &&
- j.satisfiesOutputPartitioning(j.leftKeys, j.left.outputPartitioning) &&
- j.satisfiesOutputPartitioning(j.rightKeys, j.right.outputPartitioning)
+ satisfiesOutputPartitioning(j.leftKeys, j.left.outputPartitioning) &&
+ satisfiesOutputPartitioning(j.rightKeys, j.right.outputPartitioning)
}
private def isDivisible(numBuckets1: Int, numBuckets2: Int): Boolean = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 38a8b5db2695..2a7c1206bb41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -337,28 +337,6 @@ case class EnsureRequirements(
}
}
- /**
- * Unwrap the cast in join predicates to reduce shuffle.
- */
- private def unwrapCastInJoinPredicates(plan: SparkPlan): SparkPlan = {
- if (conf.unwrapCastInJoinConditionEnabled) {
- plan match {
- case ExtractJoinWithUnwrappedCastInJoinPredicates(join, joinKeys) =>
- val (leftKeys, rightKeys) = joinKeys.unzip
- join match {
- case j: SortMergeJoinExec =>
- j.copy(leftKeys = leftKeys, rightKeys = rightKeys)
- case j: ShuffledHashJoinExec =>
- j.copy(leftKeys = leftKeys, rightKeys = rightKeys)
- case other => other
- }
- case _ => plan
- }
- } else {
- plan
- }
- }
-
/**
* Checks whether two children, `left` and `right`, of a join operator have compatible
* `KeyGroupedPartitioning`, and can benefit from storage-partitioned join.
@@ -627,8 +605,7 @@ case class EnsureRequirements(
}
case operator: SparkPlan =>
- val unwrapped = unwrapCastInJoinPredicates(operator)
- val reordered = reorderJoinPredicates(unwrapped)
+ val reordered = reorderJoinPredicates(operator)
val newChildren = ensureDistributionAndOrdering(
Some(reordered),
reordered.children,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExtractJoinWithUnwrappedCastInJoinPredicates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExtractJoinWithUnwrappedCastInJoinPredicates.scala
deleted file mode 100644
index 5d46fac90985..000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExtractJoinWithUnwrappedCastInJoinPredicates.scala
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.exchange
-
-import org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion.findWiderTypeForTwo
-import org.apache.spark.sql.catalyst.expressions.{Cast, EvalMode, Expression}
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.joins.ShuffledJoin
-import org.apache.spark.sql.types.{DataType, DecimalType, IntegralType}
-
-/**
- * An extractor that extracts `SortMergeJoinExec` and `ShuffledHashJoin`,
- * where one sides can do bucketed read after unwrap cast in join keys.
- */
-object ExtractJoinWithUnwrappedCastInJoinPredicates {
- private def isIntegralType(dt: DataType): Boolean = dt match {
- case _: IntegralType => true
- case DecimalType.Fixed(_, 0) => true
- case _ => false
- }
-
- private def unwrapCastInJoinKeys(joinKeys: Seq[Expression]): Seq[Expression] = {
- joinKeys.map {
- case c: Cast if isIntegralType(c.child.dataType) => c.child
- case e => e
- }
- }
-
- // Casts the left or right side of join keys to the same data type.
- private def coerceJoinKeyType(
- unwrapLeftKeys: Seq[Expression],
- unwrapRightKeys: Seq[Expression],
- isAddCastToLeftSide: Boolean): Seq[(Expression, Expression)] = {
- unwrapLeftKeys.zip(unwrapRightKeys).map {
- case (l, r) if l.dataType != r.dataType =>
- // Use TRY mode to avoid runtime exception in ANSI mode or data issue in non-ANSI mode.
- if (isAddCastToLeftSide) {
- Cast(l, r.dataType, evalMode = EvalMode.TRY) -> r
- } else {
- l -> Cast(r, l.dataType, evalMode = EvalMode.TRY)
- }
- case (l, r) => l -> r
- }
- }
-
- private def unwrapCastInJoinPredicates(j: ShuffledJoin): Option[Seq[(Expression, Expression)]] = {
- val leftKeys = unwrapCastInJoinKeys(j.leftKeys)
- val rightKeys = unwrapCastInJoinKeys(j.rightKeys)
- // Make sure cast to wider type.
- // For example, we do not support: cast(longCol as int) = cast(decimalCol as int).
- val isCastToWiderType = leftKeys.zip(rightKeys).zipWithIndex.forall {
- case ((e1, e2), i) =>
- findWiderTypeForTwo(e1.dataType, e2.dataType).contains(j.leftKeys(i).dataType)
- }
- if (isCastToWiderType) {
- val leftSatisfies = j.satisfiesOutputPartitioning(leftKeys, j.left.outputPartitioning)
- val rightSatisfies = j.satisfiesOutputPartitioning(rightKeys, j.right.outputPartitioning)
- if (leftSatisfies && rightSatisfies) {
- // If there is a bucketed read, their number of partitions may be inconsistent.
- // If the number of partitions on the left side is less than the number of partitions
- // on the right side, cast the left side keys to the data type of the right side keys.
- // Otherwise, cast the right side keys to the data type of the left side keys.
- Some(coerceJoinKeyType(leftKeys, rightKeys,
- j.left.outputPartitioning.numPartitions < j.right.outputPartitioning.numPartitions))
- } else if (leftSatisfies) {
- Some(coerceJoinKeyType(leftKeys, rightKeys, false))
- } else if (rightSatisfies) {
- Some(coerceJoinKeyType(leftKeys, rightKeys, true))
- } else {
- None
- }
- } else {
- None
- }
- }
-
- private def isTryToUnwrapCastInJoinPredicates(j: ShuffledJoin): Boolean = {
- (j.leftKeys.exists(_.isInstanceOf[Cast]) || j.rightKeys.exists(_.isInstanceOf[Cast])) &&
- !j.satisfiesOutputPartitioning(j.leftKeys, j.left.outputPartitioning) &&
- !j.satisfiesOutputPartitioning(j.rightKeys, j.right.outputPartitioning) &&
- j.children.map(_.outputPartitioning).exists { _ match {
- case _: PartitioningCollection => true
- case _: HashPartitioning => true
- case _ => false
- }}
- }
-
- def unapply(plan: SparkPlan): Option[(ShuffledJoin, Seq[(Expression, Expression)])] = {
- plan match {
- case j: ShuffledJoin if isTryToUnwrapCastInJoinPredicates(j) =>
- unwrapCastInJoinPredicates(j) match {
- case Some(joinKeys) => Some(j, joinKeys)
- case _ => None
- }
- case _ => None
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
index 9591218b099b..7c4628c8576c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
@@ -17,9 +17,9 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftExistence, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning, UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning, PartitioningCollection, UnknownPartitioning, UnspecifiedDistribution}
/**
* Holds common logic for join operators by shuffling two child relations
@@ -56,16 +56,6 @@ trait ShuffledJoin extends JoinCodegenSupport {
s"ShuffledJoin should not take $x as the JoinType")
}
- def satisfiesOutputPartitioning(keys: Seq[Expression], partitioning: Partitioning): Boolean = {
- partitioning match {
- case HashPartitioning(exprs, _) if exprs.length == keys.length =>
- exprs.forall(e => keys.exists(_.semanticEquals(e)))
- case PartitioningCollection(partitionings) =>
- partitionings.exists(satisfiesOutputPartitioning(keys, _))
- case _ => false
- }
- }
-
override def output: Seq[Attribute] = {
joinType match {
case _: InnerLike =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 71a86d599c0c..be532ed9097c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -1373,80 +1373,6 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
}
- test("SPARK-46219: Unwrap cast in join condition") {
- val intExpr = Literal(1)
- val longExpr = Literal(1L)
- val smjExec = SortMergeJoinExec(
- leftKeys = Cast(intExpr, LongType) :: Nil,
- rightKeys = longExpr :: Nil,
- joinType = Inner,
- condition = None,
- left = DummySparkPlan(outputPartitioning = HashPartitioning(intExpr:: Nil, 5)),
- right = DummySparkPlan())
-
- Seq(true, false).foreach { unwrapCast =>
- withSQLConf(SQLConf.UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED.key -> unwrapCast.toString) {
- val outputPlan = EnsureRequirements.apply(smjExec)
- if (unwrapCast) {
- outputPlan match {
- case SortMergeJoinExec(leftKeys, rightKeys, _, _,
- SortExec(_, _, _: DummySparkPlan, _),
- SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _), _), _) =>
- assert(leftKeys === Seq(intExpr))
- assert(rightKeys === Seq(Cast(longExpr, IntegerType, evalMode = EvalMode.TRY)))
- case _ => fail(outputPlan.toString)
- }
- } else {
- outputPlan match {
- case SortMergeJoinExec(leftKeys, rightKeys, _, _,
- SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _), _),
- SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _), _), _) =>
- assert(leftKeys === smjExec.leftKeys)
- assert(rightKeys === smjExec.rightKeys)
- case _ => fail(outputPlan.toString)
- }
- }
- }
- }
- }
-
- test("SPARK-46219: Number of partitions may be inconsistent") {
- val longExpr = Literal(1L)
- val decimalExpr = Literal(Decimal(1L, 18, 0))
- val smjExec = SortMergeJoinExec(
- leftKeys = Cast(longExpr, DecimalType(20, 0)) :: Nil,
- rightKeys = Cast(decimalExpr, DecimalType(20, 0)) :: Nil,
- joinType = Inner,
- condition = None,
- left = DummySparkPlan(outputPartitioning = HashPartitioning(longExpr :: Nil, 10)),
- right = DummySparkPlan(outputPartitioning = HashPartitioning(decimalExpr :: Nil, 5)))
-
- Seq(true, false).foreach { unwrapCast =>
- withSQLConf(SQLConf.UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED.key -> unwrapCast.toString) {
- val outputPlan = EnsureRequirements.apply(smjExec)
- if (unwrapCast) {
- outputPlan match {
- case SortMergeJoinExec(leftKeys, rightKeys, _, _,
- SortExec(_, _, _: DummySparkPlan, _),
- SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _), _), _) =>
- assert(leftKeys === Seq(longExpr))
- assert(rightKeys === Seq(Cast(decimalExpr, LongType, evalMode = EvalMode.TRY)))
- case _ => fail(outputPlan.toString)
- }
- } else {
- outputPlan match {
- case SortMergeJoinExec(leftKeys, rightKeys, _, _,
- SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _), _),
- SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _), _), _) =>
- assert(leftKeys === smjExec.leftKeys)
- assert(rightKeys === smjExec.rightKeys)
- case _ => fail(outputPlan.toString)
- }
- }
- }
- }
- }
-
test("Limit and offset should not drop LocalLimitExec operator") {
val df = sql("SELECT * FROM (SELECT * FROM RANGE(100) LIMIT 25 OFFSET 3) WHERE id > 10")
val planned = df.queryExecution.sparkPlan
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 52a316e63a81..3573bafe482c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -34,7 +34,6 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
-import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, LongType}
import org.apache.spark.tags.SlowSQLTest
import org.apache.spark.util.collection.BitSet
@@ -1089,68 +1088,4 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
}
}
}
-
- test("SPARK-46219: Unwrap cast in join condition") {
- def verify(
- query: String,
- expectedNumShuffles: Int,
- numPartitions: Option[Int] = None,
- partitioningKeyTypes: Option[Seq[DataType]] = None): Unit = {
- Seq(true, false).foreach { ansiEnabled =>
- Seq(true, false).foreach { aqeEnabled =>
- withSQLConf(
- SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled.toString) {
- val df = sql(query)
- val plan = df.queryExecution.executedPlan
- val shuffles = collect(plan) {
- case s: ShuffleExchangeExec => s
- }
- assert(shuffles.size === expectedNumShuffles)
- if (shuffles.size == 1) {
- val outputPartitioning = shuffles.head.outputPartitioning
- assert(outputPartitioning.numPartitions === numPartitions.get)
- assert(outputPartitioning.asInstanceOf[HashPartitioning]
- .expressions.map(_.dataType) === partitioningKeyTypes.get)
-
- collect(plan) { case s: SortMergeJoinExec => s }.flatMap(_.expressions).foreach {
- case c: Cast => assert(c.evalMode === EvalMode.TRY) // The EvalMode should be try.
- case _ =>
- }
-
- checkAnswer(df, Row(1, 1) :: Nil)
- }
- }
- }
- }
- }
-
- withTable("t1", "t2", "t3", "t4") {
- sql(
- s"""
- |CREATE TABLE t1 USING parquet CLUSTERED BY (i) INTO 8 buckets AS
- |SELECT CAST(v AS int) AS i FROM values(1), (${Int.MaxValue}) AS data(v)
- |""".stripMargin)
- sql(
- s"""
- |CREATE TABLE t2 USING parquet CLUSTERED BY (i) INTO 8 buckets AS
- |SELECT CAST(v AS bigint) AS i FROM values(1), (${Long.MaxValue}) AS data(v)
- |""".stripMargin)
- sql(
- s"""
- |CREATE TABLE t3 USING parquet CLUSTERED BY (i) INTO 4 buckets AS
- |SELECT CAST(v AS decimal(18, 0)) AS i FROM values(1), (${"9" * 18}) AS data(v)
- |""".stripMargin)
- spark.table("t2").write.saveAsTable("t4")
-
- withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
- SQLConf.UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED.key -> "true") {
- verify("SELECT * FROM t2 JOIN t3 ON t2.i = t3.i", 1, Some(8), Some(Seq(LongType)))
- verify("SELECT * FROM t1 JOIN t4 ON t1.i = t4.i", 1, Some(8), Some(Seq(IntegerType)))
- verify("SELECT * FROM t3 JOIN t4 ON t3.i = t4.i", 1, Some(4), Some(Seq(DecimalType(18, 0))))
- // Do not unwrap cast if it is added by user.
- verify("SELECT * FROM t2 JOIN t3 ON cast(t2.i as int) = cast(t3.i as int)", 2)
- }
- }
- }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org