You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2020/01/31 04:22:19 UTC

[spark] branch master updated: [SPARK-29438][SS] Use partition ID of StateStoreAwareZipPartitionsRDD for determining partition ID of state store in stream-stream join

This is an automated email from the ASF dual-hosted git repository.

tdas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cbb714f  [SPARK-29438][SS] Use partition ID of StateStoreAwareZipPartitionsRDD for determining partition ID of state store in stream-stream join
cbb714f is described below

commit cbb714f67e96dfd9678c60586af48febfba878ca
Author: Jungtaek Lim (HeartSaVioR) <ka...@gmail.com>
AuthorDate: Thu Jan 30 20:21:43 2020 -0800

    [SPARK-29438][SS] Use partition ID of StateStoreAwareZipPartitionsRDD for determining partition ID of state store in stream-stream join
    
    ### What changes were proposed in this pull request?
    
    Credit to uncleGen for discovering the problem and providing simple reproducer as UT. New UT in this patch is borrowed from #26156 and I'm retaining a commit from #26156 (except unnecessary part on this path) to properly give a credit.
    
    This patch fixes the issue that partition ID could be mis-assigned when the query contains UNION and stream-stream join is placed on the right side. We assume the range of partition IDs as `(0 ~ number of shuffle partitions - 1)` for stateful operators, but when we use stream-stream join on the right side of UNION, the range of partition ID of task goes to `(number of partitions in left side, number of partitions in left side + number of shuffle partitions - 1)`, which `number of part [...]
    
    The root reason of bug is that stream-stream join picks the partition ID from TaskContext, which wouldn't be same as partition ID from source if union is being used. Hopefully we can pick the right partition ID from source in StateStoreAwareZipPartitionsRDD - this patch leverages that partition ID.
    
    ### Why are the changes needed?
    
    This patch will fix the broken of assumption of partition range on stateful operator, as well as fix the issue reported in JIRA issue SPARK-29438.
    
    ### Does this PR introduce any user-facing change?
    
    Yes, if their query is using UNION and stream-stream join is placed on the right side. They may encounter the problem to read state from checkpoint and may need to discard checkpoint to continue.
    
    ### How was this patch tested?
    
    Added UT which fails on current master branch, and passes with this patch.
    
    Closes #26162 from HeartSaVioR/SPARK-29438.
    
    Lead-authored-by: Jungtaek Lim (HeartSaVioR) <ka...@gmail.com>
    Co-authored-by: uncleGen <hu...@gmail.com>
    Signed-off-by: Tathagata Das <ta...@gmail.com>
---
 .../streaming/StreamingSymmetricHashJoinExec.scala | 11 ++--
 .../StreamingSymmetricHashJoinHelper.scala         | 36 +++++++++----
 .../spark/sql/execution/streaming/memory.scala     | 32 ++++++++++--
 .../state/SymmetricHashJoinStateManager.scala      |  5 +-
 .../state/SymmetricHashJoinStateManagerSuite.scala |  2 +-
 .../apache/spark/sql/streaming/StreamTest.scala    |  6 ++-
 .../spark/sql/streaming/StreamingJoinSuite.scala   | 61 +++++++++++++++++++++-
 7 files changed, 129 insertions(+), 24 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index 3c45f22..198e17d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -215,6 +215,7 @@ case class StreamingSymmetricHashJoinExec(
   }
 
   private def processPartitions(
+      partitionId: Int,
       leftInputIter: Iterator[InternalRow],
       rightInputIter: Iterator[InternalRow]): Iterator[InternalRow] = {
     if (stateInfo.isEmpty) {
@@ -238,10 +239,10 @@ case class StreamingSymmetricHashJoinExec(
       Predicate.create(condition.bothSides.getOrElse(Literal(true)), inputSchema).eval _
     val leftSideJoiner = new OneSideHashJoiner(
       LeftSide, left.output, leftKeys, leftInputIter,
-      condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left)
+      condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left, partitionId)
     val rightSideJoiner = new OneSideHashJoiner(
       RightSide, right.output, rightKeys, rightInputIter,
-      condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right)
+      condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right, partitionId)
 
     //  Join one side input using the other side's buffered/state rows. Here is how it is done.
     //
@@ -406,6 +407,7 @@ case class StreamingSymmetricHashJoinExec(
    * @param stateWatermarkPredicate The state watermark predicate. See
    *                                [[StreamingSymmetricHashJoinExec]] for further description of
    *                                state watermarks.
+   * @param partitionId A partition ID of source RDD.
    */
   private class OneSideHashJoiner(
       joinSide: JoinSide,
@@ -414,7 +416,8 @@ case class StreamingSymmetricHashJoinExec(
       inputIter: Iterator[InternalRow],
       preJoinFilterExpr: Option[Expression],
       postJoinFilter: (InternalRow) => Boolean,
-      stateWatermarkPredicate: Option[JoinStateWatermarkPredicate]) {
+      stateWatermarkPredicate: Option[JoinStateWatermarkPredicate],
+      partitionId: Int) {
 
     // Filter the joined rows based on the given condition.
     val preJoinFilter =
@@ -422,7 +425,7 @@ case class StreamingSymmetricHashJoinExec(
 
     private val joinStateManager = new SymmetricHashJoinStateManager(
       joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value,
-      stateFormatVersion)
+      partitionId, stateFormatVersion)
     private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes)
 
     private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
index 2d4c3c10..cdd3a85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
@@ -18,11 +18,10 @@
 package org.apache.spark.sql.execution.streaming
 
 import scala.reflect.ClassTag
-import scala.util.control.NonFatal
 
-import org.apache.spark.{Partition, SparkContext}
+import org.apache.spark.{Partition, SparkContext, TaskContext}
 import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2}
+import org.apache.spark.rdd.{RDD, ZippedPartitionsBaseRDD, ZippedPartitionsPartition, ZippedPartitionsRDD2}
 import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper
 import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus}
 import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
@@ -203,17 +202,18 @@ object StreamingSymmetricHashJoinHelper extends Logging {
   /**
    * A custom RDD that allows partitions to be "zipped" together, while ensuring the tasks'
    * preferred location is based on which executors have the required join state stores already
-   * loaded. This is class is a modified version of [[ZippedPartitionsRDD2]].
+   * loaded. This class is a variant of [[ZippedPartitionsRDD2]] which only changes signature
+   * of `f`.
    */
   class StateStoreAwareZipPartitionsRDD[A: ClassTag, B: ClassTag, V: ClassTag](
       sc: SparkContext,
-      f: (Iterator[A], Iterator[B]) => Iterator[V],
-      rdd1: RDD[A],
-      rdd2: RDD[B],
+      var f: (Int, Iterator[A], Iterator[B]) => Iterator[V],
+      var rdd1: RDD[A],
+      var rdd2: RDD[B],
       stateInfo: StatefulOperatorStateInfo,
       stateStoreNames: Seq[String],
       @transient private val storeCoordinator: Option[StateStoreCoordinatorRef])
-      extends ZippedPartitionsRDD2[A, B, V](sc, f, rdd1, rdd2) {
+      extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) {
 
     /**
      * Set the preferred location of each partition using the executor that has the related
@@ -225,6 +225,24 @@ object StreamingSymmetricHashJoinHelper extends Logging {
         storeCoordinator.flatMap(_.getLocation(stateStoreProviderId))
       }.distinct
     }
+
+    override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+      val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+      if (partitions(0).index != partitions(1).index) {
+        throw new IllegalStateException(s"Partition ID should be same in both side: " +
+          s"left ${partitions(0).index} , right ${partitions(1).index}")
+      }
+
+      val partitionId = partitions(0).index
+      f(partitionId, rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context))
+    }
+
+    override def clearDependencies(): Unit = {
+      super.clearDependencies()
+      rdd1 = null
+      rdd2 = null
+      f = null
+    }
   }
 
   implicit class StateStoreAwareZipPartitionsHelper[T: ClassTag](dataRDD: RDD[T]) {
@@ -239,7 +257,7 @@ object StreamingSymmetricHashJoinHelper extends Logging {
         stateInfo: StatefulOperatorStateInfo,
         storeNames: Seq[String],
         storeCoordinator: StateStoreCoordinatorRef
-      )(f: (Iterator[T], Iterator[U]) => Iterator[V]): RDD[V] = {
+      )(f: (Int, Iterator[T], Iterator[U]) => Iterator[V]): RDD[V] = {
       new StateStoreAwareZipPartitionsRDD(
         dataRDD.sparkContext, f, dataRDD, dataRDD2, stateInfo, storeNames, Some(storeCoordinator))
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 911a526..34ed0da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -44,6 +44,9 @@ object MemoryStream {
 
   def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] =
     new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
+
+  def apply[A : Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] =
+    new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, Some(numPartitions))
 }
 
 /**
@@ -136,9 +139,14 @@ class MemoryStreamScanBuilder(stream: MemoryStreamBase[_]) extends ScanBuilder w
  * A [[Source]] that produces value stored in memory as they are added by the user.  This [[Source]]
  * is intended for use in unit tests as it can only replay data when the object is still
  * available.
+ *
+ * If numPartitions is provided, the rows will be redistributed to the given number of partitions.
  */
-case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
-    extends MemoryStreamBase[A](sqlContext) with MicroBatchStream with Logging {
+case class MemoryStream[A : Encoder](
+    id: Int,
+    sqlContext: SQLContext,
+    numPartitions: Option[Int] = None)
+  extends MemoryStreamBase[A](sqlContext) with MicroBatchStream with Logging {
 
   protected val output = logicalPlan.output
 
@@ -206,9 +214,23 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
 
       logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal))
 
-      newBlocks.map { block =>
-        new MemoryStreamInputPartition(block)
-      }.toArray
+      numPartitions match {
+        case Some(numParts) =>
+          // When the number of partition is provided, we redistribute the rows into
+          // the given number of partition, via round-robin manner.
+          val inputRows = newBlocks.flatten.toArray
+          (0 until numParts).map { newPartIdx =>
+            val records = inputRows.zipWithIndex.filter { case (_, idx) =>
+              idx % numParts == newPartIdx
+            }.map(_._1)
+            new MemoryStreamInputPartition(records)
+          }.toArray
+
+        case _ =>
+          newBlocks.map { block =>
+            new MemoryStreamInputPartition(block)
+          }.toArray
+      }
     }
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
index c107137..1a0a43c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
@@ -44,6 +44,7 @@ import org.apache.spark.util.NextIterator
  * @param stateInfo             Information about how to retrieve the correct version of state
  * @param storeConf             Configuration for the state store.
  * @param hadoopConf            Hadoop configuration for reading state data from storage
+ * @param partitionId           A partition ID of source RDD.
  * @param stateFormatVersion    The version of format for state.
  *
  * Internally, the key -> multiple values is stored in two [[StateStore]]s.
@@ -72,8 +73,8 @@ class SymmetricHashJoinStateManager(
     stateInfo: Option[StatefulOperatorStateInfo],
     storeConf: StateStoreConf,
     hadoopConf: Configuration,
+    partitionId: Int,
     stateFormatVersion: Int) extends Logging {
-
   import SymmetricHashJoinStateManager._
 
   /*
@@ -356,7 +357,7 @@ class SymmetricHashJoinStateManager(
     /** Get the StateStore with the given schema */
     protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = {
       val storeProviderId = StateStoreProviderId(
-        stateInfo.get, TaskContext.getPartitionId(), getStateStoreName(joinSide, stateStoreType))
+        stateInfo.get, partitionId, getStateStoreName(joinSide, stateStoreType))
       val store = StateStore.get(
         storeProviderId, keySchema, valueSchema, None,
         stateInfo.get.storeVersion, storeConf, hadoopConf)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
index b40f8df..ce1eabe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
@@ -170,7 +170,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter
       val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5)
       val manager = new SymmetricHashJoinStateManager(
         LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration,
-        stateFormatVersion)
+        partitionId = 0, stateFormatVersion)
       try {
         f(manager)
       } finally {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index e54a537..6d5ad87 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -112,7 +112,11 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with
   object MultiAddData {
     def apply[A]
       (source1: MemoryStream[A], data1: A*)(source2: MemoryStream[A], data2: A*): StreamAction = {
-      val actions = Seq(AddDataMemory(source1, data1), AddDataMemory(source2, data2))
+      apply((source1, data1), (source2, data2))
+    }
+
+    def apply[A](inputs: (MemoryStream[A], Seq[A])*): StreamAction = {
+      val actions = inputs.map { case (source, data) => AddDataMemory(source, data) }
       StreamProgressLockedActions(actions, desc = actions.mkString("[ ", " | ", " ]"))
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index ae6a4ec..3f218c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -25,6 +25,7 @@ import scala.util.Random
 import org.apache.commons.io.FileUtils
 import org.scalatest.BeforeAndAfter
 
+import org.apache.spark.SparkContext
 import org.apache.spark.scheduler.ExecutorCacheTaskLocation
 import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession}
 import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper
@@ -376,7 +377,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
       val rdd1 = spark.sparkContext.makeRDD(1 to 10, numPartitions)
       val rdd2 = spark.sparkContext.makeRDD((1 to 10).map(_.toString), numPartitions)
       val rdd = rdd1.stateStoreAwareZipPartitions(rdd2, stateInfo, storeNames, coordinatorRef) {
-        (left, right) => left.zip(right)
+        (_, left, right) => left.zip(right)
       }
       require(rdd.partitions.length === numPartitions)
       for (partIndex <- 0 until numPartitions) {
@@ -933,5 +934,61 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
     assert(e.getMessage.toLowerCase(Locale.ROOT)
       .contains("the query is using stream-stream outer join with state format version 1"))
   }
-}
 
+  test("SPARK-29438: ensure UNION doesn't lead stream-stream join to use shifted partition IDs") {
+    def constructUnionDf(desiredPartitionsForInput1: Int)
+        : (MemoryStream[Int], MemoryStream[Int], MemoryStream[Int], DataFrame) = {
+      val input1 = MemoryStream[Int](desiredPartitionsForInput1)
+      val df1 = input1.toDF
+        .select(
+          'value as "key",
+          'value as "leftValue",
+          'value as "rightValue")
+      val (input2, df2) = setupStream("left", 2)
+      val (input3, df3) = setupStream("right", 3)
+
+      val joined = df2
+        .join(df3,
+          df2("key") === df3("key") && df2("leftTime") === df3("rightTime"),
+          "inner")
+        .select(df2("key"), 'leftValue, 'rightValue)
+
+      (input1, input2, input3, df1.union(joined))
+    }
+
+    withTempDir { tempDir =>
+      val (input1, input2, input3, unionDf) = constructUnionDf(2)
+
+      testStream(unionDf)(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        MultiAddData(
+          (input1, Seq(11, 12, 13)),
+          (input2, Seq(11, 12, 13, 14, 15)),
+          (input3, Seq(13, 14, 15, 16, 17))),
+        CheckNewAnswer(Row(11, 11, 11), Row(12, 12, 12), Row(13, 13, 13), Row(13, 26, 39),
+          Row(14, 28, 42), Row(15, 30, 45)),
+        StopStream
+      )
+
+      // We're restoring the query with different number of partitions in left side of UNION,
+      // which leads right side of union to have mismatched partition IDs if it relies on
+      // TaskContext.partitionId(). SPARK-29438 fixes this issue to not rely on it.
+
+      val (newInput1, newInput2, newInput3, newUnionDf) = constructUnionDf(3)
+
+      newInput1.addData(11, 12, 13)
+      newInput2.addData(11, 12, 13, 14, 15)
+      newInput3.addData(13, 14, 15, 16, 17)
+
+      testStream(newUnionDf)(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        MultiAddData(
+          (newInput1, Seq(21, 22, 23)),
+          (newInput2, Seq(21, 22, 23, 24, 25)),
+          (newInput3, Seq(23, 24, 25, 26, 27))),
+        CheckNewAnswer(Row(21, 21, 21), Row(22, 22, 22), Row(23, 23, 23), Row(23, 46, 69),
+          Row(24, 48, 72), Row(25, 50, 75))
+      )
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org