You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2017/09/21 22:39:12 UTC

[1/2] spark git commit: [SPARK-22053][SS] Stream-stream inner join in Append Mode

Repository: spark
Updated Branches:
  refs/heads/master a8a5cd24e -> f32a84250


http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
index 894786c..368c460 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
 trait StateStoreMetricsTest extends StreamTest {
 
   def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery =
-    AssertOnQuery { q =>
+    AssertOnQuery(s"Check total state rows = $total, updated state rows = $updated") { q =>
       val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
       assert(
         progressWithData.stateOperators.map(_.numRowsTotal) === total,

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
new file mode 100644
index 0000000..533e116
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -0,0 +1,472 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.util.UUID
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet}
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter}
+import org.apache.spark.sql.execution.LogicalRDD
+import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper}
+import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+
+class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter {
+
+  before {
+    SparkSession.setActiveSession(spark)  // set this before force initializing 'joinExec'
+    spark.streams.stateStoreCoordinator   // initialize the lazy coordinator
+  }
+
+  after {
+    StateStore.stop()
+  }
+
+  import testImplicits._
+  test("stream stream inner join on non-time column") {
+    val input1 = MemoryStream[Int]
+    val input2 = MemoryStream[Int]
+
+    val df1 = input1.toDF.select('value as "key", ('value * 2) as "leftValue")
+    val df2 = input2.toDF.select('value as "key", ('value * 3) as "rightValue")
+    val joined = df1.join(df2, "key")
+
+    testStream(joined)(
+      AddData(input1, 1),
+      CheckAnswer(),
+      AddData(input2, 1, 10),       // 1 arrived on input1 first, then input2, should join
+      CheckLastBatch((1, 2, 3)),
+      AddData(input1, 10),          // 10 arrived on input2 first, then input1, should join
+      CheckLastBatch((10, 20, 30)),
+      AddData(input2, 1),           // another 1 in input2 should join with 1 input1
+      CheckLastBatch((1, 2, 3)),
+      StopStream,
+      StartStream(),
+      AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3)
+      CheckLastBatch((1, 2, 3), (1, 2, 3)),
+      StopStream,
+      StartStream(),
+      AddData(input1, 100),
+      AddData(input2, 100),
+      CheckLastBatch((100, 200, 300))
+    )
+  }
+
+  test("stream stream inner join on windows - without watermark") {
+    val input1 = MemoryStream[Int]
+    val input2 = MemoryStream[Int]
+
+    val df1 = input1.toDF
+      .select('value as "key", 'value.cast("timestamp") as "timestamp", ('value * 2) as "leftValue")
+      .select('key, window('timestamp, "10 second"), 'leftValue)
+
+    val df2 = input2.toDF
+      .select('value as "key", 'value.cast("timestamp") as "timestamp",
+        ('value * 3) as "rightValue")
+      .select('key, window('timestamp, "10 second"), 'rightValue)
+
+    val joined = df1.join(df2, Seq("key", "window"))
+      .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
+
+    testStream(joined)(
+      AddData(input1, 1),
+      CheckLastBatch(),
+      AddData(input2, 1),
+      CheckLastBatch((1, 10, 2, 3)),
+      StopStream,
+      StartStream(),
+      AddData(input1, 25),
+      CheckLastBatch(),
+      StopStream,
+      StartStream(),
+      AddData(input2, 25),
+      CheckLastBatch((25, 30, 50, 75)),
+      AddData(input1, 1),
+      CheckLastBatch((1, 10, 2, 3)),      // State for 1 still around as there is no watermark
+      StopStream,
+      StartStream(),
+      AddData(input1, 5),
+      CheckLastBatch(),
+      AddData(input2, 5),
+      CheckLastBatch((5, 10, 10, 15))     // No filter by any watermark
+    )
+  }
+
+  test("stream stream inner join on windows - with watermark") {
+    val input1 = MemoryStream[Int]
+    val input2 = MemoryStream[Int]
+
+    val df1 = input1.toDF
+      .select('value as "key", 'value.cast("timestamp") as "timestamp", ('value * 2) as "leftValue")
+      .withWatermark("timestamp", "10 seconds")
+      .select('key, window('timestamp, "10 second"), 'leftValue)
+
+    val df2 = input2.toDF
+      .select('value as "key", 'value.cast("timestamp") as "timestamp",
+        ('value * 3) as "rightValue")
+      .select('key, window('timestamp, "10 second"), 'rightValue)
+
+    val joined = df1.join(df2, Seq("key", "window"))
+      .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
+
+    testStream(joined)(
+      AddData(input1, 1),
+      CheckAnswer(),
+      assertNumStateRows(total = 1, updated = 1),
+
+      AddData(input2, 1),
+      CheckLastBatch((1, 10, 2, 3)),
+      assertNumStateRows(total = 2, updated = 1),
+      StopStream,
+      StartStream(),
+
+      AddData(input1, 25),
+      CheckLastBatch(), // since there is only 1 watermark operator, the watermark should be 15
+      assertNumStateRows(total = 3, updated = 1),
+
+      AddData(input2, 25),
+      CheckLastBatch((25, 30, 50, 75)), // watermark = 15 should remove 2 rows having window=[0,10]
+      assertNumStateRows(total = 2, updated = 1),
+      StopStream,
+      StartStream(),
+
+      AddData(input2, 1),
+      CheckLastBatch(),       // Should not join as < 15 removed
+      assertNumStateRows(total = 2, updated = 0),  // row not add as 1 < state key watermark = 15
+
+      AddData(input1, 5),
+      CheckLastBatch(),       // Should not join or add to state as < 15 got filtered by watermark
+      assertNumStateRows(total = 2, updated = 0)
+    )
+  }
+
+  test("stream stream inner join with time range - with watermark - one side condition") {
+    import org.apache.spark.sql.functions._
+
+    val leftInput = MemoryStream[(Int, Int)]
+    val rightInput = MemoryStream[(Int, Int)]
+
+    val df1 = leftInput.toDF.toDF("leftKey", "time")
+      .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 2) as "leftValue")
+      .withWatermark("leftTime", "10 seconds")
+
+    val df2 = rightInput.toDF.toDF("rightKey", "time")
+      .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey * 3) as "rightValue")
+      .withWatermark("rightTime", "10 seconds")
+
+    val joined =
+      df1.join(df2, expr("leftKey = rightKey AND leftTime < rightTime - interval 5 seconds"))
+        .select('leftKey, 'leftTime.cast("int"), 'rightTime.cast("int"))
+
+    testStream(joined)(
+      AddData(leftInput, (1, 5)),
+      CheckAnswer(),
+      AddData(rightInput, (1, 11)),
+      CheckLastBatch((1, 5, 11)),
+      AddData(rightInput, (1, 10)),
+      CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5
+      assertNumStateRows(total = 3, updated = 1),
+
+      // Increase event time watermark to 20s by adding data with time = 30s on both inputs
+      AddData(leftInput, (1, 3), (1, 30)),
+      CheckLastBatch((1, 3, 10), (1, 3, 11)),
+      assertNumStateRows(total = 5, updated = 2),
+      AddData(rightInput, (0, 30)),
+      CheckLastBatch(),
+      assertNumStateRows(total = 6, updated = 1),
+
+      // event time watermark:    max event time - 10   ==>   30 - 10 = 20
+      // right side state constraint:    20 < leftTime < rightTime - 5   ==>   rightTime > 25
+
+      // Run another batch with event time = 25 to clear right state where rightTime <= 25
+      AddData(rightInput, (0, 30)),
+      CheckLastBatch(),
+      assertNumStateRows(total = 5, updated = 1),  // removed (1, 11) and (1, 10), added (0, 30)
+
+      // New data to right input should match with left side (1, 3) and (1, 5), as left state should
+      // not be cleared. But rows rightTime <= 20 should be filtered due to event time watermark and
+      // state rows with rightTime <= 25 should be removed from state.
+      // (1, 20) ==> filtered by event time watermark = 20
+      // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state
+      //             as state watermark = 25
+      // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state
+      AddData(rightInput, (1, 20), (1, 21), (1, 28)),
+      CheckLastBatch((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)),
+      assertNumStateRows(total = 6, updated = 1),
+
+      // New data to left input with leftTime <= 20 should be filtered due to event time watermark
+      AddData(leftInput, (1, 20), (1, 21)),
+      CheckLastBatch((1, 21, 28)),
+      assertNumStateRows(total = 7, updated = 1)
+    )
+  }
+
+  test("stream stream inner join with time range - with watermark - two side conditions") {
+    import org.apache.spark.sql.functions._
+
+    val leftInput = MemoryStream[(Int, Int)]
+    val rightInput = MemoryStream[(Int, Int)]
+
+    val df1 = leftInput.toDF.toDF("leftKey", "time")
+      .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 2) as "leftValue")
+      .withWatermark("leftTime", "20 seconds")
+
+    val df2 = rightInput.toDF.toDF("rightKey", "time")
+      .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey * 3) as "rightValue")
+      .withWatermark("rightTime", "30 seconds")
+
+    val condition = expr(
+      "leftKey = rightKey AND " +
+        "leftTime BETWEEN rightTime - interval 10 seconds AND rightTime + interval 5 seconds")
+
+    // This translates to leftTime <= rightTime + 5 seconds AND leftTime >= rightTime - 10 seconds
+    // So given leftTime, rightTime has to be BETWEEN leftTime - 5 seconds AND leftTime + 10 seconds
+    //
+    //  =============== * ======================== * ============================== * ==> leftTime
+    //                  |                          |                                |
+    //     |<---- 5s -->|<------ 10s ------>|      |<------ 10s ------>|<---- 5s -->|
+    //     |                                |                          |
+    //  == * ============================== * =========>============== * ===============> rightTime
+    //
+    // E.g.
+    //      if rightTime = 60, then it matches only leftTime = [50, 65]
+    //      if leftTime = 20, then it match only with rightTime = [15, 30]
+    //
+    // State value predicates
+    //   left side:
+    //     values allowed:  leftTime >= rightTime - 10s   ==>   leftTime > eventTimeWatermark - 10
+    //     drop state where leftTime < eventTime - 10
+    //   right side:
+    //     values allowed:  rightTime >= leftTime - 5s   ==>   rightTime > eventTimeWatermark - 5
+    //     drop state where rightTime < eventTime - 5
+
+    val joined =
+      df1.join(df2, condition).select('leftKey, 'leftTime.cast("int"), 'rightTime.cast("int"))
+
+    testStream(joined)(
+      // If leftTime = 20, then it match only with rightTime = [15, 30]
+      AddData(leftInput, (1, 20)),
+      CheckAnswer(),
+      AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)),
+      CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)),
+      assertNumStateRows(total = 7, updated = 6),
+
+      // If rightTime = 60, then it matches only leftTime = [50, 65]
+      AddData(rightInput, (1, 60)),
+      CheckLastBatch(),                // matches with nothing on the left
+      AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)),
+      CheckLastBatch((1, 50, 60), (1, 65, 60)),
+      assertNumStateRows(total = 12, updated = 4),
+
+      // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30
+      // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=)
+      //    Should drop < 20 from left, i.e., none
+      // Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=)
+      //    Should drop < 25 from the right, i.e., 14 and 15
+      AddData(leftInput, (1, 30), (1, 31)),     // 30 should not be processed or added to stat
+      CheckLastBatch((1, 31, 26), (1, 31, 30), (1, 31, 31)),
+      assertNumStateRows(total = 11, updated = 1),  // 12 - 2 removed + 1 added
+
+      // Advance the watermark
+      AddData(rightInput, (1, 80)),
+      CheckLastBatch(),
+      assertNumStateRows(total = 12, updated = 1),
+
+      // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46
+      // Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=)
+      //    Should drop < 36 from left, i.e., 20, 31 (30 was not added)
+      // Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=)
+      //    Should drop < 41 from the right, i.e., 25, 26, 30, 31
+      AddData(rightInput, (1, 50)),
+      CheckLastBatch((1, 49, 50), (1, 50, 50)),
+      assertNumStateRows(total = 7, updated = 1)  // 12 - 6 removed + 1 added
+    )
+  }
+
+  testQuietly("stream stream inner join without equality predicate") {
+    val input1 = MemoryStream[Int]
+    val input2 = MemoryStream[Int]
+
+    val df1 = input1.toDF.select('value as "leftKey", ('value * 2) as "leftValue")
+    val df2 = input2.toDF.select('value as "rightKey", ('value * 3) as "rightValue")
+    val joined = df1.join(df2, expr("leftKey < rightKey"))
+    val e = intercept[Exception] {
+      val q = joined.writeStream.format("memory").queryName("test").start()
+      input1.addData(1)
+      q.awaitTermination(10000)
+    }
+    assert(e.toString.contains("Stream stream joins without equality predicate is not supported"))
+  }
+
+  testQuietly("extract watermark from time condition") {
+    val attributesToFindConstraintFor = Seq(
+      AttributeReference("leftTime", TimestampType)(),
+      AttributeReference("leftOther", IntegerType)())
+    val metadataWithWatermark = new MetadataBuilder()
+      .putLong(EventTimeWatermark.delayKey, 1000)
+      .build()
+    val attributesWithWatermark = Seq(
+      AttributeReference("rightTime", TimestampType, metadata = metadataWithWatermark)(),
+      AttributeReference("rightOther", IntegerType)())
+
+    def watermarkFrom(
+        conditionStr: String,
+        rightWatermark: Option[Long] = Some(10000)): Option[Long] = {
+      val conditionExpr = Some(conditionStr).map { str =>
+        val plan =
+          Filter(
+            spark.sessionState.sqlParser.parseExpression(str),
+            LogicalRDD(
+              attributesToFindConstraintFor ++ attributesWithWatermark,
+              spark.sparkContext.emptyRDD)(spark))
+        plan.queryExecution.optimizedPlan.asInstanceOf[Filter].condition
+      }
+      StreamingSymmetricHashJoinHelper.getStateValueWatermark(
+        AttributeSet(attributesToFindConstraintFor), AttributeSet(attributesWithWatermark),
+        conditionExpr, rightWatermark)
+    }
+
+    // Test comparison directionality. E.g. if leftTime < rightTime and rightTime > watermark,
+    // then cannot define constraint on leftTime.
+    assert(watermarkFrom("leftTime > rightTime") === Some(10000))
+    assert(watermarkFrom("leftTime >= rightTime") === Some(9999))
+    assert(watermarkFrom("leftTime < rightTime") === None)
+    assert(watermarkFrom("leftTime <= rightTime") === None)
+    assert(watermarkFrom("rightTime > leftTime") === None)
+    assert(watermarkFrom("rightTime >= leftTime") === None)
+    assert(watermarkFrom("rightTime < leftTime") === Some(10000))
+    assert(watermarkFrom("rightTime <= leftTime") === Some(9999))
+
+    // Test type conversions
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") === None)
+    assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS DOUBLE)") === Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") === Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") === Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS FLOAT)") === Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS STRING)") === None)
+
+    // Test with timestamp type + calendar interval on either side of equation
+    // Note: timestamptype and calendar interval don't commute, so less valid combinations to test.
+    assert(watermarkFrom("leftTime > rightTime + interval 1 second") === Some(11000))
+    assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === Some(8000))
+    assert(watermarkFrom("leftTime > rightTime - interval 3 second") === Some(7000))
+    assert(watermarkFrom("rightTime < leftTime - interval 3 second") === Some(13000))
+    assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 3 second")
+      === Some(12000))
+
+    // Test with casted long type + constants on either side of equation
+    // Note: long type and constants commute, so more combinations to test.
+    // -- Constants on the right
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 1") === Some(11000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 1") === Some(9000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 1 second) AS LONG)")
+      === Some(11000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS LONG)") === Some(12000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS LONG)") === Some(9500))
+    assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 2") === Some(12000))
+    assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) > 0.1")
+      === Some(10100))
+    assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) + 0.2")
+      === Some(10200))
+    // -- Constants on the left
+    assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS LONG)") === Some(8000))
+    assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(9000))
+    assert(watermarkFrom("CAST((leftTime  + interval 3 second) AS LONG) > CAST(rightTime AS LONG)")
+      === Some(7000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS LONG)") === Some(12000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS LONG)") === Some(9500))
+    assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 > 0")
+      === Some(12000))
+    assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 0.1 > 0")
+      === Some(10100))
+    // -- Constants on both sides, mixed types
+    assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS LONG) + 1")
+      === Some(13000))
+
+    // Test multiple conditions, should return minimum watermark
+    assert(watermarkFrom(
+      "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 2 seconds") ===
+      Some(7000))  // first condition wins
+    assert(watermarkFrom(
+      "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 4 seconds") ===
+      Some(6000))  // second condition wins
+
+    // Test invalid comparisons
+    assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None)      // non-time attributes
+    assert(watermarkFrom("leftOther > rightOther") === None)                  // non-time attributes
+    assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") === Some(10000))
+    assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None)  // non-time attributes
+    assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) // month not allowed
+
+    // Test static comparisons
+    assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000))
+  }
+
+  test("locality preferences of StateStoreAwareZippedRDD") {
+    import StreamingSymmetricHashJoinHelper._
+
+    withTempDir { tempDir =>
+      val queryId = UUID.randomUUID
+      val opId = 0
+      val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextString(10)).toString
+      val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L)
+
+      implicit val sqlContext = spark.sqlContext
+      val coordinatorRef = sqlContext.streams.stateStoreCoordinator
+      val numPartitions = 5
+      val storeNames = Seq("name1", "name2")
+
+      val partitionAndStoreNameToLocation = {
+        for (partIndex <- 0 until numPartitions; storeName <- storeNames) yield {
+          (partIndex, storeName) -> s"host-$partIndex-$storeName"
+        }
+      }.toMap
+      partitionAndStoreNameToLocation.foreach { case ((partIndex, storeName), hostName) =>
+        val providerId = StateStoreProviderId(stateInfo, partIndex, storeName)
+        coordinatorRef.reportActiveInstance(providerId, hostName, s"exec-$hostName")
+        require(
+          coordinatorRef.getLocation(providerId) ===
+            Some(ExecutorCacheTaskLocation(hostName, s"exec-$hostName").toString))
+      }
+
+      val rdd1 = spark.sparkContext.makeRDD(1 to 10, numPartitions)
+      val rdd2 = spark.sparkContext.makeRDD((1 to 10).map(_.toString), numPartitions)
+      val rdd = rdd1.stateStoreAwareZipPartitions(rdd2, stateInfo, storeNames, coordinatorRef) {
+        (left, right) => left.zip(right)
+      }
+      require(rdd.partitions.length === numPartitions)
+      for (partIndex <- 0 until numPartitions) {
+        val expectedLocations = storeNames.map { storeName =>
+          val hostName = partitionAndStoreNameToLocation((partIndex, storeName))
+          ExecutorCacheTaskLocation(hostName, s"exec-$hostName").toString
+        }.toSet
+        assert(rdd.preferredLocations(rdd.partitions(partIndex)).toSet === expectedLocations)
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/2] spark git commit: [SPARK-22053][SS] Stream-stream inner join in Append Mode

Posted by td...@apache.org.
[SPARK-22053][SS] Stream-stream inner join in Append Mode

## What changes were proposed in this pull request?

#### Architecture
This PR implements stream-stream inner join using a two-way symmetric hash join. At a high level, we want to do the following.

1. For each stream, we maintain the past rows as state in State Store.
  - For each joining key, there can be multiple rows that have been received.
  - So, we have to effectively maintain a key-to-list-of-values multimap as state for each stream.
2. In each batch, for each input row in each stream
  - Look up the other streams state to see if there are matching rows, and output them if they satisfy the joining condition
  - Add the input row to corresponding stream’s state.
  - If the data has a timestamp/window column with watermark, then we will use that to calculate the threshold for keys that are required to buffered for future matches and drop the rest from the state.

Cleaning up old unnecessary state rows depends completely on whether watermark has been defined and what are join conditions. We definitely want to support state clean up two types of queries that are likely to be common.

- Queries to time range conditions - E.g. `SELECT * FROM leftTable, rightTable ON leftKey = rightKey AND leftTime > rightTime - INTERVAL 8 MINUTES AND leftTime < rightTime + INTERVAL 1 HOUR`
- Queries with windows as the matching key - E.g. `SELECT * FROM leftTable, rightTable ON leftKey = rightKey AND window(leftTime, "1 hour") = window(rightTime, "1 hour")` (pseudo-SQL)

#### Implementation
The stream-stream join is primarily implemented in three classes
- `StreamingSymmetricHashJoinExec` implements the above symmetric join algorithm.
- `SymmetricsHashJoinStateManagers` manages the streaming state for the join. This essentially is a fault-tolerant key-to-list-of-values multimap built on the StateStore APIs. `StreamingSymmetricHashJoinExec` instantiates two such managers, one for each join side.
- `StreamingSymmetricHashJoinExecHelper` is a helper class to extract threshold for the state based on the join conditions and the event watermark.

Refer to the scaladocs class for more implementation details.

Besides the implementation of stream-stream inner join SparkPlan. Some additional changes are
- Allowed inner join in append mode in UnsupportedOperationChecker
- Prevented stream-stream join on an empty batch dataframe to be collapsed by the optimizer

## How was this patch tested?
- New tests in StreamingJoinSuite
- Updated tests UnsupportedOperationSuite

Author: Tathagata Das <ta...@gmail.com>

Closes #19271 from tdas/SPARK-22053.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f32a8425
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f32a8425
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f32a8425

Branch: refs/heads/master
Commit: f32a8425051eabdef2d69002cfc843c01d98df0d
Parents: a8a5cd2
Author: Tathagata Das <ta...@gmail.com>
Authored: Thu Sep 21 15:39:07 2017 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Thu Sep 21 15:39:07 2017 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   2 +-
 .../analysis/UnsupportedOperationChecker.scala  |   7 +-
 .../catalyst/expressions/namedExpressions.scala |   9 +-
 .../optimizer/PropagateEmptyRelation.scala      |  25 +-
 .../analysis/UnsupportedOperationsSuite.scala   |  22 +-
 .../spark/sql/execution/SparkStrategies.scala   |  21 +-
 .../streaming/IncrementalExecution.scala        |  11 +
 .../execution/streaming/StreamExecution.scala   |   2 +
 .../StreamingSymmetricHashJoinExec.scala        | 346 ++++++++++++++
 .../StreamingSymmetricHashJoinHelper.scala      | 415 ++++++++++++++++
 .../execution/streaming/state/StateStore.scala  |  21 +
 .../streaming/state/StateStoreCoordinator.scala |   6 +-
 .../state/SymmetricHashJoinStateManager.scala   | 395 ++++++++++++++++
 .../sql/execution/streaming/state/package.scala |   2 -
 .../execution/streaming/statefulOperators.scala |  55 ++-
 .../SymmetricHashJoinStateManagerSuite.scala    | 172 +++++++
 .../sql/streaming/StateStoreMetricsTest.scala   |   2 +-
 .../sql/streaming/StreamingJoinSuite.scala      | 472 +++++++++++++++++++
 18 files changed, 1940 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 45ec204..8edf575 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2390,7 +2390,7 @@ object TimeWindowing extends Rule[LogicalPlan] {
 
         if (window.windowDuration == window.slideDuration) {
           val windowStruct = Alias(getWindow(0, 1), WINDOW_COL_NAME)(
-            exprId = windowAttr.exprId)
+            exprId = windowAttr.exprId, explicitMetadata = Some(metadata))
 
           val replacedPlan = p transformExpressions {
             case t: TimeWindow => windowAttr

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 33ba086..d1d7056 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -222,8 +222,10 @@ object UnsupportedOperationChecker {
           joinType match {
 
             case _: InnerLike =>
-              if (left.isStreaming && right.isStreaming) {
-                throwError("Inner join between two streaming DataFrames/Datasets is not supported")
+              if (left.isStreaming && right.isStreaming &&
+                outputMode != InternalOutputModes.Append) {
+                throwError("Inner join between two streaming DataFrames/Datasets is not supported" +
+                  s" in ${outputMode} output mode, only in Append output mode")
               }
 
             case FullOuter =>
@@ -231,7 +233,6 @@ object UnsupportedOperationChecker {
                 throwError("Full outer joins with streaming DataFrames/Datasets are not supported")
               }
 
-
             case LeftOuter | LeftSemi | LeftAnti =>
               if (right.isStreaming) {
                 throwError("Left outer/semi/anti joins with a streaming DataFrame/Dataset " +

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index b898484..e518e73 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -164,7 +164,14 @@ case class Alias(child: Expression, name: String)(
     }
   }
 
-  override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix"
+  /** Used to signal the column used to calculate an eventTime watermark (e.g. a#1-T{delayMs}) */
+  private def delaySuffix = if (metadata.contains(EventTimeWatermark.delayKey)) {
+    s"-T${metadata.getLong(EventTimeWatermark.delayKey)}ms"
+  } else {
+    ""
+  }
+
+  override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix$delaySuffix"
 
   override protected final def otherCopyArgs: Seq[AnyRef] = {
     exprId :: qualifier :: explicitMetadata :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index cfffa6b..52fbb4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -45,14 +45,19 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
     case p: Union if p.children.forall(isEmptyLocalRelation) =>
       empty(p)
 
-    case p @ Join(_, _, joinType, _) if p.children.exists(isEmptyLocalRelation) => joinType match {
-      case _: InnerLike => empty(p)
-      // Intersect is handled as LeftSemi by `ReplaceIntersectWithSemiJoin` rule.
-      // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule.
-      case LeftOuter | LeftSemi | LeftAnti if isEmptyLocalRelation(p.left) => empty(p)
-      case RightOuter if isEmptyLocalRelation(p.right) => empty(p)
-      case FullOuter if p.children.forall(isEmptyLocalRelation) => empty(p)
-      case _ => p
+    // Joins on empty LocalRelations generated from streaming sources are not eliminated
+    // as stateful streaming joins need to perform other state management operations other than
+    // just processing the input data.
+    case p @ Join(_, _, joinType, _)
+        if !p.children.exists(_.isStreaming) && p.children.exists(isEmptyLocalRelation) =>
+      joinType match {
+        case _: InnerLike => empty(p)
+        // Intersect is handled as LeftSemi by `ReplaceIntersectWithSemiJoin` rule.
+        // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule.
+        case LeftOuter | LeftSemi | LeftAnti if isEmptyLocalRelation(p.left) => empty(p)
+        case RightOuter if isEmptyLocalRelation(p.right) => empty(p)
+        case FullOuter if p.children.forall(isEmptyLocalRelation) => empty(p)
+        case _ => p
     }
 
     case p: UnaryNode if p.children.nonEmpty && p.children.forall(isEmptyLocalRelation) => p match {
@@ -74,6 +79,10 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
       //
       // If the grouping expressions are empty, however, then the aggregate will always produce a
       // single output row and thus we cannot propagate the EmptyRelation.
+      //
+      // Aggregation on empty LocalRelation generated from a streaming source is not eliminated
+      // as stateful streaming aggregation need to perform other state management operations other
+      // than just processing the input data.
       case Aggregate(ge, _, _) if ge.nonEmpty && !p.isStreaming => empty(p)
       // Generators like Hive-style UDTF may return their records within `close`.
       case Generate(_: Explode, _, _, _, _, _) => empty(p)

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index 4de7586..11f48a3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -383,11 +383,27 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
     outputMode = Append
   )
 
-  // Inner joins: Stream-stream not supported
+  // Inner joins: Multiple stream-stream joins supported only in append mode
   testBinaryOperationInStreamingPlan(
-    "inner join",
+    "single inner join in append mode",
     _.join(_, joinType = Inner),
-    streamStreamSupported = false)
+    outputMode = Append,
+    streamStreamSupported = true)
+
+  testBinaryOperationInStreamingPlan(
+    "multiple inner joins in append mode",
+    (x: LogicalPlan, y: LogicalPlan) => {
+      x.join(y, joinType = Inner).join(streamRelation, joinType = Inner)
+    },
+    outputMode = Append,
+    streamStreamSupported = true)
+
+  testBinaryOperationInStreamingPlan(
+    "inner join in update mode",
+    _.join(_, joinType = Inner),
+    outputMode = Update,
+    streamStreamSupported = false,
+    expectedMsg = "inner join")
 
   // Full outer joins: only batch-batch is allowed
   testBinaryOperationInStreamingPlan(

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 6b16408..4da7a73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Strategy
+import org.apache.spark.sql.{execution, AnalysisException, Strategy}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
 import org.apache.spark.sql.execution.command._
 import org.apache.spark.sql.execution.exchange.ShuffleExchange
@@ -257,6 +256,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     }
   }
 
+  object StreamingJoinStrategy extends Strategy {
+    override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
+      plan match {
+        case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
+          if left.isStreaming && right.isStreaming =>
+
+          new StreamingSymmetricHashJoinExec(
+            leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
+
+        case Join(left, right, _, _) if left.isStreaming && right.isStreaming =>
+          throw new AnalysisException(
+            "Stream stream joins without equality predicate is not supported", plan = Some(plan))
+
+        case _ => Nil
+      }
+    }
+  }
+
   /**
    * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 027222e..8e0aae3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -54,6 +54,7 @@ class IncrementalExecution(
       sparkSession.sessionState.planner.strategies
 
     override def extraPlanningStrategies: Seq[Strategy] =
+      StreamingJoinStrategy ::
       StatefulAggregationStrategy ::
       FlatMapGroupsWithStateStrategy ::
       StreamingRelationStrategy ::
@@ -116,6 +117,16 @@ class IncrementalExecution(
           stateInfo = Some(nextStatefulOperationStateInfo),
           batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
           eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs))
+
+      case j: StreamingSymmetricHashJoinExec =>
+        j.copy(
+          stateInfo = Some(nextStatefulOperationStateInfo),
+          eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs),
+          stateWatermarkPredicates =
+            StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates(
+              j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition,
+              Some(offsetSeqMetadata.batchWatermarkMs))
+        )
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 18385f5..b2d6c60 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -297,6 +297,8 @@ class StreamExecution(
       val sparkSessionToRunBatches = sparkSession.cloneSession()
       // Adaptive execution can change num shuffle partitions, disallow
       sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
+      // Disable cost-based join optimization as we do not want stateful operations to be rearranged
+      sparkSessionToRunBatches.conf.set(SQLConf.CBO_ENABLED.key, "false")
       offsetSeqMetadata = OffsetSeqMetadata(
         batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionToRunBatches.conf)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
new file mode 100644
index 0000000..44f1fa5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -0,0 +1,346 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent.TimeUnit.NANOSECONDS
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
+import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.internal.SessionState
+import org.apache.spark.sql.types.{LongType, TimestampType}
+import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}
+
+
+/**
+ * Performs stream-stream join using symmetric hash join algorithm. It works as follows.
+ *
+ *                             /-----------------------\
+ *   left side input --------->|    left side state    |------\
+ *                             \-----------------------/      |
+ *                                                            |--------> joined output
+ *                             /-----------------------\      |
+ *   right side input -------->|    right side state   |------/
+ *                             \-----------------------/
+ *
+ * Each join side buffers past input rows as streaming state so that the past input can be joined
+ * with future input on the other side. This buffer state is effectively a multi-map:
+ *    equi-join key -> list of past input rows received with the join key
+ *
+ * For each input row in each side, the following operations take place.
+ * - Calculate join key from the row.
+ * - Use the join key to append the row to the buffer state of the side that the row came from.
+ * - Find past buffered values for the key from the other side. For each such value, emit the
+ *   "joined row" (left-row, right-row)
+ * - Apply the optional condition to filter the joined rows as the final output.
+ *
+ * If a timestamp column with event time watermark is present in the join keys or in the input
+ * data, then the it uses the watermark figure out which rows in the buffer will not join with
+ * and the new data, and therefore can be discarded. Depending on the provided query conditions, we
+ * can define thresholds on both state key (i.e. joining keys) and state value (i.e. input rows).
+ * There are three kinds of queries possible regarding this as explained below.
+ * Assume that watermark has been defined on both `leftTime` and `rightTime` columns used below.
+ *
+ * 1. When timestamp/time-window + watermark is in the join keys. Example (pseudo-SQL):
+ *
+ *      SELECT * FROM leftTable, rightTable
+ *      ON
+ *        leftKey = rightKey AND
+ *        window(leftTime, "1 hour") = window(rightTime, "1 hour")    // 1hr tumbling windows
+ *
+ *    In this case, this operator will join rows newer than watermark which fall in the same
+ *    1 hour window. Say the event-time watermark is "12:34" (both left and right input).
+ *    Then input rows can only have time > 12:34. Hence, they can only join with buffered rows
+ *    where window >= 12:00 - 1:00 and all buffered rows with join window < 12:00 can be
+ *    discarded. In other words, the operator will discard all state where
+ *    window in state key (i.e. join key) < event time watermark. This threshold is called
+ *    State Key Watermark.
+ *
+ * 2. When timestamp range conditions are provided (no time/window + watermark in join keys). E.g.
+ *
+ *      SELECT * FROM leftTable, rightTable
+ *      ON
+ *        leftKey = rightKey AND
+ *        leftTime > rightTime - INTERVAL 8 MINUTES AND leftTime < rightTime + INTERVAL 1 HOUR
+ *
+ *   In this case, the event-time watermark and the BETWEEN condition can be used to calculate a
+ *   state watermark, i.e., time threshold for the state rows that can be discarded.
+ *   For example, say each join side has a time column, named "leftTime" and
+ *   "rightTime", and there is a join condition "leftTime > rightTime - 8 min".
+ *   While processing, say the watermark on right input is "12:34". This means that from henceforth,
+ *   only right inputs rows with "rightTime > 12:34" will be processed, and any older rows will be
+ *   considered as "too late" and therefore dropped. Then, the left side buffer only needs
+ *   to keep rows where "leftTime > rightTime - 8 min > 12:34 - 8m > 12:26".
+ *   That is, the left state watermark is 12:26, and any rows older than that can be dropped from
+ *   the state. In other words, the operator will discard all state where
+ *   timestamp in state value (input rows) < state watermark. This threshold is called
+ *   State Value Watermark (to distinguish from the state key watermark).
+ *
+ *   Note:
+ *   - The event watermark value of one side is used to calculate the
+ *     state watermark of the other side. That is, a condition ~ "leftTime > rightTime + X" with
+ *     right side event watermark is used to calculate the left side state watermark. Conversely,
+ *     a condition ~ "left < rightTime + Y" with left side event watermark is used to calculate
+ *     right side state watermark.
+ *   - Depending on the conditions, the state watermark maybe different for the left and right
+ *     side. In the above example, leftTime > 12:26 AND rightTime > 12:34 - 1 hour = 11:34.
+ *   - State can be dropped from BOTH sides only when there are conditions of the above forms that
+ *     define time bounds on timestamp in both directions.
+ *
+ * 3. When both window in join key and time range conditions are present, case 1 + 2.
+ *    In this case, since window equality is a stricter condition than the time range, we can
+ *    use the the State Key Watermark = event time watermark to discard state (similar to case 1).
+ *
+ * @param leftKeys  Expression to generate key rows for joining from left input
+ * @param rightKeys Expression to generate key rows for joining from right input
+ * @param joinType  Type of join (inner, left outer, etc.)
+ * @param condition Optional, additional condition to filter output of the equi-join
+ * @param stateInfo Version information required to read join state (buffered rows)
+ * @param eventTimeWatermark Watermark of input event, same for both sides
+ * @param stateWatermarkPredicates Predicates for removal of state, see
+ *                                 [[JoinStateWatermarkPredicates]]
+ * @param left      Left child plan
+ * @param right     Right child plan
+ */
+case class StreamingSymmetricHashJoinExec(
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    joinType: JoinType,
+    condition: Option[Expression],
+    stateInfo: Option[StatefulOperatorStateInfo],
+    eventTimeWatermark: Option[Long],
+    stateWatermarkPredicates: JoinStateWatermarkPredicates,
+    left: SparkPlan,
+    right: SparkPlan) extends SparkPlan with BinaryExecNode with StateStoreWriter {
+
+  def this(
+      leftKeys: Seq[Expression],
+      rightKeys: Seq[Expression],
+      joinType: JoinType,
+      condition: Option[Expression],
+      left: SparkPlan,
+      right: SparkPlan) = {
+    this(
+      leftKeys, rightKeys, joinType, condition, stateInfo = None, eventTimeWatermark = None,
+      stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right)
+  }
+
+  require(joinType == Inner, s"${getClass.getSimpleName} should not take $joinType as the JoinType")
+  require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType))
+
+  private val storeConf = new StateStoreConf(sqlContext.conf)
+  private val hadoopConfBcast = sparkContext.broadcast(
+    new SerializableConfiguration(SessionState.newHadoopConf(
+      sparkContext.hadoopConfiguration, sqlContext.conf)))
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+  override def output: Seq[Attribute] = left.output ++ right.output
+
+  override def outputPartitioning: Partitioning = joinType match {
+    case _: InnerLike =>
+      PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+    case x =>
+      throw new IllegalArgumentException(
+        s"${getClass.getSimpleName} should not take $x as the JoinType")
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator
+    val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
+    left.execute().stateStoreAwareZipPartitions(
+      right.execute(), stateInfo.get, stateStoreNames, stateStoreCoord)(processPartitions)
+  }
+
+  private def processPartitions(
+      leftInputIter: Iterator[InternalRow],
+      rightInputIter: Iterator[InternalRow]): Iterator[InternalRow] = {
+    if (stateInfo.isEmpty) {
+      throw new IllegalStateException(s"Cannot execute join as state info was not specified\n$this")
+    }
+
+    val numOutputRows = longMetric("numOutputRows")
+    val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+    val numTotalStateRows = longMetric("numTotalStateRows")
+    val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
+    val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
+    val commitTimeMs = longMetric("commitTimeMs")
+    val stateMemory = longMetric("stateMemory")
+
+    val updateStartTimeNs = System.nanoTime
+    val joinedRow = new JoinedRow
+
+    val leftSideJoiner = new OneSideHashJoiner(
+      LeftSide, left.output, leftKeys, leftInputIter, stateWatermarkPredicates.left)
+    val rightSideJoiner = new OneSideHashJoiner(
+      RightSide, right.output, rightKeys, rightInputIter, stateWatermarkPredicates.right)
+
+    //  Join one side input using the other side's buffered/state rows. Here is how it is done.
+    //
+    //  - `leftJoiner.joinWith(rightJoiner)` generates all rows from matching new left input with
+    //    stored right input, and also stores all the left input
+    //
+    //  - `rightJoiner.joinWith(leftJoiner)` generates all rows from matching new right input with
+    //    stored left input, and also stores all the right input. It also generates all rows from
+    //    matching new left input with new right input, since the new left input has become stored
+    //    by that point. This tiny asymmetry is necessary to avoid duplication.
+    val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) {
+      (inputRow: UnsafeRow, matchedRow: UnsafeRow) =>
+        joinedRow.withLeft(inputRow).withRight(matchedRow)
+    }
+    val rightOutputIter = rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) {
+      (inputRow: UnsafeRow, matchedRow: UnsafeRow) =>
+        joinedRow.withLeft(matchedRow).withRight(inputRow)
+    }
+
+    // Filter the joined rows based on the given condition.
+    val outputFilterFunction =
+      newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output).eval _
+    val filteredOutputIter =
+      (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map { row =>
+        numOutputRows += 1
+        row
+      }
+
+    // Function to remove old state after all the input has been consumed and output generated
+    def onOutputCompletion = {
+      allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - updateStartTimeNs), 0)
+
+      // Remove old state if needed
+      allRemovalsTimeMs += timeTakenMs {
+        leftSideJoiner.removeOldState()
+        rightSideJoiner.removeOldState()
+      }
+
+      // Commit all state changes and update state store metrics
+      commitTimeMs += timeTakenMs {
+        val leftSideMetrics = leftSideJoiner.commitStateAndGetMetrics()
+        val rightSideMetrics = rightSideJoiner.commitStateAndGetMetrics()
+        val combinedMetrics = StateStoreMetrics.combine(Seq(leftSideMetrics, rightSideMetrics))
+
+        // Update SQL metrics
+        numUpdatedStateRows +=
+          (leftSideJoiner.numUpdatedStateRows + rightSideJoiner.numUpdatedStateRows)
+        numTotalStateRows += combinedMetrics.numKeys
+        stateMemory += combinedMetrics.memoryUsedBytes
+        combinedMetrics.customMetrics.foreach { case (metric, value) =>
+          longMetric(metric.name) += value
+        }
+      }
+    }
+
+    CompletionIterator[InternalRow, Iterator[InternalRow]](filteredOutputIter, onOutputCompletion)
+  }
+
+  /**
+   * Internal helper class to consume input rows, generate join output rows using other sides
+   * buffered state rows, and finally clean up this sides buffered state rows
+   */
+  private class OneSideHashJoiner(
+      joinSide: JoinSide,
+      inputAttributes: Seq[Attribute],
+      joinKeys: Seq[Expression],
+      inputIter: Iterator[InternalRow],
+      stateWatermarkPredicate: Option[JoinStateWatermarkPredicate]) {
+
+    private val joinStateManager = new SymmetricHashJoinStateManager(
+      joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value)
+    private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes)
+
+    private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match {
+      case Some(JoinStateKeyWatermarkPredicate(expr)) =>
+        // inputSchema can be empty as expr should only have BoundReferences and does not require
+        // the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]].
+        newPredicate(expr, Seq.empty).eval _
+      case _ =>
+        newPredicate(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
+    }
+
+    private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match {
+      case Some(JoinStateValueWatermarkPredicate(expr)) =>
+        newPredicate(expr, inputAttributes).eval _
+      case _ =>
+        newPredicate(Literal(false), Seq.empty).eval _  // false = do not remove if no predicate
+    }
+
+    private[this] var updatedStateRowsCount = 0
+
+    /**
+     * Generate joined rows by consuming input from this side, and matching it with the buffered
+     * rows (i.e. state) of the other side.
+     * @param otherSideJoiner   Joiner of the other side
+     * @param generateJoinedRow Function to generate the joined row from the
+     *                          input row from this side and the matched row from the other side
+     */
+    def storeAndJoinWithOtherSide(
+        otherSideJoiner: OneSideHashJoiner)(
+        generateJoinedRow: (UnsafeRow, UnsafeRow) => JoinedRow): Iterator[InternalRow] = {
+
+      val watermarkAttribute = inputAttributes.find(_.metadata.contains(delayKey))
+      val nonLateRows =
+        WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match {
+          case Some(watermarkExpr) =>
+            val predicate = newPredicate(watermarkExpr, inputAttributes)
+            inputIter.filter { row => !predicate.eval(row) }
+          case None =>
+            inputIter
+        }
+
+      nonLateRows.flatMap { row =>
+        val thisRow = row.asInstanceOf[UnsafeRow]
+        val key = keyGenerator(thisRow)
+        val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow =>
+          generateJoinedRow(thisRow, thatRow)
+        }
+        val shouldAddToState = // add only if both removal predicates do not match
+          !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow)
+        if (shouldAddToState) {
+          joinStateManager.append(key, thisRow)
+          updatedStateRowsCount += 1
+        }
+        outputIter
+      }
+    }
+
+    /** Remove old buffered state rows using watermarks for state keys and values */
+    def removeOldState(): Unit = {
+      stateWatermarkPredicate match {
+        case Some(JoinStateKeyWatermarkPredicate(expr)) =>
+          joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc)
+        case Some(JoinStateValueWatermarkPredicate(expr)) =>
+          joinStateManager.removeByValueCondition(stateValueWatermarkPredicateFunc)
+        case _ =>
+      }
+    }
+
+    /** Commit changes to the buffer state and return the state store metrics */
+    def commitStateAndGetMetrics(): StateStoreMetrics = {
+      joinStateManager.commit()
+      joinStateManager.metrics
+    }
+
+    def numUpdatedStateRows: Long = updatedStateRowsCount
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
new file mode 100644
index 0000000..e50274a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import scala.reflect.ClassTag
+import scala.util.control.NonFatal
+
+import org.apache.spark.{Partition, SparkContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2}
+import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
+import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression
+import org.apache.spark.sql.execution.streaming.state.{StateStoreCoordinatorRef, StateStoreProvider, StateStoreProviderId}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+
+/**
+ * Helper object for [[StreamingSymmetricHashJoinExec]]. See that object for more details.
+ */
+object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging {
+
+  sealed trait JoinSide
+  case object LeftSide extends JoinSide { override def toString(): String = "left" }
+  case object RightSide extends JoinSide { override def toString(): String = "right" }
+
+  sealed trait JoinStateWatermarkPredicate {
+    def expr: Expression
+    def desc: String
+    override def toString: String = s"$desc: $expr"
+  }
+  /** Predicate for watermark on state keys */
+  case class JoinStateKeyWatermarkPredicate(expr: Expression)
+    extends JoinStateWatermarkPredicate {
+    def desc: String = "key predicate"
+  }
+  /** Predicate for watermark on state values */
+  case class JoinStateValueWatermarkPredicate(expr: Expression)
+    extends JoinStateWatermarkPredicate {
+    def desc: String = "value predicate"
+  }
+
+  case class JoinStateWatermarkPredicates(
+    left: Option[JoinStateWatermarkPredicate] = None,
+    right: Option[JoinStateWatermarkPredicate] = None) {
+    override def toString(): String = {
+      s"state cleanup [ left ${left.map(_.toString).getOrElse("= null")}, " +
+        s"right ${right.map(_.toString).getOrElse("= null")} ]"
+    }
+  }
+
+  /** Get the predicates defining the state watermarks for both sides of the join */
+  def getStateWatermarkPredicates(
+      leftAttributes: Seq[Attribute],
+      rightAttributes: Seq[Attribute],
+      leftKeys: Seq[Expression],
+      rightKeys: Seq[Expression],
+      condition: Option[Expression],
+      eventTimeWatermark: Option[Long]): JoinStateWatermarkPredicates = {
+
+
+    // Join keys of both sides generate rows of the same fields, that is, same sequence of data
+    // types. If one side (say left side) has a column (say timestmap) that has a watermark on it,
+    // then it will never consider joining keys that are < state key watermark (i.e. event time
+    // watermark). On the other side (i.e. right side), even if there is no watermark defined,
+    // there has to be an equivalent column (i.e., timestamp). And any right side data that has the
+    // timestamp < watermark will not match will not match with left side data, as the left side get
+    // filtered with the explicitly defined watermark. So, the watermark in timestamp column in
+    // left side keys effectively causes the timestamp on the right side to have a watermark.
+    // We will use the ordinal of the left timestamp in the left keys to find the corresponding
+    // right timestamp in the right keys.
+    val joinKeyOrdinalForWatermark: Option[Int] = {
+      leftKeys.zipWithIndex.collectFirst {
+        case (ne: NamedExpression, index) if ne.metadata.contains(delayKey) => index
+      } orElse {
+        rightKeys.zipWithIndex.collectFirst {
+          case (ne: NamedExpression, index) if ne.metadata.contains(delayKey) => index
+        }
+      }
+    }
+
+    def getOneSideStateWatermarkPredicate(
+        oneSideInputAttributes: Seq[Attribute],
+        oneSideJoinKeys: Seq[Expression],
+        otherSideInputAttributes: Seq[Attribute]): Option[JoinStateWatermarkPredicate] = {
+      val isWatermarkDefinedOnInput = oneSideInputAttributes.exists(_.metadata.contains(delayKey))
+      val isWatermarkDefinedOnJoinKey = joinKeyOrdinalForWatermark.isDefined
+
+      if (isWatermarkDefinedOnJoinKey) { // case 1 and 3 in the StreamingSymmetricHashJoinExec docs
+        val keyExprWithWatermark = BoundReference(
+          joinKeyOrdinalForWatermark.get,
+          oneSideJoinKeys(joinKeyOrdinalForWatermark.get).dataType,
+          oneSideJoinKeys(joinKeyOrdinalForWatermark.get).nullable)
+        val expr = watermarkExpression(Some(keyExprWithWatermark), eventTimeWatermark)
+        expr.map(JoinStateKeyWatermarkPredicate.apply _)
+
+      } else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs
+        val stateValueWatermark = getStateValueWatermark(
+          attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes),
+          attributesWithEventWatermark = AttributeSet(otherSideInputAttributes),
+          condition,
+          eventTimeWatermark)
+        val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey))
+        val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark)
+        expr.map(JoinStateValueWatermarkPredicate.apply _)
+
+      } else {
+        None
+      }
+    }
+
+    val leftStateWatermarkPredicate =
+      getOneSideStateWatermarkPredicate(leftAttributes, leftKeys, rightAttributes)
+    val rightStateWatermarkPredicate =
+      getOneSideStateWatermarkPredicate(rightAttributes, rightKeys, leftAttributes)
+    JoinStateWatermarkPredicates(leftStateWatermarkPredicate, rightStateWatermarkPredicate)
+  }
+
+  /**
+   * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it)
+   * given the join condition and the event time watermark. This is how it works.
+   * - The condition is split into conjunctive predicates, and we find the predicates of the
+   *   form `leftTime + c1 < rightTime + c2`   (or <=, >, >=).
+   * - We canoncalize the predicate and solve it with the event time watermark value to find the
+   *  value of the state watermark.
+   * This function is supposed to make best-effort attempt to get the state watermark. If there is
+   * any error, it will return None.
+   *
+   * @param attributesToFindStateWatermarkFor attributes of the side whose state watermark
+   *                                         is to be calculated
+   * @param attributesWithEventWatermark  attributes of the other side which has a watermark column
+   * @param joinCondition                 join condition
+   * @param eventWatermark                watermark defined on the input event data
+   * @return state value watermark in milliseconds, is possible.
+   */
+  def getStateValueWatermark(
+      attributesToFindStateWatermarkFor: AttributeSet,
+      attributesWithEventWatermark: AttributeSet,
+      joinCondition: Option[Expression],
+      eventWatermark: Option[Long]): Option[Long] = {
+
+    // If condition or event time watermark is not provided, then cannot calculate state watermark
+    if (joinCondition.isEmpty || eventWatermark.isEmpty) return None
+
+    // If there is not watermark attribute, then cannot define state watermark
+    if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return None
+
+    def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = {
+      try {
+        getStateWatermarkFromLessThenPredicate(
+          l, r, attributesToFindStateWatermarkFor, attributesWithEventWatermark, eventWatermark)
+      } catch {
+        case NonFatal(e) =>
+          logWarning(s"Error trying to extract state constraint from condition $joinCondition", e)
+          None
+      }
+    }
+
+    val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate =>
+
+      // The generated the state watermark cleanup expression is inclusive of the state watermark.
+      // If state watermark is W, all state where timestamp <= W will be cleaned up.
+      // Now when the canonicalized join condition solves to leftTime >= W, we dont want to clean
+      // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence the -1 below.
+      val stateWatermark = predicate match {
+        case LessThan(l, r) => getStateWatermarkSafely(l, r)
+        case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1)
+        case GreaterThan(l, r) => getStateWatermarkSafely(r, l)
+        case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1)
+        case _ => None
+      }
+      if (stateWatermark.nonEmpty) {
+        logInfo(s"Condition $joinCondition generated watermark constraint = ${stateWatermark.get}")
+      }
+      stateWatermark
+    }
+    allStateWatermarks.reduceOption((x, y) => Math.min(x, y))
+  }
+
+  /**
+   * Extract the state value watermark (milliseconds) from the condition
+   * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find the constraint for
+   * leftTime using the watermark on the rightTime. Example:
+   *
+   * Input:                 rightTime-with-watermark + c1 < leftTime + c2
+   * Canonical form:        rightTime-with-watermark + c1 + (-c2) + (-leftTime) < 0
+   * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime
+   * With watermark value:  watermark-value + c1 + (-c2) < leftTime
+   */
+  private def getStateWatermarkFromLessThenPredicate(
+      leftExpr: Expression,
+      rightExpr: Expression,
+      attributesToFindStateWatermarkFor: AttributeSet,
+      attributesWithEventWatermark: AttributeSet,
+      eventWatermark: Option[Long]): Option[Long] = {
+
+    val attributesInCondition = AttributeSet(
+      leftExpr.collect { case a: AttributeReference => a } ++
+      rightExpr.collect { case a: AttributeReference => a }
+    )
+    if (attributesInCondition.filter { attributesToFindStateWatermarkFor.contains(_) }.size > 1 ||
+        attributesInCondition.filter { attributesWithEventWatermark.contains(_) }.size > 1) {
+      // If more than attributes present in condition from one side, then it cannot be solved
+      return None
+    }
+
+    def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = {
+      e.collectLeaves().collectFirst {
+        case a @ AttributeReference(_, TimestampType, _, _)
+          if attributesToFindStateWatermarkFor.contains(a) => a
+      }.nonEmpty
+    }
+
+    // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - (leftTime + c2) < 0
+    val allOnLeftExpr = Subtract(leftExpr, rightExpr)
+    logDebug(s"All on Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}")
+
+    // Canonicalization step 2: extract commutative terms
+    //    rightTime-with-watermark, c1, -leftTime, -c2
+    val terms = ExpressionSet(collectTerms(allOnLeftExpr))
+    logDebug("Terms extracted from join condition:\n\t" + terms.mkString("\n\t"))
+
+
+
+    // Find the term that has leftTime (i.e. the one present in attributesToFindConstraintFor
+    val constraintTerms = terms.filter(containsAttributeToFindStateConstraintFor)
+
+    // Verify there is only one correct constraint term and of the correct type
+    if (constraintTerms.size > 1) {
+      logWarning("Failed to extract state constraint terms: multiple time terms in condition\n\t" +
+        terms.mkString("\n\t"))
+      return None
+    }
+    if (constraintTerms.isEmpty) {
+      logDebug("Failed to extract state constraint terms: no time terms in condition\n\t" +
+        terms.mkString("\n\t"))
+      return None
+    }
+    val constraintTerm = constraintTerms.head
+    if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) {
+      // Incorrect condition. We want the constraint term in canonical form to be `-leftTime`
+      // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`.
+      // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark
+      // condition is `rightTime-with-watermark > watermarkValue`, then no constraint about
+      // `leftTime` can be inferred. In this case, after canonicalization and collection of terms,
+      // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we return None.
+      return None
+    }
+
+    // Replace watermark attribute with watermark value, and generate the resolved expression
+    // from the other terms. That is,
+    // rightTime-with-watermark, c1, -c2  =>  watermark, c1, -c2  =>  watermark + c1 + (-c2)
+    logDebug(s"Constraint term from join condition:\t$constraintTerm")
+    val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term =>
+      term.transform {
+        case a @ AttributeReference(_, TimestampType, _, metadata)
+          if attributesWithEventWatermark.contains(a) && metadata.contains(delayKey) =>
+          Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0))
+      }
+    }.reduceLeft(Add)
+
+    // Calculate the constraint value
+    logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted")
+    val constraintValue = exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double]
+    Some((Double2double(constraintValue) / 1000.0).toLong)
+  }
+
+  /**
+   * Collect all the terms present in an expression after converting it into the form
+   * a + b + c + d where each term be either an attribute or a literal casted to long,
+   * optionally wrapped in a unary minus.
+   */
+  private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = {
+    var invalid = false
+
+    /** Wrap a term with UnaryMinus if its needs to be negated. */
+    def negateIfNeeded(expr: Expression, minus: Boolean): Expression = {
+      if (minus) UnaryMinus(expr) else expr
+    }
+
+    /**
+     * Recursively split the expression into its leaf terms contains attributes or literals.
+     * Returns terms only of the forms:
+     *    Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)),
+     *    Cast(AttributeReference, Double), UnaryMinus(Cast(AttributeReference, Double))
+     *    Multiply(Literal), UnaryMinus(Multiply(Literal))
+     *    Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal)))
+     *
+     * Note:
+     * - If term needs to be negated for making it a commutative term,
+     *   then it will be wrapped in UnaryMinus(...)
+     * - Each terms will be representing timestamp value or time interval in microseconds,
+     *   typed as doubles.
+     */
+    def collect(expr: Expression, negate: Boolean): Seq[Expression] = {
+      expr match {
+        case Add(left, right) =>
+          collect(left, negate) ++ collect(right, negate)
+        case Subtract(left, right) =>
+          collect(left, negate) ++ collect(right, !negate)
+        case TimeAdd(left, right, _) =>
+          collect(left, negate) ++ collect(right, negate)
+        case TimeSub(left, right, _) =>
+          collect(left, negate) ++ collect(right, !negate)
+        case UnaryMinus(child) =>
+          collect(child, !negate)
+        case CheckOverflow(child, _) =>
+          collect(child, negate)
+        case Cast(child, dataType, _) =>
+          dataType match {
+            case _: NumericType | _: TimestampType => collect(child, negate)
+            case _ =>
+              invalid = true
+              Seq.empty
+          }
+        case a: AttributeReference =>
+          val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) else a
+          Seq(negateIfNeeded(castedRef, negate))
+        case lit: Literal =>
+          // If literal of type calendar interval, then explicitly convert to millis
+          // Convert other number like literal to doubles representing millis (by x1000)
+          val castedLit = lit.dataType match {
+            case CalendarIntervalType =>
+              val calendarInterval = lit.value.asInstanceOf[CalendarInterval]
+              if (calendarInterval.months > 0) {
+                invalid = true
+                logWarning(
+                  s"Failed to extract state value watermark from condition $exprToCollectFrom " +
+                    s"as imprecise intervals like months and years cannot be used for" +
+                    s"watermark calculation. Use interval in terms of day instead.")
+                Literal(0.0)
+              } else {
+                Literal(calendarInterval.microseconds.toDouble)
+              }
+            case DoubleType =>
+              Multiply(lit, Literal(1000000.0))
+            case _: NumericType =>
+              Multiply(Cast(lit, DoubleType), Literal(1000000.0))
+            case _: TimestampType =>
+              Multiply(PreciseTimestampConversion(lit, TimestampType, LongType), Literal(1000000.0))
+          }
+          Seq(negateIfNeeded(castedLit, negate))
+        case a @ _ =>
+          logWarning(
+            s"Failed to extract state value watermark from condition $exprToCollectFrom due to $a")
+          invalid = true
+          Seq.empty
+      }
+    }
+
+    val terms = collect(exprToCollectFrom, negate = false)
+    if (!invalid) terms else Seq.empty
+  }
+
+  /**
+   * A custom RDD that allows partitions to be "zipped" together, while ensuring the tasks'
+   * preferred location is based on which executors have the required join state stores already
+   * loaded. This is class is a modified verion of [[ZippedPartitionsRDD2]].
+   */
+  class StateStoreAwareZipPartitionsRDD[A: ClassTag, B: ClassTag, V: ClassTag](
+      sc: SparkContext,
+      f: (Iterator[A], Iterator[B]) => Iterator[V],
+      rdd1: RDD[A],
+      rdd2: RDD[B],
+      stateInfo: StatefulOperatorStateInfo,
+      stateStoreNames: Seq[String],
+      @transient private val storeCoordinator: Option[StateStoreCoordinatorRef])
+      extends ZippedPartitionsRDD2[A, B, V](sc, f, rdd1, rdd2) {
+
+    /**
+     * Set the preferred location of each partition using the executor that has the related
+     * [[StateStoreProvider]] already loaded.
+     */
+    override def getPreferredLocations(partition: Partition): Seq[String] = {
+      stateStoreNames.flatMap { storeName =>
+        val stateStoreProviderId = StateStoreProviderId(stateInfo, partition.index, storeName)
+        storeCoordinator.flatMap(_.getLocation(stateStoreProviderId))
+      }.distinct
+    }
+  }
+
+  implicit class StateStoreAwareZipPartitionsHelper[T: ClassTag](dataRDD: RDD[T]) {
+    /**
+     * Function used by `StreamingSymmetricHashJoinExec` to zip together the partitions of two
+     * child RDDs for joining the data in corresponding partitions, while ensuring the tasks'
+     * preferred location is based on which executors have the required join state stores already
+     * loaded.
+     */
+    def stateStoreAwareZipPartitions[U: ClassTag, V: ClassTag](
+        dataRDD2: RDD[U],
+        stateInfo: StatefulOperatorStateInfo,
+        storeNames: Seq[String],
+        storeCoordinator: StateStoreCoordinatorRef
+      )(f: (Iterator[T], Iterator[U]) => Iterator[V]): RDD[V] = {
+      new StateStoreAwareZipPartitionsRDD(
+        dataRDD.sparkContext, f, dataRDD, dataRDD2, stateInfo, storeNames, Some(storeCoordinator))
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 182fc27..6fe632f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.{SparkContext, SparkEnv}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.{ThreadUtils, Utils}
 
@@ -120,6 +121,15 @@ case class StateStoreMetrics(
     memoryUsedBytes: Long,
     customMetrics: Map[StateStoreCustomMetric, Long])
 
+object StateStoreMetrics {
+  def combine(allMetrics: Seq[StateStoreMetrics]): StateStoreMetrics = {
+    StateStoreMetrics(
+      allMetrics.map(_.numKeys).sum,
+      allMetrics.map(_.memoryUsedBytes).sum,
+      allMetrics.flatMap(_.customMetrics).toMap)
+  }
+}
+
 /**
  * Name and description of custom implementation-specific metrics that a
  * state store may wish to expose.
@@ -227,6 +237,17 @@ object StateStoreProvider {
  */
 case class StateStoreProviderId(storeId: StateStoreId, queryRunId: UUID)
 
+object StateStoreProviderId {
+  private[sql] def apply(
+      stateInfo: StatefulOperatorStateInfo,
+      partitionIndex: Int,
+      storeName: String): StateStoreProviderId = {
+    val storeId = StateStoreId(
+      stateInfo.checkpointLocation, stateInfo.operatorId, partitionIndex, storeName)
+    StateStoreProviderId(storeId, stateInfo.queryRunId)
+  }
+}
+
 /**
  * Unique identifier for a bunch of keyed state data.
  * @param checkpointRootLocation Root directory where all the state data of a query is stored

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
index 3884f5e..2b14d37 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
@@ -84,7 +84,7 @@ object StateStoreCoordinatorRef extends Logging {
  */
 class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
 
-  private[state] def reportActiveInstance(
+  private[sql] def reportActiveInstance(
       stateStoreProviderId: StateStoreProviderId,
       host: String,
       executorId: String): Unit = {
@@ -92,14 +92,14 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
   }
 
   /** Verify whether the given executor has the active instance of a state store */
-  private[state] def verifyIfInstanceActive(
+  private[sql] def verifyIfInstanceActive(
       stateStoreProviderId: StateStoreProviderId,
       executorId: String): Boolean = {
     rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(stateStoreProviderId, executorId))
   }
 
   /** Get the location of the state store */
-  private[state] def getLocation(stateStoreProviderId: StateStoreProviderId): Option[String] = {
+  private[sql] def getLocation(stateStoreProviderId: StateStoreProviderId): Option[String] = {
     rpcEndpointRef.askSync[Option[String]](GetLocation(stateStoreProviderId))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
new file mode 100644
index 0000000..3764871
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec}
+import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
+import org.apache.spark.sql.types.{LongType, StructField, StructType}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Helper class to manage state required by a single side of [[StreamingSymmetricHashJoinExec]].
+ * The interface of this class is basically that of a multi-map:
+ * - Get: Returns an iterator of multiple values for given key
+ * - Append: Append a new value to the given key
+ * - Remove Data by predicate: Drop any state using a predicate condition on keys or values
+ *
+ * @param joinSide              Defines the join side
+ * @param inputValueAttributes  Attributes of the input row which will be stored as value
+ * @param joinKeys              Expressions to generate rows that will be used to key the value rows
+ * @param stateInfo             Information about how to retrieve the correct version of state
+ * @param storeConf             Configuration for the state store.
+ * @param hadoopConf            Hadoop configuration for reading state data from storage
+ *
+ * Internally, the key -> multiple values is stored in two [[StateStore]]s.
+ * - Store 1 ([[KeyToNumValuesStore]]) maintains mapping between key -> number of values
+ * - Store 2 ([[KeyWithIndexToValueStore]]) maintains mapping between (key, index) -> value
+ * - Put:   update count in KeyToNumValuesStore,
+ *          insert new (key, count) -> value in KeyWithIndexToValueStore
+ * - Get:   read count from KeyToNumValuesStore,
+ *          read each of the n values in KeyWithIndexToValueStore
+ * - Remove state by predicate on keys:
+ *          scan all keys in KeyToNumValuesStore to find keys that do match the predicate,
+ *          delete from key from KeyToNumValuesStore, delete values in KeyWithIndexToValueStore
+ * - Remove state by condition on values:
+ *          scan all [(key, index) -> value] in KeyWithIndexToValueStore to find values that match
+ *          the predicate, delete corresponding (key, indexToDelete) from KeyWithIndexToValueStore
+ *          by overwriting with the value of (key, maxIndex), and removing [(key, maxIndex),
+ *          decrement corresponding num values in KeyToNumValuesStore
+ */
+class SymmetricHashJoinStateManager(
+    val joinSide: JoinSide,
+    inputValueAttributes: Seq[Attribute],
+    joinKeys: Seq[Expression],
+    stateInfo: Option[StatefulOperatorStateInfo],
+    storeConf: StateStoreConf,
+    hadoopConf: Configuration) extends Logging {
+
+  import SymmetricHashJoinStateManager._
+
+  /*
+  =====================================================
+                  Public methods
+  =====================================================
+   */
+
+  /** Get all the values of a key */
+  def get(key: UnsafeRow): Iterator[UnsafeRow] = {
+    val numValues = keyToNumValues.get(key)
+    keyWithIndexToValue.getAll(key, numValues)
+  }
+
+  /** Append a new value to the key */
+  def append(key: UnsafeRow, value: UnsafeRow): Unit = {
+    val numExistingValues = keyToNumValues.get(key)
+    keyWithIndexToValue.put(key, numExistingValues, value)
+    keyToNumValues.put(key, numExistingValues + 1)
+  }
+
+  /**
+   * Remove using a predicate on keys. See class docs for more context and implement details.
+   */
+  def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = {
+    val allKeyToNumValues = keyToNumValues.iterator
+
+    while (allKeyToNumValues.hasNext) {
+      val keyToNumValue = allKeyToNumValues.next
+      if (condition(keyToNumValue.key)) {
+        keyToNumValues.remove(keyToNumValue.key)
+        keyWithIndexToValue.removeAllValues(keyToNumValue.key, keyToNumValue.numValue)
+      }
+    }
+  }
+
+  /**
+   * Remove using a predicate on values. See class docs for more context and implementation details.
+   */
+  def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = {
+    val allKeyToNumValues = keyToNumValues.iterator
+
+    while (allKeyToNumValues.hasNext) {
+      val keyToNumValue = allKeyToNumValues.next
+      val key = keyToNumValue.key
+
+      var numValues: Long = keyToNumValue.numValue
+      var index: Long = 0L
+      var valueRemoved: Boolean = false
+      var valueForIndex: UnsafeRow = null
+
+      while (index < numValues) {
+        if (valueForIndex == null) {
+          valueForIndex = keyWithIndexToValue.get(key, index)
+        }
+        if (condition(valueForIndex)) {
+          if (numValues > 1) {
+            val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues - 1)
+            keyWithIndexToValue.put(key, index, valueAtMaxIndex)
+            keyWithIndexToValue.remove(key, numValues - 1)
+            valueForIndex = valueAtMaxIndex
+          } else {
+            keyWithIndexToValue.remove(key, 0)
+            valueForIndex = null
+          }
+          numValues -= 1
+          valueRemoved = true
+        } else {
+          valueForIndex = null
+          index += 1
+        }
+      }
+      if (valueRemoved) {
+        if (numValues >= 1) {
+          keyToNumValues.put(key, numValues)
+        } else {
+          keyToNumValues.remove(key)
+        }
+      }
+    }
+  }
+
+  def iterator(): Iterator[UnsafeRowPair] = {
+    val pair = new UnsafeRowPair()
+    keyWithIndexToValue.iterator.map { x =>
+      pair.withRows(x.key, x.value)
+    }
+  }
+
+  /** Commit all the changes to all the state stores */
+  def commit(): Unit = {
+    keyToNumValues.commit()
+    keyWithIndexToValue.commit()
+  }
+
+  /** Abort any changes to the state stores if needed */
+  def abortIfNeeded(): Unit = {
+    keyToNumValues.abortIfNeeded()
+    keyWithIndexToValue.abortIfNeeded()
+  }
+
+  /** Get the combined metrics of all the state stores */
+  def metrics: StateStoreMetrics = {
+    val keyToNumValuesMetrics = keyToNumValues.metrics
+    val keyWithIndexToValueMetrics = keyWithIndexToValue.metrics
+    def newDesc(desc: String): String = s"${joinSide.toString.toUpperCase}: $desc"
+
+    StateStoreMetrics(
+      keyWithIndexToValueMetrics.numKeys,       // represent each buffered row only once
+      keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes,
+      keyWithIndexToValueMetrics.customMetrics.map {
+        case (s @ StateStoreCustomSizeMetric(_, desc), value) =>
+          s.copy(desc = newDesc(desc)) -> value
+        case (s @ StateStoreCustomTimingMetric(_, desc), value) =>
+          s.copy(desc = newDesc(desc)) -> value
+      }
+    )
+  }
+
+  /*
+  =====================================================
+            Private methods and inner classes
+  =====================================================
+   */
+
+  private val keySchema = StructType(
+    joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) })
+  private val keyAttributes = keySchema.toAttributes
+  private val keyToNumValues = new KeyToNumValuesStore()
+  private val keyWithIndexToValue = new KeyWithIndexToValueStore()
+
+  // Clean up any state store resources if necessary at the end of the task
+  Option(TaskContext.get()).foreach { _.addTaskCompletionListener { _ => abortIfNeeded() } }
+
+  /** Helper trait for invoking common functionalities of a state store. */
+  private abstract class StateStoreHandler(stateStoreType: StateStoreType) extends Logging {
+
+    /** StateStore that the subclasses of this class is going to operate on */
+    protected def stateStore: StateStore
+
+    def commit(): Unit = {
+      stateStore.commit()
+      logDebug("Committed, metrics = " + stateStore.metrics)
+    }
+
+    def abortIfNeeded(): Unit = {
+      if (!stateStore.hasCommitted) {
+        logInfo(s"Aborted store ${stateStore.id}")
+        stateStore.abort()
+      }
+    }
+
+    def metrics: StateStoreMetrics = stateStore.metrics
+
+    /** Get the StateStore with the given schema */
+    protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = {
+      val storeProviderId = StateStoreProviderId(
+        stateInfo.get, TaskContext.getPartitionId(), getStateStoreName(joinSide, stateStoreType))
+      val store = StateStore.get(
+        storeProviderId, keySchema, valueSchema, None,
+        stateInfo.get.storeVersion, storeConf, hadoopConf)
+      logInfo(s"Loaded store ${store.id}")
+      store
+    }
+  }
+
+  /**
+   * Helper class for representing data returned by [[KeyWithIndexToValueStore]].
+   * Designed for object reuse.
+   */
+  private case class KeyAndNumValues(var key: UnsafeRow = null, var numValue: Long = 0) {
+    def withNew(newKey: UnsafeRow, newNumValues: Long): this.type = {
+      this.key = newKey
+      this.numValue = newNumValues
+      this
+    }
+  }
+
+
+  /** A wrapper around a [[StateStore]] that stores [key -> number of values]. */
+  private class KeyToNumValuesStore extends StateStoreHandler(KeyToNumValuesType) {
+    private val longValueSchema = new StructType().add("value", "long")
+    private val longToUnsafeRow = UnsafeProjection.create(longValueSchema)
+    private val valueRow = longToUnsafeRow(new SpecificInternalRow(longValueSchema))
+    protected val stateStore: StateStore = getStateStore(keySchema, longValueSchema)
+
+    /** Get the number of values the key has */
+    def get(key: UnsafeRow): Long = {
+      val longValueRow = stateStore.get(key)
+      if (longValueRow != null) longValueRow.getLong(0) else 0L
+    }
+
+    /** Set the number of values the key has */
+    def put(key: UnsafeRow, numValues: Long): Unit = {
+      require(numValues > 0)
+      valueRow.setLong(0, numValues)
+      stateStore.put(key, valueRow)
+    }
+
+    def remove(key: UnsafeRow): Unit = {
+      stateStore.remove(key)
+    }
+
+    def iterator: Iterator[KeyAndNumValues] = {
+      val keyAndNumValues = new KeyAndNumValues()
+      stateStore.getRange(None, None).map { case pair =>
+        keyAndNumValues.withNew(pair.key, pair.value.getLong(0))
+      }
+    }
+  }
+
+  /**
+   * Helper class for representing data returned by [[KeyWithIndexToValueStore]].
+   * Designed for object reuse.
+   */
+  private case class KeyWithIndexAndValue(
+    var key: UnsafeRow = null, var valueIndex: Long = -1, var value: UnsafeRow = null) {
+    def withNew(newKey: UnsafeRow, newIndex: Long, newValue: UnsafeRow): this.type = {
+      this.key = newKey
+      this.valueIndex = newIndex
+      this.value = newValue
+      this
+    }
+  }
+
+  /** A wrapper around a [[StateStore]] that stores [(key, index) -> value]. */
+  private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValuesType) {
+    private val keyWithIndexExprs = keyAttributes :+ Literal(1L)
+    private val keyWithIndexSchema = keySchema.add("index", LongType)
+    private val indexOrdinalInKeyWithIndexRow = keyAttributes.size
+
+    // Projection to generate (key + index) row from key row
+    private val keyWithIndexRowGenerator = UnsafeProjection.create(keyWithIndexExprs, keyAttributes)
+
+    // Projection to generate key row from (key + index) row
+    private val keyRowGenerator = UnsafeProjection.create(
+      keyAttributes, keyAttributes :+ AttributeReference("index", LongType)())
+
+    protected val stateStore = getStateStore(keyWithIndexSchema, inputValueAttributes.toStructType)
+
+    def get(key: UnsafeRow, valueIndex: Long): UnsafeRow = {
+      stateStore.get(keyWithIndexRow(key, valueIndex))
+    }
+
+    /** Get all the values for key and all indices. */
+    def getAll(key: UnsafeRow, numValues: Long): Iterator[UnsafeRow] = {
+      var index = 0
+      new NextIterator[UnsafeRow] {
+        override protected def getNext(): UnsafeRow = {
+          if (index >= numValues) {
+            finished = true
+            null
+          } else {
+            val keyWithIndex = keyWithIndexRow(key, index)
+            val value = stateStore.get(keyWithIndex)
+            index += 1
+            value
+          }
+        }
+
+        override protected def close(): Unit = {}
+      }
+    }
+
+    /** Put new value for key at the given index */
+    def put(key: UnsafeRow, valueIndex: Long, value: UnsafeRow): Unit = {
+      val keyWithIndex = keyWithIndexRow(key, valueIndex)
+      stateStore.put(keyWithIndex, value)
+    }
+
+    /**
+     * Remove key and value at given index. Note that this will create a hole in
+     * (key, index) and it is upto the caller to deal with it.
+     */
+    def remove(key: UnsafeRow, valueIndex: Long): Unit = {
+      stateStore.remove(keyWithIndexRow(key, valueIndex))
+    }
+
+    /** Remove all values (i.e. all the indices) for the given key. */
+    def removeAllValues(key: UnsafeRow, numValues: Long): Unit = {
+      var index = 0
+      while (index < numValues) {
+        stateStore.remove(keyWithIndexRow(key, index))
+        index += 1
+      }
+    }
+
+    def iterator: Iterator[KeyWithIndexAndValue] = {
+      val keyWithIndexAndValue = new KeyWithIndexAndValue()
+      stateStore.getRange(None, None).map { pair =>
+        keyWithIndexAndValue.withNew(
+          keyRowGenerator(pair.key), pair.key.getLong(indexOrdinalInKeyWithIndexRow), pair.value)
+        keyWithIndexAndValue
+      }
+    }
+
+    /** Generated a row using the key and index */
+    private def keyWithIndexRow(key: UnsafeRow, valueIndex: Long): UnsafeRow = {
+      val row = keyWithIndexRowGenerator(key)
+      row.setLong(indexOrdinalInKeyWithIndexRow, valueIndex)
+      row
+    }
+  }
+}
+
+object SymmetricHashJoinStateManager {
+
+  def allStateStoreNames(joinSides: JoinSide*): Seq[String] = {
+    val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValuesType)
+    for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield {
+      getStateStoreName(joinSide, stateStoreType)
+    }
+  }
+
+  private sealed trait StateStoreType
+
+  private case object KeyToNumValuesType extends StateStoreType {
+    override def toString(): String = "keyToNumValues"
+  }
+
+  private case object KeyWithIndexToValuesType extends StateStoreType {
+    override def toString(): String = "keyWithIndexToNumValues"
+  }
+
+  private def getStateStoreName(joinSide: JoinSide, storeType: StateStoreType): String = {
+    s"$joinSide-$storeType"
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index a0086e2..0b32327 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.execution.streaming
 
-import java.util.UUID
-
 import scala.reflect.ClassTag
 
 import org.apache.spark.TaskContext

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index d6566b8..fb960fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -43,7 +43,12 @@ case class StatefulOperatorStateInfo(
     checkpointLocation: String,
     queryRunId: UUID,
     operatorId: Long,
-    storeVersion: Long)
+    storeVersion: Long) {
+  override def toString(): String = {
+    s"state info [ checkpoint = $checkpointLocation, runId = $queryRunId, " +
+      s"opId = $operatorId, ver = $storeVersion]"
+  }
+}
 
 /**
  * An operator that reads or writes state from the [[StateStore]].
@@ -133,26 +138,9 @@ trait WatermarkSupport extends UnaryExecNode {
 
   /** Generate an expression that matches data older than the watermark */
   lazy val watermarkExpression: Option[Expression] = {
-    val optionalWatermarkAttribute =
-      child.output.find(_.metadata.contains(EventTimeWatermark.delayKey))
-
-    optionalWatermarkAttribute.map { watermarkAttribute =>
-      // If we are evicting based on a window, use the end of the window.  Otherwise just
-      // use the attribute itself.
-      val evictionExpression =
-        if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
-          LessThanOrEqual(
-            GetStructField(watermarkAttribute, 1),
-            Literal(eventTimeWatermark.get * 1000))
-        } else {
-          LessThanOrEqual(
-            watermarkAttribute,
-            Literal(eventTimeWatermark.get * 1000))
-        }
-
-      logInfo(s"Filtering state store on: $evictionExpression")
-      evictionExpression
-    }
+    WatermarkSupport.watermarkExpression(
+      child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)),
+      eventTimeWatermark)
   }
 
   /** Predicate based on keys that matches data older than the watermark */
@@ -179,6 +167,31 @@ trait WatermarkSupport extends UnaryExecNode {
   }
 }
 
+object WatermarkSupport {
+
+  /** Generate an expression on given attributes that matches data older than the watermark */
+  def watermarkExpression(
+      optionalWatermarkExpression: Option[Expression],
+      optionalWatermarkMs: Option[Long]): Option[Expression] = {
+    if (optionalWatermarkExpression.isEmpty || optionalWatermarkMs.isEmpty) return None
+
+    val watermarkAttribute = optionalWatermarkExpression.get
+    // If we are evicting based on a window, use the end of the window.  Otherwise just
+    // use the attribute itself.
+    val evictionExpression =
+      if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
+        LessThanOrEqual(
+          GetStructField(watermarkAttribute, 1),
+          Literal(optionalWatermarkMs.get * 1000))
+      } else {
+        LessThanOrEqual(
+          watermarkAttribute,
+          Literal(optionalWatermarkMs.get * 1000))
+      }
+    Some(evictionExpression)
+  }
+}
+
 /**
  * For each input tuple, the key is calculated and the value from the [[StateStore]] is added
  * to the stream (in addition to the input tuple) if present.

http://git-wip-us.apache.org/repos/asf/spark/blob/f32a8425/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
new file mode 100644
index 0000000..ffa4c3c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.UUID
+
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
+import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
+import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.LeftSide
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.types._
+
+class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter {
+
+  before {
+    SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec'
+    spark.streams.stateStoreCoordinator // initialize the lazy coordinator
+  }
+
+
+  test("SymmetricHashJoinStateManager - all operations") {
+    withJoinStateManager(inputValueAttribs, joinKeyExprs) { manager =>
+      implicit val mgr = manager
+
+      assert(get(20) === Seq.empty)     // initially empty
+      append(20, 2)
+      assert(get(20) === Seq(2))        // should first value correctly
+      assert(numRows === 1)
+
+      append(20, 3)
+      assert(get(20) === Seq(2, 3))     // should append new values
+      append(20, 3)
+      assert(get(20) === Seq(2, 3, 3))  // should append another copy if same value added again
+      assert(numRows === 3)
+
+      assert(get(30) === Seq.empty)
+      append(30, 1)
+      assert(get(30) === Seq(1))
+      assert(get(20) === Seq(2, 3, 3))  // add another key-value should not affect existing ones
+      assert(numRows === 4)
+
+      removeByKey(25)
+      assert(get(20) === Seq.empty)
+      assert(get(30) === Seq(1))        // should remove 20, not 30
+      assert(numRows === 1)
+
+      removeByKey(30)
+      assert(get(30) === Seq.empty)     // should remove 30
+      assert(numRows === 0)
+
+      def appendAndTest(key: Int, values: Int*): Unit = {
+        values.foreach { value => append(key, value)}
+        require(get(key) === values)
+      }
+
+      appendAndTest(40, 100, 200, 300)
+      appendAndTest(50, 125)
+      appendAndTest(60, 275)              // prepare for testing removeByValue
+      assert(numRows === 5)
+
+      removeByValue(125)
+      assert(get(40) === Seq(200, 300))
+      assert(get(50) === Seq.empty)
+      assert(get(60) === Seq(275))        // should remove only some values, not all
+      assert(numRows === 3)
+
+      append(40, 50)
+      assert(get(40) === Seq(50, 200, 300))
+      assert(numRows === 4)
+
+      removeByValue(200)
+      assert(get(40) === Seq(300))
+      assert(get(60) === Seq(275))        // should remove only some values, not all
+      assert(numRows === 2)
+
+      removeByValue(300)
+      assert(get(40) === Seq.empty)
+      assert(get(60) === Seq.empty)       // should remove all values now
+      assert(numRows === 0)
+    }
+  }
+  val watermarkMetadata = new MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build()
+  val inputValueSchema = new StructType()
+    .add(StructField("time", IntegerType, metadata = watermarkMetadata))
+    .add(StructField("value", BooleanType))
+  val inputValueAttribs = inputValueSchema.toAttributes
+  val inputValueAttribWithWatermark = inputValueAttribs(0)
+  val joinKeyExprs = Seq[Expression](Literal(false), inputValueAttribWithWatermark, Literal(10.0))
+
+  val inputValueGen = UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray)
+  val joinKeyGen = UnsafeProjection.create(joinKeyExprs.map(_.dataType).toArray)
+
+
+  def toInputValue(i: Int): UnsafeRow = {
+    inputValueGen.apply(new GenericInternalRow(Array[Any](i, false)))
+  }
+
+  def toJoinKeyRow(i: Int): UnsafeRow = {
+    joinKeyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0)))
+  }
+
+  def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0)
+
+  def append(key: Int, value: Int)(implicit manager: SymmetricHashJoinStateManager): Unit = {
+    manager.append(toJoinKeyRow(key), toInputValue(value))
+  }
+
+  def get(key: Int)(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = {
+    manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted
+  }
+
+  /** Remove keys (and corresponding values) where `time <= threshold` */
+  def removeByKey(threshold: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = {
+    val expr =
+      LessThanOrEqual(
+        BoundReference(
+          1, inputValueAttribWithWatermark.dataType, inputValueAttribWithWatermark.nullable),
+        Literal(threshold))
+    manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _)
+  }
+
+  /** Remove values where `time <= threshold` */
+  def removeByValue(watermark: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = {
+    val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark))
+    manager.removeByValueCondition(
+      GeneratePredicate.generate(expr, inputValueAttribs).eval _)
+  }
+
+  def numRows(implicit manager: SymmetricHashJoinStateManager): Long = {
+    manager.metrics.numKeys
+  }
+
+
+  def withJoinStateManager(
+    inputValueAttribs: Seq[Attribute],
+    joinKeyExprs: Seq[Expression])(f: SymmetricHashJoinStateManager => Unit): Unit = {
+
+    withTempDir { file =>
+      val storeConf = new StateStoreConf()
+      val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0)
+      val manager = new SymmetricHashJoinStateManager(
+        LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration)
+      try {
+        f(manager)
+      } finally {
+        manager.abortIfNeeded()
+      }
+    }
+    StateStore.stop()
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org