You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/09/09 00:59:24 UTC
[1/5] [SPARK-3019] Pluggable block transfer interface
(BlockTransferService)
Repository: spark
Updated Branches:
refs/heads/master 939a322c8 -> 08ce18881
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index c200654..e251660 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -21,15 +21,19 @@ import java.nio.{ByteBuffer, MappedByteBuffer}
import java.util.Arrays
import java.util.concurrent.TimeUnit
+import org.apache.spark.network.nio.NioBlockTransferService
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.Await
+import scala.concurrent.duration._
+import scala.language.implicitConversions
+import scala.language.postfixOps
+
import akka.actor._
import akka.pattern.ask
import akka.util.Timeout
-import org.apache.spark.shuffle.hash.HashShuffleManager
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.Matchers.any
-import org.mockito.Mockito.{doAnswer, mock, spy, when}
-import org.mockito.stubbing.Answer
+import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester}
import org.scalatest.concurrent.Eventually._
@@ -38,18 +42,12 @@ import org.scalatest.Matchers
import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
import org.apache.spark.executor.DataReadMethod
-import org.apache.spark.network.{Message, ConnectionManagerId}
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils}
-import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.Await
-import scala.concurrent.duration._
-import scala.language.implicitConversions
-import scala.language.postfixOps
-import org.apache.spark.shuffle.ShuffleBlockManager
class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
with PrivateMethodTester {
@@ -74,8 +72,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)
private def makeBlockManager(maxMem: Long, name: String = "<driver>"): BlockManager = {
- new BlockManager(name, actorSystem, master, serializer, maxMem, conf, securityMgr,
- mapOutputTracker, shuffleManager)
+ val transfer = new NioBlockTransferService(conf, securityMgr)
+ new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+ mapOutputTracker, shuffleManager, transfer)
}
before {
@@ -793,8 +792,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
+ val transfer = new NioBlockTransferService(conf, securityMgr)
store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer(conf), 1200, conf,
- securityMgr, mapOutputTracker, shuffleManager)
+ mapOutputTracker, shuffleManager, transfer)
// The put should fail since a1 is not serializable.
class UnserializableClass
@@ -1005,109 +1005,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store")
}
- test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
- securityMgr, mapOutputTracker, shuffleManager)
-
- val worker = spy(new BlockManagerWorker(store))
- val connManagerId = mock(classOf[ConnectionManagerId])
-
- // setup request block messages
- val reqBlId1 = ShuffleBlockId(0,0,0)
- val reqBlId2 = ShuffleBlockId(0,1,0)
- val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1))
- val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2))
- val reqBlockMessages = new BlockMessageArray(
- Seq(reqBlockMessage1, reqBlockMessage2))
- val reqBufferMessage = reqBlockMessages.toBufferMessage
-
- val answer = new Answer[Option[BlockMessage]] {
- override def answer(invocation: InvocationOnMock)
- :Option[BlockMessage]= {
- throw new Exception
- }
- }
-
- doAnswer(answer).when(worker).processBlockMessage(any())
-
- // Test when exception was thrown during processing block messages
- var ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId)
-
- assert(ackMessage.isDefined, "When Exception was thrown in " +
- "BlockManagerWorker#processBlockMessage, " +
- "ackMessage should be defined")
- assert(ackMessage.get.hasError, "When Exception was thown in " +
- "BlockManagerWorker#processBlockMessage, " +
- "ackMessage should have error")
-
- val notBufferMessage = mock(classOf[Message])
-
- // Test when not BufferMessage was received
- ackMessage = worker.onBlockMessageReceive(notBufferMessage, connManagerId)
- assert(ackMessage.isDefined, "When not BufferMessage was passed to " +
- "BlockManagerWorker#onBlockMessageReceive, " +
- "ackMessage should be defined")
- assert(ackMessage.get.hasError, "When not BufferMessage was passed to " +
- "BlockManagerWorker#onBlockMessageReceive, " +
- "ackMessage should have error")
- }
-
- test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
- securityMgr, mapOutputTracker, shuffleManager)
-
- val worker = spy(new BlockManagerWorker(store))
- val connManagerId = mock(classOf[ConnectionManagerId])
-
- // setup request block messages
- val reqBlId1 = ShuffleBlockId(0,0,0)
- val reqBlId2 = ShuffleBlockId(0,1,0)
- val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1))
- val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2))
- val reqBlockMessages = new BlockMessageArray(
- Seq(reqBlockMessage1, reqBlockMessage2))
-
- val tmpBufferMessage = reqBlockMessages.toBufferMessage
- val buffer = ByteBuffer.allocate(tmpBufferMessage.size)
- val arrayBuffer = new ArrayBuffer[ByteBuffer]
- tmpBufferMessage.buffers.foreach{ b =>
- buffer.put(b)
- }
- buffer.flip()
- arrayBuffer += buffer
- val reqBufferMessage = Message.createBufferMessage(arrayBuffer)
-
- // setup ack block messages
- val buf1 = ByteBuffer.allocate(4)
- val buf2 = ByteBuffer.allocate(4)
- buf1.putInt(1)
- buf1.flip()
- buf2.putInt(1)
- buf2.flip()
- val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1))
- val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2))
-
- val answer = new Answer[Option[BlockMessage]] {
- override def answer(invocation: InvocationOnMock)
- :Option[BlockMessage]= {
- if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq(
- reqBlockMessage1)) {
- return Some(ackBlockMessage1)
- } else {
- return Some(ackBlockMessage2)
- }
- }
- }
-
- doAnswer(answer).when(worker).processBlockMessage(any())
-
- val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId)
- assert(ackMessage.isDefined, "When BlockManagerWorker#onBlockMessageReceive " +
- "was executed successfully, ackMessage should be defined")
- assert(!ackMessage.get.hasError, "When BlockManagerWorker#onBlockMessageReceive " +
- "was executed successfully, ackMessage should not have error")
- }
-
test("reserve/release unroll memory") {
store = makeBlockManager(12000)
val memoryStore = store.memoryStore
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index 26082de..e4522e0 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
import java.io.{File, FileWriter}
+import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.shuffle.hash.HashShuffleManager
import scala.collection.mutable
@@ -52,7 +53,6 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
rootDir1 = Files.createTempDir()
rootDir1.deleteOnExit()
rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
- println("Created root dirs: " + rootDirs)
}
override def afterAll() {
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
new file mode 100644
index 0000000..809bd70
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.TaskContext
+import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
+
+import org.mockito.Mockito._
+import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+
+import org.scalatest.FunSuite
+
+
+class ShuffleBlockFetcherIteratorSuite extends FunSuite {
+
+ test("handle local read failures in BlockManager") {
+ val transfer = mock(classOf[BlockTransferService])
+ val blockManager = mock(classOf[BlockManager])
+ doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
+
+ val blIds = Array[BlockId](
+ ShuffleBlockId(0,0,0),
+ ShuffleBlockId(0,1,0),
+ ShuffleBlockId(0,2,0),
+ ShuffleBlockId(0,3,0),
+ ShuffleBlockId(0,4,0))
+
+ val optItr = mock(classOf[Option[Iterator[Any]]])
+ val answer = new Answer[Option[Iterator[Any]]] {
+ override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] {
+ throw new Exception
+ }
+ }
+
+ // 3rd block is going to fail
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
+ doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
+
+ val bmId = BlockManagerId("test-client", "test-client", 1)
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+ )
+
+ val iterator = new ShuffleBlockFetcherIterator(
+ new TaskContext(0, 0, 0),
+ transfer,
+ blockManager,
+ blocksByAddress,
+ null,
+ 48 * 1024 * 1024)
+
+ // Without exhausting the iterator, the iterator should be lazy and not call
+ // getLocalShuffleFromDisk.
+ verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
+
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
+ // the 2nd element of the tuple returned by iterator.next should be defined when
+ // fetching successfully
+ assert(iterator.next()._2.isDefined,
+ "1st element should be defined but is not actually defined")
+ verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any())
+
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
+ assert(iterator.next()._2.isDefined,
+ "2nd element should be defined but is not actually defined")
+ verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any())
+
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
+ // 3rd fetch should be failed
+ intercept[Exception] {
+ iterator.next()
+ }
+ verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any())
+ }
+
+ test("handle local read successes") {
+ val transfer = mock(classOf[BlockTransferService])
+ val blockManager = mock(classOf[BlockManager])
+ doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
+
+ val blIds = Array[BlockId](
+ ShuffleBlockId(0,0,0),
+ ShuffleBlockId(0,1,0),
+ ShuffleBlockId(0,2,0),
+ ShuffleBlockId(0,3,0),
+ ShuffleBlockId(0,4,0))
+
+ val optItr = mock(classOf[Option[Iterator[Any]]])
+
+ // All blocks should be fetched successfully
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
+
+ val bmId = BlockManagerId("test-client", "test-client", 1)
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+ )
+
+ val iterator = new ShuffleBlockFetcherIterator(
+ new TaskContext(0, 0, 0),
+ transfer,
+ blockManager,
+ blocksByAddress,
+ null,
+ 48 * 1024 * 1024)
+
+ // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk.
+ verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
+
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 1st element is not actually defined")
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 2nd element is not actually defined")
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 3rd element is not actually defined")
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 4th element is not actually defined")
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 5th element is not actually defined")
+
+ verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any())
+ }
+
+ test("handle remote fetch failures in BlockTransferService") {
+ val transfer = mock(classOf[BlockTransferService])
+ when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
+ listener.onBlockFetchFailure(new Exception("blah"))
+ }
+ })
+
+ val blockManager = mock(classOf[BlockManager])
+
+ when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1))
+
+ val blId1 = ShuffleBlockId(0, 0, 0)
+ val blId2 = ShuffleBlockId(0, 1, 0)
+ val bmId = BlockManagerId("test-server", "test-server", 1)
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (bmId, Seq((blId1, 1L), (blId2, 1L))))
+
+ val iterator = new ShuffleBlockFetcherIterator(
+ new TaskContext(0, 0, 0),
+ transfer,
+ blockManager,
+ blocksByAddress,
+ null,
+ 48 * 1024 * 1024)
+
+ iterator.foreach { case (_, iterOption) =>
+ assert(!iterOption.isDefined)
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org
[4/5] [SPARK-3019] Pluggable block transfer interface
(BlockTransferService)
Posted by rx...@apache.org.
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
deleted file mode 100644
index 4894ecd..0000000
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.nio.ByteBuffer
-
-import scala.concurrent.Await
-import scala.concurrent.duration._
-import scala.io.Source
-
-import org.apache.spark._
-
-private[spark] object ConnectionManagerTest extends Logging{
- def main(args: Array[String]) {
- // <mesos cluster> - the master URL <slaves file> - a list slaves to run connectionTest on
- // [num of tasks] - the number of parallel tasks to be initiated default is number of slave
- // hosts [size of msg in MB (integer)] - the size of messages to be sent in each task,
- // default is 10 [count] - how many times to run, default is 3 [await time in seconds] :
- // await time (in seconds), default is 600
- if (args.length < 2) {
- println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> [num of tasks] " +
- "[size of msg in MB (integer)] [count] [await time in seconds)] ")
- System.exit(1)
- }
-
- if (args(0).startsWith("local")) {
- println("This runs only on a mesos cluster")
- }
-
- val sc = new SparkContext(args(0), "ConnectionManagerTest")
- val slavesFile = Source.fromFile(args(1))
- val slaves = slavesFile.mkString.split("\n")
- slavesFile.close()
-
- /* println("Slaves") */
- /* slaves.foreach(println) */
- val tasknum = if (args.length > 2) args(2).toInt else slaves.length
- val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024
- val count = if (args.length > 4) args(4).toInt else 3
- val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second
- println("Running " + count + " rounds of test: " + "parallel tasks = " + tasknum + ", " +
- "msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime)
- val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map(
- i => SparkEnv.get.connectionManager.id).collect()
- println("\nSlave ConnectionManagerIds")
- slaveConnManagerIds.foreach(println)
- println
-
- (0 until count).foreach(i => {
- val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => {
- val connManager = SparkEnv.get.connectionManager
- val thisConnManagerId = connManager.id
- connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- logInfo("Received [" + msg + "] from [" + id + "]")
- None
- })
-
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- val startTime = System.currentTimeMillis
- val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map{ slaveConnManagerId =>
- {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
- connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
- }
- }
- val results = futures.map(f => Await.result(f, awaitTime))
- val finishTime = System.currentTimeMillis
- Thread.sleep(5000)
-
- val mb = size * results.size / 1024.0 / 1024.0
- val ms = finishTime - startTime
- val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms *
- 1000.0) + " MB/s"
- logInfo(resultStr)
- resultStr
- }).collect()
-
- println("---------------------")
- println("Run " + i)
- resultStrs.foreach(println)
- println("---------------------")
- })
- }
-}
-
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
new file mode 100644
index 0000000..dcecb6b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.io.{FileInputStream, RandomAccessFile, File, InputStream}
+import java.nio.ByteBuffer
+import java.nio.channels.FileChannel.MapMode
+
+import com.google.common.io.ByteStreams
+import io.netty.buffer.{ByteBufInputStream, ByteBuf}
+
+import org.apache.spark.util.ByteBufferInputStream
+
+
+/**
+ * This interface provides an immutable view for data in the form of bytes. The implementation
+ * should specify how the data is provided:
+ *
+ * - FileSegmentManagedBuffer: data backed by part of a file
+ * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer
+ * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf
+ */
+sealed abstract class ManagedBuffer {
+ // Note that all the methods are defined with parenthesis because their implementations can
+ // have side effects (io operations).
+
+ /** Number of bytes of the data. */
+ def size: Long
+
+ /**
+ * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
+ * returned ByteBuffer should not affect the content of this buffer.
+ */
+ def nioByteBuffer(): ByteBuffer
+
+ /**
+ * Exposes this buffer's data as an InputStream. The underlying implementation does not
+ * necessarily check for the length of bytes read, so the caller is responsible for making sure
+ * it does not go over the limit.
+ */
+ def inputStream(): InputStream
+}
+
+
+/**
+ * A [[ManagedBuffer]] backed by a segment in a file
+ */
+final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long)
+ extends ManagedBuffer {
+
+ override def size: Long = length
+
+ override def nioByteBuffer(): ByteBuffer = {
+ val channel = new RandomAccessFile(file, "r").getChannel
+ channel.map(MapMode.READ_ONLY, offset, length)
+ }
+
+ override def inputStream(): InputStream = {
+ val is = new FileInputStream(file)
+ is.skip(offset)
+ ByteStreams.limit(is, length)
+ }
+}
+
+
+/**
+ * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]].
+ */
+final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer {
+
+ override def size: Long = buf.remaining()
+
+ override def nioByteBuffer() = buf.duplicate()
+
+ override def inputStream() = new ByteBufferInputStream(buf)
+}
+
+
+/**
+ * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]].
+ */
+final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer {
+
+ override def size: Long = buf.readableBytes()
+
+ override def nioByteBuffer() = buf.nioBuffer()
+
+ override def inputStream() = new ByteBufInputStream(buf)
+
+ // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it.
+ def release(): Unit = buf.release()
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/Message.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala
deleted file mode 100644
index 04ea50f..0000000
--- a/core/src/main/scala/org/apache/spark/network/Message.scala
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.net.InetSocketAddress
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-
-private[spark] abstract class Message(val typ: Long, val id: Int) {
- var senderAddress: InetSocketAddress = null
- var started = false
- var startTime = -1L
- var finishTime = -1L
- var isSecurityNeg = false
- var hasError = false
-
- def size: Int
-
- def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
-
- def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
-
- def timeTaken(): String = (finishTime - startTime).toString + " ms"
-
- override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
-}
-
-
-private[spark] object Message {
- val BUFFER_MESSAGE = 1111111111L
-
- var lastId = 1
-
- def getNewId() = synchronized {
- lastId += 1
- if (lastId == 0) {
- lastId += 1
- }
- lastId
- }
-
- def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
- if (dataBuffers == null) {
- return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
- }
- if (dataBuffers.exists(_ == null)) {
- throw new Exception("Attempting to create buffer message with null buffer")
- }
- new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
- }
-
- def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
- createBufferMessage(dataBuffers, 0)
-
- def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
- if (dataBuffer == null) {
- createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
- } else {
- createBufferMessage(Array(dataBuffer), ackId)
- }
- }
-
- def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
- createBufferMessage(dataBuffer, 0)
-
- def createBufferMessage(ackId: Int): BufferMessage = {
- createBufferMessage(new Array[ByteBuffer](0), ackId)
- }
-
- def create(header: MessageChunkHeader): Message = {
- val newMessage: Message = header.typ match {
- case BUFFER_MESSAGE => new BufferMessage(header.id,
- ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
- }
- newMessage.hasError = header.hasError
- newMessage.senderAddress = header.address
- newMessage
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/MessageChunk.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/MessageChunk.scala
deleted file mode 100644
index d0f986a..0000000
--- a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-
-private[network]
-class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
-
- val size = if (buffer == null) 0 else buffer.remaining
-
- lazy val buffers = {
- val ab = new ArrayBuffer[ByteBuffer]()
- ab += header.buffer
- if (buffer != null) {
- ab += buffer
- }
- ab
- }
-
- override def toString = {
- "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
deleted file mode 100644
index f3ecca5..0000000
--- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.net.InetAddress
-import java.net.InetSocketAddress
-import java.nio.ByteBuffer
-
-private[spark] class MessageChunkHeader(
- val typ: Long,
- val id: Int,
- val totalSize: Int,
- val chunkSize: Int,
- val other: Int,
- val hasError: Boolean,
- val securityNeg: Int,
- val address: InetSocketAddress) {
- lazy val buffer = {
- // No need to change this, at 'use' time, we do a reverse lookup of the hostname.
- // Refer to network.Connection
- val ip = address.getAddress.getAddress()
- val port = address.getPort()
- ByteBuffer.
- allocate(MessageChunkHeader.HEADER_SIZE).
- putLong(typ).
- putInt(id).
- putInt(totalSize).
- putInt(chunkSize).
- putInt(other).
- put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]).
- putInt(securityNeg).
- putInt(ip.size).
- put(ip).
- putInt(port).
- position(MessageChunkHeader.HEADER_SIZE).
- flip.asInstanceOf[ByteBuffer]
- }
-
- override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
- " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg
-
-}
-
-
-private[spark] object MessageChunkHeader {
- val HEADER_SIZE = 45
-
- def create(buffer: ByteBuffer): MessageChunkHeader = {
- if (buffer.remaining != HEADER_SIZE) {
- throw new IllegalArgumentException("Cannot convert buffer data to Message")
- }
- val typ = buffer.getLong()
- val id = buffer.getInt()
- val totalSize = buffer.getInt()
- val chunkSize = buffer.getInt()
- val other = buffer.getInt()
- val hasError = buffer.get() != 0
- val securityNeg = buffer.getInt()
- val ipSize = buffer.getInt()
- val ipBytes = new Array[Byte](ipSize)
- buffer.get(ipBytes)
- val ip = InetAddress.getByAddress(ipBytes)
- val port = buffer.getInt()
- new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg,
- new InetSocketAddress(ip, port))
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
deleted file mode 100644
index 53a6038..0000000
--- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.nio.ByteBuffer
-import org.apache.spark.{SecurityManager, SparkConf}
-
-private[spark] object ReceiverTest {
- def main(args: Array[String]) {
- val conf = new SparkConf
- val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
- println("Started connection manager with id = " + manager.id)
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- /* println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis) */
- val buffer = ByteBuffer.wrap("response".getBytes("utf-8"))
- Some(Message.createBufferMessage(buffer, msg.id))
- })
- Thread.currentThread.join()
- }
-}
-
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
deleted file mode 100644
index 9af9e2e..0000000
--- a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
+++ /dev/null
@@ -1,162 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.StringBuilder
-
-import org.apache.spark._
-import org.apache.spark.network._
-
-/**
- * SecurityMessage is class that contains the connectionId and sasl token
- * used in SASL negotiation. SecurityMessage has routines for converting
- * it to and from a BufferMessage so that it can be sent by the ConnectionManager
- * and easily consumed by users when received.
- * The api was modeled after BlockMessage.
- *
- * The connectionId is the connectionId of the client side. Since
- * message passing is asynchronous and its possible for the server side (receiving)
- * to get multiple different types of messages on the same connection the connectionId
- * is used to know which connnection the security message is intended for.
- *
- * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side
- * is acting as a client and connecting to node_1. SASL negotiation has to occur
- * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message.
- * node_1 receives the message from node_0 but before it can process it and send a response,
- * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0
- * and sends a security message of its own to authenticate as a client. Now node_0 gets
- * the message and it needs to decide if this message is in response to it being a client
- * (from the first send) or if its just node_1 trying to connect to it to send data. This
- * is where the connectionId field is used. node_0 can lookup the connectionId to see if
- * it is in response to it being a client or if its in response to someone sending other data.
- *
- * The format of a SecurityMessage as its sent is:
- * - Length of the ConnectionId
- * - ConnectionId
- * - Length of the token
- * - Token
- */
-private[spark] class SecurityMessage() extends Logging {
-
- private var connectionId: String = null
- private var token: Array[Byte] = null
-
- def set(byteArr: Array[Byte], newconnectionId: String) {
- if (byteArr == null) {
- token = new Array[Byte](0)
- } else {
- token = byteArr
- }
- connectionId = newconnectionId
- }
-
- /**
- * Read the given buffer and set the members of this class.
- */
- def set(buffer: ByteBuffer) {
- val idLength = buffer.getInt()
- val idBuilder = new StringBuilder(idLength)
- for (i <- 1 to idLength) {
- idBuilder += buffer.getChar()
- }
- connectionId = idBuilder.toString()
-
- val tokenLength = buffer.getInt()
- token = new Array[Byte](tokenLength)
- if (tokenLength > 0) {
- buffer.get(token, 0, tokenLength)
- }
- }
-
- def set(bufferMsg: BufferMessage) {
- val buffer = bufferMsg.buffers.apply(0)
- buffer.clear()
- set(buffer)
- }
-
- def getConnectionId: String = {
- return connectionId
- }
-
- def getToken: Array[Byte] = {
- return token
- }
-
- /**
- * Create a BufferMessage that can be sent by the ConnectionManager containing
- * the security information from this class.
- * @return BufferMessage
- */
- def toBufferMessage: BufferMessage = {
- val buffers = new ArrayBuffer[ByteBuffer]()
-
- // 4 bytes for the length of the connectionId
- // connectionId is of type char so multiple the length by 2 to get number of bytes
- // 4 bytes for the length of token
- // token is a byte buffer so just take the length
- var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length)
- buffer.putInt(connectionId.length())
- connectionId.foreach((x: Char) => buffer.putChar(x))
- buffer.putInt(token.length)
-
- if (token.length > 0) {
- buffer.put(token)
- }
- buffer.flip()
- buffers += buffer
-
- var message = Message.createBufferMessage(buffers)
- logDebug("message total size is : " + message.size)
- message.isSecurityNeg = true
- return message
- }
-
- override def toString: String = {
- "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]"
- }
-}
-
-private[spark] object SecurityMessage {
-
- /**
- * Convert the given BufferMessage to a SecurityMessage by parsing the contents
- * of the BufferMessage and populating the SecurityMessage fields.
- * @param bufferMessage is a BufferMessage that was received
- * @return new SecurityMessage
- */
- def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = {
- val newSecurityMessage = new SecurityMessage()
- newSecurityMessage.set(bufferMessage)
- newSecurityMessage
- }
-
- /**
- * Create a SecurityMessage to send from a given saslResponse.
- * @param response is the response to a challenge from the SaslClient or Saslserver
- * @param connectionId the client connectionId we are negotiation authentication for
- * @return a new SecurityMessage
- */
- def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = {
- val newSecurityMessage = new SecurityMessage()
- newSecurityMessage.set(response, connectionId)
- newSecurityMessage
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/SenderTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
deleted file mode 100644
index ea2ad10..0000000
--- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.nio.ByteBuffer
-import org.apache.spark.{SecurityManager, SparkConf}
-
-import scala.concurrent.Await
-import scala.concurrent.duration.Duration
-import scala.util.Try
-
-private[spark] object SenderTest {
- def main(args: Array[String]) {
-
- if (args.length < 2) {
- println("Usage: SenderTest <target host> <target port>")
- System.exit(1)
- }
-
- val targetHost = args(0)
- val targetPort = args(1).toInt
- val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
- val conf = new SparkConf
- val manager = new ConnectionManager(0, conf, new SecurityManager(conf))
- println("Started connection manager with id = " + manager.id)
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- println("Received [" + msg + "] from [" + id + "]")
- None
- })
-
- val size = 100 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- val targetServer = args(0)
-
- val count = 100
- (0 until count).foreach(i => {
- val dataMessage = Message.createBufferMessage(buffer.duplicate)
- val startTime = System.currentTimeMillis
- /* println("Started timer at " + startTime) */
- val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage)
- val responseStr: String = Try(Await.result(promise, Duration.Inf))
- .map { response =>
- val buffer = response.asInstanceOf[BufferMessage].buffers(0)
- new String(buffer.array, "utf-8")
- }.getOrElse("none")
-
- val finishTime = System.currentTimeMillis
- val mb = size / 1024.0 / 1024.0
- val ms = finishTime - startTime
- // val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms
- // * 1000.0) + " MB/s"
- val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" +
- (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr
- println(resultStr)
- })
- }
-}
-
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala
new file mode 100644
index 0000000..b573f1a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
+
+import scala.collection.mutable.{ArrayBuffer, StringBuilder}
+
+// private[spark] because we need to register them in Kryo
+private[spark] case class GetBlock(id: BlockId)
+private[spark] case class GotBlock(id: BlockId, data: ByteBuffer)
+private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel)
+
+private[nio] class BlockMessage() {
+ // Un-initialized: typ = 0
+ // GetBlock: typ = 1
+ // GotBlock: typ = 2
+ // PutBlock: typ = 3
+ private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
+ private var id: BlockId = null
+ private var data: ByteBuffer = null
+ private var level: StorageLevel = null
+
+ def set(getBlock: GetBlock) {
+ typ = BlockMessage.TYPE_GET_BLOCK
+ id = getBlock.id
+ }
+
+ def set(gotBlock: GotBlock) {
+ typ = BlockMessage.TYPE_GOT_BLOCK
+ id = gotBlock.id
+ data = gotBlock.data
+ }
+
+ def set(putBlock: PutBlock) {
+ typ = BlockMessage.TYPE_PUT_BLOCK
+ id = putBlock.id
+ data = putBlock.data
+ level = putBlock.level
+ }
+
+ def set(buffer: ByteBuffer) {
+ /*
+ println()
+ println("BlockMessage: ")
+ while(buffer.remaining > 0) {
+ print(buffer.get())
+ }
+ buffer.rewind()
+ println()
+ println()
+ */
+ typ = buffer.getInt()
+ val idLength = buffer.getInt()
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buffer.getChar()
+ }
+ id = BlockId(idBuilder.toString)
+
+ if (typ == BlockMessage.TYPE_PUT_BLOCK) {
+
+ val booleanInt = buffer.getInt()
+ val replication = buffer.getInt()
+ level = StorageLevel(booleanInt, replication)
+
+ val dataLength = buffer.getInt()
+ data = ByteBuffer.allocate(dataLength)
+ if (dataLength != buffer.remaining) {
+ throw new Exception("Error parsing buffer")
+ }
+ data.put(buffer)
+ data.flip()
+ } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
+
+ val dataLength = buffer.getInt()
+ data = ByteBuffer.allocate(dataLength)
+ if (dataLength != buffer.remaining) {
+ throw new Exception("Error parsing buffer")
+ }
+ data.put(buffer)
+ data.flip()
+ }
+
+ }
+
+ def set(bufferMsg: BufferMessage) {
+ val buffer = bufferMsg.buffers.apply(0)
+ buffer.clear()
+ set(buffer)
+ }
+
+ def getType: Int = typ
+ def getId: BlockId = id
+ def getData: ByteBuffer = data
+ def getLevel: StorageLevel = level
+
+ def toBufferMessage: BufferMessage = {
+ val buffers = new ArrayBuffer[ByteBuffer]()
+ var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2)
+ buffer.putInt(typ).putInt(id.name.length)
+ id.name.foreach((x: Char) => buffer.putChar(x))
+ buffer.flip()
+ buffers += buffer
+
+ if (typ == BlockMessage.TYPE_PUT_BLOCK) {
+ buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication)
+ buffer.flip()
+ buffers += buffer
+
+ buffer = ByteBuffer.allocate(4).putInt(data.remaining)
+ buffer.flip()
+ buffers += buffer
+
+ buffers += data
+ } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
+ buffer = ByteBuffer.allocate(4).putInt(data.remaining)
+ buffer.flip()
+ buffers += buffer
+
+ buffers += data
+ }
+
+ /*
+ println()
+ println("BlockMessage: ")
+ buffers.foreach(b => {
+ while(b.remaining > 0) {
+ print(b.get())
+ }
+ b.rewind()
+ })
+ println()
+ println()
+ */
+ Message.createBufferMessage(buffers)
+ }
+
+ override def toString: String = {
+ "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level +
+ ", data = " + (if (data != null) data.remaining.toString else "null") + "]"
+ }
+}
+
+private[nio] object BlockMessage {
+ val TYPE_NON_INITIALIZED: Int = 0
+ val TYPE_GET_BLOCK: Int = 1
+ val TYPE_GOT_BLOCK: Int = 2
+ val TYPE_PUT_BLOCK: Int = 3
+
+ def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(bufferMessage)
+ newBlockMessage
+ }
+
+ def fromByteBuffer(buffer: ByteBuffer): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(buffer)
+ newBlockMessage
+ }
+
+ def fromGetBlock(getBlock: GetBlock): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(getBlock)
+ newBlockMessage
+ }
+
+ def fromGotBlock(gotBlock: GotBlock): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(gotBlock)
+ newBlockMessage
+ }
+
+ def fromPutBlock(putBlock: PutBlock): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(putBlock)
+ newBlockMessage
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala
new file mode 100644
index 0000000..a1a2c00
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.nio.ByteBuffer
+
+import org.apache.spark._
+import org.apache.spark.storage.{StorageLevel, TestBlockId}
+
+import scala.collection.mutable.ArrayBuffer
+
+private[nio]
+class BlockMessageArray(var blockMessages: Seq[BlockMessage])
+ extends Seq[BlockMessage] with Logging {
+
+ def this(bm: BlockMessage) = this(Array(bm))
+
+ def this() = this(null.asInstanceOf[Seq[BlockMessage]])
+
+ def apply(i: Int) = blockMessages(i)
+
+ def iterator = blockMessages.iterator
+
+ def length = blockMessages.length
+
+ def set(bufferMessage: BufferMessage) {
+ val startTime = System.currentTimeMillis
+ val newBlockMessages = new ArrayBuffer[BlockMessage]()
+ val buffer = bufferMessage.buffers(0)
+ buffer.clear()
+ /*
+ println()
+ println("BlockMessageArray: ")
+ while(buffer.remaining > 0) {
+ print(buffer.get())
+ }
+ buffer.rewind()
+ println()
+ println()
+ */
+ while (buffer.remaining() > 0) {
+ val size = buffer.getInt()
+ logDebug("Creating block message of size " + size + " bytes")
+ val newBuffer = buffer.slice()
+ newBuffer.clear()
+ newBuffer.limit(size)
+ logDebug("Trying to convert buffer " + newBuffer + " to block message")
+ val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer)
+ logDebug("Created " + newBlockMessage)
+ newBlockMessages += newBlockMessage
+ buffer.position(buffer.position() + size)
+ }
+ val finishTime = System.currentTimeMillis
+ logDebug("Converted block message array from buffer message in " +
+ (finishTime - startTime) / 1000.0 + " s")
+ this.blockMessages = newBlockMessages
+ }
+
+ def toBufferMessage: BufferMessage = {
+ val buffers = new ArrayBuffer[ByteBuffer]()
+
+ blockMessages.foreach(blockMessage => {
+ val bufferMessage = blockMessage.toBufferMessage
+ logDebug("Adding " + blockMessage)
+ val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size)
+ sizeBuffer.flip
+ buffers += sizeBuffer
+ buffers ++= bufferMessage.buffers
+ logDebug("Added " + bufferMessage)
+ })
+
+ logDebug("Buffer list:")
+ buffers.foreach((x: ByteBuffer) => logDebug("" + x))
+ /*
+ println()
+ println("BlockMessageArray: ")
+ buffers.foreach(b => {
+ while(b.remaining > 0) {
+ print(b.get())
+ }
+ b.rewind()
+ })
+ println()
+ println()
+ */
+ Message.createBufferMessage(buffers)
+ }
+}
+
+private[nio] object BlockMessageArray {
+
+ def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
+ val newBlockMessageArray = new BlockMessageArray()
+ newBlockMessageArray.set(bufferMessage)
+ newBlockMessageArray
+ }
+
+ def main(args: Array[String]) {
+ val blockMessages =
+ (0 until 10).map { i =>
+ if (i % 2 == 0) {
+ val buffer = ByteBuffer.allocate(100)
+ buffer.clear
+ BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer,
+ StorageLevel.MEMORY_ONLY_SER))
+ } else {
+ BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString)))
+ }
+ }
+ val blockMessageArray = new BlockMessageArray(blockMessages)
+ println("Block message array created")
+
+ val bufferMessage = blockMessageArray.toBufferMessage
+ println("Converted to buffer message")
+
+ val totalSize = bufferMessage.size
+ val newBuffer = ByteBuffer.allocate(totalSize)
+ newBuffer.clear()
+ bufferMessage.buffers.foreach(buffer => {
+ assert (0 == buffer.position())
+ newBuffer.put(buffer)
+ buffer.rewind()
+ })
+ newBuffer.flip
+ val newBufferMessage = Message.createBufferMessage(newBuffer)
+ println("Copied to new buffer message, size = " + newBufferMessage.size)
+
+ val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage)
+ println("Converted back to block message array")
+ newBlockMessageArray.foreach(blockMessage => {
+ blockMessage.getType match {
+ case BlockMessage.TYPE_PUT_BLOCK => {
+ val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
+ println(pB)
+ }
+ case BlockMessage.TYPE_GET_BLOCK => {
+ val gB = new GetBlock(blockMessage.getId)
+ println(gB)
+ }
+ }
+ })
+ }
+}
+
+
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala
new file mode 100644
index 0000000..3b245c5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.storage.BlockManager
+
+
+private[nio]
+class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
+ extends Message(Message.BUFFER_MESSAGE, id_) {
+
+ val initialSize = currentSize()
+ var gotChunkForSendingOnce = false
+
+ def size = initialSize
+
+ def currentSize() = {
+ if (buffers == null || buffers.isEmpty) {
+ 0
+ } else {
+ buffers.map(_.remaining).reduceLeft(_ + _)
+ }
+ }
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
+ if (maxChunkSize <= 0) {
+ throw new Exception("Max chunk size is " + maxChunkSize)
+ }
+
+ val security = if (isSecurityNeg) 1 else 0
+ if (size == 0 && !gotChunkForSendingOnce) {
+ val newChunk = new MessageChunk(
+ new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+
+ while(!buffers.isEmpty) {
+ val buffer = buffers(0)
+ if (buffer.remaining == 0) {
+ BlockManager.dispose(buffer)
+ buffers -= buffer
+ } else {
+ val newBuffer = if (buffer.remaining <= maxChunkSize) {
+ buffer.duplicate()
+ } else {
+ buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
+ }
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId,
+ hasError, security, senderAddress), newBuffer)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+ }
+ None
+ }
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
+ // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
+ if (buffers.size > 1) {
+ throw new Exception("Attempting to get chunk from message with multiple data buffers")
+ }
+ val buffer = buffers(0)
+ val security = if (isSecurityNeg) 1 else 0
+ if (buffer.remaining > 0) {
+ if (buffer.remaining < chunkSize) {
+ throw new Exception("Not enough space in data buffer for receiving chunk")
+ }
+ val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer)
+ return Some(newChunk)
+ }
+ None
+ }
+
+ def flip() {
+ buffers.foreach(_.flip)
+ }
+
+ def hasAckId() = (ackId != 0)
+
+ def isCompletelyReceived() = !buffers(0).hasRemaining
+
+ override def toString = {
+ if (hasAckId) {
+ "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
+ } else {
+ "BufferMessage(id = " + id + ", size = " + size + ")"
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
new file mode 100644
index 0000000..74074a8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -0,0 +1,587 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.net._
+import java.nio._
+import java.nio.channels._
+
+import org.apache.spark._
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
+
+private[nio]
+abstract class Connection(val channel: SocketChannel, val selector: Selector,
+ val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
+ extends Logging {
+
+ var sparkSaslServer: SparkSaslServer = null
+ var sparkSaslClient: SparkSaslClient = null
+
+ def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
+ this(channel_, selector_,
+ ConnectionManagerId.fromSocketAddress(
+ channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_)
+ }
+
+ channel.configureBlocking(false)
+ channel.socket.setTcpNoDelay(true)
+ channel.socket.setReuseAddress(true)
+ channel.socket.setKeepAlive(true)
+ /* channel.socket.setReceiveBufferSize(32768) */
+
+ @volatile private var closed = false
+ var onCloseCallback: Connection => Unit = null
+ var onExceptionCallback: (Connection, Exception) => Unit = null
+ var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
+
+ val remoteAddress = getRemoteAddress()
+
+ /**
+ * Used to synchronize client requests: client's work-related requests must
+ * wait until SASL authentication completes.
+ */
+ private val authenticated = new Object()
+
+ def getAuthenticated(): Object = authenticated
+
+ def isSaslComplete(): Boolean
+
+ def resetForceReregister(): Boolean
+
+ // Read channels typically do not register for write and write does not for read
+ // Now, we do have write registering for read too (temporarily), but this is to detect
+ // channel close NOT to actually read/consume data on it !
+ // How does this work if/when we move to SSL ?
+
+ // What is the interest to register with selector for when we want this connection to be selected
+ def registerInterest()
+
+ // What is the interest to register with selector for when we want this connection to
+ // be de-selected
+ // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack,
+ // it will be SelectionKey.OP_READ (until we fix it properly)
+ def unregisterInterest()
+
+ // On receiving a read event, should we change the interest for this channel or not ?
+ // Will be true for ReceivingConnection, false for SendingConnection.
+ def changeInterestForRead(): Boolean
+
+ private def disposeSasl() {
+ if (sparkSaslServer != null) {
+ sparkSaslServer.dispose()
+ }
+
+ if (sparkSaslClient != null) {
+ sparkSaslClient.dispose()
+ }
+ }
+
+ // On receiving a write event, should we change the interest for this channel or not ?
+ // Will be false for ReceivingConnection, true for SendingConnection.
+ // Actually, for now, should not get triggered for ReceivingConnection
+ def changeInterestForWrite(): Boolean
+
+ def getRemoteConnectionManagerId(): ConnectionManagerId = {
+ socketRemoteConnectionManagerId
+ }
+
+ def key() = channel.keyFor(selector)
+
+ def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
+
+ // Returns whether we have to register for further reads or not.
+ def read(): Boolean = {
+ throw new UnsupportedOperationException(
+ "Cannot read on connection of type " + this.getClass.toString)
+ }
+
+ // Returns whether we have to register for further writes or not.
+ def write(): Boolean = {
+ throw new UnsupportedOperationException(
+ "Cannot write on connection of type " + this.getClass.toString)
+ }
+
+ def close() {
+ closed = true
+ val k = key()
+ if (k != null) {
+ k.cancel()
+ }
+ channel.close()
+ disposeSasl()
+ callOnCloseCallback()
+ }
+
+ protected def isClosed: Boolean = closed
+
+ def onClose(callback: Connection => Unit) {
+ onCloseCallback = callback
+ }
+
+ def onException(callback: (Connection, Exception) => Unit) {
+ onExceptionCallback = callback
+ }
+
+ def onKeyInterestChange(callback: (Connection, Int) => Unit) {
+ onKeyInterestChangeCallback = callback
+ }
+
+ def callOnExceptionCallback(e: Exception) {
+ if (onExceptionCallback != null) {
+ onExceptionCallback(this, e)
+ } else {
+ logError("Error in connection to " + getRemoteConnectionManagerId() +
+ " and OnExceptionCallback not registered", e)
+ }
+ }
+
+ def callOnCloseCallback() {
+ if (onCloseCallback != null) {
+ onCloseCallback(this)
+ } else {
+ logWarning("Connection to " + getRemoteConnectionManagerId() +
+ " closed and OnExceptionCallback not registered")
+ }
+
+ }
+
+ def changeConnectionKeyInterest(ops: Int) {
+ if (onKeyInterestChangeCallback != null) {
+ onKeyInterestChangeCallback(this, ops)
+ } else {
+ throw new Exception("OnKeyInterestChangeCallback not registered")
+ }
+ }
+
+ def printRemainingBuffer(buffer: ByteBuffer) {
+ val bytes = new Array[Byte](buffer.remaining)
+ val curPosition = buffer.position
+ buffer.get(bytes)
+ bytes.foreach(x => print(x + " "))
+ buffer.position(curPosition)
+ print(" (" + bytes.size + ")")
+ }
+
+ def printBuffer(buffer: ByteBuffer, position: Int, length: Int) {
+ val bytes = new Array[Byte](length)
+ val curPosition = buffer.position
+ buffer.position(position)
+ buffer.get(bytes)
+ bytes.foreach(x => print(x + " "))
+ print(" (" + position + ", " + length + ")")
+ buffer.position(curPosition)
+ }
+}
+
+
+private[nio]
+class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
+ remoteId_ : ConnectionManagerId, id_ : ConnectionId)
+ extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
+
+ def isSaslComplete(): Boolean = {
+ if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
+ }
+
+ private class Outbox {
+ val messages = new Queue[Message]()
+ val defaultChunkSize = 65536
+ var nextMessageToBeUsed = 0
+
+ def addMessage(message: Message) {
+ messages.synchronized {
+ /* messages += message */
+ messages.enqueue(message)
+ logDebug("Added [" + message + "] to outbox for sending to " +
+ "[" + getRemoteConnectionManagerId() + "]")
+ }
+ }
+
+ def getChunk(): Option[MessageChunk] = {
+ messages.synchronized {
+ while (!messages.isEmpty) {
+ /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
+ /* val message = messages(nextMessageToBeUsed) */
+ val message = messages.dequeue()
+ val chunk = message.getChunkForSending(defaultChunkSize)
+ if (chunk.isDefined) {
+ messages.enqueue(message)
+ nextMessageToBeUsed = nextMessageToBeUsed + 1
+ if (!message.started) {
+ logDebug(
+ "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
+ message.started = true
+ message.startTime = System.currentTimeMillis
+ }
+ logTrace(
+ "Sending chunk from [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
+ return chunk
+ } else {
+ message.finishTime = System.currentTimeMillis
+ logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
+ "] in " + message.timeTaken )
+ }
+ }
+ }
+ None
+ }
+ }
+
+ // outbox is used as a lock - ensure that it is always used as a leaf (since methods which
+ // lock it are invoked in context of other locks)
+ private val outbox = new Outbox()
+ /*
+ This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly
+ different purpose. This flag is to see if we need to force reregister for write even when we
+ do not have any pending bytes to write to socket.
+ This can happen due to a race between adding pending buffers, and checking for existing of
+ data as detailed in https://github.com/mesos/spark/pull/791
+ */
+ private var needForceReregister = false
+
+ val currentBuffers = new ArrayBuffer[ByteBuffer]()
+
+ /* channel.socket.setSendBufferSize(256 * 1024) */
+
+ override def getRemoteAddress() = address
+
+ val DEFAULT_INTEREST = SelectionKey.OP_READ
+
+ override def registerInterest() {
+ // Registering read too - does not really help in most cases, but for some
+ // it does - so let us keep it for now.
+ changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
+ }
+
+ override def unregisterInterest() {
+ changeConnectionKeyInterest(DEFAULT_INTEREST)
+ }
+
+ def send(message: Message) {
+ outbox.synchronized {
+ outbox.addMessage(message)
+ needForceReregister = true
+ }
+ if (channel.isConnected) {
+ registerInterest()
+ }
+ }
+
+ // return previous value after resetting it.
+ def resetForceReregister(): Boolean = {
+ outbox.synchronized {
+ val result = needForceReregister
+ needForceReregister = false
+ result
+ }
+ }
+
+ // MUST be called within the selector loop
+ def connect() {
+ try{
+ channel.register(selector, SelectionKey.OP_CONNECT)
+ channel.connect(address)
+ logInfo("Initiating connection to [" + address + "]")
+ } catch {
+ case e: Exception => {
+ logError("Error connecting to " + address, e)
+ callOnExceptionCallback(e)
+ }
+ }
+ }
+
+ def finishConnect(force: Boolean): Boolean = {
+ try {
+ // Typically, this should finish immediately since it was triggered by a connect
+ // selection - though need not necessarily always complete successfully.
+ val connected = channel.finishConnect
+ if (!force && !connected) {
+ logInfo(
+ "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
+ return false
+ }
+
+ // Fallback to previous behavior - assume finishConnect completed
+ // This will happen only when finishConnect failed for some repeated number of times
+ // (10 or so)
+ // Is highly unlikely unless there was an unclean close of socket, etc
+ registerInterest()
+ logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
+ } catch {
+ case e: Exception => {
+ logWarning("Error finishing connection to " + address, e)
+ callOnExceptionCallback(e)
+ }
+ }
+ true
+ }
+
+ override def write(): Boolean = {
+ try {
+ while (true) {
+ if (currentBuffers.size == 0) {
+ outbox.synchronized {
+ outbox.getChunk() match {
+ case Some(chunk) => {
+ val buffers = chunk.buffers
+ // If we have 'seen' pending messages, then reset flag - since we handle that as
+ // normal registering of event (below)
+ if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister()
+
+ currentBuffers ++= buffers
+ }
+ case None => {
+ // changeConnectionKeyInterest(0)
+ /* key.interestOps(0) */
+ return false
+ }
+ }
+ }
+ }
+
+ if (currentBuffers.size > 0) {
+ val buffer = currentBuffers(0)
+ val remainingBytes = buffer.remaining
+ val writtenBytes = channel.write(buffer)
+ if (buffer.remaining == 0) {
+ currentBuffers -= buffer
+ }
+ if (writtenBytes < remainingBytes) {
+ // re-register for write.
+ return true
+ }
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
+ callOnExceptionCallback(e)
+ close()
+ return false
+ }
+ }
+ // should not happen - to keep scala compiler happy
+ true
+ }
+
+ // This is a hack to determine if remote socket was closed or not.
+ // SendingConnection DOES NOT expect to receive any data - if it does, it is an error
+ // For a bunch of cases, read will return -1 in case remote socket is closed : hence we
+ // register for reads to determine that.
+ override def read(): Boolean = {
+ // We don't expect the other side to send anything; so, we just read to detect an error or EOF.
+ try {
+ val length = channel.read(ByteBuffer.allocate(1))
+ if (length == -1) { // EOF
+ close()
+ } else if (length > 0) {
+ logWarning(
+ "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
+ }
+ } catch {
+ case e: Exception =>
+ logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(),
+ e)
+ callOnExceptionCallback(e)
+ close()
+ }
+
+ false
+ }
+
+ override def changeInterestForRead(): Boolean = false
+
+ override def changeInterestForWrite(): Boolean = ! isClosed
+}
+
+
+// Must be created within selector loop - else deadlock
+private[spark] class ReceivingConnection(
+ channel_ : SocketChannel,
+ selector_ : Selector,
+ id_ : ConnectionId)
+ extends Connection(channel_, selector_, id_) {
+
+ def isSaslComplete(): Boolean = {
+ if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
+ }
+
+ class Inbox() {
+ val messages = new HashMap[Int, BufferMessage]()
+
+ def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
+
+ def createNewMessage: BufferMessage = {
+ val newMessage = Message.create(header).asInstanceOf[BufferMessage]
+ newMessage.started = true
+ newMessage.startTime = System.currentTimeMillis
+ newMessage.isSecurityNeg = header.securityNeg == 1
+ logDebug(
+ "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
+ messages += ((newMessage.id, newMessage))
+ newMessage
+ }
+
+ val message = messages.getOrElseUpdate(header.id, createNewMessage)
+ logTrace(
+ "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
+ message.getChunkForReceiving(header.chunkSize)
+ }
+
+ def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
+ messages.get(chunk.header.id)
+ }
+
+ def removeMessage(message: Message) {
+ messages -= message.id
+ }
+ }
+
+ @volatile private var inferredRemoteManagerId: ConnectionManagerId = null
+
+ override def getRemoteConnectionManagerId(): ConnectionManagerId = {
+ val currId = inferredRemoteManagerId
+ if (currId != null) currId else super.getRemoteConnectionManagerId()
+ }
+
+ // The reciever's remote address is the local socket on remote side : which is NOT
+ // the connection manager id of the receiver.
+ // We infer that from the messages we receive on the receiver socket.
+ private def processConnectionManagerId(header: MessageChunkHeader) {
+ val currId = inferredRemoteManagerId
+ if (header.address == null || currId != null) return
+
+ val managerId = ConnectionManagerId.fromSocketAddress(header.address)
+
+ if (managerId != null) {
+ inferredRemoteManagerId = managerId
+ }
+ }
+
+
+ val inbox = new Inbox()
+ val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
+ var onReceiveCallback: (Connection, Message) => Unit = null
+ var currentChunk: MessageChunk = null
+
+ channel.register(selector, SelectionKey.OP_READ)
+
+ override def read(): Boolean = {
+ try {
+ while (true) {
+ if (currentChunk == null) {
+ val headerBytesRead = channel.read(headerBuffer)
+ if (headerBytesRead == -1) {
+ close()
+ return false
+ }
+ if (headerBuffer.remaining > 0) {
+ // re-register for read event ...
+ return true
+ }
+ headerBuffer.flip
+ if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
+ throw new Exception(
+ "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
+ }
+ val header = MessageChunkHeader.create(headerBuffer)
+ headerBuffer.clear()
+
+ processConnectionManagerId(header)
+
+ header.typ match {
+ case Message.BUFFER_MESSAGE => {
+ if (header.totalSize == 0) {
+ if (onReceiveCallback != null) {
+ onReceiveCallback(this, Message.create(header))
+ }
+ currentChunk = null
+ // re-register for read event ...
+ return true
+ } else {
+ currentChunk = inbox.getChunk(header).orNull
+ }
+ }
+ case _ => throw new Exception("Message of unknown type received")
+ }
+ }
+
+ if (currentChunk == null) throw new Exception("No message chunk to receive data")
+
+ val bytesRead = channel.read(currentChunk.buffer)
+ if (bytesRead == 0) {
+ // re-register for read event ...
+ return true
+ } else if (bytesRead == -1) {
+ close()
+ return false
+ }
+
+ /* logDebug("Read " + bytesRead + " bytes for the buffer") */
+
+ if (currentChunk.buffer.remaining == 0) {
+ /* println("Filled buffer at " + System.currentTimeMillis) */
+ val bufferMessage = inbox.getMessageForChunk(currentChunk).get
+ if (bufferMessage.isCompletelyReceived) {
+ bufferMessage.flip()
+ bufferMessage.finishTime = System.currentTimeMillis
+ logDebug("Finished receiving [" + bufferMessage + "] from " +
+ "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
+ if (onReceiveCallback != null) {
+ onReceiveCallback(this, bufferMessage)
+ }
+ inbox.removeMessage(bufferMessage)
+ }
+ currentChunk = null
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
+ callOnExceptionCallback(e)
+ close()
+ return false
+ }
+ }
+ // should not happen - to keep scala compiler happy
+ true
+ }
+
+ def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
+
+ // override def changeInterestForRead(): Boolean = ! isClosed
+ override def changeInterestForRead(): Boolean = true
+
+ override def changeInterestForWrite(): Boolean = {
+ throw new IllegalStateException("Unexpected invocation right now")
+ }
+
+ override def registerInterest() {
+ // Registering read too - does not really help in most cases, but for some
+ // it does - so let us keep it for now.
+ changeConnectionKeyInterest(SelectionKey.OP_READ)
+ }
+
+ override def unregisterInterest() {
+ changeConnectionKeyInterest(0)
+ }
+
+ // For read conn, always false.
+ override def resetForceReregister(): Boolean = false
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala
new file mode 100644
index 0000000..764dc5e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) {
+ override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId
+}
+
+private[nio] object ConnectionId {
+
+ def createConnectionIdFromString(connectionIdString: String): ConnectionId = {
+ val res = connectionIdString.split("_").map(_.trim())
+ if (res.size != 3) {
+ throw new Exception("Error converting ConnectionId string: " + connectionIdString +
+ " to a ConnectionId Object")
+ }
+ new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org
[3/5] [SPARK-3019] Pluggable block transfer interface
(BlockTransferService)
Posted by rx...@apache.org.
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
new file mode 100644
index 0000000..09d3ea3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -0,0 +1,1042 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.io.IOException
+import java.net._
+import java.nio._
+import java.nio.channels._
+import java.nio.channels.spi._
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit}
+import java.util.{Timer, TimerTask}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue}
+import scala.concurrent.duration._
+import scala.concurrent.{Await, ExecutionContext, Future, Promise}
+import scala.language.postfixOps
+
+import org.apache.spark._
+import org.apache.spark.util.{SystemClock, Utils}
+
+
+private[nio] class ConnectionManager(
+ port: Int,
+ conf: SparkConf,
+ securityManager: SecurityManager,
+ name: String = "Connection manager")
+ extends Logging {
+
+ /**
+ * Used by sendMessageReliably to track messages being sent.
+ * @param message the message that was sent
+ * @param connectionManagerId the connection manager that sent this message
+ * @param completionHandler callback that's invoked when the send has completed or failed
+ */
+ class MessageStatus(
+ val message: Message,
+ val connectionManagerId: ConnectionManagerId,
+ completionHandler: MessageStatus => Unit) {
+
+ /** This is non-None if message has been ack'd */
+ var ackMessage: Option[Message] = None
+
+ def markDone(ackMessage: Option[Message]) {
+ this.ackMessage = ackMessage
+ completionHandler(this)
+ }
+ }
+
+ private val selector = SelectorProvider.provider.openSelector()
+ private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
+
+ // default to 30 second timeout waiting for authentication
+ private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
+ private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
+
+ private val handleMessageExecutor = new ThreadPoolExecutor(
+ conf.getInt("spark.core.connection.handler.threads.min", 20),
+ conf.getInt("spark.core.connection.handler.threads.max", 60),
+ conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable](),
+ Utils.namedThreadFactory("handle-message-executor"))
+
+ private val handleReadWriteExecutor = new ThreadPoolExecutor(
+ conf.getInt("spark.core.connection.io.threads.min", 4),
+ conf.getInt("spark.core.connection.io.threads.max", 32),
+ conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable](),
+ Utils.namedThreadFactory("handle-read-write-executor"))
+
+ // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks :
+ // which should be executed asap
+ private val handleConnectExecutor = new ThreadPoolExecutor(
+ conf.getInt("spark.core.connection.connect.threads.min", 1),
+ conf.getInt("spark.core.connection.connect.threads.max", 8),
+ conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable](),
+ Utils.namedThreadFactory("handle-connect-executor"))
+
+ private val serverChannel = ServerSocketChannel.open()
+ // used to track the SendingConnections waiting to do SASL negotiation
+ private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection]
+ with SynchronizedMap[ConnectionId, SendingConnection]
+ private val connectionsByKey =
+ new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
+ private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
+ with SynchronizedMap[ConnectionManagerId, SendingConnection]
+ private val messageStatuses = new HashMap[Int, MessageStatus]
+ private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
+ private val registerRequests = new SynchronizedQueue[SendingConnection]
+
+ implicit val futureExecContext = ExecutionContext.fromExecutor(
+ Utils.newDaemonCachedThreadPool("Connection manager future execution context"))
+
+ @volatile
+ private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message] = null
+
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+
+ serverChannel.configureBlocking(false)
+ serverChannel.socket.setReuseAddress(true)
+ serverChannel.socket.setReceiveBufferSize(256 * 1024)
+
+ private def startService(port: Int): (ServerSocketChannel, Int) = {
+ serverChannel.socket.bind(new InetSocketAddress(port))
+ (serverChannel, serverChannel.socket.getLocalPort)
+ }
+ Utils.startServiceOnPort[ServerSocketChannel](port, startService, name)
+ serverChannel.register(selector, SelectionKey.OP_ACCEPT)
+
+ val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
+ logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
+
+ // used in combination with the ConnectionManagerId to create unique Connection ids
+ // to be able to track asynchronous messages
+ private val idCount: AtomicInteger = new AtomicInteger(1)
+
+ private val selectorThread = new Thread("connection-manager-thread") {
+ override def run() = ConnectionManager.this.run()
+ }
+ selectorThread.setDaemon(true)
+ selectorThread.start()
+
+ private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
+ private def triggerWrite(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ writeRunnableStarted.synchronized {
+ // So that we do not trigger more write events while processing this one.
+ // The write method will re-register when done.
+ if (conn.changeInterestForWrite()) conn.unregisterInterest()
+ if (writeRunnableStarted.contains(key)) {
+ // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE)
+ return
+ }
+
+ writeRunnableStarted += key
+ }
+ handleReadWriteExecutor.execute(new Runnable {
+ override def run() {
+ var register: Boolean = false
+ try {
+ register = conn.write()
+ } finally {
+ writeRunnableStarted.synchronized {
+ writeRunnableStarted -= key
+ val needReregister = register || conn.resetForceReregister()
+ if (needReregister && conn.changeInterestForWrite()) {
+ conn.registerInterest()
+ }
+ }
+ }
+ }
+ } )
+ }
+
+ private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
+ private def triggerRead(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ readRunnableStarted.synchronized {
+ // So that we do not trigger more read events while processing this one.
+ // The read method will re-register when done.
+ if (conn.changeInterestForRead())conn.unregisterInterest()
+ if (readRunnableStarted.contains(key)) {
+ return
+ }
+
+ readRunnableStarted += key
+ }
+ handleReadWriteExecutor.execute(new Runnable {
+ override def run() {
+ var register: Boolean = false
+ try {
+ register = conn.read()
+ } finally {
+ readRunnableStarted.synchronized {
+ readRunnableStarted -= key
+ if (register && conn.changeInterestForRead()) {
+ conn.registerInterest()
+ }
+ }
+ }
+ }
+ } )
+ }
+
+ private def triggerConnect(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection]
+ if (conn == null) return
+
+ // prevent other events from being triggered
+ // Since we are still trying to connect, we do not need to do the additional steps in
+ // triggerWrite
+ conn.changeConnectionKeyInterest(0)
+
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+
+ var tries: Int = 10
+ while (tries >= 0) {
+ if (conn.finishConnect(false)) return
+ // Sleep ?
+ Thread.sleep(1)
+ tries -= 1
+ }
+
+ // fallback to previous behavior : we should not really come here since this method was
+ // triggered since channel became connectable : but at times, the first finishConnect need
+ // not succeed : hence the loop to retry a few 'times'.
+ conn.finishConnect(true)
+ }
+ } )
+ }
+
+ // MUST be called within selector loop - else deadlock.
+ private def triggerForceCloseByException(key: SelectionKey, e: Exception) {
+ try {
+ key.interestOps(0)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ // Pushing to connect threadpool
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+ try {
+ conn.callOnExceptionCallback(e)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ try {
+ conn.close()
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ }
+ })
+ }
+
+
+ def run() {
+ try {
+ while(!selectorThread.isInterrupted) {
+ while (!registerRequests.isEmpty) {
+ val conn: SendingConnection = registerRequests.dequeue()
+ addListeners(conn)
+ conn.connect()
+ addConnection(conn)
+ }
+
+ while(!keyInterestChangeRequests.isEmpty) {
+ val (key, ops) = keyInterestChangeRequests.dequeue()
+
+ try {
+ if (key.isValid) {
+ val connection = connectionsByKey.getOrElse(key, null)
+ if (connection != null) {
+ val lastOps = key.interestOps()
+ key.interestOps(ops)
+
+ // hot loop - prevent materialization of string if trace not enabled.
+ if (isTraceEnabled()) {
+ def intToOpStr(op: Int): String = {
+ val opStrs = ArrayBuffer[String]()
+ if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
+ if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
+ if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
+ if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
+ if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+ }
+
+ logTrace("Changed key for connection to [" +
+ connection.getRemoteConnectionManagerId() + "] changed from [" +
+ intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+ }
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+
+ val selectedKeysCount =
+ try {
+ selector.select()
+ } catch {
+ // Explicitly only dealing with CancelledKeyException here since other exceptions
+ // should be dealt with differently.
+ case e: CancelledKeyException => {
+ // Some keys within the selectors list are invalid/closed. clear them.
+ val allKeys = selector.keys().iterator()
+
+ while (allKeys.hasNext) {
+ val key = allKeys.next()
+ try {
+ if (! key.isValid) {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+ }
+ 0
+ }
+
+ if (selectedKeysCount == 0) {
+ logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size +
+ " keys")
+ }
+ if (selectorThread.isInterrupted) {
+ logInfo("Selector thread was interrupted!")
+ return
+ }
+
+ if (0 != selectedKeysCount) {
+ val selectedKeys = selector.selectedKeys().iterator()
+ while (selectedKeys.hasNext) {
+ val key = selectedKeys.next
+ selectedKeys.remove()
+ try {
+ if (key.isValid) {
+ if (key.isAcceptable) {
+ acceptConnection(key)
+ } else
+ if (key.isConnectable) {
+ triggerConnect(key)
+ } else
+ if (key.isReadable) {
+ triggerRead(key)
+ } else
+ if (key.isWritable) {
+ triggerWrite(key)
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ // weird, but we saw this happening - even though key.isValid was true,
+ // key.isAcceptable would throw CancelledKeyException.
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+ }
+ }
+ } catch {
+ case e: Exception => logError("Error in select loop", e)
+ }
+ }
+
+ def acceptConnection(key: SelectionKey) {
+ val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
+
+ var newChannel = serverChannel.accept()
+
+ // accept them all in a tight loop. non blocking accept with no processing, should be fine
+ while (newChannel != null) {
+ try {
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
+ newConnection.onReceive(receiveMessage)
+ addListeners(newConnection)
+ addConnection(newConnection)
+ logInfo("Accepted connection from [" + newConnection.remoteAddress + "]")
+ } catch {
+ // might happen in case of issues with registering with selector
+ case e: Exception => logError("Error in accept loop", e)
+ }
+
+ newChannel = serverChannel.accept()
+ }
+ }
+
+ private def addListeners(connection: Connection) {
+ connection.onKeyInterestChange(changeConnectionKeyInterest)
+ connection.onException(handleConnectionError)
+ connection.onClose(removeConnection)
+ }
+
+ def addConnection(connection: Connection) {
+ connectionsByKey += ((connection.key, connection))
+ }
+
+ def removeConnection(connection: Connection) {
+ connectionsByKey -= connection.key
+
+ try {
+ connection match {
+ case sendingConnection: SendingConnection =>
+ val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
+ logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
+
+ connectionsById -= sendingConnectionManagerId
+ connectionsAwaitingSasl -= connection.connectionId
+
+ messageStatuses.synchronized {
+ messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
+ .foreach(status => {
+ logInfo("Notifying " + status)
+ status.markDone(None)
+ })
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
+ case receivingConnection: ReceivingConnection =>
+ val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
+ logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
+
+ val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
+ if (!sendingConnectionOpt.isDefined) {
+ logError(s"Corresponding SendingConnection to ${remoteConnectionManagerId} not found")
+ return
+ }
+
+ val sendingConnection = sendingConnectionOpt.get
+ connectionsById -= remoteConnectionManagerId
+ sendingConnection.close()
+
+ val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
+
+ assert(sendingConnectionManagerId == remoteConnectionManagerId)
+
+ messageStatuses.synchronized {
+ for (s <- messageStatuses.values
+ if s.connectionManagerId == sendingConnectionManagerId) {
+ logInfo("Notifying " + s)
+ s.markDone(None)
+ }
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
+ case _ => logError("Unsupported type of connection.")
+ }
+ } finally {
+ // So that the selection keys can be removed.
+ wakeupSelector()
+ }
+ }
+
+ def handleConnectionError(connection: Connection, e: Exception) {
+ logInfo("Handling connection error on connection to " +
+ connection.getRemoteConnectionManagerId())
+ removeConnection(connection)
+ }
+
+ def changeConnectionKeyInterest(connection: Connection, ops: Int) {
+ keyInterestChangeRequests += ((connection.key, ops))
+ // so that registerations happen !
+ wakeupSelector()
+ }
+
+ def receiveMessage(connection: Connection, message: Message) {
+ val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
+ logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
+ val runnable = new Runnable() {
+ val creationTime = System.currentTimeMillis
+ def run() {
+ logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ handleMessage(connectionManagerId, message, connection)
+ logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ }
+ }
+ handleMessageExecutor.execute(runnable)
+ /* handleMessage(connection, message) */
+ }
+
+ private def handleClientAuthentication(
+ waitingConn: SendingConnection,
+ securityMsg: SecurityMessage,
+ connectionId : ConnectionId) {
+ if (waitingConn.isSaslComplete()) {
+ logDebug("Client sasl completed for id: " + waitingConn.connectionId)
+ connectionsAwaitingSasl -= waitingConn.connectionId
+ waitingConn.getAuthenticated().synchronized {
+ waitingConn.getAuthenticated().notifyAll()
+ }
+ return
+ } else {
+ var replyToken : Array[Byte] = null
+ try {
+ replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
+ if (waitingConn.isSaslComplete()) {
+ logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
+ connectionsAwaitingSasl -= waitingConn.connectionId
+ waitingConn.getAuthenticated().synchronized {
+ waitingConn.getAuthenticated().notifyAll()
+ }
+ return
+ }
+ val securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ securityMsg.getConnectionId.toString)
+ val message = securityMsgResp.toBufferMessage
+ if (message == null) throw new IOException("Error creating security message")
+ sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
+ } catch {
+ case e: Exception => {
+ logError("Error handling sasl client authentication", e)
+ waitingConn.close()
+ throw new IOException("Error evaluating sasl response: ", e)
+ }
+ }
+ }
+ }
+
+ private def handleServerAuthentication(
+ connection: Connection,
+ securityMsg: SecurityMessage,
+ connectionId: ConnectionId) {
+ if (!connection.isSaslComplete()) {
+ logDebug("saslContext not established")
+ var replyToken : Array[Byte] = null
+ try {
+ connection.synchronized {
+ if (connection.sparkSaslServer == null) {
+ logDebug("Creating sasl Server")
+ connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ }
+ }
+ replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
+ if (connection.isSaslComplete()) {
+ logDebug("Server sasl completed: " + connection.connectionId)
+ } else {
+ logDebug("Server sasl not completed: " + connection.connectionId)
+ }
+ if (replyToken != null) {
+ val securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ securityMsg.getConnectionId)
+ val message = securityMsgResp.toBufferMessage
+ if (message == null) throw new Exception("Error creating security Message")
+ sendSecurityMessage(connection.getRemoteConnectionManagerId(), message)
+ }
+ } catch {
+ case e: Exception => {
+ logError("Error in server auth negotiation: " + e)
+ // It would probably be better to send an error message telling other side auth failed
+ // but for now just close
+ connection.close()
+ }
+ }
+ } else {
+ logDebug("connection already established for this connection id: " + connection.connectionId)
+ }
+ }
+
+
+ private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = {
+ if (bufferMessage.isSecurityNeg) {
+ logDebug("This is security neg message")
+
+ // parse as SecurityMessage
+ val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage)
+ val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId)
+
+ connectionsAwaitingSasl.get(connectionId) match {
+ case Some(waitingConn) => {
+ // Client - this must be in response to us doing Send
+ logDebug("Client handleAuth for id: " + waitingConn.connectionId)
+ handleClientAuthentication(waitingConn, securityMsg, connectionId)
+ }
+ case None => {
+ // Server - someone sent us something and we haven't authenticated yet
+ logDebug("Server handleAuth for id: " + connectionId)
+ handleServerAuthentication(conn, securityMsg, connectionId)
+ }
+ }
+ return true
+ } else {
+ if (!conn.isSaslComplete()) {
+ // We could handle this better and tell the client we need to do authentication
+ // negotiation, but for now just ignore them.
+ logError("message sent that is not security negotiation message on connection " +
+ "not authenticated yet, ignoring it!!")
+ return true
+ }
+ }
+ false
+ }
+
+ private def handleMessage(
+ connectionManagerId: ConnectionManagerId,
+ message: Message,
+ connection: Connection) {
+ logDebug("Handling [" + message + "] from [" + connectionManagerId + "]")
+ message match {
+ case bufferMessage: BufferMessage => {
+ if (authEnabled) {
+ val res = handleAuthentication(connection, bufferMessage)
+ if (res) {
+ // message was security negotiation so skip the rest
+ logDebug("After handleAuth result was true, returning")
+ return
+ }
+ }
+ if (bufferMessage.hasAckId()) {
+ messageStatuses.synchronized {
+ messageStatuses.get(bufferMessage.ackId) match {
+ case Some(status) => {
+ messageStatuses -= bufferMessage.ackId
+ status.markDone(Some(message))
+ }
+ case None => {
+ /**
+ * We can fall down on this code because of following 2 cases
+ *
+ * (1) Invalid ack sent due to buggy code.
+ *
+ * (2) Late-arriving ack for a SendMessageStatus
+ * To avoid unwilling late-arriving ack
+ * caused by long pause like GC, you can set
+ * larger value than default to spark.core.connection.ack.wait.timeout
+ */
+ logWarning(s"Could not find reference for received ack Message ${message.id}")
+ }
+ }
+ }
+ } else {
+ var ackMessage : Option[Message] = None
+ try {
+ ackMessage = if (onReceiveCallback != null) {
+ logDebug("Calling back")
+ onReceiveCallback(bufferMessage, connectionManagerId)
+ } else {
+ logDebug("Not calling back as callback is null")
+ None
+ }
+
+ if (ackMessage.isDefined) {
+ if (!ackMessage.get.isInstanceOf[BufferMessage]) {
+ logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
+ + ackMessage.get.getClass)
+ } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
+ logDebug("Response to " + bufferMessage + " does not have ack id set")
+ ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logError(s"Exception was thrown while processing message", e)
+ val m = Message.createBufferMessage(bufferMessage.id)
+ m.hasError = true
+ ackMessage = Some(m)
+ }
+ } finally {
+ sendMessage(connectionManagerId, ackMessage.getOrElse {
+ Message.createBufferMessage(bufferMessage.id)
+ })
+ }
+ }
+ }
+ case _ => throw new Exception("Unknown type message received")
+ }
+ }
+
+ private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) {
+ // see if we need to do sasl before writing
+ // this should only be the first negotiation as the Client!!!
+ if (!conn.isSaslComplete()) {
+ conn.synchronized {
+ if (conn.sparkSaslClient == null) {
+ conn.sparkSaslClient = new SparkSaslClient(securityManager)
+ var firstResponse: Array[Byte] = null
+ try {
+ firstResponse = conn.sparkSaslClient.firstToken()
+ val securityMsg = SecurityMessage.fromResponse(firstResponse,
+ conn.connectionId.toString())
+ val message = securityMsg.toBufferMessage
+ if (message == null) throw new Exception("Error creating security message")
+ connectionsAwaitingSasl += ((conn.connectionId, conn))
+ sendSecurityMessage(connManagerId, message)
+ logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
+ } catch {
+ case e: Exception => {
+ logError("Error getting first response from the SaslClient.", e)
+ conn.close()
+ throw new Exception("Error getting first response from the SaslClient")
+ }
+ }
+ }
+ }
+ } else {
+ logDebug("Sasl already established ")
+ }
+ }
+
+ // allow us to add messages to the inbox for doing sasl negotiating
+ private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) {
+ def startNewConnection(): SendingConnection = {
+ val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
+ newConnectionId)
+ logInfo("creating new sending connection for security! " + newConnectionId )
+ registerRequests.enqueue(newConnection)
+
+ newConnection
+ }
+ // I removed the lookupKey stuff as part of merge ... should I re-add it ?
+ // We did not find it useful in our test-env ...
+ // If we do re-add it, we should consistently use it everywhere I guess ?
+ message.senderAddress = id.toSocketAddress()
+ logTrace("Sending Security [" + message + "] to [" + connManagerId + "]")
+ val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection())
+
+ // send security message until going connection has been authenticated
+ connection.send(message)
+
+ wakeupSelector()
+ }
+
+ private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ def startNewConnection(): SendingConnection = {
+ val inetSocketAddress = new InetSocketAddress(connectionManagerId.host,
+ connectionManagerId.port)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
+ newConnectionId)
+ logTrace("creating new sending connection: " + newConnectionId)
+ registerRequests.enqueue(newConnection)
+
+ newConnection
+ }
+ val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
+ if (authEnabled) {
+ checkSendAuthFirst(connectionManagerId, connection)
+ }
+ message.senderAddress = id.toSocketAddress()
+ logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
+ "connectionid: " + connection.connectionId)
+
+ if (authEnabled) {
+ // if we aren't authenticated yet lets block the senders until authentication completes
+ try {
+ connection.getAuthenticated().synchronized {
+ val clock = SystemClock
+ val startTime = clock.getTime()
+
+ while (!connection.isSaslComplete()) {
+ logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
+ // have timeout in case remote side never responds
+ connection.getAuthenticated().wait(500)
+ if (((clock.getTime() - startTime) >= (authTimeout * 1000))
+ && (!connection.isSaslComplete())) {
+ // took to long to authenticate the connection, something probably went wrong
+ throw new Exception("Took to long for authentication to " + connectionManagerId +
+ ", waited " + authTimeout + "seconds, failing.")
+ }
+ }
+ }
+ } catch {
+ case e: Exception => logError("Exception while waiting for authentication.", e)
+
+ // need to tell sender it failed
+ messageStatuses.synchronized {
+ val s = messageStatuses.get(message.id)
+ s match {
+ case Some(msgStatus) => {
+ messageStatuses -= message.id
+ logInfo("Notifying " + msgStatus.connectionManagerId)
+ msgStatus.markDone(None)
+ }
+ case None => {
+ logError("no messageStatus for failed message id: " + message.id)
+ }
+ }
+ }
+ }
+ }
+ logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
+ connection.send(message)
+
+ wakeupSelector()
+ }
+
+ private def wakeupSelector() {
+ selector.wakeup()
+ }
+
+ /**
+ * Send a message and block until an acknowldgment is received or an error occurs.
+ * @param connectionManagerId the message's destination
+ * @param message the message being sent
+ * @return a Future that either returns the acknowledgment message or captures an exception.
+ */
+ def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message)
+ : Future[Message] = {
+ val promise = Promise[Message]()
+
+ val timeoutTask = new TimerTask {
+ override def run(): Unit = {
+ messageStatuses.synchronized {
+ messageStatuses.remove(message.id).foreach ( s => {
+ promise.failure(
+ new IOException("sendMessageReliably failed because ack " +
+ s"was not received within $ackTimeout sec"))
+ })
+ }
+ }
+ }
+
+ val status = new MessageStatus(message, connectionManagerId, s => {
+ timeoutTask.cancel()
+ s.ackMessage match {
+ case None => // Indicates a failure where we either never sent or never got ACK'd
+ promise.failure(new IOException("sendMessageReliably failed without being ACK'd"))
+ case Some(ackMessage) =>
+ if (ackMessage.hasError) {
+ promise.failure(
+ new IOException("sendMessageReliably failed with ACK that signalled a remote error"))
+ } else {
+ promise.success(ackMessage)
+ }
+ }
+ })
+ messageStatuses.synchronized {
+ messageStatuses += ((message.id, status))
+ }
+
+ ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000)
+ sendMessage(connectionManagerId, message)
+ promise.future
+ }
+
+ def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
+ onReceiveCallback = callback
+ }
+
+ def stop() {
+ ackTimeoutMonitor.cancel()
+ selectorThread.interrupt()
+ selectorThread.join()
+ selector.close()
+ val connections = connectionsByKey.values
+ connections.foreach(_.close())
+ if (connectionsByKey.size != 0) {
+ logWarning("All connections not cleaned up")
+ }
+ handleMessageExecutor.shutdown()
+ handleReadWriteExecutor.shutdown()
+ handleConnectExecutor.shutdown()
+ logInfo("ConnectionManager stopped")
+ }
+}
+
+
+private[spark] object ConnectionManager {
+ import scala.concurrent.ExecutionContext.Implicits.global
+
+ def main(args: Array[String]) {
+ val conf = new SparkConf
+ val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ println("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ /* testSequentialSending(manager) */
+ /* System.gc() */
+
+ /* testParallelSending(manager) */
+ /* System.gc() */
+
+ /* testParallelDecreasingSending(manager) */
+ /* System.gc() */
+
+ testContinuousSending(manager)
+ System.gc()
+ }
+
+ def testSequentialSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Sequential Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf)
+ })
+ println("--------------------------")
+ println()
+ }
+
+ def testParallelSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Parallel Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {
+ f.onFailure {
+ case e => println("Failed due to " + e)
+ }
+ Await.ready(f, 1 second)
+ })
+ val finishTime = System.currentTimeMillis
+
+ val mb = size * count / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("--------------------------")
+ println("Started at " + startTime + ", finished at " + finishTime)
+ println("Sent " + count + " messages of size " + size + " in " + ms + " ms " +
+ "(" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+
+ def testParallelDecreasingSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Parallel Decreasing Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+ val buffers = Array.tabulate(count) { i =>
+ val bufferLen = size * (i + 1)
+ val bufferContent = Array.tabulate[Byte](bufferLen)(x => x.toByte)
+ ByteBuffer.allocate(bufferLen).put(bufferContent)
+ }
+ buffers.foreach(_.flip)
+ val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0
+
+ val startTime = System.currentTimeMillis
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {
+ f.onFailure {
+ case e => println("Failed due to " + e)
+ }
+ Await.ready(f, 1 second)
+ })
+ val finishTime = System.currentTimeMillis
+
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("--------------------------")
+ /* println("Started at " + startTime + ", finished at " + finishTime) */
+ println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+
+ def testContinuousSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Continuous Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ while(true) {
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {
+ f.onFailure {
+ case e => println("Failed due to " + e)
+ }
+ Await.ready(f, 1 second)
+ })
+ val finishTime = System.currentTimeMillis
+ Thread.sleep(1000)
+ val mb = size * count / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala
new file mode 100644
index 0000000..cbb37ec
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.net.InetSocketAddress
+
+import org.apache.spark.util.Utils
+
+private[nio] case class ConnectionManagerId(host: String, port: Int) {
+ // DEBUG code
+ Utils.checkHost(host)
+ assert (port > 0)
+
+ def toSocketAddress() = new InetSocketAddress(host, port)
+}
+
+
+private[nio] object ConnectionManagerId {
+ def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
+ new ConnectionManagerId(socketAddress.getHostName, socketAddress.getPort)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/Message.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
new file mode 100644
index 0000000..0b874c2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+
+private[nio] abstract class Message(val typ: Long, val id: Int) {
+ var senderAddress: InetSocketAddress = null
+ var started = false
+ var startTime = -1L
+ var finishTime = -1L
+ var isSecurityNeg = false
+ var hasError = false
+
+ def size: Int
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
+
+ def timeTaken(): String = (finishTime - startTime).toString + " ms"
+
+ override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
+}
+
+
+private[nio] object Message {
+ val BUFFER_MESSAGE = 1111111111L
+
+ var lastId = 1
+
+ def getNewId() = synchronized {
+ lastId += 1
+ if (lastId == 0) {
+ lastId += 1
+ }
+ lastId
+ }
+
+ def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
+ if (dataBuffers == null) {
+ return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
+ }
+ if (dataBuffers.exists(_ == null)) {
+ throw new Exception("Attempting to create buffer message with null buffer")
+ }
+ new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
+ }
+
+ def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
+ createBufferMessage(dataBuffers, 0)
+
+ def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
+ if (dataBuffer == null) {
+ createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
+ } else {
+ createBufferMessage(Array(dataBuffer), ackId)
+ }
+ }
+
+ def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
+ createBufferMessage(dataBuffer, 0)
+
+ def createBufferMessage(ackId: Int): BufferMessage = {
+ createBufferMessage(new Array[ByteBuffer](0), ackId)
+ }
+
+ def create(header: MessageChunkHeader): Message = {
+ val newMessage: Message = header.typ match {
+ case BUFFER_MESSAGE => new BufferMessage(header.id,
+ ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
+ }
+ newMessage.hasError = header.hasError
+ newMessage.senderAddress = header.address
+ newMessage
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala
new file mode 100644
index 0000000..278c5ac
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+private[nio]
+class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
+
+ val size = if (buffer == null) 0 else buffer.remaining
+
+ lazy val buffers = {
+ val ab = new ArrayBuffer[ByteBuffer]()
+ ab += header.buffer
+ if (buffer != null) {
+ ab += buffer
+ }
+ ab
+ }
+
+ override def toString = {
+ "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala
new file mode 100644
index 0000000..6e20f29
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.net.{InetAddress, InetSocketAddress}
+import java.nio.ByteBuffer
+
+private[nio] class MessageChunkHeader(
+ val typ: Long,
+ val id: Int,
+ val totalSize: Int,
+ val chunkSize: Int,
+ val other: Int,
+ val hasError: Boolean,
+ val securityNeg: Int,
+ val address: InetSocketAddress) {
+ lazy val buffer = {
+ // No need to change this, at 'use' time, we do a reverse lookup of the hostname.
+ // Refer to network.Connection
+ val ip = address.getAddress.getAddress()
+ val port = address.getPort()
+ ByteBuffer.
+ allocate(MessageChunkHeader.HEADER_SIZE).
+ putLong(typ).
+ putInt(id).
+ putInt(totalSize).
+ putInt(chunkSize).
+ putInt(other).
+ put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]).
+ putInt(securityNeg).
+ putInt(ip.size).
+ put(ip).
+ putInt(port).
+ position(MessageChunkHeader.HEADER_SIZE).
+ flip.asInstanceOf[ByteBuffer]
+ }
+
+ override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
+ " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg
+
+}
+
+
+private[nio] object MessageChunkHeader {
+ val HEADER_SIZE = 45
+
+ def create(buffer: ByteBuffer): MessageChunkHeader = {
+ if (buffer.remaining != HEADER_SIZE) {
+ throw new IllegalArgumentException("Cannot convert buffer data to Message")
+ }
+ val typ = buffer.getLong()
+ val id = buffer.getInt()
+ val totalSize = buffer.getInt()
+ val chunkSize = buffer.getInt()
+ val other = buffer.getInt()
+ val hasError = buffer.get() != 0
+ val securityNeg = buffer.getInt()
+ val ipSize = buffer.getInt()
+ val ipBytes = new Array[Byte](ipSize)
+ buffer.get(ipBytes)
+ val ip = InetAddress.getByAddress(ipBytes)
+ val port = buffer.getInt()
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg,
+ new InetSocketAddress(ip, port))
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
new file mode 100644
index 0000000..59958ee
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.nio.ByteBuffer
+
+import scala.concurrent.Future
+
+import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf}
+import org.apache.spark.network._
+import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+
+/**
+ * A [[BlockTransferService]] implementation based on [[ConnectionManager]], a custom
+ * implementation using Java NIO.
+ */
+final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityManager)
+ extends BlockTransferService with Logging {
+
+ private var cm: ConnectionManager = _
+
+ private var blockDataManager: BlockDataManager = _
+
+ /**
+ * Port number the service is listening on, available only after [[init]] is invoked.
+ */
+ override def port: Int = {
+ checkInit()
+ cm.id.port
+ }
+
+ /**
+ * Host name the service is listening on, available only after [[init]] is invoked.
+ */
+ override def hostName: String = {
+ checkInit()
+ cm.id.host
+ }
+
+ /**
+ * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
+ * local blocks or put local blocks.
+ */
+ override def init(blockDataManager: BlockDataManager): Unit = {
+ this.blockDataManager = blockDataManager
+ cm = new ConnectionManager(
+ conf.getInt("spark.blockManager.port", 0),
+ conf,
+ securityManager,
+ "Connection manager for block manager")
+ cm.onReceiveMessage(onBlockMessageReceive)
+ }
+
+ /**
+ * Tear down the transfer service.
+ */
+ override def stop(): Unit = {
+ if (cm != null) {
+ cm.stop()
+ }
+ }
+
+ override def fetchBlocks(
+ hostName: String,
+ port: Int,
+ blockIds: Seq[String],
+ listener: BlockFetchingListener): Unit = {
+ checkInit()
+
+ val cmId = new ConnectionManagerId(hostName, port)
+ val blockMessageArray = new BlockMessageArray(blockIds.map { blockId =>
+ BlockMessage.fromGetBlock(GetBlock(BlockId(blockId)))
+ })
+
+ val future = cm.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
+
+ // Register the listener on success/failure future callback.
+ future.onSuccess { case message =>
+ val bufferMessage = message.asInstanceOf[BufferMessage]
+ val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+
+ for (blockMessage <- blockMessageArray) {
+ if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
+ listener.onBlockFetchFailure(
+ new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId"))
+ } else {
+ val blockId = blockMessage.getId
+ val networkSize = blockMessage.getData.limit()
+ listener.onBlockFetchSuccess(
+ blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData))
+ }
+ }
+ }(cm.futureExecContext)
+
+ future.onFailure { case exception =>
+ listener.onBlockFetchFailure(exception)
+ }(cm.futureExecContext)
+ }
+
+ /**
+ * Upload a single block to a remote node, available only after [[init]] is invoked.
+ *
+ * This call blocks until the upload completes, or throws an exception upon failures.
+ */
+ override def uploadBlock(
+ hostname: String,
+ port: Int,
+ blockId: String,
+ blockData: ManagedBuffer,
+ level: StorageLevel)
+ : Future[Unit] = {
+ checkInit()
+ val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level)
+ val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg))
+ val remoteCmId = new ConnectionManagerId(hostName, port)
+ val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage)
+ reply.map(x => ())(cm.futureExecContext)
+ }
+
+ private def checkInit(): Unit = if (cm == null) {
+ throw new IllegalStateException(getClass.getName + " has not been initialized")
+ }
+
+ private def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
+ logDebug("Handling message " + msg)
+ msg match {
+ case bufferMessage: BufferMessage =>
+ try {
+ logDebug("Handling as a buffer message " + bufferMessage)
+ val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
+ logDebug("Parsed as a block message array")
+ val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
+ Some(new BlockMessageArray(responseMessages).toBufferMessage)
+ } catch {
+ case e: Exception => {
+ logError("Exception handling buffer message", e)
+ val errorMessage = Message.createBufferMessage(msg.id)
+ errorMessage.hasError = true
+ Some(errorMessage)
+ }
+ }
+
+ case otherMessage: Any =>
+ logError("Unknown type message received: " + otherMessage)
+ val errorMessage = Message.createBufferMessage(msg.id)
+ errorMessage.hasError = true
+ Some(errorMessage)
+ }
+ }
+
+ private def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = {
+ blockMessage.getType match {
+ case BlockMessage.TYPE_PUT_BLOCK =>
+ val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
+ logDebug("Received [" + msg + "]")
+ putBlock(msg.id.toString, msg.data, msg.level)
+ None
+
+ case BlockMessage.TYPE_GET_BLOCK =>
+ val msg = new GetBlock(blockMessage.getId)
+ logDebug("Received [" + msg + "]")
+ val buffer = getBlock(msg.id.toString)
+ if (buffer == null) {
+ return None
+ }
+ Some(BlockMessage.fromGotBlock(GotBlock(msg.id, buffer)))
+
+ case _ => None
+ }
+ }
+
+ private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ val startTimeMs = System.currentTimeMillis()
+ logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes)
+ blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level)
+ logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ + " with data size: " + bytes.limit)
+ }
+
+ private def getBlock(blockId: String): ByteBuffer = {
+ val startTimeMs = System.currentTimeMillis()
+ logDebug("GetBlock " + blockId + " started from " + startTimeMs)
+ val buffer = blockDataManager.getBlockData(blockId).orNull
+ logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ + " and got buffer " + buffer)
+ buffer.nioByteBuffer()
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala
new file mode 100644
index 0000000..747a208
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.{ArrayBuffer, StringBuilder}
+
+import org.apache.spark._
+
+/**
+ * SecurityMessage is class that contains the connectionId and sasl token
+ * used in SASL negotiation. SecurityMessage has routines for converting
+ * it to and from a BufferMessage so that it can be sent by the ConnectionManager
+ * and easily consumed by users when received.
+ * The api was modeled after BlockMessage.
+ *
+ * The connectionId is the connectionId of the client side. Since
+ * message passing is asynchronous and its possible for the server side (receiving)
+ * to get multiple different types of messages on the same connection the connectionId
+ * is used to know which connnection the security message is intended for.
+ *
+ * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side
+ * is acting as a client and connecting to node_1. SASL negotiation has to occur
+ * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message.
+ * node_1 receives the message from node_0 but before it can process it and send a response,
+ * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0
+ * and sends a security message of its own to authenticate as a client. Now node_0 gets
+ * the message and it needs to decide if this message is in response to it being a client
+ * (from the first send) or if its just node_1 trying to connect to it to send data. This
+ * is where the connectionId field is used. node_0 can lookup the connectionId to see if
+ * it is in response to it being a client or if its in response to someone sending other data.
+ *
+ * The format of a SecurityMessage as its sent is:
+ * - Length of the ConnectionId
+ * - ConnectionId
+ * - Length of the token
+ * - Token
+ */
+private[nio] class SecurityMessage extends Logging {
+
+ private var connectionId: String = null
+ private var token: Array[Byte] = null
+
+ def set(byteArr: Array[Byte], newconnectionId: String) {
+ if (byteArr == null) {
+ token = new Array[Byte](0)
+ } else {
+ token = byteArr
+ }
+ connectionId = newconnectionId
+ }
+
+ /**
+ * Read the given buffer and set the members of this class.
+ */
+ def set(buffer: ByteBuffer) {
+ val idLength = buffer.getInt()
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buffer.getChar()
+ }
+ connectionId = idBuilder.toString()
+
+ val tokenLength = buffer.getInt()
+ token = new Array[Byte](tokenLength)
+ if (tokenLength > 0) {
+ buffer.get(token, 0, tokenLength)
+ }
+ }
+
+ def set(bufferMsg: BufferMessage) {
+ val buffer = bufferMsg.buffers.apply(0)
+ buffer.clear()
+ set(buffer)
+ }
+
+ def getConnectionId: String = {
+ return connectionId
+ }
+
+ def getToken: Array[Byte] = {
+ return token
+ }
+
+ /**
+ * Create a BufferMessage that can be sent by the ConnectionManager containing
+ * the security information from this class.
+ * @return BufferMessage
+ */
+ def toBufferMessage: BufferMessage = {
+ val buffers = new ArrayBuffer[ByteBuffer]()
+
+ // 4 bytes for the length of the connectionId
+ // connectionId is of type char so multiple the length by 2 to get number of bytes
+ // 4 bytes for the length of token
+ // token is a byte buffer so just take the length
+ var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length)
+ buffer.putInt(connectionId.length())
+ connectionId.foreach((x: Char) => buffer.putChar(x))
+ buffer.putInt(token.length)
+
+ if (token.length > 0) {
+ buffer.put(token)
+ }
+ buffer.flip()
+ buffers += buffer
+
+ var message = Message.createBufferMessage(buffers)
+ logDebug("message total size is : " + message.size)
+ message.isSecurityNeg = true
+ return message
+ }
+
+ override def toString: String = {
+ "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]"
+ }
+}
+
+private[nio] object SecurityMessage {
+
+ /**
+ * Convert the given BufferMessage to a SecurityMessage by parsing the contents
+ * of the BufferMessage and populating the SecurityMessage fields.
+ * @param bufferMessage is a BufferMessage that was received
+ * @return new SecurityMessage
+ */
+ def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = {
+ val newSecurityMessage = new SecurityMessage()
+ newSecurityMessage.set(bufferMessage)
+ newSecurityMessage
+ }
+
+ /**
+ * Create a SecurityMessage to send from a given saslResponse.
+ * @param response is the response to a challenge from the SaslClient or Saslserver
+ * @param connectionId the client connectionId we are negotiation authentication for
+ * @return a new SecurityMessage
+ */
+ def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = {
+ val newSecurityMessage = new SecurityMessage()
+ newSecurityMessage.set(response, connectionId)
+ newSecurityMessage
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 87ef9bb..d6386f8 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -27,9 +27,9 @@ import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
import org.apache.spark._
import org.apache.spark.broadcast.HttpBroadcast
+import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage._
-import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}
import org.apache.spark.util.BoundedPriorityQueue
import org.apache.spark.util.collection.CompactBuffer
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index 96faccc..439981d 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -26,6 +26,7 @@ import scala.collection.JavaConversions._
import org.apache.spark.{SparkEnv, SparkConf, Logging}
import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.storage._
@@ -166,34 +167,30 @@ class FileShuffleBlockManager(conf: SparkConf)
}
}
- /**
- * Returns the physical file segment in which the given BlockId is located.
- */
- private def getBlockLocation(id: ShuffleBlockId): FileSegment = {
+ override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
+ val segment = getBlockData(blockId)
+ Some(segment.nioByteBuffer())
+ }
+
+ override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
if (consolidateShuffleFiles) {
// Search all file groups associated with this shuffle.
- val shuffleState = shuffleStates(id.shuffleId)
+ val shuffleState = shuffleStates(blockId.shuffleId)
val iter = shuffleState.allFileGroups.iterator
while (iter.hasNext) {
- val segment = iter.next.getFileSegmentFor(id.mapId, id.reduceId)
- if (segment.isDefined) { return segment.get }
+ val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId)
+ if (segmentOpt.isDefined) {
+ val segment = segmentOpt.get
+ return new FileSegmentManagedBuffer(segment.file, segment.offset, segment.length)
+ }
}
- throw new IllegalStateException("Failed to find shuffle block: " + id)
+ throw new IllegalStateException("Failed to find shuffle block: " + blockId)
} else {
- val file = blockManager.diskBlockManager.getFile(id)
- new FileSegment(file, 0, file.length())
+ val file = blockManager.diskBlockManager.getFile(blockId)
+ new FileSegmentManagedBuffer(file, 0, file.length)
}
}
- override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
- val segment = getBlockLocation(blockId)
- blockManager.diskStore.getBytes(segment)
- }
-
- override def getBlockData(blockId: ShuffleBlockId): Either[FileSegment, ByteBuffer] = {
- Left(getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]))
- }
-
/** Remove all the blocks / files and metadata related to a particular shuffle. */
def removeShuffle(shuffleId: ShuffleId): Boolean = {
// Do not change the ordering of this, if shuffleStates should be removed only
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
index 8bb9efc..4ab3433 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -21,6 +21,7 @@ import java.io._
import java.nio.ByteBuffer
import org.apache.spark.SparkEnv
+import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer}
import org.apache.spark.storage._
/**
@@ -89,10 +90,11 @@ class IndexShuffleBlockManager extends ShuffleBlockManager {
}
}
- /**
- * Get the location of a block in a map output file. Uses the index file we create for it.
- * */
- private def getBlockLocation(blockId: ShuffleBlockId): FileSegment = {
+ override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
+ Some(getBlockData(blockId).nioByteBuffer())
+ }
+
+ override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
// The block is actually going to be a range of a single map output file for this map, so
// find out the consolidated file, then the offset within that from our index
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
@@ -102,20 +104,14 @@ class IndexShuffleBlockManager extends ShuffleBlockManager {
in.skip(blockId.reduceId * 8)
val offset = in.readLong()
val nextOffset = in.readLong()
- new FileSegment(getDataFile(blockId.shuffleId, blockId.mapId), offset, nextOffset - offset)
+ new FileSegmentManagedBuffer(
+ getDataFile(blockId.shuffleId, blockId.mapId),
+ offset,
+ nextOffset - offset)
} finally {
in.close()
}
}
- override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
- val segment = getBlockLocation(blockId)
- blockManager.diskStore.getBytes(segment)
- }
-
- override def getBlockData(blockId: ShuffleBlockId): Either[FileSegment, ByteBuffer] = {
- Left(getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]))
- }
-
override def stop() = {}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
index 4240580..63863cc 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
@@ -19,7 +19,8 @@ package org.apache.spark.shuffle
import java.nio.ByteBuffer
-import org.apache.spark.storage.{FileSegment, ShuffleBlockId}
+import org.apache.spark.network.ManagedBuffer
+import org.apache.spark.storage.ShuffleBlockId
private[spark]
trait ShuffleBlockManager {
@@ -31,8 +32,7 @@ trait ShuffleBlockManager {
*/
def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer]
- def getBlockData(blockId: ShuffleBlockId): Either[FileSegment, ByteBuffer]
+ def getBlockData(blockId: ShuffleBlockId): ManagedBuffer
def stop(): Unit
}
-
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 12b4756..6cf9305 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -21,10 +21,9 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import org.apache.spark._
-import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
private[hash] object BlockStoreShuffleFetcher extends Logging {
@@ -32,8 +31,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
shuffleId: Int,
reduceId: Int,
context: TaskContext,
- serializer: Serializer,
- shuffleMetrics: ShuffleReadMetrics)
+ serializer: Serializer)
: Iterator[T] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
@@ -74,7 +72,13 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
}
}
- val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics)
+ val blockFetcherItr = new ShuffleBlockFetcherIterator(
+ context,
+ SparkEnv.get.blockTransferService,
+ blockManager,
+ blocksByAddress,
+ serializer,
+ SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)
val itr = blockFetcherItr.flatMap(unpackBlock)
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 7bed97a..88a5f1e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -36,10 +36,8 @@ private[spark] class HashShuffleReader[K, C](
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
- val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val ser = Serializer.getSerializer(dep.serializer)
- val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser,
- readMetrics)
+ val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
deleted file mode 100644
index e35b7fe..0000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ /dev/null
@@ -1,254 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.util.concurrent.LinkedBlockingQueue
-import org.apache.spark.network.netty.client.{BlockClientListener, LazyInitIterator, ReferenceCountedBuffer}
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.Queue
-import scala.util.{Failure, Success}
-
-import org.apache.spark.{Logging, SparkException}
-import org.apache.spark.executor.ShuffleReadMetrics
-import org.apache.spark.network.BufferMessage
-import org.apache.spark.network.ConnectionManagerId
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.util.Utils
-
-/**
- * A block fetcher iterator interface for fetching shuffle blocks.
- */
-private[storage]
-trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
- def initialize()
-}
-
-
-private[storage]
-object BlockFetcherIterator {
-
- /**
- * A request to fetch blocks from a remote BlockManager.
- * @param address remote BlockManager to fetch from.
- * @param blocks Sequence of tuple, where the first element is the block id,
- * and the second element is the estimated size, used to calculate bytesInFlight.
- */
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
- val size = blocks.map(_._2).sum
- }
-
- /**
- * Result of a fetch from a remote block. A failure is represented as size == -1.
- * @param blockId block id
- * @param size estimated size of the block, used to calculate bytesInFlight.
- * Note that this is NOT the exact bytes.
- * @param deserialize closure to return the result in the form of an Iterator.
- */
- class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
- def failed: Boolean = size == -1
- }
-
- // TODO: Refactor this whole thing to make code more reusable.
- class BasicBlockFetcherIterator(
- private val blockManager: BlockManager,
- val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
- serializer: Serializer,
- readMetrics: ShuffleReadMetrics)
- extends BlockFetcherIterator {
-
- import blockManager._
-
- if (blocksByAddress == null) {
- throw new IllegalArgumentException("BlocksByAddress is null")
- }
-
- // Total number blocks fetched (local + remote). Also number of FetchResults expected
- protected var _numBlocksToFetch = 0
-
- protected var startTime = System.currentTimeMillis
-
- // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
- protected val localBlocksToFetch = new ArrayBuffer[BlockId]()
-
- // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
- protected val remoteBlocksToFetch = new HashSet[BlockId]()
-
- // A queue to hold our results.
- protected val results = new LinkedBlockingQueue[FetchResult]
-
- // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
- // the number of bytes in flight is limited to maxBytesInFlight
- protected val fetchRequests = new Queue[FetchRequest]
-
- // Current bytes in flight from our requests
- protected var bytesInFlight = 0L
-
- protected def sendRequest(req: FetchRequest) {
- logDebug("Sending request for %d blocks (%s) from %s".format(
- req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
- val cmId = new ConnectionManagerId(req.address.host, req.address.port)
- val blockMessageArray = new BlockMessageArray(req.blocks.map {
- case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
- })
- bytesInFlight += req.size
- val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
- val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
- future.onComplete {
- case Success(message) => {
- val bufferMessage = message.asInstanceOf[BufferMessage]
- val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
- for (blockMessage <- blockMessageArray) {
- if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
- throw new SparkException(
- "Unexpected message " + blockMessage.getType + " received from " + cmId)
- }
- val blockId = blockMessage.getId
- val networkSize = blockMessage.getData.limit()
- results.put(new FetchResult(blockId, sizeMap(blockId),
- () => dataDeserialize(blockId, blockMessage.getData, serializer)))
- // TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can
- // be incrementing bytes read at the same time (SPARK-2625).
- readMetrics.remoteBytesRead += networkSize
- readMetrics.remoteBlocksFetched += 1
- logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
- }
- }
- case Failure(exception) => {
- logError("Could not get block(s) from " + cmId, exception)
- for ((blockId, size) <- req.blocks) {
- results.put(new FetchResult(blockId, -1, null))
- }
- }
- }
- }
-
- protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
- // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
- // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
- // nodes, rather than blocking on reading output from one node.
- val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
-
- // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
- // at most maxBytesInFlight in order to limit the amount of data in flight.
- val remoteRequests = new ArrayBuffer[FetchRequest]
- var totalBlocks = 0
- for ((address, blockInfos) <- blocksByAddress) {
- totalBlocks += blockInfos.size
- if (address == blockManagerId) {
- // Filter out zero-sized blocks
- localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
- _numBlocksToFetch += localBlocksToFetch.size
- } else {
- val iterator = blockInfos.iterator
- var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(BlockId, Long)]
- while (iterator.hasNext) {
- val (blockId, size) = iterator.next()
- // Skip empty blocks
- if (size > 0) {
- curBlocks += ((blockId, size))
- remoteBlocksToFetch += blockId
- _numBlocksToFetch += 1
- curRequestSize += size
- } else if (size < 0) {
- throw new BlockException(blockId, "Negative block size " + size)
- }
- if (curRequestSize >= targetRequestSize) {
- // Add this FetchRequest
- remoteRequests += new FetchRequest(address, curBlocks)
- curBlocks = new ArrayBuffer[(BlockId, Long)]
- logDebug(s"Creating fetch request of $curRequestSize at $address")
- curRequestSize = 0
- }
- }
- // Add in the final request
- if (!curBlocks.isEmpty) {
- remoteRequests += new FetchRequest(address, curBlocks)
- }
- }
- }
- logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " +
- totalBlocks + " blocks")
- remoteRequests
- }
-
- protected def getLocalBlocks() {
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- for (id <- localBlocksToFetch) {
- try {
- readMetrics.localBlocksFetched += 1
- results.put(new FetchResult(id, 0, () => getLocalShuffleFromDisk(id, serializer).get))
- logDebug("Got local block " + id)
- } catch {
- case e: Exception => {
- logError(s"Error occurred while fetching local blocks", e)
- results.put(new FetchResult(id, -1, null))
- return
- }
- }
- }
- }
-
- override def initialize() {
- // Split local and remote blocks.
- val remoteRequests = splitLocalRemoteBlocks()
- // Add the remote requests into our queue in a random order
- fetchRequests ++= Utils.randomize(remoteRequests)
-
- // Send out initial requests for blocks, up to our maxBytesInFlight
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
- }
-
- val numFetches = remoteRequests.size - fetchRequests.size
- logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
-
- // Get Local Blocks
- startTime = System.currentTimeMillis
- getLocalBlocks()
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
- }
-
- // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue
- // as they arrive.
- @volatile protected var resultsGotten = 0
-
- override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
-
- override def next(): (BlockId, Option[Iterator[Any]]) = {
- resultsGotten += 1
- val startFetchWait = System.currentTimeMillis()
- val result = results.take()
- val stopFetchWait = System.currentTimeMillis()
- readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
- if (! result.failed) bytesInFlight -= result.size
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
- }
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
- }
- }
- // End of BasicBlockFetcherIterator
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org
[2/5] [SPARK-3019] Pluggable block transfer interface
(BlockTransferService)
Posted by rx...@apache.org.
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index a714142..d1bee3d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -20,6 +20,8 @@ package org.apache.spark.storage
import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
+import scala.concurrent.ExecutionContext.Implicits.global
+
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
@@ -58,18 +60,14 @@ private[spark] class BlockManager(
defaultSerializer: Serializer,
maxMemory: Long,
val conf: SparkConf,
- securityManager: SecurityManager,
mapOutputTracker: MapOutputTracker,
- shuffleManager: ShuffleManager)
- extends BlockDataProvider with Logging {
+ shuffleManager: ShuffleManager,
+ blockTransferService: BlockTransferService)
+ extends BlockDataManager with Logging {
- private val port = conf.getInt("spark.blockManager.port", 0)
+ blockTransferService.init(this)
val diskBlockManager = new DiskBlockManager(this, conf)
- val connectionManager =
- new ConnectionManager(port, conf, securityManager, "Connection manager for block manager")
-
- implicit val futureExecContext = connectionManager.futureExecContext
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -89,11 +87,7 @@ private[spark] class BlockManager(
}
val blockManagerId = BlockManagerId(
- executorId, connectionManager.id.host, connectionManager.id.port)
-
- // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
- // for receiving shuffle outputs)
- val maxBytesInFlight = conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024
+ executorId, blockTransferService.hostName, blockTransferService.port)
// Whether to compress broadcast variables that are stored
private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
@@ -136,11 +130,11 @@ private[spark] class BlockManager(
master: BlockManagerMaster,
serializer: Serializer,
conf: SparkConf,
- securityManager: SecurityManager,
mapOutputTracker: MapOutputTracker,
- shuffleManager: ShuffleManager) = {
+ shuffleManager: ShuffleManager,
+ blockTransferService: BlockTransferService) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
- conf, securityManager, mapOutputTracker, shuffleManager)
+ conf, mapOutputTracker, shuffleManager, blockTransferService)
}
/**
@@ -149,7 +143,6 @@ private[spark] class BlockManager(
*/
private def initialize(): Unit = {
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
- BlockManagerWorker.startBlockManagerWorker(this)
}
/**
@@ -212,21 +205,34 @@ private[spark] class BlockManager(
}
}
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
+ /**
+ * Interface to get local block data.
+ *
+ * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ */
+ override def getBlockData(blockId: String): Option[ManagedBuffer] = {
val bid = BlockId(blockId)
if (bid.isShuffle) {
- shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId])
+ Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]))
} else {
val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
if (blockBytesOpt.isDefined) {
- Right(blockBytesOpt.get)
+ val buffer = blockBytesOpt.get
+ Some(new NioByteBufferManagedBuffer(buffer))
} else {
- throw new BlockNotFoundException(blockId)
+ None
}
}
}
/**
+ * Put the block locally, using the given storage level.
+ */
+ override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = {
+ putBytes(BlockId(blockId), data.nioByteBuffer(), level)
+ }
+
+ /**
* Get the BlockStatus for the block identified by the given ID, if it exists.
* NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon.
*/
@@ -333,16 +339,10 @@ private[spark] class BlockManager(
* shuffle blocks. It is safe to do so without a lock on block info since disk store
* never deletes (recent) items.
*/
- def getLocalShuffleFromDisk(
- blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
-
- val shuffleBlockManager = shuffleManager.shuffleBlockManager
- val values = shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]).map(
- bytes => this.dataDeserialize(blockId, bytes, serializer))
-
- values.orElse {
- throw new BlockException(blockId, s"Block $blockId not found on disk, though it should be")
- }
+ def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
+ val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
+ val is = wrapForCompression(blockId, buf.inputStream())
+ Some(serializer.newInstance().deserializeStream(is).asIterator)
}
/**
@@ -513,8 +513,9 @@ private[spark] class BlockManager(
val locations = Random.shuffle(master.getLocations(blockId))
for (loc <- locations) {
logDebug(s"Getting remote block $blockId from $loc")
- val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
+ val data = blockTransferService.fetchBlockSync(
+ loc.host, loc.port, blockId.toString).nioByteBuffer()
+
if (data != null) {
if (asBlockResult) {
return Some(new BlockResult(
@@ -548,22 +549,6 @@ private[spark] class BlockManager(
None
}
- /**
- * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns
- * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined
- * fashion as they're received. Expects a size in bytes to be provided for each block fetched,
- * so that we can control the maxMegabytesInFlight for the fetch.
- */
- def getMultiple(
- blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
- serializer: Serializer,
- readMetrics: ShuffleReadMetrics): BlockFetcherIterator = {
- val iter = new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer,
- readMetrics)
- iter.initialize()
- iter
- }
-
def putIterator(
blockId: BlockId,
values: Iterator[Any],
@@ -816,12 +801,15 @@ private[spark] class BlockManager(
data.rewind()
logDebug(s"Try to replicate $blockId once; The size of the data is ${data.limit()} Bytes. " +
s"To node: $peer")
- val putBlock = PutBlock(blockId, data, tLevel)
- val cmId = new ConnectionManagerId(peer.host, peer.port)
- val syncPutBlockSuccess = BlockManagerWorker.syncPutBlock(putBlock, cmId)
- if (!syncPutBlockSuccess) {
- logError(s"Failed to call syncPutBlock to $peer")
+
+ try {
+ blockTransferService.uploadBlockSync(
+ peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel)
+ } catch {
+ case e: Exception =>
+ logError(s"Failed to replicate block to $peer", e)
}
+
logDebug("Replicating BlockId %s once used %fs; The size of the data is %d bytes."
.format(blockId, (System.nanoTime - start) / 1e6, data.limit()))
}
@@ -1051,7 +1039,7 @@ private[spark] class BlockManager(
}
def stop(): Unit = {
- connectionManager.stop()
+ blockTransferService.stop()
diskBlockManager.stop()
actorSystem.stop(slaveActor)
blockInfo.clear()
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index b7bcb2d..d4487fc 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -36,8 +36,8 @@ import org.apache.spark.util.Utils
class BlockManagerId private (
private var executorId_ : String,
private var host_ : String,
- private var port_ : Int
- ) extends Externalizable {
+ private var port_ : Int)
+ extends Externalizable {
private def this() = this(null, null, 0) // For deserialization only
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
deleted file mode 100644
index bf002a4..0000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
-import org.apache.spark.Logging
-import org.apache.spark.network._
-import org.apache.spark.util.Utils
-
-import scala.concurrent.Await
-import scala.concurrent.duration.Duration
-import scala.util.{Try, Failure, Success}
-
-/**
- * A network interface for BlockManager. Each slave should have one
- * BlockManagerWorker.
- *
- * TODO: Use event model.
- */
-private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
-
- blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
-
- def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
- logDebug("Handling message " + msg)
- msg match {
- case bufferMessage: BufferMessage => {
- try {
- logDebug("Handling as a buffer message " + bufferMessage)
- val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
- logDebug("Parsed as a block message array")
- val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
- Some(new BlockMessageArray(responseMessages).toBufferMessage)
- } catch {
- case e: Exception => {
- logError("Exception handling buffer message", e)
- val errorMessage = Message.createBufferMessage(msg.id)
- errorMessage.hasError = true
- Some(errorMessage)
- }
- }
- }
- case otherMessage: Any => {
- logError("Unknown type message received: " + otherMessage)
- val errorMessage = Message.createBufferMessage(msg.id)
- errorMessage.hasError = true
- Some(errorMessage)
- }
- }
- }
-
- def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = {
- blockMessage.getType match {
- case BlockMessage.TYPE_PUT_BLOCK => {
- val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
- logDebug("Received [" + pB + "]")
- putBlock(pB.id, pB.data, pB.level)
- None
- }
- case BlockMessage.TYPE_GET_BLOCK => {
- val gB = new GetBlock(blockMessage.getId)
- logDebug("Received [" + gB + "]")
- val buffer = getBlock(gB.id)
- if (buffer == null) {
- return None
- }
- Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer)))
- }
- case _ => None
- }
- }
-
- private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) {
- val startTimeMs = System.currentTimeMillis()
- logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
- blockManager.putBytes(id, bytes, level)
- logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
- + " with data size: " + bytes.limit)
- }
-
- private def getBlock(id: BlockId): ByteBuffer = {
- val startTimeMs = System.currentTimeMillis()
- logDebug("GetBlock " + id + " started from " + startTimeMs)
- val buffer = blockManager.getLocalBytes(id) match {
- case Some(bytes) => bytes
- case None => null
- }
- logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
- + " and got buffer " + buffer)
- buffer
- }
-}
-
-private[spark] object BlockManagerWorker extends Logging {
- private var blockManagerWorker: BlockManagerWorker = null
-
- def startBlockManagerWorker(manager: BlockManager) {
- blockManagerWorker = new BlockManagerWorker(manager)
- }
-
- def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
- val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val blockMessage = BlockMessage.fromPutBlock(msg)
- val blockMessageArray = new BlockMessageArray(blockMessage)
- val resultMessage = Try(Await.result(connectionManager.sendMessageReliably(
- toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf))
- resultMessage.isSuccess
- }
-
- def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
- val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val blockMessage = BlockMessage.fromGetBlock(msg)
- val blockMessageArray = new BlockMessageArray(blockMessage)
- val responseMessage = Try(Await.result(connectionManager.sendMessageReliably(
- toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf))
- responseMessage match {
- case Success(message) => {
- val bufferMessage = message.asInstanceOf[BufferMessage]
- logDebug("Response message received " + bufferMessage)
- BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => {
- logDebug("Found " + blockMessage)
- return blockMessage.getData
- })
- }
- case Failure(exception) => logDebug("No response message received")
- }
- null
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
deleted file mode 100644
index a2bfce7..0000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
+++ /dev/null
@@ -1,209 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.StringBuilder
-
-import org.apache.spark.network._
-
-private[spark] case class GetBlock(id: BlockId)
-private[spark] case class GotBlock(id: BlockId, data: ByteBuffer)
-private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel)
-
-private[spark] class BlockMessage() {
- // Un-initialized: typ = 0
- // GetBlock: typ = 1
- // GotBlock: typ = 2
- // PutBlock: typ = 3
- private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
- private var id: BlockId = null
- private var data: ByteBuffer = null
- private var level: StorageLevel = null
-
- def set(getBlock: GetBlock) {
- typ = BlockMessage.TYPE_GET_BLOCK
- id = getBlock.id
- }
-
- def set(gotBlock: GotBlock) {
- typ = BlockMessage.TYPE_GOT_BLOCK
- id = gotBlock.id
- data = gotBlock.data
- }
-
- def set(putBlock: PutBlock) {
- typ = BlockMessage.TYPE_PUT_BLOCK
- id = putBlock.id
- data = putBlock.data
- level = putBlock.level
- }
-
- def set(buffer: ByteBuffer) {
- /*
- println()
- println("BlockMessage: ")
- while(buffer.remaining > 0) {
- print(buffer.get())
- }
- buffer.rewind()
- println()
- println()
- */
- typ = buffer.getInt()
- val idLength = buffer.getInt()
- val idBuilder = new StringBuilder(idLength)
- for (i <- 1 to idLength) {
- idBuilder += buffer.getChar()
- }
- id = BlockId(idBuilder.toString)
-
- if (typ == BlockMessage.TYPE_PUT_BLOCK) {
-
- val booleanInt = buffer.getInt()
- val replication = buffer.getInt()
- level = StorageLevel(booleanInt, replication)
-
- val dataLength = buffer.getInt()
- data = ByteBuffer.allocate(dataLength)
- if (dataLength != buffer.remaining) {
- throw new Exception("Error parsing buffer")
- }
- data.put(buffer)
- data.flip()
- } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
-
- val dataLength = buffer.getInt()
- data = ByteBuffer.allocate(dataLength)
- if (dataLength != buffer.remaining) {
- throw new Exception("Error parsing buffer")
- }
- data.put(buffer)
- data.flip()
- }
-
- }
-
- def set(bufferMsg: BufferMessage) {
- val buffer = bufferMsg.buffers.apply(0)
- buffer.clear()
- set(buffer)
- }
-
- def getType: Int = typ
- def getId: BlockId = id
- def getData: ByteBuffer = data
- def getLevel: StorageLevel = level
-
- def toBufferMessage: BufferMessage = {
- val buffers = new ArrayBuffer[ByteBuffer]()
- var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2)
- buffer.putInt(typ).putInt(id.name.length)
- id.name.foreach((x: Char) => buffer.putChar(x))
- buffer.flip()
- buffers += buffer
-
- if (typ == BlockMessage.TYPE_PUT_BLOCK) {
- buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication)
- buffer.flip()
- buffers += buffer
-
- buffer = ByteBuffer.allocate(4).putInt(data.remaining)
- buffer.flip()
- buffers += buffer
-
- buffers += data
- } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
- buffer = ByteBuffer.allocate(4).putInt(data.remaining)
- buffer.flip()
- buffers += buffer
-
- buffers += data
- }
-
- /*
- println()
- println("BlockMessage: ")
- buffers.foreach(b => {
- while(b.remaining > 0) {
- print(b.get())
- }
- b.rewind()
- })
- println()
- println()
- */
- Message.createBufferMessage(buffers)
- }
-
- override def toString: String = {
- "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level +
- ", data = " + (if (data != null) data.remaining.toString else "null") + "]"
- }
-}
-
-private[spark] object BlockMessage {
- val TYPE_NON_INITIALIZED: Int = 0
- val TYPE_GET_BLOCK: Int = 1
- val TYPE_GOT_BLOCK: Int = 2
- val TYPE_PUT_BLOCK: Int = 3
-
- def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = {
- val newBlockMessage = new BlockMessage()
- newBlockMessage.set(bufferMessage)
- newBlockMessage
- }
-
- def fromByteBuffer(buffer: ByteBuffer): BlockMessage = {
- val newBlockMessage = new BlockMessage()
- newBlockMessage.set(buffer)
- newBlockMessage
- }
-
- def fromGetBlock(getBlock: GetBlock): BlockMessage = {
- val newBlockMessage = new BlockMessage()
- newBlockMessage.set(getBlock)
- newBlockMessage
- }
-
- def fromGotBlock(gotBlock: GotBlock): BlockMessage = {
- val newBlockMessage = new BlockMessage()
- newBlockMessage.set(gotBlock)
- newBlockMessage
- }
-
- def fromPutBlock(putBlock: PutBlock): BlockMessage = {
- val newBlockMessage = new BlockMessage()
- newBlockMessage.set(putBlock)
- newBlockMessage
- }
-
- def main(args: Array[String]) {
- val B = new BlockMessage()
- val blockId = TestBlockId("ABC")
- B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
- val bMsg = B.toBufferMessage
- val C = new BlockMessage()
- C.set(bMsg)
-
- println(B.getId + " " + B.getLevel)
- println(C.getId + " " + C.getLevel)
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
deleted file mode 100644
index 973d85c..0000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
+++ /dev/null
@@ -1,160 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark._
-import org.apache.spark.network._
-
-private[spark]
-class BlockMessageArray(var blockMessages: Seq[BlockMessage])
- extends Seq[BlockMessage] with Logging {
-
- def this(bm: BlockMessage) = this(Array(bm))
-
- def this() = this(null.asInstanceOf[Seq[BlockMessage]])
-
- def apply(i: Int) = blockMessages(i)
-
- def iterator = blockMessages.iterator
-
- def length = blockMessages.length
-
- def set(bufferMessage: BufferMessage) {
- val startTime = System.currentTimeMillis
- val newBlockMessages = new ArrayBuffer[BlockMessage]()
- val buffer = bufferMessage.buffers(0)
- buffer.clear()
- /*
- println()
- println("BlockMessageArray: ")
- while(buffer.remaining > 0) {
- print(buffer.get())
- }
- buffer.rewind()
- println()
- println()
- */
- while (buffer.remaining() > 0) {
- val size = buffer.getInt()
- logDebug("Creating block message of size " + size + " bytes")
- val newBuffer = buffer.slice()
- newBuffer.clear()
- newBuffer.limit(size)
- logDebug("Trying to convert buffer " + newBuffer + " to block message")
- val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer)
- logDebug("Created " + newBlockMessage)
- newBlockMessages += newBlockMessage
- buffer.position(buffer.position() + size)
- }
- val finishTime = System.currentTimeMillis
- logDebug("Converted block message array from buffer message in " +
- (finishTime - startTime) / 1000.0 + " s")
- this.blockMessages = newBlockMessages
- }
-
- def toBufferMessage: BufferMessage = {
- val buffers = new ArrayBuffer[ByteBuffer]()
-
- blockMessages.foreach(blockMessage => {
- val bufferMessage = blockMessage.toBufferMessage
- logDebug("Adding " + blockMessage)
- val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size)
- sizeBuffer.flip
- buffers += sizeBuffer
- buffers ++= bufferMessage.buffers
- logDebug("Added " + bufferMessage)
- })
-
- logDebug("Buffer list:")
- buffers.foreach((x: ByteBuffer) => logDebug("" + x))
- /*
- println()
- println("BlockMessageArray: ")
- buffers.foreach(b => {
- while(b.remaining > 0) {
- print(b.get())
- }
- b.rewind()
- })
- println()
- println()
- */
- Message.createBufferMessage(buffers)
- }
-}
-
-private[spark] object BlockMessageArray {
-
- def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
- val newBlockMessageArray = new BlockMessageArray()
- newBlockMessageArray.set(bufferMessage)
- newBlockMessageArray
- }
-
- def main(args: Array[String]) {
- val blockMessages =
- (0 until 10).map { i =>
- if (i % 2 == 0) {
- val buffer = ByteBuffer.allocate(100)
- buffer.clear
- BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer,
- StorageLevel.MEMORY_ONLY_SER))
- } else {
- BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString)))
- }
- }
- val blockMessageArray = new BlockMessageArray(blockMessages)
- println("Block message array created")
-
- val bufferMessage = blockMessageArray.toBufferMessage
- println("Converted to buffer message")
-
- val totalSize = bufferMessage.size
- val newBuffer = ByteBuffer.allocate(totalSize)
- newBuffer.clear()
- bufferMessage.buffers.foreach(buffer => {
- assert (0 == buffer.position())
- newBuffer.put(buffer)
- buffer.rewind()
- })
- newBuffer.flip
- val newBufferMessage = Message.createBufferMessage(newBuffer)
- println("Copied to new buffer message, size = " + newBufferMessage.size)
-
- val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage)
- println("Converted back to block message array")
- newBlockMessageArray.foreach(blockMessage => {
- blockMessage.getType match {
- case BlockMessage.TYPE_PUT_BLOCK => {
- val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
- println(pB)
- }
- case BlockMessage.TYPE_GET_BLOCK => {
- val gB = new GetBlock(blockMessage.getId)
- println(gB)
- }
- }
- })
- }
-}
-
-
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
new file mode 100644
index 0000000..c8e708a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.Queue
+
+import org.apache.spark.{TaskContext, Logging, SparkException}
+import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.util.Utils
+
+
+/**
+ * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
+ * manager. For remote blocks, it fetches them using the provided BlockTransferService.
+ *
+ * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a
+ * pipelined fashion as they are received.
+ *
+ * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
+ * using too much memory.
+ *
+ * @param context [[TaskContext]], used for metrics update
+ * @param blockTransferService [[BlockTransferService]] for fetching remote blocks
+ * @param blockManager [[BlockManager]] for reading local blocks
+ * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
+ * For each block we also require the size (in bytes as a long field) in
+ * order to throttle the memory usage.
+ * @param serializer serializer used to deserialize the data.
+ * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
+ */
+private[spark]
+final class ShuffleBlockFetcherIterator(
+ context: TaskContext,
+ blockTransferService: BlockTransferService,
+ blockManager: BlockManager,
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
+ serializer: Serializer,
+ maxBytesInFlight: Long)
+ extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
+
+ import ShuffleBlockFetcherIterator._
+
+ /**
+ * Total number of blocks to fetch. This can be smaller than the total number of blocks
+ * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]].
+ *
+ * This should equal localBlocks.size + remoteBlocks.size.
+ */
+ private[this] var numBlocksToFetch = 0
+
+ /**
+ * The number of blocks proccessed by the caller. The iterator is exhausted when
+ * [[numBlocksProcessed]] == [[numBlocksToFetch]].
+ */
+ private[this] var numBlocksProcessed = 0
+
+ private[this] val startTime = System.currentTimeMillis
+
+ /** Local blocks to fetch, excluding zero-sized blocks. */
+ private[this] val localBlocks = new ArrayBuffer[BlockId]()
+
+ /** Remote blocks to fetch, excluding zero-sized blocks. */
+ private[this] val remoteBlocks = new HashSet[BlockId]()
+
+ /**
+ * A queue to hold our results. This turns the asynchronous model provided by
+ * [[BlockTransferService]] into a synchronous model (iterator).
+ */
+ private[this] val results = new LinkedBlockingQueue[FetchResult]
+
+ // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ // the number of bytes in flight is limited to maxBytesInFlight
+ private[this] val fetchRequests = new Queue[FetchRequest]
+
+ // Current bytes in flight from our requests
+ private[this] var bytesInFlight = 0L
+
+ private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+
+ initialize()
+
+ private[this] def sendRequest(req: FetchRequest) {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
+ bytesInFlight += req.size
+
+ // so we can look up the size of each blockID
+ val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
+ val blockIds = req.blocks.map(_._1.toString)
+
+ blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds,
+ new BlockFetchingListener {
+ override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
+ results.put(new FetchResult(BlockId(blockId), sizeMap(blockId),
+ () => serializer.newInstance().deserializeStream(
+ blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator
+ ))
+ shuffleMetrics.remoteBytesRead += data.size
+ shuffleMetrics.remoteBlocksFetched += 1
+ logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ }
+
+ override def onBlockFetchFailure(e: Throwable): Unit = {
+ logError("Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
+ // Note that there is a chance that some blocks have been fetched successfully, but we
+ // still add them to the failed queue. This is fine because when the caller see a
+ // FetchFailedException, it is going to fail the entire task anyway.
+ for ((blockId, size) <- req.blocks) {
+ results.put(new FetchResult(blockId, -1, null))
+ }
+ }
+ }
+ )
+ }
+
+ private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
+ // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
+
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+
+ // Tracks total number of blocks (including zero sized blocks)
+ var totalBlocks = 0
+ for ((address, blockInfos) <- blocksByAddress) {
+ totalBlocks += blockInfos.size
+ if (address == blockManager.blockManagerId) {
+ // Filter out zero-sized blocks
+ localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
+ numBlocksToFetch += localBlocks.size
+ } else {
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(BlockId, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ // Skip empty blocks
+ if (size > 0) {
+ curBlocks += ((blockId, size))
+ remoteBlocks += blockId
+ numBlocksToFetch += 1
+ curRequestSize += size
+ } else if (size < 0) {
+ throw new BlockException(blockId, "Negative block size " + size)
+ }
+ if (curRequestSize >= targetRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curBlocks = new ArrayBuffer[(BlockId, Long)]
+ logDebug(s"Creating fetch request of $curRequestSize at $address")
+ curRequestSize = 0
+ }
+ }
+ // Add in the final request
+ if (curBlocks.nonEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
+ }
+ }
+ logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
+ remoteRequests
+ }
+
+ private[this] def fetchLocalBlocks() {
+ // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+ // these all at once because they will just memory-map some files, so they won't consume
+ // any memory that might exceed our maxBytesInFlight
+ for (id <- localBlocks) {
+ try {
+ shuffleMetrics.localBlocksFetched += 1
+ results.put(new FetchResult(
+ id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get))
+ logDebug("Got local block " + id)
+ } catch {
+ case e: Exception =>
+ logError(s"Error occurred while fetching local blocks", e)
+ results.put(new FetchResult(id, -1, null))
+ return
+ }
+ }
+ }
+
+ private[this] def initialize(): Unit = {
+ // Split local and remote blocks.
+ val remoteRequests = splitLocalRemoteBlocks()
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
+
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ while (fetchRequests.nonEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+
+ val numFetches = remoteRequests.size - fetchRequests.size
+ logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
+
+ // Get Local Blocks
+ fetchLocalBlocks()
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
+
+ override def next(): (BlockId, Option[Iterator[Any]]) = {
+ numBlocksProcessed += 1
+ val startFetchWait = System.currentTimeMillis()
+ val result = results.take()
+ val stopFetchWait = System.currentTimeMillis()
+ shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
+ if (!result.failed) {
+ bytesInFlight -= result.size
+ }
+ // Send fetch requests up to maxBytesInFlight
+ while (fetchRequests.nonEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+}
+
+
+private[storage]
+object ShuffleBlockFetcherIterator {
+
+ /**
+ * A request to fetch blocks from a remote BlockManager.
+ * @param address remote BlockManager to fetch from.
+ * @param blocks Sequence of tuple, where the first element is the block id,
+ * and the second element is the estimated size, used to calculate bytesInFlight.
+ */
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
+ val size = blocks.map(_._2).sum
+ }
+
+ /**
+ * Result of a fetch from a remote block. A failure is represented as size == -1.
+ * @param blockId block id
+ * @param size estimated size of the block, used to calculate bytesInFlight.
+ * Note that this is NOT the exact bytes.
+ * @param deserialize closure to return the result in the form of an Iterator.
+ */
+ class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
+ def failed: Boolean = size == -1
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
deleted file mode 100644
index 7540f0d..0000000
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.util.concurrent.ArrayBlockingQueue
-
-import akka.actor._
-import org.apache.spark.shuffle.hash.HashShuffleManager
-import util.Random
-
-import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
-import org.apache.spark.scheduler.LiveListenerBus
-import org.apache.spark.serializer.KryoSerializer
-
-/**
- * This class tests the BlockManager and MemoryStore for thread safety and
- * deadlocks. It spawns a number of producer and consumer threads. Producer
- * threads continuously pushes blocks into the BlockManager and consumer
- * threads continuously retrieves the blocks form the BlockManager and tests
- * whether the block is correct or not.
- */
-private[spark] object ThreadingTest {
-
- val numProducers = 5
- val numBlocksPerProducer = 20000
-
- private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread {
- val queue = new ArrayBlockingQueue[(BlockId, Seq[Int])](100)
-
- override def run() {
- for (i <- 1 to numBlocksPerProducer) {
- val blockId = TestBlockId("b-" + id + "-" + i)
- val blockSize = Random.nextInt(1000)
- val block = (1 to blockSize).map(_ => Random.nextInt())
- val level = randomLevel()
- val startTime = System.currentTimeMillis()
- manager.putIterator(blockId, block.iterator, level, tellMaster = true)
- println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms")
- queue.add((blockId, block))
- }
- println("Producer thread " + id + " terminated")
- }
-
- def randomLevel(): StorageLevel = {
- math.abs(Random.nextInt()) % 4 match {
- case 0 => StorageLevel.MEMORY_ONLY
- case 1 => StorageLevel.MEMORY_ONLY_SER
- case 2 => StorageLevel.MEMORY_AND_DISK
- case 3 => StorageLevel.MEMORY_AND_DISK_SER
- }
- }
- }
-
- private[spark] class ConsumerThread(
- manager: BlockManager,
- queue: ArrayBlockingQueue[(BlockId, Seq[Int])]
- ) extends Thread {
- var numBlockConsumed = 0
-
- override def run() {
- println("Consumer thread started")
- while(numBlockConsumed < numBlocksPerProducer) {
- val (blockId, block) = queue.take()
- val startTime = System.currentTimeMillis()
- manager.get(blockId) match {
- case Some(retrievedBlock) =>
- assert(retrievedBlock.data.toList.asInstanceOf[List[Int]] == block.toList,
- "Block " + blockId + " did not match")
- println("Got block " + blockId + " in " +
- (System.currentTimeMillis - startTime) + " ms")
- case None =>
- assert(false, "Block " + blockId + " could not be retrieved")
- }
- numBlockConsumed += 1
- }
- println("Consumer thread terminated")
- }
- }
-
- def main(args: Array[String]) {
- System.setProperty("spark.kryoserializer.buffer.mb", "1")
- val actorSystem = ActorSystem("test")
- val conf = new SparkConf()
- val serializer = new KryoSerializer(conf)
- val blockManagerMaster = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))),
- conf, true)
- val blockManager = new BlockManager(
- "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf,
- new SecurityManager(conf), new MapOutputTrackerMaster(conf), new HashShuffleManager(conf))
- val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
- val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
- producers.foreach(_.start)
- consumers.foreach(_.start)
- producers.foreach(_.join)
- consumers.foreach(_.join)
- blockManager.stop()
- blockManagerMaster.stop()
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- println("Everything stopped.")
- println(
- "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.")
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/test/scala/org/apache/spark/DistributedSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 41c294f..81b64c3 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -24,8 +24,7 @@ import org.scalatest.Matchers
import org.scalatest.time.{Millis, Span}
import org.apache.spark.SparkContext._
-import org.apache.spark.network.ConnectionManagerId
-import org.apache.spark.storage.{BlockManagerWorker, GetBlock, RDDBlockId, StorageLevel}
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
class NotSerializableClass
class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {}
@@ -136,7 +135,6 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter
sc.parallelize(1 to 10, 2).foreach { x => if (x == 1) System.exit(42) }
}
assert(thrown.getClass === classOf[SparkException])
- System.out.println(thrown.getMessage)
assert(thrown.getMessage.contains("failed 4 times"))
}
}
@@ -202,12 +200,13 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter
val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray
val blockId = blockIds(0)
val blockManager = SparkEnv.get.blockManager
- blockManager.master.getLocations(blockId).foreach(id => {
- val bytes = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(id.host, id.port))
- val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList
+ val blockTransfer = SparkEnv.get.blockTransferService
+ blockManager.master.getLocations(blockId).foreach { cmId =>
+ val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, blockId.toString)
+ val deserialized = blockManager.dataDeserialize(blockId, bytes.nioByteBuffer())
+ .asInstanceOf[Iterator[Int]].toList
assert(deserialized === (1 to 100).toList)
- })
+ }
}
test("compute without caching when no partitions fit in memory") {
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
deleted file mode 100644
index e2f4d4c..0000000
--- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
+++ /dev/null
@@ -1,301 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.io.IOException
-import java.nio._
-import java.util.concurrent.TimeoutException
-
-import org.apache.spark.{SecurityManager, SparkConf}
-import org.scalatest.FunSuite
-
-import org.mockito.Mockito._
-import org.mockito.Matchers._
-
-import scala.concurrent.TimeoutException
-import scala.concurrent.{Await, TimeoutException}
-import scala.concurrent.duration._
-import scala.language.postfixOps
-import scala.util.{Failure, Success, Try}
-
-/**
- * Test the ConnectionManager with various security settings.
- */
-class ConnectionManagerSuite extends FunSuite {
-
- test("security default off") {
- val conf = new SparkConf
- val securityManager = new SecurityManager(conf)
- val manager = new ConnectionManager(0, conf, securityManager)
- var receivedMessage = false
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- receivedMessage = true
- None
- })
-
- val size = 10 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds)
-
- assert(receivedMessage == true)
-
- manager.stop()
- }
-
- test("security on same password") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "true")
- conf.set("spark.authenticate.secret", "good")
- val securityManager = new SecurityManager(conf)
- val manager = new ConnectionManager(0, conf, securityManager)
- var numReceivedMessages = 0
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedMessages += 1
- None
- })
- val managerServer = new ConnectionManager(0, conf, securityManager)
- var numReceivedServerMessages = 0
- managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedServerMessages += 1
- None
- })
-
- val size = 10 * 1024 * 1024
- val count = 10
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- (0 until count).map(i => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds)
- })
-
- assert(numReceivedServerMessages == 10)
- assert(numReceivedMessages == 0)
-
- manager.stop()
- managerServer.stop()
- }
-
- test("security mismatch password") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "true")
- conf.set("spark.authenticate.secret", "good")
- val securityManager = new SecurityManager(conf)
- val manager = new ConnectionManager(0, conf, securityManager)
- var numReceivedMessages = 0
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedMessages += 1
- None
- })
-
- val badconf = new SparkConf
- badconf.set("spark.authenticate", "true")
- badconf.set("spark.authenticate.secret", "bad")
- val badsecurityManager = new SecurityManager(badconf)
- val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
- var numReceivedServerMessages = 0
-
- managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedServerMessages += 1
- None
- })
-
- val size = 10 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- // Expect managerServer to close connection, which we'll report as an error:
- intercept[IOException] {
- Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds)
- }
-
- assert(numReceivedServerMessages == 0)
- assert(numReceivedMessages == 0)
-
- manager.stop()
- managerServer.stop()
- }
-
- test("security mismatch auth off") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "false")
- conf.set("spark.authenticate.secret", "good")
- val securityManager = new SecurityManager(conf)
- val manager = new ConnectionManager(0, conf, securityManager)
- var numReceivedMessages = 0
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedMessages += 1
- None
- })
-
- val badconf = new SparkConf
- badconf.set("spark.authenticate", "true")
- badconf.set("spark.authenticate.secret", "good")
- val badsecurityManager = new SecurityManager(badconf)
- val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
- var numReceivedServerMessages = 0
- managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedServerMessages += 1
- None
- })
-
- val size = 10 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- (0 until 1).map(i => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliably(managerServer.id, bufferMessage)
- }).foreach(f => {
- try {
- val g = Await.result(f, 1 second)
- assert(false)
- } catch {
- case i: IOException =>
- assert(true)
- case e: TimeoutException => {
- // we should timeout here since the client can't do the negotiation
- assert(true)
- }
- }
- })
-
- assert(numReceivedServerMessages == 0)
- assert(numReceivedMessages == 0)
- manager.stop()
- managerServer.stop()
- }
-
- test("security auth off") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "false")
- val securityManager = new SecurityManager(conf)
- val manager = new ConnectionManager(0, conf, securityManager)
- var numReceivedMessages = 0
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedMessages += 1
- None
- })
-
- val badconf = new SparkConf
- badconf.set("spark.authenticate", "false")
- val badsecurityManager = new SecurityManager(badconf)
- val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
- var numReceivedServerMessages = 0
-
- managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedServerMessages += 1
- None
- })
-
- val size = 10 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- (0 until 10).map(i => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliably(managerServer.id, bufferMessage)
- }).foreach(f => {
- try {
- val g = Await.result(f, 1 second)
- } catch {
- case e: Exception => {
- assert(false)
- }
- }
- })
- assert(numReceivedServerMessages == 10)
- assert(numReceivedMessages == 0)
-
- manager.stop()
- managerServer.stop()
- }
-
- test("Ack error message") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "false")
- val securityManager = new SecurityManager(conf)
- val manager = new ConnectionManager(0, conf, securityManager)
- val managerServer = new ConnectionManager(0, conf, securityManager)
- managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- throw new Exception
- })
-
- val size = 10 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
- val bufferMessage = Message.createBufferMessage(buffer)
-
- val future = manager.sendMessageReliably(managerServer.id, bufferMessage)
-
- intercept[IOException] {
- Await.result(future, 1 second)
- }
-
- manager.stop()
- managerServer.stop()
-
- }
-
- test("sendMessageReliably timeout") {
- val clientConf = new SparkConf
- clientConf.set("spark.authenticate", "false")
- val ackTimeout = 30
- clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeout}")
-
- val clientSecurityManager = new SecurityManager(clientConf)
- val manager = new ConnectionManager(0, clientConf, clientSecurityManager)
-
- val serverConf = new SparkConf
- serverConf.set("spark.authenticate", "false")
- val serverSecurityManager = new SecurityManager(serverConf)
- val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager)
- managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- // sleep 60 sec > ack timeout for simulating server slow down or hang up
- Thread.sleep(ackTimeout * 3 * 1000)
- None
- })
-
- val size = 10 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
-
- val future = manager.sendMessageReliably(managerServer.id, bufferMessage)
-
- // Future should throw IOException in 30 sec.
- // Otherwise TimeoutExcepton is thrown from Await.result.
- // We expect TimeoutException is not thrown.
- intercept[IOException] {
- Await.result(future, (ackTimeout * 2) second)
- }
-
- manager.stop()
- managerServer.stop()
- }
-
-}
-
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
new file mode 100644
index 0000000..9f49587
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
@@ -0,0 +1,296 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.io.IOException
+import java.nio._
+
+import scala.concurrent.duration._
+import scala.concurrent.{Await, TimeoutException}
+import scala.language.postfixOps
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.{SecurityManager, SparkConf}
+
+/**
+ * Test the ConnectionManager with various security settings.
+ */
+class ConnectionManagerSuite extends FunSuite {
+
+ test("security default off") {
+ val conf = new SparkConf
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var receivedMessage = false
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ receivedMessage = true
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds)
+
+ assert(receivedMessage == true)
+
+ manager.stop()
+ }
+
+ test("security on same password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+ val managerServer = new ConnectionManager(0, conf, securityManager)
+ var numReceivedServerMessages = 0
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val count = 10
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds)
+ })
+
+ assert(numReceivedServerMessages == 10)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security mismatch password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "bad")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ // Expect managerServer to close connection, which we'll report as an error:
+ intercept[IOException] {
+ Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds)
+ }
+
+ assert(numReceivedServerMessages == 0)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security mismatch auth off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "good")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ (0 until 1).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(managerServer.id, bufferMessage)
+ }).foreach(f => {
+ try {
+ val g = Await.result(f, 1 second)
+ assert(false)
+ } catch {
+ case i: IOException =>
+ assert(true)
+ case e: TimeoutException => {
+ // we should timeout here since the client can't do the negotiation
+ assert(true)
+ }
+ }
+ })
+
+ assert(numReceivedServerMessages == 0)
+ assert(numReceivedMessages == 0)
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security auth off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ (0 until 10).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(managerServer.id, bufferMessage)
+ }).foreach(f => {
+ try {
+ val g = Await.result(f, 1 second)
+ } catch {
+ case e: Exception => {
+ assert(false)
+ }
+ }
+ })
+ assert(numReceivedServerMessages == 10)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("Ack error message") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ val managerServer = new ConnectionManager(0, conf, securityManager)
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ throw new Exception
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer)
+
+ val future = manager.sendMessageReliably(managerServer.id, bufferMessage)
+
+ intercept[IOException] {
+ Await.result(future, 1 second)
+ }
+
+ manager.stop()
+ managerServer.stop()
+
+ }
+
+ test("sendMessageReliably timeout") {
+ val clientConf = new SparkConf
+ clientConf.set("spark.authenticate", "false")
+ val ackTimeout = 30
+ clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeout}")
+
+ val clientSecurityManager = new SecurityManager(clientConf)
+ val manager = new ConnectionManager(0, clientConf, clientSecurityManager)
+
+ val serverConf = new SparkConf
+ serverConf.set("spark.authenticate", "false")
+ val serverSecurityManager = new SecurityManager(serverConf)
+ val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager)
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ // sleep 60 sec > ack timeout for simulating server slow down or hang up
+ Thread.sleep(ackTimeout * 3 * 1000)
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+
+ val future = manager.sendMessageReliably(managerServer.id, bufferMessage)
+
+ // Future should throw IOException in 30 sec.
+ // Otherwise TimeoutExcepton is thrown from Await.result.
+ // We expect TimeoutException is not thrown.
+ intercept[IOException] {
+ Await.result(future, (ackTimeout * 2) second)
+ }
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+}
+
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
index 6061e54..ba47fe5 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
@@ -25,6 +25,7 @@ import org.scalatest.FunSuite
import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf}
import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.FileShuffleBlockManager
import org.apache.spark.storage.{ShuffleBlockId, FileSegment}
@@ -32,10 +33,12 @@ import org.apache.spark.storage.{ShuffleBlockId, FileSegment}
class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
private val testConf = new SparkConf(false)
- private def checkSegments(segment1: FileSegment, segment2: FileSegment) {
- assert (segment1.file.getCanonicalPath === segment2.file.getCanonicalPath)
- assert (segment1.offset === segment2.offset)
- assert (segment1.length === segment2.length)
+ private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) {
+ assert(buffer.isInstanceOf[FileSegmentManagedBuffer])
+ val segment = buffer.asInstanceOf[FileSegmentManagedBuffer]
+ assert(expected.file.getCanonicalPath === segment.file.getCanonicalPath)
+ assert(expected.offset === segment.offset)
+ assert(expected.length === segment.length)
}
test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") {
@@ -95,14 +98,12 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
writer.commitAndClose()
}
// check before we register.
- checkSegments(shuffle2Segment, shuffleBlockManager.getBlockData(ShuffleBlockId(1, 2, 0)).left.get)
+ checkSegments(shuffle2Segment, shuffleBlockManager.getBlockData(ShuffleBlockId(1, 2, 0)))
shuffle3.releaseWriters(success = true)
- checkSegments(shuffle2Segment, shuffleBlockManager.getBlockData(ShuffleBlockId(1, 2, 0)).left.get)
+ checkSegments(shuffle2Segment, shuffleBlockManager.getBlockData(ShuffleBlockId(1, 2, 0)))
shuffleBlockManager.removeShuffle(1)
-
}
-
def writeToFile(file: File, numBytes: Int) {
val writer = new FileWriter(file, true)
for (i <- 0 until numBytes) writer.write(i)
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
deleted file mode 100644
index 3c86f6b..0000000
--- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
+++ /dev/null
@@ -1,237 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.io.IOException
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.future
-import scala.concurrent.ExecutionContext.Implicits.global
-
-import org.scalatest.{FunSuite, Matchers}
-
-import org.mockito.Mockito._
-import org.mockito.Matchers.{any, eq => meq}
-import org.mockito.stubbing.Answer
-import org.mockito.invocation.InvocationOnMock
-
-import org.apache.spark.storage.BlockFetcherIterator._
-import org.apache.spark.network.{ConnectionManager, Message}
-import org.apache.spark.executor.ShuffleReadMetrics
-
-class BlockFetcherIteratorSuite extends FunSuite with Matchers {
-
- test("block fetch from local fails using BasicBlockFetcherIterator") {
- val blockManager = mock(classOf[BlockManager])
- val connManager = mock(classOf[ConnectionManager])
- doReturn(connManager).when(blockManager).connectionManager
- doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
-
- doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight
-
- val blIds = Array[BlockId](
- ShuffleBlockId(0,0,0),
- ShuffleBlockId(0,1,0),
- ShuffleBlockId(0,2,0),
- ShuffleBlockId(0,3,0),
- ShuffleBlockId(0,4,0))
-
- val optItr = mock(classOf[Option[Iterator[Any]]])
- val answer = new Answer[Option[Iterator[Any]]] {
- override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] {
- throw new Exception
- }
- }
-
- // 3rd block is going to fail
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
- doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
-
- val bmId = BlockManagerId("test-client", "test-client", 1)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
- )
-
- val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null,
- new ShuffleReadMetrics())
-
- iterator.initialize()
-
- // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk.
- verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
- // the 2nd element of the tuple returned by iterator.next should be defined when fetching successfully
- assert(iterator.next()._2.isDefined, "1st element should be defined but is not actually defined")
- verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
- assert(iterator.next()._2.isDefined, "2nd element should be defined but is not actually defined")
- verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
- // 3rd fetch should be failed
- intercept[Exception] {
- iterator.next()
- }
- verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any())
- }
-
-
- test("block fetch from local succeed using BasicBlockFetcherIterator") {
- val blockManager = mock(classOf[BlockManager])
- val connManager = mock(classOf[ConnectionManager])
- doReturn(connManager).when(blockManager).connectionManager
- doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
-
- doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight
-
- val blIds = Array[BlockId](
- ShuffleBlockId(0,0,0),
- ShuffleBlockId(0,1,0),
- ShuffleBlockId(0,2,0),
- ShuffleBlockId(0,3,0),
- ShuffleBlockId(0,4,0))
-
- val optItr = mock(classOf[Option[Iterator[Any]]])
-
- // All blocks should be fetched successfully
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
-
- val bmId = BlockManagerId("test-client", "test-client", 1)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
- )
-
- val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null,
- new ShuffleReadMetrics())
-
- iterator.initialize()
-
- // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk.
- verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
- assert(iterator.next._2.isDefined, "All elements should be defined but 1st element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
- assert(iterator.next._2.isDefined, "All elements should be defined but 2nd element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
- assert(iterator.next._2.isDefined, "All elements should be defined but 3rd element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements")
- assert(iterator.next._2.isDefined, "All elements should be defined but 4th element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements")
- assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined")
-
- verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any())
- }
-
- test("block fetch from remote fails using BasicBlockFetcherIterator") {
- val blockManager = mock(classOf[BlockManager])
- val connManager = mock(classOf[ConnectionManager])
- when(blockManager.connectionManager).thenReturn(connManager)
-
- val f = future {
- throw new IOException("Send failed or we received an error ACK")
- }
- when(connManager.sendMessageReliably(any(),
- any())).thenReturn(f)
- when(blockManager.futureExecContext).thenReturn(global)
-
- when(blockManager.blockManagerId).thenReturn(
- BlockManagerId("test-client", "test-client", 1))
- when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)
-
- val blId1 = ShuffleBlockId(0,0,0)
- val blId2 = ShuffleBlockId(0,1,0)
- val bmId = BlockManagerId("test-server", "test-server", 1)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, Seq((blId1, 1L), (blId2, 1L)))
- )
-
- val iterator = new BasicBlockFetcherIterator(blockManager,
- blocksByAddress, null, new ShuffleReadMetrics())
-
- iterator.initialize()
- iterator.foreach{
- case (_, r) => {
- (!r.isDefined) should be(true)
- }
- }
- }
-
- test("block fetch from remote succeed using BasicBlockFetcherIterator") {
- val blockManager = mock(classOf[BlockManager])
- val connManager = mock(classOf[ConnectionManager])
- when(blockManager.connectionManager).thenReturn(connManager)
-
- val blId1 = ShuffleBlockId(0,0,0)
- val blId2 = ShuffleBlockId(0,1,0)
- val buf1 = ByteBuffer.allocate(4)
- val buf2 = ByteBuffer.allocate(4)
- buf1.putInt(1)
- buf1.flip()
- buf2.putInt(1)
- buf2.flip()
- val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1))
- val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2))
- val blockMessageArray = new BlockMessageArray(
- Seq(blockMessage1, blockMessage2))
-
- val bufferMessage = blockMessageArray.toBufferMessage
- val buffer = ByteBuffer.allocate(bufferMessage.size)
- val arrayBuffer = new ArrayBuffer[ByteBuffer]
- bufferMessage.buffers.foreach{ b =>
- buffer.put(b)
- }
- buffer.flip()
- arrayBuffer += buffer
-
- val f = future {
- Message.createBufferMessage(arrayBuffer)
- }
- when(connManager.sendMessageReliably(any(),
- any())).thenReturn(f)
- when(blockManager.futureExecContext).thenReturn(global)
-
- when(blockManager.blockManagerId).thenReturn(
- BlockManagerId("test-client", "test-client", 1))
- when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)
-
- val bmId = BlockManagerId("test-server", "test-server", 1)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, Seq((blId1, 1L), (blId2, 1L)))
- )
-
- val iterator = new BasicBlockFetcherIterator(blockManager,
- blocksByAddress, null, new ShuffleReadMetrics())
- iterator.initialize()
- iterator.foreach{
- case (_, r) => {
- (r.isDefined) should be(true)
- }
- }
- }
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org
[5/5] git commit: [SPARK-3019] Pluggable block transfer interface
(BlockTransferService)
Posted by rx...@apache.org.
[SPARK-3019] Pluggable block transfer interface (BlockTransferService)
This pull request creates a new BlockTransferService interface for block fetch/upload and refactors the existing ConnectionManager to implement BlockTransferService (NioBlockTransferService).
Most of the changes are simply moving code around. The main class to inspect is ShuffleBlockFetcherIterator.
Review guide:
- Most of the ConnectionManager code is now in network.cm package
- ManagedBuffer is a new buffer abstraction backed by several different implementations (file segment, nio ByteBuffer, Netty ByteBuf)
- BlockTransferService is the main internal interface introduced in this PR
- NioBlockTransferService implements BlockTransferService and replaces the old BlockManagerWorker
- ShuffleBlockFetcherIterator replaces the told BlockFetcherIterator to use the new interface
TODOs that should be separate PRs:
- Implement NettyBlockTransferService
- Finalize the API/semantics for ManagedBuffer.release()
Author: Reynold Xin <rx...@apache.org>
Closes #2240 from rxin/blockTransferService and squashes the following commits:
64cd9d7 [Reynold Xin] Merge branch 'master' into blockTransferService
1dfd3d7 [Reynold Xin] Limit the length of the FileInputStream.
1332156 [Reynold Xin] Fixed style violation from refactoring.
2960c93 [Reynold Xin] Added ShuffleBlockFetcherIteratorSuite.
e29c721 [Reynold Xin] Updated comment for ShuffleBlockFetcherIterator.
8a1046e [Reynold Xin] Code review feedback:
2c6b1e1 [Reynold Xin] Removed println in test cases.
2a907e4 [Reynold Xin] Merge branch 'master' into blockTransferService-merge
07ccf0d [Reynold Xin] Added init check to CMBlockTransferService.
98c668a [Reynold Xin] Added failure handling and fixed unit tests.
ae05fcd [Reynold Xin] Updated tests, although DistributedSuite is hanging.
d8d595c [Reynold Xin] Merge branch 'master' of github.com:apache/spark into blockTransferService
9ef279c [Reynold Xin] Initial refactoring to move ConnectionManager to use the BlockTransferService.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/08ce1888
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/08ce1888
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/08ce1888
Branch: refs/heads/master
Commit: 08ce18881e09c6e91db9c410d1d9ce1e5ae63a62
Parents: 939a322
Author: Reynold Xin <rx...@apache.org>
Authored: Mon Sep 8 15:59:20 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Mon Sep 8 15:59:20 2014 -0700
----------------------------------------------------------------------
.../main/scala/org/apache/spark/SparkEnv.scala | 15 +-
.../apache/spark/network/BlockDataManager.scala | 36 +
.../spark/network/BlockFetchingListener.scala | 37 +
.../spark/network/BlockTransferService.scala | 131 +++
.../apache/spark/network/BufferMessage.scala | 113 --
.../org/apache/spark/network/Connection.scala | 587 ----------
.../org/apache/spark/network/ConnectionId.scala | 34 -
.../spark/network/ConnectionManager.scala | 1047 ------------------
.../spark/network/ConnectionManagerId.scala | 37 -
.../spark/network/ConnectionManagerTest.scala | 103 --
.../apache/spark/network/ManagedBuffer.scala | 107 ++
.../org/apache/spark/network/Message.scala | 95 --
.../org/apache/spark/network/MessageChunk.scala | 41 -
.../spark/network/MessageChunkHeader.scala | 82 --
.../org/apache/spark/network/ReceiverTest.scala | 37 -
.../apache/spark/network/SecurityMessage.scala | 162 ---
.../org/apache/spark/network/SenderTest.scala | 76 --
.../apache/spark/network/nio/BlockMessage.scala | 197 ++++
.../spark/network/nio/BlockMessageArray.scala | 160 +++
.../spark/network/nio/BufferMessage.scala | 114 ++
.../apache/spark/network/nio/Connection.scala | 587 ++++++++++
.../apache/spark/network/nio/ConnectionId.scala | 34 +
.../spark/network/nio/ConnectionManager.scala | 1042 +++++++++++++++++
.../spark/network/nio/ConnectionManagerId.scala | 37 +
.../org/apache/spark/network/nio/Message.scala | 96 ++
.../apache/spark/network/nio/MessageChunk.scala | 41 +
.../spark/network/nio/MessageChunkHeader.scala | 81 ++
.../network/nio/NioBlockTransferService.scala | 205 ++++
.../spark/network/nio/SecurityMessage.scala | 160 +++
.../spark/serializer/KryoSerializer.scala | 2 +-
.../spark/shuffle/FileShuffleBlockManager.scala | 35 +-
.../shuffle/IndexShuffleBlockManager.scala | 24 +-
.../spark/shuffle/ShuffleBlockManager.scala | 6 +-
.../shuffle/hash/BlockStoreShuffleFetcher.scala | 14 +-
.../spark/shuffle/hash/HashShuffleReader.scala | 4 +-
.../spark/storage/BlockFetcherIterator.scala | 254 -----
.../org/apache/spark/storage/BlockManager.scala | 98 +-
.../apache/spark/storage/BlockManagerId.scala | 4 +-
.../spark/storage/BlockManagerWorker.scala | 147 ---
.../org/apache/spark/storage/BlockMessage.scala | 209 ----
.../spark/storage/BlockMessageArray.scala | 160 ---
.../storage/ShuffleBlockFetcherIterator.scala | 271 +++++
.../apache/spark/storage/ThreadingTest.scala | 120 --
.../org/apache/spark/DistributedSuite.scala | 15 +-
.../spark/network/ConnectionManagerSuite.scala | 301 -----
.../network/nio/ConnectionManagerSuite.scala | 296 +++++
.../shuffle/hash/HashShuffleManagerSuite.scala | 17 +-
.../storage/BlockFetcherIteratorSuite.scala | 237 ----
.../spark/storage/BlockManagerSuite.scala | 133 +--
.../spark/storage/DiskBlockManagerSuite.scala | 2 +-
.../ShuffleBlockFetcherIteratorSuite.scala | 183 +++
51 files changed, 3941 insertions(+), 4085 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 20a7444..dd95e40 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -31,7 +31,8 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.network.ConnectionManager
+import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
@@ -59,8 +60,8 @@ class SparkEnv (
val mapOutputTracker: MapOutputTracker,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
+ val blockTransferService: BlockTransferService,
val blockManager: BlockManager,
- val connectionManager: ConnectionManager,
val securityManager: SecurityManager,
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
@@ -88,6 +89,8 @@ class SparkEnv (
// down, but let's call it anyway in case it gets fixed in a later release
// UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it.
// actorSystem.awaitTermination()
+
+ // Note that blockTransferService is stopped by BlockManager since it is started by it.
}
private[spark]
@@ -223,14 +226,14 @@ object SparkEnv extends Logging {
val shuffleMemoryManager = new ShuffleMemoryManager(conf)
+ val blockTransferService = new NioBlockTransferService(conf, securityManager)
+
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
- serializer, conf, securityManager, mapOutputTracker, shuffleManager)
-
- val connectionManager = blockManager.connectionManager
+ serializer, conf, mapOutputTracker, shuffleManager, blockTransferService)
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
@@ -278,8 +281,8 @@ object SparkEnv extends Logging {
mapOutputTracker,
shuffleManager,
broadcastManager,
+ blockTransferService,
blockManager,
- connectionManager,
securityManager,
httpFileServer,
sparkFilesDir,
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
new file mode 100644
index 0000000..e0e9172
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import org.apache.spark.storage.StorageLevel
+
+
+trait BlockDataManager {
+
+ /**
+ * Interface to get local block data.
+ *
+ * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ */
+ def getBlockData(blockId: String): Option[ManagedBuffer]
+
+ /**
+ * Put the block locally, using the given storage level.
+ */
+ def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
new file mode 100644
index 0000000..34acaa5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.util.EventListener
+
+
+/**
+ * Listener callback interface for [[BlockTransferService.fetchBlocks]].
+ */
+trait BlockFetchingListener extends EventListener {
+
+ /**
+ * Called once per successfully fetched block.
+ */
+ def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit
+
+ /**
+ * Called upon failures. For each failure, this is called only once (i.e. not once per block).
+ */
+ def onBlockFetchFailure(exception: Throwable): Unit
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
new file mode 100644
index 0000000..84d991f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import scala.concurrent.{Await, Future}
+import scala.concurrent.duration.Duration
+
+import org.apache.spark.storage.StorageLevel
+
+
+abstract class BlockTransferService {
+
+ /**
+ * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
+ * local blocks or put local blocks.
+ */
+ def init(blockDataManager: BlockDataManager)
+
+ /**
+ * Tear down the transfer service.
+ */
+ def stop(): Unit
+
+ /**
+ * Port number the service is listening on, available only after [[init]] is invoked.
+ */
+ def port: Int
+
+ /**
+ * Host name the service is listening on, available only after [[init]] is invoked.
+ */
+ def hostName: String
+
+ /**
+ * Fetch a sequence of blocks from a remote node asynchronously,
+ * available only after [[init]] is invoked.
+ *
+ * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block,
+ * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block).
+ *
+ * Note that this API takes a sequence so the implementation can batch requests, and does not
+ * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
+ * the data of a block is fetched, rather than waiting for all blocks to be fetched.
+ */
+ def fetchBlocks(
+ hostName: String,
+ port: Int,
+ blockIds: Seq[String],
+ listener: BlockFetchingListener): Unit
+
+ /**
+ * Upload a single block to a remote node, available only after [[init]] is invoked.
+ */
+ def uploadBlock(
+ hostname: String,
+ port: Int,
+ blockId: String,
+ blockData: ManagedBuffer,
+ level: StorageLevel): Future[Unit]
+
+ /**
+ * A special case of [[fetchBlocks]], as it fetches only one block and is blocking.
+ *
+ * It is also only available after [[init]] is invoked.
+ */
+ def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = {
+ // A monitor for the thread to wait on.
+ val lock = new Object
+ @volatile var result: Either[ManagedBuffer, Throwable] = null
+ fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener {
+ override def onBlockFetchFailure(exception: Throwable): Unit = {
+ lock.synchronized {
+ result = Right(exception)
+ lock.notify()
+ }
+ }
+ override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
+ lock.synchronized {
+ result = Left(data)
+ lock.notify()
+ }
+ }
+ })
+
+ // Sleep until result is no longer null
+ lock.synchronized {
+ while (result == null) {
+ try {
+ lock.wait()
+ } catch {
+ case e: InterruptedException =>
+ }
+ }
+ }
+
+ result match {
+ case Left(data) => data
+ case Right(e) => throw e
+ }
+ }
+
+ /**
+ * Upload a single block to a remote node, available only after [[init]] is invoked.
+ *
+ * This method is similar to [[uploadBlock]], except this one blocks the thread
+ * until the upload finishes.
+ */
+ def uploadBlockSync(
+ hostname: String,
+ port: Int,
+ blockId: String,
+ blockData: ManagedBuffer,
+ level: StorageLevel): Unit = {
+ Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
deleted file mode 100644
index af35f1f..0000000
--- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
+++ /dev/null
@@ -1,113 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.storage.BlockManager
-
-private[spark]
-class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
- extends Message(Message.BUFFER_MESSAGE, id_) {
-
- val initialSize = currentSize()
- var gotChunkForSendingOnce = false
-
- def size = initialSize
-
- def currentSize() = {
- if (buffers == null || buffers.isEmpty) {
- 0
- } else {
- buffers.map(_.remaining).reduceLeft(_ + _)
- }
- }
-
- def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
- if (maxChunkSize <= 0) {
- throw new Exception("Max chunk size is " + maxChunkSize)
- }
-
- val security = if (isSecurityNeg) 1 else 0
- if (size == 0 && !gotChunkForSendingOnce) {
- val newChunk = new MessageChunk(
- new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null)
- gotChunkForSendingOnce = true
- return Some(newChunk)
- }
-
- while(!buffers.isEmpty) {
- val buffer = buffers(0)
- if (buffer.remaining == 0) {
- BlockManager.dispose(buffer)
- buffers -= buffer
- } else {
- val newBuffer = if (buffer.remaining <= maxChunkSize) {
- buffer.duplicate()
- } else {
- buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
- }
- buffer.position(buffer.position + newBuffer.remaining)
- val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId,
- hasError, security, senderAddress), newBuffer)
- gotChunkForSendingOnce = true
- return Some(newChunk)
- }
- }
- None
- }
-
- def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
- // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
- if (buffers.size > 1) {
- throw new Exception("Attempting to get chunk from message with multiple data buffers")
- }
- val buffer = buffers(0)
- val security = if (isSecurityNeg) 1 else 0
- if (buffer.remaining > 0) {
- if (buffer.remaining < chunkSize) {
- throw new Exception("Not enough space in data buffer for receiving chunk")
- }
- val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
- buffer.position(buffer.position + newBuffer.remaining)
- val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer)
- return Some(newChunk)
- }
- None
- }
-
- def flip() {
- buffers.foreach(_.flip)
- }
-
- def hasAckId() = (ackId != 0)
-
- def isCompletelyReceived() = !buffers(0).hasRemaining
-
- override def toString = {
- if (hasAckId) {
- "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
- } else {
- "BufferMessage(id = " + id + ", size = " + size + ")"
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/Connection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala
deleted file mode 100644
index 5285ec8..0000000
--- a/core/src/main/scala/org/apache/spark/network/Connection.scala
+++ /dev/null
@@ -1,587 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.net._
-import java.nio._
-import java.nio.channels._
-
-import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
-
-import org.apache.spark._
-
-private[spark]
-abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
- extends Logging {
-
- var sparkSaslServer: SparkSaslServer = null
- var sparkSaslClient: SparkSaslClient = null
-
- def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
- this(channel_, selector_,
- ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_)
- }
-
- channel.configureBlocking(false)
- channel.socket.setTcpNoDelay(true)
- channel.socket.setReuseAddress(true)
- channel.socket.setKeepAlive(true)
- /* channel.socket.setReceiveBufferSize(32768) */
-
- @volatile private var closed = false
- var onCloseCallback: Connection => Unit = null
- var onExceptionCallback: (Connection, Exception) => Unit = null
- var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
-
- val remoteAddress = getRemoteAddress()
-
- /**
- * Used to synchronize client requests: client's work-related requests must
- * wait until SASL authentication completes.
- */
- private val authenticated = new Object()
-
- def getAuthenticated(): Object = authenticated
-
- def isSaslComplete(): Boolean
-
- def resetForceReregister(): Boolean
-
- // Read channels typically do not register for write and write does not for read
- // Now, we do have write registering for read too (temporarily), but this is to detect
- // channel close NOT to actually read/consume data on it !
- // How does this work if/when we move to SSL ?
-
- // What is the interest to register with selector for when we want this connection to be selected
- def registerInterest()
-
- // What is the interest to register with selector for when we want this connection to
- // be de-selected
- // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack,
- // it will be SelectionKey.OP_READ (until we fix it properly)
- def unregisterInterest()
-
- // On receiving a read event, should we change the interest for this channel or not ?
- // Will be true for ReceivingConnection, false for SendingConnection.
- def changeInterestForRead(): Boolean
-
- private def disposeSasl() {
- if (sparkSaslServer != null) {
- sparkSaslServer.dispose()
- }
-
- if (sparkSaslClient != null) {
- sparkSaslClient.dispose()
- }
- }
-
- // On receiving a write event, should we change the interest for this channel or not ?
- // Will be false for ReceivingConnection, true for SendingConnection.
- // Actually, for now, should not get triggered for ReceivingConnection
- def changeInterestForWrite(): Boolean
-
- def getRemoteConnectionManagerId(): ConnectionManagerId = {
- socketRemoteConnectionManagerId
- }
-
- def key() = channel.keyFor(selector)
-
- def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
-
- // Returns whether we have to register for further reads or not.
- def read(): Boolean = {
- throw new UnsupportedOperationException(
- "Cannot read on connection of type " + this.getClass.toString)
- }
-
- // Returns whether we have to register for further writes or not.
- def write(): Boolean = {
- throw new UnsupportedOperationException(
- "Cannot write on connection of type " + this.getClass.toString)
- }
-
- def close() {
- closed = true
- val k = key()
- if (k != null) {
- k.cancel()
- }
- channel.close()
- disposeSasl()
- callOnCloseCallback()
- }
-
- protected def isClosed: Boolean = closed
-
- def onClose(callback: Connection => Unit) {
- onCloseCallback = callback
- }
-
- def onException(callback: (Connection, Exception) => Unit) {
- onExceptionCallback = callback
- }
-
- def onKeyInterestChange(callback: (Connection, Int) => Unit) {
- onKeyInterestChangeCallback = callback
- }
-
- def callOnExceptionCallback(e: Exception) {
- if (onExceptionCallback != null) {
- onExceptionCallback(this, e)
- } else {
- logError("Error in connection to " + getRemoteConnectionManagerId() +
- " and OnExceptionCallback not registered", e)
- }
- }
-
- def callOnCloseCallback() {
- if (onCloseCallback != null) {
- onCloseCallback(this)
- } else {
- logWarning("Connection to " + getRemoteConnectionManagerId() +
- " closed and OnExceptionCallback not registered")
- }
-
- }
-
- def changeConnectionKeyInterest(ops: Int) {
- if (onKeyInterestChangeCallback != null) {
- onKeyInterestChangeCallback(this, ops)
- } else {
- throw new Exception("OnKeyInterestChangeCallback not registered")
- }
- }
-
- def printRemainingBuffer(buffer: ByteBuffer) {
- val bytes = new Array[Byte](buffer.remaining)
- val curPosition = buffer.position
- buffer.get(bytes)
- bytes.foreach(x => print(x + " "))
- buffer.position(curPosition)
- print(" (" + bytes.size + ")")
- }
-
- def printBuffer(buffer: ByteBuffer, position: Int, length: Int) {
- val bytes = new Array[Byte](length)
- val curPosition = buffer.position
- buffer.position(position)
- buffer.get(bytes)
- bytes.foreach(x => print(x + " "))
- print(" (" + position + ", " + length + ")")
- buffer.position(curPosition)
- }
-}
-
-
-private[spark]
-class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId, id_ : ConnectionId)
- extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
-
- def isSaslComplete(): Boolean = {
- if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
- }
-
- private class Outbox {
- val messages = new Queue[Message]()
- val defaultChunkSize = 65536
- var nextMessageToBeUsed = 0
-
- def addMessage(message: Message) {
- messages.synchronized {
- /* messages += message */
- messages.enqueue(message)
- logDebug("Added [" + message + "] to outbox for sending to " +
- "[" + getRemoteConnectionManagerId() + "]")
- }
- }
-
- def getChunk(): Option[MessageChunk] = {
- messages.synchronized {
- while (!messages.isEmpty) {
- /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
- /* val message = messages(nextMessageToBeUsed) */
- val message = messages.dequeue()
- val chunk = message.getChunkForSending(defaultChunkSize)
- if (chunk.isDefined) {
- messages.enqueue(message)
- nextMessageToBeUsed = nextMessageToBeUsed + 1
- if (!message.started) {
- logDebug(
- "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
- message.started = true
- message.startTime = System.currentTimeMillis
- }
- logTrace(
- "Sending chunk from [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
- return chunk
- } else {
- message.finishTime = System.currentTimeMillis
- logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
- "] in " + message.timeTaken )
- }
- }
- }
- None
- }
- }
-
- // outbox is used as a lock - ensure that it is always used as a leaf (since methods which
- // lock it are invoked in context of other locks)
- private val outbox = new Outbox()
- /*
- This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly
- different purpose. This flag is to see if we need to force reregister for write even when we
- do not have any pending bytes to write to socket.
- This can happen due to a race between adding pending buffers, and checking for existing of
- data as detailed in https://github.com/mesos/spark/pull/791
- */
- private var needForceReregister = false
-
- val currentBuffers = new ArrayBuffer[ByteBuffer]()
-
- /* channel.socket.setSendBufferSize(256 * 1024) */
-
- override def getRemoteAddress() = address
-
- val DEFAULT_INTEREST = SelectionKey.OP_READ
-
- override def registerInterest() {
- // Registering read too - does not really help in most cases, but for some
- // it does - so let us keep it for now.
- changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
- }
-
- override def unregisterInterest() {
- changeConnectionKeyInterest(DEFAULT_INTEREST)
- }
-
- def send(message: Message) {
- outbox.synchronized {
- outbox.addMessage(message)
- needForceReregister = true
- }
- if (channel.isConnected) {
- registerInterest()
- }
- }
-
- // return previous value after resetting it.
- def resetForceReregister(): Boolean = {
- outbox.synchronized {
- val result = needForceReregister
- needForceReregister = false
- result
- }
- }
-
- // MUST be called within the selector loop
- def connect() {
- try{
- channel.register(selector, SelectionKey.OP_CONNECT)
- channel.connect(address)
- logInfo("Initiating connection to [" + address + "]")
- } catch {
- case e: Exception => {
- logError("Error connecting to " + address, e)
- callOnExceptionCallback(e)
- }
- }
- }
-
- def finishConnect(force: Boolean): Boolean = {
- try {
- // Typically, this should finish immediately since it was triggered by a connect
- // selection - though need not necessarily always complete successfully.
- val connected = channel.finishConnect
- if (!force && !connected) {
- logInfo(
- "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
- return false
- }
-
- // Fallback to previous behavior - assume finishConnect completed
- // This will happen only when finishConnect failed for some repeated number of times
- // (10 or so)
- // Is highly unlikely unless there was an unclean close of socket, etc
- registerInterest()
- logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
- } catch {
- case e: Exception => {
- logWarning("Error finishing connection to " + address, e)
- callOnExceptionCallback(e)
- }
- }
- true
- }
-
- override def write(): Boolean = {
- try {
- while (true) {
- if (currentBuffers.size == 0) {
- outbox.synchronized {
- outbox.getChunk() match {
- case Some(chunk) => {
- val buffers = chunk.buffers
- // If we have 'seen' pending messages, then reset flag - since we handle that as
- // normal registering of event (below)
- if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister()
-
- currentBuffers ++= buffers
- }
- case None => {
- // changeConnectionKeyInterest(0)
- /* key.interestOps(0) */
- return false
- }
- }
- }
- }
-
- if (currentBuffers.size > 0) {
- val buffer = currentBuffers(0)
- val remainingBytes = buffer.remaining
- val writtenBytes = channel.write(buffer)
- if (buffer.remaining == 0) {
- currentBuffers -= buffer
- }
- if (writtenBytes < remainingBytes) {
- // re-register for write.
- return true
- }
- }
- }
- } catch {
- case e: Exception => {
- logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
- callOnExceptionCallback(e)
- close()
- return false
- }
- }
- // should not happen - to keep scala compiler happy
- true
- }
-
- // This is a hack to determine if remote socket was closed or not.
- // SendingConnection DOES NOT expect to receive any data - if it does, it is an error
- // For a bunch of cases, read will return -1 in case remote socket is closed : hence we
- // register for reads to determine that.
- override def read(): Boolean = {
- // We don't expect the other side to send anything; so, we just read to detect an error or EOF.
- try {
- val length = channel.read(ByteBuffer.allocate(1))
- if (length == -1) { // EOF
- close()
- } else if (length > 0) {
- logWarning(
- "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
- }
- } catch {
- case e: Exception =>
- logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(),
- e)
- callOnExceptionCallback(e)
- close()
- }
-
- false
- }
-
- override def changeInterestForRead(): Boolean = false
-
- override def changeInterestForWrite(): Boolean = ! isClosed
-}
-
-
-// Must be created within selector loop - else deadlock
-private[spark] class ReceivingConnection(
- channel_ : SocketChannel,
- selector_ : Selector,
- id_ : ConnectionId)
- extends Connection(channel_, selector_, id_) {
-
- def isSaslComplete(): Boolean = {
- if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
- }
-
- class Inbox() {
- val messages = new HashMap[Int, BufferMessage]()
-
- def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
-
- def createNewMessage: BufferMessage = {
- val newMessage = Message.create(header).asInstanceOf[BufferMessage]
- newMessage.started = true
- newMessage.startTime = System.currentTimeMillis
- newMessage.isSecurityNeg = header.securityNeg == 1
- logDebug(
- "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
- messages += ((newMessage.id, newMessage))
- newMessage
- }
-
- val message = messages.getOrElseUpdate(header.id, createNewMessage)
- logTrace(
- "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
- message.getChunkForReceiving(header.chunkSize)
- }
-
- def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
- messages.get(chunk.header.id)
- }
-
- def removeMessage(message: Message) {
- messages -= message.id
- }
- }
-
- @volatile private var inferredRemoteManagerId: ConnectionManagerId = null
-
- override def getRemoteConnectionManagerId(): ConnectionManagerId = {
- val currId = inferredRemoteManagerId
- if (currId != null) currId else super.getRemoteConnectionManagerId()
- }
-
- // The reciever's remote address is the local socket on remote side : which is NOT
- // the connection manager id of the receiver.
- // We infer that from the messages we receive on the receiver socket.
- private def processConnectionManagerId(header: MessageChunkHeader) {
- val currId = inferredRemoteManagerId
- if (header.address == null || currId != null) return
-
- val managerId = ConnectionManagerId.fromSocketAddress(header.address)
-
- if (managerId != null) {
- inferredRemoteManagerId = managerId
- }
- }
-
-
- val inbox = new Inbox()
- val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
- var onReceiveCallback: (Connection, Message) => Unit = null
- var currentChunk: MessageChunk = null
-
- channel.register(selector, SelectionKey.OP_READ)
-
- override def read(): Boolean = {
- try {
- while (true) {
- if (currentChunk == null) {
- val headerBytesRead = channel.read(headerBuffer)
- if (headerBytesRead == -1) {
- close()
- return false
- }
- if (headerBuffer.remaining > 0) {
- // re-register for read event ...
- return true
- }
- headerBuffer.flip
- if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
- throw new Exception(
- "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
- }
- val header = MessageChunkHeader.create(headerBuffer)
- headerBuffer.clear()
-
- processConnectionManagerId(header)
-
- header.typ match {
- case Message.BUFFER_MESSAGE => {
- if (header.totalSize == 0) {
- if (onReceiveCallback != null) {
- onReceiveCallback(this, Message.create(header))
- }
- currentChunk = null
- // re-register for read event ...
- return true
- } else {
- currentChunk = inbox.getChunk(header).orNull
- }
- }
- case _ => throw new Exception("Message of unknown type received")
- }
- }
-
- if (currentChunk == null) throw new Exception("No message chunk to receive data")
-
- val bytesRead = channel.read(currentChunk.buffer)
- if (bytesRead == 0) {
- // re-register for read event ...
- return true
- } else if (bytesRead == -1) {
- close()
- return false
- }
-
- /* logDebug("Read " + bytesRead + " bytes for the buffer") */
-
- if (currentChunk.buffer.remaining == 0) {
- /* println("Filled buffer at " + System.currentTimeMillis) */
- val bufferMessage = inbox.getMessageForChunk(currentChunk).get
- if (bufferMessage.isCompletelyReceived) {
- bufferMessage.flip()
- bufferMessage.finishTime = System.currentTimeMillis
- logDebug("Finished receiving [" + bufferMessage + "] from " +
- "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
- if (onReceiveCallback != null) {
- onReceiveCallback(this, bufferMessage)
- }
- inbox.removeMessage(bufferMessage)
- }
- currentChunk = null
- }
- }
- } catch {
- case e: Exception => {
- logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
- callOnExceptionCallback(e)
- close()
- return false
- }
- }
- // should not happen - to keep scala compiler happy
- true
- }
-
- def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
-
- // override def changeInterestForRead(): Boolean = ! isClosed
- override def changeInterestForRead(): Boolean = true
-
- override def changeInterestForWrite(): Boolean = {
- throw new IllegalStateException("Unexpected invocation right now")
- }
-
- override def registerInterest() {
- // Registering read too - does not really help in most cases, but for some
- // it does - so let us keep it for now.
- changeConnectionKeyInterest(SelectionKey.OP_READ)
- }
-
- override def unregisterInterest() {
- changeConnectionKeyInterest(0)
- }
-
- // For read conn, always false.
- override def resetForceReregister(): Boolean = false
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/ConnectionId.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala
deleted file mode 100644
index d579c16..0000000
--- a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) {
- override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId
-}
-
-private[spark] object ConnectionId {
-
- def createConnectionIdFromString(connectionIdString: String): ConnectionId = {
- val res = connectionIdString.split("_").map(_.trim())
- if (res.size != 3) {
- throw new Exception("Error converting ConnectionId string: " + connectionIdString +
- " to a ConnectionId Object")
- }
- new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt)
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
deleted file mode 100644
index 578d806..0000000
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ /dev/null
@@ -1,1047 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.io.IOException
-import java.nio._
-import java.nio.channels._
-import java.nio.channels.spi._
-import java.net._
-import java.util.{Timer, TimerTask}
-import java.util.concurrent.atomic.AtomicInteger
-
-import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.SynchronizedMap
-import scala.collection.mutable.SynchronizedQueue
-
-import scala.concurrent.{Await, ExecutionContext, Future, Promise}
-import scala.concurrent.duration._
-import scala.language.postfixOps
-
-import org.apache.spark._
-import org.apache.spark.util.{SystemClock, Utils}
-
-private[spark] class ConnectionManager(
- port: Int,
- conf: SparkConf,
- securityManager: SecurityManager,
- name: String = "Connection manager")
- extends Logging {
-
- /**
- * Used by sendMessageReliably to track messages being sent.
- * @param message the message that was sent
- * @param connectionManagerId the connection manager that sent this message
- * @param completionHandler callback that's invoked when the send has completed or failed
- */
- class MessageStatus(
- val message: Message,
- val connectionManagerId: ConnectionManagerId,
- completionHandler: MessageStatus => Unit) {
-
- /** This is non-None if message has been ack'd */
- var ackMessage: Option[Message] = None
-
- def markDone(ackMessage: Option[Message]) {
- this.ackMessage = ackMessage
- completionHandler(this)
- }
- }
-
- private val selector = SelectorProvider.provider.openSelector()
- private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
-
- // default to 30 second timeout waiting for authentication
- private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
- private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
-
- private val handleMessageExecutor = new ThreadPoolExecutor(
- conf.getInt("spark.core.connection.handler.threads.min", 20),
- conf.getInt("spark.core.connection.handler.threads.max", 60),
- conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS,
- new LinkedBlockingDeque[Runnable](),
- Utils.namedThreadFactory("handle-message-executor"))
-
- private val handleReadWriteExecutor = new ThreadPoolExecutor(
- conf.getInt("spark.core.connection.io.threads.min", 4),
- conf.getInt("spark.core.connection.io.threads.max", 32),
- conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS,
- new LinkedBlockingDeque[Runnable](),
- Utils.namedThreadFactory("handle-read-write-executor"))
-
- // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks :
- // which should be executed asap
- private val handleConnectExecutor = new ThreadPoolExecutor(
- conf.getInt("spark.core.connection.connect.threads.min", 1),
- conf.getInt("spark.core.connection.connect.threads.max", 8),
- conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS,
- new LinkedBlockingDeque[Runnable](),
- Utils.namedThreadFactory("handle-connect-executor"))
-
- private val serverChannel = ServerSocketChannel.open()
- // used to track the SendingConnections waiting to do SASL negotiation
- private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection]
- with SynchronizedMap[ConnectionId, SendingConnection]
- private val connectionsByKey =
- new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
- private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
- with SynchronizedMap[ConnectionManagerId, SendingConnection]
- private val messageStatuses = new HashMap[Int, MessageStatus]
- private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
- private val registerRequests = new SynchronizedQueue[SendingConnection]
-
- implicit val futureExecContext = ExecutionContext.fromExecutor(
- Utils.newDaemonCachedThreadPool("Connection manager future execution context"))
-
- @volatile
- private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message] = null
-
- private val authEnabled = securityManager.isAuthenticationEnabled()
-
- serverChannel.configureBlocking(false)
- serverChannel.socket.setReuseAddress(true)
- serverChannel.socket.setReceiveBufferSize(256 * 1024)
-
- private def startService(port: Int): (ServerSocketChannel, Int) = {
- serverChannel.socket.bind(new InetSocketAddress(port))
- (serverChannel, serverChannel.socket.getLocalPort)
- }
- Utils.startServiceOnPort[ServerSocketChannel](port, startService, name)
- serverChannel.register(selector, SelectionKey.OP_ACCEPT)
-
- val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
- logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
-
- // used in combination with the ConnectionManagerId to create unique Connection ids
- // to be able to track asynchronous messages
- private val idCount: AtomicInteger = new AtomicInteger(1)
-
- private val selectorThread = new Thread("connection-manager-thread") {
- override def run() = ConnectionManager.this.run()
- }
- selectorThread.setDaemon(true)
- selectorThread.start()
-
- private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
-
- private def triggerWrite(key: SelectionKey) {
- val conn = connectionsByKey.getOrElse(key, null)
- if (conn == null) return
-
- writeRunnableStarted.synchronized {
- // So that we do not trigger more write events while processing this one.
- // The write method will re-register when done.
- if (conn.changeInterestForWrite()) conn.unregisterInterest()
- if (writeRunnableStarted.contains(key)) {
- // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE)
- return
- }
-
- writeRunnableStarted += key
- }
- handleReadWriteExecutor.execute(new Runnable {
- override def run() {
- var register: Boolean = false
- try {
- register = conn.write()
- } finally {
- writeRunnableStarted.synchronized {
- writeRunnableStarted -= key
- val needReregister = register || conn.resetForceReregister()
- if (needReregister && conn.changeInterestForWrite()) {
- conn.registerInterest()
- }
- }
- }
- }
- } )
- }
-
- private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
-
- private def triggerRead(key: SelectionKey) {
- val conn = connectionsByKey.getOrElse(key, null)
- if (conn == null) return
-
- readRunnableStarted.synchronized {
- // So that we do not trigger more read events while processing this one.
- // The read method will re-register when done.
- if (conn.changeInterestForRead())conn.unregisterInterest()
- if (readRunnableStarted.contains(key)) {
- return
- }
-
- readRunnableStarted += key
- }
- handleReadWriteExecutor.execute(new Runnable {
- override def run() {
- var register: Boolean = false
- try {
- register = conn.read()
- } finally {
- readRunnableStarted.synchronized {
- readRunnableStarted -= key
- if (register && conn.changeInterestForRead()) {
- conn.registerInterest()
- }
- }
- }
- }
- } )
- }
-
- private def triggerConnect(key: SelectionKey) {
- val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection]
- if (conn == null) return
-
- // prevent other events from being triggered
- // Since we are still trying to connect, we do not need to do the additional steps in
- // triggerWrite
- conn.changeConnectionKeyInterest(0)
-
- handleConnectExecutor.execute(new Runnable {
- override def run() {
-
- var tries: Int = 10
- while (tries >= 0) {
- if (conn.finishConnect(false)) return
- // Sleep ?
- Thread.sleep(1)
- tries -= 1
- }
-
- // fallback to previous behavior : we should not really come here since this method was
- // triggered since channel became connectable : but at times, the first finishConnect need
- // not succeed : hence the loop to retry a few 'times'.
- conn.finishConnect(true)
- }
- } )
- }
-
- // MUST be called within selector loop - else deadlock.
- private def triggerForceCloseByException(key: SelectionKey, e: Exception) {
- try {
- key.interestOps(0)
- } catch {
- // ignore exceptions
- case e: Exception => logDebug("Ignoring exception", e)
- }
-
- val conn = connectionsByKey.getOrElse(key, null)
- if (conn == null) return
-
- // Pushing to connect threadpool
- handleConnectExecutor.execute(new Runnable {
- override def run() {
- try {
- conn.callOnExceptionCallback(e)
- } catch {
- // ignore exceptions
- case e: Exception => logDebug("Ignoring exception", e)
- }
- try {
- conn.close()
- } catch {
- // ignore exceptions
- case e: Exception => logDebug("Ignoring exception", e)
- }
- }
- })
- }
-
-
- def run() {
- try {
- while(!selectorThread.isInterrupted) {
- while (!registerRequests.isEmpty) {
- val conn: SendingConnection = registerRequests.dequeue()
- addListeners(conn)
- conn.connect()
- addConnection(conn)
- }
-
- while(!keyInterestChangeRequests.isEmpty) {
- val (key, ops) = keyInterestChangeRequests.dequeue()
-
- try {
- if (key.isValid) {
- val connection = connectionsByKey.getOrElse(key, null)
- if (connection != null) {
- val lastOps = key.interestOps()
- key.interestOps(ops)
-
- // hot loop - prevent materialization of string if trace not enabled.
- if (isTraceEnabled()) {
- def intToOpStr(op: Int): String = {
- val opStrs = ArrayBuffer[String]()
- if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
- if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
- if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
- if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
- if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
- }
-
- logTrace("Changed key for connection to [" +
- connection.getRemoteConnectionManagerId() + "] changed from [" +
- intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
- }
- }
- } else {
- logInfo("Key not valid ? " + key)
- throw new CancelledKeyException()
- }
- } catch {
- case e: CancelledKeyException => {
- logInfo("key already cancelled ? " + key, e)
- triggerForceCloseByException(key, e)
- }
- case e: Exception => {
- logError("Exception processing key " + key, e)
- triggerForceCloseByException(key, e)
- }
- }
- }
-
- val selectedKeysCount =
- try {
- selector.select()
- } catch {
- // Explicitly only dealing with CancelledKeyException here since other exceptions
- // should be dealt with differently.
- case e: CancelledKeyException => {
- // Some keys within the selectors list are invalid/closed. clear them.
- val allKeys = selector.keys().iterator()
-
- while (allKeys.hasNext) {
- val key = allKeys.next()
- try {
- if (! key.isValid) {
- logInfo("Key not valid ? " + key)
- throw new CancelledKeyException()
- }
- } catch {
- case e: CancelledKeyException => {
- logInfo("key already cancelled ? " + key, e)
- triggerForceCloseByException(key, e)
- }
- case e: Exception => {
- logError("Exception processing key " + key, e)
- triggerForceCloseByException(key, e)
- }
- }
- }
- }
- 0
- }
-
- if (selectedKeysCount == 0) {
- logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size +
- " keys")
- }
- if (selectorThread.isInterrupted) {
- logInfo("Selector thread was interrupted!")
- return
- }
-
- if (0 != selectedKeysCount) {
- val selectedKeys = selector.selectedKeys().iterator()
- while (selectedKeys.hasNext) {
- val key = selectedKeys.next
- selectedKeys.remove()
- try {
- if (key.isValid) {
- if (key.isAcceptable) {
- acceptConnection(key)
- } else
- if (key.isConnectable) {
- triggerConnect(key)
- } else
- if (key.isReadable) {
- triggerRead(key)
- } else
- if (key.isWritable) {
- triggerWrite(key)
- }
- } else {
- logInfo("Key not valid ? " + key)
- throw new CancelledKeyException()
- }
- } catch {
- // weird, but we saw this happening - even though key.isValid was true,
- // key.isAcceptable would throw CancelledKeyException.
- case e: CancelledKeyException => {
- logInfo("key already cancelled ? " + key, e)
- triggerForceCloseByException(key, e)
- }
- case e: Exception => {
- logError("Exception processing key " + key, e)
- triggerForceCloseByException(key, e)
- }
- }
- }
- }
- }
- } catch {
- case e: Exception => logError("Error in select loop", e)
- }
- }
-
- def acceptConnection(key: SelectionKey) {
- val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
-
- var newChannel = serverChannel.accept()
-
- // accept them all in a tight loop. non blocking accept with no processing, should be fine
- while (newChannel != null) {
- try {
- val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
- val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
- newConnection.onReceive(receiveMessage)
- addListeners(newConnection)
- addConnection(newConnection)
- logInfo("Accepted connection from [" + newConnection.remoteAddress + "]")
- } catch {
- // might happen in case of issues with registering with selector
- case e: Exception => logError("Error in accept loop", e)
- }
-
- newChannel = serverChannel.accept()
- }
- }
-
- private def addListeners(connection: Connection) {
- connection.onKeyInterestChange(changeConnectionKeyInterest)
- connection.onException(handleConnectionError)
- connection.onClose(removeConnection)
- }
-
- def addConnection(connection: Connection) {
- connectionsByKey += ((connection.key, connection))
- }
-
- def removeConnection(connection: Connection) {
- connectionsByKey -= connection.key
-
- try {
- connection match {
- case sendingConnection: SendingConnection =>
- val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
- logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
-
- connectionsById -= sendingConnectionManagerId
- connectionsAwaitingSasl -= connection.connectionId
-
- messageStatuses.synchronized {
- messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
- .foreach(status => {
- logInfo("Notifying " + status)
- status.markDone(None)
- })
-
- messageStatuses.retain((i, status) => {
- status.connectionManagerId != sendingConnectionManagerId
- })
- }
- case receivingConnection: ReceivingConnection =>
- val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
- logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
-
- val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
- if (!sendingConnectionOpt.isDefined) {
- logError(s"Corresponding SendingConnection to ${remoteConnectionManagerId} not found")
- return
- }
-
- val sendingConnection = sendingConnectionOpt.get
- connectionsById -= remoteConnectionManagerId
- sendingConnection.close()
-
- val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
-
- assert(sendingConnectionManagerId == remoteConnectionManagerId)
-
- messageStatuses.synchronized {
- for (s <- messageStatuses.values
- if s.connectionManagerId == sendingConnectionManagerId) {
- logInfo("Notifying " + s)
- s.markDone(None)
- }
-
- messageStatuses.retain((i, status) => {
- status.connectionManagerId != sendingConnectionManagerId
- })
- }
- case _ => logError("Unsupported type of connection.")
- }
- } finally {
- // So that the selection keys can be removed.
- wakeupSelector()
- }
- }
-
- def handleConnectionError(connection: Connection, e: Exception) {
- logInfo("Handling connection error on connection to " +
- connection.getRemoteConnectionManagerId())
- removeConnection(connection)
- }
-
- def changeConnectionKeyInterest(connection: Connection, ops: Int) {
- keyInterestChangeRequests += ((connection.key, ops))
- // so that registerations happen !
- wakeupSelector()
- }
-
- def receiveMessage(connection: Connection, message: Message) {
- val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
- logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
- val runnable = new Runnable() {
- val creationTime = System.currentTimeMillis
- def run() {
- logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
- handleMessage(connectionManagerId, message, connection)
- logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
- }
- }
- handleMessageExecutor.execute(runnable)
- /* handleMessage(connection, message) */
- }
-
- private def handleClientAuthentication(
- waitingConn: SendingConnection,
- securityMsg: SecurityMessage,
- connectionId : ConnectionId) {
- if (waitingConn.isSaslComplete()) {
- logDebug("Client sasl completed for id: " + waitingConn.connectionId)
- connectionsAwaitingSasl -= waitingConn.connectionId
- waitingConn.getAuthenticated().synchronized {
- waitingConn.getAuthenticated().notifyAll()
- }
- return
- } else {
- var replyToken : Array[Byte] = null
- try {
- replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
- if (waitingConn.isSaslComplete()) {
- logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
- connectionsAwaitingSasl -= waitingConn.connectionId
- waitingConn.getAuthenticated().synchronized {
- waitingConn.getAuthenticated().notifyAll()
- }
- return
- }
- val securityMsgResp = SecurityMessage.fromResponse(replyToken,
- securityMsg.getConnectionId.toString)
- val message = securityMsgResp.toBufferMessage
- if (message == null) throw new IOException("Error creating security message")
- sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
- } catch {
- case e: Exception => {
- logError("Error handling sasl client authentication", e)
- waitingConn.close()
- throw new IOException("Error evaluating sasl response: ", e)
- }
- }
- }
- }
-
- private def handleServerAuthentication(
- connection: Connection,
- securityMsg: SecurityMessage,
- connectionId: ConnectionId) {
- if (!connection.isSaslComplete()) {
- logDebug("saslContext not established")
- var replyToken : Array[Byte] = null
- try {
- connection.synchronized {
- if (connection.sparkSaslServer == null) {
- logDebug("Creating sasl Server")
- connection.sparkSaslServer = new SparkSaslServer(securityManager)
- }
- }
- replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
- if (connection.isSaslComplete()) {
- logDebug("Server sasl completed: " + connection.connectionId)
- } else {
- logDebug("Server sasl not completed: " + connection.connectionId)
- }
- if (replyToken != null) {
- val securityMsgResp = SecurityMessage.fromResponse(replyToken,
- securityMsg.getConnectionId)
- val message = securityMsgResp.toBufferMessage
- if (message == null) throw new Exception("Error creating security Message")
- sendSecurityMessage(connection.getRemoteConnectionManagerId(), message)
- }
- } catch {
- case e: Exception => {
- logError("Error in server auth negotiation: " + e)
- // It would probably be better to send an error message telling other side auth failed
- // but for now just close
- connection.close()
- }
- }
- } else {
- logDebug("connection already established for this connection id: " + connection.connectionId)
- }
- }
-
-
- private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = {
- if (bufferMessage.isSecurityNeg) {
- logDebug("This is security neg message")
-
- // parse as SecurityMessage
- val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage)
- val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId)
-
- connectionsAwaitingSasl.get(connectionId) match {
- case Some(waitingConn) => {
- // Client - this must be in response to us doing Send
- logDebug("Client handleAuth for id: " + waitingConn.connectionId)
- handleClientAuthentication(waitingConn, securityMsg, connectionId)
- }
- case None => {
- // Server - someone sent us something and we haven't authenticated yet
- logDebug("Server handleAuth for id: " + connectionId)
- handleServerAuthentication(conn, securityMsg, connectionId)
- }
- }
- return true
- } else {
- if (!conn.isSaslComplete()) {
- // We could handle this better and tell the client we need to do authentication
- // negotiation, but for now just ignore them.
- logError("message sent that is not security negotiation message on connection " +
- "not authenticated yet, ignoring it!!")
- return true
- }
- }
- false
- }
-
- private def handleMessage(
- connectionManagerId: ConnectionManagerId,
- message: Message,
- connection: Connection) {
- logDebug("Handling [" + message + "] from [" + connectionManagerId + "]")
- message match {
- case bufferMessage: BufferMessage => {
- if (authEnabled) {
- val res = handleAuthentication(connection, bufferMessage)
- if (res) {
- // message was security negotiation so skip the rest
- logDebug("After handleAuth result was true, returning")
- return
- }
- }
- if (bufferMessage.hasAckId()) {
- messageStatuses.synchronized {
- messageStatuses.get(bufferMessage.ackId) match {
- case Some(status) => {
- messageStatuses -= bufferMessage.ackId
- status.markDone(Some(message))
- }
- case None => {
- /**
- * We can fall down on this code because of following 2 cases
- *
- * (1) Invalid ack sent due to buggy code.
- *
- * (2) Late-arriving ack for a SendMessageStatus
- * To avoid unwilling late-arriving ack
- * caused by long pause like GC, you can set
- * larger value than default to spark.core.connection.ack.wait.timeout
- */
- logWarning(s"Could not find reference for received ack Message ${message.id}")
- }
- }
- }
- } else {
- var ackMessage : Option[Message] = None
- try {
- ackMessage = if (onReceiveCallback != null) {
- logDebug("Calling back")
- onReceiveCallback(bufferMessage, connectionManagerId)
- } else {
- logDebug("Not calling back as callback is null")
- None
- }
-
- if (ackMessage.isDefined) {
- if (!ackMessage.get.isInstanceOf[BufferMessage]) {
- logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
- + ackMessage.get.getClass)
- } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
- logDebug("Response to " + bufferMessage + " does not have ack id set")
- ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
- }
- }
- } catch {
- case e: Exception => {
- logError(s"Exception was thrown while processing message", e)
- val m = Message.createBufferMessage(bufferMessage.id)
- m.hasError = true
- ackMessage = Some(m)
- }
- } finally {
- sendMessage(connectionManagerId, ackMessage.getOrElse {
- Message.createBufferMessage(bufferMessage.id)
- })
- }
- }
- }
- case _ => throw new Exception("Unknown type message received")
- }
- }
-
- private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) {
- // see if we need to do sasl before writing
- // this should only be the first negotiation as the Client!!!
- if (!conn.isSaslComplete()) {
- conn.synchronized {
- if (conn.sparkSaslClient == null) {
- conn.sparkSaslClient = new SparkSaslClient(securityManager)
- var firstResponse: Array[Byte] = null
- try {
- firstResponse = conn.sparkSaslClient.firstToken()
- val securityMsg = SecurityMessage.fromResponse(firstResponse,
- conn.connectionId.toString())
- val message = securityMsg.toBufferMessage
- if (message == null) throw new Exception("Error creating security message")
- connectionsAwaitingSasl += ((conn.connectionId, conn))
- sendSecurityMessage(connManagerId, message)
- logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
- } catch {
- case e: Exception => {
- logError("Error getting first response from the SaslClient.", e)
- conn.close()
- throw new Exception("Error getting first response from the SaslClient")
- }
- }
- }
- }
- } else {
- logDebug("Sasl already established ")
- }
- }
-
- // allow us to add messages to the inbox for doing sasl negotiating
- private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) {
- def startNewConnection(): SendingConnection = {
- val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
- val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
- val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
- newConnectionId)
- logInfo("creating new sending connection for security! " + newConnectionId )
- registerRequests.enqueue(newConnection)
-
- newConnection
- }
- // I removed the lookupKey stuff as part of merge ... should I re-add it ?
- // We did not find it useful in our test-env ...
- // If we do re-add it, we should consistently use it everywhere I guess ?
- message.senderAddress = id.toSocketAddress()
- logTrace("Sending Security [" + message + "] to [" + connManagerId + "]")
- val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection())
-
- // send security message until going connection has been authenticated
- connection.send(message)
-
- wakeupSelector()
- }
-
- private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
- def startNewConnection(): SendingConnection = {
- val inetSocketAddress = new InetSocketAddress(connectionManagerId.host,
- connectionManagerId.port)
- val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
- val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
- newConnectionId)
- logTrace("creating new sending connection: " + newConnectionId)
- registerRequests.enqueue(newConnection)
-
- newConnection
- }
- val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
- if (authEnabled) {
- checkSendAuthFirst(connectionManagerId, connection)
- }
- message.senderAddress = id.toSocketAddress()
- logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
- "connectionid: " + connection.connectionId)
-
- if (authEnabled) {
- // if we aren't authenticated yet lets block the senders until authentication completes
- try {
- connection.getAuthenticated().synchronized {
- val clock = SystemClock
- val startTime = clock.getTime()
-
- while (!connection.isSaslComplete()) {
- logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
- // have timeout in case remote side never responds
- connection.getAuthenticated().wait(500)
- if (((clock.getTime() - startTime) >= (authTimeout * 1000))
- && (!connection.isSaslComplete())) {
- // took to long to authenticate the connection, something probably went wrong
- throw new Exception("Took to long for authentication to " + connectionManagerId +
- ", waited " + authTimeout + "seconds, failing.")
- }
- }
- }
- } catch {
- case e: Exception => logError("Exception while waiting for authentication.", e)
-
- // need to tell sender it failed
- messageStatuses.synchronized {
- val s = messageStatuses.get(message.id)
- s match {
- case Some(msgStatus) => {
- messageStatuses -= message.id
- logInfo("Notifying " + msgStatus.connectionManagerId)
- msgStatus.markDone(None)
- }
- case None => {
- logError("no messageStatus for failed message id: " + message.id)
- }
- }
- }
- }
- }
- logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
- connection.send(message)
-
- wakeupSelector()
- }
-
- private def wakeupSelector() {
- selector.wakeup()
- }
-
- /**
- * Send a message and block until an acknowldgment is received or an error occurs.
- * @param connectionManagerId the message's destination
- * @param message the message being sent
- * @return a Future that either returns the acknowledgment message or captures an exception.
- */
- def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message)
- : Future[Message] = {
- val promise = Promise[Message]()
-
- val timeoutTask = new TimerTask {
- override def run(): Unit = {
- messageStatuses.synchronized {
- messageStatuses.remove(message.id).foreach ( s => {
- promise.failure(
- new IOException("sendMessageReliably failed because ack " +
- s"was not received within $ackTimeout sec"))
- })
- }
- }
- }
-
- val status = new MessageStatus(message, connectionManagerId, s => {
- timeoutTask.cancel()
- s.ackMessage match {
- case None => // Indicates a failure where we either never sent or never got ACK'd
- promise.failure(new IOException("sendMessageReliably failed without being ACK'd"))
- case Some(ackMessage) =>
- if (ackMessage.hasError) {
- promise.failure(
- new IOException("sendMessageReliably failed with ACK that signalled a remote error"))
- } else {
- promise.success(ackMessage)
- }
- }
- })
- messageStatuses.synchronized {
- messageStatuses += ((message.id, status))
- }
-
- ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000)
- sendMessage(connectionManagerId, message)
- promise.future
- }
-
- def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
- onReceiveCallback = callback
- }
-
- def stop() {
- ackTimeoutMonitor.cancel()
- selectorThread.interrupt()
- selectorThread.join()
- selector.close()
- val connections = connectionsByKey.values
- connections.foreach(_.close())
- if (connectionsByKey.size != 0) {
- logWarning("All connections not cleaned up")
- }
- handleMessageExecutor.shutdown()
- handleReadWriteExecutor.shutdown()
- handleConnectExecutor.shutdown()
- logInfo("ConnectionManager stopped")
- }
-}
-
-
-private[spark] object ConnectionManager {
- import ExecutionContext.Implicits.global
-
- def main(args: Array[String]) {
- val conf = new SparkConf
- val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- println("Received [" + msg + "] from [" + id + "]")
- None
- })
-
- /* testSequentialSending(manager) */
- /* System.gc() */
-
- /* testParallelSending(manager) */
- /* System.gc() */
-
- /* testParallelDecreasingSending(manager) */
- /* System.gc() */
-
- testContinuousSending(manager)
- System.gc()
- }
-
- def testSequentialSending(manager: ConnectionManager) {
- println("--------------------------")
- println("Sequential Sending")
- println("--------------------------")
- val size = 10 * 1024 * 1024
- val count = 10
-
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- (0 until count).map(i => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf)
- })
- println("--------------------------")
- println()
- }
-
- def testParallelSending(manager: ConnectionManager) {
- println("--------------------------")
- println("Parallel Sending")
- println("--------------------------")
- val size = 10 * 1024 * 1024
- val count = 10
-
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- val startTime = System.currentTimeMillis
- (0 until count).map(i => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliably(manager.id, bufferMessage)
- }).foreach(f => {
- f.onFailure {
- case e => println("Failed due to " + e)
- }
- Await.ready(f, 1 second)
- })
- val finishTime = System.currentTimeMillis
-
- val mb = size * count / 1024.0 / 1024.0
- val ms = finishTime - startTime
- val tput = mb * 1000.0 / ms
- println("--------------------------")
- println("Started at " + startTime + ", finished at " + finishTime)
- println("Sent " + count + " messages of size " + size + " in " + ms + " ms " +
- "(" + tput + " MB/s)")
- println("--------------------------")
- println()
- }
-
- def testParallelDecreasingSending(manager: ConnectionManager) {
- println("--------------------------")
- println("Parallel Decreasing Sending")
- println("--------------------------")
- val size = 10 * 1024 * 1024
- val count = 10
- val buffers = Array.tabulate(count) { i =>
- val bufferLen = size * (i + 1)
- val bufferContent = Array.tabulate[Byte](bufferLen)(x => x.toByte)
- ByteBuffer.allocate(bufferLen).put(bufferContent)
- }
- buffers.foreach(_.flip)
- val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0
-
- val startTime = System.currentTimeMillis
- (0 until count).map(i => {
- val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
- manager.sendMessageReliably(manager.id, bufferMessage)
- }).foreach(f => {
- f.onFailure {
- case e => println("Failed due to " + e)
- }
- Await.ready(f, 1 second)
- })
- val finishTime = System.currentTimeMillis
-
- val ms = finishTime - startTime
- val tput = mb * 1000.0 / ms
- println("--------------------------")
- /* println("Started at " + startTime + ", finished at " + finishTime) */
- println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
- println("--------------------------")
- println()
- }
-
- def testContinuousSending(manager: ConnectionManager) {
- println("--------------------------")
- println("Continuous Sending")
- println("--------------------------")
- val size = 10 * 1024 * 1024
- val count = 10
-
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- val startTime = System.currentTimeMillis
- while(true) {
- (0 until count).map(i => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliably(manager.id, bufferMessage)
- }).foreach(f => {
- f.onFailure {
- case e => println("Failed due to " + e)
- }
- Await.ready(f, 1 second)
- })
- val finishTime = System.currentTimeMillis
- Thread.sleep(1000)
- val mb = size * count / 1024.0 / 1024.0
- val ms = finishTime - startTime
- val tput = mb * 1000.0 / ms
- println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
- println("--------------------------")
- println()
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala
deleted file mode 100644
index 57f7586..0000000
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.net.InetSocketAddress
-
-import org.apache.spark.util.Utils
-
-private[spark] case class ConnectionManagerId(host: String, port: Int) {
- // DEBUG code
- Utils.checkHost(host)
- assert (port > 0)
-
- def toSocketAddress() = new InetSocketAddress(host, port)
-}
-
-
-private[spark] object ConnectionManagerId {
- def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
- new ConnectionManagerId(socketAddress.getHostName, socketAddress.getPort)
- }
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org