You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2013/09/01 23:59:22 UTC
[38/69] [abbrv] [partial] Initial work to rename package to
org.apache.spark
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/metrics/source/Source.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/metrics/source/Source.scala b/core/src/main/scala/org/apache/spark/metrics/source/Source.scala
new file mode 100644
index 0000000..3fee55c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/source/Source.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.source
+
+import com.codahale.metrics.MetricRegistry
+
+trait Source {
+ def sourceName: String
+ def metricRegistry: MetricRegistry
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
new file mode 100644
index 0000000..f736bb3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.storage.BlockManager
+
+
+private[spark]
+class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
+ extends Message(Message.BUFFER_MESSAGE, id_) {
+
+ val initialSize = currentSize()
+ var gotChunkForSendingOnce = false
+
+ def size = initialSize
+
+ def currentSize() = {
+ if (buffers == null || buffers.isEmpty) {
+ 0
+ } else {
+ buffers.map(_.remaining).reduceLeft(_ + _)
+ }
+ }
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
+ if (maxChunkSize <= 0) {
+ throw new Exception("Max chunk size is " + maxChunkSize)
+ }
+
+ if (size == 0 && gotChunkForSendingOnce == false) {
+ val newChunk = new MessageChunk(
+ new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+
+ while(!buffers.isEmpty) {
+ val buffer = buffers(0)
+ if (buffer.remaining == 0) {
+ BlockManager.dispose(buffer)
+ buffers -= buffer
+ } else {
+ val newBuffer = if (buffer.remaining <= maxChunkSize) {
+ buffer.duplicate()
+ } else {
+ buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
+ }
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+ }
+ None
+ }
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
+ // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
+ if (buffers.size > 1) {
+ throw new Exception("Attempting to get chunk from message with multiple data buffers")
+ }
+ val buffer = buffers(0)
+ if (buffer.remaining > 0) {
+ if (buffer.remaining < chunkSize) {
+ throw new Exception("Not enough space in data buffer for receiving chunk")
+ }
+ val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ return Some(newChunk)
+ }
+ None
+ }
+
+ def flip() {
+ buffers.foreach(_.flip)
+ }
+
+ def hasAckId() = (ackId != 0)
+
+ def isCompletelyReceived() = !buffers(0).hasRemaining
+
+ override def toString = {
+ if (hasAckId) {
+ "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
+ } else {
+ "BufferMessage(id = " + id + ", size = " + size + ")"
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/Connection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala
new file mode 100644
index 0000000..95cb020
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/Connection.scala
@@ -0,0 +1,586 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import org.apache.spark._
+
+import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}
+
+import java.io._
+import java.nio._
+import java.nio.channels._
+import java.nio.channels.spi._
+import java.net._
+
+
+private[spark]
+abstract class Connection(val channel: SocketChannel, val selector: Selector,
+ val socketRemoteConnectionManagerId: ConnectionManagerId)
+ extends Logging {
+
+ def this(channel_ : SocketChannel, selector_ : Selector) = {
+ this(channel_, selector_,
+ ConnectionManagerId.fromSocketAddress(
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
+ }
+
+ channel.configureBlocking(false)
+ channel.socket.setTcpNoDelay(true)
+ channel.socket.setReuseAddress(true)
+ channel.socket.setKeepAlive(true)
+ /*channel.socket.setReceiveBufferSize(32768) */
+
+ @volatile private var closed = false
+ var onCloseCallback: Connection => Unit = null
+ var onExceptionCallback: (Connection, Exception) => Unit = null
+ var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
+
+ val remoteAddress = getRemoteAddress()
+
+ def resetForceReregister(): Boolean
+
+ // Read channels typically do not register for write and write does not for read
+ // Now, we do have write registering for read too (temporarily), but this is to detect
+ // channel close NOT to actually read/consume data on it !
+ // How does this work if/when we move to SSL ?
+
+ // What is the interest to register with selector for when we want this connection to be selected
+ def registerInterest()
+
+ // What is the interest to register with selector for when we want this connection to
+ // be de-selected
+ // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack,
+ // it will be SelectionKey.OP_READ (until we fix it properly)
+ def unregisterInterest()
+
+ // On receiving a read event, should we change the interest for this channel or not ?
+ // Will be true for ReceivingConnection, false for SendingConnection.
+ def changeInterestForRead(): Boolean
+
+ // On receiving a write event, should we change the interest for this channel or not ?
+ // Will be false for ReceivingConnection, true for SendingConnection.
+ // Actually, for now, should not get triggered for ReceivingConnection
+ def changeInterestForWrite(): Boolean
+
+ def getRemoteConnectionManagerId(): ConnectionManagerId = {
+ socketRemoteConnectionManagerId
+ }
+
+ def key() = channel.keyFor(selector)
+
+ def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
+
+ // Returns whether we have to register for further reads or not.
+ def read(): Boolean = {
+ throw new UnsupportedOperationException(
+ "Cannot read on connection of type " + this.getClass.toString)
+ }
+
+ // Returns whether we have to register for further writes or not.
+ def write(): Boolean = {
+ throw new UnsupportedOperationException(
+ "Cannot write on connection of type " + this.getClass.toString)
+ }
+
+ def close() {
+ closed = true
+ val k = key()
+ if (k != null) {
+ k.cancel()
+ }
+ channel.close()
+ callOnCloseCallback()
+ }
+
+ protected def isClosed: Boolean = closed
+
+ def onClose(callback: Connection => Unit) {
+ onCloseCallback = callback
+ }
+
+ def onException(callback: (Connection, Exception) => Unit) {
+ onExceptionCallback = callback
+ }
+
+ def onKeyInterestChange(callback: (Connection, Int) => Unit) {
+ onKeyInterestChangeCallback = callback
+ }
+
+ def callOnExceptionCallback(e: Exception) {
+ if (onExceptionCallback != null) {
+ onExceptionCallback(this, e)
+ } else {
+ logError("Error in connection to " + getRemoteConnectionManagerId() +
+ " and OnExceptionCallback not registered", e)
+ }
+ }
+
+ def callOnCloseCallback() {
+ if (onCloseCallback != null) {
+ onCloseCallback(this)
+ } else {
+ logWarning("Connection to " + getRemoteConnectionManagerId() +
+ " closed and OnExceptionCallback not registered")
+ }
+
+ }
+
+ def changeConnectionKeyInterest(ops: Int) {
+ if (onKeyInterestChangeCallback != null) {
+ onKeyInterestChangeCallback(this, ops)
+ } else {
+ throw new Exception("OnKeyInterestChangeCallback not registered")
+ }
+ }
+
+ def printRemainingBuffer(buffer: ByteBuffer) {
+ val bytes = new Array[Byte](buffer.remaining)
+ val curPosition = buffer.position
+ buffer.get(bytes)
+ bytes.foreach(x => print(x + " "))
+ buffer.position(curPosition)
+ print(" (" + bytes.size + ")")
+ }
+
+ def printBuffer(buffer: ByteBuffer, position: Int, length: Int) {
+ val bytes = new Array[Byte](length)
+ val curPosition = buffer.position
+ buffer.position(position)
+ buffer.get(bytes)
+ bytes.foreach(x => print(x + " "))
+ print(" (" + position + ", " + length + ")")
+ buffer.position(curPosition)
+ }
+}
+
+
+private[spark]
+class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
+ remoteId_ : ConnectionManagerId)
+ extends Connection(SocketChannel.open, selector_, remoteId_) {
+
+ private class Outbox(fair: Int = 0) {
+ val messages = new Queue[Message]()
+ val defaultChunkSize = 65536 //32768 //16384
+ var nextMessageToBeUsed = 0
+
+ def addMessage(message: Message) {
+ messages.synchronized{
+ /*messages += message*/
+ messages.enqueue(message)
+ logDebug("Added [" + message + "] to outbox for sending to " +
+ "[" + getRemoteConnectionManagerId() + "]")
+ }
+ }
+
+ def getChunk(): Option[MessageChunk] = {
+ fair match {
+ case 0 => getChunkFIFO()
+ case 1 => getChunkRR()
+ case _ => throw new Exception("Unexpected fairness policy in outbox")
+ }
+ }
+
+ private def getChunkFIFO(): Option[MessageChunk] = {
+ /*logInfo("Using FIFO")*/
+ messages.synchronized {
+ while (!messages.isEmpty) {
+ val message = messages(0)
+ val chunk = message.getChunkForSending(defaultChunkSize)
+ if (chunk.isDefined) {
+ messages += message // this is probably incorrect, it wont work as fifo
+ if (!message.started) {
+ logDebug("Starting to send [" + message + "]")
+ message.started = true
+ message.startTime = System.currentTimeMillis
+ }
+ return chunk
+ } else {
+ /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/
+ message.finishTime = System.currentTimeMillis
+ logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
+ "] in " + message.timeTaken )
+ }
+ }
+ }
+ None
+ }
+
+ private def getChunkRR(): Option[MessageChunk] = {
+ messages.synchronized {
+ while (!messages.isEmpty) {
+ /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
+ /*val message = messages(nextMessageToBeUsed)*/
+ val message = messages.dequeue
+ val chunk = message.getChunkForSending(defaultChunkSize)
+ if (chunk.isDefined) {
+ messages.enqueue(message)
+ nextMessageToBeUsed = nextMessageToBeUsed + 1
+ if (!message.started) {
+ logDebug(
+ "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
+ message.started = true
+ message.startTime = System.currentTimeMillis
+ }
+ logTrace(
+ "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
+ return chunk
+ } else {
+ message.finishTime = System.currentTimeMillis
+ logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
+ "] in " + message.timeTaken )
+ }
+ }
+ }
+ None
+ }
+ }
+
+ // outbox is used as a lock - ensure that it is always used as a leaf (since methods which
+ // lock it are invoked in context of other locks)
+ private val outbox = new Outbox(1)
+ /*
+ This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly
+ different purpose. This flag is to see if we need to force reregister for write even when we
+ do not have any pending bytes to write to socket.
+ This can happen due to a race between adding pending buffers, and checking for existing of
+ data as detailed in https://github.com/mesos/spark/pull/791
+ */
+ private var needForceReregister = false
+ val currentBuffers = new ArrayBuffer[ByteBuffer]()
+
+ /*channel.socket.setSendBufferSize(256 * 1024)*/
+
+ override def getRemoteAddress() = address
+
+ val DEFAULT_INTEREST = SelectionKey.OP_READ
+
+ override def registerInterest() {
+ // Registering read too - does not really help in most cases, but for some
+ // it does - so let us keep it for now.
+ changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
+ }
+
+ override def unregisterInterest() {
+ changeConnectionKeyInterest(DEFAULT_INTEREST)
+ }
+
+ def send(message: Message) {
+ outbox.synchronized {
+ outbox.addMessage(message)
+ needForceReregister = true
+ }
+ if (channel.isConnected) {
+ registerInterest()
+ }
+ }
+
+ // return previous value after resetting it.
+ def resetForceReregister(): Boolean = {
+ outbox.synchronized {
+ val result = needForceReregister
+ needForceReregister = false
+ result
+ }
+ }
+
+ // MUST be called within the selector loop
+ def connect() {
+ try{
+ channel.register(selector, SelectionKey.OP_CONNECT)
+ channel.connect(address)
+ logInfo("Initiating connection to [" + address + "]")
+ } catch {
+ case e: Exception => {
+ logError("Error connecting to " + address, e)
+ callOnExceptionCallback(e)
+ }
+ }
+ }
+
+ def finishConnect(force: Boolean): Boolean = {
+ try {
+ // Typically, this should finish immediately since it was triggered by a connect
+ // selection - though need not necessarily always complete successfully.
+ val connected = channel.finishConnect
+ if (!force && !connected) {
+ logInfo(
+ "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
+ return false
+ }
+
+ // Fallback to previous behavior - assume finishConnect completed
+ // This will happen only when finishConnect failed for some repeated number of times
+ // (10 or so)
+ // Is highly unlikely unless there was an unclean close of socket, etc
+ registerInterest()
+ logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
+ return true
+ } catch {
+ case e: Exception => {
+ logWarning("Error finishing connection to " + address, e)
+ callOnExceptionCallback(e)
+ // ignore
+ return true
+ }
+ }
+ }
+
+ override def write(): Boolean = {
+ try {
+ while (true) {
+ if (currentBuffers.size == 0) {
+ outbox.synchronized {
+ outbox.getChunk() match {
+ case Some(chunk) => {
+ val buffers = chunk.buffers
+ // If we have 'seen' pending messages, then reset flag - since we handle that as normal
+ // registering of event (below)
+ if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister()
+ currentBuffers ++= buffers
+ }
+ case None => {
+ // changeConnectionKeyInterest(0)
+ /*key.interestOps(0)*/
+ return false
+ }
+ }
+ }
+ }
+
+ if (currentBuffers.size > 0) {
+ val buffer = currentBuffers(0)
+ val remainingBytes = buffer.remaining
+ val writtenBytes = channel.write(buffer)
+ if (buffer.remaining == 0) {
+ currentBuffers -= buffer
+ }
+ if (writtenBytes < remainingBytes) {
+ // re-register for write.
+ return true
+ }
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
+ callOnExceptionCallback(e)
+ close()
+ return false
+ }
+ }
+ // should not happen - to keep scala compiler happy
+ return true
+ }
+
+ // This is a hack to determine if remote socket was closed or not.
+ // SendingConnection DOES NOT expect to receive any data - if it does, it is an error
+ // For a bunch of cases, read will return -1 in case remote socket is closed : hence we
+ // register for reads to determine that.
+ override def read(): Boolean = {
+ // We don't expect the other side to send anything; so, we just read to detect an error or EOF.
+ try {
+ val length = channel.read(ByteBuffer.allocate(1))
+ if (length == -1) { // EOF
+ close()
+ } else if (length > 0) {
+ logWarning(
+ "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
+ }
+ } catch {
+ case e: Exception =>
+ logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e)
+ callOnExceptionCallback(e)
+ close()
+ }
+
+ false
+ }
+
+ override def changeInterestForRead(): Boolean = false
+
+ override def changeInterestForWrite(): Boolean = ! isClosed
+}
+
+
+// Must be created within selector loop - else deadlock
+private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
+ extends Connection(channel_, selector_) {
+
+ class Inbox() {
+ val messages = new HashMap[Int, BufferMessage]()
+
+ def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
+
+ def createNewMessage: BufferMessage = {
+ val newMessage = Message.create(header).asInstanceOf[BufferMessage]
+ newMessage.started = true
+ newMessage.startTime = System.currentTimeMillis
+ logDebug(
+ "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
+ messages += ((newMessage.id, newMessage))
+ newMessage
+ }
+
+ val message = messages.getOrElseUpdate(header.id, createNewMessage)
+ logTrace(
+ "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
+ message.getChunkForReceiving(header.chunkSize)
+ }
+
+ def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
+ messages.get(chunk.header.id)
+ }
+
+ def removeMessage(message: Message) {
+ messages -= message.id
+ }
+ }
+
+ @volatile private var inferredRemoteManagerId: ConnectionManagerId = null
+
+ override def getRemoteConnectionManagerId(): ConnectionManagerId = {
+ val currId = inferredRemoteManagerId
+ if (currId != null) currId else super.getRemoteConnectionManagerId()
+ }
+
+ // The reciever's remote address is the local socket on remote side : which is NOT
+ // the connection manager id of the receiver.
+ // We infer that from the messages we receive on the receiver socket.
+ private def processConnectionManagerId(header: MessageChunkHeader) {
+ val currId = inferredRemoteManagerId
+ if (header.address == null || currId != null) return
+
+ val managerId = ConnectionManagerId.fromSocketAddress(header.address)
+
+ if (managerId != null) {
+ inferredRemoteManagerId = managerId
+ }
+ }
+
+
+ val inbox = new Inbox()
+ val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
+ var onReceiveCallback: (Connection , Message) => Unit = null
+ var currentChunk: MessageChunk = null
+
+ channel.register(selector, SelectionKey.OP_READ)
+
+ override def read(): Boolean = {
+ try {
+ while (true) {
+ if (currentChunk == null) {
+ val headerBytesRead = channel.read(headerBuffer)
+ if (headerBytesRead == -1) {
+ close()
+ return false
+ }
+ if (headerBuffer.remaining > 0) {
+ // re-register for read event ...
+ return true
+ }
+ headerBuffer.flip
+ if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
+ throw new Exception(
+ "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
+ }
+ val header = MessageChunkHeader.create(headerBuffer)
+ headerBuffer.clear()
+
+ processConnectionManagerId(header)
+
+ header.typ match {
+ case Message.BUFFER_MESSAGE => {
+ if (header.totalSize == 0) {
+ if (onReceiveCallback != null) {
+ onReceiveCallback(this, Message.create(header))
+ }
+ currentChunk = null
+ // re-register for read event ...
+ return true
+ } else {
+ currentChunk = inbox.getChunk(header).orNull
+ }
+ }
+ case _ => throw new Exception("Message of unknown type received")
+ }
+ }
+
+ if (currentChunk == null) throw new Exception("No message chunk to receive data")
+
+ val bytesRead = channel.read(currentChunk.buffer)
+ if (bytesRead == 0) {
+ // re-register for read event ...
+ return true
+ } else if (bytesRead == -1) {
+ close()
+ return false
+ }
+
+ /*logDebug("Read " + bytesRead + " bytes for the buffer")*/
+
+ if (currentChunk.buffer.remaining == 0) {
+ /*println("Filled buffer at " + System.currentTimeMillis)*/
+ val bufferMessage = inbox.getMessageForChunk(currentChunk).get
+ if (bufferMessage.isCompletelyReceived) {
+ bufferMessage.flip
+ bufferMessage.finishTime = System.currentTimeMillis
+ logDebug("Finished receiving [" + bufferMessage + "] from " +
+ "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
+ if (onReceiveCallback != null) {
+ onReceiveCallback(this, bufferMessage)
+ }
+ inbox.removeMessage(bufferMessage)
+ }
+ currentChunk = null
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
+ callOnExceptionCallback(e)
+ close()
+ return false
+ }
+ }
+ // should not happen - to keep scala compiler happy
+ return true
+ }
+
+ def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
+
+ // override def changeInterestForRead(): Boolean = ! isClosed
+ override def changeInterestForRead(): Boolean = true
+
+ override def changeInterestForWrite(): Boolean = {
+ throw new IllegalStateException("Unexpected invocation right now")
+ }
+
+ override def registerInterest() {
+ // Registering read too - does not really help in most cases, but for some
+ // it does - so let us keep it for now.
+ changeConnectionKeyInterest(SelectionKey.OP_READ)
+ }
+
+ override def unregisterInterest() {
+ changeConnectionKeyInterest(0)
+ }
+
+ // For read conn, always false.
+ override def resetForceReregister(): Boolean = false
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
new file mode 100644
index 0000000..9e2233c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -0,0 +1,720 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import org.apache.spark._
+
+import java.nio._
+import java.nio.channels._
+import java.nio.channels.spi._
+import java.net._
+import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
+
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.SynchronizedMap
+import scala.collection.mutable.SynchronizedQueue
+import scala.collection.mutable.ArrayBuffer
+
+import akka.dispatch.{Await, Promise, ExecutionContext, Future}
+import akka.util.Duration
+import akka.util.duration._
+
+
+private[spark] class ConnectionManager(port: Int) extends Logging {
+
+ class MessageStatus(
+ val message: Message,
+ val connectionManagerId: ConnectionManagerId,
+ completionHandler: MessageStatus => Unit) {
+
+ var ackMessage: Option[Message] = None
+ var attempted = false
+ var acked = false
+
+ def markDone() { completionHandler(this) }
+ }
+
+ private val selector = SelectorProvider.provider.openSelector()
+
+ private val handleMessageExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.handler.threads.min","20").toInt,
+ System.getProperty("spark.core.connection.handler.threads.max","60").toInt,
+ System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ private val handleReadWriteExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.io.threads.min","4").toInt,
+ System.getProperty("spark.core.connection.io.threads.max","32").toInt,
+ System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap
+ private val handleConnectExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.connect.threads.min","1").toInt,
+ System.getProperty("spark.core.connection.connect.threads.max","8").toInt,
+ System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ private val serverChannel = ServerSocketChannel.open()
+ private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
+ private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
+ private val messageStatuses = new HashMap[Int, MessageStatus]
+ private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
+ private val registerRequests = new SynchronizedQueue[SendingConnection]
+
+ implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
+
+ private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+
+ serverChannel.configureBlocking(false)
+ serverChannel.socket.setReuseAddress(true)
+ serverChannel.socket.setReceiveBufferSize(256 * 1024)
+
+ serverChannel.socket.bind(new InetSocketAddress(port))
+ serverChannel.register(selector, SelectionKey.OP_ACCEPT)
+
+ val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
+ logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
+
+ private val selectorThread = new Thread("connection-manager-thread") {
+ override def run() = ConnectionManager.this.run()
+ }
+ selectorThread.setDaemon(true)
+ selectorThread.start()
+
+ private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
+ private def triggerWrite(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ writeRunnableStarted.synchronized {
+ // So that we do not trigger more write events while processing this one.
+ // The write method will re-register when done.
+ if (conn.changeInterestForWrite()) conn.unregisterInterest()
+ if (writeRunnableStarted.contains(key)) {
+ // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE)
+ return
+ }
+
+ writeRunnableStarted += key
+ }
+ handleReadWriteExecutor.execute(new Runnable {
+ override def run() {
+ var register: Boolean = false
+ try {
+ register = conn.write()
+ } finally {
+ writeRunnableStarted.synchronized {
+ writeRunnableStarted -= key
+ val needReregister = register || conn.resetForceReregister()
+ if (needReregister && conn.changeInterestForWrite()) {
+ conn.registerInterest()
+ }
+ }
+ }
+ }
+ } )
+ }
+
+ private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
+ private def triggerRead(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ readRunnableStarted.synchronized {
+ // So that we do not trigger more read events while processing this one.
+ // The read method will re-register when done.
+ if (conn.changeInterestForRead())conn.unregisterInterest()
+ if (readRunnableStarted.contains(key)) {
+ return
+ }
+
+ readRunnableStarted += key
+ }
+ handleReadWriteExecutor.execute(new Runnable {
+ override def run() {
+ var register: Boolean = false
+ try {
+ register = conn.read()
+ } finally {
+ readRunnableStarted.synchronized {
+ readRunnableStarted -= key
+ if (register && conn.changeInterestForRead()) {
+ conn.registerInterest()
+ }
+ }
+ }
+ }
+ } )
+ }
+
+ private def triggerConnect(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection]
+ if (conn == null) return
+
+ // prevent other events from being triggered
+ // Since we are still trying to connect, we do not need to do the additional steps in triggerWrite
+ conn.changeConnectionKeyInterest(0)
+
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+
+ var tries: Int = 10
+ while (tries >= 0) {
+ if (conn.finishConnect(false)) return
+ // Sleep ?
+ Thread.sleep(1)
+ tries -= 1
+ }
+
+ // fallback to previous behavior : we should not really come here since this method was
+ // triggered since channel became connectable : but at times, the first finishConnect need not
+ // succeed : hence the loop to retry a few 'times'.
+ conn.finishConnect(true)
+ }
+ } )
+ }
+
+ // MUST be called within selector loop - else deadlock.
+ private def triggerForceCloseByException(key: SelectionKey, e: Exception) {
+ try {
+ key.interestOps(0)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ // Pushing to connect threadpool
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+ try {
+ conn.callOnExceptionCallback(e)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ try {
+ conn.close()
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ }
+ })
+ }
+
+
+ def run() {
+ try {
+ while(!selectorThread.isInterrupted) {
+ while (! registerRequests.isEmpty) {
+ val conn: SendingConnection = registerRequests.dequeue
+ addListeners(conn)
+ conn.connect()
+ addConnection(conn)
+ }
+
+ while(!keyInterestChangeRequests.isEmpty) {
+ val (key, ops) = keyInterestChangeRequests.dequeue
+
+ try {
+ if (key.isValid) {
+ val connection = connectionsByKey.getOrElse(key, null)
+ if (connection != null) {
+ val lastOps = key.interestOps()
+ key.interestOps(ops)
+
+ // hot loop - prevent materialization of string if trace not enabled.
+ if (isTraceEnabled()) {
+ def intToOpStr(op: Int): String = {
+ val opStrs = ArrayBuffer[String]()
+ if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
+ if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
+ if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
+ if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
+ if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+ }
+
+ logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() +
+ "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+ }
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+
+ val selectedKeysCount =
+ try {
+ selector.select()
+ } catch {
+ // Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently.
+ case e: CancelledKeyException => {
+ // Some keys within the selectors list are invalid/closed. clear them.
+ val allKeys = selector.keys().iterator()
+
+ while (allKeys.hasNext()) {
+ val key = allKeys.next()
+ try {
+ if (! key.isValid) {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+ }
+ 0
+ }
+
+ if (selectedKeysCount == 0) {
+ logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
+ }
+ if (selectorThread.isInterrupted) {
+ logInfo("Selector thread was interrupted!")
+ return
+ }
+
+ if (0 != selectedKeysCount) {
+ val selectedKeys = selector.selectedKeys().iterator()
+ while (selectedKeys.hasNext()) {
+ val key = selectedKeys.next
+ selectedKeys.remove()
+ try {
+ if (key.isValid) {
+ if (key.isAcceptable) {
+ acceptConnection(key)
+ } else
+ if (key.isConnectable) {
+ triggerConnect(key)
+ } else
+ if (key.isReadable) {
+ triggerRead(key)
+ } else
+ if (key.isWritable) {
+ triggerWrite(key)
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException.
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+ }
+ }
+ } catch {
+ case e: Exception => logError("Error in select loop", e)
+ }
+ }
+
+ def acceptConnection(key: SelectionKey) {
+ val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
+
+ var newChannel = serverChannel.accept()
+
+ // accept them all in a tight loop. non blocking accept with no processing, should be fine
+ while (newChannel != null) {
+ try {
+ val newConnection = new ReceivingConnection(newChannel, selector)
+ newConnection.onReceive(receiveMessage)
+ addListeners(newConnection)
+ addConnection(newConnection)
+ logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
+ } catch {
+ // might happen in case of issues with registering with selector
+ case e: Exception => logError("Error in accept loop", e)
+ }
+
+ newChannel = serverChannel.accept()
+ }
+ }
+
+ private def addListeners(connection: Connection) {
+ connection.onKeyInterestChange(changeConnectionKeyInterest)
+ connection.onException(handleConnectionError)
+ connection.onClose(removeConnection)
+ }
+
+ def addConnection(connection: Connection) {
+ connectionsByKey += ((connection.key, connection))
+ }
+
+ def removeConnection(connection: Connection) {
+ connectionsByKey -= connection.key
+
+ try {
+ if (connection.isInstanceOf[SendingConnection]) {
+ val sendingConnection = connection.asInstanceOf[SendingConnection]
+ val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
+ logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
+
+ connectionsById -= sendingConnectionManagerId
+
+ messageStatuses.synchronized {
+ messageStatuses
+ .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
+ logInfo("Notifying " + status)
+ status.synchronized {
+ status.attempted = true
+ status.acked = false
+ status.markDone()
+ }
+ })
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
+ } else if (connection.isInstanceOf[ReceivingConnection]) {
+ val receivingConnection = connection.asInstanceOf[ReceivingConnection]
+ val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
+ logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
+
+ val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
+ if (! sendingConnectionOpt.isDefined) {
+ logError("Corresponding SendingConnectionManagerId not found")
+ return
+ }
+
+ val sendingConnection = sendingConnectionOpt.get
+ connectionsById -= remoteConnectionManagerId
+ sendingConnection.close()
+
+ val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
+
+ assert (sendingConnectionManagerId == remoteConnectionManagerId)
+
+ messageStatuses.synchronized {
+ for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
+ logInfo("Notifying " + s)
+ s.synchronized {
+ s.attempted = true
+ s.acked = false
+ s.markDone()
+ }
+ }
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
+ }
+ } finally {
+ // So that the selection keys can be removed.
+ wakeupSelector()
+ }
+ }
+
+ def handleConnectionError(connection: Connection, e: Exception) {
+ logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId())
+ removeConnection(connection)
+ }
+
+ def changeConnectionKeyInterest(connection: Connection, ops: Int) {
+ keyInterestChangeRequests += ((connection.key, ops))
+ // so that registerations happen !
+ wakeupSelector()
+ }
+
+ def receiveMessage(connection: Connection, message: Message) {
+ val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
+ logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
+ val runnable = new Runnable() {
+ val creationTime = System.currentTimeMillis
+ def run() {
+ logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ handleMessage(connectionManagerId, message)
+ logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ }
+ }
+ handleMessageExecutor.execute(runnable)
+ /*handleMessage(connection, message)*/
+ }
+
+ private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ logDebug("Handling [" + message + "] from [" + connectionManagerId + "]")
+ message match {
+ case bufferMessage: BufferMessage => {
+ if (bufferMessage.hasAckId) {
+ val sentMessageStatus = messageStatuses.synchronized {
+ messageStatuses.get(bufferMessage.ackId) match {
+ case Some(status) => {
+ messageStatuses -= bufferMessage.ackId
+ status
+ }
+ case None => {
+ throw new Exception("Could not find reference for received ack message " + message.id)
+ null
+ }
+ }
+ }
+ sentMessageStatus.synchronized {
+ sentMessageStatus.ackMessage = Some(message)
+ sentMessageStatus.attempted = true
+ sentMessageStatus.acked = true
+ sentMessageStatus.markDone()
+ }
+ } else {
+ val ackMessage = if (onReceiveCallback != null) {
+ logDebug("Calling back")
+ onReceiveCallback(bufferMessage, connectionManagerId)
+ } else {
+ logDebug("Not calling back as callback is null")
+ None
+ }
+
+ if (ackMessage.isDefined) {
+ if (!ackMessage.get.isInstanceOf[BufferMessage]) {
+ logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
+ } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
+ logDebug("Response to " + bufferMessage + " does not have ack id set")
+ ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
+ }
+ }
+
+ sendMessage(connectionManagerId, ackMessage.getOrElse {
+ Message.createBufferMessage(bufferMessage.id)
+ })
+ }
+ }
+ case _ => throw new Exception("Unknown type message received")
+ }
+ }
+
+ private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ def startNewConnection(): SendingConnection = {
+ val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
+ registerRequests.enqueue(newConnection)
+
+ newConnection
+ }
+ // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ...
+ // If we do re-add it, we should consistently use it everywhere I guess ?
+ val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
+ message.senderAddress = id.toSocketAddress()
+ logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
+ connection.send(message)
+
+ wakeupSelector()
+ }
+
+ private def wakeupSelector() {
+ selector.wakeup()
+ }
+
+ def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message)
+ : Future[Option[Message]] = {
+ val promise = Promise[Option[Message]]
+ val status = new MessageStatus(message, connectionManagerId, s => promise.success(s.ackMessage))
+ messageStatuses.synchronized {
+ messageStatuses += ((message.id, status))
+ }
+ sendMessage(connectionManagerId, message)
+ promise.future
+ }
+
+ def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = {
+ Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)
+ }
+
+ def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
+ onReceiveCallback = callback
+ }
+
+ def stop() {
+ selectorThread.interrupt()
+ selectorThread.join()
+ selector.close()
+ val connections = connectionsByKey.values
+ connections.foreach(_.close())
+ if (connectionsByKey.size != 0) {
+ logWarning("All connections not cleaned up")
+ }
+ handleMessageExecutor.shutdown()
+ handleReadWriteExecutor.shutdown()
+ handleConnectExecutor.shutdown()
+ logInfo("ConnectionManager stopped")
+ }
+}
+
+
+private[spark] object ConnectionManager {
+
+ def main(args: Array[String]) {
+ val manager = new ConnectionManager(9999)
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ println("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ /*testSequentialSending(manager)*/
+ /*System.gc()*/
+
+ /*testParallelSending(manager)*/
+ /*System.gc()*/
+
+ /*testParallelDecreasingSending(manager)*/
+ /*System.gc()*/
+
+ testContinuousSending(manager)
+ System.gc()
+ }
+
+ def testSequentialSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Sequential Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(manager.id, bufferMessage)
+ })
+ println("--------------------------")
+ println()
+ }
+
+ def testParallelSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Parallel Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {
+ val g = Await.result(f, 1 second)
+ if (!g.isDefined) println("Failed")
+ })
+ val finishTime = System.currentTimeMillis
+
+ val mb = size * count / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("--------------------------")
+ println("Started at " + startTime + ", finished at " + finishTime)
+ println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+
+ def testParallelDecreasingSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Parallel Decreasing Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+ val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
+ buffers.foreach(_.flip)
+ val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0
+
+ val startTime = System.currentTimeMillis
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {
+ val g = Await.result(f, 1 second)
+ if (!g.isDefined) println("Failed")
+ })
+ val finishTime = System.currentTimeMillis
+
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("--------------------------")
+ /*println("Started at " + startTime + ", finished at " + finishTime) */
+ println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+
+ def testContinuousSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Continuous Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ while(true) {
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {
+ val g = Await.result(f, 1 second)
+ if (!g.isDefined) println("Failed")
+ })
+ val finishTime = System.currentTimeMillis
+ Thread.sleep(1000)
+ val mb = size * count / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala
new file mode 100644
index 0000000..0839c01
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.net.InetSocketAddress
+
+import org.apache.spark.Utils
+
+
+private[spark] case class ConnectionManagerId(host: String, port: Int) {
+ // DEBUG code
+ Utils.checkHost(host)
+ assert (port > 0)
+
+ def toSocketAddress() = new InetSocketAddress(host, port)
+}
+
+
+private[spark] object ConnectionManagerId {
+ def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
+ new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
new file mode 100644
index 0000000..8d9ad96
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+
+import scala.io.Source
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+import akka.dispatch.Await
+import akka.util.duration._
+
+private[spark] object ConnectionManagerTest extends Logging{
+ def main(args: Array[String]) {
+ //<mesos cluster> - the master URL
+ //<slaves file> - a list slaves to run connectionTest on
+ //[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts
+ //[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10
+ //[count] - how many times to run, default is 3
+ //[await time in seconds] : await time (in seconds), default is 600
+ if (args.length < 2) {
+ println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ")
+ System.exit(1)
+ }
+
+ if (args(0).startsWith("local")) {
+ println("This runs only on a mesos cluster")
+ }
+
+ val sc = new SparkContext(args(0), "ConnectionManagerTest")
+ val slavesFile = Source.fromFile(args(1))
+ val slaves = slavesFile.mkString.split("\n")
+ slavesFile.close()
+
+ /*println("Slaves")*/
+ /*slaves.foreach(println)*/
+ val tasknum = if (args.length > 2) args(2).toInt else slaves.length
+ val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024
+ val count = if (args.length > 4) args(4).toInt else 3
+ val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second
+ println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime)
+ val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map(
+ i => SparkEnv.get.connectionManager.id).collect()
+ println("\nSlave ConnectionManagerIds")
+ slaveConnManagerIds.foreach(println)
+ println
+
+ (0 until count).foreach(i => {
+ val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => {
+ val connManager = SparkEnv.get.connectionManager
+ val thisConnManagerId = connManager.id
+ connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ logInfo("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
+ connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
+ })
+ val results = futures.map(f => Await.result(f, awaitTime))
+ val finishTime = System.currentTimeMillis
+ Thread.sleep(5000)
+
+ val mb = size * results.size / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
+ logInfo(resultStr)
+ resultStr
+ }).collect()
+
+ println("---------------------")
+ println("Run " + i)
+ resultStrs.foreach(println)
+ println("---------------------")
+ })
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/Message.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala
new file mode 100644
index 0000000..f2ecc6d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/Message.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+import java.net.InetSocketAddress
+
+import scala.collection.mutable.ArrayBuffer
+
+
+private[spark] abstract class Message(val typ: Long, val id: Int) {
+ var senderAddress: InetSocketAddress = null
+ var started = false
+ var startTime = -1L
+ var finishTime = -1L
+
+ def size: Int
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
+
+ def timeTaken(): String = (finishTime - startTime).toString + " ms"
+
+ override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
+}
+
+
+private[spark] object Message {
+ val BUFFER_MESSAGE = 1111111111L
+
+ var lastId = 1
+
+ def getNewId() = synchronized {
+ lastId += 1
+ if (lastId == 0) {
+ lastId += 1
+ }
+ lastId
+ }
+
+ def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
+ if (dataBuffers == null) {
+ return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
+ }
+ if (dataBuffers.exists(_ == null)) {
+ throw new Exception("Attempting to create buffer message with null buffer")
+ }
+ return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
+ }
+
+ def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
+ createBufferMessage(dataBuffers, 0)
+
+ def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
+ if (dataBuffer == null) {
+ return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
+ } else {
+ return createBufferMessage(Array(dataBuffer), ackId)
+ }
+ }
+
+ def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
+ createBufferMessage(dataBuffer, 0)
+
+ def createBufferMessage(ackId: Int): BufferMessage = {
+ createBufferMessage(new Array[ByteBuffer](0), ackId)
+ }
+
+ def create(header: MessageChunkHeader): Message = {
+ val newMessage: Message = header.typ match {
+ case BUFFER_MESSAGE => new BufferMessage(header.id,
+ ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
+ }
+ newMessage.senderAddress = header.address
+ newMessage
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/MessageChunk.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/MessageChunk.scala
new file mode 100644
index 0000000..e0fe57b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/MessageChunk.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+
+private[network]
+class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
+
+ val size = if (buffer == null) 0 else buffer.remaining
+
+ lazy val buffers = {
+ val ab = new ArrayBuffer[ByteBuffer]()
+ ab += header.buffer
+ if (buffer != null) {
+ ab += buffer
+ }
+ ab
+ }
+
+ override def toString = {
+ "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
new file mode 100644
index 0000000..235fbc3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+
+
+private[spark] class MessageChunkHeader(
+ val typ: Long,
+ val id: Int,
+ val totalSize: Int,
+ val chunkSize: Int,
+ val other: Int,
+ val address: InetSocketAddress) {
+ lazy val buffer = {
+ // No need to change this, at 'use' time, we do a reverse lookup of the hostname.
+ // Refer to network.Connection
+ val ip = address.getAddress.getAddress()
+ val port = address.getPort()
+ ByteBuffer.
+ allocate(MessageChunkHeader.HEADER_SIZE).
+ putLong(typ).
+ putInt(id).
+ putInt(totalSize).
+ putInt(chunkSize).
+ putInt(other).
+ putInt(ip.size).
+ put(ip).
+ putInt(port).
+ position(MessageChunkHeader.HEADER_SIZE).
+ flip.asInstanceOf[ByteBuffer]
+ }
+
+ override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
+ " and sizes " + totalSize + " / " + chunkSize + " bytes"
+}
+
+
+private[spark] object MessageChunkHeader {
+ val HEADER_SIZE = 40
+
+ def create(buffer: ByteBuffer): MessageChunkHeader = {
+ if (buffer.remaining != HEADER_SIZE) {
+ throw new IllegalArgumentException("Cannot convert buffer data to Message")
+ }
+ val typ = buffer.getLong()
+ val id = buffer.getInt()
+ val totalSize = buffer.getInt()
+ val chunkSize = buffer.getInt()
+ val other = buffer.getInt()
+ val ipSize = buffer.getInt()
+ val ipBytes = new Array[Byte](ipSize)
+ buffer.get(ipBytes)
+ val ip = InetAddress.getByAddress(ipBytes)
+ val port = buffer.getInt()
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
new file mode 100644
index 0000000..7817151
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+private[spark] object ReceiverTest {
+
+ def main(args: Array[String]) {
+ val manager = new ConnectionManager(9999)
+ println("Started connection manager with id = " + manager.id)
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/
+ val buffer = ByteBuffer.wrap("response".getBytes())
+ Some(Message.createBufferMessage(buffer, msg.id))
+ })
+ Thread.currentThread.join()
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/SenderTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
new file mode 100644
index 0000000..7775749
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+private[spark] object SenderTest {
+
+ def main(args: Array[String]) {
+
+ if (args.length < 2) {
+ println("Usage: SenderTest <target host> <target port>")
+ System.exit(1)
+ }
+
+ val targetHost = args(0)
+ val targetPort = args(1).toInt
+ val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
+
+ val manager = new ConnectionManager(0)
+ println("Started connection manager with id = " + manager.id)
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ println("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ val size = 100 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val targetServer = args(0)
+
+ val count = 100
+ (0 until count).foreach(i => {
+ val dataMessage = Message.createBufferMessage(buffer.duplicate)
+ val startTime = System.currentTimeMillis
+ /*println("Started timer at " + startTime)*/
+ val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match {
+ case Some(response) =>
+ val buffer = response.asInstanceOf[BufferMessage].buffers(0)
+ new String(buffer.array)
+ case None => "none"
+ }
+ val finishTime = System.currentTimeMillis
+ val mb = size / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ /*val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"*/
+ val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr
+ println(resultStr)
+ })
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
new file mode 100644
index 0000000..3c29700
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import io.netty.buffer._
+
+import org.apache.spark.Logging
+
+private[spark] class FileHeader (
+ val fileLen: Int,
+ val blockId: String) extends Logging {
+
+ lazy val buffer = {
+ val buf = Unpooled.buffer()
+ buf.capacity(FileHeader.HEADER_SIZE)
+ buf.writeInt(fileLen)
+ buf.writeInt(blockId.length)
+ blockId.foreach((x: Char) => buf.writeByte(x))
+ //padding the rest of header
+ if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
+ buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
+ } else {
+ throw new Exception("too long header " + buf.readableBytes)
+ logInfo("too long header")
+ }
+ buf
+ }
+
+}
+
+private[spark] object FileHeader {
+
+ val HEADER_SIZE = 40
+
+ def getFileLenOffset = 0
+ def getFileLenSize = Integer.SIZE/8
+
+ def create(buf: ByteBuf): FileHeader = {
+ val length = buf.readInt
+ val idLength = buf.readInt
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buf.readByte().asInstanceOf[Char]
+ }
+ val blockId = idBuilder.toString()
+ new FileHeader(length, blockId)
+ }
+
+
+ def main (args:Array[String]){
+
+ val header = new FileHeader(25,"block_0");
+ val buf = header.buffer;
+ val newheader = FileHeader.create(buf);
+ System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
+
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
new file mode 100644
index 0000000..9493ccf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.util.concurrent.Executors
+
+import io.netty.buffer.ByteBuf
+import io.netty.channel.ChannelHandlerContext
+import io.netty.util.CharsetUtil
+
+import org.apache.spark.Logging
+import org.apache.spark.network.ConnectionManagerId
+
+import scala.collection.JavaConverters._
+
+
+private[spark] class ShuffleCopier extends Logging {
+
+ def getBlock(host: String, port: Int, blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+
+ val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
+ val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
+ val fc = new FileClient(handler, connectTimeout)
+
+ try {
+ fc.init()
+ fc.connect(host, port)
+ fc.sendRequest(blockId)
+ fc.waitForClose()
+ fc.close()
+ } catch {
+ // Handle any socket-related exceptions in FileClient
+ case e: Exception => {
+ logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
+ handler.handleError(blockId)
+ }
+ }
+ }
+
+ def getBlock(cmId: ConnectionManagerId, blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
+ }
+
+ def getBlocks(cmId: ConnectionManagerId,
+ blocks: Seq[(String, Long)],
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+
+ for ((blockId, size) <- blocks) {
+ getBlock(cmId, blockId, resultCollectCallback)
+ }
+ }
+}
+
+
+private[spark] object ShuffleCopier extends Logging {
+
+ private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
+ extends FileClientHandler with Logging {
+
+ override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
+ logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
+ resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
+ }
+
+ override def handleError(blockId: String) {
+ if (!isComplete) {
+ resultCollectCallBack(blockId, -1, null)
+ }
+ }
+ }
+
+ def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
+ if (size != -1) {
+ logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
+ }
+ }
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>")
+ System.exit(1)
+ }
+ val host = args(0)
+ val port = args(1).toInt
+ val file = args(2)
+ val threads = if (args.length > 3) args(3).toInt else 10
+
+ val copiers = Executors.newFixedThreadPool(80)
+ val tasks = (for (i <- Range(0, threads)) yield {
+ Executors.callable(new Runnable() {
+ def run() {
+ val copier = new ShuffleCopier()
+ copier.getBlock(host, port, file, echoResultCollectCallBack)
+ }
+ })
+ }).asJava
+ copiers.invokeAll(tasks)
+ copiers.shutdown
+ System.exit(0)
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
new file mode 100644
index 0000000..537f225
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.io.File
+
+import org.apache.spark.Logging
+
+
+private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
+
+ val server = new FileServer(pResolver, portIn)
+ server.start()
+
+ def stop() {
+ server.stop()
+ }
+
+ def port: Int = server.getPort()
+}
+
+
+/**
+ * An application for testing the shuffle sender as a standalone program.
+ */
+private[spark] object ShuffleSender {
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println(
+ "Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>")
+ System.exit(1)
+ }
+
+ val port = args(0).toInt
+ val subDirsPerLocalDir = args(1).toInt
+ val localDirs = args.drop(2).map(new File(_))
+
+ val pResovler = new PathResolver {
+ override def getAbsolutePath(blockId: String): String = {
+ if (!blockId.startsWith("shuffle_")) {
+ throw new Exception("Block " + blockId + " is not a shuffle block")
+ }
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = math.abs(blockId.hashCode)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+ val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ val file = new File(subDir, blockId)
+ return file.getAbsolutePath
+ }
+ }
+ val sender = new ShuffleSender(port, pResovler)
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/main/scala/org/apache/spark/package.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
new file mode 100644
index 0000000..1126480
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to Spark, while
+ * [[org.apache.spark.RDD]] is the data type representing a distributed collection, and provides most
+ * parallel operations.
+ *
+ * In addition, [[org.apache.spark.PairRDDFunctions]] contains operations available only on RDDs of key-value
+ * pairs, such as `groupByKey` and `join`; [[org.apache.spark.DoubleRDDFunctions]] contains operations
+ * available only on RDDs of Doubles; and [[org.apache.spark.SequenceFileRDDFunctions]] contains operations
+ * available on RDDs that can be saved as SequenceFiles. These operations are automatically
+ * available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit conversions when
+ * you `import org.apache.spark.SparkContext._`.
+ */
+package object spark {
+ // For package docs only
+}