You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2018/05/21 19:58:11 UTC
spark git commit: [SPARK-24234][SS] Reader for continuous processing
shuffle
Repository: spark
Updated Branches:
refs/heads/master 03e90f65b -> a33dcf4a0
[SPARK-24234][SS] Reader for continuous processing shuffle
## What changes were proposed in this pull request?
Read RDD for continuous processing shuffle, as well as the initial RPC-based row receiver.
https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit#heading=h.8t3ci57f7uii
## How was this patch tested?
new unit tests
Author: Jose Torres <to...@gmail.com>
Closes #21337 from jose-torres/readerRddMaster.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a33dcf4a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a33dcf4a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a33dcf4a
Branch: refs/heads/master
Commit: a33dcf4a0bbe20dce6f1e1e6c2e1c3828291fb3d
Parents: 03e90f6
Author: Jose Torres <to...@gmail.com>
Authored: Mon May 21 12:58:05 2018 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Mon May 21 12:58:05 2018 -0700
----------------------------------------------------------------------
.../shuffle/ContinuousShuffleReadRDD.scala | 61 ++++++
.../shuffle/ContinuousShuffleReader.scala | 32 ++++
.../continuous/shuffle/UnsafeRowReceiver.scala | 75 ++++++++
.../shuffle/ContinuousShuffleReadSuite.scala | 184 +++++++++++++++++++
4 files changed, 352 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a33dcf4a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
new file mode 100644
index 0000000..270b1a5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import java.util.UUID
+
+import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.NextIterator
+
+case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition {
+ // Initialized only on the executor, and only once even as we call compute() multiple times.
+ lazy val (reader: ContinuousShuffleReader, endpoint) = {
+ val env = SparkEnv.get.rpcEnv
+ val receiver = new UnsafeRowReceiver(queueSize, env)
+ val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver)
+ TaskContext.get().addTaskCompletionListener { ctx =>
+ env.stop(endpoint)
+ }
+ (receiver, endpoint)
+ }
+}
+
+/**
+ * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their
+ * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks
+ * poll from their receiver until an epoch marker is sent.
+ */
+class ContinuousShuffleReadRDD(
+ sc: SparkContext,
+ numPartitions: Int,
+ queueSize: Int = 1024)
+ extends RDD[UnsafeRow](sc, Nil) {
+
+ override protected def getPartitions: Array[Partition] = {
+ (0 until numPartitions).map { partIndex =>
+ ContinuousShuffleReadPartition(partIndex, queueSize)
+ }.toArray
+ }
+
+ override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
+ split.asInstanceOf[ContinuousShuffleReadPartition].reader.read()
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/a33dcf4a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala
new file mode 100644
index 0000000..42631c9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+/**
+ * Trait for reading from a continuous processing shuffle.
+ */
+trait ContinuousShuffleReader {
+ /**
+ * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting
+ * for new rows to arrive, and end the iterator once they've received epoch markers from all
+ * shuffle writers.
+ */
+ def read(): Iterator[UnsafeRow]
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/a33dcf4a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
new file mode 100644
index 0000000..b8adbb7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue}
+import java.util.concurrent.atomic.AtomicBoolean
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.util.NextIterator
+
+/**
+ * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker.
+ */
+private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable
+private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage
+private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage
+
+/**
+ * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle
+ * writers will send rows here, with continuous shuffle readers polling for new rows as needed.
+ *
+ * TODO: Support multiple source tasks. We need to output a single epoch marker once all
+ * source tasks have sent one.
+ */
+private[shuffle] class UnsafeRowReceiver(
+ queueSize: Int,
+ override val rpcEnv: RpcEnv)
+ extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging {
+ // Note that this queue will be drained from the main task thread and populated in the RPC
+ // response thread.
+ private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize)
+
+ // Exposed for testing to determine if the endpoint gets stopped on task end.
+ private[shuffle] val stopped = new AtomicBoolean(false)
+
+ override def onStop(): Unit = {
+ stopped.set(true)
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case r: UnsafeRowReceiverMessage =>
+ queue.put(r)
+ context.reply(())
+ }
+
+ override def read(): Iterator[UnsafeRow] = {
+ new NextIterator[UnsafeRow] {
+ override def getNext(): UnsafeRow = queue.take() match {
+ case ReceiverRow(r) => r
+ case ReceiverEpochMarker() =>
+ finished = true
+ null
+ }
+
+ override def close(): Unit = {}
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/a33dcf4a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala
new file mode 100644
index 0000000..b25e75b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import org.apache.spark.{TaskContext, TaskContextImpl}
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.types.{DataType, IntegerType}
+
+class ContinuousShuffleReadSuite extends StreamTest {
+
+ private def unsafeRow(value: Int) = {
+ UnsafeProjection.create(Array(IntegerType : DataType))(
+ new GenericInternalRow(Array(value: Any)))
+ }
+
+ private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = {
+ messages.foreach(endpoint.askSync[Unit](_))
+ }
+
+ // In this unit test, we emulate that we're in the task thread where
+ // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context
+ // thread local to be set.
+ var ctx: TaskContextImpl = _
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ ctx = TaskContext.empty()
+ TaskContext.setTaskContext(ctx)
+ }
+
+ override def afterEach(): Unit = {
+ ctx.markTaskCompleted(None)
+ TaskContext.unset()
+ ctx = null
+ super.afterEach()
+ }
+
+ test("receiver stopped with row last") {
+ val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+ val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+ send(
+ endpoint,
+ ReceiverEpochMarker(),
+ ReceiverRow(unsafeRow(111))
+ )
+
+ ctx.markTaskCompleted(None)
+ val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
+ eventually(timeout(streamingTimeout)) {
+ assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get())
+ }
+ }
+
+ test("receiver stopped with marker last") {
+ val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+ val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+ endpoint.askSync[Unit](ReceiverRow(unsafeRow(111)))
+ endpoint.askSync[Unit](ReceiverEpochMarker())
+
+ ctx.markTaskCompleted(None)
+ val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
+ eventually(timeout(streamingTimeout)) {
+ assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get())
+ }
+ }
+
+ test("one epoch") {
+ val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+ val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+ send(
+ endpoint,
+ ReceiverRow(unsafeRow(111)),
+ ReceiverRow(unsafeRow(222)),
+ ReceiverRow(unsafeRow(333)),
+ ReceiverEpochMarker()
+ )
+
+ val iter = rdd.compute(rdd.partitions(0), ctx)
+ assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333))
+ }
+
+ test("multiple epochs") {
+ val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+ val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+ send(
+ endpoint,
+ ReceiverRow(unsafeRow(111)),
+ ReceiverEpochMarker(),
+ ReceiverRow(unsafeRow(222)),
+ ReceiverRow(unsafeRow(333)),
+ ReceiverEpochMarker()
+ )
+
+ val firstEpoch = rdd.compute(rdd.partitions(0), ctx)
+ assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111))
+
+ val secondEpoch = rdd.compute(rdd.partitions(0), ctx)
+ assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333))
+ }
+
+ test("empty epochs") {
+ val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+ val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+ send(
+ endpoint,
+ ReceiverEpochMarker(),
+ ReceiverEpochMarker(),
+ ReceiverRow(unsafeRow(111)),
+ ReceiverEpochMarker(),
+ ReceiverEpochMarker(),
+ ReceiverEpochMarker()
+ )
+
+ assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
+ assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
+
+ val thirdEpoch = rdd.compute(rdd.partitions(0), ctx)
+ assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111))
+
+ assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
+ assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
+ }
+
+ test("multiple partitions") {
+ val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5)
+ // Send all data before processing to ensure there's no crossover.
+ for (p <- rdd.partitions) {
+ val part = p.asInstanceOf[ContinuousShuffleReadPartition]
+ // Send index for identification.
+ send(
+ part.endpoint,
+ ReceiverRow(unsafeRow(part.index)),
+ ReceiverEpochMarker()
+ )
+ }
+
+ for (p <- rdd.partitions) {
+ val part = p.asInstanceOf[ContinuousShuffleReadPartition]
+ val iter = rdd.compute(part, ctx)
+ assert(iter.next().getInt(0) == part.index)
+ assert(!iter.hasNext)
+ }
+ }
+
+ test("blocks waiting for new rows") {
+ val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+
+ val readRowThread = new Thread {
+ override def run(): Unit = {
+ // set the non-inheritable thread local
+ TaskContext.setTaskContext(ctx)
+ val epoch = rdd.compute(rdd.partitions(0), ctx)
+ epoch.next().getInt(0)
+ }
+ }
+
+ try {
+ readRowThread.start()
+ eventually(timeout(streamingTimeout)) {
+ assert(readRowThread.getState == Thread.State.WAITING)
+ }
+ } finally {
+ readRowThread.interrupt()
+ readRowThread.join()
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org