You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/04/26 04:59:55 UTC

[GitHub] [spark] maropu commented on a change in pull request #32210: [SPARK-32634][SQL] Introduce sort-based fallback for shuffled hash join (non-code-gen path)

maropu commented on a change in pull request #32210:
URL: https://github.com/apache/spark/pull/32210#discussion_r619912231



##########
File path: sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
##########
@@ -1394,4 +1394,32 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
       checkAnswer(fullJoinDF, Row(100))
     }
   }
+
+  test("SPARK-32634: Sort-based fallback for shuffled hash join") {
+    val df1 = spark.range(300).map(_.toString).select($"value".as("k1"))
+    val df2 = spark.range(100).map(_.toString).select($"value".as("k2"))
+
+    val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2")
+    assert(collect(smjDF.queryExecution.executedPlan) {
+      case _: SortMergeJoinExec => true }.size === 1)
+    val smjResult = smjDF.collect()
+
+    Seq(
+      // All tasks fall back
+      0,
+      // Some tasks fall back
+      10,
+      // No task falls back
+      1000
+    ).foreach(fallbackStartsAt =>
+      withSQLConf(SQLConf.SHUFFLEDHASHJOIN_FALLBACK_ENABLED.key -> "true",
+        "spark.sql.ShuffledHashJoin.testFallbackStartsAt" -> fallbackStartsAt.toString) {
+        val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2")
+        assert(collect(shjDF.queryExecution.executedPlan) {
+          case _: ShuffledHashJoinExec => true }.size === 1)
+        // Same result between shuffled hash join and sort merge join
+        checkAnswer(shjDF, smjResult)

Review comment:
       Is this a test for non-codegen path?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -81,11 +83,22 @@ case class ShuffledHashJoinExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val spillThreshold = getSpillThreshold
+    val inMemoryThreshold = getInMemoryThreshold
+    val streamSortPlan = getStreamSortPlan
+    val buildSortPlan = getBuildSortPlan
+    val fallbackSMJPlan = SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right)
+
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
-      val hashed = buildHashedRelation(buildIter)
-      joinType match {
-        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
-        case _ => join(streamIter, hashed, numOutputRows)
+      buildHashedRelation(buildIter) match {
+        case r: UnfinishedUnsafeHashedRelation =>

Review comment:
       How about adding a new SQL metric for #fallbacks then checking it in the test?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -475,18 +501,89 @@ private[joins] object UnsafeHashedRelation {
           key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
           row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
         if (!success) {
-          binaryMap.free()
-          throw QueryExecutionErrors.cannotAcquireMemoryToBuildUnsafeHashedRelationError()
+          if (allowsFallbackWithNoMemory) {
+            return new UnfinishedUnsafeHashedRelation(numFields, binaryMap, row)
+          } else {
+            // Clean up map and throw exception
+            binaryMap.free()
+            throw QueryExecutionErrors.cannotAcquireMemoryToBuildUnsafeHashedRelationError()
+          }
         }
       } else if (isNullAware) {
         return HashedRelationWithAllNullKeys
       }
+      i += 1
     }
 
     new UnsafeHashedRelation(key.size, numFields, binaryMap)
   }
 }
 
+/**
+ * An unfinished version of [[UnsafeHashedRelation]].
+ * This is intended to use in sort-based fallback of [[ShuffledHashJoinExec]],
+ * when there is no enough memory to build [[UnsafeHashedRelation]].
+ *
+ * @param numFields Number of fields in each row.
+ * @param binaryMap Backed [[BytesToBytesMap]] to hold keys and rows.
+ * @param pendingRow The row which cannot be added to `binaryMap` due to memory limit.
+ */
+private[joins] class UnfinishedUnsafeHashedRelation(

Review comment:
       Needs tests for this new class.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org