You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/02/29 02:25:28 UTC
[12/14] spark git commit: [SPARK-13529][BUILD] Move network/* modules
into common/network-*
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java
new file mode 100644
index 0000000..127335e
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+import java.util.List;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
+import io.netty.channel.FileRegion;
+import io.netty.handler.codec.MessageToMessageDecoder;
+import io.netty.util.AbstractReferenceCounted;
+import io.netty.util.ReferenceCountUtil;
+
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Provides SASL-based encription for transport channels. The single method exposed by this
+ * class installs the needed channel handlers on a connected channel.
+ */
+class SaslEncryption {
+
+ @VisibleForTesting
+ static final String ENCRYPTION_HANDLER_NAME = "saslEncryption";
+
+ /**
+ * Adds channel handlers that perform encryption / decryption of data using SASL.
+ *
+ * @param channel The channel.
+ * @param backend The SASL backend.
+ * @param maxOutboundBlockSize Max size in bytes of outgoing encrypted blocks, to control
+ * memory usage.
+ */
+ static void addToChannel(
+ Channel channel,
+ SaslEncryptionBackend backend,
+ int maxOutboundBlockSize) {
+ channel.pipeline()
+ .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize))
+ .addFirst("saslDecryption", new DecryptionHandler(backend))
+ .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder());
+ }
+
+ private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
+
+ private final int maxOutboundBlockSize;
+ private final SaslEncryptionBackend backend;
+
+ EncryptionHandler(SaslEncryptionBackend backend, int maxOutboundBlockSize) {
+ this.backend = backend;
+ this.maxOutboundBlockSize = maxOutboundBlockSize;
+ }
+
+ /**
+ * Wrap the incoming message in an implementation that will perform encryption lazily. This is
+ * needed to guarantee ordering of the outgoing encrypted packets - they need to be decrypted in
+ * the same order, and netty doesn't have an atomic ChannelHandlerContext.write() API, so it
+ * does not guarantee any ordering.
+ */
+ @Override
+ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
+ throws Exception {
+
+ ctx.write(new EncryptedMessage(backend, msg, maxOutboundBlockSize), promise);
+ }
+
+ @Override
+ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+ try {
+ backend.dispose();
+ } finally {
+ super.handlerRemoved(ctx);
+ }
+ }
+
+ }
+
+ private static class DecryptionHandler extends MessageToMessageDecoder<ByteBuf> {
+
+ private final SaslEncryptionBackend backend;
+
+ DecryptionHandler(SaslEncryptionBackend backend) {
+ this.backend = backend;
+ }
+
+ @Override
+ protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out)
+ throws Exception {
+
+ byte[] data;
+ int offset;
+ int length = msg.readableBytes();
+ if (msg.hasArray()) {
+ data = msg.array();
+ offset = msg.arrayOffset();
+ msg.skipBytes(length);
+ } else {
+ data = new byte[length];
+ msg.readBytes(data);
+ offset = 0;
+ }
+
+ out.add(Unpooled.wrappedBuffer(backend.unwrap(data, offset, length)));
+ }
+
+ }
+
+ @VisibleForTesting
+ static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion {
+
+ private final SaslEncryptionBackend backend;
+ private final boolean isByteBuf;
+ private final ByteBuf buf;
+ private final FileRegion region;
+
+ /**
+ * A channel used to buffer input data for encryption. The channel has an upper size bound
+ * so that if the input is larger than the allowed buffer, it will be broken into multiple
+ * chunks.
+ */
+ private final ByteArrayWritableChannel byteChannel;
+
+ private ByteBuf currentHeader;
+ private ByteBuffer currentChunk;
+ private long currentChunkSize;
+ private long currentReportedBytes;
+ private long unencryptedChunkSize;
+ private long transferred;
+
+ EncryptedMessage(SaslEncryptionBackend backend, Object msg, int maxOutboundBlockSize) {
+ Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion,
+ "Unrecognized message type: %s", msg.getClass().getName());
+ this.backend = backend;
+ this.isByteBuf = msg instanceof ByteBuf;
+ this.buf = isByteBuf ? (ByteBuf) msg : null;
+ this.region = isByteBuf ? null : (FileRegion) msg;
+ this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize);
+ }
+
+ /**
+ * Returns the size of the original (unencrypted) message.
+ *
+ * This makes assumptions about how netty treats FileRegion instances, because there's no way
+ * to know beforehand what will be the size of the encrypted message. Namely, it assumes
+ * that netty will try to transfer data from this message while
+ * <code>transfered() < count()</code>. So these two methods return, technically, wrong data,
+ * but netty doesn't know better.
+ */
+ @Override
+ public long count() {
+ return isByteBuf ? buf.readableBytes() : region.count();
+ }
+
+ @Override
+ public long position() {
+ return 0;
+ }
+
+ /**
+ * Returns an approximation of the amount of data transferred. See {@link #count()}.
+ */
+ @Override
+ public long transfered() {
+ return transferred;
+ }
+
+ /**
+ * Transfers data from the original message to the channel, encrypting it in the process.
+ *
+ * This method also breaks down the original message into smaller chunks when needed. This
+ * is done to keep memory usage under control. This avoids having to copy the whole message
+ * data into memory at once, and can avoid ballooning memory usage when transferring large
+ * messages such as shuffle blocks.
+ *
+ * The {@link #transfered()} counter also behaves a little funny, in that it won't go forward
+ * until a whole chunk has been written. This is done because the code can't use the actual
+ * number of bytes written to the channel as the transferred count (see {@link #count()}).
+ * Instead, once an encrypted chunk is written to the output (including its header), the
+ * size of the original block will be added to the {@link #transfered()} amount.
+ */
+ @Override
+ public long transferTo(final WritableByteChannel target, final long position)
+ throws IOException {
+
+ Preconditions.checkArgument(position == transfered(), "Invalid position.");
+
+ long reportedWritten = 0L;
+ long actuallyWritten = 0L;
+ do {
+ if (currentChunk == null) {
+ nextChunk();
+ }
+
+ if (currentHeader.readableBytes() > 0) {
+ int bytesWritten = target.write(currentHeader.nioBuffer());
+ currentHeader.skipBytes(bytesWritten);
+ actuallyWritten += bytesWritten;
+ if (currentHeader.readableBytes() > 0) {
+ // Break out of loop if there are still header bytes left to write.
+ break;
+ }
+ }
+
+ actuallyWritten += target.write(currentChunk);
+ if (!currentChunk.hasRemaining()) {
+ // Only update the count of written bytes once a full chunk has been written.
+ // See method javadoc.
+ long chunkBytesRemaining = unencryptedChunkSize - currentReportedBytes;
+ reportedWritten += chunkBytesRemaining;
+ transferred += chunkBytesRemaining;
+ currentHeader.release();
+ currentHeader = null;
+ currentChunk = null;
+ currentChunkSize = 0;
+ currentReportedBytes = 0;
+ }
+ } while (currentChunk == null && transfered() + reportedWritten < count());
+
+ // Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead,
+ // we return 1 until we can (i.e. until the reported count would actually match the size
+ // of the current chunk), at which point we resort to returning 0 so that the counts still
+ // match, at the cost of some performance. That situation should be rare, though.
+ if (reportedWritten != 0L) {
+ return reportedWritten;
+ }
+
+ if (actuallyWritten > 0 && currentReportedBytes < currentChunkSize - 1) {
+ transferred += 1L;
+ currentReportedBytes += 1L;
+ return 1L;
+ }
+
+ return 0L;
+ }
+
+ private void nextChunk() throws IOException {
+ byteChannel.reset();
+ if (isByteBuf) {
+ int copied = byteChannel.write(buf.nioBuffer());
+ buf.skipBytes(copied);
+ } else {
+ region.transferTo(byteChannel, region.transfered());
+ }
+
+ byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length());
+ this.currentChunk = ByteBuffer.wrap(encrypted);
+ this.currentChunkSize = encrypted.length;
+ this.currentHeader = Unpooled.copyLong(8 + currentChunkSize);
+ this.unencryptedChunkSize = byteChannel.length();
+ }
+
+ @Override
+ protected void deallocate() {
+ if (currentHeader != null) {
+ currentHeader.release();
+ }
+ if (buf != null) {
+ buf.release();
+ }
+ if (region != null) {
+ region.release();
+ }
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java
new file mode 100644
index 0000000..89b78bc
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.sasl.SaslException;
+
+interface SaslEncryptionBackend {
+
+ /** Disposes of resources used by the backend. */
+ void dispose();
+
+ /** Encrypt data. */
+ byte[] wrap(byte[] data, int offset, int len) throws SaslException;
+
+ /** Decrypt data. */
+ byte[] unwrap(byte[] data, int offset, int len) throws SaslException;
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
new file mode 100644
index 0000000..e52b526
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+import org.apache.spark.network.protocol.Encoders;
+import org.apache.spark.network.protocol.AbstractMessage;
+
+/**
+ * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
+ * with the given appId. This appId allows a single SaslRpcHandler to multiplex different
+ * applications which may be using different sets of credentials.
+ */
+class SaslMessage extends AbstractMessage {
+
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xEA;
+
+ public final String appId;
+
+ public SaslMessage(String appId, byte[] message) {
+ this(appId, Unpooled.wrappedBuffer(message));
+ }
+
+ public SaslMessage(String appId, ByteBuf message) {
+ super(new NettyManagedBuffer(message), true);
+ this.appId = appId;
+ }
+
+ @Override
+ public Type type() { return Type.User; }
+
+ @Override
+ public int encodedLength() {
+ // The integer (a.k.a. the body size) is not really used, since that information is already
+ // encoded in the frame length. But this maintains backwards compatibility with versions of
+ // RpcRequest that use Encoders.ByteArrays.
+ return 1 + Encoders.Strings.encodedLength(appId) + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ Encoders.Strings.encode(buf, appId);
+ // See comment in encodedLength().
+ buf.writeInt((int) body().size());
+ }
+
+ public static SaslMessage decode(ByteBuf buf) {
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalStateException("Expected SaslMessage, received something else"
+ + " (maybe your client does not have SASL enabled?)");
+ }
+
+ String appId = Encoders.Strings.decode(buf);
+ // See comment in encodedLength().
+ buf.readInt();
+ return new SaslMessage(appId, buf.retain());
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
new file mode 100644
index 0000000..c41f5b6
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import javax.security.sasl.Sasl;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * RPC Handler which performs SASL authentication before delegating to a child RPC handler.
+ * The delegate will only receive messages if the given connection has been successfully
+ * authenticated. A connection may be authenticated at most once.
+ *
+ * Note that the authentication process consists of multiple challenge-response pairs, each of
+ * which are individual RPCs.
+ */
+class SaslRpcHandler extends RpcHandler {
+ private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
+
+ /** Transport configuration. */
+ private final TransportConf conf;
+
+ /** The client channel. */
+ private final Channel channel;
+
+ /** RpcHandler we will delegate to for authenticated connections. */
+ private final RpcHandler delegate;
+
+ /** Class which provides secret keys which are shared by server and client on a per-app basis. */
+ private final SecretKeyHolder secretKeyHolder;
+
+ private SparkSaslServer saslServer;
+ private boolean isComplete;
+
+ SaslRpcHandler(
+ TransportConf conf,
+ Channel channel,
+ RpcHandler delegate,
+ SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.channel = channel;
+ this.delegate = delegate;
+ this.secretKeyHolder = secretKeyHolder;
+ this.saslServer = null;
+ this.isComplete = false;
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
+ if (isComplete) {
+ // Authentication complete, delegate to base handler.
+ delegate.receive(client, message, callback);
+ return;
+ }
+
+ ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
+ SaslMessage saslMessage;
+ try {
+ saslMessage = SaslMessage.decode(nettyBuf);
+ } finally {
+ nettyBuf.release();
+ }
+
+ if (saslServer == null) {
+ // First message in the handshake, setup the necessary state.
+ client.setClientId(saslMessage.appId);
+ saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
+ conf.saslServerAlwaysEncrypt());
+ }
+
+ byte[] response;
+ try {
+ response = saslServer.response(JavaUtils.bufferToArray(
+ saslMessage.body().nioByteBuffer()));
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
+ callback.onSuccess(ByteBuffer.wrap(response));
+
+ // Setup encryption after the SASL response is sent, otherwise the client can't parse the
+ // response. It's ok to change the channel pipeline here since we are processing an incoming
+ // message, so the pipeline is busy and no new incoming messages will be fed to it before this
+ // method returns. This assumes that the code ensures, through other means, that no outbound
+ // messages are being written to the channel while negotiation is still going on.
+ if (saslServer.isComplete()) {
+ logger.debug("SASL authentication successful for channel {}", client);
+ isComplete = true;
+ if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
+ logger.debug("Enabling encryption for channel {}", client);
+ SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
+ saslServer = null;
+ } else {
+ saslServer.dispose();
+ saslServer = null;
+ }
+ }
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message) {
+ delegate.receive(client, message);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return delegate.getStreamManager();
+ }
+
+ @Override
+ public void channelActive(TransportClient client) {
+ delegate.channelActive(client);
+ }
+
+ @Override
+ public void channelInactive(TransportClient client) {
+ try {
+ delegate.channelInactive(client);
+ } finally {
+ if (saslServer != null) {
+ saslServer.dispose();
+ }
+ }
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause, TransportClient client) {
+ delegate.exceptionCaught(cause, client);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java
new file mode 100644
index 0000000..f2f9838
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import io.netty.channel.Channel;
+
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a client connects
+ * to the server. This allows customizing the client channel to allow for things such as SASL
+ * authentication.
+ */
+public class SaslServerBootstrap implements TransportServerBootstrap {
+
+ private final TransportConf conf;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public SaslServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ /**
+ * Wrap the given application handler in a SaslRpcHandler that will handle the initial SASL
+ * negotiation.
+ */
+ public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
+ return new SaslRpcHandler(conf, channel, rpcHandler, secretKeyHolder);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
new file mode 100644
index 0000000..81d5766
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+/**
+ * Interface for getting a secret key associated with some application.
+ */
+public interface SecretKeyHolder {
+ /**
+ * Gets an appropriate SASL User for the given appId.
+ * @throws IllegalArgumentException if the given appId is not associated with a SASL user.
+ */
+ String getSaslUser(String appId);
+
+ /**
+ * Gets an appropriate SASL secret key for the given appId.
+ * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key.
+ */
+ String getSecretKey(String appId);
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
new file mode 100644
index 0000000..94685e9
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+import java.util.Map;
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.RealmChoiceCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+
+import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.spark.network.sasl.SparkSaslServer.*;
+
+/**
+ * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. This client initializes the protocol via a
+ * firstToken, which is then followed by a set of challenges and responses.
+ */
+public class SparkSaslClient implements SaslEncryptionBackend {
+ private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class);
+
+ private final String secretKeyId;
+ private final SecretKeyHolder secretKeyHolder;
+ private final String expectedQop;
+ private SaslClient saslClient;
+
+ public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) {
+ this.secretKeyId = secretKeyId;
+ this.secretKeyHolder = secretKeyHolder;
+ this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH;
+
+ Map<String, String> saslProps = ImmutableMap.<String, String>builder()
+ .put(Sasl.QOP, expectedQop)
+ .build();
+ try {
+ this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM,
+ saslProps, new ClientCallbackHandler());
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /** Used to initiate SASL handshake with server. */
+ public synchronized byte[] firstToken() {
+ if (saslClient != null && saslClient.hasInitialResponse()) {
+ try {
+ return saslClient.evaluateChallenge(new byte[0]);
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ } else {
+ return new byte[0];
+ }
+ }
+
+ /** Determines whether the authentication exchange has completed. */
+ public synchronized boolean isComplete() {
+ return saslClient != null && saslClient.isComplete();
+ }
+
+ /** Returns the value of a negotiated property. */
+ public Object getNegotiatedProperty(String name) {
+ return saslClient.getNegotiatedProperty(name);
+ }
+
+ /**
+ * Respond to server's SASL token.
+ * @param token contains server's SASL token
+ * @return client's response SASL token
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0];
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslClient might be using.
+ */
+ @Override
+ public synchronized void dispose() {
+ if (saslClient != null) {
+ try {
+ saslClient.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslClient = null;
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * that works with share secrets.
+ */
+ private class ClientCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL client callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL client callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL client callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ } else if (callback instanceof RealmChoiceCallback) {
+ // ignore (?)
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+ }
+ }
+ }
+ }
+
+ @Override
+ public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
+ return saslClient.wrap(data, offset, len);
+ }
+
+ @Override
+ public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
+ return saslClient.unwrap(data, offset, len);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
new file mode 100644
index 0000000..431cb67
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableMap;
+import io.netty.buffer.Unpooled;
+import io.netty.handler.codec.base64.Base64;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. (It is not a server in the sense of accepting
+ * connections on some socket.)
+ */
+public class SparkSaslServer implements SaslEncryptionBackend {
+ private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class);
+
+ /**
+ * This is passed as the server name when creating the sasl client/server.
+ * This could be changed to be configurable in the future.
+ */
+ static final String DEFAULT_REALM = "default";
+
+ /**
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ static final String DIGEST = "DIGEST-MD5";
+
+ /**
+ * Quality of protection value that includes encryption.
+ */
+ static final String QOP_AUTH_CONF = "auth-conf";
+
+ /**
+ * Quality of protection value that does not include encryption.
+ */
+ static final String QOP_AUTH = "auth";
+
+ /** Identifier for a certain secret key within the secretKeyHolder. */
+ private final String secretKeyId;
+ private final SecretKeyHolder secretKeyHolder;
+ private SaslServer saslServer;
+
+ public SparkSaslServer(
+ String secretKeyId,
+ SecretKeyHolder secretKeyHolder,
+ boolean alwaysEncrypt) {
+ this.secretKeyId = secretKeyId;
+ this.secretKeyHolder = secretKeyHolder;
+
+ // Sasl.QOP is a comma-separated list of supported values. The value that allows encryption
+ // is listed first since it's preferred over the non-encrypted one (if the client also
+ // lists both in the request).
+ String qop = alwaysEncrypt ? QOP_AUTH_CONF : String.format("%s,%s", QOP_AUTH_CONF, QOP_AUTH);
+ Map<String, String> saslProps = ImmutableMap.<String, String>builder()
+ .put(Sasl.SERVER_AUTH, "true")
+ .put(Sasl.QOP, qop)
+ .build();
+ try {
+ this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps,
+ new DigestCallbackHandler());
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Determines whether the authentication exchange has completed successfully.
+ */
+ public synchronized boolean isComplete() {
+ return saslServer != null && saslServer.isComplete();
+ }
+
+ /** Returns the value of a negotiated property. */
+ public Object getNegotiatedProperty(String name) {
+ return saslServer.getNegotiatedProperty(name);
+ }
+
+ /**
+ * Used to respond to server SASL tokens.
+ * @param token Server's SASL token
+ * @return response to send back to the server.
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0];
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslServer might be using.
+ */
+ @Override
+ public synchronized void dispose() {
+ if (saslServer != null) {
+ try {
+ saslServer.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslServer = null;
+ }
+ }
+ }
+
+ @Override
+ public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
+ return saslServer.wrap(data, offset, len);
+ }
+
+ @Override
+ public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
+ return saslServer.unwrap(data, offset, len);
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism.
+ */
+ private class DigestCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL server callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL server callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL server callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ } else if (callback instanceof AuthorizeCallback) {
+ AuthorizeCallback ac = (AuthorizeCallback) callback;
+ String authId = ac.getAuthenticationID();
+ String authzId = ac.getAuthorizationID();
+ ac.setAuthorized(authId.equals(authzId));
+ if (ac.isAuthorized()) {
+ ac.setAuthorizedID(authzId);
+ }
+ logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized());
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+ }
+ }
+ }
+ }
+
+ /* Encode a byte[] identifier as a Base64-encoded string. */
+ public static String encodeIdentifier(String identifier) {
+ Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled");
+ return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8)))
+ .toString(Charsets.UTF_8);
+ }
+
+ /** Encode a password as a base64-encoded char[] array. */
+ public static char[] encodePassword(String password) {
+ Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled");
+ return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8)))
+ .toString(Charsets.UTF_8).toCharArray();
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java
new file mode 100644
index 0000000..4a1f28e
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import org.apache.spark.network.protocol.Message;
+
+/**
+ * Handles either request or response messages coming off of Netty. A MessageHandler instance
+ * is associated with a single Netty Channel (though it may have multiple clients on the same
+ * Channel.)
+ */
+public abstract class MessageHandler<T extends Message> {
+ /** Handles the receipt of a single message. */
+ public abstract void handle(T message) throws Exception;
+
+ /** Invoked when the channel this MessageHandler is on is active. */
+ public abstract void channelActive();
+
+ /** Invoked when an exception was caught on the Channel. */
+ public abstract void exceptionCaught(Throwable cause);
+
+ /** Invoked when the channel this MessageHandler is on is inactive. */
+ public abstract void channelInactive();
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
new file mode 100644
index 0000000..6ed61da
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.nio.ByteBuffer;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+
+/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */
+public class NoOpRpcHandler extends RpcHandler {
+ private final StreamManager streamManager;
+
+ public NoOpRpcHandler() {
+ streamManager = new OneForOneStreamManager();
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
+ throw new UnsupportedOperationException("Cannot handle messages");
+ }
+
+ @Override
+ public StreamManager getStreamManager() { return streamManager; }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
new file mode 100644
index 0000000..ea9e735
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+
+import com.google.common.base.Preconditions;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.TransportClient;
+
+/**
+ * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually
+ * fetched as chunks by the client. Each registered buffer is one chunk.
+ */
+public class OneForOneStreamManager extends StreamManager {
+ private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class);
+
+ private final AtomicLong nextStreamId;
+ private final ConcurrentHashMap<Long, StreamState> streams;
+
+ /** State of a single stream. */
+ private static class StreamState {
+ final String appId;
+ final Iterator<ManagedBuffer> buffers;
+
+ // The channel associated to the stream
+ Channel associatedChannel = null;
+
+ // Used to keep track of the index of the buffer that the user has retrieved, just to ensure
+ // that the caller only requests each chunk one at a time, in order.
+ int curChunk = 0;
+
+ StreamState(String appId, Iterator<ManagedBuffer> buffers) {
+ this.appId = appId;
+ this.buffers = Preconditions.checkNotNull(buffers);
+ }
+ }
+
+ public OneForOneStreamManager() {
+ // For debugging purposes, start with a random stream id to help identifying different streams.
+ // This does not need to be globally unique, only unique to this class.
+ nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
+ streams = new ConcurrentHashMap<Long, StreamState>();
+ }
+
+ @Override
+ public void registerChannel(Channel channel, long streamId) {
+ if (streams.containsKey(streamId)) {
+ streams.get(streamId).associatedChannel = channel;
+ }
+ }
+
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ StreamState state = streams.get(streamId);
+ if (chunkIndex != state.curChunk) {
+ throw new IllegalStateException(String.format(
+ "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk));
+ } else if (!state.buffers.hasNext()) {
+ throw new IllegalStateException(String.format(
+ "Requested chunk index beyond end %s", chunkIndex));
+ }
+ state.curChunk += 1;
+ ManagedBuffer nextChunk = state.buffers.next();
+
+ if (!state.buffers.hasNext()) {
+ logger.trace("Removing stream id {}", streamId);
+ streams.remove(streamId);
+ }
+
+ return nextChunk;
+ }
+
+ @Override
+ public void connectionTerminated(Channel channel) {
+ // Close all streams which have been associated with the channel.
+ for (Map.Entry<Long, StreamState> entry: streams.entrySet()) {
+ StreamState state = entry.getValue();
+ if (state.associatedChannel == channel) {
+ streams.remove(entry.getKey());
+
+ // Release all remaining buffers.
+ while (state.buffers.hasNext()) {
+ state.buffers.next().release();
+ }
+ }
+ }
+ }
+
+ @Override
+ public void checkAuthorization(TransportClient client, long streamId) {
+ if (client.getClientId() != null) {
+ StreamState state = streams.get(streamId);
+ Preconditions.checkArgument(state != null, "Unknown stream ID.");
+ if (!client.getClientId().equals(state.appId)) {
+ throw new SecurityException(String.format(
+ "Client %s not authorized to read stream %d (app %s).",
+ client.getClientId(),
+ streamId,
+ state.appId));
+ }
+ }
+ }
+
+ /**
+ * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
+ * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
+ * client connection is closed before the iterator is fully drained, then the remaining buffers
+ * will all be release()'d.
+ *
+ * If an app ID is provided, only callers who've authenticated with the given app ID will be
+ * allowed to fetch from this stream.
+ */
+ public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
+ long myStreamId = nextStreamId.getAndIncrement();
+ streams.put(myStreamId, new StreamState(appId, buffers));
+ return myStreamId;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java
new file mode 100644
index 0000000..a99c301
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.nio.ByteBuffer;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+
+/**
+ * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
+ */
+public abstract class RpcHandler {
+
+ private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback();
+
+ /**
+ * Receive a single RPC message. Any exception thrown while in this method will be sent back to
+ * the client in string form as a standard RPC failure.
+ *
+ * This method will not be called in parallel for a single TransportClient (i.e., channel).
+ *
+ * @param client A channel client which enables the handler to make requests back to the sender
+ * of this RPC. This will always be the exact same object for a particular channel.
+ * @param message The serialized bytes of the RPC.
+ * @param callback Callback which should be invoked exactly once upon success or failure of the
+ * RPC.
+ */
+ public abstract void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback);
+
+ /**
+ * Returns the StreamManager which contains the state about which streams are currently being
+ * fetched by a TransportClient.
+ */
+ public abstract StreamManager getStreamManager();
+
+ /**
+ * Receives an RPC message that does not expect a reply. The default implementation will
+ * call "{@link #receive(TransportClient, ByteBuffer, RpcResponseCallback)}" and log a warning if
+ * any of the callback methods are called.
+ *
+ * @param client A channel client which enables the handler to make requests back to the sender
+ * of this RPC. This will always be the exact same object for a particular channel.
+ * @param message The serialized bytes of the RPC.
+ */
+ public void receive(TransportClient client, ByteBuffer message) {
+ receive(client, message, ONE_WAY_CALLBACK);
+ }
+
+ /**
+ * Invoked when the channel associated with the given client is active.
+ */
+ public void channelActive(TransportClient client) { }
+
+ /**
+ * Invoked when the channel associated with the given client is inactive.
+ * No further requests will come from this client.
+ */
+ public void channelInactive(TransportClient client) { }
+
+ public void exceptionCaught(Throwable cause, TransportClient client) { }
+
+ private static class OneWayRpcCallback implements RpcResponseCallback {
+
+ private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class);
+
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ logger.warn("Response provided for one-way RPC.");
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.error("Error response provided for one-way RPC.", e);
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
new file mode 100644
index 0000000..07f161a
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import io.netty.channel.Channel;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.TransportClient;
+
+/**
+ * The StreamManager is used to fetch individual chunks from a stream. This is used in
+ * {@link TransportRequestHandler} in order to respond to fetchChunk() requests. Creation of the
+ * stream is outside the scope of the transport layer, but a given stream is guaranteed to be read
+ * by only one client connection, meaning that getChunk() for a particular stream will be called
+ * serially and that once the connection associated with the stream is closed, that stream will
+ * never be used again.
+ */
+public abstract class StreamManager {
+ /**
+ * Called in response to a fetchChunk() request. The returned buffer will be passed as-is to the
+ * client. A single stream will be associated with a single TCP connection, so this method
+ * will not be called in parallel for a particular stream.
+ *
+ * Chunks may be requested in any order, and requests may be repeated, but it is not required
+ * that implementations support this behavior.
+ *
+ * The returned ManagedBuffer will be release()'d after being written to the network.
+ *
+ * @param streamId id of a stream that has been previously registered with the StreamManager.
+ * @param chunkIndex 0-indexed chunk of the stream that's requested
+ */
+ public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
+
+ /**
+ * Called in response to a stream() request. The returned data is streamed to the client
+ * through a single TCP connection.
+ *
+ * Note the <code>streamId</code> argument is not related to the similarly named argument in the
+ * {@link #getChunk(long, int)} method.
+ *
+ * @param streamId id of a stream that has been previously registered with the StreamManager.
+ * @return A managed buffer for the stream, or null if the stream was not found.
+ */
+ public ManagedBuffer openStream(String streamId) {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
+ * Associates a stream with a single client connection, which is guaranteed to be the only reader
+ * of the stream. The getChunk() method will be called serially on this connection and once the
+ * connection is closed, the stream will never be used again, enabling cleanup.
+ *
+ * This must be called before the first getChunk() on the stream, but it may be invoked multiple
+ * times with the same channel and stream id.
+ */
+ public void registerChannel(Channel channel, long streamId) { }
+
+ /**
+ * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
+ * to read from the associated streams again, so any state can be cleaned up.
+ */
+ public void connectionTerminated(Channel channel) { }
+
+ /**
+ * Verify that the client is authorized to read from the given stream.
+ *
+ * @throws SecurityException If client is not authorized.
+ */
+ public void checkAuthorization(TransportClient client, long streamId) { }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
new file mode 100644
index 0000000..18a9b78
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.handler.timeout.IdleState;
+import io.netty.handler.timeout.IdleStateEvent;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.RequestMessage;
+import org.apache.spark.network.protocol.ResponseMessage;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * The single Transport-level Channel handler which is used for delegating requests to the
+ * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}.
+ *
+ * All channels created in the transport layer are bidirectional. When the Client initiates a Netty
+ * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server
+ * will produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server
+ * also gets a handle on the same Channel, so it may then begin to send RequestMessages to the
+ * Client.
+ * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler,
+ * for the Client's responses to the Server's requests.
+ *
+ * This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}.
+ * We consider a connection timed out if there are outstanding fetch or RPC requests but no traffic
+ * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not
+ * timeout if the client is continuously sending but getting no responses, for simplicity.
+ */
+public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
+ private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
+
+ private final TransportClient client;
+ private final TransportResponseHandler responseHandler;
+ private final TransportRequestHandler requestHandler;
+ private final long requestTimeoutNs;
+ private final boolean closeIdleConnections;
+
+ public TransportChannelHandler(
+ TransportClient client,
+ TransportResponseHandler responseHandler,
+ TransportRequestHandler requestHandler,
+ long requestTimeoutMs,
+ boolean closeIdleConnections) {
+ this.client = client;
+ this.responseHandler = responseHandler;
+ this.requestHandler = requestHandler;
+ this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000;
+ this.closeIdleConnections = closeIdleConnections;
+ }
+
+ public TransportClient getClient() {
+ return client;
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()),
+ cause);
+ requestHandler.exceptionCaught(cause);
+ responseHandler.exceptionCaught(cause);
+ ctx.close();
+ }
+
+ @Override
+ public void channelActive(ChannelHandlerContext ctx) throws Exception {
+ try {
+ requestHandler.channelActive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from request handler while registering channel", e);
+ }
+ try {
+ responseHandler.channelActive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from response handler while registering channel", e);
+ }
+ super.channelRegistered(ctx);
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ try {
+ requestHandler.channelInactive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from request handler while unregistering channel", e);
+ }
+ try {
+ responseHandler.channelInactive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from response handler while unregistering channel", e);
+ }
+ super.channelUnregistered(ctx);
+ }
+
+ @Override
+ public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
+ if (request instanceof RequestMessage) {
+ requestHandler.handle((RequestMessage) request);
+ } else {
+ responseHandler.handle((ResponseMessage) request);
+ }
+ }
+
+ /** Triggered based on events from an {@link io.netty.handler.timeout.IdleStateHandler}. */
+ @Override
+ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
+ if (evt instanceof IdleStateEvent) {
+ IdleStateEvent e = (IdleStateEvent) evt;
+ // See class comment for timeout semantics. In addition to ensuring we only timeout while
+ // there are outstanding requests, we also do a secondary consistency check to ensure
+ // there's no race between the idle timeout and incrementing the numOutstandingRequests
+ // (see SPARK-7003).
+ //
+ // To avoid a race between TransportClientFactory.createClient() and this code which could
+ // result in an inactive client being returned, this needs to run in a synchronized block.
+ synchronized (this) {
+ boolean isActuallyOverdue =
+ System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs;
+ if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) {
+ if (responseHandler.numOutstandingRequests() > 0) {
+ String address = NettyUtils.getRemoteAddress(ctx.channel());
+ logger.error("Connection to {} has been quiet for {} ms while there are outstanding " +
+ "requests. Assuming connection is dead; please adjust spark.network.timeout if this " +
+ "is wrong.", address, requestTimeoutNs / 1000 / 1000);
+ client.timeOut();
+ ctx.close();
+ } else if (closeIdleConnections) {
+ // While CloseIdleConnections is enable, we also close idle connection
+ client.timeOut();
+ ctx.close();
+ }
+ }
+ }
+ }
+ ctx.fireUserEventTriggered(evt);
+ }
+
+ public TransportResponseHandler getResponseHandler() {
+ return responseHandler;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
new file mode 100644
index 0000000..296ced3
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.OneWayMessage;
+import org.apache.spark.network.protocol.RequestMessage;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamRequest;
+import org.apache.spark.network.protocol.StreamResponse;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * A handler that processes requests from clients and writes chunk data back. Each handler is
+ * attached to a single Netty channel, and keeps track of which streams have been fetched via this
+ * channel, in order to clean them up if the channel is terminated (see #channelUnregistered).
+ *
+ * The messages should have been processed by the pipeline setup by {@link TransportServer}.
+ */
+public class TransportRequestHandler extends MessageHandler<RequestMessage> {
+ private final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class);
+
+ /** The Netty channel that this handler is associated with. */
+ private final Channel channel;
+
+ /** Client on the same channel allowing us to talk back to the requester. */
+ private final TransportClient reverseClient;
+
+ /** Handles all RPC messages. */
+ private final RpcHandler rpcHandler;
+
+ /** Returns each chunk part of a stream. */
+ private final StreamManager streamManager;
+
+ public TransportRequestHandler(
+ Channel channel,
+ TransportClient reverseClient,
+ RpcHandler rpcHandler) {
+ this.channel = channel;
+ this.reverseClient = reverseClient;
+ this.rpcHandler = rpcHandler;
+ this.streamManager = rpcHandler.getStreamManager();
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ rpcHandler.exceptionCaught(cause, reverseClient);
+ }
+
+ @Override
+ public void channelActive() {
+ rpcHandler.channelActive(reverseClient);
+ }
+
+ @Override
+ public void channelInactive() {
+ if (streamManager != null) {
+ try {
+ streamManager.connectionTerminated(channel);
+ } catch (RuntimeException e) {
+ logger.error("StreamManager connectionTerminated() callback failed.", e);
+ }
+ }
+ rpcHandler.channelInactive(reverseClient);
+ }
+
+ @Override
+ public void handle(RequestMessage request) {
+ if (request instanceof ChunkFetchRequest) {
+ processFetchRequest((ChunkFetchRequest) request);
+ } else if (request instanceof RpcRequest) {
+ processRpcRequest((RpcRequest) request);
+ } else if (request instanceof OneWayMessage) {
+ processOneWayMessage((OneWayMessage) request);
+ } else if (request instanceof StreamRequest) {
+ processStreamRequest((StreamRequest) request);
+ } else {
+ throw new IllegalArgumentException("Unknown request type: " + request);
+ }
+ }
+
+ private void processFetchRequest(final ChunkFetchRequest req) {
+ final String client = NettyUtils.getRemoteAddress(channel);
+
+ logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);
+
+ ManagedBuffer buf;
+ try {
+ streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
+ streamManager.registerChannel(channel, req.streamChunkId.streamId);
+ buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
+ } catch (Exception e) {
+ logger.error(String.format(
+ "Error opening block %s for request from %s", req.streamChunkId, client), e);
+ respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e)));
+ return;
+ }
+
+ respond(new ChunkFetchSuccess(req.streamChunkId, buf));
+ }
+
+ private void processStreamRequest(final StreamRequest req) {
+ final String client = NettyUtils.getRemoteAddress(channel);
+ ManagedBuffer buf;
+ try {
+ buf = streamManager.openStream(req.streamId);
+ } catch (Exception e) {
+ logger.error(String.format(
+ "Error opening stream %s for request from %s", req.streamId, client), e);
+ respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e)));
+ return;
+ }
+
+ if (buf != null) {
+ respond(new StreamResponse(req.streamId, buf.size(), buf));
+ } else {
+ respond(new StreamFailure(req.streamId, String.format(
+ "Stream '%s' was not found.", req.streamId)));
+ }
+ }
+
+ private void processRpcRequest(final RpcRequest req) {
+ try {
+ rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+ }
+ });
+ } catch (Exception e) {
+ logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
+ respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+ } finally {
+ req.body().release();
+ }
+ }
+
+ private void processOneWayMessage(OneWayMessage req) {
+ try {
+ rpcHandler.receive(reverseClient, req.body().nioByteBuffer());
+ } catch (Exception e) {
+ logger.error("Error while invoking RpcHandler#receive() for one-way message.", e);
+ } finally {
+ req.body().release();
+ }
+ }
+
+ /**
+ * Responds to a single message with some Encodable object. If a failure occurs while sending,
+ * it will be logged and the channel closed.
+ */
+ private void respond(final Encodable result) {
+ final String remoteAddress = channel.remoteAddress().toString();
+ channel.writeAndFlush(result).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ logger.trace(String.format("Sent result %s to client %s", result, remoteAddress));
+ } else {
+ logger.error(String.format("Error sending result %s to %s; closing connection",
+ result, remoteAddress), future.cause());
+ channel.close();
+ }
+ }
+ }
+ );
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
new file mode 100644
index 0000000..baae235
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.io.Closeable;
+import java.net.InetSocketAddress;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import org.apache.spark.network.util.JavaUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Server for the efficient, low-level streaming service.
+ */
+public class TransportServer implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportServer.class);
+
+ private final TransportContext context;
+ private final TransportConf conf;
+ private final RpcHandler appRpcHandler;
+ private final List<TransportServerBootstrap> bootstraps;
+
+ private ServerBootstrap bootstrap;
+ private ChannelFuture channelFuture;
+ private int port = -1;
+
+ /**
+ * Creates a TransportServer that binds to the given host and the given port, or to any available
+ * if 0. If you don't want to bind to any special host, set "hostToBind" to null.
+ * */
+ public TransportServer(
+ TransportContext context,
+ String hostToBind,
+ int portToBind,
+ RpcHandler appRpcHandler,
+ List<TransportServerBootstrap> bootstraps) {
+ this.context = context;
+ this.conf = context.getConf();
+ this.appRpcHandler = appRpcHandler;
+ this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps));
+
+ try {
+ init(hostToBind, portToBind);
+ } catch (RuntimeException e) {
+ JavaUtils.closeQuietly(this);
+ throw e;
+ }
+ }
+
+ public int getPort() {
+ if (port == -1) {
+ throw new IllegalStateException("Server not initialized");
+ }
+ return port;
+ }
+
+ private void init(String hostToBind, int portToBind) {
+
+ IOMode ioMode = IOMode.valueOf(conf.ioMode());
+ EventLoopGroup bossGroup =
+ NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server");
+ EventLoopGroup workerGroup = bossGroup;
+
+ PooledByteBufAllocator allocator = NettyUtils.createPooledByteBufAllocator(
+ conf.preferDirectBufs(), true /* allowCache */, conf.serverThreads());
+
+ bootstrap = new ServerBootstrap()
+ .group(bossGroup, workerGroup)
+ .channel(NettyUtils.getServerChannelClass(ioMode))
+ .option(ChannelOption.ALLOCATOR, allocator)
+ .childOption(ChannelOption.ALLOCATOR, allocator);
+
+ if (conf.backLog() > 0) {
+ bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
+ }
+
+ if (conf.receiveBuf() > 0) {
+ bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf());
+ }
+
+ if (conf.sendBuf() > 0) {
+ bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf());
+ }
+
+ bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ protected void initChannel(SocketChannel ch) throws Exception {
+ RpcHandler rpcHandler = appRpcHandler;
+ for (TransportServerBootstrap bootstrap : bootstraps) {
+ rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
+ }
+ context.initializePipeline(ch, rpcHandler);
+ }
+ });
+
+ InetSocketAddress address = hostToBind == null ?
+ new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind);
+ channelFuture = bootstrap.bind(address);
+ channelFuture.syncUninterruptibly();
+
+ port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
+ logger.debug("Shuffle server started on port :" + port);
+ }
+
+ @Override
+ public void close() {
+ if (channelFuture != null) {
+ // close is a local operation and should finish within milliseconds; timeout just to be safe
+ channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ channelFuture = null;
+ }
+ if (bootstrap != null && bootstrap.group() != null) {
+ bootstrap.group().shutdownGracefully();
+ }
+ if (bootstrap != null && bootstrap.childGroup() != null) {
+ bootstrap.childGroup().shutdownGracefully();
+ }
+ bootstrap = null;
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java
new file mode 100644
index 0000000..05803ab
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import io.netty.channel.Channel;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a client connects
+ * to the server. This allows customizing the client channel to allow for things such as SASL
+ * authentication.
+ */
+public interface TransportServerBootstrap {
+ /**
+ * Customizes the channel to include new features, if needed.
+ *
+ * @param channel The connected channel opened by the client.
+ * @param rpcHandler The RPC handler for the server.
+ * @return The RPC handler to use for the channel.
+ */
+ RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler);
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9e01dcc6/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java
new file mode 100644
index 0000000..b141572
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+
+/**
+ * A writable channel that stores the written data in a byte array in memory.
+ */
+public class ByteArrayWritableChannel implements WritableByteChannel {
+
+ private final byte[] data;
+ private int offset;
+
+ public ByteArrayWritableChannel(int size) {
+ this.data = new byte[size];
+ }
+
+ public byte[] getData() {
+ return data;
+ }
+
+ public int length() {
+ return offset;
+ }
+
+ /** Resets the channel so that writing to it will overwrite the existing buffer. */
+ public void reset() {
+ offset = 0;
+ }
+
+ /**
+ * Reads from the given buffer into the internal byte array.
+ */
+ @Override
+ public int write(ByteBuffer src) {
+ int toTransfer = Math.min(src.remaining(), data.length - offset);
+ src.get(data, offset, toTransfer);
+ offset += toTransfer;
+ return toTransfer;
+ }
+
+ @Override
+ public void close() {
+
+ }
+
+ @Override
+ public boolean isOpen() {
+ return true;
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org