You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2017/10/15 00:39:20 UTC

spark git commit: [SPARK-22238] Fix plan resolution bug caused by EnsureStatefulOpPartitioning

Repository: spark
Updated Branches:
  refs/heads/master 014dc8471 -> e8547ffb4


[SPARK-22238] Fix plan resolution bug caused by EnsureStatefulOpPartitioning

## What changes were proposed in this pull request?

In EnsureStatefulOpPartitioning, we check that the inputRDD to a SparkPlan has the expected partitioning for Streaming Stateful Operators. The problem is that we are not allowed to access this information during planning.
The reason we added that check was because CoalesceExec could actually create RDDs with 0 partitions. We should fix it such that when CoalesceExec says that there is a SinglePartition, there is in fact an inputRDD of 1 partition instead of 0 partitions.

## How was this patch tested?

Regression test in StreamingQuerySuite

Author: Burak Yavuz <br...@gmail.com>

Closes #19467 from brkyvz/stateful-op.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e8547ffb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e8547ffb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e8547ffb

Branch: refs/heads/master
Commit: e8547ffb49071525c06876c856cecc0d4731b918
Parents: 014dc84
Author: Burak Yavuz <br...@gmail.com>
Authored: Sat Oct 14 17:39:15 2017 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Sat Oct 14 17:39:15 2017 -0700

----------------------------------------------------------------------
 .../catalyst/plans/physical/partitioning.scala  |  15 +-
 .../sql/execution/basicPhysicalOperators.scala  |  27 +++-
 .../execution/exchange/EnsureRequirements.scala |   5 +-
 .../streaming/FlatMapGroupsWithStateExec.scala  |   2 +-
 .../streaming/IncrementalExecution.scala        |  39 ++----
 .../execution/streaming/statefulOperators.scala |  11 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   |   2 +
 .../spark/sql/execution/PlannerSuite.scala      |  17 +++
 .../streaming/state/StateStoreRDDSuite.scala    |   2 +-
 .../SymmetricHashJoinStateManagerSuite.scala    |   2 +-
 .../spark/sql/streaming/DeduplicateSuite.scala  |  11 +-
 .../EnsureStatefulOpPartitioningSuite.scala     | 138 -------------------
 .../streaming/FlatMapGroupsWithStateSuite.scala |   6 +-
 .../sql/streaming/StatefulOperatorTest.scala    |  49 +++++++
 .../streaming/StreamingAggregationSuite.scala   |   8 +-
 .../sql/streaming/StreamingJoinSuite.scala      |   2 +-
 .../sql/streaming/StreamingQuerySuite.scala     |  13 ++
 17 files changed, 160 insertions(+), 189 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 51d78dd..e57c842 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -49,7 +49,9 @@ case object AllTuples extends Distribution
  * can mean such tuples are either co-located in the same partition or they will be contiguous
  * within a single partition.
  */
-case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution {
+case class ClusteredDistribution(
+    clustering: Seq[Expression],
+    numPartitions: Option[Int] = None) extends Distribution {
   require(
     clustering != Nil,
     "The clustering expressions of a ClusteredDistribution should not be Nil. " +
@@ -221,6 +223,7 @@ case object SinglePartition extends Partitioning {
 
   override def satisfies(required: Distribution): Boolean = required match {
     case _: BroadcastDistribution => false
+    case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1)
     case _ => true
   }
 
@@ -243,8 +246,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
 
   override def satisfies(required: Distribution): Boolean = required match {
     case UnspecifiedDistribution => true
-    case ClusteredDistribution(requiredClustering) =>
-      expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
+    case ClusteredDistribution(requiredClustering, desiredPartitions) =>
+      expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
+        desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true
     case _ => false
   }
 
@@ -289,8 +293,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
     case OrderedDistribution(requiredOrdering) =>
       val minSize = Seq(requiredOrdering.size, ordering.size).min
       requiredOrdering.take(minSize) == ordering.take(minSize)
-    case ClusteredDistribution(requiredClustering) =>
-      ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x)))
+    case ClusteredDistribution(requiredClustering, desiredPartitions) =>
+      ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
+        desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true
     case _ => false
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 63cd169..d15ece3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
 import scala.concurrent.{ExecutionContext, Future}
 import scala.concurrent.duration.Duration
 
-import org.apache.spark.{InterruptibleIterator, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
 import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -590,10 +590,33 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
-    child.execute().coalesce(numPartitions, shuffle = false)
+    if (numPartitions == 1 && child.execute().getNumPartitions < 1) {
+      // Make sure we don't output an RDD with 0 partitions, when claiming that we have a
+      // `SinglePartition`.
+      new CoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions)
+    } else {
+      child.execute().coalesce(numPartitions, shuffle = false)
+    }
   }
 }
 
+object CoalesceExec {
+  /** A simple RDD with no data, but with the given number of partitions. */
+  class EmptyRDDWithPartitions(
+      @transient private val sc: SparkContext,
+      numPartitions: Int) extends RDD[InternalRow](sc, Nil) {
+
+    override def getPartitions: Array[Partition] =
+      Array.tabulate(numPartitions)(i => EmptyPartition(i))
+
+    override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
+      Iterator.empty
+    }
+  }
+
+  case class EmptyPartition(index: Int) extends Partition
+}
+
 /**
  * Physical plan for a subquery.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index d28ce60..4e2ca37 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -44,13 +44,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
 
   /**
    * Given a required distribution, returns a partitioning that satisfies that distribution.
+   * @param requiredDistribution The distribution that is required by the operator
+   * @param numPartitions Used when the distribution doesn't require a specific number of partitions
    */
   private def createPartitioning(
       requiredDistribution: Distribution,
       numPartitions: Int): Partitioning = {
     requiredDistribution match {
       case AllTuples => SinglePartition
-      case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
+      case ClusteredDistribution(clustering, desiredPartitions) =>
+        HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions))
       case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
       case dist => sys.error(s"Do not know how to satisfy distribution $dist")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index aab06d6..c81f1a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -64,7 +64,7 @@ case class FlatMapGroupsWithStateExec(
 
   /** Distribute by grouping attributes */
   override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(groupingAttributes) :: Nil
+    ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: Nil
 
   /** Ordering needed for using GroupingIterator */
   override def requiredChildOrdering: Seq[Seq[SortOrder]] =

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 82f879c..2e37863 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.OutputMode
 
 /**
@@ -61,6 +62,10 @@ class IncrementalExecution(
       StreamingDeduplicationStrategy :: Nil
   }
 
+  private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
+    .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
+    .getOrElse(sparkSession.sessionState.conf.numShufflePartitions)
+
   /**
    * See [SPARK-18339]
    * Walk the optimized logical plan and replace CurrentBatchTimestamp
@@ -83,7 +88,11 @@ class IncrementalExecution(
   /** Get the state info of the next stateful operator */
   private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = {
     StatefulOperatorStateInfo(
-      checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId)
+      checkpointLocation,
+      runId,
+      statefulOperatorId.getAndIncrement(),
+      currentBatchId,
+      numStateStores)
   }
 
   /** Locates save/restore pairs surrounding aggregation. */
@@ -130,34 +139,8 @@ class IncrementalExecution(
     }
   }
 
-  override def preparations: Seq[Rule[SparkPlan]] =
-    Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations
+  override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations
 
   /** No need assert supported, as this check has already been done */
   override def assertSupported(): Unit = { }
 }
-
-object EnsureStatefulOpPartitioning extends Rule[SparkPlan] {
-  // Needs to be transformUp to avoid extra shuffles
-  override def apply(plan: SparkPlan): SparkPlan = plan transformUp {
-    case so: StatefulOperator =>
-      val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions
-      val distributions = so.requiredChildDistribution
-      val children = so.children.zip(distributions).map { case (child, reqDistribution) =>
-        val expectedPartitioning = reqDistribution match {
-          case AllTuples => SinglePartition
-          case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions)
-          case _ => throw new AnalysisException("Unexpected distribution expected for " +
-            s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " +
-            s"$reqDistribution.")
-        }
-        if (child.outputPartitioning.guarantees(expectedPartitioning) &&
-            child.execute().getNumPartitions == expectedPartitioning.numPartitions) {
-          child
-        } else {
-          ShuffleExchangeExec(expectedPartitioning, child)
-        }
-      }
-      so.withNewChildren(children)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 0d85542..b9b07a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -43,10 +43,11 @@ case class StatefulOperatorStateInfo(
     checkpointLocation: String,
     queryRunId: UUID,
     operatorId: Long,
-    storeVersion: Long) {
+    storeVersion: Long,
+    numPartitions: Int) {
   override def toString(): String = {
     s"state info [ checkpoint = $checkpointLocation, runId = $queryRunId, " +
-      s"opId = $operatorId, ver = $storeVersion]"
+      s"opId = $operatorId, ver = $storeVersion, numPartitions = $numPartitions]"
   }
 }
 
@@ -239,7 +240,7 @@ case class StateStoreRestoreExec(
     if (keyExpressions.isEmpty) {
       AllTuples :: Nil
     } else {
-      ClusteredDistribution(keyExpressions) :: Nil
+      ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
     }
   }
 }
@@ -386,7 +387,7 @@ case class StateStoreSaveExec(
     if (keyExpressions.isEmpty) {
       AllTuples :: Nil
     } else {
-      ClusteredDistribution(keyExpressions) :: Nil
+      ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
     }
   }
 }
@@ -401,7 +402,7 @@ case class StreamingDeduplicateExec(
 
   /** Distribute by grouping attributes */
   override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(keyExpressions) :: Nil
+    ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
 
   override protected def doExecute(): RDD[InternalRow] = {
     metrics // force lazy init at driver

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index ad461fa..50de2fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -368,6 +368,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     checkAnswer(
       testData.select('key).coalesce(1).select('key),
       testData.select('key).collect().toSeq)
+
+    assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1)
   }
 
   test("convert $\"attribute name\" into unresolved attribute") {

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 8606636..c25c90d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -425,6 +425,23 @@ class PlannerSuite extends SharedSQLContext {
     }
   }
 
+  test("EnsureRequirements should respect ClusteredDistribution's num partitioning") {
+    val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13))
+    // Number of partitions differ
+    val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13)
+    val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
+    assert(!childPartitioning.satisfies(distribution))
+    val inputPlan = DummySparkPlan(
+        children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
+        requiredChildDistribution = Seq(distribution),
+        requiredChildOrdering = Seq(Seq.empty))
+
+    val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
+    val shuffle = outputPlan.collect { case e: ShuffleExchangeExec => e }
+    assert(shuffle.size === 1)
+    assert(shuffle.head.newPartitioning === finalPartitioning)
+  }
+
   test("Reuse exchanges") {
     val distribution = ClusteredDistribution(Literal(1) :: Nil)
     val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5)

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index defb9ed..65b39f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -214,7 +214,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       path: String,
       queryRunId: UUID = UUID.randomUUID,
       version: Int = 0): StatefulOperatorStateInfo = {
-    StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version)
+    StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, numPartitions = 5)
   }
 
   private val increment = (store: StateStore, iter: Iterator[String]) => {

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
index d44af1d..c0216a2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
@@ -160,7 +160,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter
 
     withTempDir { file =>
       val storeConf = new StateStoreConf()
-      val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0)
+      val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5)
       val manager = new SymmetricHashJoinStateManager(
         LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration)
       try {

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
index e858b7d..caf2bab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
@@ -19,12 +19,15 @@ package org.apache.spark.sql.streaming
 
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec}
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.functions._
 
-class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
+class DeduplicateSuite extends StateStoreMetricsTest
+    with BeforeAndAfterAll
+    with StatefulOperatorTest {
 
   import testImplicits._
 
@@ -41,6 +44,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
       AddData(inputData, "a"),
       CheckLastBatch("a"),
       assertNumStateRows(total = 1, updated = 1),
+      AssertOnQuery(sq =>
+        checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))),
       AddData(inputData, "a"),
       CheckLastBatch(),
       assertNumStateRows(total = 1, updated = 0),
@@ -58,6 +63,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
       AddData(inputData, "a" -> 1),
       CheckLastBatch("a" -> 1),
       assertNumStateRows(total = 1, updated = 1),
+      AssertOnQuery(sq =>
+        checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))),
       AddData(inputData, "a" -> 2), // Dropped
       CheckLastBatch(),
       assertNumStateRows(total = 1, updated = 0),

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala
deleted file mode 100644
index ed9823f..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.streaming
-
-import java.util.UUID
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode}
-import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo}
-import org.apache.spark.sql.test.SharedSQLContext
-
-class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext {
-
-  import testImplicits._
-
-  private var baseDf: DataFrame = null
-
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char")
-  }
-
-  test("ClusteredDistribution generates Exchange with HashPartitioning") {
-    testEnsureStatefulOpPartitioning(
-      baseDf.queryExecution.sparkPlan,
-      requiredDistribution = keys => ClusteredDistribution(keys),
-      expectedPartitioning =
-        keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
-      expectShuffle = true)
-  }
-
-  test("ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning") {
-    testEnsureStatefulOpPartitioning(
-      baseDf.coalesce(1).queryExecution.sparkPlan,
-      requiredDistribution = keys => ClusteredDistribution(keys),
-      expectedPartitioning =
-        keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
-      expectShuffle = true)
-  }
-
-  test("AllTuples generates Exchange with SinglePartition") {
-    testEnsureStatefulOpPartitioning(
-      baseDf.queryExecution.sparkPlan,
-      requiredDistribution = _ => AllTuples,
-      expectedPartitioning = _ => SinglePartition,
-      expectShuffle = true)
-  }
-
-  test("AllTuples with coalesce(1) doesn't need Exchange") {
-    testEnsureStatefulOpPartitioning(
-      baseDf.coalesce(1).queryExecution.sparkPlan,
-      requiredDistribution = _ => AllTuples,
-      expectedPartitioning = _ => SinglePartition,
-      expectShuffle = false)
-  }
-
-  /**
-   * For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan
-   * `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to
-   * ensure the expected partitioning.
-   */
-  private def testEnsureStatefulOpPartitioning(
-      inputPlan: SparkPlan,
-      requiredDistribution: Seq[Attribute] => Distribution,
-      expectedPartitioning: Seq[Attribute] => Partitioning,
-      expectShuffle: Boolean): Unit = {
-    val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1)))
-    val executed = executePlan(operator, OutputMode.Complete())
-    if (expectShuffle) {
-      val exchange = executed.children.find(_.isInstanceOf[Exchange])
-      if (exchange.isEmpty) {
-        fail(s"Was expecting an exchange but didn't get one in:\n$executed")
-      }
-      assert(exchange.get ===
-        ShuffleExchangeExec(expectedPartitioning(inputPlan.output.take(1)), inputPlan),
-        s"Exchange didn't have expected properties:\n${exchange.get}")
-    } else {
-      assert(!executed.children.exists(_.isInstanceOf[Exchange]),
-        s"Unexpected exchange found in:\n$executed")
-    }
-  }
-
-  /** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */
-  private def executePlan(
-      p: SparkPlan,
-      outputMode: OutputMode = OutputMode.Append()): SparkPlan = {
-    val execution = new IncrementalExecution(
-      spark,
-      null,
-      OutputMode.Complete(),
-      "chk",
-      UUID.randomUUID(),
-      0L,
-      OffsetSeqMetadata()) {
-      override lazy val sparkPlan: SparkPlan = p transform {
-        case plan: SparkPlan =>
-          val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
-          plan transformExpressions {
-            case UnresolvedAttribute(Seq(u)) =>
-              inputMap.getOrElse(u,
-                sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
-          }
-      }
-    }
-    execution.executedPlan
-  }
-}
-
-/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */
-case class TestStatefulOperator(
-    child: SparkPlan,
-    requiredDist: Distribution) extends UnaryExecNode with StatefulOperator {
-  override def output: Seq[Attribute] = child.output
-  override def doExecute(): RDD[InternalRow] = child.execute()
-  override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil
-  override def stateInfo: Option[StatefulOperatorStateInfo] = None
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index d2e8beb..aeb8383 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -41,7 +41,9 @@ case class RunningCount(count: Long)
 
 case class Result(key: Long, count: Int)
 
-class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
+class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
+    with BeforeAndAfterAll
+    with StatefulOperatorTest {
 
   import testImplicits._
   import GroupStateImpl._
@@ -544,6 +546,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
       AddData(inputData, "a"),
       CheckLastBatch(("a", "1")),
       assertNumStateRows(total = 1, updated = 1),
+      AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec](
+        sq, Seq("value"))),
       AddData(inputData, "a", "b"),
       CheckLastBatch(("a", "2"), ("b", "1")),
       assertNumStateRows(total = 2, updated = 2),

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala
new file mode 100644
index 0000000..4514227
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.streaming._
+
+trait StatefulOperatorTest {
+  /**
+   * Check that the output partitioning of a child operator of a Stateful operator satisfies the
+   * distribution that we expect for our Stateful operator.
+   */
+  protected def checkChildOutputHashPartitioning[T <: StatefulOperator](
+      sq: StreamingQuery,
+      colNames: Seq[String]): Boolean = {
+    val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output
+    val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions
+    val groupingAttr = attr.filter(a => colNames.contains(a.name))
+    checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions))
+  }
+
+  /**
+   * Check that the output partitioning of a child operator of a Stateful operator satisfies the
+   * distribution that we expect for our Stateful operator.
+   */
+  protected def checkChildOutputPartitioning[T <: StatefulOperator](
+      sq: StreamingQuery,
+      expectedPartitioning: Partitioning): Boolean = {
+    val operator = sq.asInstanceOf[StreamExecution].lastExecution
+      .executedPlan.collect { case p: T => p }
+    operator.head.children.forall(
+      _.outputPartitioning.numPartitions == expectedPartitioning.numPartitions)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index fe7efa6..1b4d855 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -44,7 +44,7 @@ object FailureSingleton {
 }
 
 class StreamingAggregationSuite extends StateStoreMetricsTest
-    with BeforeAndAfterAll with Assertions {
+    with BeforeAndAfterAll with Assertions with StatefulOperatorTest {
 
   override def afterAll(): Unit = {
     super.afterAll()
@@ -281,6 +281,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
       AddData(inputData, 0L, 5L, 5L, 10L),
       AdvanceManualClock(10 * 1000),
       CheckLastBatch((0L, 1), (5L, 2), (10L, 1)),
+      AssertOnQuery(sq =>
+        checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))),
 
       // advance clock to 20 seconds, should retain keys >= 10
       AddData(inputData, 15L, 15L, 20L),
@@ -455,8 +457,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
         },
         AddBlockData(inputSource), // create an empty trigger
         CheckLastBatch(1),
-        AssertOnQuery("Verify addition of exchange operator") { se =>
-          checkAggregationChain(se, expectShuffling = true, 1)
+        AssertOnQuery("Verify that no exchange is required") { se =>
+          checkAggregationChain(se, expectShuffling = false, 1)
         },
         AddBlockData(inputSource, Seq(2, 3)),
         CheckLastBatch(3),

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index a6593b7..d326172 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -330,7 +330,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
       val queryId = UUID.randomUUID
       val opId = 0
       val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextString(10)).toString
-      val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L)
+      val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L, 5)
 
       implicit val sqlContext = spark.sqlContext
       val coordinatorRef = sqlContext.streams.stateStoreCoordinator

http://git-wip-us.apache.org/repos/asf/spark/blob/e8547ffb/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index ab35079..c53889b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -652,6 +652,19 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
     }
   }
 
+  test("SPARK-22238: don't check for RDD partitions during streaming aggregation preparation") {
+    val stream = MemoryStream[(Int, Int)]
+    val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char").where("char = 'A'")
+    val otherDf = stream.toDF().toDF("num", "numSq")
+      .join(broadcast(baseDf), "num")
+      .groupBy('char)
+      .agg(sum('numSq))
+
+    testStream(otherDf, OutputMode.Complete())(
+      AddData(stream, (1, 1), (2, 4)),
+      CheckLastBatch(("A", 1)))
+  }
+
   /** Create a streaming DF that only execute one batch in which it returns the given static DF */
   private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = {
     require(!triggerDF.isStreaming)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org