You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2017/10/17 19:26:57 UTC

spark git commit: [SPARK-22136][SS] Evaluate one-sided conditions early in stream-stream joins.

Repository: spark
Updated Branches:
  refs/heads/master e1960c3d6 -> 75d666b95


[SPARK-22136][SS] Evaluate one-sided conditions early in stream-stream joins.

## What changes were proposed in this pull request?

Evaluate one-sided conditions early in stream-stream joins.

This is in addition to normal filter pushdown, because integrating it with the join logic allows it to take place in outer join scenarios. This means that rows which can never satisfy the join condition won't clog up the state.

## How was this patch tested?
new unit tests

Author: Jose Torres <jo...@databricks.com>

Closes #19452 from joseph-torres/SPARK-22136.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/75d666b9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/75d666b9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/75d666b9

Branch: refs/heads/master
Commit: 75d666b95a711787355ca3895057dabadd429023
Parents: e1960c3
Author: Jose Torres <jo...@databricks.com>
Authored: Tue Oct 17 12:26:53 2017 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Tue Oct 17 12:26:53 2017 -0700

----------------------------------------------------------------------
 .../streaming/IncrementalExecution.scala        |   2 +-
 .../StreamingSymmetricHashJoinExec.scala        | 134 +++++++++++------
 .../StreamingSymmetricHashJoinHelper.scala      |  70 ++++++++-
 .../sql/streaming/StreamingJoinSuite.scala      | 150 ++++++++++++++++++-
 .../StreamingSymmetricHashJoinHelperSuite.scala | 130 ++++++++++++++++
 5 files changed, 433 insertions(+), 53 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/75d666b9/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 2e37863..a10ed5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -133,7 +133,7 @@ class IncrementalExecution(
           eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs),
           stateWatermarkPredicates =
             StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates(
-              j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition,
+              j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full,
               Some(offsetSeqMetadata.batchWatermarkMs))
         )
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/75d666b9/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index 9bd2127..c351f65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, GenericInternalRow, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
 import org.apache.spark.sql.catalyst.plans.physical._
@@ -29,7 +29,6 @@ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
 import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.internal.SessionState
-import org.apache.spark.sql.types.{LongType, TimestampType}
 import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}
 
 
@@ -115,7 +114,8 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}
  * @param leftKeys  Expression to generate key rows for joining from left input
  * @param rightKeys Expression to generate key rows for joining from right input
  * @param joinType  Type of join (inner, left outer, etc.)
- * @param condition Optional, additional condition to filter output of the equi-join
+ * @param condition Conditions to filter rows, split by left, right, and joined. See
+ *                  [[JoinConditionSplitPredicates]]
  * @param stateInfo Version information required to read join state (buffered rows)
  * @param eventTimeWatermark Watermark of input event, same for both sides
  * @param stateWatermarkPredicates Predicates for removal of state, see
@@ -127,7 +127,7 @@ case class StreamingSymmetricHashJoinExec(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
     joinType: JoinType,
-    condition: Option[Expression],
+    condition: JoinConditionSplitPredicates,
     stateInfo: Option[StatefulOperatorStateInfo],
     eventTimeWatermark: Option[Long],
     stateWatermarkPredicates: JoinStateWatermarkPredicates,
@@ -141,8 +141,10 @@ case class StreamingSymmetricHashJoinExec(
       condition: Option[Expression],
       left: SparkPlan,
       right: SparkPlan) = {
+
     this(
-      leftKeys, rightKeys, joinType, condition, stateInfo = None, eventTimeWatermark = None,
+      leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, left, right),
+      stateInfo = None, eventTimeWatermark = None,
       stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right)
   }
 
@@ -161,6 +163,9 @@ case class StreamingSymmetricHashJoinExec(
     new SerializableConfiguration(SessionState.newHadoopConf(
       sparkContext.hadoopConfiguration, sqlContext.conf)))
 
+  val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length)
+  val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length)
+
   override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
@@ -206,10 +211,15 @@ case class StreamingSymmetricHashJoinExec(
     val updateStartTimeNs = System.nanoTime
     val joinedRow = new JoinedRow
 
+
+    val postJoinFilter =
+      newPredicate(condition.bothSides.getOrElse(Literal(true)), left.output ++ right.output).eval _
     val leftSideJoiner = new OneSideHashJoiner(
-      LeftSide, left.output, leftKeys, leftInputIter, stateWatermarkPredicates.left)
+      LeftSide, left.output, leftKeys, leftInputIter,
+      condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left)
     val rightSideJoiner = new OneSideHashJoiner(
-      RightSide, right.output, rightKeys, rightInputIter, stateWatermarkPredicates.right)
+      RightSide, right.output, rightKeys, rightInputIter,
+      condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right)
 
     //  Join one side input using the other side's buffered/state rows. Here is how it is done.
     //
@@ -221,43 +231,28 @@ case class StreamingSymmetricHashJoinExec(
     //    matching new left input with new right input, since the new left input has become stored
     //    by that point. This tiny asymmetry is necessary to avoid duplication.
     val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) {
-      (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(input).withRight(matched)
+      (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(input).withRight(matched)
     }
     val rightOutputIter = rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) {
-      (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(matched).withRight(input)
+      (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(matched).withRight(input)
     }
 
-    // Filter the joined rows based on the given condition.
-    val outputFilterFunction = newPredicate(condition.getOrElse(Literal(true)), output).eval _
-
     // We need to save the time that the inner join output iterator completes, since outer join
     // output counts as both update and removal time.
     var innerOutputCompletionTimeNs: Long = 0
     def onInnerOutputCompletion = {
       innerOutputCompletionTimeNs = System.nanoTime
     }
-    val filteredInnerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]](
-      (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction), onInnerOutputCompletion)
-
-    def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = {
-      rightSideJoiner.get(leftKeyValue.key).exists(
-        rightValue => {
-          outputFilterFunction(
-            joinedRow.withLeft(leftKeyValue.value).withRight(rightValue))
-        })
-    }
+    // This is the iterator which produces the inner join rows. For outer joins, this will be
+    // prepended to a second iterator producing outer join rows; for inner joins, this is the full
+    // output.
+    val innerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]](
+      (leftOutputIter ++ rightOutputIter), onInnerOutputCompletion)
 
-    def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = {
-      leftSideJoiner.get(rightKeyValue.key).exists(
-        leftValue => {
-          outputFilterFunction(
-            joinedRow.withLeft(leftValue).withRight(rightKeyValue.value))
-        })
-    }
 
     val outputIter: Iterator[InternalRow] = joinType match {
       case Inner =>
-        filteredInnerOutputIter
+        innerOutputIter
       case LeftOuter =>
         // We generate the outer join input by:
         // * Getting an iterator over the rows that have aged out on the left side. These rows are
@@ -268,28 +263,37 @@ case class StreamingSymmetricHashJoinExec(
         //   we know we can join with null, since there was never (including this batch) a match
         //   within the watermark period. If it does, there must have been a match at some point, so
         //   we know we can't join with null.
-        val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length)
+        def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = {
+          rightSideJoiner.get(leftKeyValue.key).exists { rightValue =>
+            postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue))
+          }
+        }
         val removedRowIter = leftSideJoiner.removeOldState()
         val outerOutputIter = removedRowIter
           .filterNot(pair => matchesWithRightSideState(pair))
           .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
 
-        filteredInnerOutputIter ++ outerOutputIter
+        innerOutputIter ++ outerOutputIter
       case RightOuter =>
         // See comments for left outer case.
-        val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length)
+        def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = {
+          leftSideJoiner.get(rightKeyValue.key).exists { leftValue =>
+            postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value))
+          }
+        }
         val removedRowIter = rightSideJoiner.removeOldState()
         val outerOutputIter = removedRowIter
           .filterNot(pair => matchesWithLeftSideState(pair))
           .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
 
-        filteredInnerOutputIter ++ outerOutputIter
+        innerOutputIter ++ outerOutputIter
       case _ => throwBadJoinTypeException()
     }
 
+    val outputProjection = UnsafeProjection.create(left.output ++ right.output, output)
     val outputIterWithMetrics = outputIter.map { row =>
       numOutputRows += 1
-      row
+      outputProjection(row)
     }
 
     // Function to remove old state after all the input has been consumed and output generated
@@ -349,14 +353,36 @@ case class StreamingSymmetricHashJoinExec(
   /**
    * Internal helper class to consume input rows, generate join output rows using other sides
    * buffered state rows, and finally clean up this sides buffered state rows
+   *
+   * @param joinSide The JoinSide - either left or right.
+   * @param inputAttributes The input attributes for this side of the join.
+   * @param joinKeys The join keys.
+   * @param inputIter The iterator of input rows on this side to be joined.
+   * @param preJoinFilterExpr A filter over rows on this side. This filter rejects rows that could
+   *                          never pass the overall join condition no matter what other side row
+   *                          they're joined with.
+   * @param postJoinFilter A filter over joined rows. This filter completes the application of
+   *                       the overall join condition, assuming that preJoinFilter on both sides
+   *                       of the join has already been passed.
+   *                       Passed as a function rather than expression to avoid creating the
+   *                       predicate twice; we also need this filter later on in the parent exec.
+   * @param stateWatermarkPredicate The state watermark predicate. See
+   *                                [[StreamingSymmetricHashJoinExec]] for further description of
+   *                                state watermarks.
    */
   private class OneSideHashJoiner(
       joinSide: JoinSide,
       inputAttributes: Seq[Attribute],
       joinKeys: Seq[Expression],
       inputIter: Iterator[InternalRow],
+      preJoinFilterExpr: Option[Expression],
+      postJoinFilter: (InternalRow) => Boolean,
       stateWatermarkPredicate: Option[JoinStateWatermarkPredicate]) {
 
+    // Filter the joined rows based on the given condition.
+    val preJoinFilter =
+      newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _
+
     private val joinStateManager = new SymmetricHashJoinStateManager(
       joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value)
     private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes)
@@ -388,8 +414,8 @@ case class StreamingSymmetricHashJoinExec(
      */
     def storeAndJoinWithOtherSide(
         otherSideJoiner: OneSideHashJoiner)(
-        generateJoinedRow: (UnsafeRow, UnsafeRow) => JoinedRow): Iterator[InternalRow] = {
-
+        generateJoinedRow: (InternalRow, InternalRow) => JoinedRow):
+    Iterator[InternalRow] = {
       val watermarkAttribute = inputAttributes.find(_.metadata.contains(delayKey))
       val nonLateRows =
         WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match {
@@ -402,17 +428,31 @@ case class StreamingSymmetricHashJoinExec(
 
       nonLateRows.flatMap { row =>
         val thisRow = row.asInstanceOf[UnsafeRow]
-        val key = keyGenerator(thisRow)
-        val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow =>
-          generateJoinedRow(thisRow, thatRow)
-        }
-        val shouldAddToState = // add only if both removal predicates do not match
-          !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow)
-        if (shouldAddToState) {
-          joinStateManager.append(key, thisRow)
-          updatedStateRowsCount += 1
+        // If this row fails the pre join filter, that means it can never satisfy the full join
+        // condition no matter what other side row it's matched with. This allows us to avoid
+        // adding it to the state, and generate an outer join row immediately (or do nothing in
+        // the case of inner join).
+        if (preJoinFilter(thisRow)) {
+          val key = keyGenerator(thisRow)
+          val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow =>
+            generateJoinedRow(thisRow, thatRow)
+          }.filter(postJoinFilter)
+          val shouldAddToState = // add only if both removal predicates do not match
+            !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow)
+          if (shouldAddToState) {
+            joinStateManager.append(key, thisRow)
+            updatedStateRowsCount += 1
+          }
+          outputIter
+        } else {
+          joinSide match {
+            case LeftSide if joinType == LeftOuter =>
+              Iterator(generateJoinedRow(thisRow, nullRight))
+            case RightSide if joinType == RightOuter =>
+              Iterator(generateJoinedRow(thisRow, nullLeft))
+            case _ => Iterator()
+          }
         }
-        outputIter
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/75d666b9/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
index 64c7189..167e991 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
@@ -24,8 +24,9 @@ import org.apache.spark.{Partition, SparkContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2}
 import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper
-import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus}
+import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus}
 import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
+import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression
 import org.apache.spark.sql.execution.streaming.state.{StateStoreCoordinatorRef, StateStoreProvider, StateStoreProviderId}
 import org.apache.spark.sql.types._
@@ -66,6 +67,73 @@ object StreamingSymmetricHashJoinHelper extends Logging {
     }
   }
 
+  /**
+   * Wrapper around various useful splits of the join condition.
+   * left AND right AND joined is equivalent to full.
+   *
+   * Note that left and right do not necessarily contain *all* conjuncts which satisfy
+   * their condition. Any conjuncts after the first nondeterministic one are treated as
+   * nondeterministic for purposes of the split.
+   *
+   * @param leftSideOnly Deterministic conjuncts which reference only the left side of the join.
+   * @param rightSideOnly Deterministic conjuncts which reference only the right side of the join.
+   * @param bothSides Conjuncts which are nondeterministic, occur after a nondeterministic conjunct,
+   *                  or reference both left and right sides of the join.
+   * @param full The full join condition.
+   */
+  case class JoinConditionSplitPredicates(
+      leftSideOnly: Option[Expression],
+      rightSideOnly: Option[Expression],
+      bothSides: Option[Expression],
+      full: Option[Expression]) {
+    override def toString(): String = {
+      s"condition = [ leftOnly = ${leftSideOnly.map(_.toString).getOrElse("null")}, " +
+        s"rightOnly = ${rightSideOnly.map(_.toString).getOrElse("null")}, " +
+        s"both = ${bothSides.map(_.toString).getOrElse("null")}, " +
+        s"full = ${full.map(_.toString).getOrElse("null")} ]"
+    }
+  }
+
+  object JoinConditionSplitPredicates extends PredicateHelper {
+    def apply(condition: Option[Expression], left: SparkPlan, right: SparkPlan):
+        JoinConditionSplitPredicates = {
+      // Split the condition into 3 parts:
+      // * Conjuncts that can be evaluated on only the left input.
+      // * Conjuncts that can be evaluated on only the right input.
+      // * Conjuncts that require both left and right input.
+      //
+      // Note that we treat nondeterministic conjuncts as though they require both left and right
+      // input. To maintain their semantics, they need to be evaluated exactly once per joined row.
+      val (leftCondition, rightCondition, joinedCondition) = {
+        if (condition.isEmpty) {
+          (None, None, None)
+        } else {
+          // Span rather than partition, because nondeterministic expressions don't commute
+          // across AND.
+          val (deterministicConjuncts, nonDeterministicConjuncts) =
+            splitConjunctivePredicates(condition.get).span(_.deterministic)
+
+          val (leftConjuncts, nonLeftConjuncts) = deterministicConjuncts.partition { cond =>
+            cond.references.subsetOf(left.outputSet)
+          }
+
+          val (rightConjuncts, nonRightConjuncts) = deterministicConjuncts.partition { cond =>
+            cond.references.subsetOf(right.outputSet)
+          }
+
+          (
+            leftConjuncts.reduceOption(And),
+            rightConjuncts.reduceOption(And),
+            (nonLeftConjuncts.intersect(nonRightConjuncts) ++ nonDeterministicConjuncts)
+              .reduceOption(And)
+          )
+        }
+      }
+
+      JoinConditionSplitPredicates(leftCondition, rightCondition, joinedCondition, condition)
+    }
+  }
+
   /** Get the predicates defining the state watermarks for both sides of the join */
   def getStateWatermarkPredicates(
       leftAttributes: Seq[Attribute],

http://git-wip-us.apache.org/repos/asf/spark/blob/75d666b9/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index d326172..54eb863 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -365,6 +365,24 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
       }
     }
   }
+
+  test("join between three streams") {
+    val input1 = MemoryStream[Int]
+    val input2 = MemoryStream[Int]
+    val input3 = MemoryStream[Int]
+
+    val df1 = input1.toDF.select('value as "leftKey", ('value * 2) as "leftValue")
+    val df2 = input2.toDF.select('value as "middleKey", ('value * 3) as "middleValue")
+    val df3 = input3.toDF.select('value as "rightKey", ('value * 5) as "rightValue")
+
+    val joined = df1.join(df2, expr("leftKey = middleKey")).join(df3, expr("rightKey = middleKey"))
+
+    testStream(joined)(
+      AddData(input1, 1, 5),
+      AddData(input2, 1, 5, 10),
+      AddData(input3, 5, 10),
+      CheckLastBatch((5, 10, 5, 15, 5, 25)))
+  }
 }
 
 class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter {
@@ -405,6 +423,130 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
     (input1, input2, joined)
   }
 
+  test("left outer early state exclusion on left") {
+    val (leftInput, df1) = setupStream("left", 2)
+    val (rightInput, df2) = setupStream("right", 3)
+    // Use different schemas to ensure the null row is being generated from the correct side.
+    val left = df1.select('key, window('leftTime, "10 second"), 'leftValue)
+    val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string"))
+
+    val joined = left.join(
+        right,
+        left("key") === right("key")
+          && left("window") === right("window")
+          && 'leftValue > 4,
+        "left_outer")
+        .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue)
+
+    testStream(joined)(
+      AddData(leftInput, 1, 2, 3),
+      AddData(rightInput, 3, 4, 5),
+      // The left rows with leftValue <= 4 should generate their outer join row now and
+      // not get added to the state.
+      CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)),
+      assertNumStateRows(total = 4, updated = 4),
+      // We shouldn't get more outer join rows when the watermark advances.
+      AddData(leftInput, 20),
+      AddData(rightInput, 21),
+      CheckLastBatch(),
+      AddData(rightInput, 20),
+      CheckLastBatch((20, 30, 40, "60"))
+    )
+  }
+
+  test("left outer early state exclusion on right") {
+    val (leftInput, df1) = setupStream("left", 2)
+    val (rightInput, df2) = setupStream("right", 3)
+    // Use different schemas to ensure the null row is being generated from the correct side.
+    val left = df1.select('key, window('leftTime, "10 second"), 'leftValue)
+    val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string"))
+
+    val joined = left.join(
+      right,
+      left("key") === right("key")
+        && left("window") === right("window")
+        && 'rightValue.cast("int") > 7,
+      "left_outer")
+      .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue)
+
+    testStream(joined)(
+      AddData(leftInput, 3, 4, 5),
+      AddData(rightInput, 1, 2, 3),
+      // The right rows with value <= 7 should never be added to the state.
+      CheckLastBatch(Row(3, 10, 6, "9")),
+      assertNumStateRows(total = 4, updated = 4),
+      // When the watermark advances, we get the outer join rows just as we would if they
+      // were added but didn't match the full join condition.
+      AddData(leftInput, 20),
+      AddData(rightInput, 21),
+      CheckLastBatch(),
+      AddData(rightInput, 20),
+      CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null))
+    )
+  }
+
+  test("right outer early state exclusion on left") {
+    val (leftInput, df1) = setupStream("left", 2)
+    val (rightInput, df2) = setupStream("right", 3)
+    // Use different schemas to ensure the null row is being generated from the correct side.
+    val left = df1.select('key, window('leftTime, "10 second"), 'leftValue)
+    val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string"))
+
+    val joined = left.join(
+      right,
+      left("key") === right("key")
+        && left("window") === right("window")
+        && 'leftValue > 4,
+      "right_outer")
+      .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue)
+
+    testStream(joined)(
+      AddData(leftInput, 1, 2, 3),
+      AddData(rightInput, 3, 4, 5),
+      // The left rows with value <= 4 should never be added to the state.
+      CheckLastBatch(Row(3, 10, 6, "9")),
+      assertNumStateRows(total = 4, updated = 4),
+      // When the watermark advances, we get the outer join rows just as we would if they
+      // were added but didn't match the full join condition.
+      AddData(leftInput, 20),
+      AddData(rightInput, 21),
+      CheckLastBatch(),
+      AddData(rightInput, 20),
+      CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15"))
+    )
+  }
+
+  test("right outer early state exclusion on right") {
+    val (leftInput, df1) = setupStream("left", 2)
+    val (rightInput, df2) = setupStream("right", 3)
+    // Use different schemas to ensure the null row is being generated from the correct side.
+    val left = df1.select('key, window('leftTime, "10 second"), 'leftValue)
+    val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string"))
+
+    val joined = left.join(
+      right,
+      left("key") === right("key")
+        && left("window") === right("window")
+        && 'rightValue.cast("int") > 7,
+      "right_outer")
+      .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue)
+
+    testStream(joined)(
+      AddData(leftInput, 3, 4, 5),
+      AddData(rightInput, 1, 2, 3),
+      // The right rows with rightValue <= 7 should generate their outer join row now and
+      // not get added to the state.
+      CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")),
+      assertNumStateRows(total = 4, updated = 4),
+      // We shouldn't get more outer join rows when the watermark advances.
+      AddData(leftInput, 20),
+      AddData(rightInput, 21),
+      CheckLastBatch(),
+      AddData(rightInput, 20),
+      CheckLastBatch((20, 30, 40, "60"))
+    )
+  }
+
   test("windowed left outer join") {
     val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer")
 
@@ -495,7 +637,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
 
   // When the join condition isn't true, the outer null rows must be generated, even if the join
   // keys themselves have a match.
-  test("left outer join with non-key condition violated on left") {
+  test("left outer join with non-key condition violated") {
     val (leftInput, simpleLeftDf) = setupStream("left", 2)
     val (rightInput, simpleRightDf) = setupStream("right", 3)
 
@@ -513,14 +655,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
       // leftValue <= 10 should generate outer join rows even though it matches right keys
       AddData(leftInput, 1, 2, 3),
       AddData(rightInput, 1, 2, 3),
-      CheckLastBatch(),
+      CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)),
       AddData(leftInput, 20),
       AddData(rightInput, 21),
       CheckLastBatch(),
-      assertNumStateRows(total = 8, updated = 2),
+      assertNumStateRows(total = 5, updated = 2),
       AddData(rightInput, 20),
       CheckLastBatch(
-        Row(20, 30, 40, 60), Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)),
+        Row(20, 30, 40, 60)),
       assertNumStateRows(total = 3, updated = 1),
       // leftValue and rightValue both satisfying condition should not generate outer join rows
       AddData(leftInput, 40, 41),

http://git-wip-us.apache.org/repos/asf/spark/blob/75d666b9/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala
new file mode 100644
index 0000000..2a854e3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec, SparkPlan}
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
+import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinConditionSplitPredicates
+import org.apache.spark.sql.types._
+
+class StreamingSymmetricHashJoinHelperSuite extends StreamTest {
+  import org.apache.spark.sql.functions._
+
+  val leftAttributeA = AttributeReference("a", IntegerType)()
+  val leftAttributeB = AttributeReference("b", IntegerType)()
+  val rightAttributeC = AttributeReference("c", IntegerType)()
+  val rightAttributeD = AttributeReference("d", IntegerType)()
+  val leftColA = new Column(leftAttributeA)
+  val leftColB = new Column(leftAttributeB)
+  val rightColC = new Column(rightAttributeC)
+  val rightColD = new Column(rightAttributeD)
+
+  val left = new LocalTableScanExec(Seq(leftAttributeA, leftAttributeB), Seq())
+  val right = new LocalTableScanExec(Seq(rightAttributeC, rightAttributeD), Seq())
+
+  test("empty") {
+    val split = JoinConditionSplitPredicates(None, left, right)
+    assert(split.leftSideOnly.isEmpty)
+    assert(split.rightSideOnly.isEmpty)
+    assert(split.bothSides.isEmpty)
+    assert(split.full.isEmpty)
+  }
+
+  test("only literals") {
+    // Literal-only conjuncts end up on the left side because that's the first bucket they fit in.
+    // There's no semantic reason they couldn't be in any bucket.
+    val predicate = (lit(1) < lit(5) && lit(6) < lit(7) && lit(0) === lit(-1)).expr
+    val split = JoinConditionSplitPredicates(Some(predicate), left, right)
+
+    assert(split.leftSideOnly.contains(predicate))
+    assert(split.rightSideOnly.contains(predicate))
+    assert(split.bothSides.isEmpty)
+    assert(split.full.contains(predicate))
+  }
+
+  test("only left") {
+    val predicate = (leftColA > lit(1) && leftColB > lit(5) && leftColA < leftColB).expr
+    val split = JoinConditionSplitPredicates(Some(predicate), left, right)
+
+    assert(split.leftSideOnly.contains(predicate))
+    assert(split.rightSideOnly.isEmpty)
+    assert(split.bothSides.isEmpty)
+    assert(split.full.contains(predicate))
+  }
+
+  test("only right") {
+    val predicate = (rightColC > lit(1) && rightColD > lit(5) && rightColD < rightColC).expr
+    val split = JoinConditionSplitPredicates(Some(predicate), left, right)
+
+    assert(split.leftSideOnly.isEmpty)
+    assert(split.rightSideOnly.contains(predicate))
+    assert(split.bothSides.isEmpty)
+    assert(split.full.contains(predicate))
+  }
+
+  test("mixed conjuncts") {
+    val predicate =
+      (leftColA > leftColB
+        && rightColC > rightColD
+        && leftColA === rightColC
+        && lit(1) === lit(1)).expr
+    val split = JoinConditionSplitPredicates(Some(predicate), left, right)
+
+    assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr))
+    assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr))
+    assert(split.bothSides.contains((leftColA === rightColC).expr))
+    assert(split.full.contains(predicate))
+  }
+
+  test("conjuncts after nondeterministic") {
+    // All conjuncts after a nondeterministic conjunct shouldn't be split because they don't
+    // commute across it.
+    val predicate =
+      (rand() > lit(0)
+        && leftColA > leftColB
+        && rightColC > rightColD
+        && leftColA === rightColC
+        && lit(1) === lit(1)).expr
+    val split = JoinConditionSplitPredicates(Some(predicate), left, right)
+
+    assert(split.leftSideOnly.isEmpty)
+    assert(split.rightSideOnly.isEmpty)
+    assert(split.bothSides.contains(predicate))
+    assert(split.full.contains(predicate))
+  }
+
+
+  test("conjuncts before nondeterministic") {
+    val randCol = rand()
+    val predicate =
+      (leftColA > leftColB
+        && rightColC > rightColD
+        && leftColA === rightColC
+        && lit(1) === lit(1)
+        && randCol > lit(0)).expr
+    val split = JoinConditionSplitPredicates(Some(predicate), left, right)
+
+    assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr))
+    assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr))
+    assert(split.bothSides.contains((leftColA === rightColC && randCol > lit(0)).expr))
+    assert(split.full.contains(predicate))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org