You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/09/09 00:59:27 UTC

[4/5] [SPARK-3019] Pluggable block transfer interface (BlockTransferService)

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
deleted file mode 100644
index 4894ecd..0000000
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.nio.ByteBuffer
-
-import scala.concurrent.Await
-import scala.concurrent.duration._
-import scala.io.Source
-
-import org.apache.spark._
-
-private[spark] object ConnectionManagerTest extends Logging{
-  def main(args: Array[String]) {
-    // <mesos cluster> - the master URL <slaves file> - a list slaves to run connectionTest on
-    // [num of tasks] - the number of parallel tasks to be initiated default is number of slave
-    // hosts [size of msg in MB (integer)] - the size of messages to be sent in each task,
-    // default is 10 [count] - how many times to run, default is 3 [await time in seconds] :
-    // await time (in seconds), default is 600
-    if (args.length < 2) {
-      println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> [num of tasks] " +
-        "[size of msg in MB (integer)] [count] [await time in seconds)] ")
-      System.exit(1)
-    }
-
-    if (args(0).startsWith("local")) {
-      println("This runs only on a mesos cluster")
-    }
-
-    val sc = new SparkContext(args(0), "ConnectionManagerTest")
-    val slavesFile = Source.fromFile(args(1))
-    val slaves = slavesFile.mkString.split("\n")
-    slavesFile.close()
-
-    /* println("Slaves") */
-    /* slaves.foreach(println) */
-    val tasknum = if (args.length > 2) args(2).toInt else slaves.length
-    val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024
-    val count = if (args.length > 4) args(4).toInt else 3
-    val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second
-    println("Running " + count + " rounds of test: " + "parallel tasks = " + tasknum + ", " +
-      "msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime)
-    val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map(
-        i => SparkEnv.get.connectionManager.id).collect()
-    println("\nSlave ConnectionManagerIds")
-    slaveConnManagerIds.foreach(println)
-    println
-
-    (0 until count).foreach(i => {
-      val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => {
-        val connManager = SparkEnv.get.connectionManager
-        val thisConnManagerId = connManager.id
-        connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
-          logInfo("Received [" + msg + "] from [" + id + "]")
-          None
-        })
-
-        val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
-        buffer.flip
-
-        val startTime = System.currentTimeMillis
-        val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map{ slaveConnManagerId =>
-          {
-            val bufferMessage = Message.createBufferMessage(buffer.duplicate)
-            logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
-            connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
-          }
-        }
-        val results = futures.map(f => Await.result(f, awaitTime))
-        val finishTime = System.currentTimeMillis
-        Thread.sleep(5000)
-
-        val mb = size * results.size / 1024.0 / 1024.0
-        val ms = finishTime - startTime
-        val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms *
-          1000.0) + " MB/s"
-        logInfo(resultStr)
-        resultStr
-      }).collect()
-
-      println("---------------------")
-      println("Run " + i)
-      resultStrs.foreach(println)
-      println("---------------------")
-    })
-  }
-}
-

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
new file mode 100644
index 0000000..dcecb6b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.io.{FileInputStream, RandomAccessFile, File, InputStream}
+import java.nio.ByteBuffer
+import java.nio.channels.FileChannel.MapMode
+
+import com.google.common.io.ByteStreams
+import io.netty.buffer.{ByteBufInputStream, ByteBuf}
+
+import org.apache.spark.util.ByteBufferInputStream
+
+
+/**
+ * This interface provides an immutable view for data in the form of bytes. The implementation
+ * should specify how the data is provided:
+ *
+ * - FileSegmentManagedBuffer: data backed by part of a file
+ * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer
+ * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf
+ */
+sealed abstract class ManagedBuffer {
+  // Note that all the methods are defined with parenthesis because their implementations can
+  // have side effects (io operations).
+
+  /** Number of bytes of the data. */
+  def size: Long
+
+  /**
+   * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
+   * returned ByteBuffer should not affect the content of this buffer.
+   */
+  def nioByteBuffer(): ByteBuffer
+
+  /**
+   * Exposes this buffer's data as an InputStream. The underlying implementation does not
+   * necessarily check for the length of bytes read, so the caller is responsible for making sure
+   * it does not go over the limit.
+   */
+  def inputStream(): InputStream
+}
+
+
+/**
+ * A [[ManagedBuffer]] backed by a segment in a file
+ */
+final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long)
+  extends ManagedBuffer {
+
+  override def size: Long = length
+
+  override def nioByteBuffer(): ByteBuffer = {
+    val channel = new RandomAccessFile(file, "r").getChannel
+    channel.map(MapMode.READ_ONLY, offset, length)
+  }
+
+  override def inputStream(): InputStream = {
+    val is = new FileInputStream(file)
+    is.skip(offset)
+    ByteStreams.limit(is, length)
+  }
+}
+
+
+/**
+ * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]].
+ */
+final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer {
+
+  override def size: Long = buf.remaining()
+
+  override def nioByteBuffer() = buf.duplicate()
+
+  override def inputStream() = new ByteBufferInputStream(buf)
+}
+
+
+/**
+ * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]].
+ */
+final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer {
+
+  override def size: Long = buf.readableBytes()
+
+  override def nioByteBuffer() = buf.nioBuffer()
+
+  override def inputStream() = new ByteBufInputStream(buf)
+
+  // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it.
+  def release(): Unit = buf.release()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/Message.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala
deleted file mode 100644
index 04ea50f..0000000
--- a/core/src/main/scala/org/apache/spark/network/Message.scala
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.net.InetSocketAddress
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-
-private[spark] abstract class Message(val typ: Long, val id: Int) {
-  var senderAddress: InetSocketAddress = null
-  var started = false
-  var startTime = -1L
-  var finishTime = -1L
-  var isSecurityNeg = false
-  var hasError = false
-
-  def size: Int
-
-  def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
-
-  def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
-
-  def timeTaken(): String = (finishTime - startTime).toString + " ms"
-
-  override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
-}
-
-
-private[spark] object Message {
-  val BUFFER_MESSAGE = 1111111111L
-
-  var lastId = 1
-
-  def getNewId() = synchronized {
-    lastId += 1
-    if (lastId == 0) {
-      lastId += 1
-    }
-    lastId
-  }
-
-  def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
-    if (dataBuffers == null) {
-      return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
-    }
-    if (dataBuffers.exists(_ == null)) {
-      throw new Exception("Attempting to create buffer message with null buffer")
-    }
-    new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
-  }
-
-  def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
-    createBufferMessage(dataBuffers, 0)
-
-  def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
-    if (dataBuffer == null) {
-      createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
-    } else {
-      createBufferMessage(Array(dataBuffer), ackId)
-    }
-  }
-
-  def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
-    createBufferMessage(dataBuffer, 0)
-
-  def createBufferMessage(ackId: Int): BufferMessage = {
-    createBufferMessage(new Array[ByteBuffer](0), ackId)
-  }
-
-  def create(header: MessageChunkHeader): Message = {
-    val newMessage: Message = header.typ match {
-      case BUFFER_MESSAGE => new BufferMessage(header.id,
-        ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
-    }
-    newMessage.hasError = header.hasError
-    newMessage.senderAddress = header.address
-    newMessage
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/MessageChunk.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/MessageChunk.scala
deleted file mode 100644
index d0f986a..0000000
--- a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-
-private[network]
-class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
-
-  val size = if (buffer == null) 0 else buffer.remaining
-
-  lazy val buffers = {
-    val ab = new ArrayBuffer[ByteBuffer]()
-    ab += header.buffer
-    if (buffer != null) {
-      ab += buffer
-    }
-    ab
-  }
-
-  override def toString = {
-    "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
deleted file mode 100644
index f3ecca5..0000000
--- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.net.InetAddress
-import java.net.InetSocketAddress
-import java.nio.ByteBuffer
-
-private[spark] class MessageChunkHeader(
-    val typ: Long,
-    val id: Int,
-    val totalSize: Int,
-    val chunkSize: Int,
-    val other: Int,
-    val hasError: Boolean,
-    val securityNeg: Int,
-    val address: InetSocketAddress) {
-  lazy val buffer = {
-    // No need to change this, at 'use' time, we do a reverse lookup of the hostname.
-    // Refer to network.Connection
-    val ip = address.getAddress.getAddress()
-    val port = address.getPort()
-    ByteBuffer.
-      allocate(MessageChunkHeader.HEADER_SIZE).
-      putLong(typ).
-      putInt(id).
-      putInt(totalSize).
-      putInt(chunkSize).
-      putInt(other).
-      put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]).
-      putInt(securityNeg).
-      putInt(ip.size).
-      put(ip).
-      putInt(port).
-      position(MessageChunkHeader.HEADER_SIZE).
-      flip.asInstanceOf[ByteBuffer]
-  }
-
-  override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
-      " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg
-
-}
-
-
-private[spark] object MessageChunkHeader {
-  val HEADER_SIZE = 45
-
-  def create(buffer: ByteBuffer): MessageChunkHeader = {
-    if (buffer.remaining != HEADER_SIZE) {
-      throw new IllegalArgumentException("Cannot convert buffer data to Message")
-    }
-    val typ = buffer.getLong()
-    val id = buffer.getInt()
-    val totalSize = buffer.getInt()
-    val chunkSize = buffer.getInt()
-    val other = buffer.getInt()
-    val hasError = buffer.get() != 0
-    val securityNeg = buffer.getInt()
-    val ipSize = buffer.getInt()
-    val ipBytes = new Array[Byte](ipSize)
-    buffer.get(ipBytes)
-    val ip = InetAddress.getByAddress(ipBytes)
-    val port = buffer.getInt()
-    new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg,
-      new InetSocketAddress(ip, port))
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
deleted file mode 100644
index 53a6038..0000000
--- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.nio.ByteBuffer
-import org.apache.spark.{SecurityManager, SparkConf}
-
-private[spark] object ReceiverTest {
-  def main(args: Array[String]) {
-    val conf = new SparkConf
-    val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
-    println("Started connection manager with id = " + manager.id)
-
-    manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
-      /* println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis) */
-      val buffer = ByteBuffer.wrap("response".getBytes("utf-8"))
-      Some(Message.createBufferMessage(buffer, msg.id))
-    })
-    Thread.currentThread.join()
-  }
-}
-

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
deleted file mode 100644
index 9af9e2e..0000000
--- a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
+++ /dev/null
@@ -1,162 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.StringBuilder
-
-import org.apache.spark._
-import org.apache.spark.network._
-
-/**
- * SecurityMessage is class that contains the connectionId and sasl token
- * used in SASL negotiation. SecurityMessage has routines for converting
- * it to and from a BufferMessage so that it can be sent by the ConnectionManager
- * and easily consumed by users when received.
- * The api was modeled after BlockMessage.
- *
- * The connectionId is the connectionId of the client side. Since
- * message passing is asynchronous and its possible for the server side (receiving)
- * to get multiple different types of messages on the same connection the connectionId
- * is used to know which connnection the security message is intended for.
- *
- * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side
- * is acting as a client and connecting to node_1. SASL negotiation has to occur
- * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message.
- * node_1 receives the message from node_0 but before it can process it and send a response,
- * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0
- * and sends a security message of its own to authenticate as a client. Now node_0 gets
- * the message and it needs to decide if this message is in response to it being a client
- * (from the first send) or if its just node_1 trying to connect to it to send data.  This
- * is where the connectionId field is used. node_0 can lookup the connectionId to see if
- * it is in response to it being a client or if its in response to someone sending other data.
- *
- * The format of a SecurityMessage as its sent is:
- *   - Length of the ConnectionId
- *   - ConnectionId
- *   - Length of the token
- *   - Token
- */
-private[spark] class SecurityMessage() extends Logging {
-
-  private var connectionId: String = null
-  private var token: Array[Byte] = null
-
-  def set(byteArr: Array[Byte], newconnectionId: String) {
-    if (byteArr == null) {
-      token = new Array[Byte](0)
-    } else {
-      token = byteArr
-    }
-    connectionId = newconnectionId
-  }
-
-  /**
-   * Read the given buffer and set the members of this class.
-   */
-  def set(buffer: ByteBuffer) {
-    val idLength = buffer.getInt()
-    val idBuilder = new StringBuilder(idLength)
-    for (i <- 1 to idLength) {
-        idBuilder += buffer.getChar()
-    }
-    connectionId  = idBuilder.toString()
-
-    val tokenLength = buffer.getInt()
-    token = new Array[Byte](tokenLength)
-    if (tokenLength > 0) {
-      buffer.get(token, 0, tokenLength)
-    }
-  }
-
-  def set(bufferMsg: BufferMessage) {
-    val buffer = bufferMsg.buffers.apply(0)
-    buffer.clear()
-    set(buffer)
-  }
-
-  def getConnectionId: String = {
-    return connectionId
-  }
-
-  def getToken: Array[Byte] = {
-    return token
-  }
-
-  /**
-   * Create a BufferMessage that can be sent by the ConnectionManager containing
-   * the security information from this class.
-   * @return BufferMessage
-   */
-  def toBufferMessage: BufferMessage = {
-    val buffers = new ArrayBuffer[ByteBuffer]()
-
-    // 4 bytes for the length of the connectionId
-    // connectionId is of type char so multiple the length by 2 to get number of bytes
-    // 4 bytes for the length of token
-    // token is a byte buffer so just take the length
-    var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length)
-    buffer.putInt(connectionId.length())
-    connectionId.foreach((x: Char) => buffer.putChar(x))
-    buffer.putInt(token.length)
-
-    if (token.length > 0) {
-      buffer.put(token)
-    }
-    buffer.flip()
-    buffers += buffer
-
-    var message = Message.createBufferMessage(buffers)
-    logDebug("message total size is : " + message.size)
-    message.isSecurityNeg = true
-    return message
-  }
-
-  override def toString: String = {
-    "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]"
-  }
-}
-
-private[spark] object SecurityMessage {
-
-  /**
-   * Convert the given BufferMessage to a SecurityMessage by parsing the contents
-   * of the BufferMessage and populating the SecurityMessage fields.
-   * @param bufferMessage is a BufferMessage that was received
-   * @return new SecurityMessage
-   */
-  def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = {
-    val newSecurityMessage = new SecurityMessage()
-    newSecurityMessage.set(bufferMessage)
-    newSecurityMessage
-  }
-
-  /**
-   * Create a SecurityMessage to send from a given saslResponse.
-   * @param response is the response to a challenge from the SaslClient or Saslserver
-   * @param connectionId the client connectionId we are negotiation authentication for
-   * @return a new SecurityMessage
-   */
-  def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = {
-    val newSecurityMessage = new SecurityMessage()
-    newSecurityMessage.set(response, connectionId)
-    newSecurityMessage
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/SenderTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
deleted file mode 100644
index ea2ad10..0000000
--- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.nio.ByteBuffer
-import org.apache.spark.{SecurityManager, SparkConf}
-
-import scala.concurrent.Await
-import scala.concurrent.duration.Duration
-import scala.util.Try
-
-private[spark] object SenderTest {
-  def main(args: Array[String]) {
-
-    if (args.length < 2) {
-      println("Usage: SenderTest <target host> <target port>")
-      System.exit(1)
-    }
-
-    val targetHost = args(0)
-    val targetPort = args(1).toInt
-    val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
-    val conf = new SparkConf
-    val manager = new ConnectionManager(0, conf, new SecurityManager(conf))
-    println("Started connection manager with id = " + manager.id)
-
-    manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
-      println("Received [" + msg + "] from [" + id + "]")
-      None
-    })
-
-    val size =  100 * 1024  * 1024
-    val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
-    buffer.flip
-
-    val targetServer = args(0)
-
-    val count = 100
-    (0 until count).foreach(i => {
-      val dataMessage = Message.createBufferMessage(buffer.duplicate)
-      val startTime = System.currentTimeMillis
-      /* println("Started timer at " + startTime) */
-      val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage)
-      val responseStr: String = Try(Await.result(promise, Duration.Inf))
-        .map { response =>
-          val buffer = response.asInstanceOf[BufferMessage].buffers(0)
-          new String(buffer.array, "utf-8")
-        }.getOrElse("none")
-
-      val finishTime = System.currentTimeMillis
-      val mb = size / 1024.0 / 1024.0
-      val ms = finishTime - startTime
-      // val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms
-      //  * 1000.0) + " MB/s"
-      val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" +
-        (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr
-      println(resultStr)
-    })
-  }
-}
-

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala
new file mode 100644
index 0000000..b573f1a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
+
+import scala.collection.mutable.{ArrayBuffer, StringBuilder}
+
+// private[spark] because we need to register them in Kryo
+private[spark] case class GetBlock(id: BlockId)
+private[spark] case class GotBlock(id: BlockId, data: ByteBuffer)
+private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel)
+
+private[nio] class BlockMessage() {
+  // Un-initialized: typ = 0
+  // GetBlock: typ = 1
+  // GotBlock: typ = 2
+  // PutBlock: typ = 3
+  private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
+  private var id: BlockId = null
+  private var data: ByteBuffer = null
+  private var level: StorageLevel = null
+
+  def set(getBlock: GetBlock) {
+    typ = BlockMessage.TYPE_GET_BLOCK
+    id = getBlock.id
+  }
+
+  def set(gotBlock: GotBlock) {
+    typ = BlockMessage.TYPE_GOT_BLOCK
+    id = gotBlock.id
+    data = gotBlock.data
+  }
+
+  def set(putBlock: PutBlock) {
+    typ = BlockMessage.TYPE_PUT_BLOCK
+    id = putBlock.id
+    data = putBlock.data
+    level = putBlock.level
+  }
+
+  def set(buffer: ByteBuffer) {
+    /*
+    println()
+    println("BlockMessage: ")
+    while(buffer.remaining > 0) {
+      print(buffer.get())
+    }
+    buffer.rewind()
+    println()
+    println()
+    */
+    typ = buffer.getInt()
+    val idLength = buffer.getInt()
+    val idBuilder = new StringBuilder(idLength)
+    for (i <- 1 to idLength) {
+      idBuilder += buffer.getChar()
+    }
+    id = BlockId(idBuilder.toString)
+
+    if (typ == BlockMessage.TYPE_PUT_BLOCK) {
+
+      val booleanInt = buffer.getInt()
+      val replication = buffer.getInt()
+      level = StorageLevel(booleanInt, replication)
+
+      val dataLength = buffer.getInt()
+      data = ByteBuffer.allocate(dataLength)
+      if (dataLength != buffer.remaining) {
+        throw new Exception("Error parsing buffer")
+      }
+      data.put(buffer)
+      data.flip()
+    } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
+
+      val dataLength = buffer.getInt()
+      data = ByteBuffer.allocate(dataLength)
+      if (dataLength != buffer.remaining) {
+        throw new Exception("Error parsing buffer")
+      }
+      data.put(buffer)
+      data.flip()
+    }
+
+  }
+
+  def set(bufferMsg: BufferMessage) {
+    val buffer = bufferMsg.buffers.apply(0)
+    buffer.clear()
+    set(buffer)
+  }
+
+  def getType: Int = typ
+  def getId: BlockId = id
+  def getData: ByteBuffer = data
+  def getLevel: StorageLevel =  level
+
+  def toBufferMessage: BufferMessage = {
+    val buffers = new ArrayBuffer[ByteBuffer]()
+    var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2)
+    buffer.putInt(typ).putInt(id.name.length)
+    id.name.foreach((x: Char) => buffer.putChar(x))
+    buffer.flip()
+    buffers += buffer
+
+    if (typ == BlockMessage.TYPE_PUT_BLOCK) {
+      buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication)
+      buffer.flip()
+      buffers += buffer
+
+      buffer = ByteBuffer.allocate(4).putInt(data.remaining)
+      buffer.flip()
+      buffers += buffer
+
+      buffers += data
+    } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
+      buffer = ByteBuffer.allocate(4).putInt(data.remaining)
+      buffer.flip()
+      buffers += buffer
+
+      buffers += data
+    }
+
+    /*
+    println()
+    println("BlockMessage: ")
+    buffers.foreach(b => {
+      while(b.remaining > 0) {
+        print(b.get())
+      }
+      b.rewind()
+    })
+    println()
+    println()
+    */
+    Message.createBufferMessage(buffers)
+  }
+
+  override def toString: String = {
+    "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level +
+    ", data = " + (if (data != null) data.remaining.toString  else "null") + "]"
+  }
+}
+
+private[nio] object BlockMessage {
+  val TYPE_NON_INITIALIZED: Int = 0
+  val TYPE_GET_BLOCK: Int = 1
+  val TYPE_GOT_BLOCK: Int = 2
+  val TYPE_PUT_BLOCK: Int = 3
+
+  def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = {
+    val newBlockMessage = new BlockMessage()
+    newBlockMessage.set(bufferMessage)
+    newBlockMessage
+  }
+
+  def fromByteBuffer(buffer: ByteBuffer): BlockMessage = {
+    val newBlockMessage = new BlockMessage()
+    newBlockMessage.set(buffer)
+    newBlockMessage
+  }
+
+  def fromGetBlock(getBlock: GetBlock): BlockMessage = {
+    val newBlockMessage = new BlockMessage()
+    newBlockMessage.set(getBlock)
+    newBlockMessage
+  }
+
+  def fromGotBlock(gotBlock: GotBlock): BlockMessage = {
+    val newBlockMessage = new BlockMessage()
+    newBlockMessage.set(gotBlock)
+    newBlockMessage
+  }
+
+  def fromPutBlock(putBlock: PutBlock): BlockMessage = {
+    val newBlockMessage = new BlockMessage()
+    newBlockMessage.set(putBlock)
+    newBlockMessage
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala
new file mode 100644
index 0000000..a1a2c00
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.nio.ByteBuffer
+
+import org.apache.spark._
+import org.apache.spark.storage.{StorageLevel, TestBlockId}
+
+import scala.collection.mutable.ArrayBuffer
+
+private[nio]
+class BlockMessageArray(var blockMessages: Seq[BlockMessage])
+  extends Seq[BlockMessage] with Logging {
+
+  def this(bm: BlockMessage) = this(Array(bm))
+
+  def this() = this(null.asInstanceOf[Seq[BlockMessage]])
+
+  def apply(i: Int) = blockMessages(i)
+
+  def iterator = blockMessages.iterator
+
+  def length = blockMessages.length
+
+  def set(bufferMessage: BufferMessage) {
+    val startTime = System.currentTimeMillis
+    val newBlockMessages = new ArrayBuffer[BlockMessage]()
+    val buffer = bufferMessage.buffers(0)
+    buffer.clear()
+    /*
+    println()
+    println("BlockMessageArray: ")
+    while(buffer.remaining > 0) {
+      print(buffer.get())
+    }
+    buffer.rewind()
+    println()
+    println()
+    */
+    while (buffer.remaining() > 0) {
+      val size = buffer.getInt()
+      logDebug("Creating block message of size " + size + " bytes")
+      val newBuffer = buffer.slice()
+      newBuffer.clear()
+      newBuffer.limit(size)
+      logDebug("Trying to convert buffer " + newBuffer + " to block message")
+      val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer)
+      logDebug("Created " + newBlockMessage)
+      newBlockMessages += newBlockMessage
+      buffer.position(buffer.position() + size)
+    }
+    val finishTime = System.currentTimeMillis
+    logDebug("Converted block message array from buffer message in " +
+      (finishTime - startTime) / 1000.0  + " s")
+    this.blockMessages = newBlockMessages
+  }
+
+  def toBufferMessage: BufferMessage = {
+    val buffers = new ArrayBuffer[ByteBuffer]()
+
+    blockMessages.foreach(blockMessage => {
+      val bufferMessage = blockMessage.toBufferMessage
+      logDebug("Adding " + blockMessage)
+      val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size)
+      sizeBuffer.flip
+      buffers += sizeBuffer
+      buffers ++= bufferMessage.buffers
+      logDebug("Added " + bufferMessage)
+    })
+
+    logDebug("Buffer list:")
+    buffers.foreach((x: ByteBuffer) => logDebug("" + x))
+    /*
+    println()
+    println("BlockMessageArray: ")
+    buffers.foreach(b => {
+      while(b.remaining > 0) {
+        print(b.get())
+      }
+      b.rewind()
+    })
+    println()
+    println()
+    */
+    Message.createBufferMessage(buffers)
+  }
+}
+
+private[nio] object BlockMessageArray {
+
+  def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
+    val newBlockMessageArray = new BlockMessageArray()
+    newBlockMessageArray.set(bufferMessage)
+    newBlockMessageArray
+  }
+
+  def main(args: Array[String]) {
+    val blockMessages =
+      (0 until 10).map { i =>
+        if (i % 2 == 0) {
+          val buffer =  ByteBuffer.allocate(100)
+          buffer.clear
+          BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer,
+            StorageLevel.MEMORY_ONLY_SER))
+        } else {
+          BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString)))
+        }
+      }
+    val blockMessageArray = new BlockMessageArray(blockMessages)
+    println("Block message array created")
+
+    val bufferMessage = blockMessageArray.toBufferMessage
+    println("Converted to buffer message")
+
+    val totalSize = bufferMessage.size
+    val newBuffer = ByteBuffer.allocate(totalSize)
+    newBuffer.clear()
+    bufferMessage.buffers.foreach(buffer => {
+      assert (0 == buffer.position())
+      newBuffer.put(buffer)
+      buffer.rewind()
+    })
+    newBuffer.flip
+    val newBufferMessage = Message.createBufferMessage(newBuffer)
+    println("Copied to new buffer message, size = " + newBufferMessage.size)
+
+    val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage)
+    println("Converted back to block message array")
+    newBlockMessageArray.foreach(blockMessage => {
+      blockMessage.getType match {
+        case BlockMessage.TYPE_PUT_BLOCK => {
+          val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
+          println(pB)
+        }
+        case BlockMessage.TYPE_GET_BLOCK => {
+          val gB = new GetBlock(blockMessage.getId)
+          println(gB)
+        }
+      }
+    })
+  }
+}
+
+

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala
new file mode 100644
index 0000000..3b245c5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.storage.BlockManager
+
+
+private[nio]
+class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
+  extends Message(Message.BUFFER_MESSAGE, id_) {
+
+  val initialSize = currentSize()
+  var gotChunkForSendingOnce = false
+
+  def size = initialSize
+
+  def currentSize() = {
+    if (buffers == null || buffers.isEmpty) {
+      0
+    } else {
+      buffers.map(_.remaining).reduceLeft(_ + _)
+    }
+  }
+
+  def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
+    if (maxChunkSize <= 0) {
+      throw new Exception("Max chunk size is " + maxChunkSize)
+    }
+
+    val security = if (isSecurityNeg) 1 else 0
+    if (size == 0 && !gotChunkForSendingOnce) {
+      val newChunk = new MessageChunk(
+        new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null)
+      gotChunkForSendingOnce = true
+      return Some(newChunk)
+    }
+
+    while(!buffers.isEmpty) {
+      val buffer = buffers(0)
+      if (buffer.remaining == 0) {
+        BlockManager.dispose(buffer)
+        buffers -= buffer
+      } else {
+        val newBuffer = if (buffer.remaining <= maxChunkSize) {
+          buffer.duplicate()
+        } else {
+          buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
+        }
+        buffer.position(buffer.position + newBuffer.remaining)
+        val newChunk = new MessageChunk(new MessageChunkHeader(
+          typ, id, size, newBuffer.remaining, ackId,
+          hasError, security, senderAddress), newBuffer)
+        gotChunkForSendingOnce = true
+        return Some(newChunk)
+      }
+    }
+    None
+  }
+
+  def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
+    // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
+    if (buffers.size > 1) {
+      throw new Exception("Attempting to get chunk from message with multiple data buffers")
+    }
+    val buffer = buffers(0)
+    val security = if (isSecurityNeg) 1 else 0
+    if (buffer.remaining > 0) {
+      if (buffer.remaining < chunkSize) {
+        throw new Exception("Not enough space in data buffer for receiving chunk")
+      }
+      val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
+      buffer.position(buffer.position + newBuffer.remaining)
+      val newChunk = new MessageChunk(new MessageChunkHeader(
+          typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer)
+      return Some(newChunk)
+    }
+    None
+  }
+
+  def flip() {
+    buffers.foreach(_.flip)
+  }
+
+  def hasAckId() = (ackId != 0)
+
+  def isCompletelyReceived() = !buffers(0).hasRemaining
+
+  override def toString = {
+    if (hasAckId) {
+      "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
+    } else {
+      "BufferMessage(id = " + id + ", size = " + size + ")"
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
new file mode 100644
index 0000000..74074a8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -0,0 +1,587 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.net._
+import java.nio._
+import java.nio.channels._
+
+import org.apache.spark._
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
+
+private[nio]
+abstract class Connection(val channel: SocketChannel, val selector: Selector,
+    val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
+  extends Logging {
+
+  var sparkSaslServer: SparkSaslServer = null
+  var sparkSaslClient: SparkSaslClient = null
+
+  def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
+    this(channel_, selector_,
+      ConnectionManagerId.fromSocketAddress(
+        channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_)
+  }
+
+  channel.configureBlocking(false)
+  channel.socket.setTcpNoDelay(true)
+  channel.socket.setReuseAddress(true)
+  channel.socket.setKeepAlive(true)
+  /* channel.socket.setReceiveBufferSize(32768) */
+
+  @volatile private var closed = false
+  var onCloseCallback: Connection => Unit = null
+  var onExceptionCallback: (Connection, Exception) => Unit = null
+  var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
+
+  val remoteAddress = getRemoteAddress()
+
+  /**
+   * Used to synchronize client requests: client's work-related requests must
+   * wait until SASL authentication completes.
+   */
+  private val authenticated = new Object()
+
+  def getAuthenticated(): Object = authenticated
+
+  def isSaslComplete(): Boolean
+
+  def resetForceReregister(): Boolean
+
+  // Read channels typically do not register for write and write does not for read
+  // Now, we do have write registering for read too (temporarily), but this is to detect
+  // channel close NOT to actually read/consume data on it !
+  // How does this work if/when we move to SSL ?
+
+  // What is the interest to register with selector for when we want this connection to be selected
+  def registerInterest()
+
+  // What is the interest to register with selector for when we want this connection to
+  // be de-selected
+  // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack,
+  // it will be SelectionKey.OP_READ (until we fix it properly)
+  def unregisterInterest()
+
+  // On receiving a read event, should we change the interest for this channel or not ?
+  // Will be true for ReceivingConnection, false for SendingConnection.
+  def changeInterestForRead(): Boolean
+
+  private def disposeSasl() {
+    if (sparkSaslServer != null) {
+      sparkSaslServer.dispose()
+    }
+
+    if (sparkSaslClient != null) {
+      sparkSaslClient.dispose()
+    }
+  }
+
+  // On receiving a write event, should we change the interest for this channel or not ?
+  // Will be false for ReceivingConnection, true for SendingConnection.
+  // Actually, for now, should not get triggered for ReceivingConnection
+  def changeInterestForWrite(): Boolean
+
+  def getRemoteConnectionManagerId(): ConnectionManagerId = {
+    socketRemoteConnectionManagerId
+  }
+
+  def key() = channel.keyFor(selector)
+
+  def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
+
+  // Returns whether we have to register for further reads or not.
+  def read(): Boolean = {
+    throw new UnsupportedOperationException(
+      "Cannot read on connection of type " + this.getClass.toString)
+  }
+
+  // Returns whether we have to register for further writes or not.
+  def write(): Boolean = {
+    throw new UnsupportedOperationException(
+      "Cannot write on connection of type " + this.getClass.toString)
+  }
+
+  def close() {
+    closed = true
+    val k = key()
+    if (k != null) {
+      k.cancel()
+    }
+    channel.close()
+    disposeSasl()
+    callOnCloseCallback()
+  }
+
+  protected def isClosed: Boolean = closed
+
+  def onClose(callback: Connection => Unit) {
+    onCloseCallback = callback
+  }
+
+  def onException(callback: (Connection, Exception) => Unit) {
+    onExceptionCallback = callback
+  }
+
+  def onKeyInterestChange(callback: (Connection, Int) => Unit) {
+    onKeyInterestChangeCallback = callback
+  }
+
+  def callOnExceptionCallback(e: Exception) {
+    if (onExceptionCallback != null) {
+      onExceptionCallback(this, e)
+    } else {
+      logError("Error in connection to " + getRemoteConnectionManagerId() +
+        " and OnExceptionCallback not registered", e)
+    }
+  }
+
+  def callOnCloseCallback() {
+    if (onCloseCallback != null) {
+      onCloseCallback(this)
+    } else {
+      logWarning("Connection to " + getRemoteConnectionManagerId() +
+        " closed and OnExceptionCallback not registered")
+    }
+
+  }
+
+  def changeConnectionKeyInterest(ops: Int) {
+    if (onKeyInterestChangeCallback != null) {
+      onKeyInterestChangeCallback(this, ops)
+    } else {
+      throw new Exception("OnKeyInterestChangeCallback not registered")
+    }
+  }
+
+  def printRemainingBuffer(buffer: ByteBuffer) {
+    val bytes = new Array[Byte](buffer.remaining)
+    val curPosition = buffer.position
+    buffer.get(bytes)
+    bytes.foreach(x => print(x + " "))
+    buffer.position(curPosition)
+    print(" (" + bytes.size + ")")
+  }
+
+  def printBuffer(buffer: ByteBuffer, position: Int, length: Int) {
+    val bytes = new Array[Byte](length)
+    val curPosition = buffer.position
+    buffer.position(position)
+    buffer.get(bytes)
+    bytes.foreach(x => print(x + " "))
+    print(" (" + position + ", " + length + ")")
+    buffer.position(curPosition)
+  }
+}
+
+
+private[nio]
+class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
+    remoteId_ : ConnectionManagerId, id_ : ConnectionId)
+  extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
+
+  def isSaslComplete(): Boolean = {
+    if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
+  }
+
+  private class Outbox {
+    val messages = new Queue[Message]()
+    val defaultChunkSize = 65536
+    var nextMessageToBeUsed = 0
+
+    def addMessage(message: Message) {
+      messages.synchronized {
+        /* messages += message */
+        messages.enqueue(message)
+        logDebug("Added [" + message + "] to outbox for sending to " +
+          "[" + getRemoteConnectionManagerId() + "]")
+      }
+    }
+
+    def getChunk(): Option[MessageChunk] = {
+      messages.synchronized {
+        while (!messages.isEmpty) {
+          /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
+          /* val message = messages(nextMessageToBeUsed) */
+          val message = messages.dequeue()
+          val chunk = message.getChunkForSending(defaultChunkSize)
+          if (chunk.isDefined) {
+            messages.enqueue(message)
+            nextMessageToBeUsed = nextMessageToBeUsed + 1
+            if (!message.started) {
+              logDebug(
+                "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
+              message.started = true
+              message.startTime = System.currentTimeMillis
+            }
+            logTrace(
+              "Sending chunk from [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
+            return chunk
+          } else {
+            message.finishTime = System.currentTimeMillis
+            logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
+              "] in "  + message.timeTaken )
+          }
+        }
+      }
+      None
+    }
+  }
+
+  // outbox is used as a lock - ensure that it is always used as a leaf (since methods which
+  // lock it are invoked in context of other locks)
+  private val outbox = new Outbox()
+  /*
+    This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly
+    different purpose. This flag is to see if we need to force reregister for write even when we
+    do not have any pending bytes to write to socket.
+    This can happen due to a race between adding pending buffers, and checking for existing of
+    data as detailed in https://github.com/mesos/spark/pull/791
+   */
+  private var needForceReregister = false
+
+  val currentBuffers = new ArrayBuffer[ByteBuffer]()
+
+  /* channel.socket.setSendBufferSize(256 * 1024) */
+
+  override def getRemoteAddress() = address
+
+  val DEFAULT_INTEREST = SelectionKey.OP_READ
+
+  override def registerInterest() {
+    // Registering read too - does not really help in most cases, but for some
+    // it does - so let us keep it for now.
+    changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
+  }
+
+  override def unregisterInterest() {
+    changeConnectionKeyInterest(DEFAULT_INTEREST)
+  }
+
+  def send(message: Message) {
+    outbox.synchronized {
+      outbox.addMessage(message)
+      needForceReregister = true
+    }
+    if (channel.isConnected) {
+      registerInterest()
+    }
+  }
+
+  // return previous value after resetting it.
+  def resetForceReregister(): Boolean = {
+    outbox.synchronized {
+      val result = needForceReregister
+      needForceReregister = false
+      result
+    }
+  }
+
+  // MUST be called within the selector loop
+  def connect() {
+    try{
+      channel.register(selector, SelectionKey.OP_CONNECT)
+      channel.connect(address)
+      logInfo("Initiating connection to [" + address + "]")
+    } catch {
+      case e: Exception => {
+        logError("Error connecting to " + address, e)
+        callOnExceptionCallback(e)
+      }
+    }
+  }
+
+  def finishConnect(force: Boolean): Boolean = {
+    try {
+      // Typically, this should finish immediately since it was triggered by a connect
+      // selection - though need not necessarily always complete successfully.
+      val connected = channel.finishConnect
+      if (!force && !connected) {
+        logInfo(
+          "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
+        return false
+      }
+
+      // Fallback to previous behavior - assume finishConnect completed
+      // This will happen only when finishConnect failed for some repeated number of times
+      // (10 or so)
+      // Is highly unlikely unless there was an unclean close of socket, etc
+      registerInterest()
+      logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
+    } catch {
+      case e: Exception => {
+        logWarning("Error finishing connection to " + address, e)
+        callOnExceptionCallback(e)
+      }
+    }
+    true
+  }
+
+  override def write(): Boolean = {
+    try {
+      while (true) {
+        if (currentBuffers.size == 0) {
+          outbox.synchronized {
+            outbox.getChunk() match {
+              case Some(chunk) => {
+                val buffers = chunk.buffers
+                // If we have 'seen' pending messages, then reset flag - since we handle that as
+                // normal registering of event (below)
+                if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister()
+
+                currentBuffers ++= buffers
+              }
+              case None => {
+                // changeConnectionKeyInterest(0)
+                /* key.interestOps(0) */
+                return false
+              }
+            }
+          }
+        }
+
+        if (currentBuffers.size > 0) {
+          val buffer = currentBuffers(0)
+          val remainingBytes = buffer.remaining
+          val writtenBytes = channel.write(buffer)
+          if (buffer.remaining == 0) {
+            currentBuffers -= buffer
+          }
+          if (writtenBytes < remainingBytes) {
+            // re-register for write.
+            return true
+          }
+        }
+      }
+    } catch {
+      case e: Exception => {
+        logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
+        callOnExceptionCallback(e)
+        close()
+        return false
+      }
+    }
+    // should not happen - to keep scala compiler happy
+    true
+  }
+
+  // This is a hack to determine if remote socket was closed or not.
+  // SendingConnection DOES NOT expect to receive any data - if it does, it is an error
+  // For a bunch of cases, read will return -1 in case remote socket is closed : hence we
+  // register for reads to determine that.
+  override def read(): Boolean = {
+    // We don't expect the other side to send anything; so, we just read to detect an error or EOF.
+    try {
+      val length = channel.read(ByteBuffer.allocate(1))
+      if (length == -1) { // EOF
+        close()
+      } else if (length > 0) {
+        logWarning(
+          "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
+      }
+    } catch {
+      case e: Exception =>
+        logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(),
+          e)
+        callOnExceptionCallback(e)
+        close()
+    }
+
+    false
+  }
+
+  override def changeInterestForRead(): Boolean = false
+
+  override def changeInterestForWrite(): Boolean = ! isClosed
+}
+
+
+// Must be created within selector loop - else deadlock
+private[spark] class ReceivingConnection(
+    channel_ : SocketChannel,
+    selector_ : Selector,
+    id_ : ConnectionId)
+    extends Connection(channel_, selector_, id_) {
+
+  def isSaslComplete(): Boolean = {
+    if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
+  }
+
+  class Inbox() {
+    val messages = new HashMap[Int, BufferMessage]()
+
+    def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
+
+      def createNewMessage: BufferMessage = {
+        val newMessage = Message.create(header).asInstanceOf[BufferMessage]
+        newMessage.started = true
+        newMessage.startTime = System.currentTimeMillis
+        newMessage.isSecurityNeg = header.securityNeg == 1
+        logDebug(
+          "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
+        messages += ((newMessage.id, newMessage))
+        newMessage
+      }
+
+      val message = messages.getOrElseUpdate(header.id, createNewMessage)
+      logTrace(
+        "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
+      message.getChunkForReceiving(header.chunkSize)
+    }
+
+    def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
+      messages.get(chunk.header.id)
+    }
+
+    def removeMessage(message: Message) {
+      messages -= message.id
+    }
+  }
+
+  @volatile private var inferredRemoteManagerId: ConnectionManagerId = null
+
+  override def getRemoteConnectionManagerId(): ConnectionManagerId = {
+    val currId = inferredRemoteManagerId
+    if (currId != null) currId else super.getRemoteConnectionManagerId()
+  }
+
+  // The reciever's remote address is the local socket on remote side : which is NOT
+  // the connection manager id of the receiver.
+  // We infer that from the messages we receive on the receiver socket.
+  private def processConnectionManagerId(header: MessageChunkHeader) {
+    val currId = inferredRemoteManagerId
+    if (header.address == null || currId != null) return
+
+    val managerId = ConnectionManagerId.fromSocketAddress(header.address)
+
+    if (managerId != null) {
+      inferredRemoteManagerId = managerId
+    }
+  }
+
+
+  val inbox = new Inbox()
+  val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
+  var onReceiveCallback: (Connection, Message) => Unit = null
+  var currentChunk: MessageChunk = null
+
+  channel.register(selector, SelectionKey.OP_READ)
+
+  override def read(): Boolean = {
+    try {
+      while (true) {
+        if (currentChunk == null) {
+          val headerBytesRead = channel.read(headerBuffer)
+          if (headerBytesRead == -1) {
+            close()
+            return false
+          }
+          if (headerBuffer.remaining > 0) {
+            // re-register for read event ...
+            return true
+          }
+          headerBuffer.flip
+          if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
+            throw new Exception(
+              "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
+          }
+          val header = MessageChunkHeader.create(headerBuffer)
+          headerBuffer.clear()
+
+          processConnectionManagerId(header)
+
+          header.typ match {
+            case Message.BUFFER_MESSAGE => {
+              if (header.totalSize == 0) {
+                if (onReceiveCallback != null) {
+                  onReceiveCallback(this, Message.create(header))
+                }
+                currentChunk = null
+                // re-register for read event ...
+                return true
+              } else {
+                currentChunk = inbox.getChunk(header).orNull
+              }
+            }
+            case _ => throw new Exception("Message of unknown type received")
+          }
+        }
+
+        if (currentChunk == null) throw new Exception("No message chunk to receive data")
+
+        val bytesRead = channel.read(currentChunk.buffer)
+        if (bytesRead == 0) {
+          // re-register for read event ...
+          return true
+        } else if (bytesRead == -1) {
+          close()
+          return false
+        }
+
+        /* logDebug("Read " + bytesRead + " bytes for the buffer") */
+
+        if (currentChunk.buffer.remaining == 0) {
+          /* println("Filled buffer at " + System.currentTimeMillis) */
+          val bufferMessage = inbox.getMessageForChunk(currentChunk).get
+          if (bufferMessage.isCompletelyReceived) {
+            bufferMessage.flip()
+            bufferMessage.finishTime = System.currentTimeMillis
+            logDebug("Finished receiving [" + bufferMessage + "] from " +
+              "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
+            if (onReceiveCallback != null) {
+              onReceiveCallback(this, bufferMessage)
+            }
+            inbox.removeMessage(bufferMessage)
+          }
+          currentChunk = null
+        }
+      }
+    } catch {
+      case e: Exception => {
+        logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
+        callOnExceptionCallback(e)
+        close()
+        return false
+      }
+    }
+    // should not happen - to keep scala compiler happy
+    true
+  }
+
+  def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
+
+  // override def changeInterestForRead(): Boolean = ! isClosed
+  override def changeInterestForRead(): Boolean = true
+
+  override def changeInterestForWrite(): Boolean = {
+    throw new IllegalStateException("Unexpected invocation right now")
+  }
+
+  override def registerInterest() {
+    // Registering read too - does not really help in most cases, but for some
+    // it does - so let us keep it for now.
+    changeConnectionKeyInterest(SelectionKey.OP_READ)
+  }
+
+  override def unregisterInterest() {
+    changeConnectionKeyInterest(0)
+  }
+
+  // For read conn, always false.
+  override def resetForceReregister(): Boolean = false
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala
new file mode 100644
index 0000000..764dc5e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) {
+  override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId
+}
+
+private[nio] object ConnectionId {
+
+  def createConnectionIdFromString(connectionIdString: String): ConnectionId = {
+    val res = connectionIdString.split("_").map(_.trim())
+    if (res.size != 3) {
+      throw new Exception("Error converting ConnectionId string: " + connectionIdString +
+        " to a ConnectionId Object")
+    }
+    new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org